Skip to content

Commit

Permalink
[DPO] add KTO loss (#1075)
Browse files Browse the repository at this point in the history
* add KTO loss

* fix docs

* Update trl/trainer/dpo_trainer.py

Co-authored-by: lewtun <[email protected]>

* formatting

* add link to papers

---------

Co-authored-by: lewtun <[email protected]>
  • Loading branch information
kashif and lewtun authored Dec 11, 2023
1 parent 7d0a8ee commit d275cb4
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 8 deletions.
7 changes: 5 additions & 2 deletions docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,17 @@ dpo_trainer.train()

Note that the `beta` is the temperature parameter for the DPO loss, typically something in the range of `0.1` to `0.5`. We ignore the reference model as `beta` -> 0.

## Loss function
## Loss functions

Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the DPO authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression.

The [RSO](https://arxiv.org/abs/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://arxiv.org/abs/2305.10425) paper. The `DPOTrainer` can be switched to this loss via the `loss_type="hinge"` argument and the `beta` in this case is the reciprocal of the margin.

The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer.

The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability that can be passed to the `DPOTrainer` via `label_smoothing` argument (between 0 and 0.5) and then a conservative DPO loss is used.
The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability that can be passed to the `DPOTrainer` via `label_smoothing` argument (between 0 and 0.5) and then a conservative DPO loss is used. Use the `loss_type="cdpo"` argument to the trainer to use it.

The [KTO](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf) loss is derived to directly maximize the utility of LLM generations instead of the log-likelihood of prefereces. Thus the dataset are not neccsarily prefereces but rather desirable vs undersirable pairs. Use the `loss_type="kto"` argument to the trainer to utilize this loss.

## Logging

Expand Down
4 changes: 3 additions & 1 deletion tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def _init_dummy_dataset(self):
# fmt: on
return Dataset.from_dict(dummy_dataset_dict)

@parameterized.expand([["gpt2", "sigmoid"], ["t5", "hinge"], ["gpt2", "ipo"], ["t5", "ipo"]])
@parameterized.expand(
[["gpt2", "sigmoid"], ["t5", "hinge"], ["gpt2", "ipo"], ["t5", "ipo"], ["gpt2", "kto"], ["t5", "kto"]]
)
def test_dpo_trainer(self, name, loss_type):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
Expand Down
27 changes: 22 additions & 5 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ class DPOTrainer(Trainer):
beta (`float`, defaults to 0.1):
The beta factor in DPO loss. Higher beta means less divergence from the initial policy. For the IPO loss, beta is the regularization parameter denoted by tau in the paper.
label_smoothing (`float`, defaults to 0):
The robust DPO label smoothing parameter that should be between 0 and 0.5.
The robust DPO label smoothing parameter from the [cDPO](https://ericmitchell.ai/cdpo.pdf) report that should be between 0 and 0.5.
loss_type (`str`, defaults to `"sigmoid"`):
The type of DPO loss to use. Either `"sigmoid"` the default DPO loss,`"hinge"` loss from SLiC paper or `"ipo"` from IPO paper.
The type of DPO loss to use. Either `"sigmoid"` the default DPO loss,`"hinge"` loss from [SLiC](https://arxiv.org/abs/2305.10425) paper, `"ipo"` from [IPO](https://arxiv.org/abs/2310.12036) paper, or `"kto"` from the HALOs [report](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf).
args (`transformers.TrainingArguments`):
The arguments to use for training.
data_collator (`transformers.DataCollator`):
Expand Down Expand Up @@ -123,7 +123,7 @@ def __init__(
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
beta: float = 0.1,
label_smoothing: float = 0,
loss_type: Literal["sigmoid", "hinge", "ipo"] = "sigmoid",
loss_type: Literal["sigmoid", "hinge", "ipo", "kto"] = "sigmoid",
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
label_pad_token_id: int = -100,
Expand Down Expand Up @@ -311,7 +311,7 @@ def make_inputs_require_grad(module, input, output):
self.label_pad_token_id = label_pad_token_id
self.padding_value = padding_value

if loss_type in ["hinge", "ipo"] and label_smoothing > 0:
if loss_type in ["hinge", "ipo", "kto"] and label_smoothing > 0:
warnings.warn(
"You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter."
)
Expand Down Expand Up @@ -465,8 +465,25 @@ def dpo_loss(
elif self.loss_type == "ipo":
# eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
losses = (logits - 1 / (2 * self.beta)) ** 2
elif self.loss_type == "kto":
# eqn (7) of the HALOs paper
chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0)
rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0)

chosen_logratios = policy_chosen_logps - reference_chosen_logps
rejected_logratios = policy_rejected_logps - reference_rejected_logps
# As described in the KTO report, the KL term for chosen (rejected) is estimated using the rejected (chosen) half.
losses = torch.cat(
(
1 - F.sigmoid(self.beta * (chosen_logratios - rejected_KL)),
1 - F.sigmoid(self.beta * (chosen_KL - rejected_logratios)),
),
0,
)
else:
raise ValueError(f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo']")
raise ValueError(
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto']"
)

chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
Expand Down

0 comments on commit d275cb4

Please sign in to comment.