From c7f1b7eeb0e19f0401f44385e282d1d9e060f944 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Wed, 29 May 2024 14:36:39 +0200 Subject: [PATCH] Performance improvements and `neuron_parallel_compile` and gradient checkpointing fixes (#602) --- optimum/neuron/accelerate/accelerator.py | 45 +++++- optimum/neuron/accelerate/state.py | 5 +- optimum/neuron/accelerate/utils/misc.py | 50 ++++++- optimum/neuron/distributed/base.py | 4 - .../distributed/parallelizers_manager.py | 22 ++- optimum/neuron/trainers.py | 132 ++++++++---------- .../neuron/utils/neuron_parallel_compile.py | 34 +++++ optimum/neuron/utils/runner.py | 2 +- .../torch_xla_and_neuronx_initialization.py | 2 +- optimum/neuron/utils/training_utils.py | 3 +- setup.py | 7 +- tests/distributed/test_common.py | 2 + .../distributed/test_model_parallelization.py | 3 +- tests/test_examples.py | 12 +- 14 files changed, 214 insertions(+), 109 deletions(-) create mode 100755 optimum/neuron/utils/neuron_parallel_compile.py diff --git a/optimum/neuron/accelerate/accelerator.py b/optimum/neuron/accelerate/accelerator.py index adb0c27b6..25d4499c8 100644 --- a/optimum/neuron/accelerate/accelerator.py +++ b/optimum/neuron/accelerate/accelerator.py @@ -57,7 +57,11 @@ patch_accelerate_is_torch_xla_available, tie_parameters, ) -from .utils.misc import apply_activation_checkpointing, create_patched_finfo, create_patched_save_pretrained +from .utils.misc import ( + apply_activation_checkpointing, + create_patched_finfo, + create_patched_save_pretrained, +) from .utils.operations import _xla_gather @@ -132,6 +136,15 @@ def __init__( if not isinstance(autocast_backend, AutocastBackend): autocast_backend = AutocastBackend(autocast_backend) + # The original `is_torch_xla_available` function is checking for TPU or GPU in `accelerate`. + # Here, we patch it to return True for Neuron cores as well. + def patched_is_torch_xla_available(check_is_tpu: bool = False, check_is_gpu: bool = False) -> bool: + return is_torch_xla_available() + + import accelerate + + accelerate.state.is_torch_xla_available = patched_is_torch_xla_available + patched_accelerator_state = partial( NeuronAcceleratorState, mp_plugin=mp_plugin, autocast_backend=autocast_backend ) @@ -336,13 +349,24 @@ def patch_model_for_neuron( ), ) + # TODO: @michaelbenayoun generalize an implementation of gradient checkpointing working for: + # - DDP + # - TP + # - PP + # if hasattr(model, "gradient_checkpointing_enable"): + # patching_specs.append( + # ( + # "gradient_checkpointing_enable", + # patched_gradient_checkpointing_enable, + # ), + # ) + prepared_patching_specs = [] for spec in patching_specs: prepared_patching_specs.append((model,) + spec) model_patcher = ModelPatcher(prepared_patching_specs, ignore_missing_attributes=True) model_patcher.patch() - return model @requires_neuronx_distributed @@ -428,6 +452,12 @@ def prepare_model( model.config.output_attentions = False model.config.output_hidden_states = False + should_apply_activation_checkpointing = False + for mod in model.modules(): + if getattr(mod, "gradient_checkpointing", False): + should_apply_activation_checkpointing = True + model.gradient_checkpointing_disable() + # It is needed for now otherwise sdpa is used since PT > 2.* is available. for module in model.modules(): if getattr(module, "_use_sdpa", False): @@ -439,13 +469,16 @@ def prepare_model( model = self._prepare_model_for_mp( model, device_placement=device_placement, evaluation_mode=evaluation_mode ) - apply_activation_checkpointing(model) - return model + if should_apply_activation_checkpointing: + apply_activation_checkpointing(model) else: - apply_activation_checkpointing(model) + if should_apply_activation_checkpointing: + apply_activation_checkpointing(model) move_model_to_device(model, xm.xla_device()) device_placement = False - return super().prepare_model(model, device_placement=device_placement, evaluation_mode=evaluation_mode) + model = super().prepare_model(model, device_placement=device_placement, evaluation_mode=evaluation_mode) + xm.mark_step() + return model def backward(self, loss, **kwargs): if self.distributed_type != DistributedType.DEEPSPEED: diff --git a/optimum/neuron/accelerate/state.py b/optimum/neuron/accelerate/state.py index c4d3de0bf..51f87f9d1 100644 --- a/optimum/neuron/accelerate/state.py +++ b/optimum/neuron/accelerate/state.py @@ -96,10 +96,7 @@ def __init__(self, cpu: bool = False, **kwargs): self.fork_launched = parse_flag_from_env("FORK_LAUNCHED", 0) def wait_for_everyone(self): - if self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM: - xm.rendezvous("accelerate.utils.wait_for_everyone") - else: - super().wait_for_everyone() + xm.rendezvous("accelerate.utils.wait_for_everyone") class NeuronAcceleratorState(AcceleratorState): diff --git a/optimum/neuron/accelerate/utils/misc.py b/optimum/neuron/accelerate/utils/misc.py index 37e4a76c2..2a564f7dc 100644 --- a/optimum/neuron/accelerate/utils/misc.py +++ b/optimum/neuron/accelerate/utils/misc.py @@ -16,17 +16,21 @@ import functools import gc +import inspect from typing import TYPE_CHECKING, Callable, Dict, Optional, Union import torch from transformers.modeling_utils import get_parameter_dtype +from ....utils import logging from ...distributed.utils import named_parameters from ...utils import is_torch_neuronx_available, is_torch_xla_available, patch_everywhere from ...utils.patching import Patcher -from ...utils.require_utils import requires_neuronx_distributed, requires_safetensors +from ...utils.require_utils import requires_neuronx_distributed, requires_safetensors, requires_torch_xla +logger = logging.get_logger(__name__) + if TYPE_CHECKING: import os @@ -191,6 +195,41 @@ def tie_parameters(model: Union["torch.nn.Module", "NxDPPModel"], tied_parameter setattr(param_to_tie_parent_module, param_to_tie_name[1], param) +# TODO: @michaelbenayoun +# Needs to make it work in the general case or be deleted and only use `apply_activation_checkpointing`. +@requires_torch_xla +def patched_gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): + from torch_xla.utils.checkpoint import checkpoint + + if not self.supports_gradient_checkpointing: + raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") + + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {"use_reentrant": True} + + gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs) + + # For old GC format (transformers < 4.35.0) for models that live on the Hub + # we will fall back to the overwritten `_set_gradient_checkpointing` method + _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters + + if not _is_using_old_format: + self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) + else: + self.apply(functools.partial(self._set_gradient_checkpointing, value=True)) + logger.warning( + "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)." + "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model." + ) + + if getattr(self, "_hf_peft_config_loaded", False): + # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True + # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334 + # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate + # the gradients to make sure the gradient flows. + self.enable_input_require_grads() + + @requires_neuronx_distributed def apply_activation_checkpointing(model: Union["PreTrainedModel", "NxDPPModel"]): from neuronx_distributed.pipeline import NxDPPModel @@ -205,9 +244,12 @@ def apply_activation_checkpointing(model: Union["PreTrainedModel", "NxDPPModel"] gradient_checkpointing_modules = set() for module in modules: - if getattr(module, "gradient_checkpointing", False): - module.gradient_checkpointing = False - gradient_checkpointing_modules.add(module) + if isinstance(module, torch.nn.ModuleList): + for mod in module: + # TODO: @michaelbenayoun. Need to find a better way to identify the blocks to apply gradient + # checkpointing to. + if "Layer" in mod.__class__.__name__ or "Block" in mod.__class__.__name__: + gradient_checkpointing_modules.add(mod) def check_fn(m: torch.nn.Module) -> bool: return m in gradient_checkpointing_modules diff --git a/optimum/neuron/distributed/base.py b/optimum/neuron/distributed/base.py index 469d6f70a..93e1c9177 100644 --- a/optimum/neuron/distributed/base.py +++ b/optimum/neuron/distributed/base.py @@ -631,7 +631,6 @@ def should_parallelize_layer_predicate_func(layer): skip_linear_weight_load=skip_linear_weight_load, kv_size_multiplier=kv_size_multiplier, ) - xm.rendezvous("End of tensor parallelism") if is_main_worker(): logger.info("Tensor parallelism done.") @@ -708,8 +707,6 @@ def should_parallelize_layer_predicate_func(layer): # Initialize or load the weights for the parallelized model if it was lazily loaded. cls._initialize_or_load_weights(model, names_of_the_parameters_to_consider, device=device) gc.collect() - xm.rendezvous(f"weight_loading_and_initialization_{worker}") - xm.rendezvous("End of initalization") if is_main_worker(): logger.info("Load and initialization of the weights done.") @@ -750,7 +747,6 @@ def should_parallelize_layer_predicate_func(layer): tracer_cls=OptimumNeuronFXTracer, ) - xm.rendezvous("End of pipeline paralellism") if is_main_worker(): logger.info("Pipeline parallelism done.") diff --git a/optimum/neuron/distributed/parallelizers_manager.py b/optimum/neuron/distributed/parallelizers_manager.py index 9c7d92e36..f8df3bd5d 100644 --- a/optimum/neuron/distributed/parallelizers_manager.py +++ b/optimum/neuron/distributed/parallelizers_manager.py @@ -15,7 +15,7 @@ """Factory class mapping model architectures to their Parallelizer class.""" import importlib -from typing import Dict, List, Type, Union +from typing import Dict, List, Tuple, Type, Union from transformers import PreTrainedModel @@ -83,16 +83,27 @@ def _get_model_type(cls, model_type_or_model: Union[str, PreTrainedModel]) -> st return model_type @classmethod - def is_model_supported(cls, model_type_or_model: Union[str, PreTrainedModel]) -> bool: + def is_model_supported(cls, model_type_or_model: Union[str, PreTrainedModel]) -> Tuple[bool, bool, bool]: """ - Returns `True` if the model can be parallelized, `False` otherwise. + Returns a tuple of 3 booleans where: + - The first element indicates if tensor parallelism can be used for this model, + - The second element indicates if sequence parallelism can be used on top of tensor parallelism for this model, + - The third element indicates if pipeline parallelism can be used for this model. Args: model_type_or_model (`Union[str, PreTrainedModel]`): Either the model type or an instance of the model. """ model_type = cls._get_model_type(model_type_or_model) - return model_type in cls._MODEL_TYPE_TO_PARALLEL_MODEL_CLASS + for_tp = model_type in cls._MODEL_TYPE_TO_PARALLEL_MODEL_CLASS + if for_tp: + parallelizer = cls._MODEL_TYPE_TO_PARALLEL_MODEL_CLASS[model_type] + for_sp = parallelizer.supports_sequence_parallelism() + for_pp = parallelizer.supports_pipeline_parallelism() + else: + for_sp = for_pp = False + + return (for_tp, for_sp, for_pp) @classmethod def parallelizer_for_model(cls, model_type_or_model: Union[str, PreTrainedModel]) -> Type[Parallelizer]: @@ -105,7 +116,8 @@ def parallelizer_for_model(cls, model_type_or_model: Union[str, PreTrainedModel] """ model_type = cls._get_model_type(model_type_or_model) - if not cls.is_model_supported(model_type_or_model): + is_tp_supported, _, _ = cls.is_model_supported(model_type_or_model) + if not is_tp_supported: supported_models = ", ".join(cls._MODEL_TYPE_TO_PARALLEL_MODEL_CLASS.keys()) raise NotImplementedError( f"{model_type} is not supported for parallelization, supported models: {supported_models}" diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index 4be6c9f5b..4e5c03478 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -59,7 +59,6 @@ has_length, speed_metrics, ) -from transformers.training_args import ParallelMode from transformers.utils import WEIGHTS_NAME, is_accelerate_available, is_apex_available, is_sagemaker_mp_enabled from ..utils import logging @@ -94,9 +93,6 @@ if is_apex_available(): from apex import amp -if is_sagemaker_mp_enabled(): - import smdistributed.modelparallel.torch as smp - if is_torch_xla_available(): import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met @@ -181,14 +177,6 @@ def prepare_for_precompilation(self, args: "TrainingArguments"): if not is_precompilation(): return - # `neuron_parallel_compile` relies on the logs to retrieve the HLO graphs to compile. - # For some reason, the logger logs strange characters that make `neuron_parallel_compile` fail when it tries to - # load the log file to extract the graphs to compile. To avoid that, we disable logging when doing - # precompilation. - logging.logging.disable(sys.maxsize) - # We disable tqdm as well just to be safe. - args.disable_tqdm = True - if args.num_train_epochs != 1: if is_main_worker(): logger.info("Setting the number of epochs for precompilation to 1.") @@ -339,11 +327,6 @@ def get_optimizer_cls_and_kwargs( def create_optimizer(self): return super().create_optimizer() - def log(self, logs: Dict[str, float]): - if is_precompilation(): - return - return super().log(logs) - def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]: # When pipeline parallelism is enabled, we should not put any tensor on device. # It is handled by the NxDPPModel class. @@ -370,9 +353,8 @@ def compute_loss(self, model, inputs, return_outputs: bool = False): if isinstance(model, NxDPPModel): inputs = self._prepare_inputs(inputs) loss = model.run_train(**inputs) - return loss - - loss = super().compute_loss(model, inputs, return_outputs=return_outputs) + else: + loss = super().compute_loss(model, inputs, return_outputs=return_outputs) return loss def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True): @@ -405,8 +387,10 @@ def training_step(self, model: torch.nn.Module, inputs: Dict[str, Union[torch.Te loss = torch.tensor(0, dtype=dtype).to(xm.xla_device()) else: loss = loss.detach() - return loss / self.args.gradient_accumulation_steps - return super().training_step(model, inputs) + output = loss / self.args.gradient_accumulation_steps + else: + output = super().training_step(model, inputs) + return output @requires_neuronx_distributed def prediction_step( @@ -431,49 +415,54 @@ def prediction_step( return (loss, None, None) return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) - def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval): - if self.control.should_log and self.state.global_step > self._globalstep_last_logged: - logs: Dict[str, float] = {} + def _reduce_loss(self, tr_loss: torch.Tensor) -> torch.Tensor: + from neuronx_distributed.parallel_layers.parallel_state import ( + get_data_parallel_group, + get_data_parallel_size, + model_parallel_is_initialized, + ) - from neuronx_distributed.parallel_layers.parallel_state import ( - get_data_parallel_group, - get_data_parallel_size, - model_parallel_is_initialized, - ) + if model_parallel_is_initialized(): + dp_size = get_data_parallel_size() + else: + dp_size = xm.xrt_world_size() - if model_parallel_is_initialized(): - dp_size = get_data_parallel_size() - else: - dp_size = xm.xrt_world_size() + tr_loss_div = tr_loss / dp_size + + if self.args.mp_plugin.should_parallelize: + # It works even for PP because under PP we make it so that the main process to log for callbacks is + # the one on dp_rank = tp_rank = 0 and pp_rank = pp_size -1. + reduced_tr_loss = xm.all_reduce(xm.REDUCE_SUM, tr_loss_div, groups=get_data_parallel_group(as_list=True)) + else: + reduced_tr_loss = xm.all_reduce(xm.REDUCE_SUM, tr_loss_div) - tr_loss_div = tr_loss / dp_size + # reset tr_loss to zero + tr_loss.zero_() - xm.mark_step() + return reduced_tr_loss - if self.args.mp_plugin.should_parallelize: - # It works even for PP because under PP we make it so that the main process to log for callbacks is - # the one on dp_rank = tp_rank = 0 and pp_rank = pp_size -1. - tr_loss_div = xm.all_reduce(xm.REDUCE_SUM, tr_loss_div, groups=get_data_parallel_group(as_list=True)) - tr_loss_scalar = tr_loss_div.detach().item() - else: - tr_loss_div = xm.all_reduce(xm.REDUCE_SUM, tr_loss_div) - tr_loss_scalar = tr_loss.detach().item() + def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval): + if self.control.should_log and self.state.global_step > self._globalstep_last_logged: - # reset tr_loss to zero - tr_loss -= tr_loss + def log_closure(self, tr_loss, grad_norm): + if is_main_worker_for_metrics(): + logs: Dict[str, float] = {} + tr_loss_scalar = tr_loss.to("cpu").item() - logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) - logs["learning_rate"] = self._get_learning_rate() + logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) + logs["learning_rate"] = self._get_learning_rate() - if grad_norm is not None: - logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm + if grad_norm is not None: + logs["grad_norm"] = ( + grad_norm.detach().to("cpu").item() if isinstance(grad_norm, torch.Tensor) else grad_norm + ) - self._total_loss_scalar += tr_loss_scalar - self._globalstep_last_logged = self.state.global_step - self.store_flos() + self._total_loss_scalar += tr_loss_scalar + self._globalstep_last_logged = self.state.global_step + self.store_flos() + self.log(logs) - if is_main_worker_for_metrics(): - self.log(logs) + xm.add_step_closure(log_closure, (self, tr_loss, grad_norm)) metrics = None if self.control.should_evaluate: @@ -488,8 +477,12 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno self.lr_scheduler.step(metrics[metric_to_check]) if self.control.should_save: - self._save_checkpoint(model, trial, metrics=metrics) - self.control = self.callback_handler.on_save(self.args, self.state, self.control) + + def save_closure(self, model, trial, metrics): + self._save_checkpoint(model, trial, metrics=metrics) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + + xm.add_step_closure(save_closure, (self, model, trial, metrics)) def _save_xla(self, output_dir: Optional[str] = None): output_dir = output_dir if output_dir is not None else self.args.output_dir @@ -1001,6 +994,8 @@ def _inner_training_loop( # last step in epoch but step is always smaller than gradient_accumulation_steps is_last_step_and_steps_less_than_grad_acc ): + xm.mark_step() + # Gradient clipping if args.max_grad_norm is not None and args.max_grad_norm > 0: # deepspeed does its own clipping @@ -1032,15 +1027,16 @@ def _inner_training_loop( # Delay optimizer scheduling until metrics are generated if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.lr_scheduler.step() - self.optimizer.zero_grad() - xm.mark_step() self.state.global_step += 1 self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch self.control = self.callback_handler.on_step_end(args, self.state, self.control) - self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval) + reduced_tr_loss = self._reduce_loss(tr_loss) + self._maybe_log_save_evaluate( + reduced_tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval + ) else: self.control = self.callback_handler.on_substep_end(args, self.state, self.control) @@ -1083,17 +1079,13 @@ def _inner_training_loop( logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: # Wait for everyone to get here so we are sure the model has been saved by process 0. - if is_torch_xla_available(): - xm.rendezvous("load_best_model_at_end") - elif args.parallel_mode == ParallelMode.DISTRIBUTED: - torch.distributed.barrier() - elif is_sagemaker_mp_enabled(): - smp.barrier() + xm.rendezvous("load_best_model") self._load_best_model() # add remaining tr_loss - self._total_loss_scalar += tr_loss.item() + loss_scalar = tr_loss.to("cpu").item() + self._total_loss_scalar += loss_scalar effective_global_step = max(self.state.global_step, 0.001) # Avoid ZeroDivisionError train_loss = self._total_loss_scalar / effective_global_step @@ -1419,12 +1411,6 @@ def predict( self.synchronize_hub_cache() return result - @patch_within_function(("transformers.Trainer.is_world_process_zero", is_main_worker_for_metrics_method)) - def log_metrics(self, split, metrics): - if is_precompilation(): - return - return super().log_metrics(split, metrics) - @patch_within_function(("transformers.Trainer.is_world_process_zero", is_main_worker_for_metrics_method)) def save_metrics(self, split, metrics, combined=True): return super().save_metrics(split, metrics, combined=combined) diff --git a/optimum/neuron/utils/neuron_parallel_compile.py b/optimum/neuron/utils/neuron_parallel_compile.py new file mode 100755 index 000000000..d7e6e4883 --- /dev/null +++ b/optimum/neuron/utils/neuron_parallel_compile.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import codecs +import re +import sys + +import torch_neuronx +from torch_neuronx.parallel_compile.neuron_parallel_compile import LOGGER as torch_neuronx_logger +from torch_neuronx.parallel_compile.neuron_parallel_compile import main + + +def get_hlos_from_run_log(trial_run_log): + # New graphs are detected by specific message matching key + hlo_key = "Extracting graphs" + new_hlo_list = [] + with codecs.open(trial_run_log, "r", encoding="utf-8", errors="ignore") as f: + for line in f.readlines(): + # Move temporary MODULE_* files into workdir before checking if there are any + # new graphs. In try_compilations, compile only new graphs (those without + # corresponding neffs). + if hlo_key in line: + model_path = line.split("Extracting graphs (")[1].split(")")[0] + new_hlo_list.append(model_path) + + format_str = "\n\t" + torch_neuronx_logger.info(f"New graph list from script: {format_str.join(new_hlo_list)}") + return new_hlo_list + + +torch_neuronx.parallel_compile.neuron_parallel_compile.get_hlos_from_run_log = get_hlos_from_run_log + +if __name__ == "__main__": + sys.argv[0] = re.sub(r"(-script\.pyw|\.exe)?$", "", sys.argv[0]) + sys.exit(main()) diff --git a/optimum/neuron/utils/runner.py b/optimum/neuron/utils/runner.py index d738c6f67..aa4f5a634 100644 --- a/optimum/neuron/utils/runner.py +++ b/optimum/neuron/utils/runner.py @@ -539,7 +539,7 @@ def split_args_and_value_in_command(cmd: List[str]) -> List[str]: precompilation_cmd.pop(-1) # Removing the --output_dir argument. max_steps_cmd_str = "--max_steps 10" max_train_samples = compute_max_train_samples( - 10, num_cores, tensor_parallel_size, train_batch_size + 10, num_cores, tensor_parallel_size, pipeline_parallel_size, train_batch_size ) max_train_samples_cmd = f"--max_train_samples {max_train_samples}" if max_steps_idx >= 0: diff --git a/optimum/neuron/utils/torch_xla_and_neuronx_initialization.py b/optimum/neuron/utils/torch_xla_and_neuronx_initialization.py index c20fba38b..b0e78e6ec 100644 --- a/optimum/neuron/utils/torch_xla_and_neuronx_initialization.py +++ b/optimum/neuron/utils/torch_xla_and_neuronx_initialization.py @@ -38,7 +38,7 @@ def init_process_group(): import torch_xla.distributed.xla_backend as xbn if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla): - torch.distributed.init_process_group(backend="xla") + torch.distributed.init_process_group(backend="xla", init_method="xla://") if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla): raise AssertionError("Failed to initialize torch.distributed process group using XLA backend.") diff --git a/optimum/neuron/utils/training_utils.py b/optimum/neuron/utils/training_utils.py index 7866dc689..1988dfe92 100644 --- a/optimum/neuron/utils/training_utils.py +++ b/optimum/neuron/utils/training_utils.py @@ -244,7 +244,8 @@ def numel(parameter_name, parameter) -> int: if pp_size > 1: param_count = torch.tensor(param_count, dtype=torch.float32).to(xm.xla_device()) param_count = xm.all_reduce(xm.REDUCE_SUM, param_count, groups=get_pipeline_model_parallel_group(as_list=True)) - param_count = int(param_count.detach().item()) + xm.mark_step() + param_count = int(param_count.detach().cpu().item()) return param_count diff --git a/setup.py b/setup.py index 8796b8641..a53daeabe 100644 --- a/setup.py +++ b/setup.py @@ -100,5 +100,10 @@ dependency_links=["https://pip.repos.neuron.amazonaws.com"], include_package_data=True, zip_safe=False, - entry_points={"console_scripts": ["optimum-cli=optimum.commands.optimum_cli:main"]}, + entry_points={ + "console_scripts": [ + "optimum-cli=optimum.commands.optimum_cli:main", + "neuron_parallel_compile=optimum.neuron.utils.neuron_parallel_compile:main", + ] + }, ) diff --git a/tests/distributed/test_common.py b/tests/distributed/test_common.py index c4e2150ec..5bc70ffcd 100644 --- a/tests/distributed/test_common.py +++ b/tests/distributed/test_common.py @@ -355,6 +355,8 @@ def test_save_model_and_load_model(self, parallel_sizes, tmpdir, monkeypatch): accelerator.state._reset_state(reset_partial_state=True) del accelerator + xm.rendezvous("wait_after_save") + if pp_size > 1: # We need to disable `NxDPPModel._set_distributed` since it is already done during the creation of the # first model, otherwise creating new `NxDPPModel`s will fail. diff --git a/tests/distributed/test_model_parallelization.py b/tests/distributed/test_model_parallelization.py index 35afb1d36..9961d10b9 100644 --- a/tests/distributed/test_model_parallelization.py +++ b/tests/distributed/test_model_parallelization.py @@ -85,7 +85,6 @@ CLASSES_TO_IGNORE = [ - "T5ForSequenceClassification", # TODO: enable this class when it can be traced for pipeline parallelism. "LlamaForQuestionAnswering", ] @@ -128,7 +127,7 @@ def _generate_supported_model_classes( for task in supported_tasks: config_class = CONFIG_MAPPING[model_type] model_class = task_mapping[task].get(config_class, None) - if model_class is not None and model_class not in CLASSES_TO_IGNORE: + if model_class is not None and model_class.__name__ not in CLASSES_TO_IGNORE: model_classes.append(model_class) return list(set(model_classes)) diff --git a/tests/test_examples.py b/tests/test_examples.py index 86542c169..ae2b9e569 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -263,6 +263,8 @@ class ExampleTestMeta(type): """ def __new__(cls, name, bases, attrs, example_name=None): + if not is_neuronx_distributed_available(): + return models_to_test = [] if example_name is not None: models_to_test = _SCRIPT_TO_MODEL_MAPPING.get(example_name) @@ -280,13 +282,9 @@ def __new__(cls, name, bases, attrs, example_name=None): # model_type, model_name_or_path, 1, True, True, config_overrides # ) - tensor_parallel_size = 2 if tp_support is not TPSupport.NONE else 1 - - if not is_neuronx_distributed_available(): - pp_support = False - else: - pp_support = ParallelizersManager.parallelizer_for_model(model_type).supports_pipeline_parallelism() - pipeline_parallel_size = 4 if pp_support else 1 + is_tp_supported, _, is_pp_supported = ParallelizersManager.is_model_supported(model_type) + tensor_parallel_size = 2 if is_tp_supported else 1 + pipeline_parallel_size = 4 if is_pp_supported else 1 disable_embedding_parallelization = tp_support is TPSupport.PARTIAL if tensor_parallel_size > 1: