diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py index cbf330cc2..f979a5952 100644 --- a/src/liger_kernel/transformers/__init__.py +++ b/src/liger_kernel/transformers/__init__.py @@ -10,6 +10,7 @@ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401 +from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401 diff --git a/src/liger_kernel/transformers/model/llava.py b/src/liger_kernel/transformers/model/llava.py new file mode 100644 index 000000000..68b2e7f83 --- /dev/null +++ b/src/liger_kernel/transformers/model/llava.py @@ -0,0 +1,370 @@ +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from transformers.models.llava.modeling_llava import _CONFIG_FOR_DOC +from transformers.models.llava.modeling_llava import LLAVA_INPUTS_DOCSTRING +from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast +from transformers.models.llava.modeling_llava import logger +from transformers.utils import add_start_docstrings_to_model_forward +from transformers.utils import replace_return_docstrings + +from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss + + +@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING) +@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) +def lce_forward_deprecated( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[int] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, +) -> Union[Tuple, LlavaCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, LlavaForConditionalGeneration + + >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") + >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") + + >>> prompt = "USER: \nWhat's the content of the image? ASSISTANT:" + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_new_tokens=15) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed" + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + legacy_processing = False + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing + # not very reliable, but we don't expect one to actually pass 500+ images for one prompt + # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True + legacy_processing = ( + (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length + ) or (input_ids.shape[-1] == 1 and pixel_values is not None) + + image_features = None + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + ) + + if legacy_processing and image_features is not None: + logger.warning_once( + "Expanding inputs for image tokens in LLaVa should be done in processing. " + "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " + "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.50." + ) + # prefill stage vs decoding stage (legacy behavior copied) + if input_ids.shape[1] != 1: + inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, attention_mask, labels + ) + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) + else: + # Retrieve the first layer to inspect the logits and mask out the hidden states + # that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) + + # Get the target length + target_length = input_ids.shape[1] + past_length = first_layer_past_key_value.shape[-1] + + extended_attention_mask = torch.ones( + (attention_mask.shape[0], past_length), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + + # Filter out only the tokens that can be un-attended, this can happen + # if one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + + attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] + + # TODO: @raushan retain only the new behavior after v4.47 + elif image_features is not None: + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] * image_features.shape[1] + + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + special_image_mask = ( + (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model.model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + ) + hidden_states = outputs[0] + + loss = None + logits = None + + if self.training and (labels is not None): + # Shift so that tokens < n predict n + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -(hidden_states.shape[1] - 1) :].to(hidden_states.device) + shift_hidden_states = hidden_states[..., :-1, :][ + shift_attention_mask.to(hidden_states.device) != 0 + ].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() + else: + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + lce = LigerFusedLinearCrossEntropyLoss() + loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels) + + if not return_dict: + # NOTE: This part has not been tested. + output = outputs[1:] + return (loss,) + output if loss is not None else output + + return LlavaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + +@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING) +@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) +def lce_forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[int] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, +) -> Union[Tuple, LlavaCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, LlavaForConditionalGeneration + + >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") + >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") + + >>> prompt = "USER: \nWhat's the content of the image? ASSISTANT:" + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_new_tokens=15) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed" + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + ) + + n_image_tokens = (input_ids == self.config.image_token_index).sum().item() + n_image_features = image_features.shape[0] * image_features.shape[1] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model.model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + ) + hidden_states = outputs[0] + + loss = None + logits = None + + if self.training and (labels is not None): + # Shift so that tokens < n predict n + if attention_mask is not None: + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -(hidden_states.shape[1] - 1) :].to(hidden_states.device) + shift_hidden_states = hidden_states[..., :-1, :][ + shift_attention_mask.to(hidden_states.device) != 0 + ].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() + else: + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + lce = LigerFusedLinearCrossEntropyLoss() + loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels) + + if not return_dict: + # NOTE: This part has not been tested. + output = outputs[1:] + return (loss,) + output if loss is not None else output + + return LlavaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index eafce145e..56a63be50 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -19,6 +19,8 @@ from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated +from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward +from liger_kernel.transformers.model.llava import lce_forward_deprecated as llava_lce_forward_deprecated from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated @@ -133,6 +135,85 @@ def apply_liger_kernel_to_llama( _patch_rms_norm_module(decoder_layer.post_attention_layernorm) +def apply_liger_kernel_to_llava( + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + model: PreTrainedModel = None, + **kwargs, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Llava models. + Due to the characteristics of LlaVa, the model must be passed to apply Liger-Kernel's patch to other models connected to LLaVa. + However, if an LM not supported by Liger-Kernel is connected to LLaVa, unexpected side effects may occur. + NOTE: Llava is not available in transformers<4.36.0 + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + assert not ( + cross_entropy and fused_linear_cross_entropy + ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + + from transformers.models.llava import modeling_llava + + if cross_entropy: + logger.warning(TRANSFORMER_DEPRECATION_WARNING) + modeling_llava.CrossEntropyLoss = LigerCrossEntropyLoss + if fused_linear_cross_entropy: + if transformer_version >= version.parse("4.49.0"): + modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward + else: # if version < 4.49.0 + logger.warning( + "Support for transformers versions < 4.49.0 will soon be discontinued due to issues with incorrect legacy processing. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/35526" + ) + modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated + + if model is not None: + text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type + text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None) + vision_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(vision_model_name, None) + + kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} + if text_liger_fn: + accept_params = inspect.signature(text_liger_fn).parameters + remain_params = set(kwargs) - (set(accept_params) & set(kwargs)) + text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params} + + if remain_params: + logger.warning( + f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n" + f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}" + ) + text_kwargs["model"] = model.language_model + text_liger_fn(**text_kwargs) + elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN: + logger.warning(f"{text_model_name} is not supported by Liger kernel.") + + if vision_liger_fn: + accept_params = inspect.signature(vision_liger_fn).parameters + remain_params = set(kwargs) - (set(accept_params) & set(kwargs)) + vision_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params} + + if remain_params: + logger.warning( + f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n" + f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}" + ) + vision_kwargs["model"] = model.vision_tower + vision_liger_fn(**vision_kwargs) + elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN: + logger.warning(f"{vision_model_name} is not supported by Liger kernel.") + + def apply_liger_kernel_to_mllama( rope: bool = True, cross_entropy: bool = False, @@ -740,6 +821,7 @@ def apply_liger_kernel_to_phi3( "gemma": apply_liger_kernel_to_gemma, "gemma2": apply_liger_kernel_to_gemma2, "llama": apply_liger_kernel_to_llama, + "llava": apply_liger_kernel_to_llava, "mllama": apply_liger_kernel_to_mllama, "mllama_text_model": apply_liger_kernel_to_mllama, "mistral": apply_liger_kernel_to_mistral, diff --git a/test/convergence/test_mini_models.py b/test/convergence/test_mini_models.py index 8566558e7..1d6040d84 100644 --- a/test/convergence/test_mini_models.py +++ b/test/convergence/test_mini_models.py @@ -21,6 +21,7 @@ from liger_kernel.transformers import apply_liger_kernel_to_gemma from liger_kernel.transformers import apply_liger_kernel_to_gemma2 from liger_kernel.transformers import apply_liger_kernel_to_llama +from liger_kernel.transformers import apply_liger_kernel_to_llava from liger_kernel.transformers import apply_liger_kernel_to_mistral from liger_kernel.transformers import apply_liger_kernel_to_mixtral from liger_kernel.transformers import apply_liger_kernel_to_mllama @@ -33,6 +34,7 @@ from test.utils import revert_liger_kernel_to_gemma from test.utils import revert_liger_kernel_to_gemma2 from test.utils import revert_liger_kernel_to_llama +from test.utils import revert_liger_kernel_to_llava from test.utils import revert_liger_kernel_to_mistral from test.utils import revert_liger_kernel_to_mixtral from test.utils import revert_liger_kernel_to_mllama @@ -61,6 +63,15 @@ except ImportError: QWEN2_VL_AVAILABLE = False +try: + from transformers import CLIPVisionConfig + from transformers.models.llava.configuration_llava import LlavaConfig + from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration + + LLAVA_AVAILABLE = True +except ImportError: + LLAVA_AVAILABLE = False + from liger_kernel.utils import infer_device device = infer_device() @@ -381,6 +392,65 @@ ), ) +if LLAVA_AVAILABLE: + # https://huggingface.co/llava-hf/llava-1.5-7b-hf + MINI_MODEL_SETUPS["mini_llava"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_llava, + liger_kernel_patch_revert_func=revert_liger_kernel_to_llava, + model_class=LlavaForConditionalGeneration, + mini_model_config=LlavaConfig( + text_config=LlamaConfig( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=2048, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + pretraining_tp=1, + rope_scaling=None, + rope_theta=500000.0, + tie_word_embeddings=False, + use_cache=True, + max_position_embeddings=4096, # llava-1.5-7b-hf + rms_norm_eps=1e-05, # llava-1.5-7b-hf + vocab_size=32064, # llava-1.5-7b-hf + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + vision_config=CLIPVisionConfig( + hidden_size=1024, + image_size=336, + intermediate_size=4096, + model_type="clip_vision_model", + num_attention_heads=16, + num_hidden_layers=24, + patch_size=14, + projection_dim=768, + vocab_size=32000, + ), + vocab_size=32064, + ignore_index=-100, + pad_token_id=4, + image_token_index=3, + projector_hidden_act="gelu", + vision_feature_layer=-2, + vision_feature_select_strategy="default", + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + def create_model(model_name="mini_llama3"): """ @@ -406,6 +476,8 @@ def run_mini_model( set_seed(42) + model = create_model(model_name).to(dtype).to(device) + revert_kwargs = {"model_config": MINI_MODEL_SETUPS[model_name]} if "mllama" in model_name: revert_kwargs["model_type"] = "causal_lm" @@ -427,13 +499,13 @@ def run_mini_model( kwargs["fused_linear_cross_entropy"] = True kwargs["cross_entropy"] = False + if "llava" in model_name: + kwargs["model"] = model MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) else: MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs) - model = create_model(model_name).to(dtype).to(device) - train_dataset = load_from_disk(DEFAULT_DATASET_PATH) loader = DataLoader(train_dataset, batch_size=16, shuffle=False, collate_fn=simple_collate_fn) loader_iter = iter(loader) @@ -471,6 +543,41 @@ def run_mini_model( 1e-2, marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), + pytest.param( + "mini_llava", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not LLAVA_AVAILABLE, + reason="LLaVa not available in this version of transformers", + ), + ), + pytest.param( + "mini_llava", + 32, + 1e-4, + torch.bfloat16, + 1e-3, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not LLAVA_AVAILABLE, + reason="LLaVa not available in this version of transformers", + ), + ], + ), pytest.param( "mini_mllama", 32, @@ -673,7 +780,6 @@ def test_mini_model( atol=loss_atol, rtol=loss_rtol, ) - # No logits are materialized # # Compare the logits from the last step diff --git a/test/convergence/test_mini_models_multimodal.py b/test/convergence/test_mini_models_multimodal.py index a7f8296a0..62e9c0b55 100644 --- a/test/convergence/test_mini_models_multimodal.py +++ b/test/convergence/test_mini_models_multimodal.py @@ -14,8 +14,11 @@ from test.utils import UNTOKENIZED_DATASET_PATH from test.utils import MiniModelConfig from test.utils import assert_verbose_allclose +from test.utils import load_image_processing_config +from test.utils import load_processor_config from test.utils import load_tokenizer_config from test.utils import multimodal_collate_fn +from test.utils import revert_liger_kernel_to_llava from test.utils import revert_liger_kernel_to_mllama from test.utils import revert_liger_kernel_to_qwen2_vl from test.utils import set_seed @@ -47,6 +50,21 @@ except ImportError: MLLAMA_AVAILABLE = False +try: + from transformers import CLIPImageProcessor + from transformers import CLIPVisionConfig + from transformers import LlamaConfig + from transformers.models.llava.configuration_llava import LlavaConfig + from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration + from transformers.models.llava.processing_llava import LlavaProcessor + + # fix for conflict + from liger_kernel.transformers import apply_liger_kernel_to_llava + + LLAVA_AVAILABLE = True +except ImportError: + LLAVA_AVAILABLE = False + from liger_kernel.utils import infer_device device = infer_device() @@ -169,6 +187,65 @@ ), ) +if LLAVA_AVAILABLE: + # https://huggingface.co/llava-hf/llava-1.5-7b-hf + MINI_MODEL_SETUPS["mini_llava"] = MiniModelConfig( + liger_kernel_patch_func=functools.partial(apply_liger_kernel_to_llava, fused_linear_cross_entropy=False), + liger_kernel_patch_revert_func=revert_liger_kernel_to_llava, + model_class=LlavaForConditionalGeneration, + mini_model_config=LlavaConfig( + text_config=LlamaConfig( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=2048, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + pretraining_tp=1, + rope_scaling=None, + rope_theta=500000.0, + tie_word_embeddings=False, + use_cache=True, + max_position_embeddings=4096, # llava-1.5-7b-hf + rms_norm_eps=1e-05, # llava-1.5-7b-hf + vocab_size=32064, # llava-1.5-7b-hf + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + vision_config=CLIPVisionConfig( + hidden_size=1024, + image_size=336, + intermediate_size=4096, + model_type="clip_vision_model", + num_attention_heads=16, + num_hidden_layers=24, + patch_size=14, + projection_dim=768, + vocab_size=32000, + ), + vocab_size=32064, + ignore_index=-100, + pad_token_id=4, + image_token_index=3, + projector_hidden_act="gelu", + vision_feature_layer=-2, + vision_feature_select_strategy="default", + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + def create_processor(model_name): if model_name == "mini_qwen2_vl": @@ -187,7 +264,6 @@ def create_processor(model_name): qwen_tokenizer = Qwen2TokenizerFast(tokenizer_object=tokenizer_base, **tokenizer_config) image_processor = Qwen2VLImageProcessor() return Qwen2VLProcessor(image_processor=image_processor, tokenizer=qwen_tokenizer) - elif model_name == "mini_mllama": tokenizer_config = load_tokenizer_config( os.path.join( @@ -207,6 +283,37 @@ def create_processor(model_name): fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer_base, **tokenizer_config) image_processor = MllamaImageProcessor(size={"height": 560, "width": 560}) return MllamaProcessor(image_processor=image_processor, tokenizer=fast_tokenizer) + elif model_name == "mini_llava": + tokenizer_config = load_tokenizer_config( + os.path.join( + FAKE_CONFIGS_PATH, + "Llava/llava-1.5-7b-hf/tokenizer_config.json", + ) + ) + image_processor_config = load_image_processing_config( + os.path.join( + FAKE_CONFIGS_PATH, + "Llava/llava-1.5-7b-hf/preprocessor_config.json", + ) + ) + processor_config = load_processor_config( + os.path.join( + FAKE_CONFIGS_PATH, + "Llava/llava-1.5-7b-hf/processor_config.json", + ) + ) + tokenizer_base = train_bpe_tokenizer( + [ + token.content + for key, token in sorted( + tokenizer_config["added_tokens_decoder"].items(), + key=lambda x: int(x[0]), + ) + ] + ) + fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer_base, **tokenizer_config) + image_processor = CLIPImageProcessor(**image_processor_config) + return LlavaProcessor(**processor_config, image_processor=image_processor, tokenizer=fast_tokenizer) else: raise ValueError(f"Processor not available for model {model_name}") @@ -288,6 +395,8 @@ def run_mini_model_multimodal( set_seed(42) + model = create_model(model_name).to(dtype).to(device) + revert_kwargs = {"model_config": MINI_MODEL_SETUPS[model_name]} if "mllama" in model_name: revert_kwargs["model_type"] = "conditional_generation" @@ -297,18 +406,24 @@ def run_mini_model_multimodal( "rope": True, "rms_norm": True, "cross_entropy": True, - "layer_norm": True, } + model_supports_layer_norm = "qwen2_vl" in model_name + if model_supports_layer_norm: + kwargs["layer_norm"] = True + if "gemma" in model_name: kwargs["geglu"] = True else: kwargs["swiglu"] = True + + if "llava" in model_name: + kwargs["model"] = model + MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) else: MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs) - model = create_model(model_name).to(dtype).to(device) model.gradient_checkpointing_enable() train_dataset = create_multimodal_dataset(model_name) @@ -409,6 +524,41 @@ def run_mini_model_multimodal( ), ], ), + pytest.param( + "mini_llava", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not LLAVA_AVAILABLE, + reason="LLaVa not available in this version of transformers", + ), + ), + pytest.param( + "mini_llava", + 32, + 1e-4, + torch.bfloat16, + 1e-3, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not LLAVA_AVAILABLE, + reason="LLaVa not available in this version of transformers", + ), + ], + ), ], ) def test_mini_model_multimodal( diff --git a/test/convergence/test_mini_models_with_logits.py b/test/convergence/test_mini_models_with_logits.py index 9abed2bd9..4c2b32161 100644 --- a/test/convergence/test_mini_models_with_logits.py +++ b/test/convergence/test_mini_models_with_logits.py @@ -21,6 +21,7 @@ from liger_kernel.transformers import apply_liger_kernel_to_gemma from liger_kernel.transformers import apply_liger_kernel_to_gemma2 from liger_kernel.transformers import apply_liger_kernel_to_llama +from liger_kernel.transformers import apply_liger_kernel_to_llava from liger_kernel.transformers import apply_liger_kernel_to_mistral from liger_kernel.transformers import apply_liger_kernel_to_mixtral from liger_kernel.transformers import apply_liger_kernel_to_mllama @@ -33,6 +34,7 @@ from test.utils import revert_liger_kernel_to_gemma from test.utils import revert_liger_kernel_to_gemma2 from test.utils import revert_liger_kernel_to_llama +from test.utils import revert_liger_kernel_to_llava from test.utils import revert_liger_kernel_to_mistral from test.utils import revert_liger_kernel_to_mixtral from test.utils import revert_liger_kernel_to_mllama @@ -61,6 +63,15 @@ except ImportError: QWEN2_VL_AVAILABLE = False +try: + from transformers import CLIPVisionConfig + from transformers.models.llava.configuration_llava import LlavaConfig + from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration + + LLAVA_AVAILABLE = True +except ImportError: + LLAVA_AVAILABLE = False + from liger_kernel.utils import infer_device device = infer_device() @@ -381,6 +392,65 @@ ), ) +if LLAVA_AVAILABLE: + # https://huggingface.co/llava-hf/llava-1.5-7b-hf + MINI_MODEL_SETUPS["mini_llava"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_llava, + liger_kernel_patch_revert_func=revert_liger_kernel_to_llava, + model_class=LlavaForConditionalGeneration, + mini_model_config=LlavaConfig( + text_config=LlamaConfig( + attention_bias=False, + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, + hidden_act="silu", + hidden_size=1024, + initializer_range=0.02, + intermediate_size=2048, + num_attention_heads=8, + num_hidden_layers=4, + num_key_value_heads=2, + pretraining_tp=1, + rope_scaling=None, + rope_theta=500000.0, + tie_word_embeddings=False, + use_cache=True, + max_position_embeddings=4096, # llava-1.5-7b-hf + rms_norm_eps=1e-05, # llava-1.5-7b-hf + vocab_size=32064, # llava-1.5-7b-hf + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + vision_config=CLIPVisionConfig( + hidden_size=1024, + image_size=336, + intermediate_size=4096, + model_type="clip_vision_model", + num_attention_heads=16, + num_hidden_layers=24, + patch_size=14, + projection_dim=768, + vocab_size=32000, + ), + vocab_size=32064, + ignore_index=-100, + pad_token_id=4, + image_token_index=3, + projector_hidden_act="gelu", + vision_feature_layer=-2, + vision_feature_select_strategy="default", + # At rope backward + # Eager produces incontiguous dq and dk + # SDPA produces contiguous dq and incontiguous dk + # Flash_attn produces contiguous dq and dk + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + def create_model(model_name="mini_llama3"): """ @@ -406,6 +476,8 @@ def run_mini_model( set_seed(42) + model = create_model(model_name).to(dtype).to(device) + revert_kwargs = {"model_config": MINI_MODEL_SETUPS[model_name]} if "mllama" in model_name: revert_kwargs["model_type"] = "causal_lm" @@ -425,6 +497,9 @@ def run_mini_model( else: kwargs["swiglu"] = True + if "llava" in model_name: + kwargs["model"] = model + kwargs["fused_linear_cross_entropy"] = False kwargs["cross_entropy"] = True @@ -432,7 +507,6 @@ def run_mini_model( else: MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs) - model = create_model(model_name).to(dtype).to(device) train_dataset = load_from_disk(DEFAULT_DATASET_PATH) loader = DataLoader(train_dataset, batch_size=16, shuffle=False, collate_fn=simple_collate_fn) loader_iter = iter(loader) @@ -470,6 +544,41 @@ def run_mini_model( 1e-2, marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), ), + pytest.param( + "mini_llava", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not LLAVA_AVAILABLE, + reason="LLaVa not available in this version of transformers", + ), + ), + pytest.param( + "mini_llava", + 32, + 1e-4, + torch.bfloat16, + 1e-3, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not LLAVA_AVAILABLE, + reason="LLaVa not available in this version of transformers", + ), + ], + ), pytest.param( "mini_mllama", 32, diff --git a/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json b/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json new file mode 100644 index 000000000..c32625c74 --- /dev/null +++ b/test/resources/fake_configs/Llava/llava-1.5-7b-hf/preprocessor_config.json @@ -0,0 +1,28 @@ +{ + "crop_size": { + "height": 336, + "width": 336 + }, + "do_center_crop": true, + "do_convert_rgb": true, + "do_normalize": true, + "do_rescale": true, + "do_resize": true, + "image_mean": [ + 0.48145466, + 0.4578275, + 0.40821073 + ], + "image_processor_type": "CLIPImageProcessor", + "image_std": [ + 0.26862954, + 0.26130258, + 0.27577711 + ], + "processor_class": "LlavaProcessor", + "resample": 3, + "rescale_factor": 0.00392156862745098, + "size": { + "shortest_edge": 336 + } +} \ No newline at end of file diff --git a/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json b/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json new file mode 100644 index 000000000..8fbb221c7 --- /dev/null +++ b/test/resources/fake_configs/Llava/llava-1.5-7b-hf/processor_config.json @@ -0,0 +1,7 @@ +{ + "image_token": "", + "num_additional_image_tokens": 1, + "patch_size": 14, + "processor_class": "LlavaProcessor", + "vision_feature_select_strategy": "default" +} \ No newline at end of file diff --git a/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json b/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json new file mode 100644 index 000000000..f9c6572a8 --- /dev/null +++ b/test/resources/fake_configs/Llava/llava-1.5-7b-hf/tokenizer_config.json @@ -0,0 +1,66 @@ +{ + "add_bos_token": true, + "add_eos_token": false, + "add_prefix_space": null, + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "3": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "4": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "bos_token": "", + "clean_up_tokenization_spaces": false, + "eos_token": "", + "extra_special_tokens": { + "image_token": "" + }, + "image_token": "", + "legacy": false, + "chat_template": "{% if not add_generation_prompt is defined %}{% set add_last_empty_assistant = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message.role == 'user' %}{{ '### User:\n' }}{% if message.content is not string %}{% for content in message.content %}{% if content.type == 'image' %}{{ '' }}{% elif content.type == 'text' %}{{ content.text }}{% else %}{# Do nothing #}{% endif %}{% endfor %}{% else %}{{ message.content }}{% endif %}{{ '\n\n' }}{% elif message.role == 'system' %}{{ '### System:\n' }}{% if message.content is not string %}{% for content in message.content %}{% if content.type == 'image' %}{{ '' }}{% elif content.type == 'text' %}{{ content.text }}{% else %}{# Do nothing #}{% endif %}{% endfor %}{% else %}{{ message.content }}{% endif %}{{ '\n\n' }}{% elif message.role == 'assistant' %}{{ '### Assistant:\n' }}{% if message.content is not string %}{% for content in message.content %}{% if content.type == 'text' %}{{ content.text }}{% else %}{# Do nothing #}{% endif %}{% endfor %}{% else %}{{ message.content }}{% endif %}{% else %}{{ '' }}{% endif %}{% endfor %}{% if not add_generation_prompt %}{{ eos_token }}{% elif add_generation_prompt %}{{ '### Assistant:\n' }}{% else %}{# Do nothing #}{% endif %}", + "model_max_length": 1000000000000000019884624838656, + "pad_token": "", + "padding_side": "left", + "processor_class": "LlavaProcessor", + "sp_model_kwargs": {}, + "tokenizer_class": "LlamaTokenizer", + "trust_remote_code": false, + "unk_token": "", + "use_default_system_prompt": false, + "return_token_type_ids": false +} \ No newline at end of file diff --git a/test/utils.py b/test/utils.py index a6af16e21..c1bfcaab4 100644 --- a/test/utils.py +++ b/test/utils.py @@ -180,6 +180,20 @@ def load_tokenizer_config(config_path: str) -> dict: return tokenizer_config +def load_image_processing_config(config_path: str) -> dict: + """Load and process image processing configuration from a JSON file.""" + with open(config_path) as reader: + image_processing_config = json.load(reader) + return image_processing_config + + +def load_processor_config(config_path: str) -> dict: + """Load and process processor configuration from a JSON file.""" + with open(config_path) as reader: + processor_config = json.load(reader) + return processor_config + + def train_bpe_tokenizer(special_tokens: List[str], unk_token: str = "<|unk|>"): """ Train a tokenizer using the BPE algorithm. @@ -380,6 +394,18 @@ def revert_liger_kernel_to_phi3(model_config: MiniModelConfig): print("Liger kernel patches have been reverted.") +def revert_liger_kernel_to_llava(model_config: MiniModelConfig): + """ + Revert all Liger kernel patches applied to llava. + """ + + from transformers.models.llava import modeling_llava + + importlib.reload(modeling_llava) + model_config.model_class = modeling_llava.LlavaForConditionalGeneration + print("Liger kernel patches have been reverted.") + + class HFAlignmentLoss: def __init__( self,