Skip to content

Commit

Permalink
Merge branch 'main' into austin362667/chunked_compiled_jsd_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
lancerts authored Dec 11, 2024
2 parents b23a4f8 + eee40c5 commit d052bab
Show file tree
Hide file tree
Showing 20 changed files with 394 additions and 294 deletions.
73 changes: 40 additions & 33 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,12 @@

<img src="https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/logo-banner.png">

[Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Cite our work](#cite-this-work)
[Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [High-level APIs](#high-level-apis) | [Low-level APIs](#low-level-apis) | [Cite our work](#cite-this-work)

<details>
<summary>Latest News 🔥</summary>

- [2024/12/15] We release LinkedIn Engineering Blog - [Liger-Kernel: Empowering an open source ecosystem of Triton Kernels for Efficient LLM Training](https://www.linkedin.com/blog/engineering/open-source/liger-kernel-open-source-ecosystem-for-efficient-llm-training)
- [2024/11/6] We release [v0.4.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.4.0): Full AMD support, Tech Report, Modal CI, Llama-3.2-Vision!
- [2024/10/21] We have released the tech report of Liger Kernel on Arxiv: https://arxiv.org/pdf/2410.10989
- [2024/9/6] We release v0.2.1 ([X post](https://x.com/liger_kernel/status/1832168197002510649)). 2500+ Stars, 10+ New Contributors, 50+ PRs, 50k Downloads in two weeks!
Expand All @@ -71,6 +72,8 @@

**Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.

We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, JSD, and many more.

## Supercharge Your Model with Liger Kernel

![Banner](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/banner.GIF)
Expand All @@ -94,6 +97,7 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
| [**Lightning Trainer**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/lightning) | Increase 15% throughput and reduce memory usage by 40% with LLaMA3-8B on MMLU dataset using 8 A100s with DeepSpeed ZeRO3 |
| [**Medusa Multi-head LLM (Retraining Phase)**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) | Reduce memory usage by 80% with 5 LM heads and improve throughput by 40% using 8 A100s with FSDP |
| [**Vision-Language Model SFT**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/huggingface/run_qwen2_vl.sh) | Finetune Qwen2-VL on image-text data using 4 A100s with FSDP |
| [**Liger ORPO Trainer**](https://github.com/linkedin/Liger-Kernel/blob/main/examples/alignment/run_orpo.py) | Align Llama 3.2 using Liger ORPO Trainer with FSDP with 50% memory reduction |

## Key Features

Expand Down Expand Up @@ -211,7 +215,7 @@ loss = loss_fn(model.weight, input, target)
loss.backward()
```

## APIs
## High-level APIs

### AutoModel

Expand All @@ -235,8 +239,12 @@ loss.backward()
| Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |


## Low-level APIs

- `Fused Linear` kernels combine linear layers with losses, reducing memory usage by up to 80% - ideal for HBM-constrained workloads.
- Other kernels use fusion and in-place techniques for memory and performance optimization.

### Kernels
### Model Kernels

| **Kernel** | **API** |
|---------------------------------|-------------------------------------------------------------|
Expand All @@ -246,46 +254,51 @@ loss.backward()
| SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` |
| GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
| CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
| FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
| Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|


### Alignment Kernels

| **Kernel** | **API** |
|---------------------------------|-------------------------------------------------------------|
| Fused Linear CPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearCPOLoss` |
| Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
| Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
| Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |

### Distillation Kernels

| **Kernel** | **API** |
|---------------------------------|-------------------------------------------------------------|
| KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
| JSD | `liger_kernel.transformers.LigerJSD` |
| FusedLinearJSD | `liger_kernel.transformers.LigerFusedLinearJSD` |

- **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
- **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup.
- **GroupNorm**: [GroupNorm](https://arxiv.org/pdf/1803.08494), which normalizes activations across the group dimension for a given sample. Channels are grouped in K groups over which the normalization is performed, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and can achieve up to ~2X speedup as the number of channels/groups increases.
- **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction.
- **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
$$\text{SwiGLU}(x)=\text{Swish}_{\beta}(xW+b)\otimes(xV+c)$$
, is implemented by fusing the elementwise multiplication (denoted by $\otimes$) into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction.
- **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
$$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
, is implemented by fusing the elementwise multiplication into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. Note that the [tanh approximation form of GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used.
- **CrossEntropy**: [Cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) is implemented by computing both the loss and gradient in the forward pass with inplace replacement of input to reduce the peak memory by avoiding simultaneous materialization of both input logits and gradient. It achieves >2X speedup and >4X memory reduction for common vocab sizes (e.g., 32K, 128K, etc.).
<!-- TODO: verify vocab sizes are accurate -->
- **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
- **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
- **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size. **NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively.
- **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192. **NOTE**: It implements forward/reverse KL when `beta` equals 0 and 1 respectively.

| Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |

### Experimental Kernels

| **Kernel** | **API** |
|---------------------------------|-------------------------------------------------------------|
| Embedding | `liger_kernel.transformers.experimental.LigerEmbedding` |
| Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul`
| Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul` |

- **Embedding**: [Embedding](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) is implemented by fusing embedding lookup and output operations. It achieves a peak speedup of ~1.5x in the forward pass and an overall speedup of ~1.1x.
- **Matmul int2xint8**: is implemented by using the cache tiled matrix multiplication and by fusing the matmul with the unpacking process which achieves a considerable speed up and performs on par with @torch.compile
<!-- TODO: be more specific about batch size -->

## Contributing, Acknowledgements, and License

- [Contributing Guidelines](https://github.com/linkedin/Liger-Kernel/blob/main/docs/CONTRIBUTING.md)
- [Acknowledgements](https://github.com/linkedin/Liger-Kernel/blob/main/docs/Acknowledgement.md)
- [License Information](https://github.com/linkedin/Liger-Kernel/blob/main/docs/License.md)

## Sponsorship and Collaboration

- [AMD](https://www.amd.com/en.html): Providing AMD GPUs for our AMD CI.
- [Intel](https://www.intel.com/): Providing Intel GPUs for our Intel CI.
- [Modal](https://modal.com/): Free 3000 credits from GPU MODE IRL for our NVIDIA CI.
- [EmbeddedLLM](https://embeddedllm.com/): Making Liger Kernel run fast and stable on AMD.
- [HuggingFace](https://huggingface.co/): Integrating Liger Kernel into Hugging Face Transformers and TRL.
- [Lightning AI](https://lightning.ai/): Integrating Liger Kernel into Lightning Thunder.
- [Axolotl](https://axolotl.ai/): Integrating Liger Kernel into Axolotl.
- [Llama-Factory](https://github.com/hiyouga/LLaMA-Factory): Integrating Liger Kernel into Llama-Factory.

## Contact

- For issues, create a Github ticket in this repository
Expand All @@ -311,12 +324,6 @@ Biblatex entry:
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=linkedin/Liger-Kernel&type=Date)](https://star-history.com/#linkedin/Liger-Kernel&Date)

## Contributors

<a href="https://github.com/linkedin/Liger-Kernel/graphs/contributors">
<img alt="contributors" src="https://contrib.rocks/image?repo=linkedin/Liger-Kernel"/>
</a>

<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
↑ Back to Top ↑
Expand Down
9 changes: 5 additions & 4 deletions dev/modal/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import modal

ROOT_PATH = Path(__file__).parent.parent.parent
REMOTE_ROOT_PATH = "/root/liger-kernel"

# REBUILD_IMAGE is an environment variable that is set to "true" in the nightly build
REBUILD_IMAGE = os.getenv("REBUILD_IMAGE") is not None
Expand All @@ -17,13 +18,13 @@
app = modal.App("liger_tests", image=image)

# mount: add local files to the remote container
repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path="/root/liger-kernel")
repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH)


@app.function(gpu="A10G", mounts=[repo], timeout=60 * 15)
def liger_tests():
import subprocess

subprocess.run(["pip", "install", "-e", "."], check=True, cwd="/root/liger-kernel")
subprocess.run(["make", "test"], check=True, cwd="/root/liger-kernel")
subprocess.run(["make", "test-convergence"], check=True, cwd="/root/liger-kernel")
subprocess.run(["pip", "install", "-e", "."], check=True, cwd=REMOTE_ROOT_PATH)
subprocess.run(["make", "test"], check=True, cwd=REMOTE_ROOT_PATH)
subprocess.run(["make", "test-convergence"], check=True, cwd=REMOTE_ROOT_PATH)
9 changes: 5 additions & 4 deletions dev/modal/tests_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import modal

ROOT_PATH = Path(__file__).parent.parent.parent
REMOTE_ROOT_PATH = "/root/liger-kernel"

# REBUILD_IMAGE is an environment variable that is set to "true" in the nightly build
REBUILD_IMAGE = os.getenv("REBUILD_IMAGE") is not None
Expand All @@ -22,13 +23,13 @@
app = modal.App("liger_tests", image=image)

# mount: add local files to the remote container
repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path="/root/liger-kernel")
repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path=REMOTE_ROOT_PATH)


@app.function(gpu="A10G", mounts=[repo], timeout=60 * 10)
def liger_tests_bwd():
import subprocess

subprocess.run(["pip", "install", "-e", "."], check=True, cwd="/root/liger-kernel")
subprocess.run(["make", "test"], check=True, cwd="/root/liger-kernel")
subprocess.run(["make", "test-convergence"], check=True, cwd="/root/liger-kernel")
subprocess.run(["pip", "install", "-e", "."], check=True, cwd=REMOTE_ROOT_PATH)
subprocess.run(["make", "test"], check=True, cwd=REMOTE_ROOT_PATH)
subprocess.run(["make", "test-convergence"], check=True, cwd=REMOTE_ROOT_PATH)
11 changes: 2 additions & 9 deletions examples/alignment/run_orpo.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import ORPOConfig, ORPOTrainer # noqa: F401
from trl import ORPOConfig # noqa: F401

from liger_kernel.transformers import LigerORPOTrainer # noqa: F401
from liger_kernel.transformers.trainer import LigerORPOTrainer # noqa: F401

model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B-Instruct",
Expand All @@ -19,13 +19,6 @@

train_dataset = load_dataset("trl-lib/tldr-preference", split="train")

# train_dataset = train_dataset.map(
# lambda example: {
# "prompt": example["prompt"],
# "chosen": example["chosen"][0]["content"],
# "rejected": example["rejected"][0]["content"],
# }
# )
training_args = ORPOConfig(
output_dir="Llama3.2_1B_Instruct",
beta=0.1,
Expand Down
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "liger_kernel"
version = "0.4.2"
version = "0.5.2"
description = "Efficient Triton kernels for LLM Training"
urls = { "Homepage" = "https://github.com/linkedin/Liger-Kernel" }
readme = { file = "README.md", content-type = "text/markdown" }
Expand All @@ -20,9 +20,12 @@ transformers = [
"transformers~=4.0"
]

trl = [
"trl>=0.11.0",
]

dev = [
"transformers>=4.44.2",
"trl>=0.11.0",
"matplotlib>=3.7.2",
"flake8>=4.0.1.1",
"black>=24.4.2",
Expand Down
25 changes: 25 additions & 0 deletions src/liger_kernel/chunked_loss/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Liger FlexChunkLoss: Alignment and Distillation loss

Liger FlexChunkLoss offers a versatile interface, delivering up to 80% memory savings and a 10% throughput boost for post-training loss functions, including alignment (DPO, ORPO, CPO) and very soon, distillation. Its flexible design supports custom losses, ensuring efficiency gains across diverse use cases.

### User interface

FlexChunkLoss offers two flexible usage options:

1. **Via `Liger[Custom Loss]Trainer`**
For example, by simply replacing the HuggingFace `ORPOTrainer` with `LigerORPOTrainer` in your code, you can leverage our optimized ORPO implementation and immediately benefit from improved performance.

2. **Using `nn.Module` Implementations of Custom Loss Functions**
Explore the [LigerORPOTrainer implementation](https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/orpo_trainer.py) to see how the modular design integrates custom loss functions seamlessly.

### What's under the hood?

We employ chunking and fused kernel optimizations to enhance performance. By fusing the final linear layer with loss computation and calculating backward gradients during the forward pass, we significantly reduce the need for storing intermediate activations. All operations are implemented in PyTorch, leveraging `torch.compile` to streamline kernel execution without relying on extensive low-level optimizations. Additionally, we minimize `torch.compile` recompilations to reduce overhead and ensure consistent performance gains.

### Extending to custom loss functions

We provide two base classes: `LigerFusedLinearPreferenceBase` for alignment use cases and `LigerFusedLinearDistillationBase` for distillation use cases. These base classes manage chunking, kernel fusions, and Torch compilation.

To implement a custom loss function, you need to create a subclass that defines the custom preference or distillation loss function, capable of processing a given input chunk. The base class will take care of the optimizations, handling most of the heavy lifting for you.

For a working example, refer to the [ORPO loss implementation](https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/chunked_loss/orpo_loss.py).
Loading

0 comments on commit d052bab

Please sign in to comment.