Skip to content

Commit

Permalink
Fix checkpoint saving and consolidation for TP (#378)
Browse files Browse the repository at this point in the history
* Fix issue #373

* Fix issue #368

* Fix #367
  • Loading branch information
michaelbenayoun authored Dec 15, 2023
1 parent 51d7793 commit ecdeee8
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 2 deletions.
2 changes: 1 addition & 1 deletion optimum/neuron/distributed/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,6 @@ def consolidate_tensor_parallel_checkpoints_to_unified_checkpoint(
torch.save(shard, output_dir / shard_file)
if index is not None:
save_index_file = SAFE_WEIGHTS_INDEX_NAME if save_format == "safetensors" else WEIGHTS_INDEX_NAME
with open(save_index_file, "w") as fp:
with open(output_dir / save_index_file, "w") as fp:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
fp.write(content)
8 changes: 7 additions & 1 deletion optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Defines Trainer subclasses to perform training on AWS Neuron instances."""

import contextlib
import copy
import glob
import os
import random
Expand Down Expand Up @@ -395,7 +396,12 @@ def _save_xla(self, output_dir: Optional[str] = None):
if self.accelerator.distributed_type is NeuronDistributedType.TENSOR_PARALLELISM:
logger.info("Model parallelism is enabled, only saving the model sharded state dict.")
if isinstance(self.model, PreTrainedModel):
self.model.config.save_pretrained(output_dir)
from neuronx_distributed.parallel_layers.parallel_state import get_tensor_model_parallel_size

config = copy.deepcopy(self.model.config)
if self.args.tp_plugin.parallelize_embeddings:
config.vocab_size = config.vocab_size * get_tensor_model_parallel_size()
config.save_pretrained(output_dir)

parallelizer = ParallelizersManager.parallelizer_for_model(self.model)
# This mark_step is needed to avoid hang issues.
Expand Down
4 changes: 4 additions & 0 deletions optimum/neuron/utils/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,10 @@ def prepare_environment_for_neuron():
"""
# Set compiler flag to compile for transformer model type
os.environ["NEURON_CC_FLAGS"] = os.environ.get("NEURON_CC_FLAGS", "") + " --model-type=transformer"
# Setting MALLOC_ARENA_MAX is needed because of a memory issue in XLA/glic, otherwise OOM can happen during
# checkpointing. More information here:
# https://awsdocs-neuron.readthedocs-hosted.com/en/latest/release-notes/torch/torch-neuronx/index.html#memory-leaking-in-glibc
os.environ["MALLOC_ARENA_MAX"] = "64"


def set_neuron_cc_optlevel_for_model(model: "PreTrainedModel", optlevel: str = "auto"):
Expand Down

0 comments on commit ecdeee8

Please sign in to comment.