Skip to content

Commit

Permalink
Merge branch 'main' into austin362667/chunked_compiled_jsd_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
austin362667 authored Jan 20, 2025
2 parents 02de410 + a8fa3bb commit a3d585b
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 20 deletions.
18 changes: 16 additions & 2 deletions benchmark/scripts/benchmark_rope.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
import triton

from test.utils import transformers_version_dispatch
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from utils import QUANTILES
Expand Down Expand Up @@ -30,7 +32,13 @@ def bench_speed_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput
seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x

head_dim = hidden_size // num_q_heads
rotary_emb = LlamaRotaryEmbedding(head_dim, device=device)
rotary_emb = transformers_version_dispatch(
"4.48.0",
LlamaRotaryEmbedding,
LlamaRotaryEmbedding,
before_kwargs={"dim": head_dim, "device": device},
after_kwargs={"config": LlamaConfig(num_kv_heads=num_kv_heads, head_dim=head_dim), "device": device},
)
q = torch.randn(
(1, seq_len, num_q_heads, head_dim),
device=device,
Expand Down Expand Up @@ -105,7 +113,13 @@ def bench_memory_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutpu
seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x

head_dim = hidden_size // num_q_heads
rotary_emb = LlamaRotaryEmbedding(head_dim, device=device)
rotary_emb = transformers_version_dispatch(
"4.48.0",
LlamaRotaryEmbedding,
LlamaRotaryEmbedding,
before_kwargs={"dim": head_dim, "device": device},
after_kwargs={"config": LlamaConfig(num_kv_heads=num_kv_heads, head_dim=head_dim), "device": device},
)
q = torch.randn(
(1, seq_len, num_q_heads, head_dim),
device=device,
Expand Down
27 changes: 21 additions & 6 deletions src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def fused_linear_cross_entropy_forward(
label_smoothing=0.0,
reduction="mean",
softcap=None,
return_z_loss=False,
):
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"
device = _input.device

# inputs have shape: BT x H
Expand All @@ -47,6 +49,7 @@ def fused_linear_cross_entropy_forward(
grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None
# we use fp32 for loss accumulator
loss_1d = torch.zeros(BT, dtype=torch.float32, device=device)
z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None

# TODO: evaluate how CUDA synchronization caused by .item() affects the speed
target_mask = target != ignore_index
Expand Down Expand Up @@ -81,6 +84,7 @@ def fused_linear_cross_entropy_forward(

# unreduced loss
loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size,
z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None

# ensure _input and target are contiguous
logits_chunk = logits_chunk.contiguous()
Expand All @@ -94,7 +98,7 @@ def fused_linear_cross_entropy_forward(
Y_stride=target_chunk.stride(-1), # always 1
weight_ptr=ce_weight,
loss_ptr=loss_1d_slice,
z_loss_ptr=None,
z_loss_ptr=z_loss_1d_slice,
loss_stride=loss_1d_slice.stride(-1), # always 1
n_cols=V,
n_non_ignore=total_n_non_ignore,
Expand All @@ -105,14 +109,16 @@ def fused_linear_cross_entropy_forward(
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap,
RETURN_Z_LOSS=False,
RETURN_Z_LOSS=return_z_loss,
HAS_WEIGHT=True if ce_weight is not None else False,
HAS_SOFTCAPPING=True if softcap is not None else False,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if not is_hip() else 16,
)

loss_1d[start_idx:end_idx] = loss_1d_slice
if return_z_loss:
z_loss_1d[start_idx:end_idx] = z_loss_1d_slice
grad_logits_chunk = logits_chunk # chunk_size x V

grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
Expand All @@ -139,9 +145,12 @@ def fused_linear_cross_entropy_forward(

if reduction == "none":
loss = loss_1d
z_loss = z_loss_1d if return_z_loss else None

else:
loss = torch.sum(loss_1d)
return loss, grad_input, grad_weight, grad_bias
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
return loss, z_loss, grad_input, grad_weight, grad_bias


def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
Expand Down Expand Up @@ -206,6 +215,7 @@ def forward(
label_smoothing=0.0,
reduction="mean",
softcap=None,
return_z_loss: bool = False,
):
"""
Fusing the last linear layer with cross-entropy loss
Expand All @@ -226,7 +236,7 @@ def forward(
reduction: reduction to apply
"""

loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
_input=_input,
weight=weight,
target=target,
Expand All @@ -237,18 +247,22 @@ def forward(
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap,
return_z_loss=return_z_loss,
)
# downcast to dtype and store for backward
ctx.save_for_backward(
grad_input.detach(),
grad_weight.detach() if grad_weight is not None else None,
grad_bias.detach() if bias is not None else None,
)
return loss
ctx.return_z_loss = return_z_loss
return loss, z_loss

@staticmethod
@amp_custom_bwd
def backward(ctx, grad_output):
def backward(ctx, grad_output, grad_output2):
if ctx.return_z_loss:
del grad_output2 # z_loss is only for logging
(grad_input, grad_weight, grad_bias) = ctx.saved_tensors
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
grad_output, grad_input, grad_weight, grad_bias
Expand All @@ -264,4 +278,5 @@ def backward(ctx, grad_output):
None,
None,
None,
None,
)
7 changes: 6 additions & 1 deletion src/liger_kernel/transformers/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ def liger_fused_linear_cross_entropy(
label_smoothing: float = 0.0,
reduction: str = "mean",
softcap: Optional[float] = None,
return_z_loss: bool = False,
):
return LigerFusedLinearCrossEntropyFunction.apply(
loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply(
input,
weight,
target,
Expand All @@ -68,7 +69,11 @@ def liger_fused_linear_cross_entropy(
label_smoothing,
reduction,
softcap,
return_z_loss,
)
if not return_z_loss:
return loss
return loss, z_loss


def liger_fused_linear_jsd(
Expand Down
8 changes: 7 additions & 1 deletion src/liger_kernel/transformers/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(
label_smoothing: float = 0.0,
reduction: str = "mean",
softcap: Optional[float] = None,
return_z_loss: bool = False,
):
super().__init__()
assert (label_smoothing >= 0) and (
Expand All @@ -31,9 +32,10 @@ def __init__(
self.label_smoothing = label_smoothing
self.reduction = reduction
self.softcap = softcap
self.return_z_loss = return_z_loss

def forward(self, lin_weight, _input, target, bias=None):
return LigerFusedLinearCrossEntropyFunction.apply(
loss, z_loss = LigerFusedLinearCrossEntropyFunction.apply(
_input,
lin_weight,
target,
Expand All @@ -44,4 +46,8 @@ def forward(self, lin_weight, _input, target, bias=None):
self.label_smoothing,
self.reduction,
self.softcap,
self.return_z_loss,
)
if not self.return_z_loss:
return loss
return loss, z_loss
41 changes: 33 additions & 8 deletions test/transformers/test_fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
label_smoothing: float = 0.0,
reduction: str = "mean",
softcap: Optional[float] = None,
return_z_loss: bool = False,
):
super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
Expand All @@ -54,6 +55,7 @@ def __init__(
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
reduction=reduction,
return_z_loss=return_z_loss,
)
self.softcap = softcap

Expand All @@ -77,6 +79,7 @@ def __init__(
label_smoothing: float = 0.0,
reduction: str = "mean",
softcap: Optional[float] = None,
return_z_loss: bool = False,
):
super().__init__()
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
Expand All @@ -87,6 +90,7 @@ def __init__(
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap,
return_z_loss=return_z_loss,
)

def forward(self, x, y):
Expand Down Expand Up @@ -118,11 +122,11 @@ def forward(self, x, y):
)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize(
"has_ce_weight, label_smoothing, ignore_index, lse_square_scale, softcap",
"has_ce_weight, label_smoothing, ignore_index, lse_square_scale, softcap, return_z_loss",
[
(False, 0, -100, 0, None),
(False, 0, -100, 0, None, False),
# Pass non-default values once to ensure all params work along
(True, 0.1, 42, 1e-4, 30.0),
(True, 0.1, 42, 1e-4, 30.0, True),
],
)
def test_correctness(
Expand All @@ -139,6 +143,7 @@ def test_correctness(
ignore_index,
reduction,
softcap,
return_z_loss,
atol,
rtol,
):
Expand All @@ -156,6 +161,7 @@ def test_correctness(
ignore_index=ignore_index,
reduction=reduction,
softcap=softcap,
return_z_loss=return_z_loss,
dtype=dtype,
).to(device)
liger_lm_head_ce = LigerLMHeadCE(
Expand All @@ -168,6 +174,7 @@ def test_correctness(
ignore_index=ignore_index,
reduction=reduction,
softcap=softcap,
return_z_loss=return_z_loss,
dtype=dtype,
).to(device)

Expand All @@ -189,10 +196,16 @@ def test_correctness(
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] # Randomly select indices
target[indices_to_assign] = ignore_index

output1 = torch_lm_head_ce(_input1, target)
output2 = liger_lm_head_ce(_input2, target)
if return_z_loss:
output1, z_output1 = torch_lm_head_ce(_input1, target)
output2, z_output2 = liger_lm_head_ce(_input2, target)
else:
output1 = torch_lm_head_ce(_input1, target)
output2 = liger_lm_head_ce(_input2, target)

assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol)
if return_z_loss:
assert_verbose_allclose(z_output1, z_output2, atol=atol, rtol=rtol)

output1.backward(gradient=torch.ones_like(output1))
output2.backward(gradient=torch.ones_like(output2))
Expand Down Expand Up @@ -230,8 +243,9 @@ def test_correctness(
(1.0, torch.float32, 1e-5, 5e-4),
],
)
@pytest.mark.parametrize("ce_weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
def test_correctness_functional(B, T, H, V, scalar, dtype, bias, atol, rtol):
def test_correctness_functional(B, T, H, V, scalar, dtype, bias, ce_weight, atol, rtol):
_input = torch.randn(B * T, H, device=device, dtype=dtype) * scalar
x1 = _input.detach().clone().requires_grad_(True)
x2 = _input.detach().clone().requires_grad_(True)
Expand All @@ -241,15 +255,26 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, bias, atol, rtol):
weight = torch.randn(V, H, device=device, dtype=dtype)
bias = torch.randn(V, device=device, dtype=dtype) if bias else None

y1 = liger_fused_linear_cross_entropy(
ce_weight = torch.randn(V, device=device) if ce_weight else None
y1, z1 = liger_fused_linear_cross_entropy(
input=x1,
weight=weight,
target=target,
bias=bias,
ce_weight=ce_weight,
ignore_index=-100,
lse_square_scale=1e-4,
label_smoothing=0.1,
reduction="mean",
softcap=30.0,
return_z_loss=True,
)
y2, z2 = LigerFusedLinearCrossEntropyFunction.apply(
x2, weight, target, bias, ce_weight, -100, 1e-4, 0.1, "mean", 30.0, True
)
y2 = LigerFusedLinearCrossEntropyFunction.apply(x2, weight, target, bias)

assert torch.allclose(y1, y2, atol=atol, rtol=rtol)
assert torch.allclose(z1, z2, atol=atol, rtol=rtol)

grad_output = torch.randn_like(y1)

Expand Down
18 changes: 16 additions & 2 deletions test/transformers/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch

from test.utils import supports_bfloat16
from test.utils import transformers_version_dispatch
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb

Expand Down Expand Up @@ -57,7 +59,13 @@ def test_correctness(
atol,
rtol,
):
rotary_emb = LlamaRotaryEmbedding(head_dim, device=device)
rotary_emb = transformers_version_dispatch(
"4.48.0",
LlamaRotaryEmbedding,
LlamaRotaryEmbedding,
before_kwargs={"dim": head_dim, "device": device},
after_kwargs={"config": LlamaConfig(num_kv_heads=num_kv_heads, head_dim=head_dim), "device": device},
)

_tensor_q = torch.randn((bsz, seq_len, num_q_heads, head_dim), device=device).transpose(1, 2).to(dtype)

Expand Down Expand Up @@ -133,7 +141,13 @@ def test_functional_correctness(
k1 = _k.clone().requires_grad_(True)
k2 = _k.clone().requires_grad_(True)

rotary_emb = LlamaRotaryEmbedding(head_dim, device=device)
rotary_emb = transformers_version_dispatch(
"4.48.0",
LlamaRotaryEmbedding,
LlamaRotaryEmbedding,
before_kwargs={"dim": head_dim, "device": device},
after_kwargs={"config": LlamaConfig(num_kv_heads=num_kv_heads, head_dim=head_dim), "device": device},
)

pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
if expand_position_ids:
Expand Down
Loading

0 comments on commit a3d585b

Please sign in to comment.