Skip to content

Commit

Permalink
HuggingFace Accelerate trainer load_model choose what method to use b…
Browse files Browse the repository at this point in the history
…ased on path extension
  • Loading branch information
mdw771 committed Mar 22, 2024
1 parent 1e2ffa1 commit 3c72ce0
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions generic_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,9 +1092,15 @@ def save_model_and_states_checkpoint(self):
self.update_saved_model(filename='checkpoint_model')

def load_model(self, path=None, state_dict=None, subcomponent=None):
# This would also load optimizer, scheduler and dataloader states.
# TODO: find a way to load only the encoder.
self.accelerator.load_state(path)
if len(os.path.splitext(path)[1]) == '':
logging.warning('The provided mode path {} does not have an extension so I am assuming it is the '
'HuggingFace checkpoint format. The states of all other components like optimizer, '
'scheduler etc. will also be loaded.'.format(path))
self.accelerator.load_state(path)
else:
logging.warning('The provided mode path {} is assumed to be a native PyTorch checkpoint. Loading it '
'with the native load_model method.'.format(path))
Pretrainer.load_model(self, path=path, state_dict=state_dict, subcomponent=subcomponent)

def load_state_checkpoint(self):
if self.configs.checkpoint_dir is None:
Expand Down

0 comments on commit 3c72ce0

Please sign in to comment.