Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added sliding window feature to Giudi Gemma2 model #1736

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,11 @@ def setup_parser(parser):
action="store_true",
help="Whether to enable Habana Flash Attention in fast softmax mode.",
)
parser.add_argument(
"--flash_attention_2",
action="store_true",
help="Whether to enable Flash Attention-2, provided that the model supports it.",
)
parser.add_argument(
"--book_source",
action="store_true",
Expand Down
5 changes: 5 additions & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,7 @@ def setup_generation_config(args, model, assistant_model, tokenizer):
generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask
generation_config.flash_attention_fast_softmax = args.flash_attention_fast_softmax
generation_config.trust_remote_code = args.trust_remote_code
generation_config.flash_attention_2 = args.flash_attention_2
generation_config.valid_sequence_lengths = None

return generation_config
Expand Down Expand Up @@ -716,6 +717,10 @@ def initialize_model(args, logger):
"token": args.token,
"trust_remote_code": args.trust_remote_code,
}

if args.flash_attention_2:
model_kwargs["attn_implementation"] = "flash_attention_2"

if args.load_quantized_model_with_inc or args.local_quantized_inc_model_path:
model_kwargs["torch_dtype"] = torch.bfloat16

Expand Down
14 changes: 14 additions & 0 deletions optimum/habana/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,20 @@ def forward(
The only differences are:
- add new args token_idx
"""
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# Flash-attn is a 2D tensor
if self.config._attn_implementation == "flash_attention_2":
if past_key_value is not None: # when decoding
attention_mask = attention_mask[:, -self.sliding_window :]
else:
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window :]

residual = hidden_states

hidden_states, self_attn_weights, present_key_value = self.pre_attn(
Expand Down