Skip to content

Commit

Permalink
Introduce Knowledge Distillation Base (#432)
Browse files Browse the repository at this point in the history
## Summary

Made #417 from the main
repo.

Thanks to the nice suggestions from @Tcc0403 and @hongpeng-guo. This PR
is the s first split from
#408, focusing solely on
introducing the Knowledge Distillation base class. As a result, this PR
does not include any tests at the moment.

#### Code Changes

1. Refactor `beta` into two weights: `weight_hard_loss` and
`weight_soft_loss`, as coefficients between `hard_loss` and `soft_loss`.
@Tcc0403 also pointed out that we could use `torch.lerp` if applicable.

2. Pass `teacher_logits` and `student_logits` directly to the divergence
loss function. This avoids redundant computations of converting logits
to log probabilities and then reverting them to raw logits. However note
that we are not reusing the `student_log_probs` value calculated during
`ce_loss` in distillation base.

    1. Remove the unnecessary `get_batch_logps` in `test/utils.py`.

3. Modify `chunking` dimensions from `B` to `B * T`. Thanks to
@hongpeng-guo's great advice.
1. Fix the loss calculation to use per-token values instead of averaging
across the sequence length dimension.

4. Normalize the `distillation_loss` using `(full_target !=
ignore_index).sum()`.

#### TODO  

1. [X] Although a slightly slowdown is reasonable, we need to
investigate why this PR's implementation is **significantly slower**
compared to the naive approach. Thanks to @Tcc0403 's clarification.
    
The issue arises because we are not properly configuring the
`chunk_size` for the `B * T` dimension, which is extremely large (a few
thousand). The previous default of 1 results in an excessive number of
chunks.

In contrast, this problem does not occur with the preference loss, as
chunking is performed on the `B` dimension. This produces fewer than 10
chunks, which is efficient and works as expected.

In conclusion, I set `chunk_size` to `1024` works pretty well in new
benchmark results as shown in
#425

2. [ ]
#417 (comment)

#### Knowledge Distillation

Knowledge Distillation (KD; [Hinton et al.
2015](https://arxiv.org/abs/1503.02531), [Gou et al.
2020](https://arxiv.org/abs/2006.05525)) is a straightforward way to
build a smaller, cheaper model (“student model”) to speed up inference
by transferring skills from a pre-trained expensive model (“teacher
model”) into the student.

In knowledge distillation, a student model is trained to replicate the
outputs of a teacher model using a distillation loss. Neural networks
typically include a softmax layer; for instance, a large language model
produces a probability distribution over tokens. Let `z_t` and `z_s`
represent the logits before the softmax layer for the teacher and
student models, respectively. The distillation loss reduces the
discrepancy between the two softmax outputs at a high temperature `T`.
When ground truth labels `y` are available, this approach can be
combined with a supervised learning objective, such as cross-entropy, to
compare the student’s outputs with the ground truth.

The combined loss function is defined as:

```math
\mathcal{L}_{\text{knowledge distillation}} = \mathcal{w}_{\text{soft}} \cdot \mathcal{L}_{\text{distill}}(\mathbf{z_t}, \mathbf{z_s}, T) + \mathcal{w}_{\text{hard}} \cdot \mathcal{L}_{\text{cross entropy}}(\mathbf{y}, \mathbf{z_s}),
``` 

Here,  we directly pass in `logits` rather than `logpbs`. @Tcc0403 

#### Shared `DistillationBase`

To support various distillation learning objectives, this PR aims to add
a `LigerFusedLinearDistillationBase` which is basically same as propose
by @hongpeng-guo within this discussion
#371 (comment).
Thank you @hongpeng-guo for thinking through this.

## Testing Done

I'll post JSD tests and benchmarks results in next PR:
#425

- Hardware Type: L40S
- [ ] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence

---------

Signed-off-by: Austin Liu <[email protected]>
Co-authored-by: shivam15s <[email protected]>
  • Loading branch information
austin362667 and shivam15s authored Dec 9, 2024
1 parent d887657 commit fcba35a
Show file tree
Hide file tree
Showing 3 changed files with 461 additions and 101 deletions.
250 changes: 250 additions & 0 deletions src/liger_kernel/chunked_loss/fused_linear_distillation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
from abc import abstractmethod
from functools import partial

import torch
from torch.nn import functional as F


class LigerFusedLinearDistillationBase(torch.autograd.Function):

@abstractmethod
def distillation_loss_fn(student_logits, teacher_logits, temperature):
"""
Compute distillation loss.
Args:
student_logits (torch.Tensor): Raw logits of student tokens. Shape: (batch_size * seq_len, vocab_size).
teacher_logits (torch.Tensor): Raw logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size).
"""
raise NotImplementedError("Distillation loss function must be implemented.")

@staticmethod
def chunk_forward(
student_input_chunk,
student_weight,
teacher_input_chunk,
teacher_weight,
target_chunk,
student_bias=None,
teacher_bias=None,
ignore_index=-100,
compute_ce_loss=True,
):
# Student
student_logits_chunk = student_input_chunk @ student_weight.t()
if student_bias is not None:
student_logits_chunk += student_bias
student_log_probs_chunk = F.log_softmax(student_logits_chunk.float(), dim=-1)

# Teacher
with torch.no_grad():
teacher_logits_chunk = teacher_input_chunk @ teacher_weight.t()
if teacher_bias is not None:
teacher_logits_chunk += teacher_bias

# The hard/task loss
ce_loss = 0.0
if compute_ce_loss:
ce_loss = F.nll_loss(
student_log_probs_chunk.view(-1, student_log_probs_chunk.shape[-1]),
target_chunk.view(-1),
reduction="sum",
ignore_index=ignore_index,
)

return student_logits_chunk, teacher_logits_chunk, ce_loss

@staticmethod
def _compute_loss(
student_input_chunk,
student_weight,
teacher_input_chunk,
teacher_weight,
target_chunk,
student_bias=None,
teacher_bias=None,
distillation_loss_fn=None,
full_target=None,
ignore_index=-100,
temperature=1.0,
weight_hard_loss=0.5,
weight_soft_loss=0.5,
compute_ce_loss=True,
**loss_kwargs,
):
"""
Compute the total loss for a chunk of input and target, while using an knowleedge distillation loss function.
Args:
distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size).
student_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, student_hidden_size).
teacher_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, teacher_hidden_size).
teacher_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, teacher_hidden_size).
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,).
student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
full_target (torch.Tensor): Full target tensor. Shape: (chunk_size,).
ignore_index (int): Index to ignore for loss computation.
weight_hard_loss (float): Weight for hard loss.
weight_soft_loss (float): Weight for soft loss.
compute_ce_loss (bool): Whether to compute CE loss.
loss_kwargs (dict): Additional arguments for the loss function.
"""
student_logits_chunk, teacher_logits_chunk, hard_loss = (
LigerFusedLinearDistillationBase.chunk_forward(
student_input_chunk,
student_weight,
teacher_input_chunk,
teacher_weight,
target_chunk,
student_bias=student_bias,
teacher_bias=teacher_bias,
ignore_index=ignore_index,
compute_ce_loss=compute_ce_loss,
)
)

hard_loss /= full_target.shape[0]

soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature)
soft_loss /= full_target.shape[0]

loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk)

@staticmethod
def forward(
ctx,
student_input,
student_weight,
teacher_input,
teacher_weight,
target,
student_bias=None,
teacher_bias=None,
loss_fn=None,
chunk_size=1024,
ignore_index=-100,
weight_hard_loss=0.5,
weight_soft_loss=0.5,
compute_ce_loss=True,
temperature=1.0,
compiled=True,
**loss_kwargs,
):
"""
Base class for fused linear layer with distillation loss.
Only need to compute gradients for student model.
Args:
student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, student_hidden_size).
student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, student_hidden_size).
teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, teacher_hidden_size).
teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, teacher_hidden_size).
target (torch.Tensor): Target truth label tensor. Shape: (batch_size * seq_len).
student_bias (torch.Tensor, optional): Student bias tensor. Shape: (vocab_size,).
teacher_bias (torch.Tensor, optional): Teacher bias tensor. Shape: (vocab_size,).
loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
chunk_size (int): Size of a chunk.
compute_ce_loss (bool): Whether to compute CE loss.
ignore_index (int): Index to ignore for loss computation.
weight_hard_loss (float): Weight for hard/task loss.
weight_soft_loss (float): Weight for soft/distillation loss.
compiled (bool): Whether to use torch compile for chunk accumulation.
loss_kwargs (dict): Other possible arguments that a loss function might need
"""
CHUNK_SIZE = chunk_size
grad_weight = torch.zeros_like(student_weight)
grad_inputs = []
grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None
loss_acc = torch.zeros((), device=student_input.device)

loss_func_to_call = partial(
LigerFusedLinearDistillationBase._compute_loss,
distillation_loss_fn=loss_fn,
full_target=target,
ignore_index=ignore_index,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
compute_ce_loss=compute_ce_loss,
temperature=temperature,
**loss_kwargs,
)

def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk):
if student_bias is not None:
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (
chunk_loss,
(
chunk_soft_loss,
chunk_hard_loss,
chunk_student_logits,
chunk_teacher_logits,
),
) = torch.func.grad_and_value(
loss_func_to_call, argnums=(0, 1, 5), has_aux=True
)(
student_input_chunk,
student_weight,
teacher_input_chunk,
teacher_weight,
target_chunk,
student_bias,
teacher_bias,
)
grad_bias.add_(chunk_grad_bias)
else:
(chunk_grad_input, chunk_grad_weight), (
chunk_loss,
(
chunk_soft_loss,
chunk_hard_loss,
chunk_student_logits,
chunk_teacher_logits,
),
) = torch.func.grad_and_value(
loss_func_to_call, argnums=(0, 1), has_aux=True
)(
student_input_chunk,
student_weight,
teacher_input_chunk,
teacher_weight,
target_chunk,
student_bias,
teacher_bias,
)
grad_weight.add_(chunk_grad_weight)
loss_acc.add_(chunk_loss)
return chunk_grad_input

if compiled:
accumulate_chunk = torch.compile(accumulate_chunk)

num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE)
_student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0)
_teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0)
_target_chunks = torch.chunk(target, chunks=num_chunks, dim=0)

for student_input_chunk, teacher_input_chunk, target_chunk in zip(
_student_input_chunks, _teacher_input_chunks, _target_chunks
):
grad_input = accumulate_chunk(
student_input_chunk, teacher_input_chunk, target_chunk
)
grad_inputs.append(grad_input)

ctx.save_for_backward(
torch.cat(grad_inputs, dim=0),
grad_weight,
grad_bias,
)
return loss_acc

@staticmethod
def backward(ctx, grad_output):
grad_input, grad_weight, grad_bias = ctx.saved_tensors
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
grad_input = grad_input * grad_output
grad_weight = grad_weight * grad_output
grad_bias = grad_bias * grad_output if grad_bias is not None else None

return grad_input, grad_weight, None, grad_bias
Loading

0 comments on commit fcba35a

Please sign in to comment.