Skip to content

Commit

Permalink
Merge branch 'main' into austin362667/chunked_compiled_jsd_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
austin362667 authored Jan 9, 2025
2 parents ed2da69 + 9586a87 commit f02c396
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 64 deletions.
10 changes: 6 additions & 4 deletions benchmark/scripts/benchmark_orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ def bench_memory_fused_linear_orpo_loss(

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
nll_target = torch.randint(V, (B, T), dtype=torch.long, device=device)

def fwd():
if provider == "liger":
return liger_lm_head_orpo(_input, target)
return liger_lm_head_orpo(_input, target, nll_target)
elif provider == "huggingface":
return torch_lm_head_orpo(_input, target)
return torch_lm_head_orpo(_input, target, nll_target)

def full():
y = fwd()
Expand Down Expand Up @@ -91,12 +92,13 @@ def bench_speed_fused_linear_orpo_loss(

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
nll_target = torch.randint(V, (B, T), dtype=torch.long, device=device)

def fwd():
if provider == "liger":
return liger_lm_head_orpo(_input, target)
return liger_lm_head_orpo(_input, target, nll_target)
elif provider == "huggingface":
return torch_lm_head_orpo(_input, target)
return torch_lm_head_orpo(_input, target, nll_target)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
Expand Down
52 changes: 40 additions & 12 deletions src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def forward(
alpha=1.0,
beta=0.1,
compute_nll_loss=True,
nll_target=None,
compiled=True,
use_ref_model=False,
ref_input=None,
Expand Down Expand Up @@ -58,6 +59,7 @@ def forward(
alpha (float): Weight for the NLL loss.
beta (float): Weight for the preference loss.
compute_nll_loss (bool): Whether to compute NLL loss.
nll_target (torch.Tensor, optional): Target tensor for NLL loss. Shape: (batch_size, seq_len). If not provided the target is used.
compiled (bool): Whether to use torch compile for chunk accumulation.
use_ref_model (bool): Whether to use a reference model for the alignment loss.
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
Expand Down Expand Up @@ -96,11 +98,12 @@ def forward(
use_ref_model=use_ref_model,
ref_weight=ref_weight,
ref_bias=ref_bias,
full_nll_target=nll_target,
average_log_prob=average_log_prob,
**loss_kwargs,
)

def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk):
def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk):
"""
Fused forward and backward pass for a chunk of input and target.
"""
Expand All @@ -111,13 +114,18 @@ def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk):
target_chunk,
bias,
ref_input_chunk=ref_input_chunk,
chosen_nll_target_chunk=chosen_nll_target_chunk,
)
else:
return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
input_chunk, weight, target_chunk, ref_input_chunk=ref_input_chunk
input_chunk,
weight,
target_chunk,
ref_input_chunk=ref_input_chunk,
chosen_nll_target_chunk=chosen_nll_target_chunk,
)

def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None):
def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None, chosen_nll_target_chunk=None):
if bias is not None:
(
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
Expand All @@ -132,7 +140,7 @@ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None):
*aux_outputs,
),
),
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
grad_bias.add_(chunk_grad_bias) # accumulate bias gradient
else:
(
Expand All @@ -148,7 +156,7 @@ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None):
*aux_outputs,
),
),
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk)
) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)

# Accumulate gradients
grad_weight.add_(chunk_grad_weight)
Expand Down Expand Up @@ -191,6 +199,9 @@ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None):
_rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
_rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)

if nll_target is not None:
_chosen_nll_target_chunks = torch.chunk(nll_target[:len_chosen], chunks=chunks, dim=0)

if use_ref_model:
_ref_chosen_input_chunks = torch.chunk(ref_input[:len_chosen], chunks=chunks, dim=0)
_ref_rejected_input_chunks = torch.chunk(ref_input[len_chosen:], chunks=chunks, dim=0)
Expand All @@ -202,13 +213,15 @@ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None):
rejected_target_chunk,
ref_chosen_input_chunk,
ref_rejected_input_chunk,
chosen_nll_target_chunk,
) in zip(
_chosen_input_chunks,
_rejected_input_chunks,
_chosen_target_chunks,
_rejected_target_chunks,
(_ref_chosen_input_chunks if use_ref_model else [None] * len(_chosen_input_chunks)),
(_ref_rejected_input_chunks if use_ref_model else [None] * len(_rejected_input_chunks)),
(_chosen_nll_target_chunks if nll_target is not None else [None] * len(_chosen_input_chunks)),
strict=False,
):
input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
Expand All @@ -222,9 +235,10 @@ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None):
torch._dynamo.mark_dynamic(target_chunk, 1)
torch._dynamo.mark_dynamic(target, 1)
torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None
torch._dynamo.mark_dynamic(chosen_nll_target_chunk, 1) if nll_target is not None else None

# accumulate loss, gradients, and metrics
accumulate_chunk(input_chunk, target_chunk, ref_input_chunk)
accumulate_chunk(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)

# combine grad_chosen_inputs and grad_rejected_inputs
grad_inputs = grad_chosen_inputs + grad_rejected_inputs
Expand Down Expand Up @@ -258,7 +272,7 @@ def backward(ctx, *grad_output):
grad_weight = grad_weight * grad_output[0][0]
grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None

return grad_input, grad_weight, None, grad_bias, None, None, None
return grad_input, grad_weight, None, grad_bias, None, None, None, None

@staticmethod
def chunk_forward(
Expand All @@ -268,6 +282,7 @@ def chunk_forward(
bias=None,
ignore_index=-100,
compute_nll_loss=True,
chosen_nll_target_chunk=None,
average_log_prob=True,
):
len_chosen_chunk = target_chunk.shape[0] // 2
Expand All @@ -278,9 +293,12 @@ def chunk_forward(

chosen_nll_loss = 0.0
if compute_nll_loss:
nll_labels = (
chosen_nll_target_chunk if chosen_nll_target_chunk is not None else target_chunk[:len_chosen_chunk]
)
chosen_nll_loss = F.nll_loss(
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
target_chunk[:len_chosen_chunk].view(-1),
nll_labels.view(-1),
reduction="sum",
ignore_index=ignore_index,
)
Expand Down Expand Up @@ -324,6 +342,8 @@ def _compute_loss(
ref_input_chunk=None,
ref_weight=None,
ref_bias=None,
full_nll_target=None,
chosen_nll_target_chunk=None,
average_log_prob=True,
**loss_kwargs,
):
Expand All @@ -343,6 +363,8 @@ def _compute_loss(
use_ref_model (bool): Whether to use a reference model for the alignment loss.
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
full_nll_target (torch.Tensor, optional): Full target tensor for NLL loss. Shape: (batch_size, sequence_length).
chosen_nll_target_chunk (torch.Tensor, optional): Target tensor for NLL loss. Shape: (chunk_size, sequence_length) If not provided the target_chunk is used.
average_log_prob (bool): Whether to average log probabilities or the sum.
loss_kwargs (dict): Additional arguments for the loss function.
"""
Expand All @@ -359,9 +381,14 @@ def _compute_loss(
bias=bias,
ignore_index=ignore_index,
compute_nll_loss=compute_nll_loss,
chosen_nll_target_chunk=chosen_nll_target_chunk,
average_log_prob=average_log_prob,
)
chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
if full_nll_target is not None:
chosen_nll_loss = chosen_nll_loss / (full_nll_target[: full_nll_target.shape[0] // 2] != ignore_index).sum()
else:
chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()

chosen_logits_mean = chosen_logits.sum() / (full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0])
rejected_logits_mean = rejected_logits.sum() / (
full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]
Expand All @@ -372,16 +399,17 @@ def _compute_loss(
(
ref_chosen_logps,
ref_rejected_logps,
ref_chosen_logits,
ref_rejected_logits,
ref_chosen_nll_loss,
_,
_,
_,
) = LigerFusedLinearPreferenceBase.chunk_forward(
ref_input_chunk,
ref_weight,
target_chunk,
ref_bias,
ignore_index=ignore_index,
compute_nll_loss=False, # We don't need NLL loss for the reference model
chosen_nll_target_chunk=None,
average_log_prob=average_log_prob,
)
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
Expand Down
7 changes: 5 additions & 2 deletions src/liger_kernel/chunked_loss/orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def forward(
ignore_index=-100,
beta=0.1,
compute_nll_loss=True,
nll_target=None,
compiled=True,
):
return LigerFusedLinearPreferenceBase.forward(
Expand All @@ -64,13 +65,14 @@ def forward(
ignore_index=ignore_index,
beta=beta,
compute_nll_loss=compute_nll_loss,
nll_target=nll_target,
compiled=compiled,
)

@staticmethod
def backward(ctx, *grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
return *grads, None, None, None, None
return *grads, None, None, None, None, None


class LigerFusedLinearORPOLoss(torch.nn.Module):
Expand All @@ -96,7 +98,7 @@ def __init__(
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled

def forward(self, lin_weight, _input, target, bias=None):
def forward(self, lin_weight, _input, target, bias=None, nll_target=None):
return LigerFusedLinearORPOFunction.apply(
_input,
lin_weight,
Expand All @@ -105,5 +107,6 @@ def forward(self, lin_weight, _input, target, bias=None):
self.ignore_index,
self.beta,
self.compute_nll_loss,
nll_target,
self.compiled,
)
32 changes: 8 additions & 24 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
else:
from triton.language.math import tanh

_TRUE: tl.constexpr = tl.constexpr(1)
_FALSE: tl.constexpr = tl.constexpr(0)


@triton.jit
def liger_cross_entropy_kernel(
Expand Down Expand Up @@ -95,7 +92,7 @@ def liger_cross_entropy_kernel(
return

loss_ptr += program_id * loss_stride
if RETURN_Z_LOSS == _TRUE:
if RETURN_Z_LOSS:
z_loss_ptr += program_id * loss_stride

if HAS_WEIGHT:
Expand Down Expand Up @@ -254,7 +251,7 @@ def liger_cross_entropy_kernel(
loss += z_loss

tl.store(loss_ptr, loss)
if RETURN_Z_LOSS == _TRUE:
if RETURN_Z_LOSS:
tl.store(z_loss_ptr, z_loss)


Expand All @@ -264,12 +261,6 @@ def liger_cross_entropy_kernel(
MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning


_bool_to_return_z_loss = {
True: _TRUE.value,
False: _FALSE.value,
}


def cross_entropy_forward(
_input,
target,
Expand All @@ -281,11 +272,7 @@ def cross_entropy_forward(
softcap,
return_z_loss,
):
if not isinstance(return_z_loss, int):
assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}"
return_z_loss = _bool_to_return_z_loss[return_z_loss]
else:
assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}"
assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}"

BT, V = _input.shape
n_rows = BT
Expand All @@ -294,10 +281,7 @@ def cross_entropy_forward(

# unreduced loss
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
if return_z_loss == _TRUE.value:
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
else:
z_loss_1d = None # set None when return_z_loss == False
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None

target_mask = target != ignore_index
n_non_ignore = target_mask.sum().item()
Expand Down Expand Up @@ -326,7 +310,7 @@ def cross_entropy_forward(
X_stride=_input.stride(-2),
Y_ptr=target,
Y_stride=target.stride(-1), # always 1
weight_ptr=weight if weight is not None else _input, # dummy if None
weight_ptr=weight, # dummy if None
loss_ptr=loss_1d,
z_loss_ptr=z_loss_1d,
loss_stride=loss_1d.stride(-1), # always 1
Expand All @@ -338,7 +322,7 @@ def cross_entropy_forward(
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap if softcap is not None else 0.0,
softcap=softcap,
RETURN_Z_LOSS=return_z_loss,
BLOCK_SIZE=BLOCK_SIZE,
HAS_WEIGHT=True if weight is not None else False,
Expand All @@ -350,10 +334,10 @@ def cross_entropy_forward(

if reduction == "none":
loss = loss_1d
z_loss = z_loss_1d if return_z_loss == _TRUE.value else None
z_loss = z_loss_1d if return_z_loss else None
else:
loss = torch.sum(loss_1d)
z_loss = torch.sum(z_loss_1d) if return_z_loss == _TRUE.value else None
z_loss = torch.sum(z_loss_1d) if return_z_loss else None

return loss, z_loss, _input

Expand Down
8 changes: 4 additions & 4 deletions src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ def fused_linear_cross_entropy_forward(
X_stride=logits_chunk.stride(-2),
Y_ptr=target_chunk,
Y_stride=target_chunk.stride(-1), # always 1
weight_ptr=ce_weight if ce_weight is not None else _input, # dummy if None
weight_ptr=ce_weight,
loss_ptr=loss_1d_slice,
z_loss_ptr=loss_1d_slice, # dummy ptr, not used
z_loss_ptr=None,
loss_stride=loss_1d_slice.stride(-1), # always 1
n_cols=V,
n_non_ignore=total_n_non_ignore,
Expand All @@ -104,8 +104,8 @@ def fused_linear_cross_entropy_forward(
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap if softcap is not None else 0.0,
RETURN_Z_LOSS=0, # False
softcap=softcap,
RETURN_Z_LOSS=False,
HAS_WEIGHT=True if ce_weight is not None else False,
HAS_SOFTCAPPING=True if softcap is not None else False,
BLOCK_SIZE=BLOCK_SIZE,
Expand Down
3 changes: 0 additions & 3 deletions src/liger_kernel/transformers/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ def __init__(
assert (label_smoothing >= 0) and (
label_smoothing <= 1
), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
assert (label_smoothing >= 0) and (
label_smoothing <= 1
), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}"
assert reduction in {
"mean",
"sum",
Expand Down
Loading

0 comments on commit f02c396

Please sign in to comment.