Skip to content

Commit

Permalink
align jsd with paper & twist tests beta
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Liu <[email protected]>
  • Loading branch information
austin362667 committed Jan 23, 2025
1 parent 50022c0 commit 7268f64
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/liger_kernel/chunked_loss/jsd_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def distillation_loss_fn(student_logits, teacher_logits, beta=0.5):
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)

# JSD is the average of the KL divergences
jsd_loss = beta * student_kl + (1 - beta) * teacher_kl
# JSD is the weighted average of the KL divergences
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
return jsd_loss

@staticmethod
Expand Down
6 changes: 4 additions & 2 deletions test/chunked_loss/test_jsd_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def distillation_loss(self, student_logits, teacher_logits, beta=0.5):
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="batchmean", log_target=True)
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="batchmean", log_target=True)

# JSD is the average of the KL divergences
jsd_loss = beta * student_kl + (1 - beta) * teacher_kl
# JSD is the weighted average of the KL divergences
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
return jsd_loss


Expand Down Expand Up @@ -167,6 +167,8 @@ def forward(self, student_input, teacher_input, target):
(0.5, 0.5, 0.5, 0.5),
(1.0, 0.0, 1.0, 0.5),
(1.0, 1.0, 0.0, 0.5),
(1.0, 0.5, 0.5, 0.3),
(2.0, 0.5, 0.5, 0.7),
],
)
def test_correctness(
Expand Down

0 comments on commit 7268f64

Please sign in to comment.