Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add huggingface llava #524

Open
wants to merge 46 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
f84c03f
Add: LLaVa lce_loss
jp1924 Jan 8, 2025
4aaf1ef
Add: LLaVa monkey_patched
jp1924 Jan 8, 2025
8a6f8cd
Add: llava config
jp1924 Jan 9, 2025
4477479
Add: load_image_processing_config
jp1924 Jan 9, 2025
f77d647
Add: llava test
jp1924 Jan 9, 2025
53efec3
Add: llava
jp1924 Jan 9, 2025
1fadb0e
Add: llava test
jp1924 Jan 15, 2025
2f14cf6
Add: llava test
jp1924 Jan 15, 2025
a3db21c
Refactor: update llava model forward method and enhance loss handling
jp1924 Jan 16, 2025
79e2102
debugging
jp1924 Jan 16, 2025
80cfc08
Refactor: simplify lce_forward method and improve loss calculation
jp1924 Jan 16, 2025
24ca457
Refactor: clean up apply_liger_kernel_to_llava function by removing c…
jp1924 Jan 16, 2025
8c181a6
rm debugging
jp1924 Jan 16, 2025
24bc56c
Refactor: remove redundant import of llava_lce_forward in apply_liger…
jp1924 Jan 16, 2025
42dbe36
Refactor: update mini_llava model configuration and add test cases fo…
jp1924 Jan 16, 2025
f6ba33f
Fix: loss value error
jp1924 Jan 16, 2025
9223959
Refactor: add new processor and tokenizer configuration files, remove…
jp1924 Jan 16, 2025
61476e8
Refactor: clean up lce_forward function and update apply_liger_kernel…
jp1924 Jan 16, 2025
22cafbf
Add: processor configuration loading function and clean up model conf…
jp1924 Jan 16, 2025
7f61b13
Merge branch 'main' into add_llava
jp1924 Jan 16, 2025
21bdc0a
Fix: typo Qwen2-VL -> LLaVa
jp1924 Jan 20, 2025
37857ef
Refactor: remove unused input list from run_mini_model function
jp1924 Jan 20, 2025
23403c5
Add model initialization for llava in multimodal tests
jp1924 Jan 20, 2025
5fb2853
Add support for deprecated llava forward function and warning for leg…
jp1924 Jan 20, 2025
8289347
Merge branch 'main' into add_llava
lancerts Jan 21, 2025
dbc1cdd
Merge branch 'main' into add_llava
jp1924 Jan 22, 2025
5adf3e2
Clean: unused module
jp1924 Jan 23, 2025
5b72328
Update: validate kwargs & model
jp1924 Jan 23, 2025
9b1f929
Merge branch 'main' into add_llava
jp1924 Jan 23, 2025
35c1788
Fix: incorrect model input
jp1924 Jan 23, 2025
20e87db
Update: enhance documentation for Llava
jp1924 Jan 23, 2025
419467e
Clean: check model
jp1924 Jan 23, 2025
a12c95b
Merge branch 'main' into add_llava
jp1924 Jan 24, 2025
5801eb9
Fix: conflict
jp1924 Jan 24, 2025
a980a88
revert change
jp1924 Jan 24, 2025
bcca064
solve conflict
jp1924 Jan 26, 2025
4fe5307
solve conflict
jp1924 Jan 26, 2025
5dae0ff
fix conflict
jp1924 Jan 26, 2025
dcef62f
Merge branch 'main' into add_llava
jp1924 Jan 26, 2025
eaef787
solve conflict
jp1924 Jan 26, 2025
90587e8
Add: load_processor_config, load_image_processing_config
jp1924 Jan 26, 2025
ec94e9c
Add: revert_liger_kernel_to_llava
jp1924 Jan 26, 2025
34db41f
apply lint
jp1924 Jan 28, 2025
f29dce4
resolve conflict
jp1924 Jan 28, 2025
38ef343
Merge branch 'main' into add_llava
jp1924 Jan 28, 2025
3f360b5
Merge branch 'main' into add_llava
jp1924 Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma2 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
Expand Down
370 changes: 370 additions & 0 deletions src/liger_kernel/transformers/model/llava.py

Large diffs are not rendered by default.

82 changes: 82 additions & 0 deletions src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
from liger_kernel.transformers.model.llava import lce_forward_deprecated as llava_lce_forward_deprecated
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
Expand Down Expand Up @@ -133,6 +135,85 @@ def apply_liger_kernel_to_llama(
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)


def apply_liger_kernel_to_llava(
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
model: PreTrainedModel = None,
**kwargs,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Llava models.
tyler-romero marked this conversation as resolved.
Show resolved Hide resolved
Due to the characteristics of LlaVa, the model must be passed to apply Liger-Kernel's patch to other models connected to LLaVa.
However, if an LM not supported by Liger-Kernel is connected to LLaVa, unexpected side effects may occur.
NOTE: Llava is not available in transformers<4.36.0

Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (
cross_entropy and fused_linear_cross_entropy
), "cross_entropy and fused_linear_cross_entropy cannot both be True."

from transformers.models.llava import modeling_llava

if cross_entropy:
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
modeling_llava.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
if transformer_version >= version.parse("4.49.0"):
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
else: # if version < 4.49.0
logger.warning(
"Support for transformers versions < 4.49.0 will soon be discontinued due to issues with incorrect legacy processing. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/35526"
)
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated

if model is not None:
text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
vision_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(vision_model_name, None)

kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs}
if text_liger_fn:
accept_params = inspect.signature(text_liger_fn).parameters
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
text_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}

if remain_params:
logger.warning(
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
)
text_kwargs["model"] = model.language_model
text_liger_fn(**text_kwargs)
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
logger.warning(f"{text_model_name} is not supported by Liger kernel.")

if vision_liger_fn:
accept_params = inspect.signature(vision_liger_fn).parameters
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
vision_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}

if remain_params:
logger.warning(
f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
)
vision_kwargs["model"] = model.vision_tower
vision_liger_fn(**vision_kwargs)
elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
logger.warning(f"{vision_model_name} is not supported by Liger kernel.")


def apply_liger_kernel_to_mllama(
rope: bool = True,
cross_entropy: bool = False,
Expand Down Expand Up @@ -740,6 +821,7 @@ def apply_liger_kernel_to_phi3(
"gemma": apply_liger_kernel_to_gemma,
"gemma2": apply_liger_kernel_to_gemma2,
"llama": apply_liger_kernel_to_llama,
"llava": apply_liger_kernel_to_llava,
"mllama": apply_liger_kernel_to_mllama,
"mllama_text_model": apply_liger_kernel_to_mllama,
"mistral": apply_liger_kernel_to_mistral,
Expand Down
112 changes: 109 additions & 3 deletions test/convergence/test_mini_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from liger_kernel.transformers import apply_liger_kernel_to_gemma
from liger_kernel.transformers import apply_liger_kernel_to_gemma2
from liger_kernel.transformers import apply_liger_kernel_to_llama
from liger_kernel.transformers import apply_liger_kernel_to_llava
from liger_kernel.transformers import apply_liger_kernel_to_mistral
from liger_kernel.transformers import apply_liger_kernel_to_mixtral
from liger_kernel.transformers import apply_liger_kernel_to_mllama
Expand All @@ -33,6 +34,7 @@
from test.utils import revert_liger_kernel_to_gemma
from test.utils import revert_liger_kernel_to_gemma2
from test.utils import revert_liger_kernel_to_llama
from test.utils import revert_liger_kernel_to_llava
from test.utils import revert_liger_kernel_to_mistral
from test.utils import revert_liger_kernel_to_mixtral
from test.utils import revert_liger_kernel_to_mllama
Expand Down Expand Up @@ -61,6 +63,15 @@
except ImportError:
QWEN2_VL_AVAILABLE = False

try:
from transformers import CLIPVisionConfig
from transformers.models.llava.configuration_llava import LlavaConfig
from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration

LLAVA_AVAILABLE = True
except ImportError:
LLAVA_AVAILABLE = False

from liger_kernel.utils import infer_device

device = infer_device()
Expand Down Expand Up @@ -381,6 +392,65 @@
),
)

if LLAVA_AVAILABLE:
# https://huggingface.co/llava-hf/llava-1.5-7b-hf
MINI_MODEL_SETUPS["mini_llava"] = MiniModelConfig(
liger_kernel_patch_func=apply_liger_kernel_to_llava,
liger_kernel_patch_revert_func=revert_liger_kernel_to_llava,
model_class=LlavaForConditionalGeneration,
mini_model_config=LlavaConfig(
text_config=LlamaConfig(
attention_bias=False,
attention_dropout=0.0,
bos_token_id=1,
eos_token_id=2,
hidden_act="silu",
hidden_size=1024,
initializer_range=0.02,
intermediate_size=2048,
num_attention_heads=8,
num_hidden_layers=4,
num_key_value_heads=2,
pretraining_tp=1,
rope_scaling=None,
rope_theta=500000.0,
tie_word_embeddings=False,
use_cache=True,
max_position_embeddings=4096, # llava-1.5-7b-hf
rms_norm_eps=1e-05, # llava-1.5-7b-hf
vocab_size=32064, # llava-1.5-7b-hf
# At rope backward
# Eager produces incontiguous dq and dk
# SDPA produces contiguous dq and incontiguous dk
# Flash_attn produces contiguous dq and dk
attn_implementation="sdpa", # default value, pytorch native attention
),
vision_config=CLIPVisionConfig(
hidden_size=1024,
image_size=336,
intermediate_size=4096,
model_type="clip_vision_model",
num_attention_heads=16,
num_hidden_layers=24,
patch_size=14,
projection_dim=768,
vocab_size=32000,
),
vocab_size=32064,
ignore_index=-100,
pad_token_id=4,
image_token_index=3,
projector_hidden_act="gelu",
vision_feature_layer=-2,
vision_feature_select_strategy="default",
# At rope backward
# Eager produces incontiguous dq and dk
# SDPA produces contiguous dq and incontiguous dk
# Flash_attn produces contiguous dq and dk
attn_implementation="sdpa", # default value, pytorch native attention
),
)


def create_model(model_name="mini_llama3"):
"""
Expand All @@ -406,6 +476,8 @@ def run_mini_model(

set_seed(42)

model = create_model(model_name).to(dtype).to(device)

revert_kwargs = {"model_config": MINI_MODEL_SETUPS[model_name]}
if "mllama" in model_name:
revert_kwargs["model_type"] = "causal_lm"
Expand All @@ -427,13 +499,13 @@ def run_mini_model(

kwargs["fused_linear_cross_entropy"] = True
kwargs["cross_entropy"] = False
if "llava" in model_name:
kwargs["model"] = model

MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs)
else:
MINI_MODEL_SETUPS[model_name].liger_kernel_patch_revert_func(**revert_kwargs)

model = create_model(model_name).to(dtype).to(device)

train_dataset = load_from_disk(DEFAULT_DATASET_PATH)
loader = DataLoader(train_dataset, batch_size=16, shuffle=False, collate_fn=simple_collate_fn)
loader_iter = iter(loader)
Expand Down Expand Up @@ -471,6 +543,41 @@ def run_mini_model(
1e-2,
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
),
pytest.param(
"mini_llava",
32,
1e-4,
torch.float32,
1e-8,
1e-5,
5e-3,
1e-5,
5e-3,
1e-5,
marks=pytest.mark.skipif(
not LLAVA_AVAILABLE,
reason="LLaVa not available in this version of transformers",
),
),
pytest.param(
"mini_llava",
32,
1e-4,
torch.bfloat16,
1e-3,
1e-2,
1e-1,
1e-2,
1e-2,
1e-2,
marks=[
pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
pytest.mark.skipif(
not LLAVA_AVAILABLE,
reason="LLaVa not available in this version of transformers",
),
],
),
pytest.param(
"mini_mllama",
32,
Expand Down Expand Up @@ -673,7 +780,6 @@ def test_mini_model(
atol=loss_atol,
rtol=loss_rtol,
)

# No logits are materialized

# # Compare the logits from the last step
Expand Down
Loading