Skip to content

Commit

Permalink
HuggingFace Accelerate checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
mdw771 committed Mar 13, 2024
1 parent 16ac05b commit 129dddc
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 4 deletions.
1 change: 1 addition & 0 deletions examples/classification_huggingface_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
optimizer=torch.optim.AdamW,
optimizer_params={'weight_decay': 0.01},
num_epochs=5,
# checkpoint_dir='temp',
model_save_dir='temp',
task_type='classification'
)
Expand Down
11 changes: 10 additions & 1 deletion generic_trainer/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,13 @@ class TrainingConfig(Config):
"""

num_epochs: int = 60
"""
The number of epochs. When loading a checkpoint using `TrainingConfig.checkpoint_dir`, the epoch counter continues
from the checkpoint, and this parameters sets the final number of epochs: for example, if the checkpoint is at
epoch 200 and `num_epochs` is 300, then only 100 more epochs will be run in the current job. However, if
`TrainingConfigl.pretrained_model_path` is used instead, then the epoch counter and all the other states start from
scratch.
"""

learning_rate_per_process: float = 1e-3
"""
Expand All @@ -185,7 +192,9 @@ class TrainingConfig(Config):
checkpoint_dir: Any = None
"""
The checkpoint directory. If not None, the trainer will load the checkpoint that contains the model,
optimizer, and all the other state dictionaries from it.
optimizer, and all the other state dictionaries from it. The given directory should be the one that contains
"checkpoint_model.pth"; if using HuggingFaceAccelerateTrianer, this should be the directory that contains
"checkpoint_model".
"""

pretrained_model_path: Any = None
Expand Down
48 changes: 45 additions & 3 deletions generic_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,7 @@ def update_saved_model(self, filename='best_model.pth', save_configs=True, subco
"""
Updates saved model if validation loss is minimum.
"""
if not self.gatekeeper.should_proceed(gate_kept=True):
if not self.gatekeeper.should_proceed(gate_kept=run_with_only_rank_0):
return
path = self.configs.model_save_dir
dest_path = os.path.join(path, filename)
Expand Down Expand Up @@ -1070,8 +1070,6 @@ def build(self):
self.build_scheduler()
self.build_accelerate()

self.load_state_checkpoint()

self.build_dir()

def build_model(self):
Expand All @@ -1085,9 +1083,53 @@ def build_accelerate(self):
self.model, self.optimizer, self.training_dataloader, self.scheduler
)

if self.configs.checkpoint_dir is not None:
self.load_state_checkpoint()

def run_backprop(self, loss_node):
self.accelerator.backward(loss_node)

def move_to_device(self, var):
# HuggingFace Accelerate should not need manual data offloading.
return var

def update_saved_model(self, filename='best_model', save_configs=True, subcomponent=None, **kwargs):
"""
Save model checkpoint.
HuggingFace Accelerate takes a directory to save the model. This directory will be named as
basename(splitext(filename)[0]).
:param filename: str. Name of the checkpoint directory. If it comes with an extension, the extension will
be removed.
:param save_configs: bool. If True, trainer configs will also be saved as a JSON.
:param subcomponent: str. If not None, only the subcomponent of the model with this name will be saved.
"""
path = os.path.join(self.configs.model_save_dir, os.path.splitext(filename)[0])
self.accelerator.save_state(path)
if save_configs and self.gatekeeper.should_proceed(gate_kept=True):
self.configs.dump_to_json(os.path.join(self.configs.model_save_dir, 'configs.json'))

def save_model_and_states_checkpoint(self):
# Save epoch counter and loss tracker.
state_dict = self.generate_state_dict()
torch.save(state_dict, os.path.join(self.configs.model_save_dir, 'checkpoint.state'))

# Save model, optimizer, scheduler, and dataloader states.
self.update_saved_model(filename='checkpoint_model')

def load_model(self):
self.load_state_checkpoint()

def load_state_checkpoint(self):
if self.configs.checkpoint_dir is None:
return
# Accelerator loads model, optimizer, scheduler, and dataloader states.
self.accelerator.load_state(os.path.join(self.configs.checkpoint_dir, 'checkpoint_model'))

# Also load epoch counter and loss tracker.
checkpoint_fname = os.path.join(self.configs.checkpoint_dir, 'checkpoint.state')
if not os.path.exists(checkpoint_fname):
logger.warning('Checkpoint not found in {}.'.format(checkpoint_fname))
state_dict = torch.load(checkpoint_fname)
self.current_epoch = state_dict['current_epoch']
self.loss_tracker = state_dict['loss_tracker']

0 comments on commit 129dddc

Please sign in to comment.