From 129dddcd11585aa604b810763c34988923bd980a Mon Sep 17 00:00:00 2001 From: Ming Du Date: Wed, 13 Mar 2024 16:52:39 -0500 Subject: [PATCH] HuggingFace Accelerate checkpointing --- .../classification_huggingface_accelerate.py | 1 + generic_trainer/configs.py | 11 ++++- generic_trainer/trainer.py | 48 +++++++++++++++++-- 3 files changed, 56 insertions(+), 4 deletions(-) diff --git a/examples/classification_huggingface_accelerate.py b/examples/classification_huggingface_accelerate.py index e150be5..dd1192c 100644 --- a/examples/classification_huggingface_accelerate.py +++ b/examples/classification_huggingface_accelerate.py @@ -32,6 +32,7 @@ optimizer=torch.optim.AdamW, optimizer_params={'weight_decay': 0.01}, num_epochs=5, + # checkpoint_dir='temp', model_save_dir='temp', task_type='classification' ) diff --git a/generic_trainer/configs.py b/generic_trainer/configs.py index c2acd8f..9ea52ab 100644 --- a/generic_trainer/configs.py +++ b/generic_trainer/configs.py @@ -166,6 +166,13 @@ class TrainingConfig(Config): """ num_epochs: int = 60 + """ + The number of epochs. When loading a checkpoint using `TrainingConfig.checkpoint_dir`, the epoch counter continues + from the checkpoint, and this parameters sets the final number of epochs: for example, if the checkpoint is at + epoch 200 and `num_epochs` is 300, then only 100 more epochs will be run in the current job. However, if + `TrainingConfigl.pretrained_model_path` is used instead, then the epoch counter and all the other states start from + scratch. + """ learning_rate_per_process: float = 1e-3 """ @@ -185,7 +192,9 @@ class TrainingConfig(Config): checkpoint_dir: Any = None """ The checkpoint directory. If not None, the trainer will load the checkpoint that contains the model, - optimizer, and all the other state dictionaries from it. + optimizer, and all the other state dictionaries from it. The given directory should be the one that contains + "checkpoint_model.pth"; if using HuggingFaceAccelerateTrianer, this should be the directory that contains + "checkpoint_model". """ pretrained_model_path: Any = None diff --git a/generic_trainer/trainer.py b/generic_trainer/trainer.py index cf80211..4fb51fb 100644 --- a/generic_trainer/trainer.py +++ b/generic_trainer/trainer.py @@ -736,7 +736,7 @@ def update_saved_model(self, filename='best_model.pth', save_configs=True, subco """ Updates saved model if validation loss is minimum. """ - if not self.gatekeeper.should_proceed(gate_kept=True): + if not self.gatekeeper.should_proceed(gate_kept=run_with_only_rank_0): return path = self.configs.model_save_dir dest_path = os.path.join(path, filename) @@ -1070,8 +1070,6 @@ def build(self): self.build_scheduler() self.build_accelerate() - self.load_state_checkpoint() - self.build_dir() def build_model(self): @@ -1085,9 +1083,53 @@ def build_accelerate(self): self.model, self.optimizer, self.training_dataloader, self.scheduler ) + if self.configs.checkpoint_dir is not None: + self.load_state_checkpoint() + def run_backprop(self, loss_node): self.accelerator.backward(loss_node) def move_to_device(self, var): # HuggingFace Accelerate should not need manual data offloading. return var + + def update_saved_model(self, filename='best_model', save_configs=True, subcomponent=None, **kwargs): + """ + Save model checkpoint. + HuggingFace Accelerate takes a directory to save the model. This directory will be named as + basename(splitext(filename)[0]). + + :param filename: str. Name of the checkpoint directory. If it comes with an extension, the extension will + be removed. + :param save_configs: bool. If True, trainer configs will also be saved as a JSON. + :param subcomponent: str. If not None, only the subcomponent of the model with this name will be saved. + """ + path = os.path.join(self.configs.model_save_dir, os.path.splitext(filename)[0]) + self.accelerator.save_state(path) + if save_configs and self.gatekeeper.should_proceed(gate_kept=True): + self.configs.dump_to_json(os.path.join(self.configs.model_save_dir, 'configs.json')) + + def save_model_and_states_checkpoint(self): + # Save epoch counter and loss tracker. + state_dict = self.generate_state_dict() + torch.save(state_dict, os.path.join(self.configs.model_save_dir, 'checkpoint.state')) + + # Save model, optimizer, scheduler, and dataloader states. + self.update_saved_model(filename='checkpoint_model') + + def load_model(self): + self.load_state_checkpoint() + + def load_state_checkpoint(self): + if self.configs.checkpoint_dir is None: + return + # Accelerator loads model, optimizer, scheduler, and dataloader states. + self.accelerator.load_state(os.path.join(self.configs.checkpoint_dir, 'checkpoint_model')) + + # Also load epoch counter and loss tracker. + checkpoint_fname = os.path.join(self.configs.checkpoint_dir, 'checkpoint.state') + if not os.path.exists(checkpoint_fname): + logger.warning('Checkpoint not found in {}.'.format(checkpoint_fname)) + state_dict = torch.load(checkpoint_fname) + self.current_epoch = state_dict['current_epoch'] + self.loss_tracker = state_dict['loss_tracker']