Skip to content

Commit

Permalink
Minor refactor to conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragjn committed Dec 21, 2023
1 parent 604f135 commit ae60af5
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ def _train(
torch.cuda.synchronize()

dist_s.wait_for_everyone()
if hasattr(trainer, "deepspeed"):
if training_arguments.deepspeed:
with patched_deepspeed_load_checkpoint():
trainer.train(resume_from_checkpoint=last_checkpoint_dir)
else:
Expand All @@ -852,12 +852,11 @@ def _train(
logger.info("Saving model...")
if dist_s.is_main_process:
cleanup_checkpoints(output_dir=training_arguments.output_dir)

dist_s.wait_for_everyone()
trainer.save_model(output_dir=training_arguments.output_dir)
dist_s.wait_for_everyone()

if hasattr(trainer, "deepspeed") and hasattr(trainer.deepspeed, "destroy"):
if training_arguments.deepspeed and hasattr(trainer, "deepspeed") and hasattr(trainer.deepspeed, "destroy"):
trainer.deepspeed.destroy()
trainer.accelerator.free_memory()
dist_s.wait_for_everyone()
Expand Down

0 comments on commit ae60af5

Please sign in to comment.