diff --git a/benchmark/data/all_benchmark_data.csv b/benchmark/data/all_benchmark_data.csv index 7b182f657..b68ea49f9 100644 --- a/benchmark/data/all_benchmark_data.csv +++ b/benchmark/data/all_benchmark_data.csv @@ -745,3 +745,27 @@ kto_loss,huggingface,full,memory,MB,B,Batch Size (B),4,5544.25390625,5544.253906 kto_loss,huggingface,full,memory,MB,B,Batch Size (B),8,9057.287109375,9057.287109375,9057.287109375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2 kto_loss,huggingface,full,memory,MB,B,Batch Size (B),16,16087.353515625,16087.353515625,16087.353515625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2 kto_loss,huggingface,full,memory,MB,B,Batch Size (B),32,30147.486328125,30147.486328125,30147.486328125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-23 23:34:59,0.5.2 +distill_jsd_loss,liger,forward,speed,ms,BT,B x T,1024,7.735536098480225,7.729177474975586,7.798131465911865,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2 +distill_jsd_loss,liger,forward,speed,ms,BT,B x T,2048,15.20411205291748,15.165056228637695,15.226079940795898,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2 +distill_jsd_loss,liger,forward,speed,ms,BT,B x T,4096,30.159456253051758,30.126911163330078,30.165311813354492,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2 +distill_jsd_loss,liger,forward,speed,ms,BT,B x T,8192,60.24163055419922,60.24163055419922,60.24163055419922,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:58:46,0.4.2 +distill_jsd_loss,torch,forward,speed,ms,BT,B x T,1024,10.906111717224121,10.903244972229004,10.91296672821045,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:18,0.4.2 +distill_jsd_loss,torch,forward,speed,ms,BT,B x T,2048,21.480207443237305,21.465139389038086,21.489286422729492,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:18,0.4.2 +distill_jsd_loss,torch,forward,speed,ms,BT,B x T,4096,42.96339416503906,42.96237564086914,42.96440887451172,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:18,0.4.2 +distill_jsd_loss,torch,forward,speed,ms,BT,B x T,8192,85.3946533203125,85.3946533203125,85.3946533203125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:18,0.4.2 +distill_jsd_loss,liger,full,speed,ms,BT,B x T,1024,8.312895774841309,8.310400009155273,8.326751708984375,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:51,0.4.2 +distill_jsd_loss,liger,full,speed,ms,BT,B x T,2048,15.770208358764648,15.767775535583496,15.774784088134766,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:51,0.4.2 +distill_jsd_loss,liger,full,speed,ms,BT,B x T,4096,30.922752380371094,30.920312881469727,30.927898406982422,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:51,0.4.2 +distill_jsd_loss,liger,full,speed,ms,BT,B x T,8192,60.70627212524414,60.70627212524414,60.70627212524414,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 07:59:51,0.4.2 +distill_jsd_loss,torch,full,speed,ms,BT,B x T,1024,28.72480010986328,28.718809127807617,28.728179931640625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:25,0.4.2 +distill_jsd_loss,torch,full,speed,ms,BT,B x T,2048,54.281761169433594,54.281761169433594,54.281761169433594,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:25,0.4.2 +distill_jsd_loss,torch,full,speed,ms,BT,B x T,4096,107.08905792236328,107.08905792236328,107.08905792236328,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:25,0.4.2 +distill_jsd_loss,torch,full,speed,ms,BT,B x T,8192,213.1598663330078,213.1598663330078,213.1598663330078,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:25,0.4.2 +distill_jsd_loss,liger,full,memory,MB,BT,B x T,1024,10913.541015625,10913.541015625,10913.541015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:58,0.4.2 +distill_jsd_loss,liger,full,memory,MB,BT,B x T,2048,10941.548828125,10941.548828125,10941.548828125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:58,0.4.2 +distill_jsd_loss,liger,full,memory,MB,BT,B x T,4096,10997.564453125,10997.564453125,10997.564453125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:58,0.4.2 +distill_jsd_loss,liger,full,memory,MB,BT,B x T,8192,11109.595703125,11109.595703125,11109.595703125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:00:58,0.4.2 +distill_jsd_loss,torch,full,memory,MB,BT,B x T,1024,16174.0390625,16174.0390625,16174.0390625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2 +distill_jsd_loss,torch,full,memory,MB,BT,B x T,2048,23713.05078125,23713.05078125,23713.05078125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2 +distill_jsd_loss,torch,full,memory,MB,BT,B x T,4096,38791.07421875,38791.07421875,38791.07421875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2 +distill_jsd_loss,torch,full,memory,MB,BT,B x T,8192,68947.1015625,68947.1015625,68947.1015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2 diff --git a/benchmark/scripts/benchmark_distill_jsd_loss.py b/benchmark/scripts/benchmark_distill_jsd_loss.py new file mode 100644 index 000000000..331cbab7f --- /dev/null +++ b/benchmark/scripts/benchmark_distill_jsd_loss.py @@ -0,0 +1,261 @@ +import os +import sys + +import torch +import triton + +from utils import QUANTILES +from utils import SingleBenchmarkRunInput +from utils import SingleBenchmarkRunOutput +from utils import _test_memory +from utils import parse_benchmark_script_args +from utils import run_benchmarks + +from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction +from liger_kernel.utils import infer_device + +device = infer_device() + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + + +class TorchJSDLoss(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + bias: bool = False, + ): + from test.chunked_loss.test_jsd_loss import HFJSDLoss + + super().__init__() + self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=bias, dtype=dtype) + self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.jsd_loss = HFJSDLoss( + ignore_index=ignore_index, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + temperature=temperature, + ).get_batch_loss_metrics + + def forward(self, student, teacher, target): + return self.jsd_loss( + student, + self.student_lin.weight, + teacher, + self.teacher_lin.weight, + target, + ) + + +class LigerJSDLoss(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + bias: bool = False, + ): + super().__init__() + self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=bias, dtype=dtype) + self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.weight_hard_loss = weight_hard_loss + self.weight_soft_loss = weight_soft_loss + self.ignore_index = ignore_index + self.temperature = temperature + self.jsd_loss = LigerFusedLinearJSDFunction.apply + + def forward(self, student, teacher, target): + return self.jsd_loss( + student, + self.student_lin.weight, + teacher, + self.teacher_lin.weight, + target, + self.weight_hard_loss, + self.weight_soft_loss, + ) + + +def bench_memory_jsd_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + BT = input.x + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + bias = input.extra_benchmark_config["bias"] + weight_hard_loss = input.extra_benchmark_config["weight_hard_loss"] + weight_soft_loss = input.extra_benchmark_config["weight_soft_loss"] + ignore_index = input.extra_benchmark_config["ignore_index"] + provider = input.kernel_provider + + torch_jsd_loss = TorchJSDLoss( + H=H, + V=V, + dtype=dtype, + ignore_index=ignore_index, + bias=bias, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + ).to(device) + liger_jsd_loss = LigerJSDLoss( + H=H, + V=V, + dtype=dtype, + ignore_index=ignore_index, + bias=bias, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + ).to(device) + + _tensor = torch.rand(BT, H // 2, device=device, dtype=dtype) + student_input1 = _tensor.detach().clone().requires_grad_(True) + student_input2 = _tensor.detach().clone().requires_grad_(True) + + teacher_input = torch.rand(BT, H, device=device, dtype=dtype) + + target = torch.randint(0, V, (BT,), device=device, dtype=torch.long) + + def fwd(): + if provider == "liger": + return liger_jsd_loss(student_input1, teacher_input, target) + elif provider == "torch": + return torch_jsd_loss(student_input2, teacher_input, target) + + def full(): + y = fwd() + y.backward() + + mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +def bench_speed_jsd_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + BT = input.x + H = input.extra_benchmark_config["H"] + V = input.extra_benchmark_config["V"] + dtype = input.extra_benchmark_config["dtype"] + bias = input.extra_benchmark_config["bias"] + weight_hard_loss = input.extra_benchmark_config["weight_hard_loss"] + weight_soft_loss = input.extra_benchmark_config["weight_soft_loss"] + ignore_index = input.extra_benchmark_config["ignore_index"] + provider = input.kernel_provider + mode = input.kernel_operation_mode + + torch_jsd_loss = TorchJSDLoss( + H=H, + V=V, + dtype=dtype, + ignore_index=ignore_index, + bias=bias, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + ).to(device) + liger_jsd_loss = LigerJSDLoss( + H=H, + V=V, + dtype=dtype, + ignore_index=ignore_index, + bias=bias, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + ).to(device) + + _tensor = torch.rand(BT, H // 2, device=device, dtype=dtype) + student_input1 = _tensor.detach().clone().requires_grad_(True) + student_input2 = _tensor.detach().clone().requires_grad_(True) + + teacher_input = torch.rand(BT, H, device=device, dtype=dtype) + + target = torch.randint(0, V, (BT,), device=device, dtype=torch.long) + + def fwd(): + if provider == "liger": + return liger_jsd_loss(student_input1, teacher_input, target) + elif provider == "torch": + return torch_jsd_loss(student_input2, teacher_input, target) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + fwd, + rep=100, + quantiles=QUANTILES, + ) + elif mode == "backward": + y = fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(retain_graph=True), + grad_to_none=[student_input1, student_input2], + rep=100, + quantiles=QUANTILES, + ) + elif mode == "full": + + def full(): + y = fwd() + y.backward() + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, + rep=100, + quantiles=QUANTILES, + ) + + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "distill_jsd_loss", + "x_name": "BT", + "x_label": "B x T", + "x_values": [2**i for i in range(10, 14)], + "kernel_providers": ["liger", "torch"], + "extra_benchmark_configs": [ + { + "H": 4096, + "V": 128256, + "mode": "forward", + "dtype": torch.bfloat16, + "bias": False, + "weight_hard_loss": 0.5, + "weight_soft_loss": 0.5, + "ignore_index": -100, + } + ], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_jsd_loss, + kernel_operation_modes=["forward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + + run_benchmarks( + bench_test_fn=bench_memory_jsd_loss, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/src/liger_kernel/chunked_loss/__init__.py b/src/liger_kernel/chunked_loss/__init__.py index 87f3887b5..4f76ab79d 100644 --- a/src/liger_kernel/chunked_loss/__init__.py +++ b/src/liger_kernel/chunked_loss/__init__.py @@ -1,5 +1,6 @@ from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401 from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401 +from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDLoss # noqa: F401 from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401 from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401 from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401 diff --git a/src/liger_kernel/chunked_loss/functional.py b/src/liger_kernel/chunked_loss/functional.py index a10398400..2fd0439ff 100644 --- a/src/liger_kernel/chunked_loss/functional.py +++ b/src/liger_kernel/chunked_loss/functional.py @@ -1,11 +1,13 @@ from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction +from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOFunction from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply +liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply liger_fused_linear_kto = LigerFusedLinearKTOFunction.apply diff --git a/src/liger_kernel/chunked_loss/fused_linear_distillation.py b/src/liger_kernel/chunked_loss/fused_linear_distillation.py index d8d3a5315..9ed5cdcb9 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_distillation.py +++ b/src/liger_kernel/chunked_loss/fused_linear_distillation.py @@ -17,6 +17,9 @@ def distillation_loss_fn( Args: student_logits (torch.Tensor): Raw (temperature-scaled) logits of student tokens. Shape: (batch_size * seq_len, vocab_size). teacher_logits (torch.Tensor): Raw (temperature-scaled) logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size). + Returns: + torch.Tensor: Sum of distillation losses for the chunk. The class will handle + converting this to mean loss by dividing by the full batch size * sequence length in _compute_loss. """ raise NotImplementedError("Distillation loss function must be implemented.") @@ -71,10 +74,11 @@ def _compute_loss( weight_hard_loss=0.5, weight_soft_loss=0.5, compute_ce_loss=True, + temperature=1, **loss_kwargs, ): """ - Compute the total loss for a chunk of input and target, while using an knowleedge distillation loss function. + Compute the total loss for a chunk of input and target, while using an knowledge distillation loss function. Args: distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size). @@ -84,11 +88,12 @@ def _compute_loss( target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,). student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). - full_target (torch.Tensor): Full target tensor. Shape: (chunk_size,). + full_target (torch.Tensor): Full target tensor. Shape: (batch_size * sequence_length,). ignore_index (int): Index to ignore for loss computation. weight_hard_loss (float): Weight for hard loss. weight_soft_loss (float): Weight for soft loss. compute_ce_loss (bool): Whether to compute CE loss. + temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale) loss_kwargs (dict): Additional arguments for the loss function. """ ( @@ -107,6 +112,9 @@ def _compute_loss( compute_ce_loss=compute_ce_loss, ) + student_logits_chunk /= temperature + teacher_logits_chunk /= temperature + hard_loss /= full_target.shape[0] soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk) @@ -130,6 +138,7 @@ def forward( ignore_index=-100, weight_hard_loss=0.5, weight_soft_loss=0.5, + beta=0.5, compute_ce_loss=True, temperature=1.0, compiled=True, @@ -152,6 +161,7 @@ def forward( ignore_index (int): Index to ignore for loss computation. weight_hard_loss (float): Weight for hard/task loss. weight_soft_loss (float): Weight for soft/distillation loss. + beta (float): Interpolation coefficient between 0 and 1 (default: 0.5). compute_ce_loss (bool): Whether to compute CE loss. temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale) compiled (bool): Whether to use torch compile for chunk accumulation. @@ -170,7 +180,9 @@ def forward( ignore_index=ignore_index, weight_hard_loss=weight_hard_loss, weight_soft_loss=weight_soft_loss, + beta=beta, compute_ce_loss=compute_ce_loss, + temperature=temperature, **loss_kwargs, ) @@ -225,9 +237,6 @@ def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk): if compiled: accumulate_chunk = torch.compile(accumulate_chunk) - student_input /= temperature - teacher_input /= temperature - num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE) _student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0) _teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0) diff --git a/src/liger_kernel/chunked_loss/jsd_loss.py b/src/liger_kernel/chunked_loss/jsd_loss.py new file mode 100644 index 000000000..90510b8e7 --- /dev/null +++ b/src/liger_kernel/chunked_loss/jsd_loss.py @@ -0,0 +1,154 @@ +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase + + +class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase): + @staticmethod + def distillation_loss_fn(student_logits, teacher_logits, beta=0.5): + """ + Compute JSD loss (Jensen-Shannon Divergence Loss). + Args: + student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,). + teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,). + beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`. + Returns: + torch.Tensor: Jensen-Shannon Divergence loss + """ + student_log_probs = F.log_softmax(student_logits, dim=-1) + teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) + + # Compute probabilities (only required for mean calculation) + mean_probs = beta * student_log_probs.exp() + (1 - beta) * teacher_log_probs.exp() + log_mean_probs = mean_probs.log() + + student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True) + teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True) + + # JSD is the weighted average of the KL divergences + jsd_loss = beta * teacher_kl + (1 - beta) * student_kl + return jsd_loss + + @staticmethod + def forward( + ctx, + student_input: torch.Tensor, + student_weight: torch.Tensor, + teacher_input: torch.Tensor, + teacher_weight: torch.Tensor, + true_labels: torch.LongTensor, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + compiled: bool = True, + ): + """ + Fused linear layer with JSD distillation loss. + Args: + student_input (torch.Tensor): Student input tensor. Shape: (batch_size * seq_len, hidden_size_student) + student_weight (torch.Tensor): Student weight tensor. Shape: (vocab_size, hidden_size_student) + teacher_input (torch.Tensor): Teacher input tensor. Shape: (batch_size * seq_len, hidden_size_teacher) + teacher_weight (torch.Tensor): Teacher weight tensor. Shape: (vocab_size, hidden_size_teacher) + true_labels (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,) + weight_hard_loss (float): Weight for hard loss. + weight_soft_loss (float): Weight for soft loss. + beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`. + ignore_index (int): Index to ignore in loss computation + temperature (float): Temperature for softening/sharpening distributions + compiled (bool): Whether to use torch compile + Returns: + torch.Tensor: Computed loss + """ + return LigerFusedLinearDistillationBase.forward( + ctx=ctx, + student_input=student_input, + student_weight=student_weight, + teacher_input=teacher_input, + teacher_weight=teacher_weight, + target=true_labels, + loss_fn=LigerFusedLinearJSDFunction.distillation_loss_fn, + chunk_size=1, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + beta=beta, + ignore_index=ignore_index, + temperature=temperature, + compiled=compiled, + ) + + @staticmethod + def backward(ctx, grad_output): + grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:4] + + return (*grads, None, None, None, None, None, None, None) + + +class LigerFusedLinearJSDLoss(torch.nn.Module): + """ + Fused linear layer with JSD distillation loss. + """ + + def __init__( + self, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + compiled: bool = True, + ): + """ + Args: + weight_hard_loss (float): Weight for hard loss. + weight_soft_loss (float): Weight for soft loss. + ignore_index (int): Index to ignore in the loss + temperature (float): Temperature for softening distributions + compiled (bool): Whether to use torch compile + beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`. + """ + super().__init__() + assert temperature != 0, "Temperature cannot be 0." + self.weight_hard_loss = weight_hard_loss + self.weight_soft_loss = weight_soft_loss + self.ignore_index = ignore_index + self.temperature = temperature + self.compiled = compiled + self.beta = beta + + def forward( + self, + student_input: torch.Tensor, + student_weight: torch.Tensor, + teacher_input: torch.Tensor, + teacher_weight: torch.Tensor, + true_labels: torch.LongTensor, + ) -> torch.Tensor: + """ + Compute the JSD distillation loss. + + Args: + student_input (torch.Tensor): Student input tensor + student_weight (torch.Tensor): Student weight tensor + teacher_input (torch.Tensor): Teacher input tensor + teacher_weight (torch.Tensor): Teacher weight tensor + true_labels (torch.LongTensor): Target labels tensor + + Returns: + torch.Tensor: Computed loss + """ + return LigerFusedLinearJSDFunction.apply( + student_input, + student_weight, + teacher_input, + teacher_weight, + true_labels, + self.weight_hard_loss, + self.weight_soft_loss, + self.beta, + self.ignore_index, + self.temperature, + self.compiled, + ) diff --git a/test/chunked_loss/test_jsd_loss.py b/test/chunked_loss/test_jsd_loss.py new file mode 100644 index 000000000..513d91acb --- /dev/null +++ b/test/chunked_loss/test_jsd_loss.py @@ -0,0 +1,318 @@ +import pytest +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss import LigerFusedLinearJSDLoss +from liger_kernel.chunked_loss.functional import liger_fused_linear_jsd +from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction +from liger_kernel.utils import infer_device +from test.utils import HFDistillationLoss +from test.utils import assert_verbose_allclose +from test.utils import set_seed + +device = infer_device() + +# set random seed globally +set_seed() + + +class HFJSDLoss(HFDistillationLoss): + """ + Naive implementation of a distillation loss using Jensen-Shannon Divergence (JSD). + """ + + def __init__( + self, + temperature: float = 1.0, + ignore_index: int = -100, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + beta: float = 0.5, + ): + super().__init__( + ignore_index=ignore_index, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + temperature=temperature, + ) + self.beta = (beta,) + + def distillation_loss(self, student_logits, teacher_logits, beta=0.5): + """ + Compute JSD loss (Jensen-Shannon Divergence Loss). + Args: + student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,). + teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,). + beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`. + Returns: + torch.Tensor: Jensen-Shannon Divergence loss + """ + student_log_probs = F.log_softmax(student_logits, dim=-1) + teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) + + # Compute probabilities (only required for mean calculation) + mean_probs = beta * student_log_probs.exp() + (1 - beta) * teacher_log_probs.exp() + log_mean_probs = mean_probs.log() + + student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="batchmean", log_target=True) + teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="batchmean", log_target=True) + + # JSD is the weighted average of the KL divergences + jsd_loss = beta * teacher_kl + (1 - beta) * student_kl + return jsd_loss + + +class TorchLMHeadJSD(torch.nn.Module): + """Ground truth implementation of the linear fused with torch based jsd loss. + :param H: hidden size + :param V: vocab size + :param temperature: softmax temperature + :param weight_hard_loss: weight_hard_loss + :param weight_soft_loss: weight_soft_loss + """ + + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + device: torch.device, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + ): + super().__init__() + # smaller student model weights + self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=False, dtype=dtype, device=device) + self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device) + self.jsd = HFJSDLoss( + ignore_index=ignore_index, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + temperature=temperature, + beta=beta, + ).get_batch_loss_metrics + + def forward(self, student_input, teacher_input, target): + jsd_loss = self.jsd( + student_input, + self.student_lin.weight, + teacher_input, + self.teacher_lin.weight, + target, + ) + return jsd_loss + + +class LigerLMHeadJSD(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + device: torch.device, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + ): + super().__init__() + # smaller student model weights + self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=False, dtype=dtype, device=device) + self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device) + self.chunked_jsd = LigerFusedLinearJSDLoss( + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + ignore_index=ignore_index, + temperature=temperature, + ) + + def forward(self, student_input, teacher_input, target): + return self.chunked_jsd( + student_input, + self.student_lin.weight, + teacher_input, + self.teacher_lin.weight, + target, + ) + + +############################################################################# +# Test the correctness of the fused linear JSD +############################################################################# + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize( + "temperature, weight_hard_loss, weight_soft_loss, beta", + [ + (1.0, 0.5, 0.5, 0.5), + (2.0, 0.5, 0.5, 0.5), + (0.5, 0.5, 0.5, 0.5), + (1.0, 0.0, 1.0, 0.5), + (1.0, 1.0, 0.0, 0.5), + (1.0, 0.5, 0.5, 0.3), + (2.0, 0.5, 0.5, 0.7), + ], +) +def test_correctness( + B, + T, + H, + V, + scalar, + dtype, + atol, + rtol, + temperature, + weight_hard_loss, + weight_soft_loss, + beta, +): + torch_lm_head_jsd = TorchLMHeadJSD( + H=H, + V=V, + dtype=dtype, + device=device, + temperature=temperature, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + beta=beta, + ) + liger_lm_head_jsd = LigerLMHeadJSD( + H=H, + V=V, + dtype=dtype, + device=device, + temperature=temperature, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + beta=beta, + ) + + torch_lm_head_jsd.student_lin.weight.data = liger_lm_head_jsd.student_lin.weight.data = torch.rand( + V, H // 2, device=device, dtype=dtype + ) + torch_lm_head_jsd.teacher_lin.weight.data = liger_lm_head_jsd.teacher_lin.weight.data = torch.rand( + V, H, device=device, dtype=dtype + ) + + _tensor = torch.rand(B * T, H // 2, device=device, dtype=dtype) * scalar + student_input1 = _tensor.detach().clone().requires_grad_(True) + student_input2 = _tensor.detach().clone().requires_grad_(True) + + teacher_input = torch.rand(B * T, H, device=device, dtype=dtype) * scalar + + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + loss1 = torch_lm_head_jsd(student_input1, teacher_input, target) + loss2 = liger_lm_head_jsd(student_input2, teacher_input, target) + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + loss1.backward() + loss2.backward() + + assert_verbose_allclose(student_input1.grad, student_input2.grad, atol=atol, rtol=rtol) + + assert_verbose_allclose( + torch_lm_head_jsd.student_lin.weight.grad, + liger_lm_head_jsd.student_lin.weight.grad, + atol=atol, + rtol=rtol, + ) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (2, 2, 8, 8), + (9, 7, 41, 41), + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-2), + (1.0, torch.float32, 1e-4, 5e-3), + ], +) +@pytest.mark.parametrize( + "temperature, weight_hard_loss, weight_soft_loss, beta, ignore_index", + [(1.0, 0.5, 0.5, 0.5, -100), (2.0, 0.1, 0.9, 0.5, 42)], +) +def test_correctness_functional( + B, + T, + H, + V, + scalar, + dtype, + weight_hard_loss, + weight_soft_loss, + beta, + ignore_index, + temperature, + atol, + rtol, +): + _weight = torch.rand(V, H // 2, device=device, dtype=dtype) + student_weight1 = _weight.detach().clone().requires_grad_(True) + student_weight2 = _weight.detach().clone().requires_grad_(True) + teacher_weight = torch.rand(V, H, device=device, dtype=dtype) + + _tensor = torch.rand(B * T, H // 2, device=device, dtype=dtype) * scalar + student_input1 = _tensor.detach().clone().requires_grad_(True) + student_input2 = _tensor.detach().clone().requires_grad_(True) + teacher_input = torch.rand(B * T, H, device=device, dtype=dtype) * scalar + + label = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + output1 = liger_fused_linear_jsd( + student_input1, + student_weight1, + teacher_input, + teacher_weight, + label, + weight_hard_loss, + weight_soft_loss, + beta, + ignore_index, + temperature, + ) + output2 = LigerFusedLinearJSDFunction.apply( + student_input2, + student_weight2, + teacher_input, + teacher_weight, + label, + weight_hard_loss, + weight_soft_loss, + beta, + ignore_index, + temperature, + ) + + assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol) + + output1.backward() + output2.backward() + + assert_verbose_allclose(student_input1.grad, student_input2.grad, atol=atol, rtol=rtol) + + assert_verbose_allclose(student_weight1.grad, student_weight2.grad, atol=atol, rtol=rtol) diff --git a/test/utils.py b/test/utils.py index 974d09a45..a6af16e21 100644 --- a/test/utils.py +++ b/test/utils.py @@ -689,7 +689,10 @@ def get_batch_loss_metrics( hard_loss, ) = forward_output + student_logits /= self.temperature + teacher_logits /= self.temperature + soft_loss = self.distillation_loss(student_logits, teacher_logits) # full loss - loss = self.weight_hard_loss * hard_loss + self.weight_soft_loss * soft_loss.mean() + loss = self.weight_hard_loss * hard_loss + self.weight_soft_loss * soft_loss return loss