Skip to content

Commit

Permalink
Merge branch 'main' into austin362667/chunked_compiled_jsd_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
lancerts authored Jan 21, 2025
2 parents a3d585b + 2ea3cfb commit 7ec96b1
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions src/liger_kernel/transformers/model/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def lce_forward(
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
r"""
Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy
Expand Down Expand Up @@ -125,14 +126,30 @@ def lce_forward(
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)

if version.parse(transformers_version) > version.parse("4.46.2"):
if version.parse(transformers_version) > version.parse("4.46.3"):
# NOTE: this bug fix for qwen2-vl is not applied until transformers 4.47.0
# https://github.com/huggingface/transformers/issues/33401
# While correct, this breaks equivalence with past versions of Qwen2-VL from
# transformers and leads to failed tests or users noticing differences in results.
# TODO: remove above conditional when liger drops support for transformers<4.47.0
if position_ids is None and input_ids is not None:
position_ids, _ = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
# calculate RoPE index once per generation in the pre-fill stage only
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
position_ids, rope_deltas = self.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, attention_mask
)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
if cache_position is not None: # otherwise `deltas` is an int `0`
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)

outputs = self.model(
input_ids=None,
Expand All @@ -144,6 +161,7 @@ def lce_forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)

hidden_states = outputs[0]
Expand Down

0 comments on commit 7ec96b1

Please sign in to comment.