Skip to content

Commit

Permalink
🔬 SFT simplification (#2405)
Browse files Browse the repository at this point in the history
* initial commit

* update

* Refactor SFTTrainer and SFTConfig

* Update SFTConfig class in sft_config.py

* Fix SFTConfig torch_dtype validation and dataset preprocessing flag

* Refactor dataset mapping and conversion

* Refactor dataset mapping in SFTTrainer

* Fix SFTTrainerTester unit test by removing unnecessary code

* Remove unused variables and update tokenization logic

* Remove pack_dataset function

* Add deprecation warning for tokenizer in SFTTrainer constructor

* add docstring back

* Update model parameter type annotation

* Update SFTTrainer class definition

* style

* preprocess_dataset -> _prepare_dataset

* Retro compat

* Update formatting_func type hint in SFTTrainer constructor

* typo

* better comment

* simplify tokenize row

* Fix type hint for peft_config

* fix doc

* Add pack_examples function to `test_data_utils.py`

* promote pack_examples and document

* improve doc

* Add new SFTTrainerTester2 class for testing

* test was reversed

* ©️ Copyrights update (#2454)

* First changes

* Other files

* Finally

* rm comment

* fix nashmd

* Fix example

* Fix example

* 💬 Fix chat for windows (#2443)

* fix chat for windows

* add some tests back

* Revert "add some tests back"

This reverts commit 350aef5.

* 🆔 Add `datast_config` to `ScriptArguments` (#2440)

* datast_config_name

* Update trl/utils.py

* sort import

* typo

* Trigger CI

* Rename `dataset_config_name` to `dataset_config`

* 🏎 Fix deepspeed preparation of `ref_model` in `OnlineDPOTrainer` (#2417)

* Remove unused deepspeed code

* add model prep back

* add deepspeed even if it doesn't work

* rm old code

* 👯 Standardize `model_args` (#2442)

* `model_config` -> `model_args`

* sort

* refactor config

* drop skip prepare dataset

* add sep to packing

* drop prompt-completion for now

* Revert "drop prompt-completion for now"

This reverts commit 16ef195.

* Revert "add sep to packing"

This reverts commit dc84d08.

* Revert "drop skip prepare dataset"

This reverts commit d2ee070.

* Revert "refactor config"

This reverts commit f732aa8.

* Format

* Update doc-builder workflow to use specific commit sha

* add peft edge cases

* no logits when using liger

* remove unused columns

* proper handle of prompt-completion

* trick to keep messages

* fix messages missing

* for Liger kernel, ensure only input_ids is present

* packing and liger are compatible

* shinny doc and final nits

* another nit

* refactor config and doc

* re add truncation

* fix ci

* drop deprecated params in tests

* fix link

* fix config docstring

---------

Co-authored-by: Kashif Rasul <[email protected]>
  • Loading branch information
qgallouedec and kashif authored Feb 7, 2025
1 parent 82d12eb commit 5b9236d
Show file tree
Hide file tree
Showing 9 changed files with 597 additions and 526 deletions.
4 changes: 4 additions & 0 deletions docs/source/data_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,7 @@
## maybe_unpair_preference_dataset

[[autodoc]] maybe_unpair_preference_dataset

## pack_examples

[[autodoc]] pack_examples
43 changes: 43 additions & 0 deletions tests/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
maybe_apply_chat_template,
maybe_extract_prompt,
maybe_unpair_preference_dataset,
pack_examples,
unpair_preference_dataset,
)

Expand Down Expand Up @@ -392,6 +393,48 @@ def test_maybe_extract_prompt_standard_already_explicit(self):
)


class TestPackExamples(unittest.TestCase):
def test_pack_examples_larger_chunks(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
}
seq_length = 5
expected_output = {
"input_ids": [[1, 2, 3, 4, 5], [6, 7, 8]],
"attention_mask": [[0, 1, 1, 0, 0], [1, 1, 1]],
}
result = pack_examples(examples, seq_length)
self.assertEqual(result, expected_output)

def test_pack_examples_smaller_chunks(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
}
seq_length = 2
expected_output = {
"input_ids": [[1, 2], [3, 4], [5, 6], [7, 8]],
"attention_mask": [[0, 1], [1, 0], [0, 1], [1, 1]],
}
result = pack_examples(examples, seq_length)
self.assertEqual(result, expected_output)

def test_pack_with_dataset(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
}
dataset = Dataset.from_dict(examples)
seq_length = 3
expected_output = {
"input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]],
"attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]],
}
dataset = dataset.map(pack_examples, batched=True, fn_kwargs={"seq_length": seq_length})
self.assertEqual(dataset.to_dict(), expected_output)


# Run the tests
if __name__ == "__main__":
unittest.main()
120 changes: 71 additions & 49 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def formatting_prompts_func_batched(example):


if is_peft_available():
from peft import LoraConfig, PeftModel
from peft import LoraConfig, PeftModel, get_peft_model

if is_vision_available():
from PIL import Image as PILImage
Expand Down Expand Up @@ -327,7 +327,6 @@ def test_sft_trainer_uncorrect_data(self):
save_steps=1,
per_device_train_batch_size=2,
max_seq_length=32, # make sure there is at least 1 packed sequence
num_of_sequences=32,
packing=True,
report_to="none",
)
Expand Down Expand Up @@ -408,45 +407,6 @@ def test_sft_trainer_uncorrect_data(self):
formatting_func=formatting_prompts_func,
)

# This should not work because not enough data for one sample
training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
max_steps=2,
eval_steps=1,
save_steps=1,
per_device_train_batch_size=2,
max_seq_length=1024, # make sure there is NOT at least 1 packed sequence
packing=True,
report_to="none",
)
with self.assertRaises(ValueError):
_ = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.dummy_dataset,
formatting_func=formatting_prompts_func,
)

# This should not work as well
with self.assertRaises(ValueError):
training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
max_steps=2,
eval_steps=1,
save_steps=1,
per_device_train_batch_size=2,
packing=False,
report_to="none",
)
_ = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.dummy_dataset,
formatting_func=formatting_prompts_func,
)

# but this should work
training_args = SFTConfig(
output_dir=tmp_dir,
Expand Down Expand Up @@ -502,7 +462,6 @@ def test_sft_trainer_with_model_num_train_epochs(self):
num_train_epochs=2,
per_device_train_batch_size=2,
max_seq_length=16,
num_of_sequences=16,
packing=True,
report_to="none",
)
Expand Down Expand Up @@ -576,7 +535,6 @@ def test_sft_trainer_with_model(self):
save_steps=1,
per_device_train_batch_size=2,
max_seq_length=16,
num_of_sequences=16,
packing=True,
report_to="none",
)
Expand All @@ -601,7 +559,6 @@ def test_sft_trainer_with_model(self):
save_steps=1,
per_device_train_batch_size=2,
max_seq_length=16,
num_of_sequences=16,
packing=True,
report_to="none",
)
Expand Down Expand Up @@ -808,8 +765,6 @@ def test_sft_trainer_infinite_with_model(self):
eval_dataset=self.eval_dataset,
)

self.assertTrue(trainer.train_dataset.infinite)

trainer.train()

self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
Expand Down Expand Up @@ -837,8 +792,6 @@ def test_sft_trainer_infinite_with_model_epochs(self):
eval_dataset=self.eval_dataset,
)

self.assertFalse(trainer.train_dataset.infinite)

trainer.train()

self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
Expand Down Expand Up @@ -1345,6 +1298,75 @@ def test_sft_trainer_torch_dtype(self):
)

self.assertIn(
"Invalid `torch_dtype` passed to the SFTConfig. Expected a string with either `torch.dtype` or 'auto', but got -1.",
"Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing "
"a `torch.dtype` (e.g., 'float32'), but got -1.",
str(context.exception),
)


# This new tester aims to replace the first one at some point
class SFTTrainerTester2(unittest.TestCase):
def test_train(self):
# Get the model and dataset
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id)
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")

with tempfile.TemporaryDirectory() as tmp_dir:
# Initialize the trainer
training_args = SFTConfig(output_dir=tmp_dir, report_to="none")
trainer = SFTTrainer(args=training_args, model=model, train_dataset=dataset)

# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

# Train the model
trainer.train()

# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")

@require_peft
def test_train_peft_model(self):
# Get the base model
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id)

# Get the base model parameter names
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]

# Turn the model into a peft model
lora_config = LoraConfig()
model = get_peft_model(model, lora_config)

# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")

with tempfile.TemporaryDirectory() as tmp_dir:
# Initialize the trainer
training_args = SFTConfig(output_dir=tmp_dir, report_to="none")
trainer = SFTTrainer(args=training_args, model=model, train_dataset=dataset)

# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

# Train the model
trainer.train()

# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names: # We expect the base model parameters to be the same
self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
elif (
"base_layer" not in n
): # We expect the peft parameters to be different (except for the base layer)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
4 changes: 0 additions & 4 deletions tests/test_trainers_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,6 @@ def test_sft(self):
model_init_kwargs={"trust_remote_code": True},
dataset_kwargs={"append_concat_token": True, "skip_prepare_dataset": True},
eval_packing=True,
num_of_sequences=32,
chars_per_token=4.2,
)
trainer = SFTTrainer(model_id, args=training_args, train_dataset=dataset)
self.assertEqual(trainer.args.dataset_text_field, "dummy_text_field")
Expand All @@ -389,8 +387,6 @@ def test_sft(self):
self.assertIn("append_concat_token", trainer.args.dataset_kwargs)
self.assertEqual(trainer.args.dataset_kwargs["append_concat_token"], True)
self.assertEqual(trainer.args.eval_packing, True)
self.assertEqual(trainer.args.num_of_sequences, 32)
self.assertEqual(trainer.args.chars_per_token, 4.2)

@parameterized.expand([(False,), (True,)])
def test_xpo(self, alpha_list):
Expand Down
2 changes: 2 additions & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"maybe_apply_chat_template",
"maybe_extract_prompt",
"maybe_unpair_preference_dataset",
"pack_examples",
"unpair_preference_dataset",
],
"environment": ["TextEnvironment", "TextHistory"],
Expand Down Expand Up @@ -127,6 +128,7 @@
maybe_apply_chat_template,
maybe_extract_prompt,
maybe_unpair_preference_dataset,
pack_examples,
unpair_preference_dataset,
)
from .environment import TextEnvironment, TextHistory
Expand Down
34 changes: 34 additions & 0 deletions trl/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,3 +412,37 @@ def maybe_extract_prompt(example: dict[str, list]) -> dict[str, list]:
if (chosen_conv and prompt_conv) or (not chosen_conv and not prompt_conv):
return example
return extract_prompt({"chosen": example["chosen"], "rejected": example["rejected"]})


def pack_examples(examples: dict[str, list[list]], seq_length: int) -> dict[str, list[list]]:
"""
Pack examples into chunks of size `seq_length`.
Args:
examples (`dict[str, list[list]]`):
Dictionary of examples with keys as strings and values as lists of lists.
seq_length (`int`):
Maximum sequence length.
Returns:
`dict[str, list[list]]`: Dictionary of examples with keys as strings and values as lists of lists.
Example:
```python
>>> from trl import pack_examples
>>> examples = {
... "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
... "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
... }
>>> pack_examples(examples, seq_length=5)
{'input_ids': [[1, 2, 3, 4, 5], [6, 7, 8]], 'attention_mask': [[0, 1, 1, 0, 0], [1, 1, 1]]}
>>> pack_examples(examples, seq_length=2)
{'input_ids': [[1, 2], [3, 4], [5, 6], [7, 8]], 'attention_mask': [[0, 1], [1, 0], [0, 1], [1, 1]]}
```
"""
# Join all the values into a single list
examples = {k: sum(v, []) for k, v in examples.items()}
# Split the values into chunks of size seq_length
examples = {k: [v[i : i + seq_length] for i in range(0, len(v), seq_length)] for k, v in examples.items()}
return examples
10 changes: 8 additions & 2 deletions trl/trainer/gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def __init__(
processing_class: Optional[
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
callbacks: Optional[list[TrainerCallback]] = None,
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
Expand All @@ -97,7 +96,6 @@ def __init__(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
model_init=model_init,
compute_metrics=compute_metrics,
callbacks=callbacks,
optimizers=optimizers,
Expand Down Expand Up @@ -158,6 +156,14 @@ def __init__(
):
self.generation_config.eos_token_id = self.model.generation_config.eos_token_id

def _prepare_dataset(self, dataset, *args):
# SFTTrainer._prepare_dataset() applies the chat template and rename the messages column to text. However, we
# need to keep the messages column as it is. We use the following workaround to keep the messages column.
dataset = dataset.add_column("_messages", dataset["messages"])
dataset = super()._prepare_dataset(dataset, *args)
dataset = dataset.rename_column("_messages", "messages")
return dataset

@staticmethod
def generalized_jsd_loss(
student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean"
Expand Down
Loading

0 comments on commit 5b9236d

Please sign in to comment.