diff --git a/pyproject.toml b/pyproject.toml index 87843322a..9a53ee223 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=42", "wheel"] +requires = ["setuptools>=61.0", "wheel"] build-backend = "setuptools.build_meta" [project] @@ -27,6 +27,11 @@ dev = [ "seaborn", ] +test = [ + "causal-conv1d>=1.4.0", + "mamba-ssm>=2.2.2", +] + [tool.setuptools.packages.find] where = ["src"] include = ["liger_kernel", "liger_kernel.*"] diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py index 147948b18..2f0de8ae2 100644 --- a/src/liger_kernel/transformers/__init__.py +++ b/src/liger_kernel/transformers/__init__.py @@ -10,6 +10,7 @@ from liger_kernel.transformers.monkey_patch import ( # noqa: F401 apply_liger_kernel_to_gemma, apply_liger_kernel_to_gemma2, + apply_liger_kernel_to_jamba, apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral, apply_liger_kernel_to_mixtral, diff --git a/src/liger_kernel/transformers/model/jamba.py b/src/liger_kernel/transformers/model/jamba.py new file mode 100644 index 000000000..8bef7304e --- /dev/null +++ b/src/liger_kernel/transformers/model/jamba.py @@ -0,0 +1,169 @@ +from typing import Optional, Tuple, Union + +import torch +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import MoeCausalLMOutputWithPast +from transformers.models.jamba.modeling_jamba import ( + _CONFIG_FOR_DOC, + JAMBA_INPUTS_DOCSTRING, + HybridMambaAttentionDynamicCache, + load_balancing_loss_func, +) +from transformers.utils import ( + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) + +from liger_kernel.transformers.fused_linear_cross_entropy import ( + LigerFusedLinearCrossEntropyLoss, +) + + +@add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +) +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: Optional[Union[int, None]] = None, +) -> Union[Tuple, MoeCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int` or `None`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all + `input_ids`. Only last token logits are needed for generation, and calculating them only for that token + can save memory, which becomes pretty significant for long sequences. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, JambaForCausalLM + + >>> model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + cache_position=cache_position, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + loss = None + logits = None + + if self.training: + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # flatten tokens + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + + lce = LigerFusedLinearCrossEntropyLoss() + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + else: + if num_logits_to_keep is None: + logits = self.lm_head(hidden_states) + else: + logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :]) + logits = logits.float() + + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to( + loss.device + ) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 1cca9753c..bb8555cc6 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -6,6 +6,7 @@ from liger_kernel.transformers.geglu import LigerGEGLUMLP from liger_kernel.transformers.layer_norm import LigerLayerNorm from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward +from liger_kernel.transformers.model.jamba import lce_forward as jamba_lce_forward from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward @@ -332,6 +333,43 @@ def apply_liger_kernel_to_phi3( modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward +def apply_liger_kernel_to_jamba( + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Jamba models + to make GPU go burrr. + + # Note: Jamba model does not use rotary position embedding(RoPE). + + Args: + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused lienar cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True. + """ + assert not ( + cross_entropy and fused_linear_cross_entropy + ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + + from transformers.models.jamba import modeling_jamba + + if rms_norm: + modeling_jamba.JambaRMSNorm = LigerRMSNorm + if cross_entropy: + modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss + if swiglu: + modeling_jamba.JambaMLP = LigerSwiGLUMLP + if fused_linear_cross_entropy: + modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward + + # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py MODEL_TYPE_TO_APPLY_LIGER_FN = { "gemma": apply_liger_kernel_to_gemma, @@ -342,6 +380,7 @@ def apply_liger_kernel_to_phi3( "qwen2": apply_liger_kernel_to_qwen2, "qwen2_vl": apply_liger_kernel_to_qwen2_vl, "phi3": apply_liger_kernel_to_phi3, + "jamba": apply_liger_kernel_to_jamba, } diff --git a/test/convergence/test_mini_models.py b/test/convergence/test_mini_models.py index a2b6f59ef..e01fd7318 100644 --- a/test/convergence/test_mini_models.py +++ b/test/convergence/test_mini_models.py @@ -15,6 +15,7 @@ from torch.utils.data import DataLoader from transformers.models.gemma import GemmaConfig, GemmaForCausalLM from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM +from transformers.models.jamba import JambaConfig, JambaForCausalLM from transformers.models.llama import LlamaConfig, LlamaForCausalLM from transformers.models.mistral import MistralConfig, MistralForCausalLM from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM @@ -24,6 +25,7 @@ from liger_kernel.transformers import ( apply_liger_kernel_to_gemma, apply_liger_kernel_to_gemma2, + apply_liger_kernel_to_jamba, apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral, apply_liger_kernel_to_mixtral, @@ -279,6 +281,31 @@ attn_implementation="eager", ), ), + "mini_jamba": MiniModelConfig( + liger_kernel_patch_func=functools.partial( + apply_liger_kernel_to_jamba, fused_linear_cross_entropy=False + ), + model_class=JambaForCausalLM, + mini_model_config=JambaConfig( + attention_dropout=0.0, + num_experts_per_tok=1, + num_experts=2, + bos_token_id=1, + eos_token_id=2, # 32000 + hidden_act="silu", + hidden_size=1024, # 3072 + initializer_range=0.02, + intermediate_size=2048, # 8192 + max_position_embeddings=32768, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + rms_norm_eps=1e-5, + sliding_window=None, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32064, + ), + ), } @@ -316,6 +343,8 @@ def run_mini_model( kwargs["geglu"] = True else: kwargs["swiglu"] = True + if model_name == "mini_jamba": + del kwargs["rope"] MINI_MODEL_SETUPS[model_name].liger_kernel_patch_func(**kwargs) model = create_model(model_name).to(dtype).to("cuda") @@ -328,7 +357,6 @@ def run_mini_model( optimizer = torch.optim.AdamW(model.parameters(), lr=lr) loss_list = [] - for i in range(num_steps): batch = next(loader_iter).to(model.device) optimizer.zero_grad() @@ -461,6 +489,8 @@ def run_mini_model( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), + # To run this test, you need to first run `pip install . '[test]'` + # ("mini_jamba", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), ], ) def test_mini_model( diff --git a/test/convergence/test_mini_models_no_logits.py b/test/convergence/test_mini_models_no_logits.py index 540468849..38f295dc1 100644 --- a/test/convergence/test_mini_models_no_logits.py +++ b/test/convergence/test_mini_models_no_logits.py @@ -11,6 +11,7 @@ import torch from datasets import load_from_disk from torch.utils.data import DataLoader +from transformers import JambaConfig, JambaForCausalLM from transformers.models.gemma import GemmaConfig, GemmaForCausalLM from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM from transformers.models.llama import LlamaConfig, LlamaForCausalLM @@ -22,6 +23,7 @@ from liger_kernel.transformers import ( apply_liger_kernel_to_gemma, apply_liger_kernel_to_gemma2, + apply_liger_kernel_to_jamba, apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral, apply_liger_kernel_to_mixtral, @@ -259,6 +261,29 @@ attention_dropout=0.0, ), ), + "mini_jamba": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_jamba, + model_class=JambaForCausalLM, + mini_model_config=JambaConfig( + attention_dropout=0.0, + num_experts_per_tok=1, + num_experts=2, + bos_token_id=1, + eos_token_id=2, # 32000 + hidden_act="silu", + hidden_size=1024, # 3072 + initializer_range=0.02, + intermediate_size=2048, # 8192 + max_position_embeddings=32768, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + rms_norm_eps=1e-5, + sliding_window=None, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32064, + ), + ), } if QWEN2_VL_AVAILABLE: @@ -346,7 +371,8 @@ def run_mini_model( kwargs["geglu"] = True else: kwargs["swiglu"] = True - + if model_name == "mini_jamba": + del kwargs["rope"] model_support_flce = "gemma2" not in model_name if model_support_flce: kwargs["fused_linear_cross_entropy"] = True @@ -547,6 +573,8 @@ def run_mini_model( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), ), + # To run this test, you need to first run `pip install . '[test]'` + # ("mini_jamba", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5), ], ) def test_mini_model(