|
|
import ast |
|
|
import copy |
|
|
import datetime |
|
|
import gc |
|
|
import io |
|
|
import json |
|
|
import math |
|
|
import mimetypes |
|
|
import os |
|
|
import random |
|
|
import re |
|
|
import sys |
|
|
import tarfile |
|
|
import tempfile |
|
|
import zipfile |
|
|
from collections import defaultdict, deque |
|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
|
|
import av |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import PIL |
|
|
import pkg_resources |
|
|
import scipy.signal as scsig |
|
|
import torch |
|
|
from decord import VideoReader, cpu |
|
|
from PIL import Image, ImageDraw |
|
|
from smart_open import open |
|
|
from torchvision.transforms.functional import to_tensor |
|
|
|
|
|
from hcxvlm.dataset.base_dataset import image_decoder |
|
|
from hcxvlm.dataset.hcx_vision_prompter import HCXVisionPrompter |
|
|
|
|
|
CHOICES = list(map(chr, range(97, 123))) |
|
|
IGNORE_INDEX = -100 |
|
|
DEFAULT_SAMPLE_RATE = 16000 |
|
|
MIN_DISCRETE_AUDIO_CHUNK_SAMPLES = 1600 |
|
|
DEFAULT_VOLUME_LEVEL = 10 ** (-26 / 20) |
|
|
|
|
|
hcx_vision_prompter = HCXVisionPrompter() |
|
|
|
|
|
|
|
|
def hpf_normalize( |
|
|
wav: np.ndarray, |
|
|
sr: int = DEFAULT_SAMPLE_RATE, |
|
|
volume_level: float = DEFAULT_VOLUME_LEVEL, |
|
|
) -> np.ndarray: |
|
|
assert (wav**2).mean() > 0, "Error in the wav file" |
|
|
|
|
|
filter_ = scsig.butter(2, 70, "highpass", fs=sr, output="sos") |
|
|
wav = scsig.sosfilt(filter_, wav) |
|
|
wav = wav.astype(np.float32) |
|
|
|
|
|
gain = min(volume_level / (wav**2).mean() ** 0.5, 1 / np.max(np.abs(wav))) |
|
|
wav *= gain |
|
|
return wav |
|
|
|
|
|
|
|
|
def convert_bboxes(img, img_meta): |
|
|
for k, v in img_meta.items(): |
|
|
if k == "region": |
|
|
bbox_key = "bbox" if "bbox" in img_meta[k] else "boundingBox" |
|
|
img_meta[k] = reform_bbox( |
|
|
img_meta[k][bbox_key], img.size, format=img_meta[k]["format"] |
|
|
) |
|
|
return img_meta |
|
|
|
|
|
|
|
|
def reform_bbox(bbox, image_size, format="REL_XYXY"): |
|
|
w, h = image_size |
|
|
if format == "REL_XYXY": |
|
|
x1, y1, x2, y2 = bbox[0] * w, bbox[1] * h, bbox[2] * w, bbox[3] * h |
|
|
elif format == "REL_XYWH": |
|
|
x1, y1 = bbox[0] * w, bbox[1] * h |
|
|
x2, y2 = x1 + bbox[2] * w, y1 + bbox[3] * h |
|
|
else: |
|
|
raise NotImplementedError |
|
|
new_bbox = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]] |
|
|
return new_bbox |
|
|
|
|
|
|
|
|
def generate_random_color(use_alpha=True, seed=None): |
|
|
if seed is None: |
|
|
seed = np.random.default_rng() |
|
|
|
|
|
if use_alpha: |
|
|
color_list = [ |
|
|
("빨강", (255, 127, 127, 100)), |
|
|
("노랑", (255, 255, 127, 100)), |
|
|
("초록", (127, 255, 125, 100)), |
|
|
("하늘", (127, 255, 255, 100)), |
|
|
("파랑", (127, 127, 255, 100)), |
|
|
("보라", (255, 127, 255, 100)), |
|
|
] |
|
|
else: |
|
|
color_list = [ |
|
|
("빨강", (255, 0, 0)), |
|
|
("노랑", (255, 255, 0)), |
|
|
("초록", (0, 128, 0)), |
|
|
("하늘", (135, 206, 235)), |
|
|
("파랑", (0, 0, 255)), |
|
|
("보라", (128, 0, 128)), |
|
|
] |
|
|
return color_list[seed.integers(0, len(color_list))] |
|
|
|
|
|
|
|
|
EN_COLOR = { |
|
|
"빨강": "red", |
|
|
"노랑": "yellow", |
|
|
"초록": "green", |
|
|
"하늘": "sky blue", |
|
|
"파랑": "blue", |
|
|
"보라": "purple", |
|
|
} |
|
|
|
|
|
|
|
|
def overlay_rectangle(image, words, lang, seed=None): |
|
|
color_str, color = generate_random_color(seed=seed) |
|
|
draw = ImageDraw.Draw(image, "RGBA") |
|
|
for word in words: |
|
|
shape_rect = word["bbox"] |
|
|
shape_rect = [(round(x[0]), round(x[1])) for x in shape_rect] |
|
|
draw.polygon(shape_rect, color) |
|
|
del draw |
|
|
if lang == "en": |
|
|
color_str = EN_COLOR[color_str] |
|
|
return image, color_str |
|
|
|
|
|
|
|
|
def convert_tags_for_video(img, json): |
|
|
"""video 데이터에는 <image_xx> 태그 대신 <video_00> tag가 있음. |
|
|
img 숫자 만큼 <video_00> tag 대신 <image_xx> tag를 변환하여 넣음 |
|
|
""" |
|
|
image_tag = "".join([f"<image_{idx:02d}>" for idx in range(len(img))]) |
|
|
for json_key in json: |
|
|
if "qa_pairs" in json_key: |
|
|
new_qa_pairs = [] |
|
|
for qa_pair in json[json_key]: |
|
|
question = qa_pair[0] |
|
|
question = question.replace("<video_00>", image_tag) |
|
|
new_qa_pairs.append([question, qa_pair[1]]) |
|
|
json[json_key] = new_qa_pairs |
|
|
|
|
|
return img, json |
|
|
|
|
|
|
|
|
def sampling_multiturn_single_img( |
|
|
seq, |
|
|
count, |
|
|
multiturn_preserve_order=True, |
|
|
multiturn_continuous=False, |
|
|
is_train: bool = True, |
|
|
seed=None, |
|
|
): |
|
|
if seed is None: |
|
|
seed = np.random.default_rng() |
|
|
n_sample = min(count, len(seq)) |
|
|
|
|
|
if multiturn_continuous: |
|
|
if len(seq) <= n_sample: |
|
|
start_index = 0 |
|
|
else: |
|
|
start_index = seed.integers(0, len(seq) - n_sample) |
|
|
indices = range(start_index, start_index + n_sample) |
|
|
elif multiturn_preserve_order: |
|
|
indices = sorted(seed.choice(range(len(seq)), size=n_sample, replace=False)) |
|
|
else: |
|
|
indices = seed.choice(range(len(seq)), size=n_sample, replace=False) |
|
|
|
|
|
return [seq[i] for i in indices] |
|
|
|
|
|
|
|
|
def draw_bbox(image, bbox, lang="en", line_width=5, seed=None): |
|
|
if seed is None: |
|
|
seed = np.random.default_rng() |
|
|
color_str, color = generate_random_color(use_alpha=False, seed=seed) |
|
|
draw = ImageDraw.Draw(image, "RGB") |
|
|
rect_bbox = (bbox[0][0], bbox[0][1], bbox[2][0], bbox[2][1]) |
|
|
draw.rectangle(rect_bbox, outline=color, width=line_width) |
|
|
del draw |
|
|
if lang == "en": |
|
|
color_str = EN_COLOR[color_str] |
|
|
return image, color_str |
|
|
|
|
|
|
|
|
def bbox_process(bbox, detection_precision=2): |
|
|
bbox_str = "[" |
|
|
for idx, point in enumerate(bbox): |
|
|
if idx % 2 == 0: |
|
|
normalized = point |
|
|
else: |
|
|
normalized = point |
|
|
|
|
|
if idx < len(bbox) - 1: |
|
|
bbox_str += format(normalized, ".2f") + ", " |
|
|
else: |
|
|
bbox_str += format(normalized, ".2f") |
|
|
bbox_str += "]" |
|
|
return bbox_str |
|
|
|
|
|
|
|
|
def load_txt(file_path): |
|
|
lines_list = [] |
|
|
with open(file_path, "r") as file: |
|
|
for line in file: |
|
|
lines_list.append(line.replace("\\n", "\n").strip()) |
|
|
return lines_list |
|
|
|
|
|
|
|
|
def convert_format_for_multi_image( |
|
|
img, json, convert_key_list=["words", "text", "objects", "entities"] |
|
|
): |
|
|
"""single image dataset 과 multi image dataset 에서 읽어온 img, json format 이 다름. |
|
|
따라서 single image dataset 에서 읽어온 img, json 을 multi image dataset 의 format (dict) 으로 convert |
|
|
""" |
|
|
is_multi_image_dataset = isinstance(img, dict) |
|
|
if not is_multi_image_dataset: |
|
|
img = {"00": img} |
|
|
|
|
|
for convert_key in convert_key_list: |
|
|
if convert_key in json: |
|
|
json[convert_key] = {"00": json[convert_key]} |
|
|
|
|
|
for json_key in json: |
|
|
if "region" in json_key: |
|
|
json[json_key] = {"00": json[json_key]} |
|
|
else: |
|
|
for convert_key in convert_key_list: |
|
|
if convert_key in json: |
|
|
if isinstance(json[convert_key], list): |
|
|
json[convert_key] = {"00": json[convert_key]} |
|
|
|
|
|
for json_key in json: |
|
|
if "region" in json_key: |
|
|
if isinstance(json[json_key], list): |
|
|
json[json_key] = {"00": json[json_key]} |
|
|
|
|
|
return is_multi_image_dataset, img, json |
|
|
|
|
|
|
|
|
class ConditionalError(Exception): |
|
|
def __init__(self, message="Our assertion error"): |
|
|
super().__init__(message) |
|
|
|
|
|
|
|
|
def get_wds_default_config(default_config, existing_default_config=None): |
|
|
if existing_default_config is None: |
|
|
default_config_check_dict = { |
|
|
"subtask": "", |
|
|
"reasoning": False, |
|
|
"use_task_prompt": True, |
|
|
"get_random": True, |
|
|
"add_instruct_prompts": [], |
|
|
"multiturn_n_samples": 0, |
|
|
"multiturn_preserve_order": True, |
|
|
"multiturn_continuous": False, |
|
|
"insert_ocr": 200, |
|
|
"ocr_filter_strategy": "confidence", |
|
|
"ocr_use_ratio": 1.0, |
|
|
"entity_top_k": 0, |
|
|
"entity_keyword_threshold": 100, |
|
|
"entity_keyword_fashion_threshold": 100, |
|
|
"entity_use_ratio": 0.0, |
|
|
"llava_pretrain": False, |
|
|
"random_system_prob": 0.0, |
|
|
"random_system_path": "", |
|
|
"random_tool_prob": 0.005, |
|
|
} |
|
|
else: |
|
|
default_config_check_dict = existing_default_config |
|
|
if default_config is None: |
|
|
default_config = default_config_check_dict |
|
|
else: |
|
|
for key, value in default_config_check_dict.items(): |
|
|
if key not in default_config: |
|
|
default_config[key] = value |
|
|
return default_config |
|
|
|
|
|
|
|
|
def get_datalake_default_config(default_config): |
|
|
default_config_check_dict = { |
|
|
"multiturn_n_samples": 0, |
|
|
"multiturn_preserve_order": True, |
|
|
"multiturn_continuous": True, |
|
|
"insert_ocr": 0, |
|
|
"ocr_filter_strategy": "confidence", |
|
|
"entity_top_k": 0, |
|
|
"entity_keyword_threshold": 0, |
|
|
"entity_keyword_fashion_threshold": 0, |
|
|
"entity_use_ratio": 0.0, |
|
|
"ocr_use_ratio": 0.0, |
|
|
"llava_pretrain": False, |
|
|
"random_system_prob": 0.0, |
|
|
"random_system_path": "", |
|
|
"random_tool_prob": 0.005, |
|
|
} |
|
|
if default_config is None: |
|
|
default_config = default_config_check_dict |
|
|
else: |
|
|
for key, value in default_config_check_dict.items(): |
|
|
if key not in default_config: |
|
|
default_config[key] = value |
|
|
return default_config |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Processed_sample: |
|
|
input_str: str = None |
|
|
input_ids: list = None |
|
|
label_ids: list = None |
|
|
imgs: list = None |
|
|
discrete_imgs: list = None |
|
|
videos: list = None |
|
|
videos_duration: List[dict] = None |
|
|
video_audios: list = None |
|
|
audios: list = None |
|
|
audios_duration: List[dict] = None |
|
|
discrete_audios: list = None |
|
|
sample_mm_counter: dict = None |
|
|
|
|
|
|
|
|
from hcxvlm.dataset.bbox_processor import ( |
|
|
extract_bboxes, |
|
|
insert_bboxes_to_json, |
|
|
is_bbox_padded, |
|
|
) |
|
|
|
|
|
|
|
|
class Preprocessor: |
|
|
prompt_head = "" |
|
|
va_prefix = "\n<|im_start|>" |
|
|
new_line = "\n" |
|
|
turn_prefix = "<|im_start|>" |
|
|
turn_suffix = "<|im_end|>" |
|
|
mime_start = "<|mime_start|>" |
|
|
mime_end = "<|mime_end|>" |
|
|
aux_img_start = "<|image_aux_start|>" |
|
|
aux_img_end = "<|image_aux_end|>" |
|
|
aux_video_start = "<|video_aux_start|>" |
|
|
aux_video_end = "<|video_aux_end|>" |
|
|
aux_audio_start = "<|audio_aux_start|>" |
|
|
aux_audio_end = "<|audio_aux_end|>" |
|
|
image_start = "<|image_start|>" |
|
|
image_end = "<|image_end|>" |
|
|
image_pad = "<|IMAGE_PAD|>" |
|
|
video_start = "<|video_start|>" |
|
|
video_end = "<|video_end|>" |
|
|
video_pad = "<|VIDEO_PAD|>" |
|
|
audio_start = "<|audio_start|>" |
|
|
audio_end = "<|audio_end|>" |
|
|
audio_pad = "<|AUDIO_PAD|>" |
|
|
discrete_image_start = "<|discrete_image_start|>" |
|
|
discrete_image_end = "<|discrete_image_end|>" |
|
|
discrete_image_pad = "<|DISCRETE_IMAGE_PAD|>" |
|
|
video_audio_pad = "<|VIDEO_AUDIO_PAD|>" |
|
|
discrete_audio_start = "<|discrete_audio_start|>" |
|
|
discrete_audio_end = "<|discrete_audio_end|>" |
|
|
discrete_audio_pad = "<|DISCRETE_AUDIO_PAD|>" |
|
|
|
|
|
discrete_image_eol = "<|vision_eol|>" |
|
|
discrete_image_eof = "<|vision_eof|>" |
|
|
discrete_image_ratios = { |
|
|
(1, 1): "<|vision_ratio_1:1|>", |
|
|
(1, 2): "<|vision_ratio_1:2|>", |
|
|
(2, 1): "<|vision_ratio_2:1|>", |
|
|
(3, 4): "<|vision_ratio_3:4|>", |
|
|
(4, 3): "<|vision_ratio_4:3|>", |
|
|
(3, 5): "<|vision_ratio_3:5|>", |
|
|
(5, 3): "<|vision_ratio_5:3|>", |
|
|
(4, 5): "<|vision_ratio_4:5|>", |
|
|
(5, 4): "<|vision_ratio_5:4|>", |
|
|
(6, 9): "<|vision_ratio_6:9|>", |
|
|
(9, 6): "<|vision_ratio_9:6|>", |
|
|
(9, 16): "<|vision_ratio_9:16|>", |
|
|
(16, 9): "<|vision_ratio_16:9|>", |
|
|
} |
|
|
|
|
|
aux_vid_prompt = ( |
|
|
"다음 중 video_duration은 비디오 길이 정보입니다. 참고하여 답변하세요. " |
|
|
) |
|
|
aux_audio_prompt = ( |
|
|
"다음 중 audio_duration은 오디오 길이 정보입니다. 참고하여 답변하세요. " |
|
|
) |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
tokenizer=None, |
|
|
prepare_input_fn=None, |
|
|
prepare_audio_input_fn=None, |
|
|
sample_min_length=0, |
|
|
decoder_max_length=None, |
|
|
mode="train", |
|
|
model=None, |
|
|
datalake_default_config=None, |
|
|
wds_default_config=None, |
|
|
video_config=None, |
|
|
train_video=False, |
|
|
train_audio=False, |
|
|
sequence_parallel_size=1, |
|
|
video_audio_compressor_type=None, |
|
|
): |
|
|
self.sequence_parallel_size = sequence_parallel_size |
|
|
if sequence_parallel_size > 1: |
|
|
self.rng = np.random.default_rng(seed=42) |
|
|
else: |
|
|
self.rng = np.random.default_rng() |
|
|
|
|
|
if model is not None: |
|
|
tokenizer = model.tokenizer |
|
|
decoder_max_length = 16000 |
|
|
|
|
|
if model is not None and prepare_input_fn is None: |
|
|
raise "please give ImageProcessor!" |
|
|
|
|
|
self.prepare_input_fn = prepare_input_fn |
|
|
self.prepare_audio_input_fn = prepare_audio_input_fn |
|
|
try: |
|
|
from transformers.models.qwen2_5_vl.processing_qwen2_5_vl import ( |
|
|
Qwen2_5_VLProcessor, |
|
|
) |
|
|
|
|
|
self.is_qwen_visual = isinstance(prepare_input_fn, Qwen2_5_VLProcessor) |
|
|
except Exception as e: |
|
|
self.is_qwen_visual = False |
|
|
try: |
|
|
if not self.is_qwen_visual: |
|
|
from hcxvlm.models.processing_vlm import HCXVisionV2Processor |
|
|
|
|
|
self.is_qwen_visual = isinstance(prepare_input_fn, HCXVisionV2Processor) |
|
|
except Exception as e: |
|
|
self.is_qwen_visual = False |
|
|
assert self.is_qwen_visual, "qwen2.5-vl visual prepare_input_fn import error" |
|
|
|
|
|
self.video_max_num_frames = ( |
|
|
video_config["video_max_num_frames"] |
|
|
if video_config and "video_max_num_frames" in video_config |
|
|
else 120 |
|
|
) |
|
|
self.video_max_pixels = ( |
|
|
video_config["video_max_pixels"] |
|
|
if video_config and "video_max_pixels" in video_config |
|
|
else 378 * 378 |
|
|
) |
|
|
|
|
|
self.tokenizer = tokenizer |
|
|
self.sample_min_length = sample_min_length |
|
|
self.decoder_max_length = decoder_max_length |
|
|
self.mode = mode |
|
|
self.default_config = get_datalake_default_config(datalake_default_config) |
|
|
self.wds_default_config = get_wds_default_config(wds_default_config) |
|
|
self.train_video = train_video |
|
|
self.train_audio = train_audio |
|
|
self.video_audio_compressor_type = video_audio_compressor_type |
|
|
|
|
|
self.img_token = self.tokenizer.encode(Preprocessor.image_pad)[0] |
|
|
assert ( |
|
|
len(self.tokenizer.encode(Preprocessor.image_pad)) == 1 |
|
|
), "img_token is not configured in tokenizer" |
|
|
|
|
|
self.discrete_image_token = self.tokenizer.encode( |
|
|
Preprocessor.discrete_image_pad |
|
|
)[0] |
|
|
assert ( |
|
|
len(self.tokenizer.encode(Preprocessor.discrete_image_pad)) == 1 |
|
|
), "discrete_image_token is not configured in tokenizer" |
|
|
|
|
|
self.discrete_image_eol_token = self.tokenizer.encode( |
|
|
Preprocessor.discrete_image_eol |
|
|
)[0] |
|
|
assert ( |
|
|
len(self.tokenizer.encode(Preprocessor.discrete_image_eol)) == 1 |
|
|
), "discrete_image_eol_token is not configured in tokenizer" |
|
|
|
|
|
self.discrete_image_eof_token = self.tokenizer.encode( |
|
|
Preprocessor.discrete_image_eof |
|
|
)[0] |
|
|
assert ( |
|
|
len(self.tokenizer.encode(Preprocessor.discrete_image_eof)) == 1 |
|
|
), "discrete_image_eof_token is not configured in tokenizer" |
|
|
|
|
|
self.discrete_image_ratio_tokens = dict() |
|
|
for ratio, token_str in Preprocessor.discrete_image_ratios.items(): |
|
|
token_id = self.tokenizer.encode(token_str)[0] |
|
|
assert ( |
|
|
len(self.tokenizer.encode(token_str)) == 1 |
|
|
), f"discrete_image_ratio_token {token_str} is not configured in tokenizer" |
|
|
self.discrete_image_ratio_tokens[ratio] = token_id |
|
|
|
|
|
self.video_token = self.tokenizer.encode(Preprocessor.video_pad)[0] |
|
|
assert ( |
|
|
len(self.tokenizer.encode(Preprocessor.video_pad)) == 1 |
|
|
), "video_token is not configured in tokenizer" |
|
|
|
|
|
self.video_audio_token = self.tokenizer.encode(Preprocessor.video_audio_pad)[0] |
|
|
assert ( |
|
|
len(self.tokenizer.encode(Preprocessor.video_audio_pad)) == 1 |
|
|
), "video_audio_token is not configured in tokenizer" |
|
|
|
|
|
def resize_min_edge(img: Image.Image) -> Image.Image: |
|
|
w, h = img.size |
|
|
min_size = 28 |
|
|
if min(w, h) >= min_size: |
|
|
return img |
|
|
if w < h: |
|
|
new_w = min_size |
|
|
new_h = int(h * (min_size / w)) |
|
|
else: |
|
|
new_h = min_size |
|
|
new_w = int(w * (min_size / h)) |
|
|
return img.resize((new_w, new_h), Image.BICUBIC) |
|
|
|
|
|
self._resize_min_edge = resize_min_edge |
|
|
|
|
|
self.audio_token = self.tokenizer.encode(Preprocessor.audio_pad)[0] |
|
|
assert ( |
|
|
len(self.tokenizer.encode(Preprocessor.audio_pad)) == 1 |
|
|
), "audio_token is not configured in tokenizer" |
|
|
|
|
|
self.discrete_audio_token = self.tokenizer.encode( |
|
|
Preprocessor.discrete_audio_pad |
|
|
)[0] |
|
|
assert ( |
|
|
len(self.tokenizer.encode(Preprocessor.discrete_audio_pad)) == 1 |
|
|
), "audio_token is not configured in tokenizer" |
|
|
|
|
|
from hcxvlm.dataset.json_processer import generate_prompt |
|
|
|
|
|
self.generate_prompt = generate_prompt |
|
|
|
|
|
self.mimes = list() |
|
|
for mime_filename in [ |
|
|
"words_alpha.txt", |
|
|
"korean-366506-wordslistUnique.txt", |
|
|
]: |
|
|
self.mimes += ( |
|
|
pkg_resources.resource_string( |
|
|
"hcxvlm", f"dataset/hcx_vision_prompter/prompts/{mime_filename}" |
|
|
) |
|
|
.decode("utf-8") |
|
|
.split("\r\n") |
|
|
) |
|
|
|
|
|
self.common_tools = [] |
|
|
try: |
|
|
common_tools_bytes = pkg_resources.resource_string( |
|
|
"hcxvlm", |
|
|
"dataset/hcx_vision_prompter/prompts/common_tools.jsonl", |
|
|
) |
|
|
for line in common_tools_bytes.decode("utf-8").splitlines(): |
|
|
line = line.strip() |
|
|
if not line: |
|
|
continue |
|
|
try: |
|
|
self.common_tools.append(json.loads(line)) |
|
|
except Exception: |
|
|
continue |
|
|
except Exception: |
|
|
self.common_tools = [] |
|
|
|
|
|
self.random_system_prompt = "" |
|
|
if self.default_config["random_system_path"] != "": |
|
|
self.random_system_prompt = "" |
|
|
with open(self.default_config["random_system_path"], "r") as f: |
|
|
for line in f: |
|
|
self.random_system_prompt += line |
|
|
|
|
|
if ( |
|
|
self.random_system_prompt != "" |
|
|
and self.wds_default_config["random_system_path"] != "" |
|
|
): |
|
|
assert ( |
|
|
self.wds_default_config["random_system_path"] |
|
|
== self.default_config["random_system_path"] |
|
|
), "random_system_path in both default_config and wds_default_config should be the same" |
|
|
|
|
|
def _find_best_ratio_token(self, original_size): |
|
|
"""Find the best ratio token based on original_size""" |
|
|
base_ratios = list(self.discrete_image_ratio_tokens.keys()) |
|
|
vision_aspect_ratios = [ |
|
|
r for ratio in base_ratios for r in [ratio, ratio[::-1]] |
|
|
][1:] |
|
|
|
|
|
if not isinstance(original_size, list) or len(original_size) != 2: |
|
|
return self.discrete_image_ratio_tokens[(1, 1)] |
|
|
|
|
|
h, w = original_size |
|
|
if h == 0 or w == 0: |
|
|
return self.discrete_image_ratio_tokens[(1, 1)] |
|
|
|
|
|
ratios = [i / j for i, j in vision_aspect_ratios] |
|
|
|
|
|
best_size_idx = np.argmin([abs(w / h - r) for r in ratios]) |
|
|
|
|
|
i, j = vision_aspect_ratios[best_size_idx] |
|
|
return self.discrete_image_ratio_tokens[(i, j)] |
|
|
|
|
|
@classmethod |
|
|
def prompt_mime( |
|
|
cls, |
|
|
mimes: Optional[list[str]] = None, |
|
|
file_name: str = None, |
|
|
tag_idx: int = 1, |
|
|
fixed_mime: bool = False, |
|
|
is_video: bool = False, |
|
|
is_audio: bool = False, |
|
|
seed: np.random.Generator = None, |
|
|
) -> list[dict]: |
|
|
assert mimes or file_name |
|
|
|
|
|
if seed is None: |
|
|
seed = np.random.default_rng() |
|
|
|
|
|
if file_name: |
|
|
name, ext = os.path.splitext(file_name) |
|
|
ext = ext.lstrip(".") |
|
|
elif fixed_mime: |
|
|
ext = "jpeg" |
|
|
name = mimes[tag_idx] |
|
|
elif not fixed_mime and seed is not None: |
|
|
ext = seed.choice(["png", "jpeg"]) |
|
|
name = mimes[seed.integers(0, len(mimes))] |
|
|
else: |
|
|
ext = "jpeg" |
|
|
name = mimes[tag_idx] |
|
|
|
|
|
if is_video: |
|
|
ext_candidates = ["mp4", "mov", "avi", "webm"] |
|
|
if fixed_mime: |
|
|
ext = "mp4" |
|
|
elif ext not in ext_candidates: |
|
|
ext = seed.choice(ext_candidates) |
|
|
|
|
|
filename = f"{name}.{ext}" |
|
|
mime_type = mimetypes.guess_type(filename)[0] |
|
|
mime_prompt = { |
|
|
"id": f"video_{str(tag_idx).zfill(2)}", |
|
|
"type": f"{mime_type}", |
|
|
"filename": f"{filename}", |
|
|
} |
|
|
return mime_prompt |
|
|
|
|
|
if is_audio: |
|
|
ext_candidates = ["mp3", "wav", "aac", "flac", "pcm"] |
|
|
if fixed_mime: |
|
|
ext = "wav" |
|
|
elif ext not in ext_candidates: |
|
|
ext = seed.choice(ext_candidates) |
|
|
|
|
|
filename = f"{name}.{ext}" |
|
|
mime_type = mimetypes.guess_type(filename)[0] |
|
|
mime_prompt = { |
|
|
"id": f"audio_{str(tag_idx).zfill(2)}", |
|
|
"type": f"{mime_type}", |
|
|
"filename": f"{filename}", |
|
|
} |
|
|
return mime_prompt |
|
|
|
|
|
if file_name: |
|
|
filename = f"{name}.{ext}" |
|
|
mime_type = mimetypes.guess_type(filename)[0] |
|
|
mime_prompt = { |
|
|
"id": f"image_{str(tag_idx).zfill(2)}", |
|
|
"type": f"{mime_type}", |
|
|
"filename": f"{filename}", |
|
|
} |
|
|
else: |
|
|
mime_prompt = { |
|
|
"id": f"image_{str(tag_idx).zfill(2)}", |
|
|
"type": f"image/{ext}", |
|
|
"filename": f"{name}.{'jpg' if ext == 'jpeg' else 'png'}", |
|
|
} |
|
|
return mime_prompt |
|
|
|
|
|
@classmethod |
|
|
def ocr_preprocess( |
|
|
cls, |
|
|
words: list[dict], |
|
|
n_insert_ocr_tokens: int = 2000, |
|
|
insert_ocr: int = 200, |
|
|
ocr_use_ratio: float = 0.5, |
|
|
tokenizer=None, |
|
|
seed=None, |
|
|
) -> list[str]: |
|
|
if seed is None: |
|
|
seed = np.random.default_rng() |
|
|
if ocr_use_ratio < seed.random(): |
|
|
return None |
|
|
if insert_ocr == 0: |
|
|
return None |
|
|
|
|
|
confidence_list = [] |
|
|
insert_ocr_prompt = [] |
|
|
for word in words: |
|
|
if "confidence" in word: |
|
|
confidence_list.append(word["confidence"]) |
|
|
has_ocr_confidence = len(confidence_list) >= insert_ocr |
|
|
|
|
|
if len(words) <= insert_ocr or not has_ocr_confidence: |
|
|
insert_ocr_prompt += [ |
|
|
d["text"].strip() for d in words if d["text"].strip() |
|
|
][:insert_ocr] |
|
|
else: |
|
|
confidence_threshold = 0.3 |
|
|
cnt = 0 |
|
|
for word in words: |
|
|
if word["text"] == "": |
|
|
continue |
|
|
if word["confidence"] >= confidence_threshold: |
|
|
insert_ocr_prompt.append(word["text"]) |
|
|
cnt += 1 |
|
|
if cnt >= insert_ocr: |
|
|
break |
|
|
ocr_inputs = " ".join(insert_ocr_prompt) |
|
|
if tokenizer: |
|
|
ocr_inputs = tokenizer.decode( |
|
|
tokenizer.encode(ocr_inputs)[:n_insert_ocr_tokens] |
|
|
) |
|
|
return ocr_inputs |
|
|
|
|
|
@classmethod |
|
|
def lens_preprocess( |
|
|
cls, |
|
|
lens: list[dict], |
|
|
entity_top_k: int = 100, |
|
|
entity_keyword_threshold: float = 0.0, |
|
|
entity_keyword_fashion_threshold: float = 0.0, |
|
|
entity_use_ratio: float = 0.0, |
|
|
seed=None, |
|
|
): |
|
|
if seed is None: |
|
|
seed = np.random.default_rng() |
|
|
if seed.uniform(0, 1) > entity_use_ratio: |
|
|
return None |
|
|
|
|
|
entities = lens |
|
|
filter_idx = [] |
|
|
insert_entity_prompt = {} |
|
|
for idx, entity in enumerate(entities): |
|
|
if entity["type"] != "naver_lens_api": |
|
|
filter_idx.append(idx) |
|
|
continue |
|
|
if ( |
|
|
isinstance(entity_keyword_threshold, (int, float)) |
|
|
and entity["confidence"] < entity_keyword_threshold |
|
|
): |
|
|
filter_idx.append(idx) |
|
|
continue |
|
|
if ( |
|
|
isinstance(entity_keyword_fashion_threshold, (int, float)) |
|
|
and ("fashion" in entity["info"]["classes"]) |
|
|
and entity["confidence"] < entity_keyword_fashion_threshold |
|
|
): |
|
|
filter_idx.append(idx) |
|
|
continue |
|
|
|
|
|
entityvalue = [ |
|
|
keyword for idx, keyword in enumerate(entities) if idx not in filter_idx |
|
|
] |
|
|
entityvalue = sorted(entityvalue, key=lambda x: x["confidence"], reverse=True) |
|
|
|
|
|
important_entity_list = [] |
|
|
local_entity_str_list = [] |
|
|
keywords_and_bbox_per_detector = {} |
|
|
for keyword_dict in entityvalue[:entity_top_k]: |
|
|
object_class = "/".join(keyword_dict["info"]["classes"]) |
|
|
if object_class not in keywords_and_bbox_per_detector.keys(): |
|
|
keywords_and_bbox_per_detector[object_class] = [] |
|
|
keywords_and_bbox_per_detector[object_class].append(keyword_dict) |
|
|
|
|
|
for object_class in keywords_and_bbox_per_detector.keys(): |
|
|
entities_per_object = keywords_and_bbox_per_detector[object_class] |
|
|
normalized_bbox = bbox_process( |
|
|
[*entities_per_object[0]["bbox"][0], *entities_per_object[0]["bbox"][2]] |
|
|
) |
|
|
entities = [entity["text"] for entity in entities_per_object] |
|
|
if "context" in object_class: |
|
|
important_entity_list += entities |
|
|
|
|
|
else: |
|
|
local_entity_str_list += [ |
|
|
str(normalized_bbox) + " " + ", ".join(entities) |
|
|
] |
|
|
if len(important_entity_list) > 0: |
|
|
insert_entity_prompt["lens_keywords"] = ", ".join(important_entity_list) |
|
|
if len(local_entity_str_list) > 0: |
|
|
insert_entity_prompt["lens_local_keywords"] = " ".join( |
|
|
local_entity_str_list |
|
|
) |
|
|
|
|
|
return insert_entity_prompt |
|
|
|
|
|
@classmethod |
|
|
def prompt_toollist( |
|
|
cls, |
|
|
output, |
|
|
tokenizer=None, |
|
|
turn: Optional[dict] = None, |
|
|
content: Optional[list[dict]] = None, |
|
|
): |
|
|
assert content or turn |
|
|
if turn is None: |
|
|
turn = { |
|
|
"role": "tool_list", |
|
|
"content": content, |
|
|
} |
|
|
|
|
|
toollist_str = ( |
|
|
cls.turn_prefix.strip() |
|
|
+ turn["role"] |
|
|
+ "\n" |
|
|
+ turn["content"] |
|
|
+ cls.turn_suffix |
|
|
) |
|
|
|
|
|
if hasattr(output, "input_str"): |
|
|
output.input_str += toollist_str |
|
|
|
|
|
if getattr(output, "input_ids", None) is not None: |
|
|
token_ids = tokenizer.encode(toollist_str, truncation=False) |
|
|
output.input_ids += token_ids |
|
|
output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] |
|
|
return output |
|
|
|
|
|
@classmethod |
|
|
def prompt_system( |
|
|
cls, |
|
|
output, |
|
|
tokenizer=None, |
|
|
turn: Optional[dict] = None, |
|
|
content: Optional[str] = None, |
|
|
seed=None, |
|
|
tool_prompt=None, |
|
|
system_role_count=0, |
|
|
): |
|
|
assert content or turn |
|
|
if seed is None: |
|
|
seed = np.random.default_rng() |
|
|
if turn is None: |
|
|
system_prompt = content |
|
|
else: |
|
|
if "candidates" in turn: |
|
|
if len(turn["candidates"]) > 0: |
|
|
system_prompt = seed.choice(turn["candidates"]) |
|
|
if type(system_prompt) is dict: |
|
|
system_prompt = system_prompt["content"] |
|
|
else: |
|
|
system_prompt = "" |
|
|
elif isinstance(turn["content"], str): |
|
|
system_prompt = turn["content"] |
|
|
elif len(turn["content"]) > 0: |
|
|
system_prompt = seed.choice(turn["content"]) |
|
|
|
|
|
system_str = cls.turn_prefix + turn["role"] + "\n" |
|
|
system_str += system_prompt.strip() |
|
|
if system_role_count == 0: |
|
|
if system_prompt.strip(): |
|
|
system_str += "\n" |
|
|
system_str += tool_prompt |
|
|
system_str += cls.turn_suffix |
|
|
|
|
|
if hasattr(output, "input_str"): |
|
|
output.input_str += system_str |
|
|
|
|
|
if getattr(output, "input_ids", None) is not None: |
|
|
token_ids = tokenizer.encode(system_str, truncation=False) |
|
|
output.input_ids += token_ids |
|
|
output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] |
|
|
return output |
|
|
|
|
|
@classmethod |
|
|
def load_mm( |
|
|
cls, |
|
|
output, |
|
|
img_dir: str = "", |
|
|
turn: Optional[dict] = None, |
|
|
image_urls: Optional[list[str]] = None, |
|
|
image_metas: Optional[list[dict]] = None, |
|
|
video_urls: Optional[list[str]] = None, |
|
|
video_metas: Optional[list[dict]] = None, |
|
|
audio_urls: Optional[list[str]] = None, |
|
|
audio_metas: Optional[list[dict]] = None, |
|
|
prepare_input_fn=None, |
|
|
prepare_audio_input_fn=None, |
|
|
max_image_cnt=21, |
|
|
video_max_num_frames=None, |
|
|
video_max_pixels=None, |
|
|
use_audio: bool = False, |
|
|
audio_sample_rate: int = 16000, |
|
|
): |
|
|
assert (image_urls or video_urls or audio_urls) or turn |
|
|
if turn is None: |
|
|
turn = {} |
|
|
if image_urls: |
|
|
turn.update({"image_urls": image_urls}) |
|
|
turn.update({"image_metas": image_metas}) |
|
|
if video_urls: |
|
|
turn.update({"video_urls": video_urls}) |
|
|
turn.update({"video_metas": video_metas}) |
|
|
if audio_urls: |
|
|
turn.update({"audio_urls": audio_urls}) |
|
|
turn.update({"audio_metas": audio_metas}) |
|
|
|
|
|
if "video_urls" in turn: |
|
|
if len(turn["video_urls"]) and (prepare_input_fn is None): |
|
|
raise ConditionalError("video processing needs 'prepare_input_fn'") |
|
|
|
|
|
if not isinstance(turn["content"], str): |
|
|
raise ConditionalError(f"turn['content'] must be a string") |
|
|
|
|
|
turn["content"] = re.sub(r"<image_\d+>", "<|image|>", turn["content"]) |
|
|
pattern = re.compile( |
|
|
r"<\|video\|>|<\|image\|>|<\|t2i_model_generation_target_discrete_image\|>|<\|audio\|>|<\|discrete_audio\|>" |
|
|
) |
|
|
tags = [match.group() for match in pattern.finditer(turn["content"])] |
|
|
|
|
|
img_idx = 0 |
|
|
vid_idx = 0 |
|
|
aud_idx = 0 |
|
|
|
|
|
if "image_urls" not in turn: |
|
|
turn["image_urls"] = [] |
|
|
if "video_urls" not in turn: |
|
|
turn["video_urls"] = [] |
|
|
if "audio_urls" not in turn: |
|
|
turn["audio_urls"] = [] |
|
|
|
|
|
for tag in tags: |
|
|
if ( |
|
|
tag == "<|image|>" |
|
|
or tag == "<|t2i_model_generation_target_discrete_image|>" |
|
|
): |
|
|
img_path = turn["image_urls"][img_idx] |
|
|
|
|
|
if isinstance(img_path, str): |
|
|
if "#" in img_path: |
|
|
compression_path, img_path = img_path.split("#", 1) |
|
|
compression_path = os.path.join(img_dir, compression_path) |
|
|
assert compression_path[-4:] in [ |
|
|
".zip", |
|
|
".tar", |
|
|
], f"unsupported compression format: {compression_path}" |
|
|
|
|
|
with open(compression_path, "rb") as comp_file: |
|
|
if compression_path.endswith(".zip"): |
|
|
with zipfile.ZipFile(comp_file, "r") as zip_file: |
|
|
with zip_file.open(img_path) as img_file: |
|
|
img_binary = img_file.read() |
|
|
elif compression_path.endswith(".tar"): |
|
|
with tarfile.open( |
|
|
fileobj=comp_file, mode="r" |
|
|
) as tar_file: |
|
|
img_file = tar_file.extractfile(img_path) |
|
|
img_binary = img_file.read() |
|
|
else: |
|
|
with open(os.path.join(img_dir, img_path), "rb") as f: |
|
|
img_binary = f.read() |
|
|
img = image_decoder(img_binary) |
|
|
else: |
|
|
if isinstance(img_path, (bytes, bytearray)): |
|
|
img = io.BytesIO(img_path) |
|
|
img = Image.open(img).convert("RGB") |
|
|
else: |
|
|
img = img_path |
|
|
if not isinstance(img, Image.Image): |
|
|
img = Image.fromarray(np.uint8(img)).convert("RGB") |
|
|
|
|
|
if "image_metas" in turn and turn["image_metas"]: |
|
|
turn["image_metas"][img_idx] = convert_bboxes( |
|
|
img, turn["image_metas"][img_idx] |
|
|
) |
|
|
|
|
|
if tag == "<|image|>": |
|
|
output.imgs.append(img) |
|
|
output.discrete_imgs.append(img) |
|
|
|
|
|
img_idx += 1 |
|
|
elif tag == "<|video|>": |
|
|
video_path = turn["video_urls"][vid_idx] |
|
|
if isinstance(video_path, str): |
|
|
if "#" in video_path: |
|
|
compression_path, video_path = video_path.split("#", 1) |
|
|
compression_path = os.path.join(img_dir, compression_path) |
|
|
assert compression_path[-4:] in [ |
|
|
".zip", |
|
|
".tar", |
|
|
], f"unsupported compression format: {compression_path}" |
|
|
|
|
|
with open(compression_path, "rb") as comp_file: |
|
|
if compression_path.endswith(".zip"): |
|
|
with zipfile.ZipFile(comp_file, "r") as zip_file: |
|
|
video_file = zip_file.open(video_path) |
|
|
video_binary = video_file.read() |
|
|
elif compression_path.endswith(".tar"): |
|
|
with tarfile.open( |
|
|
fileobj=comp_file, mode="r" |
|
|
) as tar_file: |
|
|
video_file = tar_file.extractfile(video_path) |
|
|
video_binary = video_file.read() |
|
|
else: |
|
|
with open(os.path.join(img_dir, video_path), "rb") as f: |
|
|
video_binary = f.read() |
|
|
video_binary = io.BytesIO(video_binary) |
|
|
else: |
|
|
video_binary = video_path |
|
|
|
|
|
assert isinstance(video_binary, io.BytesIO), "video binary read error" |
|
|
|
|
|
try: |
|
|
from hcxvlm.dataset.qwen_vision_process import process_vision_info |
|
|
except: |
|
|
from qwen_vl_utils import process_vision_info |
|
|
|
|
|
if video_max_num_frames is None: |
|
|
video_max_num_frames = 120 |
|
|
if video_max_pixels is None: |
|
|
video_max_pixels = 378 * 378 |
|
|
|
|
|
messages = [ |
|
|
[ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "video", |
|
|
"video": video_binary, |
|
|
"max_frames": video_max_num_frames, |
|
|
"max_pixels": video_max_pixels, |
|
|
} |
|
|
], |
|
|
} |
|
|
], |
|
|
] |
|
|
_, videos, video_kwargs = process_vision_info( |
|
|
messages, |
|
|
return_video_kwargs=True, |
|
|
use_audio=use_audio, |
|
|
audio_sample_rate=audio_sample_rate, |
|
|
) |
|
|
output.videos.append(videos[0]) |
|
|
video_len = round(videos[0].shape[0] / video_kwargs["fps"][0], 2) |
|
|
output.videos_duration.append( |
|
|
{ |
|
|
"video_duration": f"{video_len}s", |
|
|
} |
|
|
) |
|
|
|
|
|
if use_audio and "audio_chunks" in video_kwargs: |
|
|
audio_chunks = video_kwargs["audio_chunks"][0] |
|
|
if audio_chunks is not None: |
|
|
output.video_audios.append(audio_chunks) |
|
|
else: |
|
|
output.video_audios.append([]) |
|
|
elif use_audio: |
|
|
output.video_audios.append([]) |
|
|
|
|
|
vid_idx += 1 |
|
|
|
|
|
elif tag == "<|audio|>" or tag == "<|discrete_audio|>": |
|
|
audio_path = turn["audio_urls"][aud_idx] |
|
|
if isinstance(audio_path, str): |
|
|
if "#" in audio_path: |
|
|
compression_path, inner_path = audio_path.split("#", 1) |
|
|
compression_path = os.path.join(img_dir, compression_path) |
|
|
assert compression_path[-4:] in [ |
|
|
".zip", |
|
|
".tar", |
|
|
], f"unsupported compression format: {compression_path}" |
|
|
with open(compression_path, "rb") as comp_file: |
|
|
if compression_path.endswith(".zip"): |
|
|
with zipfile.ZipFile(comp_file, "r") as zip_file: |
|
|
with zip_file.open(inner_path) as audio_file: |
|
|
audio_binary = audio_file.read() |
|
|
elif compression_path.endswith(".tar"): |
|
|
with tarfile.open( |
|
|
fileobj=comp_file, mode="r" |
|
|
) as tar_file: |
|
|
audio_file = tar_file.extractfile(inner_path) |
|
|
audio_binary = audio_file.read() |
|
|
else: |
|
|
with open(os.path.join(img_dir, audio_path), "rb") as f: |
|
|
audio_binary = f.read() |
|
|
audio_stream = io.BytesIO(audio_binary) |
|
|
else: |
|
|
if isinstance(audio_path, (bytes, bytearray)): |
|
|
audio_stream = io.BytesIO(audio_path) |
|
|
else: |
|
|
audio_stream = audio_path |
|
|
|
|
|
try: |
|
|
import librosa |
|
|
|
|
|
y, sr = librosa.load( |
|
|
audio_stream, sr=DEFAULT_SAMPLE_RATE, mono=True |
|
|
) |
|
|
assert ( |
|
|
DEFAULT_SAMPLE_RATE == sr |
|
|
), f"librosa resampling failed: {DEFAULT_SAMPLE_RATE} != {sr}" |
|
|
except Exception as e: |
|
|
raise ConditionalError( |
|
|
f"audio decoding failed for {audio_path}: {e}" |
|
|
) |
|
|
|
|
|
audio_duration = len(y) / sr |
|
|
if audio_duration < 0.5: |
|
|
raise ConditionalError( |
|
|
f"Audio too short ({audio_duration:.2f}s). Minimum 0.5s required." |
|
|
) |
|
|
if audio_duration > 600: |
|
|
raise ConditionalError( |
|
|
f"Audio duration ({audio_duration:.2f}s) exceeds maximum allowed duration (600s)" |
|
|
) |
|
|
|
|
|
if len(y) < MIN_DISCRETE_AUDIO_CHUNK_SAMPLES: |
|
|
raise ConditionalError( |
|
|
f"Audio too short ({len(y)} samples = {audio_duration:.4f}s < 0.1s). " |
|
|
f"Minimum {MIN_DISCRETE_AUDIO_CHUNK_SAMPLES} samples required for CosyVoice encoder." |
|
|
) |
|
|
|
|
|
if not hasattr(output, "audios"): |
|
|
output.audios = [] |
|
|
if not hasattr(output, "discrete_audios"): |
|
|
output.discrete_audios = [] |
|
|
|
|
|
normalized_y = hpf_normalize(y) |
|
|
normalized_y = torch.from_numpy(normalized_y).float() |
|
|
|
|
|
output.discrete_audios.append(normalized_y) |
|
|
if tag == "<|audio|>": |
|
|
|
|
|
output.audios.append(y) |
|
|
total_duration = len(y) / sr |
|
|
output.audios_duration.append( |
|
|
{ |
|
|
"duration": f"{(total_duration):.2f}s", |
|
|
} |
|
|
) |
|
|
|
|
|
aud_idx += 1 |
|
|
else: |
|
|
raise ConditionalError( |
|
|
f"{tag} is not in ['<|image|>', '<|video|>', '<|audio|>']" |
|
|
) |
|
|
|
|
|
return output |
|
|
|
|
|
@classmethod |
|
|
def prompt_user( |
|
|
cls, |
|
|
output, |
|
|
tokenizer=None, |
|
|
turn: Optional[dict] = None, |
|
|
content: Optional[str] = None, |
|
|
is_train=False, |
|
|
fixed_mime=False, |
|
|
insert_ocr=300, |
|
|
file_names: Optional[list[str]] = None, |
|
|
mimes: Optional[list[str]] = None, |
|
|
mm_tokens: Optional[list[str]] = None, |
|
|
words: Optional[list] = None, |
|
|
lens: Optional[list] = None, |
|
|
query_template: Optional[list[str]] = None, |
|
|
config: Optional[dict] = None, |
|
|
seed: np.random.Generator = None, |
|
|
): |
|
|
assert content or turn |
|
|
if turn is None: |
|
|
image_metas = [ |
|
|
{"words": words[i], "lens": lens[i]} for i in range(len(words)) |
|
|
] |
|
|
turn = { |
|
|
"content": content, |
|
|
"image_metas": image_metas, |
|
|
} |
|
|
if seed is None: |
|
|
seed = np.random.default_rng() |
|
|
|
|
|
turn["content"] = re.sub(r"<image_\d+>", "<|image|>", turn["content"]) |
|
|
turn["content"] = re.sub(r"<video_\d+>", "<|video|>", turn["content"]) |
|
|
turn["content"] = re.sub(r"<audio_\d+>", "<|audio|>", turn["content"]) |
|
|
|
|
|
pattern = re.compile(r"(<\|video\|>|<\|image\|>|<\|audio\|>)") |
|
|
|
|
|
all_tags_in_order = [ |
|
|
match.group() for match in pattern.finditer(turn["content"]) |
|
|
] |
|
|
n_vids = sum(1 for tag in all_tags_in_order if tag == "<|video|>") |
|
|
n_audios = sum(1 for tag in all_tags_in_order if tag == "<|audio|>") |
|
|
|
|
|
assert ( |
|
|
len(turn.get("image_urls", [])) |
|
|
+ len(turn.get("video_urls", [])) |
|
|
+ len(turn.get("audio_urls", [])) |
|
|
) == len( |
|
|
all_tags_in_order |
|
|
), f"Number of media URLs does not match number of media tags." |
|
|
|
|
|
if mm_tokens is None: |
|
|
mm_tokens = [ |
|
|
cls.audio_pad if tag == "<|audio|>" else cls.image_pad |
|
|
for tag in all_tags_in_order |
|
|
] |
|
|
|
|
|
assert len(mm_tokens) == len(all_tags_in_order) |
|
|
|
|
|
if config.get("llava_pretrain", False): |
|
|
mm_str = "".join([mm_tokens[i] for i in range(len(all_tags_in_order))]) |
|
|
if hasattr(output, "input_str"): |
|
|
output.input_str += mm_str |
|
|
|
|
|
if getattr(output, "input_ids", None) is not None: |
|
|
token_ids = tokenizer.encode(mm_str, truncation=False) |
|
|
output.input_ids += token_ids |
|
|
output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] |
|
|
return output |
|
|
|
|
|
if query_template: |
|
|
processed_content = seed.choice(query_template).format(turn["content"]) |
|
|
|
|
|
tags_after_template = pattern.findall(processed_content) |
|
|
if len(all_tags_in_order) != len(tags_after_template): |
|
|
cleaned_template_text = pattern.sub("", processed_content) |
|
|
processed_content = "".join(all_tags_in_order) + cleaned_template_text |
|
|
turn["content"] = processed_content |
|
|
|
|
|
content_parts = pattern.split(turn["content"].strip()) |
|
|
|
|
|
if hasattr(output, "input_str"): |
|
|
output.input_str += f"{cls.new_line}{cls.turn_prefix}{turn['role']}" |
|
|
if getattr(output, "input_ids", None) is not None: |
|
|
role_encoded = tokenizer.encode( |
|
|
f"{cls.new_line}{cls.turn_prefix}{turn['role']}", truncation=False |
|
|
) |
|
|
output.input_ids += role_encoded |
|
|
if turn.get("trainable_role", False): |
|
|
output.label_ids += role_encoded |
|
|
else: |
|
|
output.label_ids += [IGNORE_INDEX for _ in range(len(role_encoded))] |
|
|
|
|
|
tag_cursor = 0 |
|
|
|
|
|
for part in content_parts: |
|
|
part = part.strip() |
|
|
|
|
|
if not part: |
|
|
continue |
|
|
|
|
|
if part not in ["<|image|>", "<|video|>", "<|audio|>"]: |
|
|
content_text = part |
|
|
|
|
|
if hasattr(output, "input_str"): |
|
|
output.input_str += "\n" + content_text |
|
|
if getattr(output, "input_ids", None) is not None: |
|
|
content_encoded = tokenizer.encode( |
|
|
"\n" + content_text, truncation=False |
|
|
) |
|
|
output.input_ids += content_encoded |
|
|
if turn.get("trainable_content", False): |
|
|
output.label_ids += content_encoded |
|
|
else: |
|
|
output.label_ids += [ |
|
|
IGNORE_INDEX for _ in range(len(content_encoded)) |
|
|
] |
|
|
continue |
|
|
|
|
|
if part == "<|image|>": |
|
|
mime = Preprocessor.prompt_mime( |
|
|
mimes=mimes if not file_names else None, |
|
|
fixed_mime=fixed_mime if not file_names else False, |
|
|
file_name=file_names[tag_cursor] if file_names else None, |
|
|
tag_idx=output.sample_mm_counter["image"], |
|
|
is_video=False, |
|
|
is_audio=False, |
|
|
seed=seed, |
|
|
) |
|
|
mime_str = f"{cls.mime_start}{json.dumps(mime, ensure_ascii=False)}{cls.mime_end}" |
|
|
discrete_image_str = f"{cls.discrete_image_start}{cls.discrete_image_pad}{cls.discrete_image_end}" |
|
|
vector_str = f"{cls.image_start}{cls.image_pad}{cls.image_end}" |
|
|
mm_str = ( |
|
|
cls.new_line |
|
|
+ mime_str |
|
|
+ cls.new_line |
|
|
+ discrete_image_str |
|
|
+ cls.new_line |
|
|
+ vector_str |
|
|
) |
|
|
|
|
|
if hasattr(output, "input_str"): |
|
|
output.input_str += mm_str |
|
|
if getattr(output, "input_ids", None) is not None: |
|
|
token_ids = tokenizer.encode(mm_str, truncation=False) |
|
|
output.input_ids += token_ids |
|
|
output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] |
|
|
|
|
|
output.sample_mm_counter["image"] += 1 |
|
|
tag_cursor += 1 |
|
|
|
|
|
elif part == "<|video|>": |
|
|
mime = Preprocessor.prompt_mime( |
|
|
mimes=mimes if not file_names else None, |
|
|
fixed_mime=fixed_mime if not file_names else False, |
|
|
file_name=file_names[tag_cursor] if file_names else None, |
|
|
tag_idx=output.sample_mm_counter["video"], |
|
|
is_video=True, |
|
|
is_audio=False, |
|
|
seed=seed, |
|
|
) |
|
|
mm_str = "" |
|
|
aux_inputs = { |
|
|
"video_duration": output.videos_duration[ |
|
|
output.sample_mm_counter["video"] |
|
|
]["video_duration"], |
|
|
} |
|
|
mime_str = f"{cls.mime_start}{json.dumps(mime, ensure_ascii=False)}{cls.mime_end}" |
|
|
aux_str = f"{cls.aux_video_start}{cls.aux_vid_prompt}{json.dumps(aux_inputs, ensure_ascii=False)}{cls.aux_video_end}" |
|
|
vector_str = f"{cls.video_start}{cls.video_pad}{cls.video_end}" |
|
|
mm_str += ( |
|
|
cls.new_line |
|
|
+ mime_str |
|
|
+ cls.new_line |
|
|
+ aux_str |
|
|
+ cls.new_line |
|
|
+ vector_str |
|
|
) |
|
|
if hasattr(output, "input_str"): |
|
|
output.input_str += mm_str |
|
|
if getattr(output, "input_ids", None) is not None: |
|
|
token_ids = tokenizer.encode(mm_str, truncation=False) |
|
|
output.input_ids += token_ids |
|
|
output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] |
|
|
output.sample_mm_counter["video"] += 1 |
|
|
tag_cursor += 1 |
|
|
|
|
|
elif part == "<|audio|>": |
|
|
mime = Preprocessor.prompt_mime( |
|
|
mimes=mimes if not file_names else None, |
|
|
fixed_mime=fixed_mime if not file_names else False, |
|
|
file_name=file_names[tag_cursor] if file_names else None, |
|
|
tag_idx=output.sample_mm_counter["audio"], |
|
|
is_video=False, |
|
|
is_audio=True, |
|
|
seed=seed, |
|
|
) |
|
|
mm_str = "" |
|
|
aux_inputs = { |
|
|
"audio_duration": output.audios_duration[ |
|
|
output.sample_mm_counter["audio"] |
|
|
]["duration"], |
|
|
} |
|
|
mime_str = f"{cls.mime_start}{json.dumps(mime, ensure_ascii=False)}{cls.mime_end}" |
|
|
aux_str = f"{cls.aux_audio_start}{cls.aux_audio_prompt}{json.dumps(aux_inputs, ensure_ascii=False)}{cls.aux_audio_end}" |
|
|
discrete_audio_str = f"{cls.discrete_audio_start}{cls.discrete_audio_pad}{cls.discrete_audio_end}" |
|
|
vector_str = f"{cls.audio_start}{cls.audio_pad}{cls.audio_end}" |
|
|
mm_str += ( |
|
|
cls.new_line |
|
|
+ mime_str |
|
|
+ cls.new_line |
|
|
+ aux_str |
|
|
+ cls.new_line |
|
|
+ discrete_audio_str |
|
|
+ cls.new_line |
|
|
+ vector_str |
|
|
) |
|
|
if hasattr(output, "input_str"): |
|
|
output.input_str += mm_str |
|
|
if getattr(output, "input_ids", None) is not None: |
|
|
token_ids = tokenizer.encode(mm_str, truncation=False) |
|
|
output.input_ids += token_ids |
|
|
output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] |
|
|
|
|
|
output.sample_mm_counter["audio"] += 1 |
|
|
tag_cursor += 1 |
|
|
|
|
|
if hasattr(output, "input_str"): |
|
|
output.input_str += cls.turn_suffix |
|
|
|
|
|
if getattr(output, "input_ids", None) is not None: |
|
|
token_ids = tokenizer.encode(cls.turn_suffix, truncation=False) |
|
|
output.input_ids += token_ids |
|
|
output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] |
|
|
|
|
|
return output |
|
|
|
|
|
@classmethod |
|
|
def prompt_assistant( |
|
|
cls, |
|
|
output, |
|
|
tokenizer=None, |
|
|
turn: Optional[dict] = None, |
|
|
role: Optional[str] = "assistant", |
|
|
content: Optional[str] = None, |
|
|
is_last_turn=False, |
|
|
is_eval=True, |
|
|
is_llava_pretrain=False, |
|
|
is_after_last_user_turn=False, |
|
|
): |
|
|
assert content or turn |
|
|
if turn is None: |
|
|
turn = { |
|
|
"content": content, |
|
|
"role": role, |
|
|
} |
|
|
|
|
|
if is_llava_pretrain: |
|
|
if hasattr(output, "input_str"): |
|
|
output.input_str += turn["content"] |
|
|
if getattr(output, "input_ids", None) is not None: |
|
|
content_encoded = tokenizer.encode(turn["content"], truncation=False) |
|
|
output.input_ids += content_encoded |
|
|
output.label_ids += content_encoded |
|
|
return output |
|
|
|
|
|
reasoning_content = turn.get("reasoning_content", "") |
|
|
if ( |
|
|
not reasoning_content |
|
|
and isinstance(turn["content"], str) |
|
|
and "</think>" in turn["content"] |
|
|
): |
|
|
parts = turn["content"].split("</think>", 1) |
|
|
reasoning_content = parts[0].split("<think>", 1)[-1].lstrip("\n") |
|
|
turn["content"] = parts[1].lstrip("\n") |
|
|
|
|
|
if is_after_last_user_turn and (is_last_turn or reasoning_content): |
|
|
content_to_strip = turn.get("content") or "" |
|
|
stripped_content = content_to_strip.lstrip("\n") |
|
|
|
|
|
if reasoning_content is None: |
|
|
reasoning_content = "" |
|
|
turn["content"] = ( |
|
|
f"<think>\n{reasoning_content.strip()}\n</think>\n\n{stripped_content}" |
|
|
) |
|
|
|
|
|
if turn.get("tool_calls"): |
|
|
for tool_call in turn["tool_calls"]: |
|
|
func_name = tool_call.get("function", {}).get("name", "") |
|
|
args = tool_call.get("function", {}).get("arguments", {}) |
|
|
|
|
|
if isinstance(args, str): |
|
|
try: |
|
|
args = json.loads(args) |
|
|
except Exception: |
|
|
pass |
|
|
if not isinstance(args, dict): |
|
|
print( |
|
|
f"[error] tool_call.function.arguments가 dict이 아님: type={type(args)}, value={str(args)}" |
|
|
) |
|
|
assert ( |
|
|
False |
|
|
), "tool_call.function.arguments는 dict이거나 dict를 나타내는 JSON 문자열이어야 합니다." |
|
|
|
|
|
tool_turn_content = f"\n<tool_call>{func_name}\n" |
|
|
|
|
|
for key, value in args.items(): |
|
|
arg_value = ( |
|
|
json.dumps(value, ensure_ascii=False) |
|
|
if not isinstance(value, str) |
|
|
else value |
|
|
) |
|
|
tool_turn_content += f"<arg_key>{key}</arg_key>\n<arg_value>{arg_value}</arg_value>\n" |
|
|
tool_turn_content += "</tool_call>" |
|
|
|
|
|
if func_name == "t2i_model_generation": |
|
|
assert ( |
|
|
"<|t2i_model_generation_target_discrete_image|>" |
|
|
in turn["content"] |
|
|
), "t2i_model_generation tool call must have target discrete image tag in content." |
|
|
turn["content"] = turn["content"].replace( |
|
|
"<|t2i_model_generation_target_discrete_image|>", |
|
|
tool_turn_content, |
|
|
) |
|
|
else: |
|
|
turn["content"] += tool_turn_content |
|
|
|
|
|
pattern = re.compile( |
|
|
r"(<\|image\|>|<\|discrete_image\|>|<\|audio\|>|<\|discrete_audio\|>)" |
|
|
) |
|
|
all_tags_in_order = [ |
|
|
match.group() for match in pattern.finditer(turn["content"]) |
|
|
] |
|
|
|
|
|
assert ( |
|
|
len(turn.get("image_urls", [])) |
|
|
+ len(turn.get("video_urls", [])) |
|
|
+ len(turn.get("audio_urls", [])) |
|
|
) == len( |
|
|
all_tags_in_order |
|
|
), f"Number of media URLs does not match number of media tags." |
|
|
|
|
|
if hasattr(output, "input_str"): |
|
|
output.input_str += f"{cls.new_line}{cls.turn_prefix}{turn['role']}" |
|
|
if is_eval and is_last_turn: |
|
|
if reasoning_content.strip() == "": |
|
|
output.input_str += f"<think>\n\n</think>\n\n" |
|
|
turn["content"] = stripped_content |
|
|
else: |
|
|
output.input_str += f"{turn['content']}{cls.turn_suffix}" |
|
|
|
|
|
if getattr(output, "input_ids", None) is not None: |
|
|
role_encoded = tokenizer.encode( |
|
|
f"{cls.new_line}{cls.turn_prefix}{turn['role']}", truncation=False |
|
|
) |
|
|
output.input_ids += role_encoded |
|
|
|
|
|
if is_eval and is_last_turn: |
|
|
if reasoning_content.strip() == "": |
|
|
output.input_ids += tokenizer.encode( |
|
|
f"<think>\n\n</think>\n\n", truncation=False |
|
|
) |
|
|
turn["content"] = stripped_content |
|
|
else: |
|
|
if turn.get("trainable_role", True): |
|
|
output.label_ids += role_encoded |
|
|
else: |
|
|
output.label_ids += [IGNORE_INDEX for _ in range(len(role_encoded))] |
|
|
|
|
|
turn_img_idx = 0 |
|
|
content_parts = pattern.split(turn["content"].strip()) |
|
|
for part in content_parts: |
|
|
part = part.strip() |
|
|
|
|
|
if not part: |
|
|
continue |
|
|
|
|
|
if part not in [ |
|
|
"<|image|>", |
|
|
"<|discrete_image|>", |
|
|
"<|audio|>", |
|
|
"<|discrete_audio|>", |
|
|
]: |
|
|
content_text = part |
|
|
|
|
|
if hasattr(output, "input_str"): |
|
|
output.input_str += "\n" + content_text |
|
|
if getattr(output, "input_ids", None) is not None: |
|
|
content_encoded = tokenizer.encode( |
|
|
"\n" + content_text, truncation=False |
|
|
) |
|
|
output.input_ids += content_encoded |
|
|
if turn.get("trainable_content", True): |
|
|
output.label_ids += content_encoded |
|
|
else: |
|
|
output.label_ids += [ |
|
|
IGNORE_INDEX for _ in range(len(content_encoded)) |
|
|
] |
|
|
continue |
|
|
|
|
|
if part == "<|image|>": |
|
|
file_name = turn.get("image_urls", [])[turn_img_idx] |
|
|
if isinstance(file_name, str) and "#" in file_name: |
|
|
file_name = file_name.split("#")[-1] |
|
|
file_name = os.path.basename(file_name) |
|
|
mime = Preprocessor.prompt_mime( |
|
|
mimes=None, |
|
|
fixed_mime=False, |
|
|
file_name=file_name, |
|
|
tag_idx=output.sample_mm_counter["image"], |
|
|
is_video=False, |
|
|
is_audio=False, |
|
|
seed=None, |
|
|
) |
|
|
mime_str = f"{cls.mime_start}{json.dumps(mime, ensure_ascii=False)}{cls.mime_end}" |
|
|
discrete_image_str = f"{cls.discrete_image_start}{cls.discrete_image_pad}{cls.discrete_image_end}" |
|
|
vector_str = f"{cls.image_start}{cls.image_pad}{cls.image_end}" |
|
|
mm_str = ( |
|
|
cls.new_line |
|
|
+ mime_str |
|
|
+ cls.new_line |
|
|
+ discrete_image_str |
|
|
+ cls.new_line |
|
|
+ vector_str |
|
|
) |
|
|
|
|
|
if hasattr(output, "input_str"): |
|
|
output.input_str += mm_str |
|
|
if getattr(output, "input_ids", None) is not None: |
|
|
token_ids = tokenizer.encode(mm_str, truncation=False) |
|
|
output.input_ids += token_ids |
|
|
output.label_ids += [ |
|
|
IGNORE_INDEX for _ in range(len(token_ids)) |
|
|
] |
|
|
turn_img_idx += 1 |
|
|
output.sample_mm_counter["image"] += 1 |
|
|
|
|
|
elif part == "<|discrete_image|>": |
|
|
discrete_image_str = f"{cls.discrete_image_start}{cls.discrete_image_pad}{cls.discrete_image_end}" |
|
|
mm_str = cls.new_line + discrete_image_str |
|
|
if hasattr(output, "input_str"): |
|
|
output.input_str += mm_str |
|
|
if getattr(output, "input_ids", None) is not None: |
|
|
token_ids = tokenizer.encode(mm_str, truncation=False) |
|
|
output.input_ids += token_ids |
|
|
output.label_ids += token_ids |
|
|
turn_img_idx += 1 |
|
|
|
|
|
elif part == "<|discrete_audio|>": |
|
|
discrete_audio_str = f"{cls.discrete_audio_start}{cls.discrete_audio_pad}{cls.discrete_audio_end}" |
|
|
mm_str = cls.new_line + discrete_audio_str |
|
|
if hasattr(output, "input_str"): |
|
|
output.input_str += mm_str |
|
|
if getattr(output, "input_ids", None) is not None: |
|
|
token_ids = tokenizer.encode(mm_str, truncation=False) |
|
|
output.input_ids += token_ids |
|
|
if turn.get("trainable_content", True): |
|
|
output.label_ids += token_ids |
|
|
else: |
|
|
output.label_ids += [ |
|
|
IGNORE_INDEX for _ in range(len(token_ids)) |
|
|
] |
|
|
|
|
|
elif part == "<|audio|>": |
|
|
raise Exception( |
|
|
"Assistant turn에서 <|audio|> 태그는 지원하지 않음. discrete_audio 만 지원함." |
|
|
) |
|
|
|
|
|
if hasattr(output, "input_str"): |
|
|
output.input_str += cls.turn_suffix |
|
|
|
|
|
if getattr(output, "input_ids", None) is not None: |
|
|
token_ids = tokenizer.encode(cls.turn_suffix, truncation=False) |
|
|
output.input_ids += token_ids |
|
|
if turn.get("trainable_content", True): |
|
|
output.label_ids += token_ids |
|
|
else: |
|
|
output.label_ids += [IGNORE_INDEX for _ in range(len(token_ids))] |
|
|
|
|
|
return output |
|
|
|
|
|
@classmethod |
|
|
def prompt_tool( |
|
|
cls, |
|
|
output, |
|
|
tokenizer=None, |
|
|
turn: Optional[dict] = None, |
|
|
role: Optional[str] = None, |
|
|
content: Optional[str] = None, |
|
|
eot: Optional[bool] = None, |
|
|
need_start_tag=True, |
|
|
need_end_tag=True, |
|
|
): |
|
|
assert (content and role) or turn |
|
|
if turn is None: |
|
|
turn = { |
|
|
"content": content, |
|
|
"role": role, |
|
|
"endofturn": eot, |
|
|
} |
|
|
assert ( |
|
|
"tool" == turn["role"] |
|
|
), f'[warning] unexpected turn["role"]: {turn["role"]}' |
|
|
content_value = turn.get("content", "") |
|
|
|
|
|
if isinstance(content_value, dict): |
|
|
if "response" in content_value: |
|
|
content_str = content_value["response"] |
|
|
else: |
|
|
content_str = json.dumps(content_value, ensure_ascii=False) |
|
|
elif isinstance(content_value, str): |
|
|
try: |
|
|
parsed = json.loads(content_value) |
|
|
if isinstance(parsed, dict): |
|
|
if "response" in parsed: |
|
|
content_str = parsed["response"] |
|
|
else: |
|
|
content_str = json.dumps(parsed, ensure_ascii=False) |
|
|
else: |
|
|
content_str = content_value |
|
|
except (json.JSONDecodeError, TypeError): |
|
|
content_str = content_value |
|
|
else: |
|
|
content_str = str(content_value) |
|
|
|
|
|
turn["content"] = ( |
|
|
f"<tool_response>{turn.get('name', '')}\n{content_str}\n</tool_response>" |
|
|
) |
|
|
|
|
|
if hasattr(output, "input_str"): |
|
|
if need_start_tag: |
|
|
output.input_str += f"{cls.new_line}{cls.turn_prefix}{turn['role']}" |
|
|
output.input_str += f"{cls.new_line}{turn['content']}" |
|
|
if need_end_tag: |
|
|
output.input_str += cls.turn_suffix |
|
|
|
|
|
if getattr(output, "input_ids", None) is not None: |
|
|
if need_start_tag: |
|
|
role_encoded = tokenizer.encode( |
|
|
f"{cls.new_line}{cls.turn_prefix}{turn['role']}", truncation=False |
|
|
) |
|
|
output.input_ids += role_encoded |
|
|
|
|
|
if turn.get("trainable_role", True): |
|
|
output.label_ids += role_encoded |
|
|
else: |
|
|
output.label_ids += [IGNORE_INDEX for _ in range(len(role_encoded))] |
|
|
|
|
|
content = f"{cls.new_line}{turn['content']}" |
|
|
content_encoded = tokenizer.encode(content, truncation=False) |
|
|
if need_end_tag: |
|
|
content_encoded += tokenizer.encode( |
|
|
f"{cls.turn_suffix}", truncation=False |
|
|
) |
|
|
output.input_ids += content_encoded |
|
|
if turn.get("trainable_content", True): |
|
|
output.label_ids += content_encoded |
|
|
else: |
|
|
output.label_ids += [ |
|
|
IGNORE_INDEX for _ in range(len(content_encoded)) |
|
|
] |
|
|
return output |
|
|
|
|
|
@classmethod |
|
|
def prompt_etc( |
|
|
cls, |
|
|
output, |
|
|
tokenizer=None, |
|
|
turn: Optional[dict] = None, |
|
|
role: Optional[str] = None, |
|
|
content: Optional[str] = None, |
|
|
eot: Optional[bool] = None, |
|
|
): |
|
|
assert (content and role) or turn |
|
|
if turn is None: |
|
|
turn = { |
|
|
"content": content, |
|
|
"role": role, |
|
|
"endofturn": eot, |
|
|
} |
|
|
print(f'[warning] unexpected turn["role"]: {turn["role"]}') |
|
|
|
|
|
if hasattr(output, "input_str"): |
|
|
output.input_str += f"{cls.turn_prefix}{turn['role']}\n" |
|
|
output.input_str += f"{turn['content']}{cls.turn_suffix}" |
|
|
if turn.get("stop", False): |
|
|
output.input_str += cls.stop_token |
|
|
if turn.get("endofturn", False): |
|
|
output.input_str += cls.eot |
|
|
|
|
|
if getattr(output, "input_ids", None) is not None: |
|
|
role_encoded = tokenizer.encode( |
|
|
f"{cls.turn_prefix}{turn['role']}\n", truncation=False |
|
|
) |
|
|
output.input_ids += role_encoded |
|
|
|
|
|
if turn.get("trainable_role", True): |
|
|
output.label_ids += role_encoded |
|
|
else: |
|
|
output.label_ids += [IGNORE_INDEX for _ in range(len(role_encoded))] |
|
|
|
|
|
content = f"{turn['content']}{cls.turn_suffix}" |
|
|
if turn.get("stop", False): |
|
|
content += cls.stop_token |
|
|
if turn.get("endofturn", False): |
|
|
content += cls.eot |
|
|
content_encoded = tokenizer.encode(content, truncation=False) |
|
|
output.input_ids += content_encoded |
|
|
if turn.get("trainable_content", True): |
|
|
output.label_ids += content_encoded |
|
|
else: |
|
|
output.label_ids += [IGNORE_INDEX for _ in range(len(content_encoded))] |
|
|
return output |
|
|
|
|
|
def __call__(self, sample): |
|
|
return self.preprocess_new(sample) |
|
|
|
|
|
@classmethod |
|
|
def batchify( |
|
|
cls, |
|
|
items: List[Dict[str, Any],], |
|
|
device: str = None, |
|
|
): |
|
|
batch = dict() |
|
|
for item in items: |
|
|
for k, v in item.items(): |
|
|
if isinstance(v, torch.Tensor): |
|
|
if device is not None: |
|
|
v = v.to(device=device) |
|
|
elif k == "pixel_values": |
|
|
v = [_v.to(device=device) for _v in v] |
|
|
|
|
|
if k not in batch: |
|
|
batch[k] = [ |
|
|
v, |
|
|
] |
|
|
else: |
|
|
batch[k].append(v) |
|
|
|
|
|
for k, v in batch.items(): |
|
|
if isinstance(v[0], torch.Tensor): |
|
|
if k in ["image_grid_thw", "video_grid_thw"]: |
|
|
batch[k] = torch.cat(v, dim=0) |
|
|
continue |
|
|
batch[k] = torch.stack(v, dim=0) |
|
|
batch["video_grid_thw"] = None |
|
|
batch["pixel_values_videos"] = None |
|
|
return batch |
|
|
|
|
|
def convert_wds_to_datalake( |
|
|
self, |
|
|
img: Union[PIL.Image.Image, Dict[str, PIL.Image.Image]] = {}, |
|
|
json: Dict[str, Any] = {}, |
|
|
benchmark: Optional[str] = None, |
|
|
video: Union[io.BytesIO, Dict[str, io.BytesIO]] = {}, |
|
|
audio: Union[io.BytesIO, Dict[str, io.BytesIO]] = {}, |
|
|
): |
|
|
|
|
|
if "lines" in json: |
|
|
del json["lines"] |
|
|
if "paragraphs" in json: |
|
|
del json["paragraphs"] |
|
|
|
|
|
assert json["meta"]["type"] in [ |
|
|
"caption", |
|
|
"vqa", |
|
|
"textread", |
|
|
], f"{json['meta']['path']}, {json['meta']['type']}: The dataset type should be one of them: caption, vqa, textread." |
|
|
|
|
|
sample = {"vlm": {}} |
|
|
sample["vlm"] = get_wds_default_config( |
|
|
json["meta"], existing_default_config=self.wds_default_config |
|
|
) |
|
|
sample["vlm"]["data_name"] = json["meta"].get("name", "unk") |
|
|
|
|
|
sample["vlm"]["data_type"] = ( |
|
|
"wds" |
|
|
if (isinstance(img, PIL.Image.Image) and img) |
|
|
or (isinstance(img, dict) and len(img) > 0) |
|
|
else "sft1" |
|
|
) |
|
|
|
|
|
sample["vlm"]["sample_id"] = json.get("qa_id", None) |
|
|
sample["vlm"]["category"] = json.get("category", None) |
|
|
sample["vlm"]["data_info"] = json.get("data_info", dict()) |
|
|
sample["vlm"]["options"] = None |
|
|
if "choices_en" in sample["vlm"]["data_info"]: |
|
|
if sample["vlm"]["options"] is None and json["meta"]["lang"] == "en": |
|
|
sample["vlm"]["options"] = sample["vlm"]["data_info"]["choices_en"] |
|
|
sample["vlm"]["options_en"] = sample["vlm"]["data_info"]["choices_en"] |
|
|
if "choices_ko" in sample["vlm"]["data_info"]: |
|
|
if sample["vlm"]["options"] is None and json["meta"]["lang"] == "ko": |
|
|
sample["vlm"]["options"] = sample["vlm"]["data_info"]["choices_ko"] |
|
|
sample["vlm"]["options_ko"] = sample["vlm"]["data_info"]["choices_ko"] |
|
|
sample["vlm"]["image_index"] = json.get( |
|
|
"image_index", json.get("img_url", None) |
|
|
) |
|
|
|
|
|
if sample["vlm"].get("video", False): |
|
|
is_multi_image_dataset = False |
|
|
else: |
|
|
is_multi_image_dataset, img, json = convert_format_for_multi_image( |
|
|
img, json |
|
|
) |
|
|
|
|
|
if json["meta"]["type"] == "textread": |
|
|
key = "words" |
|
|
elif json["meta"].get("subtask", "") == "region": |
|
|
key = f"regions_{json['meta']['lang']}" |
|
|
elif json["meta"]["type"] == "vqa": |
|
|
key = f"qa_pairs_{json['meta']['lang']}" |
|
|
elif json["meta"]["type"] == "caption": |
|
|
key = f"captions_{json['meta']['lang']}" |
|
|
else: |
|
|
raise ConditionalError( |
|
|
f"wrong task type in wds config: {sample['vlm']['data_name']}" |
|
|
) |
|
|
|
|
|
turns = [ |
|
|
{ |
|
|
"role": "tool_list", |
|
|
"content": "", |
|
|
"content_type": "text", |
|
|
"trainable_role": False, |
|
|
"trainable_content": False, |
|
|
"stop": False, |
|
|
"debuggingInfo": {}, |
|
|
"meta": {}, |
|
|
"candidates": [], |
|
|
"endofturn": False, |
|
|
}, |
|
|
{ |
|
|
"role": "system", |
|
|
"content_type": "text", |
|
|
"candidates": [], |
|
|
"trainable_role": False, |
|
|
"trainable_content": False, |
|
|
"stop": False, |
|
|
"debuggingInfo": {}, |
|
|
"meta": {}, |
|
|
"content": "", |
|
|
"endofturn": False, |
|
|
}, |
|
|
] |
|
|
|
|
|
if json["meta"].get("llava_pretrain", False): |
|
|
sample["vlm"]["llava_pretrain"] = True |
|
|
|
|
|
use_task_prompt = json["meta"].get( |
|
|
"use_task_prompt", self.wds_default_config["use_task_prompt"] |
|
|
) |
|
|
get_random = json["meta"].get( |
|
|
"get_random", self.wds_default_config["get_random"] |
|
|
) |
|
|
reasoning = json["meta"].get("reasoning", self.wds_default_config["reasoning"]) |
|
|
|
|
|
try: |
|
|
if key not in json: |
|
|
key = key[:-3] |
|
|
assert key in json |
|
|
if len(json[key]) == 0: |
|
|
key = key[:-3] |
|
|
assert key in json |
|
|
except: |
|
|
raise ConditionalError( |
|
|
f"{key} key is not in json? dataset name: {sample['vlm']['data_name']}" |
|
|
) |
|
|
|
|
|
first_turn = True |
|
|
if "region" in key: |
|
|
json[key] = json[key]["00"] |
|
|
sample["vlm"]["multiturn_n_samples"] = 1 |
|
|
if ( |
|
|
not is_multi_image_dataset |
|
|
and sample["vlm"]["multiturn_n_samples"] > 1 |
|
|
or "region" in key |
|
|
): |
|
|
json[key] = sampling_multiturn_single_img( |
|
|
json[key], |
|
|
sample["vlm"]["multiturn_n_samples"], |
|
|
sample["vlm"]["multiturn_preserve_order"], |
|
|
sample["vlm"]["multiturn_continuous"], |
|
|
) |
|
|
|
|
|
if sample["vlm"].get("video", False): |
|
|
for qa in json[key]: |
|
|
vid_src = [] |
|
|
user = { |
|
|
"role": "user", |
|
|
"content_type": "text", |
|
|
"candidates": [], |
|
|
"trainable_role": False, |
|
|
"trainable_content": False, |
|
|
"stop": False, |
|
|
"debuggingInfo": {}, |
|
|
"meta": {}, |
|
|
"image_urls": [], |
|
|
"image_metas": [], |
|
|
"video_urls": [], |
|
|
"video_metas": [], |
|
|
"audio_urls": [], |
|
|
"audio_metas": [], |
|
|
"content": "", |
|
|
"endofturn": False, |
|
|
} |
|
|
|
|
|
instruct_prompt, task_prompt = hcx_vision_prompter( |
|
|
task=json["meta"]["type"], |
|
|
subtask=json["meta"].get("subtask", None), |
|
|
lang=json["meta"]["lang"], |
|
|
get_random=get_random, |
|
|
use_task_prompt=use_task_prompt, |
|
|
) |
|
|
|
|
|
prompt = qa[0] |
|
|
answer = qa[-1] if reasoning else qa[1] |
|
|
|
|
|
if first_turn: |
|
|
user["video_metas"].append({"lens": []}) |
|
|
user["content"] += "<|video|>" |
|
|
prompt = task_prompt.format(prompt) |
|
|
|
|
|
if "entities" in json: |
|
|
user["video_metas"][0]["lens"] = json["entities"].get("00", []) |
|
|
if isinstance(video, dict): |
|
|
vid_src.append(video["00"]) |
|
|
else: |
|
|
vid_src.append(video) |
|
|
first_turn = False |
|
|
|
|
|
user["video_urls"] = vid_src |
|
|
user["content"] += prompt |
|
|
|
|
|
assistant = { |
|
|
"candidates": [], |
|
|
"content": answer, |
|
|
"content_type": "text", |
|
|
"debuggingInfo": {}, |
|
|
"meta": {}, |
|
|
"role": "assistant", |
|
|
"trainable_content": True, |
|
|
"trainable_role": True, |
|
|
"stop": False, |
|
|
"endofturn": True, |
|
|
} |
|
|
turns.append(user) |
|
|
turns.append(assistant) |
|
|
|
|
|
else: |
|
|
if key.startswith("qa_pairs") or key.startswith("captions"): |
|
|
if self.mode != "train" and key.startswith("qa_pairs"): |
|
|
qas = dict() |
|
|
for qa in json[key]: |
|
|
q = qa[0] |
|
|
if q not in qas: |
|
|
qas[q] = list() |
|
|
for _i, _e in enumerate(qa[1:]): |
|
|
if len(qas[q]) <= _i: |
|
|
qas[q].append(list()) |
|
|
qas[q][_i].append(_e) |
|
|
json[key] = [ |
|
|
[ |
|
|
k, |
|
|
] |
|
|
+ v |
|
|
for k, v in qas.items() |
|
|
] |
|
|
|
|
|
if self.mode != "train": |
|
|
json[key] = json[key][:1] |
|
|
|
|
|
for qa in json[key]: |
|
|
img_src = [] |
|
|
user = { |
|
|
"role": "user", |
|
|
"content_type": "text", |
|
|
"candidates": [], |
|
|
"trainable_role": False, |
|
|
"trainable_content": False, |
|
|
"stop": False, |
|
|
"debuggingInfo": {}, |
|
|
"meta": {}, |
|
|
"image_urls": [], |
|
|
"image_metas": [], |
|
|
"video_urls": [], |
|
|
"video_metas": [], |
|
|
"audio_urls": [], |
|
|
"audio_metas": [], |
|
|
"content": "", |
|
|
"endofturn": False, |
|
|
} |
|
|
img_keys = re.findall(r"<image_(\d+)>", qa[0]) |
|
|
video_keys = re.findall(r"<video_(\d+)>", qa[0]) |
|
|
audio_keys = re.findall(r"<audio_(\d+)>", qa[0]) |
|
|
|
|
|
if key.startswith("qa_pairs"): |
|
|
if len(qa) > 2: |
|
|
sample_id = qa[2] |
|
|
if ( |
|
|
isinstance(sample_id, (list, tuple)) |
|
|
and len(sample_id) > 0 |
|
|
): |
|
|
sample_id = sample_id[0] |
|
|
sample["vlm"]["sample_id"] = sample_id |
|
|
|
|
|
instruct_prompt, task_prompt = hcx_vision_prompter( |
|
|
task=json["meta"]["type"], |
|
|
subtask=json["meta"].get("subtask", None), |
|
|
lang=json["meta"]["lang"], |
|
|
get_random=get_random, |
|
|
use_task_prompt=use_task_prompt, |
|
|
) |
|
|
if json["meta"]["type"] == "vqa": |
|
|
prompt = qa[0] |
|
|
answer = qa[-1] if reasoning else qa[1] |
|
|
elif json["meta"]["type"] == "caption": |
|
|
prompt = task_prompt.format("") |
|
|
answer = qa |
|
|
|
|
|
if first_turn or self.mode != "train": |
|
|
if json["meta"]["type"] == "vqa": |
|
|
prompt = task_prompt.format(prompt) |
|
|
if first_turn and not is_multi_image_dataset: |
|
|
user["image_metas"].append({"words": [], "lens": []}) |
|
|
if "<image_00>" in prompt: |
|
|
prompt = prompt.replace("<image_00>", "<|image|>") |
|
|
else: |
|
|
user["content"] += "<|image|>" |
|
|
user["image_metas"][0]["words"] = json.get("words", {}).get( |
|
|
"00", [] |
|
|
) |
|
|
if "objects" in json: |
|
|
user["image_metas"][0]["lens"] = json["objects"].get( |
|
|
"00", [] |
|
|
) |
|
|
elif "entities" in json: |
|
|
user["image_metas"][0]["lens"] = json["entities"].get( |
|
|
"00", [] |
|
|
) |
|
|
if isinstance(img, dict): |
|
|
img_src.append(img["00"]) |
|
|
else: |
|
|
img_src.append(img) |
|
|
elif len(img_keys) > 0: |
|
|
for i, key in enumerate(img_keys): |
|
|
user["image_metas"].append({"words": [], "lens": []}) |
|
|
if f"<image_{i:02d}>" in prompt: |
|
|
prompt = prompt.replace(f"<image_{i:02d}>", "<|image|>") |
|
|
else: |
|
|
user["content"] += "<|image|>" |
|
|
img_src.append(img[key]) |
|
|
_words = json.get("words", {}) |
|
|
if isinstance(_words, dict): |
|
|
_words = _words.get(key, []) |
|
|
user["image_metas"][i]["words"] = _words |
|
|
if "objects" in json: |
|
|
_objects = json["objects"].get(key, []) |
|
|
if isinstance(_objects, dict): |
|
|
_objects = _objects.get(key, []) |
|
|
user["image_metas"][i]["lens"] = _objects |
|
|
if "entities" in json: |
|
|
_entities = json["entities"].get(key, []) |
|
|
if isinstance(_entities, dict): |
|
|
_entities = _entities.get(key, []) |
|
|
user["image_metas"][i]["lens"] = _entities |
|
|
user["image_urls"] = img_src |
|
|
|
|
|
if len(audio_keys) > 0: |
|
|
for i, key in enumerate(audio_keys): |
|
|
if isinstance(audio, dict): |
|
|
user["audio_urls"].append(audio[key]) |
|
|
else: |
|
|
user["audio_urls"].append(audio) |
|
|
user["audio_metas"].append( |
|
|
{ |
|
|
"format": "wav", |
|
|
"note": "This audio sample is passed to convert_wds_to_datalake function.", |
|
|
} |
|
|
) |
|
|
if f"<audio_{i:02d}>" in prompt: |
|
|
prompt = prompt.replace(f"<audio_{i:02d}>", "<|audio|>") |
|
|
else: |
|
|
user["content"] += "<|audio|>" |
|
|
|
|
|
user["content"] += prompt |
|
|
|
|
|
content, candidates = None, list() |
|
|
if self.mode != "train": |
|
|
if isinstance(answer, (int, float)): |
|
|
pass |
|
|
elif isinstance(answer, str): |
|
|
if answer != "None": |
|
|
try: |
|
|
answer = ast.literal_eval(answer) |
|
|
except Exception as ex: |
|
|
pass |
|
|
if not isinstance(answer, (list, tuple)): |
|
|
answer = [ |
|
|
answer, |
|
|
] |
|
|
candidates += answer[1:] |
|
|
answer = answer[0] |
|
|
content = answer |
|
|
elif isinstance(answer, (list, tuple)): |
|
|
for _idx, _answer in enumerate(answer): |
|
|
if isinstance(_answer, str): |
|
|
if isinstance(benchmark, str) and benchmark in [ |
|
|
"textvqa", |
|
|
]: |
|
|
try: |
|
|
_answer = ast.literal_eval(_answer) |
|
|
except Exception as ex: |
|
|
pass |
|
|
if isinstance(_answer, dict): |
|
|
_answer = str(_answer) |
|
|
if not isinstance(_answer, (list, tuple)): |
|
|
_answer = [ |
|
|
_answer, |
|
|
] |
|
|
if _idx == 0: |
|
|
content = _answer[0] |
|
|
candidates += _answer[1:] |
|
|
else: |
|
|
candidates += _answer |
|
|
|
|
|
if isinstance(content, (int, float)): |
|
|
content = str(content) |
|
|
assert content is None or isinstance(content, str) |
|
|
for _idx, _candidate in enumerate(candidates): |
|
|
if isinstance(_candidate, (int, float)): |
|
|
candidates[_idx] = str(_candidate) |
|
|
assert isinstance(candidates[_idx], str) |
|
|
mcqa_gt = sample["vlm"]["data_info"].get("choice_answer", None) |
|
|
if isinstance(mcqa_gt, str): |
|
|
content = mcqa_gt |
|
|
|
|
|
assistant = { |
|
|
"candidates": candidates, |
|
|
"content": answer if self.mode == "train" else content, |
|
|
"content_type": "text", |
|
|
"debuggingInfo": {}, |
|
|
"meta": {}, |
|
|
"role": "assistant", |
|
|
"trainable_content": True, |
|
|
"trainable_role": True, |
|
|
"stop": False, |
|
|
"endofturn": True, |
|
|
} |
|
|
turns.append(user) |
|
|
turns.append(assistant) |
|
|
|
|
|
elif key == "words": |
|
|
img_src = [] |
|
|
user = { |
|
|
"role": "user", |
|
|
"content_type": "text", |
|
|
"candidates": [], |
|
|
"trainable_role": False, |
|
|
"trainable_content": False, |
|
|
"stop": False, |
|
|
"debuggingInfo": {}, |
|
|
"meta": {}, |
|
|
"image_urls": [], |
|
|
"image_metas": [], |
|
|
"video_urls": [], |
|
|
"video_metas": [], |
|
|
"audio_urls": [], |
|
|
"audio_metas": [], |
|
|
"content": "<|image|>", |
|
|
"endofturn": False, |
|
|
} |
|
|
instruct_prompt, task_prompt = hcx_vision_prompter( |
|
|
task=json["meta"]["type"], |
|
|
subtask=json["meta"].get("subtask", None), |
|
|
lang=json["meta"]["lang"], |
|
|
get_random=get_random, |
|
|
use_task_prompt=use_task_prompt, |
|
|
) |
|
|
user["content"] += task_prompt |
|
|
user["image_metas"].append({"words": [], "lens": []}) |
|
|
user["image_metas"][0]["words"] = json["words"]["00"] |
|
|
if "entities" in json: |
|
|
user["image_metas"][0]["lens"] = json["entities"].get("00", []) |
|
|
img_src.append(img["00"]) |
|
|
user["image_urls"] = img_src |
|
|
|
|
|
words_list = [ |
|
|
d["text"].strip() for d in json["words"]["00"] if d["text"] |
|
|
] |
|
|
gt = " ".join(words_list) |
|
|
assistant = { |
|
|
"candidates": [], |
|
|
"content": gt, |
|
|
"content_type": "text", |
|
|
"debuggingInfo": {}, |
|
|
"meta": {}, |
|
|
"role": "assistant", |
|
|
"trainable_content": True, |
|
|
"trainable_role": True, |
|
|
"stop": False, |
|
|
"endofturn": True, |
|
|
} |
|
|
turns.append(user) |
|
|
turns.append(assistant) |
|
|
|
|
|
elif key.startswith("regions"): |
|
|
for region in json[key]: |
|
|
img_src = [] |
|
|
user = { |
|
|
"role": "user", |
|
|
"content_type": "text", |
|
|
"candidates": [], |
|
|
"trainable_role": False, |
|
|
"trainable_content": False, |
|
|
"stop": False, |
|
|
"debuggingInfo": {}, |
|
|
"meta": {}, |
|
|
"image_urls": [], |
|
|
"image_metas": [], |
|
|
"video_urls": [], |
|
|
"video_metas": [], |
|
|
"audio_urls": [], |
|
|
"audio_metas": [], |
|
|
"content": "<|image|><|region|>", |
|
|
"endofturn": False, |
|
|
} |
|
|
instruct_prompt, task_prompt = hcx_vision_prompter( |
|
|
task=json["meta"]["type"], |
|
|
subtask=json["meta"].get("subtask", None), |
|
|
lang=json["meta"]["lang"], |
|
|
get_random=get_random, |
|
|
use_task_prompt=use_task_prompt, |
|
|
) |
|
|
sample["vlm"]["query_template"] = [task_prompt] |
|
|
user["image_metas"].append({"words": [], "lens": []}) |
|
|
user["image_metas"][0]["region"] = region |
|
|
if "words" in json: |
|
|
user["image_metas"][0]["words"] = json["words"].get("00", []) |
|
|
if "objects" in json: |
|
|
user["image_metas"][0]["lens"] = json["objects"].get("00", []) |
|
|
if "entities" in json: |
|
|
user["image_metas"][0]["lens"] = json["entities"].get("00", []) |
|
|
img_src.append(img["00"]) |
|
|
user["image_urls"] = img_src |
|
|
|
|
|
assistant = { |
|
|
"candidates": [], |
|
|
"content": region["text"], |
|
|
"content_type": "text", |
|
|
"debuggingInfo": {}, |
|
|
"meta": {}, |
|
|
"role": "assistant", |
|
|
"trainable_content": True, |
|
|
"trainable_role": True, |
|
|
"stop": False, |
|
|
"endofturn": True, |
|
|
} |
|
|
turns.append(user) |
|
|
turns.append(assistant) |
|
|
else: |
|
|
raise ConditionalError( |
|
|
f"wrong task type in wds config: {sample['vlm']['data_name']}" |
|
|
) |
|
|
sample["data"] = turns |
|
|
return sample |
|
|
|
|
|
def preprocess_new(self, sample): |
|
|
|
|
|
config = sample.get("vlm", {}) |
|
|
if config["data_type"] in ["sft1", "datalake"]: |
|
|
default_config = copy.deepcopy(self.default_config) |
|
|
default_config.update(config) |
|
|
config = default_config |
|
|
idx_for_debug = sample.get("idx", -1) |
|
|
turns = sample["data"] if "data" in sample else sample["messages"] |
|
|
|
|
|
if self.random_system_prompt and self.rng.random() < config.get( |
|
|
"random_system_prob", 0.0 |
|
|
): |
|
|
for turn in turns: |
|
|
if turn["role"] == "system": |
|
|
turn["content"] = self.random_system_prompt |
|
|
break |
|
|
|
|
|
if sample.get("tools", None) is None: |
|
|
sample["tools"] = [] |
|
|
|
|
|
if len(sample["tools"]) == 0: |
|
|
if ( |
|
|
self.rng.random() < config.get("random_tool_prob", 0.005) |
|
|
and len(self.common_tools) > 0 |
|
|
): |
|
|
|
|
|
max_n_tools = min(7, len(self.common_tools)) |
|
|
tool_counts = np.arange(1, max_n_tools + 1) |
|
|
tool_count_weights = 1.0 / tool_counts |
|
|
tool_count_weights = tool_count_weights / tool_count_weights.sum() |
|
|
n_tools = int(self.rng.choice(tool_counts, p=tool_count_weights)) |
|
|
|
|
|
idxs = np.arange(len(self.common_tools)) |
|
|
weights = 1.0 / (idxs + 1) |
|
|
weights[0] += 1.0 |
|
|
weights = weights / weights.sum() |
|
|
|
|
|
chosen_indices = self.rng.choice( |
|
|
len(self.common_tools), size=n_tools, replace=False, p=weights |
|
|
) |
|
|
|
|
|
self.rng.shuffle(chosen_indices) |
|
|
|
|
|
sample["tools"] = [self.common_tools[i] for i in chosen_indices] |
|
|
|
|
|
if "tools" in sample and sample["tools"]: |
|
|
tool_prompt = [] |
|
|
tool_prompt.append("# Tools\n\n") |
|
|
tool_prompt.append( |
|
|
"You may call one or more functions to assist with the user query.\n\n" |
|
|
) |
|
|
tool_prompt.append( |
|
|
"You are provided with function signatures within <tools></tools> XML tags:\n" |
|
|
) |
|
|
tool_prompt.append("<tools>\n") |
|
|
for tool in sample["tools"]: |
|
|
tool_prompt.append(json.dumps(tool, ensure_ascii=False)) |
|
|
tool_prompt.append("\n</tools>\n\n") |
|
|
tool_prompt.append( |
|
|
"For each function call, output the function name and arguments within the following XML format:\n" |
|
|
) |
|
|
tool_prompt.append("<tool_call>{function-name}\n") |
|
|
tool_prompt.append("<arg_key>{arg-key-1}</arg_key>\n") |
|
|
tool_prompt.append("<arg_value>{arg-value-1}</arg_value>\n") |
|
|
tool_prompt.append("<arg_key>{arg-key-2}</arg_key>\n") |
|
|
tool_prompt.append("<arg_value>{arg-value-2}</arg_value>\n") |
|
|
tool_prompt.append("...\n") |
|
|
tool_prompt.append("</tool_call>") |
|
|
|
|
|
tool_prompt = "".join(tool_prompt) |
|
|
else: |
|
|
tool_prompt = "" |
|
|
|
|
|
multiturn_n_sample = config.get("multiturn_n_samples", 0) |
|
|
if multiturn_n_sample > 0 and self.mode == "train": |
|
|
turns = self._sampling_multiturn( |
|
|
turns, |
|
|
multiturn_n_sample, |
|
|
multiturn_preserve_order=config.get("multiturn_preserve_order", True), |
|
|
multiturn_continuous=config.get("multiturn_continuous", False), |
|
|
) |
|
|
|
|
|
for i, turn in enumerate(turns): |
|
|
if turn["role"] == "user": |
|
|
if "img_src" in turn: |
|
|
turns[i]["image_urls"] = turn["img_src"] |
|
|
turns[i]["image_metas"] = turn["meta"] |
|
|
for j, turn_img_meta in enumerate(turns[i]["image_metas"]): |
|
|
if "entities" in turn_img_meta: |
|
|
turns[i]["image_metas"][j]["lens"] = turn_img_meta[ |
|
|
"entities" |
|
|
] |
|
|
turns[i]["meta"] = {} |
|
|
|
|
|
max_image_cnt = config.get("max_image_cnt", 20) |
|
|
if max_image_cnt > 0 and config["data_type"] != "sft1": |
|
|
n_imgs = {} |
|
|
for i, turn in enumerate(turns): |
|
|
if turn["role"] == "user": |
|
|
n_imgs[i] = len(turn.get("image_urls", [])) |
|
|
assert ( |
|
|
n_imgs[i] <= max_image_cnt |
|
|
), "skip sample if image_nums exceeds max_image_count per turn" |
|
|
|
|
|
if sum(n_imgs.values()) > max_image_cnt: |
|
|
img_count = 0 |
|
|
for k, v in reversed(list(n_imgs.items())): |
|
|
img_count += v |
|
|
if img_count > max_image_cnt: |
|
|
break |
|
|
|
|
|
img_count = sum(n_imgs.values()) - max_image_cnt |
|
|
|
|
|
for i in range(k + 1): |
|
|
if turns[i]["role"] == "user": |
|
|
turns[i]["content"], n_removed1 = re.subn( |
|
|
r"<image_\d{2}>", |
|
|
"", |
|
|
turns[i]["content"].strip(), |
|
|
count=img_count, |
|
|
) |
|
|
img_count -= n_removed1 |
|
|
turns[i]["content"], n_removed2 = re.subn( |
|
|
r"<\|image\|>", |
|
|
"", |
|
|
turns[i]["content"].strip(), |
|
|
count=img_count, |
|
|
) |
|
|
img_count -= n_removed2 |
|
|
n_removed_imgs = n_removed1 + n_removed2 |
|
|
turns[i]["image_urls"] = turns[i]["image_urls"][n_removed_imgs:] |
|
|
|
|
|
if n_removed_imgs > 0 and len(turns[i]["image_urls"]) == 0: |
|
|
idx = i |
|
|
while True: |
|
|
idx += 1 |
|
|
turns[idx]["trainable_role"] = False |
|
|
turns[idx]["trainable_content"] = False |
|
|
if turns[idx]["role"] == "assistant": |
|
|
break |
|
|
|
|
|
n_imgs_after = {} |
|
|
for i, turn in enumerate(turns): |
|
|
if turn["role"] == "user": |
|
|
n_imgs_after[i] = len(turn.get("image_urls", [])) |
|
|
assert sum(n_imgs_after.values()) > 0, "The n_imgs of vlm data is zero." |
|
|
|
|
|
n_mm_after = {} |
|
|
for i, turn in enumerate(turns): |
|
|
if turn["role"] == "user" or turn["role"] == "assistant": |
|
|
n_mm_after[i] = ( |
|
|
len(turn.get("image_urls", [])) |
|
|
+ len(turn.get("video_urls", [])) |
|
|
+ len(turn.get("audio_urls", [])) |
|
|
) |
|
|
assert sum(n_mm_after.values()) > 0, "The n_mm of omni data is zero." |
|
|
|
|
|
queries, gts = list(), list() |
|
|
output = Processed_sample( |
|
|
input_str="", |
|
|
input_ids=[], |
|
|
label_ids=[], |
|
|
imgs=[], |
|
|
discrete_imgs=[], |
|
|
videos=[], |
|
|
videos_duration=[], |
|
|
video_audios=[], |
|
|
audios=[], |
|
|
audios_duration=[], |
|
|
discrete_audios=[], |
|
|
sample_mm_counter={ |
|
|
"image": 0, |
|
|
"video": 0, |
|
|
"audio": 0, |
|
|
}, |
|
|
) |
|
|
system_role_count = 0 |
|
|
last_user_idx = max( |
|
|
(i for i, d in enumerate(turns) if d.get("role") == "user"), default=-1 |
|
|
) |
|
|
for i, turn in enumerate(turns): |
|
|
if turn["role"] == "tool_list": |
|
|
continue |
|
|
|
|
|
elif turn["role"] == "system": |
|
|
if config.get("llava_pretrain", False): |
|
|
continue |
|
|
output = Preprocessor.prompt_system( |
|
|
turn=turn, |
|
|
output=output, |
|
|
tokenizer=self.tokenizer, |
|
|
seed=self.rng, |
|
|
tool_prompt=tool_prompt, |
|
|
system_role_count=system_role_count, |
|
|
) |
|
|
system_role_count += 1 |
|
|
|
|
|
elif turn["role"].startswith("user"): |
|
|
output = Preprocessor.load_mm( |
|
|
output=output, |
|
|
img_dir=config.get("img_dir", ""), |
|
|
turn=turn, |
|
|
prepare_input_fn=self.prepare_input_fn, |
|
|
max_image_cnt=max_image_cnt, |
|
|
video_max_num_frames=self.video_max_num_frames, |
|
|
video_max_pixels=self.video_max_pixels, |
|
|
use_audio=self.train_audio, |
|
|
) |
|
|
output = Preprocessor.prompt_user( |
|
|
output=output, |
|
|
tokenizer=self.tokenizer, |
|
|
turn=turn, |
|
|
is_train=True if self.mode == "train" else False, |
|
|
fixed_mime=config.get("fixed_mime", False), |
|
|
mimes=self.mimes, |
|
|
query_template=config.get("query_template", None), |
|
|
config=config, |
|
|
seed=self.rng, |
|
|
) |
|
|
|
|
|
queries.append(turn["content"].replace("<|image|>", "").strip()) |
|
|
elif turn["role"].startswith("assistant"): |
|
|
output = Preprocessor.load_mm( |
|
|
output=output, |
|
|
img_dir=config.get("img_dir", ""), |
|
|
turn=turn, |
|
|
prepare_input_fn=self.prepare_input_fn, |
|
|
max_image_cnt=max_image_cnt, |
|
|
video_max_num_frames=self.video_max_num_frames, |
|
|
video_max_pixels=self.video_max_pixels, |
|
|
use_audio=self.train_audio, |
|
|
) |
|
|
|
|
|
is_after_last_user = i > last_user_idx |
|
|
is_first_assistant_after_last_user = False |
|
|
if is_after_last_user: |
|
|
is_first_assistant_after_last_user = all( |
|
|
turns[j]["role"] != "assistant" |
|
|
for j in range(last_user_idx + 1, i) |
|
|
) |
|
|
|
|
|
output = Preprocessor.prompt_assistant( |
|
|
output=output, |
|
|
tokenizer=self.tokenizer, |
|
|
turn=turn, |
|
|
is_last_turn=is_first_assistant_after_last_user, |
|
|
is_eval=True if self.mode != "train" else False, |
|
|
is_llava_pretrain=config.get("llava_pretrain", False), |
|
|
is_after_last_user_turn=is_after_last_user, |
|
|
) |
|
|
_gts = turn["content"] |
|
|
if isinstance(_gts, str): |
|
|
_gts = [ |
|
|
_gts, |
|
|
] |
|
|
if "candidates" in turn and len(turn["candidates"]) > 0: |
|
|
for _candidates in turn["candidates"]: |
|
|
if isinstance(_candidates, str): |
|
|
_gts += [ |
|
|
_candidates, |
|
|
] |
|
|
elif isinstance(turn["candidates"][0], (list, tuple)): |
|
|
_gts += _candidates |
|
|
gts.append(_gts) |
|
|
elif turn["role"] == "tool": |
|
|
if config.get("llava_pretrain", False): |
|
|
continue |
|
|
|
|
|
output = Preprocessor.prompt_tool( |
|
|
output=output, |
|
|
tokenizer=self.tokenizer, |
|
|
turn=turn, |
|
|
need_start_tag=( |
|
|
True |
|
|
if (i == 0 or turns[i - 1].get("role") != "tool") |
|
|
else False |
|
|
), |
|
|
need_end_tag=( |
|
|
True |
|
|
if (i == (len(turns) - 1) or turns[i + 1].get("role") != "tool") |
|
|
else False |
|
|
), |
|
|
) |
|
|
else: |
|
|
if config.get("llava_pretrain", False): |
|
|
continue |
|
|
|
|
|
import pdb |
|
|
import sys |
|
|
|
|
|
class ForkedPdb(pdb.Pdb): |
|
|
"""A Pdb subclass that may be used from a forked multiprocessing child""" |
|
|
|
|
|
def interaction(self, *args, **kwargs): |
|
|
_stdin = sys.stdin |
|
|
try: |
|
|
sys.stdin = open("/dev/stdin") |
|
|
pdb.Pdb.interaction(self, *args, **kwargs) |
|
|
finally: |
|
|
sys.stdin = _stdin |
|
|
|
|
|
ForkedPdb().set_trace() |
|
|
output = Preprocessor.prompt_etc( |
|
|
output=output, |
|
|
tokenizer=self.tokenizer, |
|
|
turn=turn, |
|
|
) |
|
|
|
|
|
pixel_values = [] |
|
|
mm_query_lengths = [] |
|
|
discrete_pixel_values = [] |
|
|
image_ratios = [] |
|
|
discrete_image_query_lengths = [] |
|
|
|
|
|
labels = output.label_ids |
|
|
input_ids = output.input_ids |
|
|
total_mm_query_length = 0 |
|
|
|
|
|
is_sft1 = False |
|
|
if config["data_type"] == "sft1": |
|
|
if self.sequence_parallel_size > 1: |
|
|
if len(input_ids) % self.sequence_parallel_size != 0: |
|
|
input_ids += [self.tokenizer.pad_token_id] * ( |
|
|
self.sequence_parallel_size |
|
|
- (len(input_ids) % self.sequence_parallel_size) |
|
|
) |
|
|
labels += [IGNORE_INDEX] * ( |
|
|
self.sequence_parallel_size |
|
|
- (len(labels) % self.sequence_parallel_size) |
|
|
) |
|
|
|
|
|
input_ids = input_ids[ |
|
|
: (len(input_ids) // self.sequence_parallel_size) |
|
|
* self.sequence_parallel_size |
|
|
] |
|
|
labels = labels[ |
|
|
: (len(labels) // self.sequence_parallel_size) |
|
|
* self.sequence_parallel_size |
|
|
] |
|
|
|
|
|
input_ids = torch.tensor(input_ids[-self.decoder_max_length :]) |
|
|
labels = torch.tensor(labels[-self.decoder_max_length :]) |
|
|
is_sft1 = True |
|
|
|
|
|
dummy_preprocess_results = self.prepare_input_fn.image_processor( |
|
|
Image.new("RGB", (224, 224), (0, 0, 0)) |
|
|
) |
|
|
dummy_pixel_values = torch.from_numpy( |
|
|
np.concatenate([dummy_preprocess_results.pixel_values], axis=0) |
|
|
) |
|
|
dummy_grid_thw = torch.from_numpy( |
|
|
np.concatenate([dummy_preprocess_results.image_grid_thw], axis=0) |
|
|
) |
|
|
|
|
|
image_grid_thw = [] |
|
|
for img in output.imgs: |
|
|
w, h = img.size |
|
|
|
|
|
img = self._resize_min_edge(img) |
|
|
preprocess_results = self.prepare_input_fn.image_processor([img]) |
|
|
pixel_values.append(preprocess_results.pixel_values) |
|
|
image_grid_thw.append(preprocess_results.image_grid_thw) |
|
|
mm_query_lengths.append(preprocess_results.pixel_values.shape[0] // 4) |
|
|
|
|
|
if len(output.imgs) == 0: |
|
|
pixel_values = torch.zeros(0, 1176) |
|
|
image_grid_thw = torch.zeros(0, 3, dtype=torch.long) |
|
|
else: |
|
|
pixel_values = torch.from_numpy(np.concatenate(pixel_values, axis=0)) |
|
|
image_grid_thw = torch.from_numpy(np.concatenate(image_grid_thw, axis=0)) |
|
|
|
|
|
for img in output.discrete_imgs: |
|
|
w, h = img.size |
|
|
|
|
|
img_ratio = self._find_best_ratio_token([h, w]) |
|
|
image_ratios.append(img_ratio) |
|
|
discrete_pixel_value = img.resize((384, 384), Image.BICUBIC) |
|
|
discrete_pixel_tensor = to_tensor(discrete_pixel_value) |
|
|
|
|
|
assert discrete_pixel_tensor.shape == ( |
|
|
3, |
|
|
384, |
|
|
384, |
|
|
), f"Unexpected discrete_pixel_tensor shape: {discrete_pixel_tensor.shape}" |
|
|
assert not torch.isnan( |
|
|
discrete_pixel_tensor |
|
|
).any(), "discrete_pixel_tensor contains NaN" |
|
|
assert not torch.isinf( |
|
|
discrete_pixel_tensor |
|
|
).any(), "discrete_pixel_tensor contains Inf" |
|
|
pixel_min = discrete_pixel_tensor.min().item() |
|
|
pixel_max = discrete_pixel_tensor.max().item() |
|
|
assert ( |
|
|
0.0 <= pixel_min <= 1.0 and 0.0 <= pixel_max <= 1.0 |
|
|
), f"discrete_pixel_tensor values out of range [0, 1]: min={pixel_min}, max={pixel_max}" |
|
|
|
|
|
discrete_pixel_values.append(discrete_pixel_tensor) |
|
|
discrete_image_query_lengths.append(729) |
|
|
|
|
|
if len(output.discrete_imgs) == 0: |
|
|
discrete_pixel_values = torch.zeros(0, 3, 384, 384) |
|
|
else: |
|
|
discrete_pixel_values = torch.stack(discrete_pixel_values, dim=0) |
|
|
|
|
|
assert discrete_pixel_values.shape[1:] == ( |
|
|
3, |
|
|
384, |
|
|
384, |
|
|
), f"Unexpected stacked discrete_pixel_values shape: {discrete_pixel_values.shape}" |
|
|
assert not torch.isnan( |
|
|
discrete_pixel_values |
|
|
).any(), "Stacked discrete_pixel_values contains NaN" |
|
|
assert not torch.isinf( |
|
|
discrete_pixel_values |
|
|
).any(), "Stacked discrete_pixel_values contains Inf" |
|
|
|
|
|
pixel_values_videos = None |
|
|
video_grid_thw = None |
|
|
if self.train_video: |
|
|
pixel_values_videos = [] |
|
|
video_grid_thw = [] |
|
|
video_query_lengths = [] |
|
|
for video in output.videos: |
|
|
preprocess_results = self.prepare_input_fn.video_processor([video]) |
|
|
pixel_values_videos.append(preprocess_results.pixel_values_videos) |
|
|
video_grid_thw.append(preprocess_results.video_grid_thw) |
|
|
video_query_lengths.append( |
|
|
preprocess_results.pixel_values_videos.shape[0] // 4 |
|
|
) |
|
|
if len(output.videos) == 0: |
|
|
pixel_values_videos = torch.zeros(0, 1176) |
|
|
video_grid_thw = torch.zeros(0, 3, dtype=torch.long) |
|
|
else: |
|
|
pixel_values_videos = torch.from_numpy( |
|
|
np.concatenate(pixel_values_videos, axis=0) |
|
|
) |
|
|
video_grid_thw = torch.from_numpy( |
|
|
np.concatenate(video_grid_thw, axis=0) |
|
|
) |
|
|
|
|
|
video_audio_values = [] |
|
|
video_audio_masks = [] |
|
|
video_audio_query_lengths = [] |
|
|
if self.train_video and hasattr(output, "video_audios") and output.video_audios: |
|
|
for idx, video_audio_chunks in enumerate(output.video_audios): |
|
|
if video_audio_chunks: |
|
|
processed_audio_values = [] |
|
|
processed_audio_masks = [] |
|
|
chunk_output_lengths = [] |
|
|
|
|
|
for chunk in video_audio_chunks: |
|
|
if isinstance(chunk, torch.Tensor): |
|
|
chunk_np = chunk.cpu().numpy() |
|
|
else: |
|
|
chunk_np = chunk |
|
|
|
|
|
preprocess_results = self.prepare_audio_input_fn( |
|
|
[chunk_np], |
|
|
sampling_rate=self.prepare_audio_input_fn.sampling_rate, |
|
|
return_attention_mask=True, |
|
|
padding="max_length", |
|
|
) |
|
|
|
|
|
audio_value = preprocess_results.input_features[0] |
|
|
audio_mask = preprocess_results.attention_mask[0] |
|
|
|
|
|
mask_sum = int(audio_mask.sum()) |
|
|
input_lengths = (mask_sum - 1) // 2 + 1 |
|
|
output_lengths = (input_lengths - 2) // 2 + 1 |
|
|
chunk_output_lengths.append(output_lengths) |
|
|
|
|
|
processed_audio_values.append(torch.from_numpy(audio_value)) |
|
|
processed_audio_masks.append(torch.from_numpy(audio_mask)) |
|
|
|
|
|
pool_size = 25 |
|
|
if self.video_audio_compressor_type is not None: |
|
|
total_valid_len = sum(chunk_output_lengths) |
|
|
total_audio_query_length = ( |
|
|
total_valid_len + pool_size - 1 |
|
|
) // pool_size |
|
|
else: |
|
|
total_audio_query_length = sum( |
|
|
(valid_len + pool_size - 1) // pool_size |
|
|
for valid_len in chunk_output_lengths |
|
|
) |
|
|
|
|
|
video_audio_values.append(processed_audio_values) |
|
|
video_audio_masks.append(processed_audio_masks) |
|
|
video_audio_query_lengths.append(total_audio_query_length) |
|
|
|
|
|
import os |
|
|
|
|
|
if ( |
|
|
int(os.environ.get("RANK", -1)) == 0 |
|
|
and total_audio_query_length == 177 |
|
|
): |
|
|
print( |
|
|
f"\n[PREPROCESSOR VIDEO - 177 TOKENS DETECTED!] total_audio_query_length={total_audio_query_length}, num_chunks={len(processed_audio_masks)}" |
|
|
) |
|
|
for chunk_idx, mask_tensor in enumerate(processed_audio_masks): |
|
|
chunk_mask_sum = int(mask_tensor.sum()) |
|
|
chunk_input_len = (chunk_mask_sum - 1) // 2 + 1 |
|
|
chunk_output_len = (chunk_input_len - 2) // 2 + 1 |
|
|
chunk_pooled = (chunk_output_len + 24) // 25 |
|
|
print( |
|
|
f" Chunk {chunk_idx}: mask_sum={chunk_mask_sum}, output_len={chunk_output_len}, pooled={chunk_pooled}" |
|
|
) |
|
|
print() |
|
|
|
|
|
else: |
|
|
video_audio_values.append([]) |
|
|
video_audio_masks.append([]) |
|
|
video_audio_query_lengths.append(0) |
|
|
|
|
|
dummy_video_preprocess_results = self.prepare_input_fn.video_processor( |
|
|
[Image.new("RGB", (224, 224), (0, 0, 0))] * 3 |
|
|
) |
|
|
dummy_pixel_values_videos = torch.from_numpy( |
|
|
np.concatenate([dummy_video_preprocess_results.pixel_values_videos], axis=0) |
|
|
) |
|
|
dummy_video_grid_thw = torch.from_numpy( |
|
|
np.concatenate([dummy_video_preprocess_results.video_grid_thw], axis=0) |
|
|
) |
|
|
dummy_video_preprocess_results = self.prepare_audio_input_fn( |
|
|
[np.zeros(self.prepare_audio_input_fn.sampling_rate * 3, dtype=np.float32)], |
|
|
sampling_rate=self.prepare_audio_input_fn.sampling_rate, |
|
|
return_attention_mask=True, |
|
|
padding="max_length", |
|
|
) |
|
|
dummy_video_audio_values = torch.from_numpy( |
|
|
dummy_video_preprocess_results.input_features |
|
|
) |
|
|
dummy_video_audio_masks = torch.from_numpy( |
|
|
dummy_video_preprocess_results.attention_mask |
|
|
) |
|
|
|
|
|
audio_values = None |
|
|
discrete_audio_values = None |
|
|
audio_masks = None |
|
|
dummy_preprocess_results = self.prepare_audio_input_fn( |
|
|
[np.zeros(self.prepare_audio_input_fn.sampling_rate * 3, dtype=np.float32)], |
|
|
sampling_rate=self.prepare_audio_input_fn.sampling_rate, |
|
|
return_attention_mask=True, |
|
|
padding="max_length", |
|
|
) |
|
|
dummy_audio_values = torch.from_numpy(dummy_preprocess_results.input_features) |
|
|
dummy_audio_masks = torch.from_numpy(dummy_preprocess_results.attention_mask) |
|
|
if self.train_audio: |
|
|
audio_values = [] |
|
|
discrete_audio_values = [] |
|
|
audio_masks = [] |
|
|
audio_query_lengths = [] |
|
|
discrete_audio_query_lengths = [] |
|
|
|
|
|
if len(output.audios) > 99: |
|
|
raise ConditionalError( |
|
|
f"Too many audio segments in one sample: {len(output.audios)} audios." |
|
|
) |
|
|
|
|
|
for audio in output.audios: |
|
|
chunks = [] |
|
|
for i in range( |
|
|
0, len(audio), 30 * self.prepare_audio_input_fn.sampling_rate |
|
|
): |
|
|
chunks.append( |
|
|
audio[i : i + 30 * self.prepare_audio_input_fn.sampling_rate] |
|
|
) |
|
|
num_of_chunks = len(chunks) |
|
|
preprocess_results = self.prepare_audio_input_fn( |
|
|
chunks, |
|
|
sampling_rate=self.prepare_audio_input_fn.sampling_rate, |
|
|
return_attention_mask=True, |
|
|
padding="max_length", |
|
|
) |
|
|
audio_value = preprocess_results.input_features |
|
|
audio_mask = preprocess_results.attention_mask |
|
|
audio_values.append(audio_value) |
|
|
audio_masks.append(audio_mask) |
|
|
input_lengths = int(audio_mask.sum()) |
|
|
input_lengths = (input_lengths - 1) // 2 + 1 |
|
|
output_lengths = (input_lengths - 2) // 2 + 1 |
|
|
audio_query_lengths.append(output_lengths) |
|
|
|
|
|
if len(output.audios) == 0: |
|
|
audio_values = torch.zeros(0, 128, 3000) |
|
|
audio_masks = torch.zeros(0, 3000) |
|
|
else: |
|
|
audio_values = torch.from_numpy(np.concatenate(audio_values, axis=0)) |
|
|
audio_masks = torch.from_numpy(np.concatenate(audio_masks, axis=0)) |
|
|
|
|
|
for audio in output.discrete_audios: |
|
|
audio_length = len(audio) |
|
|
|
|
|
assert audio_length >= MIN_DISCRETE_AUDIO_CHUNK_SAMPLES, ( |
|
|
f"discrete_audio is too short ({audio_length} samples < {MIN_DISCRETE_AUDIO_CHUNK_SAMPLES}). " |
|
|
f"This will cause 0-dim/empty tensor in CosyVoice encoder. " |
|
|
f"Skip this sample." |
|
|
) |
|
|
|
|
|
max_audio_length = 600 * DEFAULT_SAMPLE_RATE |
|
|
audio_duration_sec = audio_length / DEFAULT_SAMPLE_RATE |
|
|
assert ( |
|
|
audio_length <= max_audio_length |
|
|
), f"discrete_audio is too long ({audio_length} samples = {audio_duration_sec:.1f}s > 600s). " |
|
|
|
|
|
assert not torch.isnan(audio).any(), ( |
|
|
f"discrete_audio contains NaN values! " |
|
|
f"This will cause CUDA illegal memory access. Skip this sample." |
|
|
) |
|
|
assert not torch.isinf(audio).any(), ( |
|
|
f"discrete_audio contains Inf values! " |
|
|
f"This will cause CUDA illegal memory access. Skip this sample." |
|
|
) |
|
|
|
|
|
audio_min, audio_max = audio.min().item(), audio.max().item() |
|
|
assert -100.0 <= audio_min <= 100.0 and -100.0 <= audio_max <= 100.0, ( |
|
|
f"discrete_audio has extreme values (min={audio_min:.2f}, max={audio_max:.2f}). " |
|
|
f"Expected roughly [-1, 1] range. This indicates corrupted audio. Skip this sample." |
|
|
) |
|
|
|
|
|
discrete_audio_values.append(audio) |
|
|
|
|
|
if audio_length > 80 * DEFAULT_SAMPLE_RATE: |
|
|
chunk_size = 80 * DEFAULT_SAMPLE_RATE |
|
|
|
|
|
total_code_len = 0 |
|
|
|
|
|
for start in range(0, audio_length, chunk_size): |
|
|
end = min(start + chunk_size, audio_length) |
|
|
|
|
|
if ( |
|
|
end < audio_length |
|
|
and audio_length - end < MIN_DISCRETE_AUDIO_CHUNK_SAMPLES |
|
|
): |
|
|
end = audio_length |
|
|
|
|
|
chunk_length = end - start |
|
|
|
|
|
assert chunk_length >= MIN_DISCRETE_AUDIO_CHUNK_SAMPLES, ( |
|
|
f"chunk_length={chunk_length} < {MIN_DISCRETE_AUDIO_CHUNK_SAMPLES}. This should never happen with our chunking logic. " |
|
|
f"audio_length={audio_length}, start={start}, end={end}. Skip this sample." |
|
|
) |
|
|
|
|
|
mel_len = chunk_length // 160 |
|
|
|
|
|
assert mel_len > 0, ( |
|
|
f"mel_len={mel_len} is invalid (chunk_length={chunk_length}). " |
|
|
f"This will cause illegal memory access in AudioEncoder. Skip this sample." |
|
|
) |
|
|
|
|
|
after_conv1 = (mel_len + 2 * 1 - 1 * (3 - 1) - 1) // 2 + 1 |
|
|
code_len = (after_conv1 + 2 * 1 - 1 * (3 - 1) - 1) // 2 + 1 |
|
|
|
|
|
assert code_len > 0, ( |
|
|
f"code_len={code_len} is invalid (mel_len={mel_len}, after_conv1={after_conv1}). " |
|
|
f"This will cause illegal memory access. Skip this sample." |
|
|
) |
|
|
|
|
|
total_code_len += code_len |
|
|
|
|
|
if end >= audio_length: |
|
|
break |
|
|
|
|
|
assert total_code_len > 0, ( |
|
|
f"total_code_len={total_code_len} is invalid after processing all chunks. " |
|
|
f"audio_length={audio_length}. This should never happen. Skip this sample." |
|
|
) |
|
|
|
|
|
audio_duration_sec = audio_length / DEFAULT_SAMPLE_RATE |
|
|
max_expected_codes = int(audio_duration_sec * 25 * 1.1) |
|
|
assert total_code_len <= max_expected_codes, ( |
|
|
f"total_code_len={total_code_len} is suspiciously large (max_expected={max_expected_codes}). " |
|
|
f"audio_length={audio_length} ({audio_duration_sec:.1f}s). " |
|
|
f"Expected ~{int(audio_duration_sec * 25)} tokens (25 tokens/sec). " |
|
|
f"This indicates calculation error. Skip this sample." |
|
|
) |
|
|
|
|
|
discrete_audio_query_lengths.append(total_code_len) |
|
|
else: |
|
|
mel_len = audio_length // 160 |
|
|
|
|
|
assert mel_len > 0, ( |
|
|
f"mel_len={mel_len} is invalid (audio_length={audio_length}). " |
|
|
f"This will cause illegal memory access in AudioEncoder. Skip this sample." |
|
|
) |
|
|
|
|
|
after_conv1 = (mel_len + 2 * 1 - 1 * (3 - 1) - 1) // 2 + 1 |
|
|
code_len = (after_conv1 + 2 * 1 - 1 * (3 - 1) - 1) // 2 + 1 |
|
|
|
|
|
assert code_len > 0, ( |
|
|
f"Calculated code_len={code_len} is invalid (audio_length={audio_length}, " |
|
|
f"mel_len={mel_len}, after_conv1={after_conv1}). " |
|
|
f"This indicates corrupted audio data. Skip this sample." |
|
|
) |
|
|
|
|
|
assert code_len <= 2048, ( |
|
|
f"code_len={code_len} exceeds freqs_cis max length (2048). " |
|
|
f"Audio length: {audio_length / DEFAULT_SAMPLE_RATE:.1f}s (max ~82s for single chunk). " |
|
|
f"Expected ~{int((audio_length / DEFAULT_SAMPLE_RATE) * 25)} tokens at 25 tokens/sec. " |
|
|
f"This will cause illegal memory access in apply_rotary_emb. Skip this sample." |
|
|
) |
|
|
|
|
|
discrete_audio_query_lengths.append(code_len) |
|
|
|
|
|
img_start_ids = [ |
|
|
i for i, token in enumerate(input_ids) if token == self.img_token |
|
|
] |
|
|
assert len(img_start_ids) == len(mm_query_lengths) |
|
|
for i, length in zip( |
|
|
range(len(mm_query_lengths) - 1, -1, -1), mm_query_lengths[::-1] |
|
|
): |
|
|
labels[img_start_ids[i] : img_start_ids[i] + 1] = [IGNORE_INDEX] * length |
|
|
input_ids[img_start_ids[i] : img_start_ids[i] + 1] = [ |
|
|
self.img_token |
|
|
] * length |
|
|
total_mm_query_length += length |
|
|
|
|
|
discrete_image_start_ids = [ |
|
|
i for i, token in enumerate(input_ids) if token == self.discrete_image_token |
|
|
] |
|
|
assert len(discrete_image_start_ids) == len(discrete_image_query_lengths) |
|
|
assert len(discrete_image_start_ids) == len( |
|
|
image_ratios |
|
|
), "discrete_image_start_ids and image_ratios length mismatch" |
|
|
|
|
|
for idx in range(len(discrete_image_query_lengths) - 1, -1, -1): |
|
|
i = discrete_image_start_ids[idx] |
|
|
length = discrete_image_query_lengths[idx] |
|
|
ratio_token_id = image_ratios[idx] |
|
|
assert ( |
|
|
length == 729 |
|
|
), f"discrete_image_query_length must be 729, but got {length}" |
|
|
|
|
|
token_sequence = [ratio_token_id] |
|
|
for token_idx in range(length): |
|
|
token_sequence.append(self.discrete_image_token) |
|
|
if (token_idx + 1) % 27 == 0: |
|
|
token_sequence.append(self.discrete_image_eol_token) |
|
|
token_sequence.append(self.discrete_image_eof_token) |
|
|
|
|
|
total_length = len(token_sequence) |
|
|
if labels[i] == IGNORE_INDEX: |
|
|
labels[i : i + 1] = [IGNORE_INDEX] * total_length |
|
|
else: |
|
|
labels[i : i + 1] = token_sequence |
|
|
input_ids[i : i + 1] = token_sequence |
|
|
|
|
|
if self.train_video: |
|
|
vid_start_ids = [ |
|
|
i for i, token in enumerate(input_ids) if token == self.video_token |
|
|
] |
|
|
|
|
|
for idx in range(len(vid_start_ids) - 1, -1, -1): |
|
|
pos = vid_start_ids[idx] |
|
|
|
|
|
num_frames = int(video_grid_thw[idx][0]) |
|
|
frame_query_length = video_query_lengths[idx] |
|
|
|
|
|
has_video_audio = ( |
|
|
idx < len(video_audio_query_lengths) |
|
|
and video_audio_query_lengths[idx] > 0 |
|
|
) |
|
|
|
|
|
if has_video_audio: |
|
|
total_audio_tokens = video_audio_query_lengths[idx] |
|
|
|
|
|
token_sequence = [] |
|
|
|
|
|
if num_frames > 0: |
|
|
|
|
|
frame_base = frame_query_length // num_frames |
|
|
frame_remainder = frame_query_length % num_frames |
|
|
|
|
|
assert frame_remainder == 0, ( |
|
|
f"frame_query_length({frame_query_length}) must be divisible by num_frames({num_frames}). " |
|
|
f"Each frame produces fixed number of tokens. Got remainder={frame_remainder}." |
|
|
) |
|
|
|
|
|
audio_base = total_audio_tokens // num_frames |
|
|
audio_remainder = total_audio_tokens % num_frames |
|
|
|
|
|
for frame_idx in range(num_frames): |
|
|
frame_tokens = frame_base + ( |
|
|
1 if frame_idx < frame_remainder else 0 |
|
|
) |
|
|
token_sequence.extend([self.video_token] * frame_tokens) |
|
|
|
|
|
audio_tokens = audio_base + ( |
|
|
1 if frame_idx < audio_remainder else 0 |
|
|
) |
|
|
if audio_tokens > 0: |
|
|
token_sequence.extend( |
|
|
[self.video_audio_token] * audio_tokens |
|
|
) |
|
|
else: |
|
|
token_sequence = [self.video_token] * frame_query_length |
|
|
else: |
|
|
token_sequence = [self.video_token] * frame_query_length |
|
|
|
|
|
total_length = len(token_sequence) |
|
|
labels[pos : pos + 1] = [IGNORE_INDEX] * total_length |
|
|
input_ids[pos : pos + 1] = token_sequence |
|
|
|
|
|
if self.train_audio: |
|
|
audio_start_ids = [ |
|
|
i for i, token in enumerate(input_ids) if token == self.audio_token |
|
|
] |
|
|
assert len(audio_start_ids) == len(audio_query_lengths) |
|
|
for i, length in zip( |
|
|
range(len(audio_query_lengths) - 1, -1, -1), audio_query_lengths[::-1] |
|
|
): |
|
|
labels[audio_start_ids[i] : audio_start_ids[i] + 1] = [ |
|
|
IGNORE_INDEX |
|
|
] * length |
|
|
input_ids[audio_start_ids[i] : audio_start_ids[i] + 1] = [ |
|
|
self.audio_token |
|
|
] * length |
|
|
|
|
|
discrete_audio_start_ids = [ |
|
|
i |
|
|
for i, token in enumerate(input_ids) |
|
|
if token == self.discrete_audio_token |
|
|
] |
|
|
|
|
|
assert len(discrete_audio_start_ids) == len(discrete_audio_query_lengths), ( |
|
|
f"discrete_audio_start_ids count ({len(discrete_audio_start_ids)}) != " |
|
|
f"discrete_audio_query_lengths count ({len(discrete_audio_query_lengths)}). " |
|
|
f"This indicates a serious bug in preprocessor or corrupted data. Skip this sample." |
|
|
) |
|
|
|
|
|
for i, length in zip( |
|
|
range(len(discrete_audio_query_lengths) - 1, -1, -1), |
|
|
discrete_audio_query_lengths[::-1], |
|
|
): |
|
|
assert 0 < length < 16000, ( |
|
|
f"discrete_audio_query_length={length} is out of valid range [1, 16000). " |
|
|
f"Expected max ~15,000 for 600s audio at 25 tokens/sec. " |
|
|
f"This can cause illegal memory access when creating embeddings. Skip this sample." |
|
|
) |
|
|
|
|
|
if labels[discrete_audio_start_ids[i]] == IGNORE_INDEX: |
|
|
labels[ |
|
|
discrete_audio_start_ids[i] : discrete_audio_start_ids[i] + 1 |
|
|
] = [IGNORE_INDEX] * length |
|
|
else: |
|
|
labels[ |
|
|
discrete_audio_start_ids[i] : discrete_audio_start_ids[i] + 1 |
|
|
] = [self.discrete_audio_token] * length |
|
|
input_ids[ |
|
|
discrete_audio_start_ids[i] : discrete_audio_start_ids[i] + 1 |
|
|
] = [self.discrete_audio_token] * length |
|
|
|
|
|
if self.sequence_parallel_size > 1: |
|
|
if len(input_ids) % self.sequence_parallel_size != 0: |
|
|
input_ids += [self.tokenizer.pad_token_id] * ( |
|
|
self.sequence_parallel_size |
|
|
- (len(input_ids) % self.sequence_parallel_size) |
|
|
) |
|
|
labels += [IGNORE_INDEX] * ( |
|
|
self.sequence_parallel_size |
|
|
- (len(labels) % self.sequence_parallel_size) |
|
|
) |
|
|
|
|
|
if not is_sft1: |
|
|
input_ids = torch.tensor(input_ids) |
|
|
labels = torch.tensor(labels) |
|
|
|
|
|
if self.mode == "train": |
|
|
if self.sample_min_length is not None and self.sample_min_length > 0: |
|
|
assert ( |
|
|
len(labels) >= self.sample_min_length |
|
|
), "The sample is too short: {} < {}".format( |
|
|
len(labels), self.sample_min_length |
|
|
) |
|
|
assert ( |
|
|
len(labels) <= self.decoder_max_length |
|
|
), "The sample exceeds decoder_max_len: {} > {}".format( |
|
|
len(labels), self.decoder_max_length |
|
|
) |
|
|
assert len(input_ids) == len(labels) |
|
|
|
|
|
if len(labels) < 30: |
|
|
raise ConditionalError( |
|
|
"The sample is too short: {}".format(len(labels)) |
|
|
) |
|
|
|
|
|
if torch.all(labels == IGNORE_INDEX): |
|
|
raise ConditionalError( |
|
|
"Labels contain only IGNORE_INDEX, no training targets available" |
|
|
) |
|
|
|
|
|
sample = { |
|
|
"pixel_values": pixel_values, |
|
|
"discrete_pixel_values": discrete_pixel_values, |
|
|
"idx_for_debug": idx_for_debug, |
|
|
"input_ids": input_ids, |
|
|
"labels": labels, |
|
|
"queries": queries if len(queries) > 0 else None, |
|
|
"gts": gts if len(gts) > 0 else None, |
|
|
"mm_query_lengths": mm_query_lengths, |
|
|
"non_mm_query_lengths": len(labels) - total_mm_query_length, |
|
|
"total_length": len(labels), |
|
|
"data_name": config["data_name"], |
|
|
"data_type": config["data_type"], |
|
|
"img_start_ids": img_start_ids, |
|
|
"prompt": output.input_str, |
|
|
"options": config.get("options", None), |
|
|
"image_grid_thw": image_grid_thw, |
|
|
"pixel_values_videos": pixel_values_videos, |
|
|
"video_grid_thw": video_grid_thw, |
|
|
"video_audio_values": ( |
|
|
video_audio_values if len(video_audio_values) > 0 else None |
|
|
), |
|
|
"video_audio_masks": ( |
|
|
video_audio_masks if len(video_audio_masks) > 0 else None |
|
|
), |
|
|
"audio_values": audio_values, |
|
|
"discrete_audio_values": discrete_audio_values, |
|
|
"audio_masks": audio_masks, |
|
|
"dummy_pixel_values": dummy_pixel_values, |
|
|
"dummy_grid_thw": dummy_grid_thw, |
|
|
"dummy_audio_values": dummy_audio_values, |
|
|
"dummy_audio_masks": dummy_audio_masks, |
|
|
"dummy_pixel_values_videos": dummy_pixel_values_videos, |
|
|
"dummy_video_grid_thw": dummy_video_grid_thw, |
|
|
"dummy_video_audio_values": dummy_video_audio_values, |
|
|
"dummy_video_audio_masks": dummy_video_audio_masks, |
|
|
} |
|
|
|
|
|
return sample |
|
|
|
|
|
def _sampling_multiturn( |
|
|
self, |
|
|
turns, |
|
|
n_sample, |
|
|
multiturn_preserve_order=True, |
|
|
multiturn_continuous=False, |
|
|
): |
|
|
new_turns = [] |
|
|
sample_indices = [] |
|
|
first_user_turn = True |
|
|
start_idx = 0 |
|
|
for idx, turn in enumerate(turns): |
|
|
if turn["role"] in ["system", "tool_list"]: |
|
|
new_turns.append(turn) |
|
|
start_idx = idx + 1 |
|
|
continue |
|
|
if turn["role"] == "user": |
|
|
image_nums = re.findall(r"<image_(\d+)>", turn["content"]) |
|
|
if len(image_nums) == 0: |
|
|
image_nums = re.findall(r"<\|image\|>", turn["content"]) |
|
|
if len(image_nums) > 0: |
|
|
if first_user_turn: |
|
|
first_user_turn = False |
|
|
continue |
|
|
sample_indices.append([i for i in range(start_idx, idx)]) |
|
|
start_idx = idx |
|
|
sample_indices.append([i for i in range(start_idx, idx + 1)]) |
|
|
n_sample = min(n_sample, len(sample_indices)) |
|
|
if multiturn_continuous: |
|
|
start_index = random.randint(0, len(sample_indices) - n_sample) |
|
|
indices = range(start_index, start_index + n_sample) |
|
|
elif multiturn_preserve_order: |
|
|
indices = sorted(random.sample(range(len(sample_indices)), n_sample)) |
|
|
else: |
|
|
indices = random.sample(range(len(sample_indices)), n_sample) |
|
|
sampled_indices = [sample_indices[i] for i in indices] |
|
|
new_turns = new_turns + [ |
|
|
turns[i] for sampled_turns in sampled_indices for i in sampled_turns |
|
|
] |
|
|
return new_turns |
|
|
|