Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KTO loss #410

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 33 additions & 32 deletions src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@ def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):

@staticmethod
def forward(
ctx,
_input,
weight,
target,
bias=None,
loss_fn=None,
chunk_size=1,
compute_nll_loss=True,
ignore_index=-100,
alpha=1.0,
beta=0.1,
compiled=True,
**loss_kwargs,
ctx,
_input,
weight,
target,
bias=None,
loss_fn=None,
chunk_size=1,
compute_nll_loss=True,
ignore_index=-100,
alpha=1.0,
beta=0.1,
compiled=True,
**loss_kwargs,
):
"""
Base class for fused linear layer with preference loss.
Expand All @@ -53,6 +53,7 @@ def forward(
loss_kwargs (dict): Other possible arguments that a loss function might need
"""
# TODO: Tune CHUNK_SIZE to fully utilize the GPU

CHUNK_SIZE = chunk_size

grad_weight = torch.zeros_like(weight)
Expand Down Expand Up @@ -101,16 +102,17 @@ def accumulate_chunk(input_chunk, target_chunk):
accumulate_chunk = torch.compile(accumulate_chunk)

len_chosen = target.shape[0] // 2
chunks = int(chunks)
_chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0)
_chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0)
_rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
_rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)

for (
chosen_input_chunk,
rejected_input_chunk,
chosen_target_chunk,
rejected_target_chunk,
chosen_input_chunk,
rejected_input_chunk,
chosen_target_chunk,
rejected_target_chunk,
) in zip(
_chosen_input_chunks,
_rejected_input_chunks,
Expand All @@ -125,7 +127,7 @@ def accumulate_chunk(input_chunk, target_chunk):
grad_input = accumulate_chunk(input_chunk, target_chunk)

grad_chosen_inputs.append(grad_input[: chosen_target_chunk.shape[0]])
grad_rejected_inputs.append(grad_input[chosen_target_chunk.shape[0] :])
grad_rejected_inputs.append(grad_input[chosen_target_chunk.shape[0]:])

# combine grad_chosen_inputs and grad_rejected_inputs
grad_inputs = grad_chosen_inputs + grad_rejected_inputs
Expand All @@ -149,17 +151,17 @@ def backward(ctx, grad_output):

@staticmethod
def _compute_loss(
input_chunk,
weight,
target_chunk,
bias=None,
preference_loss_fn=None,
full_target=None,
ignore_index=-100,
alpha=1.0,
beta=0.1,
compute_nll_loss=True,
**loss_kwargs,
input_chunk,
weight,
target_chunk,
bias=None,
preference_loss_fn=None,
full_target=None,
ignore_index=-100,
alpha=1.0,
beta=0.1,
compute_nll_loss=True,
**loss_kwargs,
):
"""
Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
Expand Down Expand Up @@ -191,8 +193,8 @@ def _compute_loss(
ignore_index=ignore_index,
)
chosen_nll_loss = (
chosen_nll_loss
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
chosen_nll_loss
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
)

loss_mask = target_chunk != ignore_index
Expand All @@ -205,7 +207,6 @@ def _compute_loss(

chosen_logps = average_log_prob[:len_chosen_chunk]
rejected_logps = average_log_prob[len_chosen_chunk:]

alignment_loss = preference_loss_fn(
chosen_logps, rejected_logps, beta=beta, **loss_kwargs
)
Expand Down
236 changes: 236 additions & 0 deletions src/liger_kernel/chunked_loss/fused_linear_preference_kto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
from abc import abstractmethod
from functools import partial

import torch
from torch.nn import functional as F


class LigerFusedLinearKTOPreferenceBase(torch.autograd.Function):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am getting this error
E RuntimeError: CUDA error: device-side assert triggered E CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. E For debugging consider passing CUDA_LAUNCH_BLOCKING=1 E Compile with TORCH_USE_CUDA_DSA` to enable device-side assertions.

src/liger_kernel/chunked_loss/fused_linear_preference.py:210: RuntimeError
---------------------------------------------------------------------------------------------------------------- Captured stdout call -----------------------------------------------------------------------------------------------------------------

---------------------------------------------------------------------------------------------------------------- Captured stderr call -----------------------------------------------------------------------------------------------------------------
NoneType: None
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [6,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [7,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [12,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [83,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [32,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [43,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [54,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [59,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [62,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.
=============================================================================================================== short test summary info ===============================================================================================================
FAILED test/chunked_loss/test_kto_loss.py::test_correctness[-100-0.1-1.0-False-1.0-dtype0-0.005-0.005-3-47-31-123] - RuntimeError: CUDA error: device-side assert triggered
================================================================================================================== 1 failed in 1.86s ============================================`

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will do the equations and the formatting. Also I need two arguments 'reference_chosen_logps' and 'reference_rejected_logps' to my custom loss function.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused do you still need this file? The base classes abstract preference_loss_fn does accept those two arguments, you can set beta=0 if it's not needed.

In case you need a completely new function signature, my advice would be to add a new overloaded function in the existing base class.


@abstractmethod
def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1):
"""
Compute preference loss.
Args:
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
beta (float): Weight for the odds ratio loss.
"""
raise NotImplementedError("Preference loss function must be implemented.")

@staticmethod
def forward(
ctx,
_input,
weight,
target,
labels,
reference_logps,
bias=None,
loss_fn=None,
chunk_size=1,
compute_nll_loss=True,
ignore_index=-100,
alpha=1.0,
beta=0.1,
compiled=True,
**loss_kwargs,
):
"""
Base class for fused linear layer with preference loss.
Expects _input to be stacked with chosen and rejected inputs on the batch dimension.

Args:
_input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size).
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len).
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs).
compute_nll_loss (bool): Whether to compute NLL loss.
ignore_index (int): Index to ignore for loss computation.
alpha (float): Weight for the NLL loss.
beta (float): Weight for the odds ratio loss.
compiled (bool): Whether to use torch compile for chunk accumulation.
loss_kwargs (dict): Other possible arguments that a loss function might need
"""
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
CHUNK_SIZE = chunk_size

grad_weight = torch.zeros_like(weight)
grad_chosen_inputs = []
grad_rejected_inputs = []
grad_bias = torch.zeros_like(bias) if bias is not None else None
loss_acc = torch.zeros((), device=_input.device)

chunks = 1#max(1, _input.shape[0] // (2 * CHUNK_SIZE))

loss_func_to_call = partial(
LigerFusedLinearKTOPreferenceBase._compute_loss,
preference_loss_fn=loss_fn,
ignore_index=ignore_index,
alpha=alpha,
beta=beta,
compute_nll_loss=compute_nll_loss,
full_target=target,
**loss_kwargs,
)

def accumulate_chunk(input_chunk, target_chunk):
#global loss_acc
if bias is not None:
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
chunk_loss,
(chunk_or_loss, chunk_chosen_logps, chunk_rejected_logps),
) = torch.func.grad_and_value(
loss_func_to_call, argnums=(0, 1,2,3,4,5), has_aux=True
)(
input_chunk, weight[0], labels, reference_logps, target_chunk, bias
)
grad_bias.add_(chunk_grad_bias)
else:

(( chunk_grad_input , chunk_grad_weight),
(chunk_loss, (alignment_loss, chosen_logps,rejected_logps))) = torch.func.grad_and_value(
loss_func_to_call, argnums=(0, 1), has_aux=True
)(
input_chunk, weight, labels, reference_logps,target_chunk
)
#grad_weight.add_(chunk_grad_weight)
grad_weight = chunk_grad_weight
#loss_acc = chunk_loss
loss_acc.add_(chunk_loss)
return chunk_grad_input

if compiled:
accumulate_chunk = torch.compile(accumulate_chunk)

len_chosen = target.shape[0] // 2
_chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0)
_chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0)
_rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0)
_rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0)

for (
chosen_input_chunk,
rejected_input_chunk,
chosen_target_chunk,
rejected_target_chunk,
) in zip(
_chosen_input_chunks,
_rejected_input_chunks,
_chosen_target_chunks,
_rejected_target_chunks,
):
input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0)
target_chunk = torch.cat(
[chosen_target_chunk, rejected_target_chunk], dim=0
)

grad_input = accumulate_chunk(input_chunk, target_chunk)

grad_chosen_inputs.append(grad_input[: chosen_target_chunk.shape[0]])
grad_rejected_inputs.append(grad_input[chosen_target_chunk.shape[0] :])

# combine grad_chosen_inputs and grad_rejected_inputs
grad_inputs = grad_chosen_inputs + grad_rejected_inputs

ctx.save_for_backward(
torch.cat(grad_inputs, dim=0),
grad_weight,
grad_bias,
)

return loss_acc

@staticmethod
def backward(ctx, grad_output):
grad_input, grad_weight, grad_bias = ctx.saved_tensors
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
grad_input = grad_input * grad_output
grad_weight = grad_weight * grad_output
grad_bias = grad_bias * grad_output if grad_bias is not None else None

return grad_input, grad_weight, None, grad_bias, None, None, None

@staticmethod
def _compute_loss(
input_chunk,
weight,
labels,
reference_logps,
target_chunk,
bias=None,
preference_loss_fn=None,
full_target=None,
ignore_index=-100,
alpha=1.0,
beta=0.1,
compute_nll_loss=False,
**loss_kwargs,
):
"""
Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
Args:
preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size).
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size).
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length).
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length).
ignore_index (int): Index to ignore for loss computation.
alpha (float): Weight for the NLL loss.
beta (float): Weight for the odds ratio loss.
loss_kwargs (dict): Additional arguments for the loss function.
"""
len_chosen_chunk = target_chunk.shape[0] // 2

logits_chunk = input_chunk @ weight.t() # chunk_size x V
if bias is not None:
logits_chunk = logits_chunk + bias
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1)

chosen_nll_loss = 0.0
# compute_nll_loss=False
if compute_nll_loss:
chosen_nll_loss = F.nll_loss(
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]),
target_chunk[:len_chosen_chunk].view(-1),
reduction="sum",
ignore_index=ignore_index,
)
chosen_nll_loss = (
chosen_nll_loss
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
)

loss_mask = target_chunk != ignore_index
label_chunk = torch.where(loss_mask, target_chunk, 0)

per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(
-1
)
average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)

chosen_logps = average_log_prob[:len_chosen_chunk]
rejected_logps = average_log_prob[len_chosen_chunk:]
for i in range(reference_logps.shape[0]):
chosen_idx = [i for i in range(reference_logps.shape[0]) if labels[i][0].item() == 1.]
rejected_idx = [i for i in range(reference_logps.shape[0]) if labels[i][0].item() == 0.]

reference_chosen_logps = reference_logps[chosen_idx, ...]
reference_rejected_logps = reference_logps[rejected_idx, ...]

alignment_loss, chosen_rewards, rejected_rewards = preference_loss_fn(
chosen_logps, rejected_logps,reference_chosen_logps,reference_rejected_logps,
beta=beta, **loss_kwargs
)

#alignment_loss = alignment_loss / (full_target.shape[0] // 2)

#loss = alpha * chosen_nll_loss - alignment_loss
loss = 0 - alignment_loss.mean()

return loss, (alignment_loss, chosen_logps, rejected_logps)
#return loss[0], (alignment_loss, chosen_logps, rejected_logps)
#return None, (None, None, None)
Loading