HyperCLOVAX-SEED-Omni-8B / preprocessor.py
PenPaperKeyCode's picture
Init
3169f6c
import ast
import copy
import datetime
import gc
import io
import json
import math
import mimetypes
import os
import random
import re
import sys
import tarfile
import tempfile
import zipfile
from collections import defaultdict, deque
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import av
import cv2
import numpy as np
import PIL
import pkg_resources
import scipy.signal as scsig
import torch
from decord import VideoReader, cpu
from PIL import Image, ImageDraw
from smart_open import open
from torchvision.transforms.functional import to_tensor
from hcxvlm.dataset.base_dataset import image_decoder
from hcxvlm.dataset.hcx_vision_prompter import HCXVisionPrompter
CHOICES = list(map(chr, range(97, 123)))
IGNORE_INDEX = -100
DEFAULT_SAMPLE_RATE = 16000
MIN_DISCRETE_AUDIO_CHUNK_SAMPLES = 1600
DEFAULT_VOLUME_LEVEL = 10 ** (-26 / 20)
hcx_vision_prompter = HCXVisionPrompter()
def hpf_normalize(
wav: np.ndarray,
sr: int = DEFAULT_SAMPLE_RATE,
volume_level: float = DEFAULT_VOLUME_LEVEL,
) -> np.ndarray:
assert (wav**2).mean() > 0, "Error in the wav file"
filter_ = scsig.butter(2, 70, "highpass", fs=sr, output="sos")
wav = scsig.sosfilt(filter_, wav)
wav = wav.astype(np.float32)
gain = min(volume_level / (wav**2).mean() ** 0.5, 1 / np.max(np.abs(wav)))
wav *= gain
return wav
def convert_bboxes(img, img_meta):
for k, v in img_meta.items():
if k == "region":
bbox_key = "bbox" if "bbox" in img_meta[k] else "boundingBox"
img_meta[k] = reform_bbox(
img_meta[k][bbox_key], img.size, format=img_meta[k]["format"]
)
return img_meta
def reform_bbox(bbox, image_size, format="REL_XYXY"):
w, h = image_size
if format == "REL_XYXY":
x1, y1, x2, y2 = bbox[0] * w, bbox[1] * h, bbox[2] * w, bbox[3] * h
elif format == "REL_XYWH":
x1, y1 = bbox[0] * w, bbox[1] * h
x2, y2 = x1 + bbox[2] * w, y1 + bbox[3] * h
else:
raise NotImplementedError
new_bbox = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
return new_bbox
def generate_random_color(use_alpha=True, seed=None):
if seed is None:
seed = np.random.default_rng()
if use_alpha:
color_list = [
("빨강", (255, 127, 127, 100)),
("노랑", (255, 255, 127, 100)),
("초록", (127, 255, 125, 100)),
("하늘", (127, 255, 255, 100)),
("파랑", (127, 127, 255, 100)),
("보라", (255, 127, 255, 100)),
]
else:
color_list = [
("빨강", (255, 0, 0)),
("노랑", (255, 255, 0)),
("초록", (0, 128, 0)),
("하늘", (135, 206, 235)),
("파랑", (0, 0, 255)),
("보라", (128, 0, 128)),
]
return color_list[seed.integers(0, len(color_list))]
EN_COLOR = {
"빨강": "red",
"노랑": "yellow",
"초록": "green",
"하늘": "sky blue",
"파랑": "blue",
"보라": "purple",
}
def overlay_rectangle(image, words, lang, seed=None):
color_str, color = generate_random_color(seed=seed)
draw = ImageDraw.Draw(image, "RGBA")
for word in words:
shape_rect = word["bbox"]
shape_rect = [(round(x[0]), round(x[1])) for x in shape_rect]
draw.polygon(shape_rect, color)
del draw
if lang == "en":
color_str = EN_COLOR[color_str]
return image, color_str
def convert_tags_for_video(img, json):
"""video 데이터에는 <image_xx> 태그 대신 <video_00> tag가 있음.
img 숫자 만큼 <video_00> tag 대신 <image_xx> tag를 변환하여 넣음
"""
image_tag = "".join([f"<image_{idx:02d}>" for idx in range(len(img))])
for json_key in json:
if "qa_pairs" in json_key:
new_qa_pairs = []
for qa_pair in json[json_key]:
question = qa_pair[0]
question = question.replace("<video_00>", image_tag)
new_qa_pairs.append([question, qa_pair[1]])
json[json_key] = new_qa_pairs
return img, json
def sampling_multiturn_single_img(
seq,
count,
multiturn_preserve_order=True,
multiturn_continuous=False,
is_train: bool = True,
seed=None,
):
if seed is None:
seed = np.random.default_rng()
n_sample = min(count, len(seq))
if multiturn_continuous:
if len(seq) <= n_sample:
start_index = 0
else:
start_index = seed.integers(0, len(seq) - n_sample)
indices = range(start_index, start_index + n_sample)
elif multiturn_preserve_order:
indices = sorted(seed.choice(range(len(seq)), size=n_sample, replace=False))
else:
indices = seed.choice(range(len(seq)), size=n_sample, replace=False)
return [seq[i] for i in indices]
def draw_bbox(image, bbox, lang="en", line_width=5, seed=None):
if seed is None:
seed = np.random.default_rng()
color_str, color = generate_random_color(use_alpha=False, seed=seed)
draw = ImageDraw.Draw(image, "RGB")
rect_bbox = (bbox[0][0], bbox[0][1], bbox[2][0], bbox[2][1])
draw.rectangle(rect_bbox, outline=color, width=line_width)
del draw
if lang == "en":
color_str = EN_COLOR[color_str]
return image, color_str
def bbox_process(bbox, detection_precision=2):
bbox_str = "["
for idx, point in enumerate(bbox):
if idx % 2 == 0:
normalized = point
else:
normalized = point
if idx < len(bbox) - 1:
bbox_str += format(normalized, ".2f") + ", "
else:
bbox_str += format(normalized, ".2f")
bbox_str += "]"
return bbox_str
def load_txt(file_path):
lines_list = []
with open(file_path, "r") as file:
for line in file:
lines_list.append(line.replace("\\n", "\n").strip())
return lines_list
def convert_format_for_multi_image(
img, json, convert_key_list=["words", "text", "objects", "entities"]
):
"""single image dataset 과 multi image dataset 에서 읽어온 img, json format 이 다름.
따라서 single image dataset 에서 읽어온 img, json 을 multi image dataset 의 format (dict) 으로 convert
"""
is_multi_image_dataset = isinstance(img, dict)
if not is_multi_image_dataset:
img = {"00": img}
for convert_key in convert_key_list:
if convert_key in json:
json[convert_key] = {"00": json[convert_key]}
for json_key in json:
if "region" in json_key:
json[json_key] = {"00": json[json_key]}
else:
for convert_key in convert_key_list:
if convert_key in json:
if isinstance(json[convert_key], list):
json[convert_key] = {"00": json[convert_key]}
for json_key in json:
if "region" in json_key:
if isinstance(json[json_key], list):
json[json_key] = {"00": json[json_key]}
return is_multi_image_dataset, img, json
class ConditionalError(Exception):
def __init__(self, message="Our assertion error"):
super().__init__(message)
def get_wds_default_config(default_config, existing_default_config=None):
if existing_default_config is None:
default_config_check_dict = {
"subtask": "",
"reasoning": False,
"use_task_prompt": True,
"get_random": True,
"add_instruct_prompts": [],
"multiturn_n_samples": 0,
"multiturn_preserve_order": True,
"multiturn_continuous": False,
"insert_ocr": 200,
"ocr_filter_strategy": "confidence",
"ocr_use_ratio": 1.0,
"entity_top_k": 0,
"entity_keyword_threshold": 100,
"entity_keyword_fashion_threshold": 100,
"entity_use_ratio": 0.0,
"llava_pretrain": False,
"random_system_prob": 0.0,
"random_system_path": "",
"random_tool_prob": 0.005,
}
else:
default_config_check_dict = existing_default_config
if default_config is None:
default_config = default_config_check_dict
else:
for key, value in default_config_check_dict.items():
if key not in default_config:
default_config[key] = value
return default_config
def get_datalake_default_config(default_config):
default_config_check_dict = {
"multiturn_n_samples": 0,
"multiturn_preserve_order": True,
"multiturn_continuous": True,
"insert_ocr": 0,
"ocr_filter_strategy": "confidence",
"entity_top_k": 0,
"entity_keyword_threshold": 0,
"entity_keyword_fashion_threshold": 0,
"entity_use_ratio": 0.0,
"ocr_use_ratio": 0.0,
"llava_pretrain": False,
"random_system_prob": 0.0,
"random_system_path": "",
"random_tool_prob": 0.005,
}
if default_config is None:
default_config = default_config_check_dict
else:
for key, value in default_config_check_dict.items():
if key not in default_config:
default_config[key] = value
return default_config
@dataclass
class Processed_sample:
input_str: str = None
input_ids: list = None
label_ids: list = None
imgs: list = None
discrete_imgs: list = None
videos: list = None
videos_duration: List[dict] = None
video_audios: list = None
audios: list = None
audios_duration: List[dict] = None
discrete_audios: list = None
sample_mm_counter: dict = None
from hcxvlm.dataset.bbox_processor import (
extract_bboxes,
insert_bboxes_to_json,
is_bbox_padded,
)
class Preprocessor:
prompt_head = ""
va_prefix = "\n<|im_start|>"
new_line = "\n"
turn_prefix = "<|im_start|>"
turn_suffix = "<|im_end|>"
mime_start = "<|mime_start|>"
mime_end = "<|mime_end|>"
aux_img_start = "<|image_aux_start|>"
aux_img_end = "<|image_aux_end|>"
aux_video_start = "<|video_aux_start|>"
aux_video_end = "<|video_aux_end|>"
aux_audio_start = "<|audio_aux_start|>"
aux_audio_end = "<|audio_aux_end|>"
image_start = "<|image_start|>"
image_end = "<|image_end|>"
image_pad = "<|IMAGE_PAD|>"
video_start = "<|video_start|>"
video_end = "<|video_end|>"
video_pad = "<|VIDEO_PAD|>"
audio_start = "<|audio_start|>"
audio_end = "<|audio_end|>"
audio_pad = "<|AUDIO_PAD|>"
discrete_image_start = "<|discrete_image_start|>"
discrete_image_end = "<|discrete_image_end|>"
discrete_image_pad = "<|DISCRETE_IMAGE_PAD|>"
video_audio_pad = "<|VIDEO_AUDIO_PAD|>"
discrete_audio_start = "<|discrete_audio_start|>"
discrete_audio_end = "<|discrete_audio_end|>"
discrete_audio_pad = "<|DISCRETE_AUDIO_PAD|>"
discrete_image_eol = "<|vision_eol|>"
discrete_image_eof = "<|vision_eof|>"
discrete_image_ratios = {
(1, 1): "<|vision_ratio_1:1|>",
(1, 2): "<|vision_ratio_1:2|>",
(2, 1): "<|vision_ratio_2:1|>",
(3, 4): "<|vision_ratio_3:4|>",
(4, 3): "<|vision_ratio_4:3|>",
(3, 5): "<|vision_ratio_3:5|>",
(5, 3): "<|vision_ratio_5:3|>",
(4, 5): "<|vision_ratio_4:5|>",
(5, 4): "<|vision_ratio_5:4|>",
(6, 9): "<|vision_ratio_6:9|>",
(9, 6): "<|vision_ratio_9:6|>",
(9, 16): "<|vision_ratio_9:16|>",
(16, 9): "<|vision_ratio_16:9|>",
}
aux_vid_prompt = (
"다음 중 video_duration은 비디오 길이 정보입니다. 참고하여 답변하세요. "
)
aux_audio_prompt = (
"다음 중 audio_duration은 오디오 길이 정보입니다. 참고하여 답변하세요. "
)
def __init__(
self,
tokenizer=None,
prepare_input_fn=None,
prepare_audio_input_fn=None,
sample_min_length=0,
decoder_max_length=None,
mode="train",
model=None,
datalake_default_config=None,
wds_default_config=None,
video_config=None,
train_video=False,
train_audio=False,
sequence_parallel_size=1,
video_audio_compressor_type=None,
):
self.sequence_parallel_size = sequence_parallel_size
if sequence_parallel_size > 1:
self.rng = np.random.default_rng(seed=42)
else:
self.rng = np.random.default_rng()
if model is not None:
tokenizer = model.tokenizer
decoder_max_length = 16000
if model is not None and prepare_input_fn is None:
raise "please give ImageProcessor!"
self.prepare_input_fn = prepare_input_fn
self.prepare_audio_input_fn = prepare_audio_input_fn
try:
from transformers.models.qwen2_5_vl.processing_qwen2_5_vl import (
Qwen2_5_VLProcessor,
)
self.is_qwen_visual = isinstance(prepare_input_fn, Qwen2_5_VLProcessor)
except Exception as e:
self.is_qwen_visual = False
try:
if not self.is_qwen_visual:
from hcxvlm.models.processing_vlm import HCXVisionV2Processor
self.is_qwen_visual = isinstance(prepare_input_fn, HCXVisionV2Processor)
except Exception as e:
self.is_qwen_visual = False
assert self.is_qwen_visual, "qwen2.5-vl visual prepare_input_fn import error"
self.video_max_num_frames = (
video_config["video_max_num_frames"]
if video_config and "video_max_num_frames" in video_config
else 120
)
self.video_max_pixels = (
video_config["video_max_pixels"]
if video_config and "video_max_pixels" in video_config
else 378 * 378
)
self.tokenizer = tokenizer
self.sample_min_length = sample_min_length
self.decoder_max_length = decoder_max_length
self.mode = mode
self.default_config = get_datalake_default_config(datalake_default_config)
self.wds_default_config = get_wds_default_config(wds_default_config)
self.train_video = train_video
self.train_audio = train_audio
self.video_audio_compressor_type = video_audio_compressor_type
self.img_token = self.tokenizer.encode(Preprocessor.image_pad)[0]
assert (
len(self.tokenizer.encode(Preprocessor.image_pad)) == 1
), "img_token is not configured in tokenizer"
self.discrete_image_token = self.tokenizer.encode(
Preprocessor.discrete_image_pad
)[0]
assert (
len(self.tokenizer.encode(Preprocessor.discrete_image_pad)) == 1
), "discrete_image_token is not configured in tokenizer"
self.discrete_image_eol_token = self.tokenizer.encode(
Preprocessor.discrete_image_eol
)[0]
assert (
len(self.tokenizer.encode(Preprocessor.discrete_image_eol)) == 1
), "discrete_image_eol_token is not configured in tokenizer"
self.discrete_image_eof_token = self.tokenizer.encode(
Preprocessor.discrete_image_eof
)[0]
assert (
len(self.tokenizer.encode(Preprocessor.discrete_image_eof)) == 1
), "discrete_image_eof_token is not configured in tokenizer"
self.discrete_image_ratio_tokens = dict()
for ratio, token_str in Preprocessor.discrete_image_ratios.items():
token_id = self.tokenizer.encode(token_str)[0]
assert (
len(self.tokenizer.encode(token_str)) == 1
), f"discrete_image_ratio_token {token_str} is not configured in tokenizer"
self.discrete_image_ratio_tokens[ratio] = token_id
self.video_token = self.tokenizer.encode(Preprocessor.video_pad)[0]
assert (
len(self.tokenizer.encode(Preprocessor.video_pad)) == 1
), "video_token is not configured in tokenizer"
self.video_audio_token = self.tokenizer.encode(Preprocessor.video_audio_pad)[0]
assert (
len(self.tokenizer.encode(Preprocessor.video_audio_pad)) == 1
), "video_audio_token is not configured in tokenizer"
def resize_min_edge(img: Image.Image) -> Image.Image:
w, h = img.size
min_size = 28
if min(w, h) >= min_size:
return img
if w < h:
new_w = min_size
new_h = int(h * (min_size / w))
else:
new_h = min_size
new_w = int(w * (min_size / h))
return img.resize((new_w, new_h), Image.BICUBIC)
self._resize_min_edge = resize_min_edge
self.audio_token = self.tokenizer.encode(Preprocessor.audio_pad)[0]
assert (
len(self.tokenizer.encode(Preprocessor.audio_pad)) == 1
), "audio_token is not configured in tokenizer"
self.discrete_audio_token = self.tokenizer.encode(
Preprocessor.discrete_audio_pad
)[0]
assert (
len(self.tokenizer.encode(Preprocessor.discrete_audio_pad)) == 1
), "audio_token is not configured in tokenizer"
from hcxvlm.dataset.json_processer import generate_prompt
self.generate_prompt = generate_prompt
self.mimes = list()
for mime_filename in [
"words_alpha.txt",
"korean-366506-wordslistUnique.txt",
]:
self.mimes += (
pkg_resources.resource_string(
"hcxvlm", f"dataset/hcx_vision_prompter/prompts/{mime_filename}"
)
.decode("utf-8")
.split("\r\n")
)
self.common_tools = []
try:
common_tools_bytes = pkg_resources.resource_string(
"hcxvlm",
"dataset/hcx_vision_prompter/prompts/common_tools.jsonl",
)
for line in common_tools_bytes.decode("utf-8").splitlines():
line = line.strip()
if not line:
continue
try:
self.common_tools.append(json.loads(line))
except Exception:
continue
except Exception:
self.common_tools = []
self.random_system_prompt = ""
if self.default_config["random_system_path"] != "":
self.random_system_prompt = ""
with open(self.default_config["random_system_path"], "r") as f:
for line in f:
self.random_system_prompt += line
if (
self.random_system_prompt != ""
and self.wds_default_config["random_system_path"] != ""
):
assert (
self.wds_default_config["random_system_path"]
== self.default_config["random_system_path"]
), "random_system_path in both default_config and wds_default_config should be the same"
def _find_best_ratio_token(self, original_size):
"""Find the best ratio token based on original_size"""
base_ratios = list(self.discrete_image_ratio_tokens.keys())
vision_aspect_ratios = [
r for ratio in base_ratios for r in [ratio, ratio[::-1]]
][1:]
if not isinstance(original_size, list) or len(original_size) != 2:
return self.discrete_image_ratio_tokens[(1, 1)]
h, w = original_size
if h == 0 or w == 0:
return self.discrete_image_ratio_tokens[(1, 1)]
ratios = [i / j for i, j in vision_aspect_ratios]
best_size_idx = np.argmin([abs(w / h - r) for r in ratios])
i, j = vision_aspect_ratios[best_size_idx]
return self.discrete_image_ratio_tokens[(i, j)]
@classmethod
def prompt_mime(
cls,
mimes: Optional[list[str]] = None,
file_name: str = None,
tag_idx: int = 1,
fixed_mime: bool = False,
is_video: bool = False,
is_audio: bool = False,
seed: np.random.Generator = None,
) -> list[dict]:
assert mimes or file_name
if seed is None:
seed = np.random.default_rng()
if file_name:
name, ext = os.path.splitext(file_name)
ext = ext.lstrip(".")
elif fixed_mime:
ext = "jpeg"
name = mimes[tag_idx]
elif not fixed_mime and seed is not None:
ext = seed.choice(["png", "jpeg"])
name = mimes[seed.integers(0, len(mimes))]
else:
ext = "jpeg"
name = mimes[tag_idx]
if is_video:
ext_candidates = ["mp4", "mov", "avi", "webm"]
if fixed_mime:
ext = "mp4"
elif ext not in ext_candidates:
ext = seed.choice(ext_candidates)
filename = f"{name}.{ext}"
mime_type = mimetypes.guess_type(filename)[0]
mime_prompt = {
"id": f"video_{str(tag_idx).zfill(2)}",
"type": f"{mime_type}",
"filename": f"{filename}",
}
return mime_prompt
if is_audio:
ext_candidates = ["mp3", "wav", "aac", "flac", "pcm"]
if fixed_mime:
ext = "wav"
elif ext not in ext_candidates:
ext = seed.choice(ext_candidates)
filename = f"{name}.{ext}"
mime_type = mimetypes.guess_type(filename)[0]
mime_prompt = {
"id": f"audio_{str(tag_idx).zfill(2)}",
"type": f"{mime_type}",
"filename": f"{filename}",
}
return mime_prompt
if file_name:
filename = f"{name}.{ext}"
mime_type = mimetypes.guess_type(filename)[0]
mime_prompt = {
"id": f"image_{str(tag_idx).zfill(2)}",
"type": f"{mime_type}",
"filename": f"{filename}",
}
else:
mime_prompt = {
"id": f"image_{str(tag_idx).zfill(2)}",
"type": f"image/{ext}",
"filename": f"{name}.{'jpg' if ext == 'jpeg' else 'png'}",
}
return mime_prompt
@classmethod
def ocr_preprocess(
cls,
words: list[dict],
n_insert_ocr_tokens: int = 2000,
insert_ocr: int = 200,
ocr_use_ratio: float = 0.5,
tokenizer=None,
seed=None,
) -> list[str]:
if seed is None:
seed = np.random.default_rng()
if ocr_use_ratio < seed.random():
return None
if insert_ocr == 0:
return None
confidence_list = []
insert_ocr_prompt = []
for word in words:
if "confidence" in word:
confidence_list.append(word["confidence"])
has_ocr_confidence = len(confidence_list) >= insert_ocr
if len(words) <= insert_ocr or not has_ocr_confidence:
insert_ocr_prompt += [
d["text"].strip() for d in words if d["text"].strip()
][:insert_ocr]
else:
confidence_threshold = 0.3
cnt = 0
for word in words:
if word["text"] == "":
continue
if word["confidence"] >= confidence_threshold:
insert_ocr_prompt.append(word["text"])
cnt += 1
if cnt >= insert_ocr:
break
ocr_inputs = " ".join(insert_ocr_prompt)
if tokenizer:
ocr_inputs = tokenizer.decode(
tokenizer.encode(ocr_inputs)[:n_insert_ocr_tokens]
)
return ocr_inputs
@classmethod
def lens_preprocess(
cls,
lens: list[dict],
entity_top_k: int = 100,
entity_keyword_threshold: float = 0.0,
entity_keyword_fashion_threshold: float = 0.0,
entity_use_ratio: float = 0.0,
seed=None,
):
if seed is None:
seed = np.random.default_rng()
if seed.uniform(0, 1) > entity_use_ratio:
return None
entities = lens
filter_idx = []
insert_entity_prompt = {}
for idx, entity in enumerate(entities):
if entity["type"] != "naver_lens_api":
filter_idx.append(idx)
continue
if (
isinstance(entity_keyword_threshold, (int, float))
and entity["confidence"] < entity_keyword_threshold
):
filter_idx.append(idx)
continue
if (
isinstance(entity_keyword_fashion_threshold, (int, float))
and ("fashion" in entity["info"]["classes"])
and entity["confidence"] < entity_keyword_fashion_threshold
):
filter_idx.append(idx)
continue
entityvalue = [
keyword for idx, keyword in enumerate(entities) if idx not in filter_idx
]
entityvalue = sorted(entityvalue, key=lambda x: x["confidence"], reverse=True)
important_entity_list = []
local_entity_str_list = []
keywords_and_bbox_per_detector = {}
for keyword_dict in entityvalue[:entity_top_k]:
object_class = "/".join(keyword_dict["info"]["classes"])
if object_class not in keywords_and_bbox_per_detector.keys():
keywords_and_bbox_per_detector[object_class] = []
keywords_and_bbox_per_detector[object_class].append(keyword_dict)
for object_class in keywords_and_bbox_per_detector.keys():
entities_per_object = keywords_and_bbox_per_detector[object_class]
normalized_bbox = bbox_process(
[*entities_per_object[0]["bbox"][0], *entities_per_object[0]["bbox"][2]]
)
entities = [entity["text"] for entity in entities_per_object]
if "context" in object_class:
important_entity_list += entities
else:
local_entity_str_list += [
str(normalized_bbox) + " " + ", ".join(entities)
]
if len(important_entity_list) > 0:
insert_entity_prompt["lens_keywords"] = ", ".join(important_entity_list)
if len(local_entity_str_list) > 0:
insert_entity_prompt["lens_local_keywords"] = " ".join(
local_entity_str_list
)
return insert_entity_prompt
@classmethod
def prompt_toollist(
cls,
output,
tokenizer=None,
turn: Optional[dict] = None,
content: Optional[list[dict]] = None,
):
assert content or turn
if turn is None:
turn = {
"role": "tool_list",
"content": content,
}
toollist_str = (
cls.turn_prefix.strip()
+ turn["role"]
+ "\n"
+ turn["content"]
+ cls.turn_suffix
)
if hasattr(output, "input_str"):
output.input_str += toollist_str
if getattr(output, "input_ids", None) is not None:
token_ids = tokenizer.encode(toollist_str, truncation=False)
output.input_ids += token_ids
output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))]
return output
@classmethod
def prompt_system(
cls,
output,
tokenizer=None,
turn: Optional[dict] = None,
content: Optional[str] = None,
seed=None,
tool_prompt=None,
system_role_count=0,
):
assert content or turn
if seed is None:
seed = np.random.default_rng()
if turn is None:
system_prompt = content
else:
if "candidates" in turn:
if len(turn["candidates"]) > 0:
system_prompt = seed.choice(turn["candidates"])
if type(system_prompt) is dict:
system_prompt = system_prompt["content"]
else:
system_prompt = ""
elif isinstance(turn["content"], str):
system_prompt = turn["content"]
elif len(turn["content"]) > 0:
system_prompt = seed.choice(turn["content"])
system_str = cls.turn_prefix + turn["role"] + "\n"
system_str += system_prompt.strip()
if system_role_count == 0:
if system_prompt.strip():
system_str += "\n"
system_str += tool_prompt
system_str += cls.turn_suffix
if hasattr(output, "input_str"):
output.input_str += system_str
if getattr(output, "input_ids", None) is not None:
token_ids = tokenizer.encode(system_str, truncation=False)
output.input_ids += token_ids
output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))]
return output
@classmethod
def load_mm(
cls,
output,
img_dir: str = "",
turn: Optional[dict] = None,
image_urls: Optional[list[str]] = None,
image_metas: Optional[list[dict]] = None,
video_urls: Optional[list[str]] = None,
video_metas: Optional[list[dict]] = None,
audio_urls: Optional[list[str]] = None,
audio_metas: Optional[list[dict]] = None,
prepare_input_fn=None,
prepare_audio_input_fn=None,
max_image_cnt=21,
video_max_num_frames=None,
video_max_pixels=None,
use_audio: bool = False,
audio_sample_rate: int = 16000,
):
assert (image_urls or video_urls or audio_urls) or turn
if turn is None:
turn = {}
if image_urls:
turn.update({"image_urls": image_urls})
turn.update({"image_metas": image_metas})
if video_urls:
turn.update({"video_urls": video_urls})
turn.update({"video_metas": video_metas})
if audio_urls:
turn.update({"audio_urls": audio_urls})
turn.update({"audio_metas": audio_metas})
if "video_urls" in turn:
if len(turn["video_urls"]) and (prepare_input_fn is None):
raise ConditionalError("video processing needs 'prepare_input_fn'")
if not isinstance(turn["content"], str):
raise ConditionalError(f"turn['content'] must be a string")
turn["content"] = re.sub(r"<image_\d+>", "<|image|>", turn["content"])
pattern = re.compile(
r"<\|video\|>|<\|image\|>|<\|t2i_model_generation_target_discrete_image\|>|<\|audio\|>|<\|discrete_audio\|>"
)
tags = [match.group() for match in pattern.finditer(turn["content"])]
img_idx = 0
vid_idx = 0
aud_idx = 0
if "image_urls" not in turn:
turn["image_urls"] = []
if "video_urls" not in turn:
turn["video_urls"] = []
if "audio_urls" not in turn:
turn["audio_urls"] = []
for tag in tags:
if (
tag == "<|image|>"
or tag == "<|t2i_model_generation_target_discrete_image|>"
):
img_path = turn["image_urls"][img_idx]
if isinstance(img_path, str):
if "#" in img_path:
compression_path, img_path = img_path.split("#", 1)
compression_path = os.path.join(img_dir, compression_path)
assert compression_path[-4:] in [
".zip",
".tar",
], f"unsupported compression format: {compression_path}"
with open(compression_path, "rb") as comp_file:
if compression_path.endswith(".zip"):
with zipfile.ZipFile(comp_file, "r") as zip_file:
with zip_file.open(img_path) as img_file:
img_binary = img_file.read()
elif compression_path.endswith(".tar"):
with tarfile.open(
fileobj=comp_file, mode="r"
) as tar_file:
img_file = tar_file.extractfile(img_path)
img_binary = img_file.read()
else:
with open(os.path.join(img_dir, img_path), "rb") as f:
img_binary = f.read()
img = image_decoder(img_binary)
else:
if isinstance(img_path, (bytes, bytearray)):
img = io.BytesIO(img_path)
img = Image.open(img).convert("RGB")
else:
img = img_path
if not isinstance(img, Image.Image):
img = Image.fromarray(np.uint8(img)).convert("RGB")
if "image_metas" in turn and turn["image_metas"]:
turn["image_metas"][img_idx] = convert_bboxes(
img, turn["image_metas"][img_idx]
)
if tag == "<|image|>":
output.imgs.append(img)
output.discrete_imgs.append(img)
img_idx += 1
elif tag == "<|video|>":
video_path = turn["video_urls"][vid_idx]
if isinstance(video_path, str):
if "#" in video_path:
compression_path, video_path = video_path.split("#", 1)
compression_path = os.path.join(img_dir, compression_path)
assert compression_path[-4:] in [
".zip",
".tar",
], f"unsupported compression format: {compression_path}"
with open(compression_path, "rb") as comp_file:
if compression_path.endswith(".zip"):
with zipfile.ZipFile(comp_file, "r") as zip_file:
video_file = zip_file.open(video_path)
video_binary = video_file.read()
elif compression_path.endswith(".tar"):
with tarfile.open(
fileobj=comp_file, mode="r"
) as tar_file:
video_file = tar_file.extractfile(video_path)
video_binary = video_file.read()
else:
with open(os.path.join(img_dir, video_path), "rb") as f:
video_binary = f.read()
video_binary = io.BytesIO(video_binary)
else:
video_binary = video_path
assert isinstance(video_binary, io.BytesIO), "video binary read error"
try:
from hcxvlm.dataset.qwen_vision_process import process_vision_info
except:
from qwen_vl_utils import process_vision_info
if video_max_num_frames is None:
video_max_num_frames = 120
if video_max_pixels is None:
video_max_pixels = 378 * 378
messages = [
[
{
"role": "user",
"content": [
{
"type": "video",
"video": video_binary,
"max_frames": video_max_num_frames,
"max_pixels": video_max_pixels,
}
],
}
],
]
_, videos, video_kwargs = process_vision_info(
messages,
return_video_kwargs=True,
use_audio=use_audio,
audio_sample_rate=audio_sample_rate,
)
output.videos.append(videos[0])
video_len = round(videos[0].shape[0] / video_kwargs["fps"][0], 2)
output.videos_duration.append(
{
"video_duration": f"{video_len}s",
}
)
if use_audio and "audio_chunks" in video_kwargs:
audio_chunks = video_kwargs["audio_chunks"][0]
if audio_chunks is not None:
output.video_audios.append(audio_chunks)
else:
output.video_audios.append([])
elif use_audio:
output.video_audios.append([])
vid_idx += 1
elif tag == "<|audio|>" or tag == "<|discrete_audio|>":
audio_path = turn["audio_urls"][aud_idx]
if isinstance(audio_path, str):
if "#" in audio_path:
compression_path, inner_path = audio_path.split("#", 1)
compression_path = os.path.join(img_dir, compression_path)
assert compression_path[-4:] in [
".zip",
".tar",
], f"unsupported compression format: {compression_path}"
with open(compression_path, "rb") as comp_file:
if compression_path.endswith(".zip"):
with zipfile.ZipFile(comp_file, "r") as zip_file:
with zip_file.open(inner_path) as audio_file:
audio_binary = audio_file.read()
elif compression_path.endswith(".tar"):
with tarfile.open(
fileobj=comp_file, mode="r"
) as tar_file:
audio_file = tar_file.extractfile(inner_path)
audio_binary = audio_file.read()
else:
with open(os.path.join(img_dir, audio_path), "rb") as f:
audio_binary = f.read()
audio_stream = io.BytesIO(audio_binary)
else:
if isinstance(audio_path, (bytes, bytearray)):
audio_stream = io.BytesIO(audio_path)
else:
audio_stream = audio_path
try:
import librosa
y, sr = librosa.load(
audio_stream, sr=DEFAULT_SAMPLE_RATE, mono=True
)
assert (
DEFAULT_SAMPLE_RATE == sr
), f"librosa resampling failed: {DEFAULT_SAMPLE_RATE} != {sr}"
except Exception as e:
raise ConditionalError(
f"audio decoding failed for {audio_path}: {e}"
)
audio_duration = len(y) / sr
if audio_duration < 0.5:
raise ConditionalError(
f"Audio too short ({audio_duration:.2f}s). Minimum 0.5s required."
)
if audio_duration > 600:
raise ConditionalError(
f"Audio duration ({audio_duration:.2f}s) exceeds maximum allowed duration (600s)"
)
if len(y) < MIN_DISCRETE_AUDIO_CHUNK_SAMPLES:
raise ConditionalError(
f"Audio too short ({len(y)} samples = {audio_duration:.4f}s < 0.1s). "
f"Minimum {MIN_DISCRETE_AUDIO_CHUNK_SAMPLES} samples required for CosyVoice encoder."
)
if not hasattr(output, "audios"):
output.audios = []
if not hasattr(output, "discrete_audios"):
output.discrete_audios = []
normalized_y = hpf_normalize(y)
normalized_y = torch.from_numpy(normalized_y).float()
output.discrete_audios.append(normalized_y)
if tag == "<|audio|>":
output.audios.append(y)
total_duration = len(y) / sr
output.audios_duration.append(
{
"duration": f"{(total_duration):.2f}s",
}
)
aud_idx += 1
else:
raise ConditionalError(
f"{tag} is not in ['<|image|>', '<|video|>', '<|audio|>']"
)
return output
@classmethod
def prompt_user(
cls,
output,
tokenizer=None,
turn: Optional[dict] = None,
content: Optional[str] = None,
is_train=False,
fixed_mime=False,
insert_ocr=300,
file_names: Optional[list[str]] = None,
mimes: Optional[list[str]] = None,
mm_tokens: Optional[list[str]] = None,
words: Optional[list] = None,
lens: Optional[list] = None,
query_template: Optional[list[str]] = None,
config: Optional[dict] = None,
seed: np.random.Generator = None,
):
assert content or turn
if turn is None:
image_metas = [
{"words": words[i], "lens": lens[i]} for i in range(len(words))
]
turn = {
"content": content,
"image_metas": image_metas,
}
if seed is None:
seed = np.random.default_rng()
turn["content"] = re.sub(r"<image_\d+>", "<|image|>", turn["content"])
turn["content"] = re.sub(r"<video_\d+>", "<|video|>", turn["content"])
turn["content"] = re.sub(r"<audio_\d+>", "<|audio|>", turn["content"])
pattern = re.compile(r"(<\|video\|>|<\|image\|>|<\|audio\|>)")
all_tags_in_order = [
match.group() for match in pattern.finditer(turn["content"])
]
n_vids = sum(1 for tag in all_tags_in_order if tag == "<|video|>")
n_audios = sum(1 for tag in all_tags_in_order if tag == "<|audio|>")
assert (
len(turn.get("image_urls", []))
+ len(turn.get("video_urls", []))
+ len(turn.get("audio_urls", []))
) == len(
all_tags_in_order
), f"Number of media URLs does not match number of media tags."
if mm_tokens is None:
mm_tokens = [
cls.audio_pad if tag == "<|audio|>" else cls.image_pad
for tag in all_tags_in_order
]
assert len(mm_tokens) == len(all_tags_in_order)
if config.get("llava_pretrain", False):
mm_str = "".join([mm_tokens[i] for i in range(len(all_tags_in_order))])
if hasattr(output, "input_str"):
output.input_str += mm_str
if getattr(output, "input_ids", None) is not None:
token_ids = tokenizer.encode(mm_str, truncation=False)
output.input_ids += token_ids
output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))]
return output
if query_template:
processed_content = seed.choice(query_template).format(turn["content"])
tags_after_template = pattern.findall(processed_content)
if len(all_tags_in_order) != len(tags_after_template):
cleaned_template_text = pattern.sub("", processed_content)
processed_content = "".join(all_tags_in_order) + cleaned_template_text
turn["content"] = processed_content
content_parts = pattern.split(turn["content"].strip())
if hasattr(output, "input_str"):
output.input_str += f"{cls.new_line}{cls.turn_prefix}{turn['role']}"
if getattr(output, "input_ids", None) is not None:
role_encoded = tokenizer.encode(
f"{cls.new_line}{cls.turn_prefix}{turn['role']}", truncation=False
)
output.input_ids += role_encoded
if turn.get("trainable_role", False):
output.label_ids += role_encoded
else:
output.label_ids += [IGNORE_INDEX for _ in range(len(role_encoded))]
tag_cursor = 0
for part in content_parts:
part = part.strip()
if not part:
continue
if part not in ["<|image|>", "<|video|>", "<|audio|>"]:
content_text = part
if hasattr(output, "input_str"):
output.input_str += "\n" + content_text
if getattr(output, "input_ids", None) is not None:
content_encoded = tokenizer.encode(
"\n" + content_text, truncation=False
)
output.input_ids += content_encoded
if turn.get("trainable_content", False):
output.label_ids += content_encoded
else:
output.label_ids += [
IGNORE_INDEX for _ in range(len(content_encoded))
]
continue
if part == "<|image|>":
mime = Preprocessor.prompt_mime(
mimes=mimes if not file_names else None,
fixed_mime=fixed_mime if not file_names else False,
file_name=file_names[tag_cursor] if file_names else None,
tag_idx=output.sample_mm_counter["image"],
is_video=False,
is_audio=False,
seed=seed,
)
mime_str = f"{cls.mime_start}{json.dumps(mime, ensure_ascii=False)}{cls.mime_end}"
discrete_image_str = f"{cls.discrete_image_start}{cls.discrete_image_pad}{cls.discrete_image_end}"
vector_str = f"{cls.image_start}{cls.image_pad}{cls.image_end}"
mm_str = (
cls.new_line
+ mime_str
+ cls.new_line
+ discrete_image_str
+ cls.new_line
+ vector_str
)
if hasattr(output, "input_str"):
output.input_str += mm_str
if getattr(output, "input_ids", None) is not None:
token_ids = tokenizer.encode(mm_str, truncation=False)
output.input_ids += token_ids
output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))]
output.sample_mm_counter["image"] += 1
tag_cursor += 1
elif part == "<|video|>":
mime = Preprocessor.prompt_mime(
mimes=mimes if not file_names else None,
fixed_mime=fixed_mime if not file_names else False,
file_name=file_names[tag_cursor] if file_names else None,
tag_idx=output.sample_mm_counter["video"],
is_video=True,
is_audio=False,
seed=seed,
)
mm_str = ""
aux_inputs = {
"video_duration": output.videos_duration[
output.sample_mm_counter["video"]
]["video_duration"],
}
mime_str = f"{cls.mime_start}{json.dumps(mime, ensure_ascii=False)}{cls.mime_end}"
aux_str = f"{cls.aux_video_start}{cls.aux_vid_prompt}{json.dumps(aux_inputs, ensure_ascii=False)}{cls.aux_video_end}"
vector_str = f"{cls.video_start}{cls.video_pad}{cls.video_end}"
mm_str += (
cls.new_line
+ mime_str
+ cls.new_line
+ aux_str
+ cls.new_line
+ vector_str
)
if hasattr(output, "input_str"):
output.input_str += mm_str
if getattr(output, "input_ids", None) is not None:
token_ids = tokenizer.encode(mm_str, truncation=False)
output.input_ids += token_ids
output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))]
output.sample_mm_counter["video"] += 1
tag_cursor += 1
elif part == "<|audio|>":
mime = Preprocessor.prompt_mime(
mimes=mimes if not file_names else None,
fixed_mime=fixed_mime if not file_names else False,
file_name=file_names[tag_cursor] if file_names else None,
tag_idx=output.sample_mm_counter["audio"],
is_video=False,
is_audio=True,
seed=seed,
)
mm_str = ""
aux_inputs = {
"audio_duration": output.audios_duration[
output.sample_mm_counter["audio"]
]["duration"],
}
mime_str = f"{cls.mime_start}{json.dumps(mime, ensure_ascii=False)}{cls.mime_end}"
aux_str = f"{cls.aux_audio_start}{cls.aux_audio_prompt}{json.dumps(aux_inputs, ensure_ascii=False)}{cls.aux_audio_end}"
discrete_audio_str = f"{cls.discrete_audio_start}{cls.discrete_audio_pad}{cls.discrete_audio_end}"
vector_str = f"{cls.audio_start}{cls.audio_pad}{cls.audio_end}"
mm_str += (
cls.new_line
+ mime_str
+ cls.new_line
+ aux_str
+ cls.new_line
+ discrete_audio_str
+ cls.new_line
+ vector_str
)
if hasattr(output, "input_str"):
output.input_str += mm_str
if getattr(output, "input_ids", None) is not None:
token_ids = tokenizer.encode(mm_str, truncation=False)
output.input_ids += token_ids
output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))]
output.sample_mm_counter["audio"] += 1
tag_cursor += 1
if hasattr(output, "input_str"):
output.input_str += cls.turn_suffix
if getattr(output, "input_ids", None) is not None:
token_ids = tokenizer.encode(cls.turn_suffix, truncation=False)
output.input_ids += token_ids
output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))]
return output
@classmethod
def prompt_assistant(
cls,
output,
tokenizer=None,
turn: Optional[dict] = None,
role: Optional[str] = "assistant",
content: Optional[str] = None,
is_last_turn=False,
is_eval=True,
is_llava_pretrain=False,
is_after_last_user_turn=False,
):
assert content or turn
if turn is None:
turn = {
"content": content,
"role": role,
}
if is_llava_pretrain:
if hasattr(output, "input_str"):
output.input_str += turn["content"]
if getattr(output, "input_ids", None) is not None:
content_encoded = tokenizer.encode(turn["content"], truncation=False)
output.input_ids += content_encoded
output.label_ids += content_encoded
return output
reasoning_content = turn.get("reasoning_content", "")
if (
not reasoning_content
and isinstance(turn["content"], str)
and "</think>" in turn["content"]
):
parts = turn["content"].split("</think>", 1)
reasoning_content = parts[0].split("<think>", 1)[-1].lstrip("\n")
turn["content"] = parts[1].lstrip("\n")
if is_after_last_user_turn and (is_last_turn or reasoning_content):
content_to_strip = turn.get("content") or ""
stripped_content = content_to_strip.lstrip("\n")
if reasoning_content is None:
reasoning_content = ""
turn["content"] = (
f"<think>\n{reasoning_content.strip()}\n</think>\n\n{stripped_content}"
)
if turn.get("tool_calls"):
for tool_call in turn["tool_calls"]:
func_name = tool_call.get("function", {}).get("name", "")
args = tool_call.get("function", {}).get("arguments", {})
if isinstance(args, str):
try:
args = json.loads(args)
except Exception:
pass
if not isinstance(args, dict):
print(
f"[error] tool_call.function.arguments가 dict이 아님: type={type(args)}, value={str(args)}"
)
assert (
False
), "tool_call.function.arguments는 dict이거나 dict를 나타내는 JSON 문자열이어야 합니다."
tool_turn_content = f"\n<tool_call>{func_name}\n"
for key, value in args.items():
arg_value = (
json.dumps(value, ensure_ascii=False)
if not isinstance(value, str)
else value
)
tool_turn_content += f"<arg_key>{key}</arg_key>\n<arg_value>{arg_value}</arg_value>\n"
tool_turn_content += "</tool_call>"
if func_name == "t2i_model_generation":
assert (
"<|t2i_model_generation_target_discrete_image|>"
in turn["content"]
), "t2i_model_generation tool call must have target discrete image tag in content."
turn["content"] = turn["content"].replace(
"<|t2i_model_generation_target_discrete_image|>",
tool_turn_content,
)
else:
turn["content"] += tool_turn_content
pattern = re.compile(
r"(<\|image\|>|<\|discrete_image\|>|<\|audio\|>|<\|discrete_audio\|>)"
)
all_tags_in_order = [
match.group() for match in pattern.finditer(turn["content"])
]
assert (
len(turn.get("image_urls", []))
+ len(turn.get("video_urls", []))
+ len(turn.get("audio_urls", []))
) == len(
all_tags_in_order
), f"Number of media URLs does not match number of media tags."
if hasattr(output, "input_str"):
output.input_str += f"{cls.new_line}{cls.turn_prefix}{turn['role']}"
if is_eval and is_last_turn:
if reasoning_content.strip() == "":
output.input_str += f"<think>\n\n</think>\n\n"
turn["content"] = stripped_content
else:
output.input_str += f"{turn['content']}{cls.turn_suffix}"
if getattr(output, "input_ids", None) is not None:
role_encoded = tokenizer.encode(
f"{cls.new_line}{cls.turn_prefix}{turn['role']}", truncation=False
)
output.input_ids += role_encoded
if is_eval and is_last_turn:
if reasoning_content.strip() == "":
output.input_ids += tokenizer.encode(
f"<think>\n\n</think>\n\n", truncation=False
)
turn["content"] = stripped_content
else:
if turn.get("trainable_role", True):
output.label_ids += role_encoded
else:
output.label_ids += [IGNORE_INDEX for _ in range(len(role_encoded))]
turn_img_idx = 0
content_parts = pattern.split(turn["content"].strip())
for part in content_parts:
part = part.strip()
if not part:
continue
if part not in [
"<|image|>",
"<|discrete_image|>",
"<|audio|>",
"<|discrete_audio|>",
]:
content_text = part
if hasattr(output, "input_str"):
output.input_str += "\n" + content_text
if getattr(output, "input_ids", None) is not None:
content_encoded = tokenizer.encode(
"\n" + content_text, truncation=False
)
output.input_ids += content_encoded
if turn.get("trainable_content", True):
output.label_ids += content_encoded
else:
output.label_ids += [
IGNORE_INDEX for _ in range(len(content_encoded))
]
continue
if part == "<|image|>":
file_name = turn.get("image_urls", [])[turn_img_idx]
if isinstance(file_name, str) and "#" in file_name:
file_name = file_name.split("#")[-1]
file_name = os.path.basename(file_name)
mime = Preprocessor.prompt_mime(
mimes=None,
fixed_mime=False,
file_name=file_name,
tag_idx=output.sample_mm_counter["image"],
is_video=False,
is_audio=False,
seed=None,
)
mime_str = f"{cls.mime_start}{json.dumps(mime, ensure_ascii=False)}{cls.mime_end}"
discrete_image_str = f"{cls.discrete_image_start}{cls.discrete_image_pad}{cls.discrete_image_end}"
vector_str = f"{cls.image_start}{cls.image_pad}{cls.image_end}"
mm_str = (
cls.new_line
+ mime_str
+ cls.new_line
+ discrete_image_str
+ cls.new_line
+ vector_str
)
if hasattr(output, "input_str"):
output.input_str += mm_str
if getattr(output, "input_ids", None) is not None:
token_ids = tokenizer.encode(mm_str, truncation=False)
output.input_ids += token_ids
output.label_ids += [
IGNORE_INDEX for _ in range(len(token_ids))
]
turn_img_idx += 1
output.sample_mm_counter["image"] += 1
elif part == "<|discrete_image|>":
discrete_image_str = f"{cls.discrete_image_start}{cls.discrete_image_pad}{cls.discrete_image_end}"
mm_str = cls.new_line + discrete_image_str
if hasattr(output, "input_str"):
output.input_str += mm_str
if getattr(output, "input_ids", None) is not None:
token_ids = tokenizer.encode(mm_str, truncation=False)
output.input_ids += token_ids
output.label_ids += token_ids
turn_img_idx += 1
elif part == "<|discrete_audio|>":
discrete_audio_str = f"{cls.discrete_audio_start}{cls.discrete_audio_pad}{cls.discrete_audio_end}"
mm_str = cls.new_line + discrete_audio_str
if hasattr(output, "input_str"):
output.input_str += mm_str
if getattr(output, "input_ids", None) is not None:
token_ids = tokenizer.encode(mm_str, truncation=False)
output.input_ids += token_ids
if turn.get("trainable_content", True):
output.label_ids += token_ids
else:
output.label_ids += [
IGNORE_INDEX for _ in range(len(token_ids))
]
elif part == "<|audio|>":
raise Exception(
"Assistant turn에서 <|audio|> 태그는 지원하지 않음. discrete_audio 만 지원함."
)
if hasattr(output, "input_str"):
output.input_str += cls.turn_suffix
if getattr(output, "input_ids", None) is not None:
token_ids = tokenizer.encode(cls.turn_suffix, truncation=False)
output.input_ids += token_ids
if turn.get("trainable_content", True):
output.label_ids += token_ids
else:
output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))]
return output
@classmethod
def prompt_tool(
cls,
output,
tokenizer=None,
turn: Optional[dict] = None,
role: Optional[str] = None,
content: Optional[str] = None,
eot: Optional[bool] = None,
need_start_tag=True,
need_end_tag=True,
):
assert (content and role) or turn
if turn is None:
turn = {
"content": content,
"role": role,
"endofturn": eot,
}
assert (
"tool" == turn["role"]
), f'[warning] unexpected turn["role"]: {turn["role"]}'
content_value = turn.get("content", "")
if isinstance(content_value, dict):
if "response" in content_value:
content_str = content_value["response"]
else:
content_str = json.dumps(content_value, ensure_ascii=False)
elif isinstance(content_value, str):
try:
parsed = json.loads(content_value)
if isinstance(parsed, dict):
if "response" in parsed:
content_str = parsed["response"]
else:
content_str = json.dumps(parsed, ensure_ascii=False)
else:
content_str = content_value
except (json.JSONDecodeError, TypeError):
content_str = content_value
else:
content_str = str(content_value)
turn["content"] = (
f"<tool_response>{turn.get('name', '')}\n{content_str}\n</tool_response>"
)
if hasattr(output, "input_str"):
if need_start_tag:
output.input_str += f"{cls.new_line}{cls.turn_prefix}{turn['role']}"
output.input_str += f"{cls.new_line}{turn['content']}"
if need_end_tag:
output.input_str += cls.turn_suffix
if getattr(output, "input_ids", None) is not None:
if need_start_tag:
role_encoded = tokenizer.encode(
f"{cls.new_line}{cls.turn_prefix}{turn['role']}", truncation=False
)
output.input_ids += role_encoded
if turn.get("trainable_role", True):
output.label_ids += role_encoded
else:
output.label_ids += [IGNORE_INDEX for _ in range(len(role_encoded))]
content = f"{cls.new_line}{turn['content']}"
content_encoded = tokenizer.encode(content, truncation=False)
if need_end_tag:
content_encoded += tokenizer.encode(
f"{cls.turn_suffix}", truncation=False
)
output.input_ids += content_encoded
if turn.get("trainable_content", True):
output.label_ids += content_encoded
else:
output.label_ids += [
IGNORE_INDEX for _ in range(len(content_encoded))
]
return output
@classmethod
def prompt_etc(
cls,
output,
tokenizer=None,
turn: Optional[dict] = None,
role: Optional[str] = None,
content: Optional[str] = None,
eot: Optional[bool] = None,
):
assert (content and role) or turn
if turn is None:
turn = {
"content": content,
"role": role,
"endofturn": eot,
}
print(f'[warning] unexpected turn["role"]: {turn["role"]}')
if hasattr(output, "input_str"):
output.input_str += f"{cls.turn_prefix}{turn['role']}\n"
output.input_str += f"{turn['content']}{cls.turn_suffix}"
if turn.get("stop", False):
output.input_str += cls.stop_token
if turn.get("endofturn", False):
output.input_str += cls.eot
if getattr(output, "input_ids", None) is not None:
role_encoded = tokenizer.encode(
f"{cls.turn_prefix}{turn['role']}\n", truncation=False
)
output.input_ids += role_encoded
if turn.get("trainable_role", True):
output.label_ids += role_encoded
else:
output.label_ids += [IGNORE_INDEX for _ in range(len(role_encoded))]
content = f"{turn['content']}{cls.turn_suffix}"
if turn.get("stop", False):
content += cls.stop_token
if turn.get("endofturn", False):
content += cls.eot
content_encoded = tokenizer.encode(content, truncation=False)
output.input_ids += content_encoded
if turn.get("trainable_content", True):
output.label_ids += content_encoded
else:
output.label_ids += [IGNORE_INDEX for _ in range(len(content_encoded))]
return output
def __call__(self, sample):
return self.preprocess_new(sample)
@classmethod
def batchify(
cls,
items: List[Dict[str, Any],],
device: str = None,
):
batch = dict()
for item in items:
for k, v in item.items():
if isinstance(v, torch.Tensor):
if device is not None:
v = v.to(device=device)
elif k == "pixel_values":
v = [_v.to(device=device) for _v in v]
if k not in batch:
batch[k] = [
v,
]
else:
batch[k].append(v)
for k, v in batch.items():
if isinstance(v[0], torch.Tensor):
if k in ["image_grid_thw", "video_grid_thw"]:
batch[k] = torch.cat(v, dim=0)
continue
batch[k] = torch.stack(v, dim=0)
batch["video_grid_thw"] = None
batch["pixel_values_videos"] = None
return batch
def convert_wds_to_datalake(
self,
img: Union[PIL.Image.Image, Dict[str, PIL.Image.Image]] = {},
json: Dict[str, Any] = {},
benchmark: Optional[str] = None,
video: Union[io.BytesIO, Dict[str, io.BytesIO]] = {},
audio: Union[io.BytesIO, Dict[str, io.BytesIO]] = {},
):
if "lines" in json:
del json["lines"]
if "paragraphs" in json:
del json["paragraphs"]
assert json["meta"]["type"] in [
"caption",
"vqa",
"textread",
], f"{json['meta']['path']}, {json['meta']['type']}: The dataset type should be one of them: caption, vqa, textread."
sample = {"vlm": {}}
sample["vlm"] = get_wds_default_config(
json["meta"], existing_default_config=self.wds_default_config
)
sample["vlm"]["data_name"] = json["meta"].get("name", "unk")
sample["vlm"]["data_type"] = (
"wds"
if (isinstance(img, PIL.Image.Image) and img)
or (isinstance(img, dict) and len(img) > 0)
else "sft1"
)
sample["vlm"]["sample_id"] = json.get("qa_id", None)
sample["vlm"]["category"] = json.get("category", None)
sample["vlm"]["data_info"] = json.get("data_info", dict())
sample["vlm"]["options"] = None
if "choices_en" in sample["vlm"]["data_info"]:
if sample["vlm"]["options"] is None and json["meta"]["lang"] == "en":
sample["vlm"]["options"] = sample["vlm"]["data_info"]["choices_en"]
sample["vlm"]["options_en"] = sample["vlm"]["data_info"]["choices_en"]
if "choices_ko" in sample["vlm"]["data_info"]:
if sample["vlm"]["options"] is None and json["meta"]["lang"] == "ko":
sample["vlm"]["options"] = sample["vlm"]["data_info"]["choices_ko"]
sample["vlm"]["options_ko"] = sample["vlm"]["data_info"]["choices_ko"]
sample["vlm"]["image_index"] = json.get(
"image_index", json.get("img_url", None)
)
if sample["vlm"].get("video", False):
is_multi_image_dataset = False
else:
is_multi_image_dataset, img, json = convert_format_for_multi_image(
img, json
)
if json["meta"]["type"] == "textread":
key = "words"
elif json["meta"].get("subtask", "") == "region":
key = f"regions_{json['meta']['lang']}"
elif json["meta"]["type"] == "vqa":
key = f"qa_pairs_{json['meta']['lang']}"
elif json["meta"]["type"] == "caption":
key = f"captions_{json['meta']['lang']}"
else:
raise ConditionalError(
f"wrong task type in wds config: {sample['vlm']['data_name']}"
)
turns = [
{
"role": "tool_list",
"content": "",
"content_type": "text",
"trainable_role": False,
"trainable_content": False,
"stop": False,
"debuggingInfo": {},
"meta": {},
"candidates": [],
"endofturn": False,
},
{
"role": "system",
"content_type": "text",
"candidates": [],
"trainable_role": False,
"trainable_content": False,
"stop": False,
"debuggingInfo": {},
"meta": {},
"content": "",
"endofturn": False,
},
]
if json["meta"].get("llava_pretrain", False):
sample["vlm"]["llava_pretrain"] = True
use_task_prompt = json["meta"].get(
"use_task_prompt", self.wds_default_config["use_task_prompt"]
)
get_random = json["meta"].get(
"get_random", self.wds_default_config["get_random"]
)
reasoning = json["meta"].get("reasoning", self.wds_default_config["reasoning"])
try:
if key not in json:
key = key[:-3]
assert key in json
if len(json[key]) == 0:
key = key[:-3]
assert key in json
except:
raise ConditionalError(
f"{key} key is not in json? dataset name: {sample['vlm']['data_name']}"
)
first_turn = True
if "region" in key:
json[key] = json[key]["00"]
sample["vlm"]["multiturn_n_samples"] = 1
if (
not is_multi_image_dataset
and sample["vlm"]["multiturn_n_samples"] > 1
or "region" in key
):
json[key] = sampling_multiturn_single_img(
json[key],
sample["vlm"]["multiturn_n_samples"],
sample["vlm"]["multiturn_preserve_order"],
sample["vlm"]["multiturn_continuous"],
)
if sample["vlm"].get("video", False):
for qa in json[key]:
vid_src = []
user = {
"role": "user",
"content_type": "text",
"candidates": [],
"trainable_role": False,
"trainable_content": False,
"stop": False,
"debuggingInfo": {},
"meta": {},
"image_urls": [],
"image_metas": [],
"video_urls": [],
"video_metas": [],
"audio_urls": [],
"audio_metas": [],
"content": "",
"endofturn": False,
}
instruct_prompt, task_prompt = hcx_vision_prompter(
task=json["meta"]["type"],
subtask=json["meta"].get("subtask", None),
lang=json["meta"]["lang"],
get_random=get_random,
use_task_prompt=use_task_prompt,
)
prompt = qa[0]
answer = qa[-1] if reasoning else qa[1]
if first_turn:
user["video_metas"].append({"lens": []})
user["content"] += "<|video|>"
prompt = task_prompt.format(prompt)
if "entities" in json:
user["video_metas"][0]["lens"] = json["entities"].get("00", [])
if isinstance(video, dict):
vid_src.append(video["00"])
else:
vid_src.append(video)
first_turn = False
user["video_urls"] = vid_src
user["content"] += prompt
assistant = {
"candidates": [],
"content": answer,
"content_type": "text",
"debuggingInfo": {},
"meta": {},
"role": "assistant",
"trainable_content": True,
"trainable_role": True,
"stop": False,
"endofturn": True,
}
turns.append(user)
turns.append(assistant)
else:
if key.startswith("qa_pairs") or key.startswith("captions"):
if self.mode != "train" and key.startswith("qa_pairs"):
qas = dict()
for qa in json[key]:
q = qa[0]
if q not in qas:
qas[q] = list()
for _i, _e in enumerate(qa[1:]):
if len(qas[q]) <= _i:
qas[q].append(list())
qas[q][_i].append(_e)
json[key] = [
[
k,
]
+ v
for k, v in qas.items()
]
if self.mode != "train":
json[key] = json[key][:1]
for qa in json[key]:
img_src = []
user = {
"role": "user",
"content_type": "text",
"candidates": [],
"trainable_role": False,
"trainable_content": False,
"stop": False,
"debuggingInfo": {},
"meta": {},
"image_urls": [],
"image_metas": [],
"video_urls": [],
"video_metas": [],
"audio_urls": [],
"audio_metas": [],
"content": "",
"endofturn": False,
}
img_keys = re.findall(r"<image_(\d+)>", qa[0])
video_keys = re.findall(r"<video_(\d+)>", qa[0])
audio_keys = re.findall(r"<audio_(\d+)>", qa[0])
if key.startswith("qa_pairs"):
if len(qa) > 2:
sample_id = qa[2]
if (
isinstance(sample_id, (list, tuple))
and len(sample_id) > 0
):
sample_id = sample_id[0]
sample["vlm"]["sample_id"] = sample_id
instruct_prompt, task_prompt = hcx_vision_prompter(
task=json["meta"]["type"],
subtask=json["meta"].get("subtask", None),
lang=json["meta"]["lang"],
get_random=get_random,
use_task_prompt=use_task_prompt,
)
if json["meta"]["type"] == "vqa":
prompt = qa[0]
answer = qa[-1] if reasoning else qa[1]
elif json["meta"]["type"] == "caption":
prompt = task_prompt.format("")
answer = qa
if first_turn or self.mode != "train":
if json["meta"]["type"] == "vqa":
prompt = task_prompt.format(prompt)
if first_turn and not is_multi_image_dataset:
user["image_metas"].append({"words": [], "lens": []})
if "<image_00>" in prompt:
prompt = prompt.replace("<image_00>", "<|image|>")
else:
user["content"] += "<|image|>"
user["image_metas"][0]["words"] = json.get("words", {}).get(
"00", []
)
if "objects" in json:
user["image_metas"][0]["lens"] = json["objects"].get(
"00", []
)
elif "entities" in json:
user["image_metas"][0]["lens"] = json["entities"].get(
"00", []
)
if isinstance(img, dict):
img_src.append(img["00"])
else:
img_src.append(img)
elif len(img_keys) > 0:
for i, key in enumerate(img_keys):
user["image_metas"].append({"words": [], "lens": []})
if f"<image_{i:02d}>" in prompt:
prompt = prompt.replace(f"<image_{i:02d}>", "<|image|>")
else:
user["content"] += "<|image|>"
img_src.append(img[key])
_words = json.get("words", {})
if isinstance(_words, dict):
_words = _words.get(key, [])
user["image_metas"][i]["words"] = _words
if "objects" in json:
_objects = json["objects"].get(key, [])
if isinstance(_objects, dict):
_objects = _objects.get(key, [])
user["image_metas"][i]["lens"] = _objects
if "entities" in json:
_entities = json["entities"].get(key, [])
if isinstance(_entities, dict):
_entities = _entities.get(key, [])
user["image_metas"][i]["lens"] = _entities
user["image_urls"] = img_src
if len(audio_keys) > 0:
for i, key in enumerate(audio_keys):
if isinstance(audio, dict):
user["audio_urls"].append(audio[key])
else:
user["audio_urls"].append(audio)
user["audio_metas"].append(
{
"format": "wav",
"note": "This audio sample is passed to convert_wds_to_datalake function.",
}
)
if f"<audio_{i:02d}>" in prompt:
prompt = prompt.replace(f"<audio_{i:02d}>", "<|audio|>")
else:
user["content"] += "<|audio|>"
user["content"] += prompt
content, candidates = None, list()
if self.mode != "train":
if isinstance(answer, (int, float)):
pass
elif isinstance(answer, str):
if answer != "None":
try:
answer = ast.literal_eval(answer)
except Exception as ex:
pass
if not isinstance(answer, (list, tuple)):
answer = [
answer,
]
candidates += answer[1:]
answer = answer[0]
content = answer
elif isinstance(answer, (list, tuple)):
for _idx, _answer in enumerate(answer):
if isinstance(_answer, str):
if isinstance(benchmark, str) and benchmark in [
"textvqa",
]:
try:
_answer = ast.literal_eval(_answer)
except Exception as ex:
pass
if isinstance(_answer, dict):
_answer = str(_answer)
if not isinstance(_answer, (list, tuple)):
_answer = [
_answer,
]
if _idx == 0:
content = _answer[0]
candidates += _answer[1:]
else:
candidates += _answer
if isinstance(content, (int, float)):
content = str(content)
assert content is None or isinstance(content, str)
for _idx, _candidate in enumerate(candidates):
if isinstance(_candidate, (int, float)):
candidates[_idx] = str(_candidate)
assert isinstance(candidates[_idx], str)
mcqa_gt = sample["vlm"]["data_info"].get("choice_answer", None)
if isinstance(mcqa_gt, str):
content = mcqa_gt
assistant = {
"candidates": candidates,
"content": answer if self.mode == "train" else content,
"content_type": "text",
"debuggingInfo": {},
"meta": {},
"role": "assistant",
"trainable_content": True,
"trainable_role": True,
"stop": False,
"endofturn": True,
}
turns.append(user)
turns.append(assistant)
elif key == "words":
img_src = []
user = {
"role": "user",
"content_type": "text",
"candidates": [],
"trainable_role": False,
"trainable_content": False,
"stop": False,
"debuggingInfo": {},
"meta": {},
"image_urls": [],
"image_metas": [],
"video_urls": [],
"video_metas": [],
"audio_urls": [],
"audio_metas": [],
"content": "<|image|>",
"endofturn": False,
}
instruct_prompt, task_prompt = hcx_vision_prompter(
task=json["meta"]["type"],
subtask=json["meta"].get("subtask", None),
lang=json["meta"]["lang"],
get_random=get_random,
use_task_prompt=use_task_prompt,
)
user["content"] += task_prompt
user["image_metas"].append({"words": [], "lens": []})
user["image_metas"][0]["words"] = json["words"]["00"]
if "entities" in json:
user["image_metas"][0]["lens"] = json["entities"].get("00", [])
img_src.append(img["00"])
user["image_urls"] = img_src
words_list = [
d["text"].strip() for d in json["words"]["00"] if d["text"]
]
gt = " ".join(words_list)
assistant = {
"candidates": [],
"content": gt,
"content_type": "text",
"debuggingInfo": {},
"meta": {},
"role": "assistant",
"trainable_content": True,
"trainable_role": True,
"stop": False,
"endofturn": True,
}
turns.append(user)
turns.append(assistant)
elif key.startswith("regions"):
for region in json[key]:
img_src = []
user = {
"role": "user",
"content_type": "text",
"candidates": [],
"trainable_role": False,
"trainable_content": False,
"stop": False,
"debuggingInfo": {},
"meta": {},
"image_urls": [],
"image_metas": [],
"video_urls": [],
"video_metas": [],
"audio_urls": [],
"audio_metas": [],
"content": "<|image|><|region|>",
"endofturn": False,
}
instruct_prompt, task_prompt = hcx_vision_prompter(
task=json["meta"]["type"],
subtask=json["meta"].get("subtask", None),
lang=json["meta"]["lang"],
get_random=get_random,
use_task_prompt=use_task_prompt,
)
sample["vlm"]["query_template"] = [task_prompt]
user["image_metas"].append({"words": [], "lens": []})
user["image_metas"][0]["region"] = region
if "words" in json:
user["image_metas"][0]["words"] = json["words"].get("00", [])
if "objects" in json:
user["image_metas"][0]["lens"] = json["objects"].get("00", [])
if "entities" in json:
user["image_metas"][0]["lens"] = json["entities"].get("00", [])
img_src.append(img["00"])
user["image_urls"] = img_src
assistant = {
"candidates": [],
"content": region["text"],
"content_type": "text",
"debuggingInfo": {},
"meta": {},
"role": "assistant",
"trainable_content": True,
"trainable_role": True,
"stop": False,
"endofturn": True,
}
turns.append(user)
turns.append(assistant)
else:
raise ConditionalError(
f"wrong task type in wds config: {sample['vlm']['data_name']}"
)
sample["data"] = turns
return sample
def preprocess_new(self, sample):
config = sample.get("vlm", {})
if config["data_type"] in ["sft1", "datalake"]:
default_config = copy.deepcopy(self.default_config)
default_config.update(config)
config = default_config
idx_for_debug = sample.get("idx", -1)
turns = sample["data"] if "data" in sample else sample["messages"]
if self.random_system_prompt and self.rng.random() < config.get(
"random_system_prob", 0.0
):
for turn in turns:
if turn["role"] == "system":
turn["content"] = self.random_system_prompt
break
if sample.get("tools", None) is None:
sample["tools"] = []
if len(sample["tools"]) == 0:
if (
self.rng.random() < config.get("random_tool_prob", 0.005)
and len(self.common_tools) > 0
):
max_n_tools = min(7, len(self.common_tools))
tool_counts = np.arange(1, max_n_tools + 1)
tool_count_weights = 1.0 / tool_counts
tool_count_weights = tool_count_weights / tool_count_weights.sum()
n_tools = int(self.rng.choice(tool_counts, p=tool_count_weights))
idxs = np.arange(len(self.common_tools))
weights = 1.0 / (idxs + 1)
weights[0] += 1.0
weights = weights / weights.sum()
chosen_indices = self.rng.choice(
len(self.common_tools), size=n_tools, replace=False, p=weights
)
self.rng.shuffle(chosen_indices)
sample["tools"] = [self.common_tools[i] for i in chosen_indices]
if "tools" in sample and sample["tools"]:
tool_prompt = []
tool_prompt.append("# Tools\n\n")
tool_prompt.append(
"You may call one or more functions to assist with the user query.\n\n"
)
tool_prompt.append(
"You are provided with function signatures within <tools></tools> XML tags:\n"
)
tool_prompt.append("<tools>\n")
for tool in sample["tools"]:
tool_prompt.append(json.dumps(tool, ensure_ascii=False))
tool_prompt.append("\n</tools>\n\n")
tool_prompt.append(
"For each function call, output the function name and arguments within the following XML format:\n"
)
tool_prompt.append("<tool_call>{function-name}\n")
tool_prompt.append("<arg_key>{arg-key-1}</arg_key>\n")
tool_prompt.append("<arg_value>{arg-value-1}</arg_value>\n")
tool_prompt.append("<arg_key>{arg-key-2}</arg_key>\n")
tool_prompt.append("<arg_value>{arg-value-2}</arg_value>\n")
tool_prompt.append("...\n")
tool_prompt.append("</tool_call>")
tool_prompt = "".join(tool_prompt)
else:
tool_prompt = ""
multiturn_n_sample = config.get("multiturn_n_samples", 0)
if multiturn_n_sample > 0 and self.mode == "train":
turns = self._sampling_multiturn(
turns,
multiturn_n_sample,
multiturn_preserve_order=config.get("multiturn_preserve_order", True),
multiturn_continuous=config.get("multiturn_continuous", False),
)
for i, turn in enumerate(turns):
if turn["role"] == "user":
if "img_src" in turn:
turns[i]["image_urls"] = turn["img_src"]
turns[i]["image_metas"] = turn["meta"]
for j, turn_img_meta in enumerate(turns[i]["image_metas"]):
if "entities" in turn_img_meta:
turns[i]["image_metas"][j]["lens"] = turn_img_meta[
"entities"
]
turns[i]["meta"] = {}
max_image_cnt = config.get("max_image_cnt", 20)
if max_image_cnt > 0 and config["data_type"] != "sft1":
n_imgs = {}
for i, turn in enumerate(turns):
if turn["role"] == "user":
n_imgs[i] = len(turn.get("image_urls", []))
assert (
n_imgs[i] <= max_image_cnt
), "skip sample if image_nums exceeds max_image_count per turn"
if sum(n_imgs.values()) > max_image_cnt:
img_count = 0
for k, v in reversed(list(n_imgs.items())):
img_count += v
if img_count > max_image_cnt:
break
img_count = sum(n_imgs.values()) - max_image_cnt
for i in range(k + 1):
if turns[i]["role"] == "user":
turns[i]["content"], n_removed1 = re.subn(
r"<image_\d{2}>",
"",
turns[i]["content"].strip(),
count=img_count,
)
img_count -= n_removed1
turns[i]["content"], n_removed2 = re.subn(
r"<\|image\|>",
"",
turns[i]["content"].strip(),
count=img_count,
)
img_count -= n_removed2
n_removed_imgs = n_removed1 + n_removed2
turns[i]["image_urls"] = turns[i]["image_urls"][n_removed_imgs:]
if n_removed_imgs > 0 and len(turns[i]["image_urls"]) == 0:
idx = i
while True:
idx += 1
turns[idx]["trainable_role"] = False
turns[idx]["trainable_content"] = False
if turns[idx]["role"] == "assistant":
break
n_imgs_after = {}
for i, turn in enumerate(turns):
if turn["role"] == "user":
n_imgs_after[i] = len(turn.get("image_urls", []))
assert sum(n_imgs_after.values()) > 0, "The n_imgs of vlm data is zero."
n_mm_after = {}
for i, turn in enumerate(turns):
if turn["role"] == "user" or turn["role"] == "assistant":
n_mm_after[i] = (
len(turn.get("image_urls", []))
+ len(turn.get("video_urls", []))
+ len(turn.get("audio_urls", []))
)
assert sum(n_mm_after.values()) > 0, "The n_mm of omni data is zero."
queries, gts = list(), list()
output = Processed_sample(
input_str="",
input_ids=[],
label_ids=[],
imgs=[],
discrete_imgs=[],
videos=[],
videos_duration=[],
video_audios=[],
audios=[],
audios_duration=[],
discrete_audios=[],
sample_mm_counter={
"image": 0,
"video": 0,
"audio": 0,
},
)
system_role_count = 0
last_user_idx = max(
(i for i, d in enumerate(turns) if d.get("role") == "user"), default=-1
)
for i, turn in enumerate(turns):
if turn["role"] == "tool_list":
continue
elif turn["role"] == "system":
if config.get("llava_pretrain", False):
continue
output = Preprocessor.prompt_system(
turn=turn,
output=output,
tokenizer=self.tokenizer,
seed=self.rng,
tool_prompt=tool_prompt,
system_role_count=system_role_count,
)
system_role_count += 1
elif turn["role"].startswith("user"):
output = Preprocessor.load_mm(
output=output,
img_dir=config.get("img_dir", ""),
turn=turn,
prepare_input_fn=self.prepare_input_fn,
max_image_cnt=max_image_cnt,
video_max_num_frames=self.video_max_num_frames,
video_max_pixels=self.video_max_pixels,
use_audio=self.train_audio,
)
output = Preprocessor.prompt_user(
output=output,
tokenizer=self.tokenizer,
turn=turn,
is_train=True if self.mode == "train" else False,
fixed_mime=config.get("fixed_mime", False),
mimes=self.mimes,
query_template=config.get("query_template", None),
config=config,
seed=self.rng,
)
queries.append(turn["content"].replace("<|image|>", "").strip())
elif turn["role"].startswith("assistant"):
output = Preprocessor.load_mm(
output=output,
img_dir=config.get("img_dir", ""),
turn=turn,
prepare_input_fn=self.prepare_input_fn,
max_image_cnt=max_image_cnt,
video_max_num_frames=self.video_max_num_frames,
video_max_pixels=self.video_max_pixels,
use_audio=self.train_audio,
)
is_after_last_user = i > last_user_idx
is_first_assistant_after_last_user = False
if is_after_last_user:
is_first_assistant_after_last_user = all(
turns[j]["role"] != "assistant"
for j in range(last_user_idx + 1, i)
)
output = Preprocessor.prompt_assistant(
output=output,
tokenizer=self.tokenizer,
turn=turn,
is_last_turn=is_first_assistant_after_last_user,
is_eval=True if self.mode != "train" else False,
is_llava_pretrain=config.get("llava_pretrain", False),
is_after_last_user_turn=is_after_last_user,
)
_gts = turn["content"]
if isinstance(_gts, str):
_gts = [
_gts,
]
if "candidates" in turn and len(turn["candidates"]) > 0:
for _candidates in turn["candidates"]:
if isinstance(_candidates, str):
_gts += [
_candidates,
]
elif isinstance(turn["candidates"][0], (list, tuple)):
_gts += _candidates
gts.append(_gts)
elif turn["role"] == "tool":
if config.get("llava_pretrain", False):
continue
output = Preprocessor.prompt_tool(
output=output,
tokenizer=self.tokenizer,
turn=turn,
need_start_tag=(
True
if (i == 0 or turns[i - 1].get("role") != "tool")
else False
),
need_end_tag=(
True
if (i == (len(turns) - 1) or turns[i + 1].get("role") != "tool")
else False
),
)
else:
if config.get("llava_pretrain", False):
continue
import pdb
import sys
class ForkedPdb(pdb.Pdb):
"""A Pdb subclass that may be used from a forked multiprocessing child"""
def interaction(self, *args, **kwargs):
_stdin = sys.stdin
try:
sys.stdin = open("/dev/stdin")
pdb.Pdb.interaction(self, *args, **kwargs)
finally:
sys.stdin = _stdin
ForkedPdb().set_trace()
output = Preprocessor.prompt_etc(
output=output,
tokenizer=self.tokenizer,
turn=turn,
)
pixel_values = []
mm_query_lengths = []
discrete_pixel_values = []
image_ratios = []
discrete_image_query_lengths = []
labels = output.label_ids
input_ids = output.input_ids
total_mm_query_length = 0
is_sft1 = False
if config["data_type"] == "sft1":
if self.sequence_parallel_size > 1:
if len(input_ids) % self.sequence_parallel_size != 0:
input_ids += [self.tokenizer.pad_token_id] * (
self.sequence_parallel_size
- (len(input_ids) % self.sequence_parallel_size)
)
labels += [IGNORE_INDEX] * (
self.sequence_parallel_size
- (len(labels) % self.sequence_parallel_size)
)
input_ids = input_ids[
: (len(input_ids) // self.sequence_parallel_size)
* self.sequence_parallel_size
]
labels = labels[
: (len(labels) // self.sequence_parallel_size)
* self.sequence_parallel_size
]
input_ids = torch.tensor(input_ids[-self.decoder_max_length :])
labels = torch.tensor(labels[-self.decoder_max_length :])
is_sft1 = True
dummy_preprocess_results = self.prepare_input_fn.image_processor(
Image.new("RGB", (224, 224), (0, 0, 0))
)
dummy_pixel_values = torch.from_numpy(
np.concatenate([dummy_preprocess_results.pixel_values], axis=0)
)
dummy_grid_thw = torch.from_numpy(
np.concatenate([dummy_preprocess_results.image_grid_thw], axis=0)
)
image_grid_thw = []
for img in output.imgs:
w, h = img.size
img = self._resize_min_edge(img)
preprocess_results = self.prepare_input_fn.image_processor([img])
pixel_values.append(preprocess_results.pixel_values)
image_grid_thw.append(preprocess_results.image_grid_thw)
mm_query_lengths.append(preprocess_results.pixel_values.shape[0] // 4)
if len(output.imgs) == 0:
pixel_values = torch.zeros(0, 1176)
image_grid_thw = torch.zeros(0, 3, dtype=torch.long)
else:
pixel_values = torch.from_numpy(np.concatenate(pixel_values, axis=0))
image_grid_thw = torch.from_numpy(np.concatenate(image_grid_thw, axis=0))
for img in output.discrete_imgs:
w, h = img.size
img_ratio = self._find_best_ratio_token([h, w])
image_ratios.append(img_ratio)
discrete_pixel_value = img.resize((384, 384), Image.BICUBIC)
discrete_pixel_tensor = to_tensor(discrete_pixel_value)
assert discrete_pixel_tensor.shape == (
3,
384,
384,
), f"Unexpected discrete_pixel_tensor shape: {discrete_pixel_tensor.shape}"
assert not torch.isnan(
discrete_pixel_tensor
).any(), "discrete_pixel_tensor contains NaN"
assert not torch.isinf(
discrete_pixel_tensor
).any(), "discrete_pixel_tensor contains Inf"
pixel_min = discrete_pixel_tensor.min().item()
pixel_max = discrete_pixel_tensor.max().item()
assert (
0.0 <= pixel_min <= 1.0 and 0.0 <= pixel_max <= 1.0
), f"discrete_pixel_tensor values out of range [0, 1]: min={pixel_min}, max={pixel_max}"
discrete_pixel_values.append(discrete_pixel_tensor)
discrete_image_query_lengths.append(729)
if len(output.discrete_imgs) == 0:
discrete_pixel_values = torch.zeros(0, 3, 384, 384)
else:
discrete_pixel_values = torch.stack(discrete_pixel_values, dim=0)
assert discrete_pixel_values.shape[1:] == (
3,
384,
384,
), f"Unexpected stacked discrete_pixel_values shape: {discrete_pixel_values.shape}"
assert not torch.isnan(
discrete_pixel_values
).any(), "Stacked discrete_pixel_values contains NaN"
assert not torch.isinf(
discrete_pixel_values
).any(), "Stacked discrete_pixel_values contains Inf"
pixel_values_videos = None
video_grid_thw = None
if self.train_video:
pixel_values_videos = []
video_grid_thw = []
video_query_lengths = []
for video in output.videos:
preprocess_results = self.prepare_input_fn.video_processor([video])
pixel_values_videos.append(preprocess_results.pixel_values_videos)
video_grid_thw.append(preprocess_results.video_grid_thw)
video_query_lengths.append(
preprocess_results.pixel_values_videos.shape[0] // 4
)
if len(output.videos) == 0:
pixel_values_videos = torch.zeros(0, 1176)
video_grid_thw = torch.zeros(0, 3, dtype=torch.long)
else:
pixel_values_videos = torch.from_numpy(
np.concatenate(pixel_values_videos, axis=0)
)
video_grid_thw = torch.from_numpy(
np.concatenate(video_grid_thw, axis=0)
)
video_audio_values = []
video_audio_masks = []
video_audio_query_lengths = []
if self.train_video and hasattr(output, "video_audios") and output.video_audios:
for idx, video_audio_chunks in enumerate(output.video_audios):
if video_audio_chunks:
processed_audio_values = []
processed_audio_masks = []
chunk_output_lengths = []
for chunk in video_audio_chunks:
if isinstance(chunk, torch.Tensor):
chunk_np = chunk.cpu().numpy()
else:
chunk_np = chunk
preprocess_results = self.prepare_audio_input_fn(
[chunk_np],
sampling_rate=self.prepare_audio_input_fn.sampling_rate,
return_attention_mask=True,
padding="max_length",
)
audio_value = preprocess_results.input_features[0]
audio_mask = preprocess_results.attention_mask[0]
mask_sum = int(audio_mask.sum())
input_lengths = (mask_sum - 1) // 2 + 1
output_lengths = (input_lengths - 2) // 2 + 1
chunk_output_lengths.append(output_lengths)
processed_audio_values.append(torch.from_numpy(audio_value))
processed_audio_masks.append(torch.from_numpy(audio_mask))
pool_size = 25
if self.video_audio_compressor_type is not None:
total_valid_len = sum(chunk_output_lengths)
total_audio_query_length = (
total_valid_len + pool_size - 1
) // pool_size
else:
total_audio_query_length = sum(
(valid_len + pool_size - 1) // pool_size
for valid_len in chunk_output_lengths
)
video_audio_values.append(processed_audio_values)
video_audio_masks.append(processed_audio_masks)
video_audio_query_lengths.append(total_audio_query_length)
import os
if (
int(os.environ.get("RANK", -1)) == 0
and total_audio_query_length == 177
):
print(
f"\n[PREPROCESSOR VIDEO - 177 TOKENS DETECTED!] total_audio_query_length={total_audio_query_length}, num_chunks={len(processed_audio_masks)}"
)
for chunk_idx, mask_tensor in enumerate(processed_audio_masks):
chunk_mask_sum = int(mask_tensor.sum())
chunk_input_len = (chunk_mask_sum - 1) // 2 + 1
chunk_output_len = (chunk_input_len - 2) // 2 + 1
chunk_pooled = (chunk_output_len + 24) // 25
print(
f" Chunk {chunk_idx}: mask_sum={chunk_mask_sum}, output_len={chunk_output_len}, pooled={chunk_pooled}"
)
print()
else:
video_audio_values.append([])
video_audio_masks.append([])
video_audio_query_lengths.append(0)
dummy_video_preprocess_results = self.prepare_input_fn.video_processor(
[Image.new("RGB", (224, 224), (0, 0, 0))] * 3
)
dummy_pixel_values_videos = torch.from_numpy(
np.concatenate([dummy_video_preprocess_results.pixel_values_videos], axis=0)
)
dummy_video_grid_thw = torch.from_numpy(
np.concatenate([dummy_video_preprocess_results.video_grid_thw], axis=0)
)
dummy_video_preprocess_results = self.prepare_audio_input_fn(
[np.zeros(self.prepare_audio_input_fn.sampling_rate * 3, dtype=np.float32)],
sampling_rate=self.prepare_audio_input_fn.sampling_rate,
return_attention_mask=True,
padding="max_length",
)
dummy_video_audio_values = torch.from_numpy(
dummy_video_preprocess_results.input_features
)
dummy_video_audio_masks = torch.from_numpy(
dummy_video_preprocess_results.attention_mask
)
audio_values = None
discrete_audio_values = None
audio_masks = None
dummy_preprocess_results = self.prepare_audio_input_fn(
[np.zeros(self.prepare_audio_input_fn.sampling_rate * 3, dtype=np.float32)],
sampling_rate=self.prepare_audio_input_fn.sampling_rate,
return_attention_mask=True,
padding="max_length",
)
dummy_audio_values = torch.from_numpy(dummy_preprocess_results.input_features)
dummy_audio_masks = torch.from_numpy(dummy_preprocess_results.attention_mask)
if self.train_audio:
audio_values = []
discrete_audio_values = []
audio_masks = []
audio_query_lengths = []
discrete_audio_query_lengths = []
if len(output.audios) > 99:
raise ConditionalError(
f"Too many audio segments in one sample: {len(output.audios)} audios."
)
for audio in output.audios:
chunks = []
for i in range(
0, len(audio), 30 * self.prepare_audio_input_fn.sampling_rate
):
chunks.append(
audio[i : i + 30 * self.prepare_audio_input_fn.sampling_rate]
)
num_of_chunks = len(chunks)
preprocess_results = self.prepare_audio_input_fn(
chunks,
sampling_rate=self.prepare_audio_input_fn.sampling_rate,
return_attention_mask=True,
padding="max_length",
)
audio_value = preprocess_results.input_features
audio_mask = preprocess_results.attention_mask
audio_values.append(audio_value)
audio_masks.append(audio_mask)
input_lengths = int(audio_mask.sum())
input_lengths = (input_lengths - 1) // 2 + 1
output_lengths = (input_lengths - 2) // 2 + 1
audio_query_lengths.append(output_lengths)
if len(output.audios) == 0:
audio_values = torch.zeros(0, 128, 3000)
audio_masks = torch.zeros(0, 3000)
else:
audio_values = torch.from_numpy(np.concatenate(audio_values, axis=0))
audio_masks = torch.from_numpy(np.concatenate(audio_masks, axis=0))
for audio in output.discrete_audios:
audio_length = len(audio)
assert audio_length >= MIN_DISCRETE_AUDIO_CHUNK_SAMPLES, (
f"discrete_audio is too short ({audio_length} samples < {MIN_DISCRETE_AUDIO_CHUNK_SAMPLES}). "
f"This will cause 0-dim/empty tensor in CosyVoice encoder. "
f"Skip this sample."
)
max_audio_length = 600 * DEFAULT_SAMPLE_RATE
audio_duration_sec = audio_length / DEFAULT_SAMPLE_RATE
assert (
audio_length <= max_audio_length
), f"discrete_audio is too long ({audio_length} samples = {audio_duration_sec:.1f}s > 600s). "
assert not torch.isnan(audio).any(), (
f"discrete_audio contains NaN values! "
f"This will cause CUDA illegal memory access. Skip this sample."
)
assert not torch.isinf(audio).any(), (
f"discrete_audio contains Inf values! "
f"This will cause CUDA illegal memory access. Skip this sample."
)
audio_min, audio_max = audio.min().item(), audio.max().item()
assert -100.0 <= audio_min <= 100.0 and -100.0 <= audio_max <= 100.0, (
f"discrete_audio has extreme values (min={audio_min:.2f}, max={audio_max:.2f}). "
f"Expected roughly [-1, 1] range. This indicates corrupted audio. Skip this sample."
)
discrete_audio_values.append(audio)
if audio_length > 80 * DEFAULT_SAMPLE_RATE:
chunk_size = 80 * DEFAULT_SAMPLE_RATE
total_code_len = 0
for start in range(0, audio_length, chunk_size):
end = min(start + chunk_size, audio_length)
if (
end < audio_length
and audio_length - end < MIN_DISCRETE_AUDIO_CHUNK_SAMPLES
):
end = audio_length
chunk_length = end - start
assert chunk_length >= MIN_DISCRETE_AUDIO_CHUNK_SAMPLES, (
f"chunk_length={chunk_length} < {MIN_DISCRETE_AUDIO_CHUNK_SAMPLES}. This should never happen with our chunking logic. "
f"audio_length={audio_length}, start={start}, end={end}. Skip this sample."
)
mel_len = chunk_length // 160
assert mel_len > 0, (
f"mel_len={mel_len} is invalid (chunk_length={chunk_length}). "
f"This will cause illegal memory access in AudioEncoder. Skip this sample."
)
after_conv1 = (mel_len + 2 * 1 - 1 * (3 - 1) - 1) // 2 + 1
code_len = (after_conv1 + 2 * 1 - 1 * (3 - 1) - 1) // 2 + 1
assert code_len > 0, (
f"code_len={code_len} is invalid (mel_len={mel_len}, after_conv1={after_conv1}). "
f"This will cause illegal memory access. Skip this sample."
)
total_code_len += code_len
if end >= audio_length:
break
assert total_code_len > 0, (
f"total_code_len={total_code_len} is invalid after processing all chunks. "
f"audio_length={audio_length}. This should never happen. Skip this sample."
)
audio_duration_sec = audio_length / DEFAULT_SAMPLE_RATE
max_expected_codes = int(audio_duration_sec * 25 * 1.1)
assert total_code_len <= max_expected_codes, (
f"total_code_len={total_code_len} is suspiciously large (max_expected={max_expected_codes}). "
f"audio_length={audio_length} ({audio_duration_sec:.1f}s). "
f"Expected ~{int(audio_duration_sec * 25)} tokens (25 tokens/sec). "
f"This indicates calculation error. Skip this sample."
)
discrete_audio_query_lengths.append(total_code_len)
else:
mel_len = audio_length // 160
assert mel_len > 0, (
f"mel_len={mel_len} is invalid (audio_length={audio_length}). "
f"This will cause illegal memory access in AudioEncoder. Skip this sample."
)
after_conv1 = (mel_len + 2 * 1 - 1 * (3 - 1) - 1) // 2 + 1
code_len = (after_conv1 + 2 * 1 - 1 * (3 - 1) - 1) // 2 + 1
assert code_len > 0, (
f"Calculated code_len={code_len} is invalid (audio_length={audio_length}, "
f"mel_len={mel_len}, after_conv1={after_conv1}). "
f"This indicates corrupted audio data. Skip this sample."
)
assert code_len <= 2048, (
f"code_len={code_len} exceeds freqs_cis max length (2048). "
f"Audio length: {audio_length / DEFAULT_SAMPLE_RATE:.1f}s (max ~82s for single chunk). "
f"Expected ~{int((audio_length / DEFAULT_SAMPLE_RATE) * 25)} tokens at 25 tokens/sec. "
f"This will cause illegal memory access in apply_rotary_emb. Skip this sample."
)
discrete_audio_query_lengths.append(code_len)
img_start_ids = [
i for i, token in enumerate(input_ids) if token == self.img_token
]
assert len(img_start_ids) == len(mm_query_lengths)
for i, length in zip(
range(len(mm_query_lengths) - 1, -1, -1), mm_query_lengths[::-1]
):
labels[img_start_ids[i] : img_start_ids[i] + 1] = [IGNORE_INDEX] * length
input_ids[img_start_ids[i] : img_start_ids[i] + 1] = [
self.img_token
] * length
total_mm_query_length += length
discrete_image_start_ids = [
i for i, token in enumerate(input_ids) if token == self.discrete_image_token
]
assert len(discrete_image_start_ids) == len(discrete_image_query_lengths)
assert len(discrete_image_start_ids) == len(
image_ratios
), "discrete_image_start_ids and image_ratios length mismatch"
for idx in range(len(discrete_image_query_lengths) - 1, -1, -1):
i = discrete_image_start_ids[idx]
length = discrete_image_query_lengths[idx]
ratio_token_id = image_ratios[idx]
assert (
length == 729
), f"discrete_image_query_length must be 729, but got {length}"
token_sequence = [ratio_token_id]
for token_idx in range(length):
token_sequence.append(self.discrete_image_token)
if (token_idx + 1) % 27 == 0:
token_sequence.append(self.discrete_image_eol_token)
token_sequence.append(self.discrete_image_eof_token)
total_length = len(token_sequence)
if labels[i] == IGNORE_INDEX:
labels[i : i + 1] = [IGNORE_INDEX] * total_length
else:
labels[i : i + 1] = token_sequence
input_ids[i : i + 1] = token_sequence
if self.train_video:
vid_start_ids = [
i for i, token in enumerate(input_ids) if token == self.video_token
]
for idx in range(len(vid_start_ids) - 1, -1, -1):
pos = vid_start_ids[idx]
num_frames = int(video_grid_thw[idx][0])
frame_query_length = video_query_lengths[idx]
has_video_audio = (
idx < len(video_audio_query_lengths)
and video_audio_query_lengths[idx] > 0
)
if has_video_audio:
total_audio_tokens = video_audio_query_lengths[idx]
token_sequence = []
if num_frames > 0:
frame_base = frame_query_length // num_frames
frame_remainder = frame_query_length % num_frames
assert frame_remainder == 0, (
f"frame_query_length({frame_query_length}) must be divisible by num_frames({num_frames}). "
f"Each frame produces fixed number of tokens. Got remainder={frame_remainder}."
)
audio_base = total_audio_tokens // num_frames
audio_remainder = total_audio_tokens % num_frames
for frame_idx in range(num_frames):
frame_tokens = frame_base + (
1 if frame_idx < frame_remainder else 0
)
token_sequence.extend([self.video_token] * frame_tokens)
audio_tokens = audio_base + (
1 if frame_idx < audio_remainder else 0
)
if audio_tokens > 0:
token_sequence.extend(
[self.video_audio_token] * audio_tokens
)
else:
token_sequence = [self.video_token] * frame_query_length
else:
token_sequence = [self.video_token] * frame_query_length
total_length = len(token_sequence)
labels[pos : pos + 1] = [IGNORE_INDEX] * total_length
input_ids[pos : pos + 1] = token_sequence
if self.train_audio:
audio_start_ids = [
i for i, token in enumerate(input_ids) if token == self.audio_token
]
assert len(audio_start_ids) == len(audio_query_lengths)
for i, length in zip(
range(len(audio_query_lengths) - 1, -1, -1), audio_query_lengths[::-1]
):
labels[audio_start_ids[i] : audio_start_ids[i] + 1] = [
IGNORE_INDEX
] * length
input_ids[audio_start_ids[i] : audio_start_ids[i] + 1] = [
self.audio_token
] * length
discrete_audio_start_ids = [
i
for i, token in enumerate(input_ids)
if token == self.discrete_audio_token
]
assert len(discrete_audio_start_ids) == len(discrete_audio_query_lengths), (
f"discrete_audio_start_ids count ({len(discrete_audio_start_ids)}) != "
f"discrete_audio_query_lengths count ({len(discrete_audio_query_lengths)}). "
f"This indicates a serious bug in preprocessor or corrupted data. Skip this sample."
)
for i, length in zip(
range(len(discrete_audio_query_lengths) - 1, -1, -1),
discrete_audio_query_lengths[::-1],
):
assert 0 < length < 16000, (
f"discrete_audio_query_length={length} is out of valid range [1, 16000). "
f"Expected max ~15,000 for 600s audio at 25 tokens/sec. "
f"This can cause illegal memory access when creating embeddings. Skip this sample."
)
if labels[discrete_audio_start_ids[i]] == IGNORE_INDEX:
labels[
discrete_audio_start_ids[i] : discrete_audio_start_ids[i] + 1
] = [IGNORE_INDEX] * length
else:
labels[
discrete_audio_start_ids[i] : discrete_audio_start_ids[i] + 1
] = [self.discrete_audio_token] * length
input_ids[
discrete_audio_start_ids[i] : discrete_audio_start_ids[i] + 1
] = [self.discrete_audio_token] * length
if self.sequence_parallel_size > 1:
if len(input_ids) % self.sequence_parallel_size != 0:
input_ids += [self.tokenizer.pad_token_id] * (
self.sequence_parallel_size
- (len(input_ids) % self.sequence_parallel_size)
)
labels += [IGNORE_INDEX] * (
self.sequence_parallel_size
- (len(labels) % self.sequence_parallel_size)
)
if not is_sft1:
input_ids = torch.tensor(input_ids)
labels = torch.tensor(labels)
if self.mode == "train":
if self.sample_min_length is not None and self.sample_min_length > 0:
assert (
len(labels) >= self.sample_min_length
), "The sample is too short: {} < {}".format(
len(labels), self.sample_min_length
)
assert (
len(labels) <= self.decoder_max_length
), "The sample exceeds decoder_max_len: {} > {}".format(
len(labels), self.decoder_max_length
)
assert len(input_ids) == len(labels)
if len(labels) < 30:
raise ConditionalError(
"The sample is too short: {}".format(len(labels))
)
if torch.all(labels == IGNORE_INDEX):
raise ConditionalError(
"Labels contain only IGNORE_INDEX, no training targets available"
)
sample = {
"pixel_values": pixel_values,
"discrete_pixel_values": discrete_pixel_values,
"idx_for_debug": idx_for_debug,
"input_ids": input_ids,
"labels": labels,
"queries": queries if len(queries) > 0 else None,
"gts": gts if len(gts) > 0 else None,
"mm_query_lengths": mm_query_lengths,
"non_mm_query_lengths": len(labels) - total_mm_query_length,
"total_length": len(labels),
"data_name": config["data_name"],
"data_type": config["data_type"],
"img_start_ids": img_start_ids,
"prompt": output.input_str,
"options": config.get("options", None),
"image_grid_thw": image_grid_thw,
"pixel_values_videos": pixel_values_videos,
"video_grid_thw": video_grid_thw,
"video_audio_values": (
video_audio_values if len(video_audio_values) > 0 else None
),
"video_audio_masks": (
video_audio_masks if len(video_audio_masks) > 0 else None
),
"audio_values": audio_values,
"discrete_audio_values": discrete_audio_values,
"audio_masks": audio_masks,
"dummy_pixel_values": dummy_pixel_values,
"dummy_grid_thw": dummy_grid_thw,
"dummy_audio_values": dummy_audio_values,
"dummy_audio_masks": dummy_audio_masks,
"dummy_pixel_values_videos": dummy_pixel_values_videos,
"dummy_video_grid_thw": dummy_video_grid_thw,
"dummy_video_audio_values": dummy_video_audio_values,
"dummy_video_audio_masks": dummy_video_audio_masks,
}
return sample
def _sampling_multiturn(
self,
turns,
n_sample,
multiturn_preserve_order=True,
multiturn_continuous=False,
):
new_turns = []
sample_indices = []
first_user_turn = True
start_idx = 0
for idx, turn in enumerate(turns):
if turn["role"] in ["system", "tool_list"]:
new_turns.append(turn)
start_idx = idx + 1
continue
if turn["role"] == "user":
image_nums = re.findall(r"<image_(\d+)>", turn["content"])
if len(image_nums) == 0:
image_nums = re.findall(r"<\|image\|>", turn["content"])
if len(image_nums) > 0:
if first_user_turn:
first_user_turn = False
continue
sample_indices.append([i for i in range(start_idx, idx)])
start_idx = idx
sample_indices.append([i for i in range(start_idx, idx + 1)])
n_sample = min(n_sample, len(sample_indices))
if multiturn_continuous:
start_index = random.randint(0, len(sample_indices) - n_sample)
indices = range(start_index, start_index + n_sample)
elif multiturn_preserve_order:
indices = sorted(random.sample(range(len(sample_indices)), n_sample))
else:
indices = random.sample(range(len(sample_indices)), n_sample)
sampled_indices = [sample_indices[i] for i in indices]
new_turns = new_turns + [
turns[i] for sampled_turns in sampled_indices for i in sampled_turns
]
return new_turns