Skip to content

Commit

Permalink
Remove graph breaks for torch.compile() in padding free branch in Dat…
Browse files Browse the repository at this point in the history
…aCollatorForCompletionOnlyLM (#2158)

* feat: Add info to batch in DataCollatorForCompletionOnlyLM

Signed-off-by: Abhishek <[email protected]>

* fix: formatting

Signed-off-by: Abhishek <[email protected]>

* feat: Add info to batch in DataCollatorForCompletionOnlyLM

Signed-off-by: Abhishek <[email protected]>

* fix: formatting

Signed-off-by: Abhishek <[email protected]>

* fix: max_length_k to int

Signed-off-by: Abhishek <[email protected]>

* fix:Added comments

Signed-off-by: Abhishek <[email protected]>

* test cases

Signed-off-by: Abhishek <[email protected]>

* test cases

Signed-off-by: Abhishek <[email protected]>

* test cases

Signed-off-by: Abhishek <[email protected]>

* feat: Add info to batch in DataCollatorForCompletionOnlyLM

Signed-off-by: Abhishek <[email protected]>

* fix: formatting

Signed-off-by: Abhishek <[email protected]>

* feat: Add info to batch in DataCollatorForCompletionOnlyLM

Signed-off-by: Abhishek <[email protected]>

* test cases

Signed-off-by: Abhishek <[email protected]>

* test cases

Signed-off-by: Abhishek <[email protected]>

* test cases

Signed-off-by: Abhishek <[email protected]>

* unit test changes

Signed-off-by: Abhishek <[email protected]>

* style

* add test

* remove test

---------

Signed-off-by: Abhishek <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
  • Loading branch information
3 people authored Jan 6, 2025
1 parent 763738f commit d9ee2fd
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
20 changes: 19 additions & 1 deletion tests/test_data_collator_completion_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_padding_free(self):
inst1 = "### System: You are a helpful assistant.\n\n### User: How much is 2+2?\n\n### Assistant: 2+2 equals 4"
inst2 = "### System: You are a honest and helpful assistant.\n\n### User: What is the answer of 22x22?\n\n### Assistant: 22x22 equals 484"

response_template = "\n### Assistant:"
response_template = "\n\n### Assistant:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
collator_paddingfree = DataCollatorForCompletionOnlyLM(
response_template, tokenizer=tokenizer, padding_free=True
Expand Down Expand Up @@ -143,3 +143,21 @@ def test_padding_free(self):
self.assertTrue((input_ids_remove_pad == batch_paddingfree["input_ids"]).all())
self.assertTrue((expected_position_ids == batch_paddingfree["position_ids"]).all())
self.assertTrue((expected_labels == batch_paddingfree["labels"]).all())

def test_data_collator_for_completion_only_lm(self):
# The tokenizer isn't use but the collator needs it to be provided.
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")

collator = DataCollatorForCompletionOnlyLM(tokenizer.decode(9999), tokenizer=tokenizer, padding_free=True)

tokenized_instruction = [
{"input_ids": [1, 2, 3, 9999, 4, 5], "attention_mask": [1, 1, 1, 1, 1, 1]},
{"input_ids": [6, 7, 8, 9, 9999, 10, 11], "attention_mask": [1, 1, 1, 1, 1, 1, 1]},
]
batch = collator(tokenized_instruction)

self.assertEqual(batch["position_ids"].tolist(), [[0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6]]) # flat pos ids
self.assertEqual(batch["cu_seq_lens_q"].tolist(), [0, 6, 13]) # start idx of each seq + total number of tokens
self.assertEqual(batch["cu_seq_lens_k"].tolist(), [0, 6, 13]) # idem
self.assertEqual(batch["max_length_k"], 7) # max length in batch, here 7 (second sequence)
self.assertEqual(batch["max_length_q"], 7) # idem
19 changes: 19 additions & 0 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,25 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d
batch["labels"] = batch["labels"][attn_mask.bool()].unsqueeze(0)
batch["labels"][batch["position_ids"] == 0] = self.ignore_index

# Calculate cumulative sequence lengths for queries and keys to prevent graph breaks during further computations.
flattened_position_ids = batch["position_ids"].flatten()
indices_q = torch.arange(
flattened_position_ids.size(0), device=flattened_position_ids.device, dtype=torch.int32
)
batch["cu_seq_lens_q"] = torch.cat(
(
indices_q[flattened_position_ids == 0],
torch.tensor(
flattened_position_ids.size(), device=flattened_position_ids.device, dtype=torch.int32
),
)
)
batch["cu_seq_lens_k"] = batch["cu_seq_lens_q"]

# Determine maximum sequence lengths to prevent graph breaks during further computations.
batch["max_length_k"] = flattened_position_ids.max().item() + 1
batch["max_length_q"] = batch["max_length_k"]

return batch


Expand Down

0 comments on commit d9ee2fd

Please sign in to comment.