Skip to content

Commit

Permalink
fix custom loss (#2374)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang authored Nov 1, 2024
1 parent 6b88670 commit e94e1e4
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
2 changes: 1 addition & 1 deletion requirements/framework.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ rouge
safetensors
tensorboard
tqdm
transformers>=4.33,<4.47
transformers>=4.33,<4.48
transformers_stream_generator
trl>=0.11.0
18 changes: 12 additions & 6 deletions swift/trainers/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,12 @@ def __init__(self, length_smooth: float = 0.9):
self._norm_factor = 0
self._smoothing = length_smooth

def __call__(self, outputs, labels) -> torch.Tensor:
def __call__(self, outputs, labels, num_items_in_batch=None) -> torch.Tensor:
# moving average
loss, masks = ce_loss_func(outputs, labels)
if num_items_in_batch is not None:
# The gradient accumulation equivalent to mini_batch for transformers >= 4.46 and fallback behavior.
return loss.sum() / num_items_in_batch
self._s_length = self._s_length * self._smoothing + loss.shape[0]
self._norm_factor = self._norm_factor * self._smoothing + 1
loss = loss.sum() / (self._s_length / self._norm_factor)
Expand All @@ -65,14 +68,17 @@ def __call__(self, outputs, labels) -> torch.Tensor:


@register_loss_func(LossName.loss_scale)
def loss_scale_func(outputs, labels, loss_scale=None) -> torch.Tensor:
def loss_scale_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
loss, masks = ce_loss_func(outputs, labels)
if loss_scale is None:
loss = loss.mean()
else:
if loss_scale is not None:
shift_scale = loss_scale[..., 1:].to(masks.device)
shift_scale = shift_scale[masks]
loss = (shift_scale * loss).mean()
loss = (shift_scale * loss)
if num_items_in_batch is None:
loss = loss.mean()
else:
# compat transformers>=4.46
loss = loss.sum() / num_items_in_batch
return loss


Expand Down
6 changes: 3 additions & 3 deletions swift/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,17 +150,17 @@ def compute_loss(self, model, inputs, return_outputs=None, num_items_in_batch=No
if loss_name is None and 'loss_scale' in inputs:
loss_name = 'loss-scale'

loss_kwargs = {}
loss_kwargs = {'num_items_in_batch': num_items_in_batch}
if loss_name == 'loss-scale':
loss_kwargs['loss_scale'] = inputs.pop('loss_scale')
loss_kwargs['loss_scale'] = inputs.pop('loss_scale', None)

if loss_name is not None or self.label_smoother is not None and 'labels' in inputs:
labels = inputs.pop('labels')

loss_kwargs['labels'] = labels
outputs = model(**inputs)
# fix https://github.com/huggingface/transformers/issues/34263
if outputs.loss is not None and num_items_in_batch is not None:
if 'labels' in inputs and num_items_in_batch is not None:
outputs.loss = outputs.loss * (inputs['labels'][:, 1:] != -100).sum() / num_items_in_batch
if loss_name is not None:
loss_func = get_loss_func(loss_name)
Expand Down

0 comments on commit e94e1e4

Please sign in to comment.