Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add huggingface llava #524

Open
wants to merge 46 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
f84c03f
Add: LLaVa lce_loss
jp1924 Jan 8, 2025
4aaf1ef
Add: LLaVa monkey_patched
jp1924 Jan 8, 2025
8a6f8cd
Add: llava config
jp1924 Jan 9, 2025
4477479
Add: load_image_processing_config
jp1924 Jan 9, 2025
f77d647
Add: llava test
jp1924 Jan 9, 2025
53efec3
Add: llava
jp1924 Jan 9, 2025
1fadb0e
Add: llava test
jp1924 Jan 15, 2025
2f14cf6
Add: llava test
jp1924 Jan 15, 2025
a3db21c
Refactor: update llava model forward method and enhance loss handling
jp1924 Jan 16, 2025
79e2102
debugging
jp1924 Jan 16, 2025
80cfc08
Refactor: simplify lce_forward method and improve loss calculation
jp1924 Jan 16, 2025
24ca457
Refactor: clean up apply_liger_kernel_to_llava function by removing c…
jp1924 Jan 16, 2025
8c181a6
rm debugging
jp1924 Jan 16, 2025
24bc56c
Refactor: remove redundant import of llava_lce_forward in apply_liger…
jp1924 Jan 16, 2025
42dbe36
Refactor: update mini_llava model configuration and add test cases fo…
jp1924 Jan 16, 2025
f6ba33f
Fix: loss value error
jp1924 Jan 16, 2025
9223959
Refactor: add new processor and tokenizer configuration files, remove…
jp1924 Jan 16, 2025
61476e8
Refactor: clean up lce_forward function and update apply_liger_kernel…
jp1924 Jan 16, 2025
22cafbf
Add: processor configuration loading function and clean up model conf…
jp1924 Jan 16, 2025
7f61b13
Merge branch 'main' into add_llava
jp1924 Jan 16, 2025
21bdc0a
Fix: typo Qwen2-VL -> LLaVa
jp1924 Jan 20, 2025
37857ef
Refactor: remove unused input list from run_mini_model function
jp1924 Jan 20, 2025
23403c5
Add model initialization for llava in multimodal tests
jp1924 Jan 20, 2025
5fb2853
Add support for deprecated llava forward function and warning for leg…
jp1924 Jan 20, 2025
8289347
Merge branch 'main' into add_llava
lancerts Jan 21, 2025
dbc1cdd
Merge branch 'main' into add_llava
jp1924 Jan 22, 2025
5adf3e2
Clean: unused module
jp1924 Jan 23, 2025
5b72328
Update: validate kwargs & model
jp1924 Jan 23, 2025
9b1f929
Merge branch 'main' into add_llava
jp1924 Jan 23, 2025
35c1788
Fix: incorrect model input
jp1924 Jan 23, 2025
20e87db
Update: enhance documentation for Llava
jp1924 Jan 23, 2025
419467e
Clean: check model
jp1924 Jan 23, 2025
a12c95b
Merge branch 'main' into add_llava
jp1924 Jan 24, 2025
5801eb9
Fix: conflict
jp1924 Jan 24, 2025
a980a88
revert change
jp1924 Jan 24, 2025
bcca064
solve conflict
jp1924 Jan 26, 2025
4fe5307
solve conflict
jp1924 Jan 26, 2025
5dae0ff
fix conflict
jp1924 Jan 26, 2025
dcef62f
Merge branch 'main' into add_llava
jp1924 Jan 26, 2025
eaef787
solve conflict
jp1924 Jan 26, 2025
90587e8
Add: load_processor_config, load_image_processing_config
jp1924 Jan 26, 2025
ec94e9c
Add: revert_liger_kernel_to_llava
jp1924 Jan 26, 2025
34db41f
apply lint
jp1924 Jan 28, 2025
f29dce4
resolve conflict
jp1924 Jan 28, 2025
38ef343
Merge branch 'main' into add_llava
jp1924 Jan 28, 2025
3f360b5
Merge branch 'main' into add_llava
jp1924 Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
Expand Down
164 changes: 164 additions & 0 deletions src/liger_kernel/transformers/model/llava.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import torch
import torch.nn as nn

from transformers.models.llava.modeling_llava import _CONFIG_FOR_DOC
from transformers.models.llava.modeling_llava import LLAVA_INPUTS_DOCSTRING
from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
from transformers.models.llava.modeling_llava import logger
from transformers.utils import add_start_docstrings_to_model_forward
from transformers.utils import replace_return_docstrings

from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss


@add_start_docstrings_to_model_forward(LLAVA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=LlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def lce_forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[int] = None,
vision_feature_select_strategy: Optional[str] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.


Returns:

Example:

```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, LlavaForConditionalGeneration

>>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

>>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> inputs = processor(images=image, text=prompt, return_tensors="pt")

>>> # Generate
>>> generate_ids = model.generate(**inputs, max_new_tokens=15)
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)

if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)

if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)

if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
)

n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0] * image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)

outputs = self.language_model.model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
num_logits_to_keep=num_logits_to_keep,
)
hidden_states = outputs[0]

loss = None
logits = None

if self.training and (labels is not None):
# Shift so that tokens < n predict n
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(hidden_states.shape[1] - 1) :].to(hidden_states.device)
shift_hidden_states = hidden_states[..., :-1, :][
shift_attention_mask.to(hidden_states.device) != 0
].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()

lce = LigerFusedLinearCrossEntropyLoss()
loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)

if not return_dict:
# NOTE: This part has not been tested.
output = (outputs.logits,) + outputs[1:]
return (outputs.loss,) + output if outputs.loss is not None else output

return LlavaCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
52 changes: 52 additions & 0 deletions src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
Expand Down Expand Up @@ -133,6 +134,56 @@ def apply_liger_kernel_to_llama(
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)


def apply_liger_kernel_to_llava(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
model: PreTrainedModel = None,
**kwargs,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Llava models.
tyler-romero marked this conversation as resolved.
Show resolved Hide resolved
NOTE: Llava is not available in transformers<4.45.0

Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (
cross_entropy and fused_linear_cross_entropy
), "cross_entropy and fused_linear_cross_entropy cannot both be True."

from transformers.models.llava import modeling_llava

if cross_entropy:
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_llava.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward

if model is not None:
if model.config.text_config.model_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
MODEL_TYPE_TO_APPLY_LIGER_FN[model.config.text_config.model_type](
cross_entropy=False,
fused_linear_cross_entropy=False,
**kwargs,
)

if model.config.vision_config.model_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
MODEL_TYPE_TO_APPLY_LIGER_FN[model.config.vision_config.model_type](
cross_entropy=False,
fused_linear_cross_entropy=False,
**kwargs,
)


def apply_liger_kernel_to_mllama(
rope: bool = True,
cross_entropy: bool = False,
Expand Down Expand Up @@ -740,6 +791,7 @@ def apply_liger_kernel_to_phi3(
"gemma": apply_liger_kernel_to_gemma,
"gemma2": apply_liger_kernel_to_gemma2,
"llama": apply_liger_kernel_to_llama,
"llava": apply_liger_kernel_to_llava,
"mllama": apply_liger_kernel_to_mllama,
"mllama_text_model": apply_liger_kernel_to_mllama,
"mistral": apply_liger_kernel_to_mistral,
Expand Down
Loading