Skip to content

Commit

Permalink
Applying scales rename to fp8 config (ROCm#387)
Browse files Browse the repository at this point in the history
  • Loading branch information
gshtras authored Jan 24, 2025
1 parent 84f5d47 commit 8e87b08
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,10 @@ def get_cache_scale(self, name: str) -> Optional[str]:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
if name.endswith(".output_scale") and ".q_proj" in name:
return name.replace(".q_proj.output_scale", ".attn.q_scale")
if name.endswith("self_attn.prob_output_scale"):
return name.replace(".prob_output_scale", ".attn.prob_scale")
# If no matches, return None
return None

Expand Down
20 changes: 20 additions & 0 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,26 @@ def get_quant_method(self, layer: torch.nn.Module,
return Fp8KVCacheMethod(self)
return None

def get_cache_scale(self, name: str) -> Optional[str]:
"""
Check whether the param name matches the format for k/v cache scales
in compressed-tensors. If this is the case, return its equivalent
param name expected by vLLM
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if name.endswith(".output_scale") and ".k_proj" in name:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
if name.endswith(".output_scale") and ".q_proj" in name:
return name.replace(".q_proj.output_scale", ".attn.q_scale")
if name.endswith("self_attn.prob_output_scale"):
return name.replace(".prob_output_scale", ".attn.prob_scale")
# If no matches, return None
return None


class Fp8LinearMethod(LinearMethodBase):
"""Linear method for FP8.
Expand Down

0 comments on commit 8e87b08

Please sign in to comment.