Skip to content

Commit

Permalink
Add batch dim idx to support latest deepspeed DistributedAttention (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
bhargaveede authored Feb 6, 2025
1 parent 58de6b6 commit bedc041
Showing 1 changed file with 73 additions and 7 deletions.
80 changes: 73 additions & 7 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,68 @@ def forward(self, cur, dim, idx):
return self.update(self.cache, cur, dim, idx, self.inp_seq_len)


def GaudiDistributedAttention(fused_scaled_dot_product_attention, fused_scaled_dot_product_attention_distributed):
class GaudiDistributedAttention(torch.nn.Module):
def __init__(
self, hpu_module_fsdpa: ModuleFusedSDPA, scale, attention_dropout, enable_recompute, flash_attention_fp8
):
super().__init__()
self._hpu_module_fsdpa = hpu_module_fsdpa
if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1:
from deepspeed.sequence.layer import DistributedAttention

self._hpu_module_fsdpa_distributed = DistributedAttention(
self._hpu_module_fsdpa, parallel_state.get_sequence_parallel_group(), 1, 2
)

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor,
dropout_p: float,
is_casual,
scale,
softmax_mode,
recompute_mode,
valid_sequence_lengths,
padding_side="left",
):
if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1:
return self._hpu_module_fsdpa_distributed(
query,
key,
value,
0, # As the shape for inputs is [B, N, S, H]
None,
attn_mask,
dropout_p,
is_casual,
scale,
softmax_mode,
recompute_mode,
valid_sequence_lengths,
padding_side,
)
else:
return self._hpu_module_fsdpa(
query,
key,
value,
attn_mask,
dropout_p,
is_casual,
scale,
softmax_mode,
recompute_mode,
valid_sequence_lengths,
padding_side,
)


def get_gaudi_distributed_attention(
fused_scaled_dot_product_attention, fused_scaled_dot_product_attention_distributed
):
if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1:
return fused_scaled_dot_product_attention_distributed
else:
Expand Down Expand Up @@ -472,14 +533,19 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
if FusedSDPA
else None
)
# https://github.com/microsoft/DeepSpeed/issues/4359
# for all2all comm, Distributed Attention cares about sequence (s) and number of heads (h) dimensions. In HPU, they are at 1 and 2 indices
self.fused_scaled_dot_product_attention_distributed = None
if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1:
from deepspeed.sequence.layer import DistributedAttention

self.fused_scaled_dot_product_attention_distributed = DistributedAttention(
self.fused_scaled_dot_product_attention, parallel_state.get_sequence_parallel_group(), 1, 2
self.fused_scaled_dot_product_attention_distributed = (
GaudiDistributedAttention(
self.fused_scaled_dot_product_attention,
scale=self.norm_factor,
attention_dropout=self.attention_dropout,
enable_recompute=False,
flash_attention_fp8=getattr(config, "flash_attention_fp8", False),
)
if FusedSDPA
else None
)

def get_k_proj_weight(self):
Expand Down Expand Up @@ -696,7 +762,7 @@ def pre_attn_forward(
kv_seq_len = key_states.shape[-2]
else:
past_key_value = None
fused_scaled_dot_product_attention = GaudiDistributedAttention(
fused_scaled_dot_product_attention = get_gaudi_distributed_attention(
self.fused_scaled_dot_product_attention, self.fused_scaled_dot_product_attention_distributed
)
if use_flash_attention and FusedSDPA is not None:
Expand Down

0 comments on commit bedc041

Please sign in to comment.