Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch dim idx to support latest deepspeed DistributedAttention #1725

Merged
merged 3 commits into from
Feb 6, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 76 additions & 10 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,9 @@ def allocate(self, inp_seq_len, dtype, device, shape):
self.inp_seq_len = inp_seq_len
self.cache = torch.zeros(shape, dtype=dtype, device=device)
else:
assert (
self.inp_seq_len == inp_seq_len
), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
assert self.inp_seq_len == inp_seq_len, (
f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
)
self.cache.fill_(0)

@staticmethod
Expand Down Expand Up @@ -427,7 +427,68 @@ def forward(self, cur, dim, idx):
return self.update(self.cache, cur, dim, idx, self.inp_seq_len)


def GaudiDistributedAttention(fused_scaled_dot_product_attention, fused_scaled_dot_product_attention_distributed):
class GaudiDistributedAttention(torch.nn.Module):
def __init__(
self, hpu_module_fsdpa: ModuleFusedSDPA, scale, attention_dropout, enable_recompute, flash_attention_fp8
):
super().__init__()
self._hpu_module_fsdpa = hpu_module_fsdpa
if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1:
from deepspeed.sequence.layer import DistributedAttention

self._hpu_module_fsdpa_distributed = DistributedAttention(
self._hpu_module_fsdpa, parallel_state.get_sequence_parallel_group(), 1, 2
)

def forward(
bhargaveede marked this conversation as resolved.
Show resolved Hide resolved
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor,
dropout_p: float,
is_casual,
scale,
softmax_mode,
recompute_mode,
valid_sequence_lengths,
padding_side="left",
):
if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1:
return self._hpu_module_fsdpa_distributed(
query,
key,
value,
0, # As the shape for inputs is [B, N, S, H]
None,
attn_mask,
dropout_p,
is_casual,
scale,
softmax_mode,
recompute_mode,
valid_sequence_lengths,
padding_side,
)
else:
return self._hpu_module_fsdpa(
query,
key,
value,
attn_mask,
dropout_p,
is_casual,
scale,
softmax_mode,
recompute_mode,
valid_sequence_lengths,
padding_side,
)


def get_gaudi_distributed_attention(
fused_scaled_dot_product_attention, fused_scaled_dot_product_attention_distributed
):
if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1:
return fused_scaled_dot_product_attention_distributed
else:
Expand Down Expand Up @@ -469,14 +530,19 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
if FusedSDPA
else None
)
# https://github.com/microsoft/DeepSpeed/issues/4359
# for all2all comm, Distributed Attention cares about sequence (s) and number of heads (h) dimensions. In HPU, they are at 1 and 2 indices
self.fused_scaled_dot_product_attention_distributed = None
if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1:
from deepspeed.sequence.layer import DistributedAttention

self.fused_scaled_dot_product_attention_distributed = DistributedAttention(
self.fused_scaled_dot_product_attention, parallel_state.get_sequence_parallel_group(), 1, 2
self.fused_scaled_dot_product_attention_distributed = (
GaudiDistributedAttention(
self.fused_scaled_dot_product_attention,
scale=self.norm_factor,
attention_dropout=self.attention_dropout,
enable_recompute=False,
flash_attention_fp8=getattr(config, "flash_attention_fp8", False),
)
if FusedSDPA
else None
)

def get_k_proj_weight(self):
Expand Down Expand Up @@ -683,7 +749,7 @@ def pre_attn_forward(
kv_seq_len = key_states.shape[-2]
else:
past_key_value = None
fused_scaled_dot_product_attention = GaudiDistributedAttention(
fused_scaled_dot_product_attention = get_gaudi_distributed_attention(
self.fused_scaled_dot_product_attention, self.fused_scaled_dot_product_attention_distributed
)
if use_flash_attention and FusedSDPA is not None:
Expand Down