diff --git a/xtuner/dataset/hybrid/collate.py b/xtuner/dataset/hybrid/collate.py index 925b9ac01..74d2218e7 100644 --- a/xtuner/dataset/hybrid/collate.py +++ b/xtuner/dataset/hybrid/collate.py @@ -16,9 +16,9 @@ def hybrid_collate_fn(instances: Sequence[Dict], pixel_values = [] cumulative_len = [] image_ranges = [] - image_belong = [] + image_belongs = [] position_ids = [] - + for i, data in enumerate(instances): input_ids.append(torch.LongTensor(data['input_ids'])) labels.append(torch.LongTensor(data['labels'])) @@ -27,28 +27,33 @@ def hybrid_collate_fn(instances: Sequence[Dict], if 'cumulative_len' in data: cumulative_len.append(torch.IntTensor(data['cumulative_len'])) - image_belong.append(i) - pixel_values.extend(data['pixel_values']) - image_ranges.extend(torch.IntTensor(data['image_ranges'])) - + + _values = data['pixel_values'] + _ranges = data['image_ranges'] + + assert len(_values) == len(_ranges) + for v, rng in zip(_values, _ranges): + pixel_values.append(v) + image_ranges.append(rng) + image_belongs.append(i) + if len(pixel_values) > 0: assert len(image_ranges) > 0 - assert len(image_belong) > 0 + assert len(image_belongs) > 0 pixel_values = torch.stack(pixel_values) - image_ranges = torch.stack(image_ranges) - image_belong = torch.IntTensor(image_belong) + # image_belongs = torch.IntTensor(image_belongs) else: pixel_values = None image_ranges = None - image_belong = None + image_belongs = None if len(instances) > 1: input_ids = pad_sequence( input_ids, batch_first=True, padding_value=pad_index) labels = pad_sequence( labels, batch_first=True, padding_value=IGNORE_INDEX) - position_ids = pad_sequence(labels, batch_first=True, padding_value=0) + position_ids = pad_sequence(position_ids, batch_first=True, padding_value=0) else: input_ids = torch.stack(input_ids) labels = torch.stack(labels) @@ -57,6 +62,7 @@ def hybrid_collate_fn(instances: Sequence[Dict], if len(cumulative_len) == 0: cumulative_len = None + # breakpoint() data_dict = { 'input_ids': input_ids, 'position_ids': position_ids, @@ -65,8 +71,9 @@ def hybrid_collate_fn(instances: Sequence[Dict], 'pixel_values': pixel_values, 'cumulative_len': cumulative_len, 'image_ranges': image_ranges, - 'image_belong': image_belong + 'image_belongs': image_belongs } + if return_hf_format: return data_dict diff --git a/xtuner/dataset/hybrid/hybrid.py b/xtuner/dataset/hybrid/hybrid.py deleted file mode 100644 index 289d21de2..000000000 --- a/xtuner/dataset/hybrid/hybrid.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, Sequence - -import torch -from torch.nn.utils.rnn import pad_sequence - -from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX -from xtuner.types import RawTrainingData - - -def hybrid_collate_fn(instances: Sequence[Dict], - pad_index: int = DEFAULT_PAD_TOKEN_INDEX, - return_hf_format: bool = False): - - input_ids = [] - labels = [] - pixel_values = [] - cumulative_len = [] - image_ranges = [] - # indexes = [] - - - for item in instances: - input_ids.append(torch.LongTensor(item['input_ids'])) - labels.append(torch.LongTensor(item['labels'])) - - if 'cumulative_len' in item: - cumulative_len.append(torch.IntTensor(item['cumulative_len'])) - - pixel_values.extend(item['pixel_values']) - # image_ranges.extend(torch.IntTensor(item['image_ranges'])) - - if len(pixel_values) > 0: - pixel_values = torch.stack(pixel_values) - else: - pixel_values = None - - if len(instances) > 1: - input_ids = pad_sequence( - input_ids, batch_first=True, padding_value=pad_index) - labels = pad_sequence( - labels, batch_first=True, padding_value=IGNORE_INDEX) - else: - input_ids = torch.stack(input_ids) - labels = torch.stack(labels) - - # if len(image_ranges) > 0: - # image_ranges = torch.stack(image_ranges) - # else: - # image_ranges = None - - if len(cumulative_len) == 0: - cumulative_len = None - - data_dict = { - 'input_ids': input_ids, - 'attention_mask': input_ids.ne(pad_index), - 'labels': labels, - 'pixel_values': pixel_values, - 'cumulative_len': cumulative_len, - # 'image_ranges': image_ranges, - } - - - if return_hf_format: - return data_dict - else: - return {'data': data_dict, 'data_samples': None} diff --git a/xtuner/model/hybrid.py b/xtuner/model/hybrid.py index 40db93820..24b165cd7 100644 --- a/xtuner/model/hybrid.py +++ b/xtuner/model/hybrid.py @@ -4,8 +4,9 @@ import torch from mmengine.model import BaseModel from peft import LoraConfig +from mmengine import print_log from torch import nn - +import math from xtuner.registry import BUILDER from xtuner.utils.config import build_from_cfg_or_obj from .modules import ProjectorConfig, ProjectorModel, dispatch_modules @@ -13,8 +14,8 @@ get_peft_model_state_dict, prepare_for_llm_lora, prepare_for_vision_lora, smart_tokenizer_and_embedding_resize) - - +import torch.distributed as dist +from mmengine import runner class HybridFinetune(BaseModel): def __init__( @@ -106,46 +107,117 @@ def forward(self, data, data_samples=None, mode='loss'): """Overload parent class method, only support training.""" if mode == 'loss': - return self.compute_loss(data, data_samples) + return self.compute_loss(data) else: raise NotImplementedError( f"{type(self)}'s forward is only supported for use during " 'training. If you want to get predictions or chat, please ' "directly use `llm`'s forward.") - def compute_loss(self, data, data_samples=None): - + + + def _get_vision_embeds_and_ranges(self, data): + input_ids = data['input_ids'] - labels = data['labels'] - position_ids = data['position_ids'] - attention_mask = data['attention_mask'] pixel_values = data['pixel_values'] img_rngs = data['image_ranges'] - img_belong = data['image_belong'] - - input_embeds = self.llm.get_input_embeddings()(input_ids) + img_belongs = data['image_belongs'] + + bs, tokens = input_ids.shape + + img_embeds = [] + ranges_in_flat_batch = [] if pixel_values is not None: + assert isinstance(pixel_values, torch.Tensor) + assert len(img_rngs) == len(img_belongs) == pixel_values.size(0) + + batch_total_imgs = len(img_rngs) + visual_outputs = self.visual_encoder( pixel_values, output_hidden_states=True) - img_embeds = self.projector( + features = self.projector( visual_outputs.hidden_states[self.visual_select_layer][:, 1:]) - - empty_embs = torch.zeros_like(input_embeds) - for emb, rng, b_id in zip(img_embeds, img_rngs, img_belong): - left, right = rng - if emb.size(0) == right - left: - empty_embs[b_id, left:right, :] = emb - elif not emb.size(0) == right - left and left == 0: - empty_embs[b_id, left:right, :] = emb[-right:] - elif not emb.size( - 0) == right - left and right == empty_embs.size(1): - empty_embs[b_id, left:right, :] = emb[:right - left] + batch_total_imgs, actual_img_tokens, _ = features.shape + + + for i in range(batch_total_imgs): + img_start, img_end = img_rngs[i] + expect_img_tokens = img_end - img_start + img_emb = features[i] + img_bs_ind = img_belongs[i] + + if actual_img_tokens == expect_img_tokens: + img_embeds.append(img_emb) + elif not actual_img_tokens == expect_img_tokens and img_start == 0: + img_embeds.append(img_emb[actual_img_tokens-img_end:]) + elif not actual_img_tokens == expect_img_tokens and img_end == tokens: + img_embeds.append(img_emb[:expect_img_tokens]) else: - breakpoint() + raise RuntimeError + + flat_offset = tokens * img_bs_ind + + left = flat_offset + img_start + right = flat_offset + img_end + ranges_in_flat_batch.append((left, right)) + + return img_embeds, ranges_in_flat_batch + + + def _insert_mm_embeddings(self, flat_embeds, mm_embeds, ranges): + + assert len(mm_embeds) == len(ranges) + if len(mm_embeds) == 0: + return flat_embeds + + chunk_embeds = [] + chunk_sizes = [] + mm_chunk_ids = [] + + cursor = 0 + _empty_embeds = torch.zeros_like(flat_embeds) + for (start, end), emb in zip(ranges, mm_embeds): + _empty_embeds[start: end] += emb + # if start - cursor > 0: + # chunk_sizes.append(start - cursor) + # cursor = start + + # mm_chunk_ids.append(len(chunk_sizes)) + + + # chunk_embeds.append(emb) + # chunk_sizes.append(end - start) + # cursor = end + + # tokens = flat_embeds.size(0) + # if sum(chunk_sizes) < tokens : + # chunk_sizes.append(tokens - sum(chunk_sizes)) + + # chunk_embs = list(torch.split(flat_embeds, chunk_sizes)) + # for ind, mm_emb in zip(mm_chunk_ids, mm_embeds) : + # chunk_embs[ind] = mm_emb + + # flat_embeds = torch.cat(chunk_embs, dim=0) + flat_embeds = flat_embeds * (_empty_embeds == 0) + + return flat_embeds + _empty_embeds + + def compute_loss(self, data): - non_img_mask = (empty_embs == 0) - input_embeds = input_embeds * non_img_mask + empty_embs + input_ids = data['input_ids'] + labels = data['labels'] + position_ids = data['position_ids'] + attention_mask = data['attention_mask'] + + input_embeds = self.llm.get_input_embeddings()(input_ids) + + bs, tokens, dim = input_embeds.shape + flat_embeds = input_embeds.flatten(0,1) + + img_embs, flat_bs_img_rngs = self._get_vision_embeds_and_ranges(data) + flat_embeds = self._insert_mm_embeddings(flat_embeds, img_embs, flat_bs_img_rngs) + input_embeds = flat_embeds.reshape(bs, tokens, dim) outputs = self.llm( input_ids=None, @@ -153,7 +225,7 @@ def compute_loss(self, data, data_samples=None): attention_mask=attention_mask, inputs_embeds=input_embeds, labels=labels) - + loss_dict = {'loss': outputs.loss} return loss_dict diff --git a/xtuner/types/chat.py b/xtuner/types/chat.py index 616d89b09..125f789ee 100644 --- a/xtuner/types/chat.py +++ b/xtuner/types/chat.py @@ -92,6 +92,30 @@ def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: return chat_template.decorate_function_result(self.content) +class CodeInterpreterCallMsg(BaseModel): + + role: Literal['assistant'] + content: str + conde_interpreter_call: Union[str, Dict] + + def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: + + return chat_template.decorate_code_interpreter_call( + self.content, self.conde_interpreter_call) + + + +class CodeInterpreterResultMsg(BaseModel): + role: Literal['function'] + name: str + content: Union[str, Dict] + + def apply_chat_template(self, chat_template: HybridChatTemplate) -> str: + return chat_template.decorate_code_internpreter_result(self.content) + + + + class Functions(BaseModel): # class Parameters(BaseModel): @@ -108,6 +132,26 @@ class Functions(BaseModel): name: str description: Union[str, Dict] parameters: Union[str, Dict] + + + +class CodeInterpreter(BaseModel): + + # class Parameters(BaseModel): + + # class Property(BaseModel): + # type: str + # description: str + # enum: Optional[List] = None + + # type: Literal['object'] + # properties: Dict[str, Property] + # required: List[str] + + name: str + description: Union[str, Dict] + + HybridChatMsgType = Union[ChatMsg, FunctionCallMsg, FunctionResultMsg] diff --git a/xtuner/utils/config.py b/xtuner/utils/config.py index ecd165920..0514dd8bf 100644 --- a/xtuner/utils/config.py +++ b/xtuner/utils/config.py @@ -124,7 +124,7 @@ def build_from_cfg_or_obj(cfg_or_obj: Union[dict, OBJ_T], raise TypeError( f'Expect an object of {accept}, but there is an object of ' f'{type(obj)}.') - return BUILDER.build(cfg_or_obj) + return obj else: raise TypeError(f'cfg_or_obj must be a dict, or {accept}, but got '