import contextlib import gc import inspect import json import os import time from functools import partial from pathlib import Path from typing import List, Optional, Tuple, Union import torch import torch.distributed as dist import torch.nn as nn from liger_kernel.transformers import ( LigerCrossEntropyLoss, LigerFusedLinearCrossEntropyLoss, ) from torch.nn import CrossEntropyLoss from transformers import AutoTokenizer from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_utils import is_fsdp_enabled, is_local_dist_rank_0 from hcxvlm.models.ulysses.sp_utils import ( gather_outputs_and_unpad, get_ulysses_sequence_parallel_group, get_ulysses_sequence_parallel_rank, get_ulysses_sequence_parallel_world_size, slice_input_tensor, ) from .configuration_vlm import HCXVisionConfig from .modeling_vlm import HCXVisionForCausalLM, get_rank extra_special_tokens = { "image_token": "<|IMAGE_PAD|>", "discrete_image_token": "<|DISCRETE_IMAGE_PAD|>", "discrete_image_unit_0_id": "<|vision00000|>", "video_token": "<|VIDEO_PAD|>", "video_audio_token": "<|VIDEO_AUDIO_PAD|>", "audio_token": "<|AUDIO_PAD|>", "discrete_audio_token": "<|DISCRETE_AUDIO_PAD|>", "discrete_audio_unit_0_id": "<|audio0000|>", } def load_state_dict_into_model(model_to_load, state_dict, strict=True, start_prefix=""): old_keys = [] new_keys = [] for key in state_dict.keys(): new_key = None if "gamma" in key: new_key = key.replace("gamma", "weight") if "beta" in key: new_key = key.replace("beta", "bias") if new_key: old_keys.append(key) new_keys.append(new_key) for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) metadata = getattr(state_dict, "_metadata", None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata error_msgs = [] def load(module: nn.Module, state_dict, prefix=""): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) args = (state_dict, prefix, local_metadata, strict, [], [], error_msgs) if len([key for key in state_dict if key.startswith(prefix)]) > 0: if is_deepspeed_zero3_enabled(): import deepspeed named_parameters = dict( module.named_parameters(prefix=prefix[:-1], recurse=False) ) params_to_gather = [ named_parameters[k] for k in state_dict.keys() if k in named_parameters ] if len(params_to_gather) > 0: with deepspeed.zero.GatheredParameters( params_to_gather, modifier_rank=0 ): if torch.distributed.get_rank() == 0: module._load_from_state_dict(*args) else: module._load_from_state_dict(*args) for name, child in module._modules.items(): if child is not None: load(child, state_dict, prefix + name + ".") load(model_to_load, state_dict, prefix=start_prefix) del state_dict return error_msgs def load_sharded_checkpoint( model, folder, pick_prefix="", replace_prefix_list=[], replace_prefix_dict={}, print_info=True, ): if folder is None: return {} files = os.listdir(folder) pytorch_bin_files = [ file for file in files if file.startswith("pytorch_model") and file.endswith(".bin") ] safetensor_files = [file for file in files if file.endswith(".safetensors")] shard_index_file = [file for file in files if file.endswith(".index.json")] index_present = len(shard_index_file) > 0 index_file = os.path.join(folder, shard_index_file[0]) if index_present else [] is_safetensor = len(safetensor_files) > 0 model_keys = model.state_dict().keys() if is_safetensor: from safetensors.torch import load_file load_function = load_file shard_files = safetensor_files else: load_function = partial(torch.load, map_location="cpu") shard_files = pytorch_bin_files if index_present: with open(index_file, "r", encoding="utf-8") as f: index = json.load(f) loaded_keys = index["weight_map"].keys() if pick_prefix: loaded_keys = [ k[len(pick_prefix) :] for k in loaded_keys if k.startswith(pick_prefix) ] if replace_prefix_list: for rep_prefix in replace_prefix_list: loaded_keys = [ k[len(rep_prefix) :] if k.startswith(rep_prefix) else k for k in loaded_keys ] if replace_prefix_dict: for rep_prefix in replace_prefix_dict: loaded_keys = [ ( k.replace(rep_prefix, replace_prefix_dict[rep_prefix]) if k.startswith(rep_prefix) else k ) for k in loaded_keys ] for i, shard_file in enumerate(shard_files): state_dict = load_function(os.path.join(folder, shard_file)) if pick_prefix: state_dict = { k[len(pick_prefix) :]: v for k, v in state_dict.items() if k.startswith(pick_prefix) } for rep_prefix in replace_prefix_list: state_dict = { k[len(rep_prefix) :] if k.startswith(rep_prefix) else k: v for k, v in state_dict.items() } for rep_prefix in replace_prefix_dict: state_dict = { ( k.replace(rep_prefix, replace_prefix_dict[rep_prefix]) if k.startswith(rep_prefix) else k ): v for k, v in state_dict.items() } if is_deepspeed_zero3_enabled(): rank = torch.distributed.get_rank() print(f"# [info] ZeRo3 - load sharded no {i}, rank {rank}") load_state_dict_into_model(model, state_dict, strict=False) elif is_fsdp_enabled(): if is_local_dist_rank_0(): model.load_state_dict(state_dict, strict=False) else: model.load_state_dict(state_dict, strict=False) if not index_present: loaded_keys = state_dict.keys() del state_dict gc.collect() missing_keys = [key for key in model_keys if key not in loaded_keys] unexpected_keys = [key for key in loaded_keys if key not in model_keys] if get_rank() == 0 and print_info: print(f"[info] missing_keys: {missing_keys}") print(f"[info] unexpected_keys: {unexpected_keys}") return {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys} class HCXVisionForCausalLM_VU(HCXVisionForCausalLM): def __init__(self, config, **kwargs): self.use_liger = kwargs.pop("use_liger", True) self.use_fused_ce = kwargs.pop("use_fused_ce", True) self.use_meansum_loss = kwargs.pop("use_meansum_loss", True) self.use_turnmeansum_loss = kwargs.pop("use_turnmeansum_loss", False) self.use_sqrtsum_loss = kwargs.pop("use_sqrtsum_loss", False) use_sum_loss = True if kwargs.pop("use_sum_loss", False) else False self.sequence_parallel_size = kwargs.pop("sequence_parallel_size", 1) self.sp_manager = kwargs.pop("sp_manager", None) self.train_video = kwargs.pop("train_video", False) assert ( int(self.use_meansum_loss) + int(self.use_turnmeansum_loss) + int(self.use_sqrtsum_loss) ) <= 1, "use_meansum_loss, use_turnmeansum_loss, use_sqrtsum_loss 중 둘 이상을 동시에 True로 설정할 수 없습니다." if self.use_meansum_loss or self.use_turnmeansum_loss or self.use_sqrtsum_loss: self.reduction = "none" elif use_sum_loss: self.reduction = "sum" else: self.reduction = "mean" super().__init__(config, **kwargs) if config.text_config.model_type == "hyperclovax" and self.use_liger: self.language_model._get_apply_liger_kernel_converter()( model=self.language_model ) print("[info] use liger kernel for hcx 24b") if config.freeze_encoder: for param in self.vision_model.parameters(): param.requires_grad = False assert ( all(param.requires_grad for param in self.vision_model.parameters()) == False ) @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, text_model_name_or_path: Optional[Union[str, os.PathLike]] = None, vision_model_name_or_path: Optional[Union[str, os.PathLike]] = None, discrete_vision_model_name_or_path: Optional[Union[str, os.PathLike]] = None, audio_model_name_or_path: Optional[Union[str, os.PathLike]] = None, discrete_audio_model_name_or_path: Optional[Union[str, os.PathLike]] = None, q_former_model_name_or_path: Optional[Union[str, os.PathLike]] = None, without_llm: bool = False, *model_args, **kwargs, ): """ :param pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] : pre-trained path for LLM(text_model_name_or_path) e.g. /path/to/model/ :param vision_model_name_or_path: Optional[Union[str, os.PathLike]] : pre-trained path for VisionModule(HyperClova-VisionModule) e.g. /path/to/vision/module/ :param q_former_model_name_or_path: Optional[Union[str, os.PathLike]] : pre-trained path for VLM e.g. /path/to/vlm/checkpoint/ :param without_llm: Bool: False: init/load llm weight from pre-trained True: init/load llm weight from dummy file :param model_args: :param kwargs: :return: """ assert pretrained_model_name_or_path is not None or ( text_model_name_or_path is not None and vision_model_name_or_path is not None ) cache_dirpath = kwargs.pop("cache_dirpath", None) if cache_dirpath is None: cache_dirpath = "~/.cache" runtime_only_keys = { "use_liger", "use_fused_ce", "use_meansum_loss", "use_turnmeansum_loss", "use_sqrtsum_loss", "use_sum_loss", "sequence_parallel_size", "sp_manager", "train_video", } runtime_kwargs = {} for k in list(runtime_only_keys): if k in kwargs: runtime_kwargs[k] = kwargs.pop(k) kwargs["vision_model_name_or_path"] = vision_model_name_or_path kwargs["discrete_vision_model_name_or_path"] = ( discrete_vision_model_name_or_path ) kwargs["audio_model_name_or_path"] = audio_model_name_or_path kwargs["discrete_audio_model_name_or_path"] = discrete_audio_model_name_or_path save_only_vision = ( kwargs.pop("save_only_vision") if "save_only_vision" in kwargs else False ) save_only_qformer = ( kwargs.pop("save_only_qformer") if "save_only_qformer" in kwargs else False ) save_shard_size = ( kwargs.pop("save_shard_size") if "save_shard_size" in kwargs else "5GB" ) def _purge_runtime_from_config(cfg): for rk in runtime_only_keys: if hasattr(cfg, rk): delattr(cfg, rk) template_path = "hcxvlm/dataset/chat_template.jinja" with open(template_path, "r", encoding="utf-8") as f: chat_template_str = f.read() if without_llm: assert pretrained_model_name_or_path is not None and os.path.exists( pretrained_model_name_or_path ) dummy_config = HCXVisionConfig.from_pretrained( pretrained_model_name_or_path=pretrained_model_name_or_path, *model_args, **kwargs, ) _purge_runtime_from_config(dummy_config) dummy_config.text_config.num_hidden_layers = 0 dummy_config.text_config.num_attention_heads = 1 if isinstance( dummy_config.vision_model_name_or_path, str ) and os.path.exists(dummy_config.vision_model_name_or_path): vision_model_name_or_path = dummy_config.vision_model_name_or_path assert isinstance(vision_model_name_or_path, str) and os.path.exists( vision_model_name_or_path ), f"# [error] invalid vision_model_name_or_path: {vision_model_name_or_path}" dummy_config.vision_model_name_or_path = vision_model_name_or_path dummy_config.vision_config._name_or_path = vision_model_name_or_path dummy_config.vision_config.vison_pretrained_name_or_path = ( vision_model_name_or_path ) model = super().from_pretrained( pretrained_model_name_or_path=pretrained_model_name_or_path, without_llm=True, config=dummy_config, *model_args, **{**kwargs, **runtime_kwargs}, ) model.tokenizer = AutoTokenizer.from_pretrained( pretrained_model_name_or_path ) model.tokenizer.chat_template = chat_template_str model.transformer = None else: if pretrained_model_name_or_path is not None and ( audio_model_name_or_path is not None or discrete_audio_model_name_or_path is not None or discrete_vision_model_name_or_path is not None ): assert ( audio_model_name_or_path is not None and discrete_audio_model_name_or_path is not None and discrete_vision_model_name_or_path is not None ) print(f"[DEBUG] image stage2 끝난 시점에서 audio 를 stage3 로 붙일때.") pt_config = HCXVisionConfig.from_pretrained( pretrained_model_name_or_path ) _purge_runtime_from_config(pt_config) config_dict = pt_config.to_dict() config_dict["audio_model_name_or_path"] = audio_model_name_or_path config_dict["discrete_audio_model_name_or_path"] = ( discrete_audio_model_name_or_path ) config_dict["discrete_vision_model_name_or_path"] = ( discrete_vision_model_name_or_path ) config = HCXVisionConfig.from_dict(config_dict) print(f"config: {config}") model = super().from_pretrained( pretrained_model_name_or_path, without_llm=False, config=config, _fast_init=False, *model_args, **kwargs, ) model.tokenizer = AutoTokenizer.from_pretrained( pretrained_model_name_or_path ) model.tokenizer.chat_template = chat_template_str elif isinstance(q_former_model_name_or_path, str): config = HCXVisionConfig.from_dict( {"text_model_name_or_path": text_model_name_or_path, **kwargs} ) _purge_runtime_from_config(config) model = super().from_pretrained( q_former_model_name_or_path, without_llm=False, config=config, _fast_init=False, *model_args, **{**kwargs, **runtime_kwargs}, ) model.tokenizer = AutoTokenizer.from_pretrained( q_former_model_name_or_path ) model.tokenizer.chat_template = chat_template_str elif pretrained_model_name_or_path is not None: config = HCXVisionConfig.from_pretrained( pretrained_model_name_or_path, *model_args, **kwargs ) _purge_runtime_from_config(config) model = super().from_pretrained( pretrained_model_name_or_path, *model_args, config=config, **runtime_kwargs, ) model.tokenizer = AutoTokenizer.from_pretrained( pretrained_model_name_or_path ) model.tokenizer.chat_template = chat_template_str else: config = HCXVisionConfig.from_dict( {"text_model_name_or_path": text_model_name_or_path, **kwargs} ) _purge_runtime_from_config(config) model = HCXVisionForCausalLM_VU( config, *model_args, **{**kwargs, **runtime_kwargs} ) model.tokenizer = AutoTokenizer.from_pretrained(text_model_name_or_path) model.tokenizer.chat_template = chat_template_str model.mm_projector.apply(model._init_weights) img_start_id = model.tokenizer.encode( extra_special_tokens["image_token"], add_special_tokens=False ) assert ( len(img_start_id) == 1 ), f'{extra_special_tokens["image_token"]} was not encoded into a single special token. Encoding result: {img_start_id}' model.config.img_start_id = img_start_id[0] model.config.image_token_id = img_start_id[0] video_start_id = model.tokenizer.encode( extra_special_tokens["video_token"], add_special_tokens=False ) assert ( len(video_start_id) == 1 ), f"video_token was not encoded into a single special token. Encoding result: {video_start_id}" model.config.video_start_id = video_start_id[0] model.config.video_token_id = video_start_id[0] video_audio_start_id = model.tokenizer.encode( extra_special_tokens["video_audio_token"], add_special_tokens=False ) assert ( len(video_audio_start_id) == 1 ), f"video_audio_token was not encoded into a single special token. Encoding result: {video_audio_start_id}" model.config.video_audio_start_id = video_audio_start_id[0] model.config.video_audio_token_id = video_audio_start_id[0] if ( audio_model_name_or_path is not None or discrete_audio_model_name_or_path is not None or discrete_vision_model_name_or_path is not None ): audio_start_id = model.tokenizer.encode( extra_special_tokens["audio_token"], add_special_tokens=False ) assert ( len(audio_start_id) == 1 ), f"audio_token was not encoded into a single special token. Encoding result: {audio_start_id}" model.config.audio_start_id = audio_start_id[0] model.config.audio_token_id = audio_start_id[0] discrete_audio_start_id = model.tokenizer.encode( extra_special_tokens["discrete_audio_token"], add_special_tokens=False ) assert ( len(discrete_audio_start_id) == 1 ), f"discrete_audio_token was not encoded into a single special token. Encoding result: {discrete_audio_start_id}" model.config.discrete_audio_start_id = discrete_audio_start_id[0] model.config.discrete_audio_token_id = discrete_audio_start_id[0] discrete_audio_unit_0_id = model.tokenizer.encode( extra_special_tokens["discrete_audio_unit_0_id"], add_special_tokens=False, ) assert ( len(discrete_audio_unit_0_id) == 1 ), f'{extra_special_tokens["discrete_audio_unit_0_id"]} was not encoded into a single special token. Encoding result: {discrete_audio_unit_0_id}' model.config.discrete_audio_unit_0_id = discrete_audio_unit_0_id[0] discrete_image_start_id = model.tokenizer.encode( extra_special_tokens["discrete_image_token"], add_special_tokens=False ) assert ( len(discrete_image_start_id) == 1 ), f'{extra_special_tokens["discrete_image_token"]} was not encoded into a single special token. Encoding result: {discrete_image_start_id}' model.config.discrete_image_start_id = discrete_image_start_id[0] model.config.discrete_image_token_id = discrete_image_start_id[0] discrete_image_unit_0_id = model.tokenizer.encode( extra_special_tokens["discrete_image_unit_0_id"], add_special_tokens=False, ) assert ( len(discrete_image_unit_0_id) == 1 ), f'{extra_special_tokens["discrete_image_unit_0_id"]} was not encoded into a single special token. Encoding result: {discrete_image_unit_0_id}' model.config.discrete_image_unit_0_id = discrete_image_unit_0_id[0] model.save_only_vision = save_only_vision model.save_only_qformer = save_only_qformer model.save_shard_size = save_shard_size if pretrained_model_name_or_path is None or ( pretrained_model_name_or_path is not None and audio_model_name_or_path is not None ): vision_model_name_or_path = kwargs.get("vision_model_name_or_path", None) if vision_model_name_or_path is not None: load_sharded_checkpoint(model.vision_model, vision_model_name_or_path) if get_rank() == 0: print("[info] vision model loading complete") discrete_vision_model_name_or_path = kwargs.get( "discrete_vision_model_name_or_path", None ) if discrete_vision_model_name_or_path is not None: model.discrete_vision_model.load_state_dict( torch.load( discrete_vision_model_name_or_path, map_location=model.device, weights_only=False, )["model"]["sd"], strict=True, ) if get_rank() == 0: print("[info] discrete vision model loading complete") audio_model_name_or_path = kwargs.get("audio_model_name_or_path", None) if audio_model_name_or_path is not None: load_sharded_checkpoint(model.audio_model, audio_model_name_or_path) if get_rank() == 0: print("[info] audio model loading complete") discrete_audio_model_name_or_path = kwargs.get( "discrete_audio_model_name_or_path", None ) if discrete_audio_model_name_or_path is not None: model.discrete_audio_model.load_state_dict( torch.load( discrete_audio_model_name_or_path, map_location=model.device, weights_only=False, ), strict=True, ) if get_rank() == 0: print("[info] discrete audio model loading complete") if text_model_name_or_path is not None: load_sharded_checkpoint(model.language_model, text_model_name_or_path) if get_rank() == 0: print("[info] text model loading complete") if isinstance(q_former_model_name_or_path, str): assert Path( q_former_model_name_or_path ).exists(), f"# [error] given q_former_name_or_path not exist: {q_former_model_name_or_path}" load_result = load_sharded_checkpoint( model, q_former_model_name_or_path, replace_prefix_dict={ "vision_model.image_encoder.model.vision_tower": "vision_model", "model": "language_model.model", "lm_head.weight": "language_model.lm_head.weight", }, print_info=False, ) if get_rank() == 0: missing_keys_summary = dict() for key in load_result["missing_keys"]: if key.split(".")[0] in missing_keys_summary: missing_keys_summary[key.split(".")[0]] += 1 else: missing_keys_summary[key.split(".")[0]] = 1 print(f"[info] missing_keys summary : {missing_keys_summary}") print("[info] q_former model loading complete") config: HCXVisionConfig = model.config if config.model_type != "vlm": model.config.model_type = "vlm" return model def _pad_sequence_for_sp( self, inputs_embeds: torch.Tensor, labels: Optional[torch.Tensor], sp_world_size: int, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Ensure sequence length is divisible by the SP group size by padding on the sequence dimension. Returns the possibly padded (inputs_embeds, labels). """ batch_size, seqlen, hidden_size = inputs_embeds.shape remainder = seqlen % sp_world_size if remainder != 0: print( f"[info] Padding sequence dimension to make it divisible by {sp_world_size}" ) pad_len = sp_world_size - remainder pad_embeds = torch.zeros( (batch_size, pad_len, hidden_size), dtype=inputs_embeds.dtype, device=inputs_embeds.device, ) inputs_embeds = torch.cat([inputs_embeds, pad_embeds], dim=1) if labels is not None: ignore_index = getattr(self.config, "ignore_index", -100) pad_labels = torch.full( (batch_size, pad_len), fill_value=ignore_index, dtype=labels.dtype, device=labels.device, ) labels = torch.cat([labels, pad_labels], dim=1) return inputs_embeds, labels def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[List[List[torch.FloatTensor]]] = None, discrete_pixel_values: Optional[List[List[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, image_sizes: Optional[List[List[List[int]]]] = None, mm_query_lengths: Optional[List[List[int]]] = None, non_mm_query_lengths: Optional[List[List[int]]] = None, img_start_ids_list: Optional[List[List[int]]] = None, num_queries_vis_abstractors: Optional[List[List[int]]] = None, num_queries_vis_abstractors_slow: Optional[List[List[int]]] = None, first_last_frames_slows: Optional[List[List[bool]]] = None, is_videos: Optional[List[List[bool]]] = None, image_grid_thw: Optional[torch.LongTensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, video_audio_values: Optional[torch.FloatTensor] = None, video_audio_masks: Optional[torch.FloatTensor] = None, audio_values: Optional[torch.FloatTensor] = None, discrete_audio_values: Optional[torch.FloatTensor] = None, discrete_audio_value_num_per_sample: Optional[torch.LongTensor] = None, audio_masks: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: """ :param input_ids: torch.int64 : torch.size([batchsize, variable)]) : SystemPrompt with Question text token indices for tokenizer. In positions where images are inputted, the value is replaced by config.img_start_id, which is a vocabulary index used to indicate the start of image data. :param pixel_values: List of List of 4D tensor (torch.float32) Each outer list corresponds to a batch and contains inner lists, each holding tensors for images in a sample. The structure accounts for samples with multiple images. :param past_key_values: None :param inputs_embeds: None :param labels: Optional[torch.int64] : [batchsize, variable (input_ids.size(1)+ num visual tokens)] visual token 들은 모두 IGNORE_INDEX :param use_cache: None :param output_attentions: Optional[bool] : get attention weights of each layers of transformer network (true: 결과값에 포함, false: 결과값에 미포함) :param output_hidden_states: Optional[bool] : get hidden states of each layers of transformer network (true: 결과값에 포함, false: 결과값에 미포함) :param return_dict: Optional[bool] : True - return dict, Fasle - return tensor :param image_sizes: Stacked as a List of List, representing image sizes (width, height). In cases where a sample contains no images, a single dummy image is included. :param mm_query_lengths: A List of List that stores the lengths when each image is converted into visual tokens for LLM input. In cases where a sample does not contain any images, an empty list is included. :param non_mm_query_lengths: contains the lengths of text tokens (excluding visual tokens) for each sample in a batch. :img_start_ids_list: contains the indices of the img_start_id tokens for each sample. :num_queries_vis_abstractors: A List of List that contains the number of visual tokens for each image grid. :num_queries_vis_abstractors_slow: A List of List that contains the number of visual tokens for the slow part when applying the slowfast algorithm to video frames. If the slowfast algorithm is not applied, it will have a value of None. :first_last_frames_slows: A List of List that contains the only first and last frames slow mode for each sample in a batch. :is_videos: A List of List that contains the boolean value indicating whether each sample in a batch is a video. :image_grid_thw: A 3D tensor (torch.int64) for qwen2.5-vl visual encoder. :pixel_values_videos: A 2D tensor (torch.float32) for qwen2.5-vl visual encoder. :video_grid_thw: A 3D tensor (torch.int64) for qwen2.5-vl visual encoder. :return: """ if self.sp_manager is not None and self.train_video: sp_group = get_ulysses_sequence_parallel_group() if sp_group is not None: sp_rank = get_ulysses_sequence_parallel_rank(sp_group) sp_world_size = get_ulysses_sequence_parallel_world_size(sp_group) if sp_rank == 0: payload = { "input_ids": input_ids, "labels": labels, "pixel_values": pixel_values, "image_grid_thw": image_grid_thw, "pixel_values_videos": pixel_values_videos, "video_grid_thw": video_grid_thw, "video_audio_values": video_audio_values, "video_audio_masks": video_audio_masks, } else: payload = { "input_ids": None, "labels": None, "pixel_values": None, "image_grid_thw": None, "pixel_values_videos": None, "video_grid_thw": None, "video_audio_values": None, "video_audio_masks": None, } obj_list = [payload] src_global_rank = dist.get_global_rank(sp_group, 0) dist.broadcast_object_list( obj_list, src=src_global_rank, group=sp_group ) payload = obj_list[0] if sp_rank != 0: device = input_ids.device input_ids = payload["input_ids"] if isinstance(input_ids, torch.Tensor): input_ids = input_ids.to(device) labels = payload["labels"] if isinstance(labels, torch.Tensor): labels = labels.to(device) image_grid_thw = payload["image_grid_thw"] if isinstance(image_grid_thw, torch.Tensor): image_grid_thw = image_grid_thw.to(device) pixel_values_videos = payload["pixel_values_videos"] if isinstance(pixel_values_videos, torch.Tensor): pixel_values_videos = pixel_values_videos.to(device) video_grid_thw = payload["video_grid_thw"] if isinstance(video_grid_thw, torch.Tensor): video_grid_thw = video_grid_thw.to(device) video_audio_values = payload["video_audio_values"] if isinstance(video_audio_values, torch.Tensor): video_audio_values = video_audio_values.to(device) video_audio_masks = payload["video_audio_masks"] if isinstance(video_audio_masks, torch.Tensor): video_audio_masks = video_audio_masks.to(device) pixel_values = payload["pixel_values"] if isinstance(pixel_values, torch.Tensor): pixel_values = pixel_values.to(device) attention_mask = None output_attentions = ( output_attentions if output_attentions is not None else self.config.vision_config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.vision_config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if inputs_embeds is None and past_key_values is None: inputs_embeds, labels = self.model.extract_inputs_embeds( input_ids=input_ids, labels=labels, pixel_values=pixel_values, discrete_pixel_values=discrete_pixel_values, past_key_values=past_key_values, image_sizes=image_sizes, mm_query_lengths=mm_query_lengths, non_mm_query_lengths=non_mm_query_lengths, img_start_ids_list=img_start_ids_list, num_queries_vis_abstractors=num_queries_vis_abstractors, num_queries_vis_abstractors_slow=num_queries_vis_abstractors_slow, first_last_frames_slows=first_last_frames_slows, is_videos=is_videos, image_grid_thw=image_grid_thw, pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, video_audio_values=video_audio_values, video_audio_masks=video_audio_masks, audio_values=audio_values, discrete_audio_values=discrete_audio_values, discrete_audio_value_num_per_sample=discrete_audio_value_num_per_sample, audio_masks=audio_masks, ) if labels is not None and labels.size(1) > 32768: print( f"[RANK {rank} debug] ❌ labels.size(1) > 32768. labels.size(): {labels.size()}" ) if inputs_embeds is not None: input_ids = None import os rank = int(os.environ.get("RANK", -1)) if inputs_embeds is not None: expected_hidden_size = self.config.text_config.hidden_size if inputs_embeds.shape[-1] != expected_hidden_size: print(f"[RANK {rank}] ❌ inputs_embeds dimension mismatch!") print( f" Expected: {expected_hidden_size}, Got: {inputs_embeds.shape[-1]}" ) if labels is not None: vocab_size = self.get_input_embeddings().num_embeddings valid_labels = labels[labels != -100] if len(valid_labels) > 0: if (valid_labels >= vocab_size).any() or (valid_labels < 0).any(): print(f"[RANK {rank}] ❌ CRITICAL: labels out of vocab range!") print( f" labels min/max: {valid_labels.min().item()}/{valid_labels.max().item()}" ) print(f" vocab_size: {vocab_size}") print( f" Out-of-range count: {(valid_labels >= vocab_size).sum().item()}" ) if attention_mask is not None and inputs_embeds is not None: if attention_mask.shape[1] != inputs_embeds.shape[1]: print(f"[RANK {rank}] ❌ attention_mask shape mismatch!") print( f" attention_mask: {attention_mask.shape}, inputs_embeds: {inputs_embeds.shape}" ) if position_ids is not None: max_position = position_ids.max().item() if hasattr(self.language_model.config, "max_position_embeddings"): max_allowed = self.language_model.config.max_position_embeddings if max_position >= max_allowed: print(f"[RANK {rank}] ❌ position_ids out of range!") print(f" max_position: {max_position}, max_allowed: {max_allowed}") if self.sp_manager is not None: batch_size, seqlen, hidden_size = inputs_embeds.shape sp_group = get_ulysses_sequence_parallel_group() sp_world_size = get_ulysses_sequence_parallel_world_size(sp_group) inputs_embeds, labels = self._pad_sequence_for_sp( inputs_embeds, labels, sp_world_size ) if position_ids is None: position_ids = torch.arange( seqlen, device=inputs_embeds.device, dtype=torch.long ) position_ids = ( position_ids.unsqueeze(0).expand(batch_size, -1).contiguous() ) inputs_embeds = slice_input_tensor( inputs_embeds, 1, padding=False, group=sp_group ) labels = slice_input_tensor(labels, 1, padding=False, group=sp_group) use_cache = False outputs = self.language_model.base_model( input_ids=input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] hidden_states = hidden_states * self.config.text_config.logits_scaling loss = None logits = None if labels is not None: if self.use_liger and self.use_fused_ce: shift_labels = labels[..., 1:].contiguous() shift_labels = shift_labels.view(-1) hidden_states = hidden_states[..., :-1, :].contiguous() hidden_states = hidden_states.view( -1, self.language_model.config.hidden_size ).to(self.language_model.lm_head.weight.dtype) import os rank = int(os.environ.get("RANK", -1)) vocab_size = self.language_model.lm_head.weight.shape[0] valid_labels = shift_labels[shift_labels != -100] if len(valid_labels) > 0 and ( (valid_labels >= vocab_size).any() or (valid_labels < 0).any() ): print( f"[RANK {rank}] ❌ CRITICAL: shift_labels out of vocab range!" ) print( f" min/max: {valid_labels.min().item()}/{valid_labels.max().item()}, vocab: {vocab_size}" ) print( f" Out-of-range count: {(valid_labels >= vocab_size).sum().item()}" ) lce = LigerFusedLinearCrossEntropyLoss(reduction=self.reduction) try: loss = lce( self.language_model.lm_head.weight, hidden_states, shift_labels ) except RuntimeError as e: print( f"[RANK {rank}] ❌ FATAL: LigerFusedLinearCrossEntropyLoss failed!" ) print(f" Error: {e}") print( f" hidden_states: shape={hidden_states.shape}, dtype={hidden_states.dtype}" ) print( f" shift_labels: shape={shift_labels.shape}, unique_values={torch.unique(shift_labels).tolist()[:20]}" ) print( f" lm_head.weight: shape={self.language_model.lm_head.weight.shape}" ) raise elif self.use_liger: logits = self.language_model.lm_head(hidden_states) shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = LigerCrossEntropyLoss(reduction=self.reduction) shift_logits = shift_logits.view(-1, self.config.text_config.vocab_size) shift_labels = shift_labels.view(-1) shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) else: logits = self.language_model.lm_head(hidden_states) shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = CrossEntropyLoss(reduction=self.reduction) shift_logits = shift_logits.view(-1, self.config.text_config.vocab_size) shift_labels = shift_labels.view(-1) shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if self.sp_manager is not None: loss = gather_outputs_and_unpad( loss, gather_dim=0, unpad_dim=0, padding_size=0, group=sp_group ) if self.use_meansum_loss: loss = loss.view(labels.size(0), -1).mean(dim=1).sum() elif self.use_sqrtsum_loss: per_token = loss.view(labels.size(0), -1) per_sample_mean = per_token.mean(dim=1) with torch.no_grad(): labels_2d = labels.view(labels.size(0), -1) ignore_index = getattr(self.config, "ignore_index", -100) valid_mask = labels_2d.ne(ignore_index) valid_count = valid_mask.sum(dim=1).clamp(min=1).float() raw_w = valid_count.sqrt() w_mean = raw_w.mean().clamp(min=1e-6) norm_w = raw_w / w_mean loss = (per_sample_mean * norm_w).sum() elif self.use_turnmeansum_loss: with torch.no_grad(): mask = shift_labels.view(labels.size(0), -1).ne( self.config.ignore_index ) prev_mask = mask.roll(shifts=1, dims=1) prev_mask[:, 0] = False turn_starts = mask & (~prev_mask) turn_count = turn_starts.sum(dim=1).clamp(min=1).float() loss = (loss.view(labels.size(0), -1).mean(dim=1) * turn_count).sum() if self.sp_manager is not None: loss = loss / self.sp_manager.device_mesh.shape[1] if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def save_pretrained( self, save_directory: Union[str, os.PathLike], *args, **kwargs, ): state_dict = ( kwargs["state_dict"] if kwargs.get("state_dict", None) else self.state_dict() ) partial_state_dict = self.get_pretrained_state_dict( state_dict, ) kwargs["state_dict"] = partial_state_dict kwargs["safe_serialization"] = self.is_safetensor_save kwargs.setdefault("max_shard_size", self.save_shard_size) super().save_pretrained(save_directory, *args, **kwargs) if self.is_qwen_visual: self.config.architectures = ["HCXVisionV2ForCausalLM"] else: self.config.architectures = ["HCXVisionForCausalLM"] self.config.auto_map["AutoModelForCausalLM"] = ( "modeling_vlm.HCXVisionForCausalLM" ) self.config.auto_map["AutoModelForSequenceClassification"] = ( "modeling_vlm.HCXVisionForSequenceClassification" ) self.config.save_pretrained(save_directory) def get_pretrained_state_dict(self, state_dict): vision_key = "vision_model." llm_keys = ["language_model."] head_key = "lm_head." for key in list(state_dict.keys()): if self.save_only_vision: for llm_key in llm_keys: if llm_key in key: state_dict.pop(key) if key.startswith(head_key): state_dict.pop(key) elif self.save_only_qformer: if f"{vision_key}" in key: state_dict.pop(key) return state_dict