-
Notifications
You must be signed in to change notification settings - Fork 245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
KTO loss #410
base: main
Are you sure you want to change the base?
KTO loss #410
Conversation
test/chunked_loss/test_cpo_loss.py
Outdated
@@ -126,7 +126,7 @@ def test_correctness( | |||
input1, weight1, target, bias1, alpha=alpha | |||
) | |||
loss2 = LigerFusedLinearCPOFunction.apply( | |||
input2, weight2, target, bias2, ignore_index, beta, alpha, True | |||
input2, weight2, target, bias2, ignore_index, beta, alpha, False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we changing the test case for an unrelated alignment algo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry my bad.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, I think this code needs to be refactored to make things a bit cleaner and easier to understand. Could you also write out the equations for KTO in the description to the PR so that its easier for a reviewer to understand?
from torch.nn import functional as F | ||
|
||
|
||
class LigerFusedLinearKTOPreferenceBase(torch.autograd.Function): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this class needed, can't you reuse https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/chunked_loss/fused_linear_preference.py?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am getting this error
E RuntimeError: CUDA error: device-side assert triggered E CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. E For debugging consider passing CUDA_LAUNCH_BLOCKING=1 E Compile with
TORCH_USE_CUDA_DSA` to enable device-side assertions.
src/liger_kernel/chunked_loss/fused_linear_preference.py:210: RuntimeError
---------------------------------------------------------------------------------------------------------------- Captured stdout call -----------------------------------------------------------------------------------------------------------------
---------------------------------------------------------------------------------------------------------------- Captured stderr call -----------------------------------------------------------------------------------------------------------------
NoneType: None
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [6,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds"
failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [7,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds"
failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [12,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds"
failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [83,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds"
failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [32,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds"
failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [43,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds"
failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [54,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds"
failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [59,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds"
failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [62,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds"
failed.
=============================================================================================================== short test summary info ===============================================================================================================
FAILED test/chunked_loss/test_kto_loss.py::test_correctness[-100-0.1-1.0-False-1.0-dtype0-0.005-0.005-3-47-31-123] - RuntimeError: CUDA error: device-side assert triggered
================================================================================================================== 1 failed in 1.86s ============================================`
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will do the equations and the formatting. Also I need two arguments 'reference_chosen_logps' and 'reference_rejected_logps' to my custom loss function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused do you still need this file? The base classes abstract preference_loss_fn
does accept those two arguments, you can set beta=0 if it's not needed.
In case you need a completely new function signature, my advice would be to add a new overloaded function in the existing base class.
Okay code formatted and comment about the source of the loss added |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vulkomilev can you please make sure that all unsued code is deleted and can you also confirm if
make checkstyle
make test
works?
It'd be great if you can add the equations of KTO in the PRs description similar to #386
make checkstyle and make test works now.The commented code was removed and I have added the formula in kto_loss.py but I am not sure about the formmating |
@ByronHsu @shivam15s could either of you please take over reviewing this PR, have to switch my focus to other stuff. |
Summary
This is the kto loss implemented by references from other projects
Details
I am not sure about the correctness (because this is my first PR) of the final results so I expect a lot of comments
Testing Done
I have done the basic testing inspired from cpo