Skip to content

Commit

Permalink
[DPO] cDPO loss (#1035)
Browse files Browse the repository at this point in the history
* add cDPO loss

* add comment

* docs

* info about label_smoothing not being used
  • Loading branch information
kashif authored Nov 30, 2023
1 parent 4b67af3 commit c84e591
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 40 deletions.
2 changes: 2 additions & 0 deletions docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ The [RSO](https://arxiv.org/abs/2309.06657) authors propose to use a hinge loss

The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer.

The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability that can be passed to the `DPOTrainer` via `label_smoothing` argument (between 0 and 0.5) and then a conservative DPO loss is used.

## Logging

While training and evaluating we record the following reward metrics:
Expand Down
59 changes: 19 additions & 40 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class DPOTrainer(Trainer):
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
beta (`float`, defaults to 0.1):
The beta factor in DPO loss. Higher beta means less divergence from the initial policy. For the IPO loss, beta is the regularization parameter denoted by tau in the paper.
label_smoothing (`float`, defaults to 0):
The robust DPO label smoothing parameter that should be between 0 and 0.5.
loss_type (`str`, defaults to `"sigmoid"`):
The type of DPO loss to use. Either `"sigmoid"` the default DPO loss,`"hinge"` loss from SLiC paper or `"ipo"` from IPO paper.
args (`transformers.TrainingArguments`):
Expand Down Expand Up @@ -120,6 +122,7 @@ def __init__(
model: Union[PreTrainedModel, nn.Module, str] = None,
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
beta: float = 0.1,
label_smoothing: float = 0,
loss_type: Literal["sigmoid", "hinge", "ipo"] = "sigmoid",
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
Expand Down Expand Up @@ -308,7 +311,13 @@ def make_inputs_require_grad(module, input, output):
self.label_pad_token_id = label_pad_token_id
self.padding_value = padding_value

if loss_type in ["hinge", "ipo"] and label_smoothing > 0:
warnings.warn(
"You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter."
)

self.beta = beta
self.label_smoothing = label_smoothing
self.loss_type = loss_type

self._stored_metrics = defaultdict(lambda: defaultdict(list))
Expand Down Expand Up @@ -444,11 +453,18 @@ def dpo_loss(
logits = pi_logratios - ref_logratios

# The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5.
# We ignore the reference model as beta -> 0.
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
# calculates a conservative DPO loss.
if self.loss_type == "sigmoid":
losses = -F.logsigmoid(self.beta * logits)
losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
elif self.loss_type == "hinge":
losses = torch.relu(1 - self.beta * logits)
elif self.loss_type == "ipo":
# eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
losses = (logits - 1 / (2 * self.beta)) ** 2
else:
raise ValueError(f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge']")

Expand All @@ -457,38 +473,6 @@ def dpo_loss(

return losses, chosen_rewards, rejected_rewards

def ipo_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Compute the IPO loss for a batch of policy and reference model log probabilities.
Args:
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
Returns:
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
The losses tensor contains the IPO loss for each example in the batch.
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
"""
pi_logratios = policy_chosen_logps + reference_rejected_logps
ref_logratios = policy_rejected_logps + reference_chosen_logps

logits = pi_logratios - ref_logratios
# eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
losses = (logits - 1 / (2 * self.beta)) ** 2

chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()

return losses, chosen_rewards, rejected_rewards

def _get_batch_logps(
self,
logits: torch.FloatTensor,
Expand Down Expand Up @@ -593,12 +577,7 @@ def get_batch_metrics(
_,
) = self.concatenated_forward(self.ref_model, batch)

if self.loss_type == "ipo":
loss_fn = self.ipo_loss
else:
loss_fn = self.dpo_loss

losses, chosen_rewards, rejected_rewards = loss_fn(
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
policy_chosen_logps,
policy_rejected_logps,
reference_chosen_logps,
Expand Down

0 comments on commit c84e591

Please sign in to comment.