Skip to content

Commit

Permalink
Initial support for Pipeline Parallelism (#279)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun authored Jan 23, 2024
1 parent b643d7f commit ca6c4ff
Show file tree
Hide file tree
Showing 51 changed files with 3,912 additions and 1,304 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/test_trainium_common.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ jobs:
run: echo "/home/ubuntu/.local/bin" >> $GITHUB_PATH
- name: Set pip repository pointing to the Neuron repository
run: pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com
- name: Update pip
run: pip install -U pip
- name: Install Python dependencies
run: pip install .[tests,neuronx]
- name: Run tests on Neuron cores
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_trainium_distributed.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ jobs:
run: pip install .[tests,neuronx]
- name: Run tests on Neuron cores
run: |
HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_NEURON_CI }} pytest -m "is_trainium_test" tests/distributed/
HF_TOKEN=${{ secrets.HF_TOKEN_OPTIMUM_NEURON_CI }} pytest -m "is_trainium_test" tests/distributed/ -v --durations=0 -x --ignore tests/distributed/test_training.py
6 changes: 3 additions & 3 deletions docs/source/guides/distributed_training.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,11 @@ Just as for ZeRO-1, it is possible to wrap the optimizer class to make it lazy.
```python
from torch.optim import AdamW
from optimum.neuron import NeuronAccelerator
from optimum.neuron.accelerate.utils import TensorParallelismPlugin
from optimum.neuron.accelerate.utils import ModelParallelismPlugin
from optimum.neuron.distributed import lazy_load_for_parallelism

tensor_parallel_size = 8
tp_plugin = TensorParallelismPlugin(
mp_plugin = ModelParallelismPlugin(
tensor_parallel_size,
parallelize_embeddings=True,
sequence_parallel_enabled=True,
Expand All @@ -195,7 +195,7 @@ tp_plugin = TensorParallelismPlugin(

accelerator = NeuronAccelerator(
...
tp_plugin=tp_plugin,
mp_plugin=mp_plugin,
)

with lazy_load_for_parallelism(tensor_parallel_size=tensor_parallel_size):
Expand Down
2 changes: 1 addition & 1 deletion docs/source/package_reference/distributed.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ The [`~optimum.neuron.distributed.Parallelizer`] class is the base abstract clas
[[autodoc]] distributed.Parallelizer
- _parallelize
- parallelize
- optimizer_for_tp
- optimizer_for_mp
- save_model_checkpoint
- load_model_checkpoint

Expand Down
59 changes: 48 additions & 11 deletions examples/image-classification/run_image_classification.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging
import os
import sys
import warnings
from dataclasses import dataclass, field
from typing import Optional

Expand All @@ -28,6 +29,7 @@
from torchvision.transforms import (
CenterCrop,
Compose,
Lambda,
Normalize,
RandomHorizontalFlip,
RandomResizedCrop,
Expand Down Expand Up @@ -56,7 +58,7 @@
logger = logging.getLogger(__name__)

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.31.0")
check_min_version("4.35.0")

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")

Expand Down Expand Up @@ -143,12 +145,28 @@ class ModelArguments:
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
image_processor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
token: str = field(
default=None,
metadata={
"help": (
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
)
},
)
use_auth_token: bool = field(
default=None,
metadata={
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
},
)
trust_remote_code: bool = field(
default=False,
metadata={
"help": (
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
"with private models)."
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
"should only be set to `True` for repositories you trust and in which you have read the code, as it will "
"execute code present on the Hub on your local machine."
)
},
)
Expand Down Expand Up @@ -177,6 +195,15 @@ def main():
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

if model_args.use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
FutureWarning,
)
if model_args.token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
model_args.token = model_args.use_auth_token

# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_image_classification", model_args, data_args)
Expand All @@ -200,8 +227,8 @@ def main():

# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
+ f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")

Expand Down Expand Up @@ -230,7 +257,7 @@ def main():
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
task="image-classification",
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
)
else:
data_files = {}
Expand Down Expand Up @@ -277,32 +304,42 @@ def compute_metrics(p):
finetuning_task="image-classification",
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
with lazy_load_for_parallelism(tensor_parallel_size=training_args.tensor_parallel_size):
with lazy_load_for_parallelism(
tensor_parallel_size=training_args.tensor_parallel_size,
pipeline_parallel_size=training_args.pipeline_parallel_size,
):
model = AutoModelForImageClassification.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)

image_processor = AutoImageProcessor.from_pretrained(
model_args.image_processor_name or model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)

# Define torchvision transforms to be applied to each image.
if "shortest_edge" in image_processor.size:
size = image_processor.size["shortest_edge"]
else:
size = (image_processor.size["height"], image_processor.size["width"])
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
normalize = (
Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
if hasattr(image_processor, "image_mean") and hasattr(image_processor, "image_std")
else Lambda(lambda x: x)
)
_train_transforms = Compose(
[
RandomResizedCrop(size),
Expand Down
Loading

0 comments on commit ca6c4ff

Please sign in to comment.