Skip to content

Commit

Permalink
option for padding to max length
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragjn committed Dec 28, 2023
1 parent 0f00b66 commit 69ddccc
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 9 deletions.
13 changes: 10 additions & 3 deletions data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,10 @@ def construct_dataset(self, input_batch):
class CausalDatasetBuilder(DatasetBuilder):
"""Builds generative dataset for Causal LM."""

def __init__(self, tokenizer, max_length, train_on_prompt=True):
def __init__(self, tokenizer, max_length, train_on_prompt=True, pad_to_max_length=False):
super().__init__(tokenizer, max_length)
self.train_on_prompt = train_on_prompt
self.pad_to_max_length = pad_to_max_length

def construct_dataset(self, input_batch):
tokenized_prompts = self.batch_tokenize(input_batch[PROMPT_KEY], truncation=False)
Expand All @@ -215,7 +216,9 @@ def construct_dataset(self, input_batch):
labels = []
for prompt, completion in zip(input_batch[PROMPT_KEY], input_batch[COMPLETION_KEY]):
labels.append(prompt + "\n" + completion + self.tokenizer.eos_token)
input_ids = [val.squeeze() for val in self.batch_tokenize(labels)]

padding = "max_length" if self.pad_to_max_length else "longest"
input_ids = [val.squeeze() for val in self.batch_tokenize(labels, padding=padding)]
labels = copy.deepcopy(input_ids)
if not self.train_on_prompt:
# Masking for loss computation
Expand Down Expand Up @@ -304,7 +307,11 @@ def build_dataset(
):
# TODO (chiragjn): This should not be loading the entire dataset in memory all at once. Make this streaming
# TODO (chiragjn): Add dataset packing to increase training efficiency
builder = CausalDatasetBuilder(tokenizer=tokenizer, max_length=max_length, train_on_prompt=train_on_prompt)
builder = CausalDatasetBuilder(
tokenizer=tokenizer,
max_length=max_length,
train_on_prompt=train_on_prompt,
)
dataset_dict = DatasetDict(train=Dataset.from_list(train_data), eval=Dataset.from_list(eval_data))
# TODO (chiragjn): Read cpu limits from cgroup, cpu_count is not usable in containers environment
num_proc = max(1, min(4, os.cpu_count()))
Expand Down
24 changes: 18 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

import bitsandbytes as bnb
import mlfoundry
Expand Down Expand Up @@ -185,6 +185,10 @@ class OtherArguments:
"from tokenizer config (default: None)"
},
)
pad_to_max_length: bool = field(
default=False,
metadata={"help": "If to always pad to the given/found max sequence length (default: False)"},
)
max_num_samples: Optional[int] = field(
default=None,
metadata={"help": "For quick debugging purposes, how many samples to use (default: all)"},
Expand Down Expand Up @@ -516,6 +520,7 @@ def get_peft_wrapped_model(
model,
training_arguments: HFTrainingArguments,
other_arguments: OtherArguments,
modules_to_save: Optional[List[str]] = None,
_device_map=None,
_checkpoint_dir: Optional[str] = None,
):
Expand All @@ -542,6 +547,7 @@ def get_peft_wrapped_model(
lora_dropout=other_arguments.lora_dropout,
bias=other_arguments.lora_bias,
task_type="CAUSAL_LM",
modules_to_save=modules_to_save,
)
)
logger.info("Applying peft config ...")
Expand Down Expand Up @@ -608,6 +614,9 @@ def get_tokenizer(model_source: str):
special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
# TODO (chiragjn): Consider adding fake tokens to vocab to pad to multiple of 64. Can provide better throughput
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
if "bos_token" or "eos_token" in special_tokens_dict:
if hasattr(tokenizer, "update_post_processor"):
tokenizer.update_post_processor()
return tokenizer, num_new_tokens


Expand Down Expand Up @@ -705,7 +714,8 @@ def dist_build_dataset(
train_data,
eval_data,
tokenizer,
max_length,
max_length: int,
pad_to_max_length: bool,
train_on_prompt: bool,
training_arguments: HFTrainingArguments,
):
Expand All @@ -718,6 +728,7 @@ def dist_build_dataset(
eval_data=eval_data,
tokenizer=tokenizer,
max_length=max_length,
pad_to_max_length=pad_to_max_length,
train_on_prompt=train_on_prompt,
)
dataset_dict.save_to_disk(dataset_cache_path)
Expand Down Expand Up @@ -796,6 +807,7 @@ def _train(
eval_data=eval_data,
tokenizer=tokenizer,
max_length=max_length,
pad_to_max_length=other_arguments.pad_to_max_length,
train_on_prompt=other_arguments.train_on_prompt,
training_arguments=training_arguments,
)
Expand Down Expand Up @@ -831,23 +843,23 @@ def _train(
training_arguments=training_arguments,
other_arguments=other_arguments,
)
lora_modules_to_save = None
if model.get_input_embeddings().num_embeddings < len(tokenizer):
logger.info(
f"Resizing embeddings layer for newly added tokens. "
f"Tokenizer length is {len(tokenizer)} but model embedding "
f"layer has {model.get_input_embeddings().num_embeddings}"
)
model.resize_token_embeddings(len(tokenizer))
# TODO (chiragjn): Check if we want to enable this!
# lora_modules_to_save = ["embed_tokens", "lm_head"]

# TODO (chiragjn):
# If there are new tokens added, check if we want grads to be enabled on embedding and lm head.
# prepare_model_for_k_bit actually disables grad on embedding and lm head
# We need to pass them to modules_to_save in LoraConfig
if other_arguments.use_lora or other_arguments.use_qlora:
model = get_peft_wrapped_model(
model,
training_arguments=training_arguments,
other_arguments=other_arguments,
modules_to_save=lora_modules_to_save,
)
logger.info("Training...")
# TODO (chiragjn): Add text generation metrics to `compute_metrics
Expand Down

0 comments on commit 69ddccc

Please sign in to comment.