-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
KTO loss #410
Open
vulkomilev
wants to merge
6
commits into
linkedin:main
Choose a base branch
from
valkomilev:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
KTO loss #410
Changes from 5 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
ef59f91
working on tests
vulkomilev b053b0c
test are working but I have problem with assertions
vulkomilev 2461a33
basic test working
vulkomilev 5deb7f9
returned to fused loss
vulkomilev cf0c3eb
fromated code and addded source for kto
vulkomilev 440241c
checkstyles tests and formula done
vulkomilev File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
236 changes: 236 additions & 0 deletions
236
src/liger_kernel/chunked_loss/fused_linear_preference_kto.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
from abc import abstractmethod | ||
from functools import partial | ||
|
||
import torch | ||
from torch.nn import functional as F | ||
|
||
|
||
class LigerFusedLinearKTOPreferenceBase(torch.autograd.Function): | ||
|
||
@abstractmethod | ||
def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1): | ||
""" | ||
Compute preference loss. | ||
Args: | ||
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). | ||
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). | ||
beta (float): Weight for the odds ratio loss. | ||
""" | ||
raise NotImplementedError("Preference loss function must be implemented.") | ||
|
||
@staticmethod | ||
def forward( | ||
ctx, | ||
_input, | ||
weight, | ||
target, | ||
labels, | ||
reference_logps, | ||
bias=None, | ||
loss_fn=None, | ||
chunk_size=1, | ||
compute_nll_loss=True, | ||
ignore_index=-100, | ||
alpha=1.0, | ||
beta=0.1, | ||
compiled=True, | ||
**loss_kwargs, | ||
): | ||
""" | ||
Base class for fused linear layer with preference loss. | ||
Expects _input to be stacked with chosen and rejected inputs on the batch dimension. | ||
|
||
Args: | ||
_input (torch.Tensor): Input tensor. Shape: (batch_size, seq_len, hidden_size). | ||
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). | ||
target (torch.Tensor): Target tensor. Shape: (batch_size, seq_len). | ||
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). | ||
loss_fn (callable): Loss function to compute the loss on a chunk of input/target. | ||
chunk_size (int): Size of a chunk (# of batches of stacked chosen and rejected inputs). | ||
compute_nll_loss (bool): Whether to compute NLL loss. | ||
ignore_index (int): Index to ignore for loss computation. | ||
alpha (float): Weight for the NLL loss. | ||
beta (float): Weight for the odds ratio loss. | ||
compiled (bool): Whether to use torch compile for chunk accumulation. | ||
loss_kwargs (dict): Other possible arguments that a loss function might need | ||
""" | ||
# TODO: Tune CHUNK_SIZE to fully utilize the GPU | ||
CHUNK_SIZE = chunk_size | ||
|
||
grad_weight = torch.zeros_like(weight) | ||
grad_chosen_inputs = [] | ||
grad_rejected_inputs = [] | ||
grad_bias = torch.zeros_like(bias) if bias is not None else None | ||
loss_acc = torch.zeros((), device=_input.device) | ||
|
||
chunks = 1#max(1, _input.shape[0] // (2 * CHUNK_SIZE)) | ||
|
||
loss_func_to_call = partial( | ||
LigerFusedLinearKTOPreferenceBase._compute_loss, | ||
preference_loss_fn=loss_fn, | ||
ignore_index=ignore_index, | ||
alpha=alpha, | ||
beta=beta, | ||
compute_nll_loss=compute_nll_loss, | ||
full_target=target, | ||
**loss_kwargs, | ||
) | ||
|
||
def accumulate_chunk(input_chunk, target_chunk): | ||
#global loss_acc | ||
if bias is not None: | ||
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias), ( | ||
chunk_loss, | ||
(chunk_or_loss, chunk_chosen_logps, chunk_rejected_logps), | ||
) = torch.func.grad_and_value( | ||
loss_func_to_call, argnums=(0, 1,2,3,4,5), has_aux=True | ||
)( | ||
input_chunk, weight[0], labels, reference_logps, target_chunk, bias | ||
) | ||
grad_bias.add_(chunk_grad_bias) | ||
else: | ||
|
||
(( chunk_grad_input , chunk_grad_weight), | ||
(chunk_loss, (alignment_loss, chosen_logps,rejected_logps))) = torch.func.grad_and_value( | ||
loss_func_to_call, argnums=(0, 1), has_aux=True | ||
)( | ||
input_chunk, weight, labels, reference_logps,target_chunk | ||
) | ||
#grad_weight.add_(chunk_grad_weight) | ||
grad_weight = chunk_grad_weight | ||
#loss_acc = chunk_loss | ||
loss_acc.add_(chunk_loss) | ||
return chunk_grad_input | ||
|
||
if compiled: | ||
accumulate_chunk = torch.compile(accumulate_chunk) | ||
|
||
len_chosen = target.shape[0] // 2 | ||
_chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0) | ||
_chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0) | ||
_rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0) | ||
_rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0) | ||
|
||
for ( | ||
chosen_input_chunk, | ||
rejected_input_chunk, | ||
chosen_target_chunk, | ||
rejected_target_chunk, | ||
) in zip( | ||
_chosen_input_chunks, | ||
_rejected_input_chunks, | ||
_chosen_target_chunks, | ||
_rejected_target_chunks, | ||
): | ||
input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0) | ||
target_chunk = torch.cat( | ||
[chosen_target_chunk, rejected_target_chunk], dim=0 | ||
) | ||
|
||
grad_input = accumulate_chunk(input_chunk, target_chunk) | ||
|
||
grad_chosen_inputs.append(grad_input[: chosen_target_chunk.shape[0]]) | ||
grad_rejected_inputs.append(grad_input[chosen_target_chunk.shape[0] :]) | ||
|
||
# combine grad_chosen_inputs and grad_rejected_inputs | ||
grad_inputs = grad_chosen_inputs + grad_rejected_inputs | ||
|
||
ctx.save_for_backward( | ||
torch.cat(grad_inputs, dim=0), | ||
grad_weight, | ||
grad_bias, | ||
) | ||
|
||
return loss_acc | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
grad_input, grad_weight, grad_bias = ctx.saved_tensors | ||
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)): | ||
grad_input = grad_input * grad_output | ||
grad_weight = grad_weight * grad_output | ||
grad_bias = grad_bias * grad_output if grad_bias is not None else None | ||
|
||
return grad_input, grad_weight, None, grad_bias, None, None, None | ||
|
||
@staticmethod | ||
def _compute_loss( | ||
input_chunk, | ||
weight, | ||
labels, | ||
reference_logps, | ||
target_chunk, | ||
bias=None, | ||
preference_loss_fn=None, | ||
full_target=None, | ||
ignore_index=-100, | ||
alpha=1.0, | ||
beta=0.1, | ||
compute_nll_loss=False, | ||
**loss_kwargs, | ||
): | ||
""" | ||
Compute the total loss for a chunk of input and target, while using an alignment/preference loss function. | ||
Args: | ||
preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. | ||
input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size). | ||
weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). | ||
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length). | ||
bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). | ||
full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length). | ||
ignore_index (int): Index to ignore for loss computation. | ||
alpha (float): Weight for the NLL loss. | ||
beta (float): Weight for the odds ratio loss. | ||
loss_kwargs (dict): Additional arguments for the loss function. | ||
""" | ||
len_chosen_chunk = target_chunk.shape[0] // 2 | ||
|
||
logits_chunk = input_chunk @ weight.t() # chunk_size x V | ||
if bias is not None: | ||
logits_chunk = logits_chunk + bias | ||
log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1) | ||
|
||
chosen_nll_loss = 0.0 | ||
# compute_nll_loss=False | ||
if compute_nll_loss: | ||
chosen_nll_loss = F.nll_loss( | ||
log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]), | ||
target_chunk[:len_chosen_chunk].view(-1), | ||
reduction="sum", | ||
ignore_index=ignore_index, | ||
) | ||
chosen_nll_loss = ( | ||
chosen_nll_loss | ||
/ (full_target[: full_target.shape[0] // 2] != ignore_index).sum() | ||
) | ||
|
||
loss_mask = target_chunk != ignore_index | ||
label_chunk = torch.where(loss_mask, target_chunk, 0) | ||
|
||
per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze( | ||
-1 | ||
) | ||
average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) | ||
|
||
chosen_logps = average_log_prob[:len_chosen_chunk] | ||
rejected_logps = average_log_prob[len_chosen_chunk:] | ||
for i in range(reference_logps.shape[0]): | ||
chosen_idx = [i for i in range(reference_logps.shape[0]) if labels[i][0].item() == 1.] | ||
rejected_idx = [i for i in range(reference_logps.shape[0]) if labels[i][0].item() == 0.] | ||
|
||
reference_chosen_logps = reference_logps[chosen_idx, ...] | ||
reference_rejected_logps = reference_logps[rejected_idx, ...] | ||
|
||
alignment_loss, chosen_rewards, rejected_rewards = preference_loss_fn( | ||
chosen_logps, rejected_logps,reference_chosen_logps,reference_rejected_logps, | ||
beta=beta, **loss_kwargs | ||
) | ||
|
||
#alignment_loss = alignment_loss / (full_target.shape[0] // 2) | ||
|
||
#loss = alpha * chosen_nll_loss - alignment_loss | ||
loss = 0 - alignment_loss.mean() | ||
|
||
return loss, (alignment_loss, chosen_logps, rejected_logps) | ||
#return loss[0], (alignment_loss, chosen_logps, rejected_logps) | ||
#return None, (None, None, None) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this class needed, can't you reuse https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/chunked_loss/fused_linear_preference.py?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am getting this error
E RuntimeError: CUDA error: device-side assert triggered E CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. E For debugging consider passing CUDA_LAUNCH_BLOCKING=1 E Compile with
TORCH_USE_CUDA_DSA` to enable device-side assertions.src/liger_kernel/chunked_loss/fused_linear_preference.py:210: RuntimeError
---------------------------------------------------------------------------------------------------------------- Captured stdout call -----------------------------------------------------------------------------------------------------------------
---------------------------------------------------------------------------------------------------------------- Captured stderr call -----------------------------------------------------------------------------------------------------------------
NoneType: None
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [6,0,0] Assertion
idx_dim >= 0 && idx_dim < index_size && "index out of bounds"
failed.../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [7,0,0] Assertion
idx_dim >= 0 && idx_dim < index_size && "index out of bounds"
failed.../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [12,0,0] Assertion
idx_dim >= 0 && idx_dim < index_size && "index out of bounds"
failed.../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [83,0,0] Assertion
idx_dim >= 0 && idx_dim < index_size && "index out of bounds"
failed.../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [32,0,0] Assertion
idx_dim >= 0 && idx_dim < index_size && "index out of bounds"
failed.../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [43,0,0] Assertion
idx_dim >= 0 && idx_dim < index_size && "index out of bounds"
failed.../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [54,0,0] Assertion
idx_dim >= 0 && idx_dim < index_size && "index out of bounds"
failed.../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [59,0,0] Assertion
idx_dim >= 0 && idx_dim < index_size && "index out of bounds"
failed.../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [62,0,0] Assertion
idx_dim >= 0 && idx_dim < index_size && "index out of bounds"
failed.=============================================================================================================== short test summary info ===============================================================================================================
FAILED test/chunked_loss/test_kto_loss.py::test_correctness[-100-0.1-1.0-False-1.0-dtype0-0.005-0.005-3-47-31-123] - RuntimeError: CUDA error: device-side assert triggered
================================================================================================================== 1 failed in 1.86s ============================================`
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will do the equations and the formatting. Also I need two arguments 'reference_chosen_logps' and 'reference_rejected_logps' to my custom loss function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused do you still need this file? The base classes abstract
preference_loss_fn
does accept those two arguments, you can set beta=0 if it's not needed.In case you need a completely new function signature, my advice would be to add a new overloaded function in the existing base class.