From fcba35a06720fce91b2bb6cf6486ab5d37929853 Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Mon, 9 Dec 2024 14:08:55 +0800 Subject: [PATCH] Introduce Knowledge Distillation Base (#432) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Made https://github.com/linkedin/Liger-Kernel/pull/417 from the main repo. Thanks to the nice suggestions from @Tcc0403 and @hongpeng-guo. This PR is the s first split from https://github.com/linkedin/Liger-Kernel/pull/408, focusing solely on introducing the Knowledge Distillation base class. As a result, this PR does not include any tests at the moment. #### Code Changes 1. Refactor `beta` into two weights: `weight_hard_loss` and `weight_soft_loss`, as coefficients between `hard_loss` and `soft_loss`. @Tcc0403 also pointed out that we could use `torch.lerp` if applicable. 2. Pass `teacher_logits` and `student_logits` directly to the divergence loss function. This avoids redundant computations of converting logits to log probabilities and then reverting them to raw logits. However note that we are not reusing the `student_log_probs` value calculated during `ce_loss` in distillation base. 1. Remove the unnecessary `get_batch_logps` in `test/utils.py`. 3. Modify `chunking` dimensions from `B` to `B * T`. Thanks to @hongpeng-guo's great advice. 1. Fix the loss calculation to use per-token values instead of averaging across the sequence length dimension. 4. Normalize the `distillation_loss` using `(full_target != ignore_index).sum()`. #### TODO 1. [X] Although a slightly slowdown is reasonable, we need to investigate why this PR's implementation is **significantly slower** compared to the naive approach. Thanks to @Tcc0403 's clarification. The issue arises because we are not properly configuring the `chunk_size` for the `B * T` dimension, which is extremely large (a few thousand). The previous default of 1 results in an excessive number of chunks. In contrast, this problem does not occur with the preference loss, as chunking is performed on the `B` dimension. This produces fewer than 10 chunks, which is efficient and works as expected. In conclusion, I set `chunk_size` to `1024` works pretty well in new benchmark results as shown in https://github.com/linkedin/Liger-Kernel/pull/425 2. [ ] https://github.com/linkedin/Liger-Kernel/pull/417#discussion_r1874231427 #### Knowledge Distillation Knowledge Distillation (KD; [Hinton et al. 2015](https://arxiv.org/abs/1503.02531), [Gou et al. 2020](https://arxiv.org/abs/2006.05525)) is a straightforward way to build a smaller, cheaper model (“student model”) to speed up inference by transferring skills from a pre-trained expensive model (“teacher model”) into the student. In knowledge distillation, a student model is trained to replicate the outputs of a teacher model using a distillation loss. Neural networks typically include a softmax layer; for instance, a large language model produces a probability distribution over tokens. Let `z_t` and `z_s` represent the logits before the softmax layer for the teacher and student models, respectively. The distillation loss reduces the discrepancy between the two softmax outputs at a high temperature `T`. When ground truth labels `y` are available, this approach can be combined with a supervised learning objective, such as cross-entropy, to compare the student’s outputs with the ground truth. The combined loss function is defined as: ```math \mathcal{L}_{\text{knowledge distillation}} = \mathcal{w}_{\text{soft}} \cdot \mathcal{L}_{\text{distill}}(\mathbf{z_t}, \mathbf{z_s}, T) + \mathcal{w}_{\text{hard}} \cdot \mathcal{L}_{\text{cross entropy}}(\mathbf{y}, \mathbf{z_s}), ``` Here, we directly pass in `logits` rather than `logpbs`. @Tcc0403 #### Shared `DistillationBase` To support various distillation learning objectives, this PR aims to add a `LigerFusedLinearDistillationBase` which is basically same as propose by @hongpeng-guo within this discussion https://github.com/linkedin/Liger-Kernel/issues/371#issuecomment-2496940347. Thank you @hongpeng-guo for thinking through this. ## Testing Done I'll post JSD tests and benchmarks results in next PR: https://github.com/linkedin/Liger-Kernel/pull/425 - Hardware Type: L40S - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Austin Liu Co-authored-by: shivam15s --- .../chunked_loss/fused_linear_distillation.py | 250 ++++++++++++++++++ .../chunked_loss/fused_linear_preference.py | 202 +++++++------- test/utils.py | 110 ++++++++ 3 files changed, 461 insertions(+), 101 deletions(-) create mode 100644 src/liger_kernel/chunked_loss/fused_linear_distillation.py diff --git a/src/liger_kernel/chunked_loss/fused_linear_distillation.py b/src/liger_kernel/chunked_loss/fused_linear_distillation.py new file mode 100644 index 000000000..11ae767f6 --- /dev/null +++ b/src/liger_kernel/chunked_loss/fused_linear_distillation.py @@ -0,0 +1,250 @@ +from abc import abstractmethod +from functools import partial + +import torch +from torch.nn import functional as F + + +class LigerFusedLinearDistillationBase(torch.autograd.Function): + + @abstractmethod + def distillation_loss_fn(student_logits, teacher_logits, temperature): + """ + Compute distillation loss. + Args: + student_logits (torch.Tensor): Raw logits of student tokens. Shape: (batch_size * seq_len, vocab_size). + teacher_logits (torch.Tensor): Raw logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size). + """ + raise NotImplementedError("Distillation loss function must be implemented.") + + @staticmethod + def chunk_forward( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias=None, + teacher_bias=None, + ignore_index=-100, + compute_ce_loss=True, + ): + # Student + student_logits_chunk = student_input_chunk @ student_weight.t() + if student_bias is not None: + student_logits_chunk += student_bias + student_log_probs_chunk = F.log_softmax(student_logits_chunk.float(), dim=-1) + + # Teacher + with torch.no_grad(): + teacher_logits_chunk = teacher_input_chunk @ teacher_weight.t() + if teacher_bias is not None: + teacher_logits_chunk += teacher_bias + + # The hard/task loss + ce_loss = 0.0 + if compute_ce_loss: + ce_loss = F.nll_loss( + student_log_probs_chunk.view(-1, student_log_probs_chunk.shape[-1]), + target_chunk.view(-1), + reduction="sum", + ignore_index=ignore_index, + ) + + return student_logits_chunk, teacher_logits_chunk, ce_loss + + @staticmethod + def _compute_loss( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias=None, + teacher_bias=None, + distillation_loss_fn=None, + full_target=None, + ignore_index=-100, + temperature=1.0, + weight_hard_loss=0.5, + weight_soft_loss=0.5, + compute_ce_loss=True, + **loss_kwargs, + ): + """ + Compute the total loss for a chunk of input and target, while using an knowleedge distillation loss function. + Args: + distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. + student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size). + student_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, student_hidden_size). + teacher_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, teacher_hidden_size). + teacher_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, teacher_hidden_size). + target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,). + student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + full_target (torch.Tensor): Full target tensor. Shape: (chunk_size,). + ignore_index (int): Index to ignore for loss computation. + weight_hard_loss (float): Weight for hard loss. + weight_soft_loss (float): Weight for soft loss. + compute_ce_loss (bool): Whether to compute CE loss. + loss_kwargs (dict): Additional arguments for the loss function. + """ + student_logits_chunk, teacher_logits_chunk, hard_loss = ( + LigerFusedLinearDistillationBase.chunk_forward( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias=student_bias, + teacher_bias=teacher_bias, + ignore_index=ignore_index, + compute_ce_loss=compute_ce_loss, + ) + ) + + hard_loss /= full_target.shape[0] + + soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, temperature) + soft_loss /= full_target.shape[0] + + loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss + return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk) + + @staticmethod + def forward( + ctx, + student_input, + student_weight, + teacher_input, + teacher_weight, + target, + student_bias=None, + teacher_bias=None, + loss_fn=None, + chunk_size=1024, + ignore_index=-100, + weight_hard_loss=0.5, + weight_soft_loss=0.5, + compute_ce_loss=True, + temperature=1.0, + compiled=True, + **loss_kwargs, + ): + """ + Base class for fused linear layer with distillation loss. + Only need to compute gradients for student model. + + Args: + student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, student_hidden_size). + student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, student_hidden_size). + teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, teacher_hidden_size). + teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, teacher_hidden_size). + target (torch.Tensor): Target truth label tensor. Shape: (batch_size * seq_len). + student_bias (torch.Tensor, optional): Student bias tensor. Shape: (vocab_size,). + teacher_bias (torch.Tensor, optional): Teacher bias tensor. Shape: (vocab_size,). + loss_fn (callable): Loss function to compute the loss on a chunk of input/target. + chunk_size (int): Size of a chunk. + compute_ce_loss (bool): Whether to compute CE loss. + ignore_index (int): Index to ignore for loss computation. + weight_hard_loss (float): Weight for hard/task loss. + weight_soft_loss (float): Weight for soft/distillation loss. + compiled (bool): Whether to use torch compile for chunk accumulation. + loss_kwargs (dict): Other possible arguments that a loss function might need + """ + CHUNK_SIZE = chunk_size + grad_weight = torch.zeros_like(student_weight) + grad_inputs = [] + grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None + loss_acc = torch.zeros((), device=student_input.device) + + loss_func_to_call = partial( + LigerFusedLinearDistillationBase._compute_loss, + distillation_loss_fn=loss_fn, + full_target=target, + ignore_index=ignore_index, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + compute_ce_loss=compute_ce_loss, + temperature=temperature, + **loss_kwargs, + ) + + def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk): + if student_bias is not None: + (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), ( + chunk_loss, + ( + chunk_soft_loss, + chunk_hard_loss, + chunk_student_logits, + chunk_teacher_logits, + ), + ) = torch.func.grad_and_value( + loss_func_to_call, argnums=(0, 1, 5), has_aux=True + )( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias, + teacher_bias, + ) + grad_bias.add_(chunk_grad_bias) + else: + (chunk_grad_input, chunk_grad_weight), ( + chunk_loss, + ( + chunk_soft_loss, + chunk_hard_loss, + chunk_student_logits, + chunk_teacher_logits, + ), + ) = torch.func.grad_and_value( + loss_func_to_call, argnums=(0, 1), has_aux=True + )( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias, + teacher_bias, + ) + grad_weight.add_(chunk_grad_weight) + loss_acc.add_(chunk_loss) + return chunk_grad_input + + if compiled: + accumulate_chunk = torch.compile(accumulate_chunk) + + num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE) + _student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0) + _teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0) + _target_chunks = torch.chunk(target, chunks=num_chunks, dim=0) + + for student_input_chunk, teacher_input_chunk, target_chunk in zip( + _student_input_chunks, _teacher_input_chunks, _target_chunks + ): + grad_input = accumulate_chunk( + student_input_chunk, teacher_input_chunk, target_chunk + ) + grad_inputs.append(grad_input) + + ctx.save_for_backward( + torch.cat(grad_inputs, dim=0), + grad_weight, + grad_bias, + ) + return loss_acc + + @staticmethod + def backward(ctx, grad_output): + grad_input, grad_weight, grad_bias = ctx.saved_tensors + if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)): + grad_input = grad_input * grad_output + grad_weight = grad_weight * grad_output + grad_bias = grad_bias * grad_output if grad_bias is not None else None + + return grad_input, grad_weight, None, grad_bias diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index c31cbba8b..26ae38a3d 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -64,6 +64,103 @@ def chunk_forward( chosen_nll_loss, ) + @staticmethod + def _compute_loss( + input_chunk, + weight, + target_chunk, + bias=None, + preference_loss_fn=None, + full_target=None, + ignore_index=-100, + alpha=1.0, + beta=0.1, + compute_nll_loss=True, + use_ref_model=False, + ref_weight=None, + ref_bias=None, + **loss_kwargs, + ): + """ + Compute the total loss for a chunk of input and target, while using an alignment/preference loss function. + Args: + preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. + input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size). + weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). + target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length). + bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length). + ignore_index (int): Index to ignore for loss computation. + alpha (float): Weight for the NLL loss. + beta (float): Weight for the odds ratio loss. + compute_nll_loss (bool): Whether to compute NLL loss. + use_ref_model (bool): Whether to use a reference model for the alignment loss. + ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). + ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). + loss_kwargs (dict): Additional arguments for the loss function. + """ + ( + chosen_logps, + rejected_logps, + chosen_logits, + rejected_logits, + chosen_nll_loss, + ) = LigerFusedLinearPreferenceBase.chunk_forward( + input_chunk, + weight, + target_chunk, + bias=bias, + ignore_index=ignore_index, + compute_nll_loss=compute_nll_loss, + ) + chosen_nll_loss = ( + chosen_nll_loss + / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() + ) + chosen_logits_mean = chosen_logits.sum() / ( + full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] + ) + rejected_logits_mean = rejected_logits.sum() / ( + full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] + ) + + if use_ref_model: + with torch.no_grad(): + ( + ref_chosen_logps, + ref_rejected_logps, + ref_chosen_logits, + ref_rejected_logits, + ref_chosen_nll_loss, + ) = LigerFusedLinearPreferenceBase.chunk_forward( + input_chunk, + ref_weight, + target_chunk, + ref_bias, + ignore_index=ignore_index, + compute_nll_loss=False, # We don't need NLL loss for the reference model + ) + loss_kwargs["ref_chosen_logps"] = ref_chosen_logps + loss_kwargs["ref_rejected_logps"] = ref_rejected_logps + + preference_loss_outputs = preference_loss_fn( + chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs + ) + if isinstance(preference_loss_outputs, tuple): + preference_loss, *aux_outputs = preference_loss_outputs + else: + preference_loss, aux_outputs = preference_loss_outputs, [] + + loss = alpha * chosen_nll_loss - preference_loss + return_vars = ( + chosen_logps, + rejected_logps, + chosen_logits_mean, + rejected_logits_mean, + chosen_nll_loss, + ) + return loss, (*return_vars, *aux_outputs) + @staticmethod def forward( ctx, @@ -134,7 +231,7 @@ def forward( **loss_kwargs, ) - def accumulate_helper(input_chunk, target_chunk): + def accumulate_core(input_chunk, target_chunk): if bias is not None: return torch.func.grad_and_value( loss_func_to_call, argnums=(0, 1, 3), has_aux=True @@ -156,7 +253,7 @@ def accumulate_chunk(input_chunk, target_chunk): chunk_nll_loss, *aux_outputs, ), - ) = accumulate_helper(input_chunk, target_chunk) + ) = accumulate_core(input_chunk, target_chunk) grad_bias.add_(chunk_grad_bias) # accumulate bias gradient else: (chunk_grad_input, chunk_grad_weight), ( @@ -169,7 +266,7 @@ def accumulate_chunk(input_chunk, target_chunk): chunk_nll_loss, *aux_outputs, ), - ) = accumulate_helper(input_chunk, target_chunk) + ) = accumulate_core(input_chunk, target_chunk) grad_weight.add_(chunk_grad_weight) loss_acc.add_(chunk_loss) @@ -199,7 +296,7 @@ def accumulate_chunk(input_chunk, target_chunk): return chunk_grad_input if compiled: - accumulate_helper = torch.compile(accumulate_helper) + accumulate_core = torch.compile(accumulate_core) len_chosen = target.shape[0] // 2 chunks = max(1, _input.shape[0] // (2 * CHUNK_SIZE)) @@ -270,100 +367,3 @@ def backward(ctx, *grad_output): grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None return grad_input, grad_weight, None, grad_bias, None, None, None - - @staticmethod - def _compute_loss( - input_chunk, - weight, - target_chunk, - bias=None, - preference_loss_fn=None, - full_target=None, - ignore_index=-100, - alpha=1.0, - beta=0.1, - compute_nll_loss=True, - use_ref_model=False, - ref_weight=None, - ref_bias=None, - **loss_kwargs, - ): - """ - Compute the total loss for a chunk of input and target, while using an alignment/preference loss function. - Args: - preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. - input_chunk (torch.Tensor): Chunk of input tensor. Shape: (2 * chunk_size, sequence_length, hidden_size). - weight (torch.Tensor): Weight tensor. Shape: (vocab_size, hidden_size). - target_chunk (torch.Tensor): Chunk of target tensor. Shape: (2 * chunk_size, sequence_length). - bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). - full_target (torch.Tensor): Full target tensor. Shape: (batch_size, sequence_length). - ignore_index (int): Index to ignore for loss computation. - alpha (float): Weight for the NLL loss. - beta (float): Weight for the odds ratio loss. - compute_nll_loss (bool): Whether to compute NLL loss. - use_ref_model (bool): Whether to use a reference model for the alignment loss. - ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). - ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). - loss_kwargs (dict): Additional arguments for the loss function. - """ - ( - chosen_logps, - rejected_logps, - chosen_logits, - rejected_logits, - chosen_nll_loss, - ) = LigerFusedLinearPreferenceBase.chunk_forward( - input_chunk, - weight, - target_chunk, - bias=bias, - ignore_index=ignore_index, - compute_nll_loss=compute_nll_loss, - ) - chosen_nll_loss = ( - chosen_nll_loss - / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() - ) - chosen_logits_mean = chosen_logits.sum() / ( - full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] - ) - rejected_logits_mean = rejected_logits.sum() / ( - full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] - ) - - if use_ref_model: - with torch.no_grad(): - ( - ref_chosen_logps, - ref_rejected_logps, - ref_chosen_logits, - ref_rejected_logits, - ref_chosen_nll_loss, - ) = LigerFusedLinearPreferenceBase.chunk_forward( - input_chunk, - ref_weight, - target_chunk, - ref_bias, - ignore_index=ignore_index, - compute_nll_loss=False, # We don't need NLL loss for the reference model - ) - loss_kwargs["ref_chosen_logps"] = ref_chosen_logps - loss_kwargs["ref_rejected_logps"] = ref_rejected_logps - - preference_loss_outputs = preference_loss_fn( - chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs - ) - if isinstance(preference_loss_outputs, tuple): - preference_loss, *aux_outputs = preference_loss_outputs - else: - preference_loss, aux_outputs = preference_loss_outputs, [] - - loss = alpha * chosen_nll_loss - preference_loss - return_vars = ( - chosen_logps, - rejected_logps, - chosen_logits_mean, - rejected_logits_mean, - chosen_nll_loss, - ) - return loss, (*return_vars, *aux_outputs) diff --git a/test/utils.py b/test/utils.py index 711c4f870..29e0d9143 100644 --- a/test/utils.py +++ b/test/utils.py @@ -519,3 +519,113 @@ def get_batch_loss_metrics( policy_nll_loss, ) return loss, (*return_vars, *aggregated_aux_outputs) + + +class HFDistillationLoss: + def __init__( + self, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + ignore_index: int = -100, + temperature: float = 1, + ): + self.weight_hard_loss = weight_hard_loss + self.weight_soft_loss = weight_soft_loss + self.ignore_index = ignore_index + self.temperature = temperature + + @abstractmethod + def distillation_loss(self, student_logits, teacher_logits): + """Abstract method for computing distillation loss.""" + pass + + def concatenated_forward( + self, + student_input: torch.FloatTensor, + student_weight: torch.FloatTensor, + teacher_input: torch.FloatTensor, + teacher_weight: torch.FloatTensor, + target: torch.LongTensor, + student_bias: torch.FloatTensor = None, + teacher_bias: torch.FloatTensor = None, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + ]: + """Compute forward pass for both student and teacher models.""" + + student_batch_seq_len_size, student_hidden_size = student_input.shape + student_input_reshaped = student_input.view(-1, student_hidden_size) + teacher_batch_seq_len_size, teacher_hidden_size = teacher_input.shape + teacher_input_reshaped = teacher_input.view(-1, teacher_hidden_size) + + student_outputs = student_input_reshaped @ student_weight.t() + if student_bias is not None: + student_outputs = student_outputs + student_bias + + with torch.no_grad(): + teacher_outputs = teacher_input_reshaped @ teacher_weight.t() + if teacher_bias is not None: + teacher_outputs = teacher_outputs + teacher_bias + + student_logits = student_outputs.view(student_batch_seq_len_size, -1).float() + teacher_logits = teacher_outputs.view(teacher_batch_seq_len_size, -1).float() + + if torch.all(target == self.ignore_index): + return torch.tensor(0.0) + + def cross_entropy_loss(logits, labels): + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index) + logits = logits.view(-1, logits.shape[-1]) + labels = labels.view(-1) + # Enable model parallelism + labels = labels.to(logits.device) + loss = loss_fct(logits, labels) + return loss + + labels = target + ce_loss = cross_entropy_loss( + student_logits.view(-1, student_logits.shape[-1]), + labels.view(-1), + ) + + return ( + student_logits, + teacher_logits, + ce_loss, + ) + + def get_batch_loss_metrics( + self, + student_input: torch.FloatTensor, + student_weight: torch.FloatTensor, + teacher_input: torch.FloatTensor, + teacher_weight: torch.FloatTensor, + target: torch.LongTensor, + student_bias: torch.FloatTensor = None, + teacher_bias: torch.FloatTensor = None, + ): + """Compute the distillation loss metrics for the given batch.""" + forward_output = self.concatenated_forward( + student_input, + student_weight, + teacher_input, + teacher_weight, + target, + student_bias, + teacher_bias, + ) + ( + student_logits, + teacher_logits, + hard_loss, + ) = forward_output + + soft_loss = self.distillation_loss(student_logits, teacher_logits) + # full loss + loss = self.weight_hard_loss * hard_loss + self.weight_soft_loss * soft_loss.mean() + return loss