diff --git a/docs/source/tutorials/stable_diffusion.mdx b/docs/source/tutorials/stable_diffusion.mdx
index c115dd760..5d6a734b6 100644
--- a/docs/source/tutorials/stable_diffusion.mdx
+++ b/docs/source/tutorials/stable_diffusion.mdx
@@ -357,7 +357,7 @@ To avoid Neuron device out of memory, it's suggested to finish all base inferenc
Latent Consistency Models (LCMs) were proposed in [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference by Simian Luo, Yiqin Tan, Longbo Huang, Jian Li, and Hang Zhao](https://huggingface.co/papers/2310.04378). LCMs enable inference with fewer steps on any pre-trained LDMs, including Stable Diffusion and SDXL.
In `optimum-neuron`, you can:
- - Use the class `NeuronLatentConsistencyModelPipeline` to compile and run inference of LCMs distilled from Stable Diffusion (SD) models,
+ - Use the class `NeuronLatentConsistencyModelPipeline` to compile and run inference of LCMs distilled from Stable Diffusion (SD) models.
- And continue to use the class `NeuronStableDiffusionXLPipeline` for LCMs distilled from SDXL models.
Here are examples to compile the LCMs of Stable Diffusion ( [SimianLuo/LCM_Dreamshaper_v7](https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7) ) and Stable Diffusion XL( [latent-consistency/lcm-sdxl](https://huggingface.co/latent-consistency/lcm-sdxl) ), and then run inference on AWS Inferentia 2 :
diff --git a/optimum/commands/export/neuronx.py b/optimum/commands/export/neuronx.py
index d1a5cace4..5761bac44 100644
--- a/optimum/commands/export/neuronx.py
+++ b/optimum/commands/export/neuronx.py
@@ -102,6 +102,12 @@ def parse_args_neuronx(parser: "ArgumentParser"):
type=int,
help=f"Sequence length {doc_input}",
)
+ input_group.add_argument(
+ "--num_beams",
+ type=int,
+ default=1,
+ help=f"Number of beams for beam search {doc_input}",
+ )
input_group.add_argument(
"--num_choices",
type=int,
@@ -135,6 +141,16 @@ def parse_args_neuronx(parser: "ArgumentParser"):
"UNet model ID on huggingface.co or path on disk to load model from. This will replace the unet in the original Stable Diffusion pipeline."
),
)
+ optional_group.add_argument(
+ "--output_hidden_states",
+ action="store_true",
+ help=("Whether or not for the traced model to return the hidden states of all layers."),
+ )
+ optional_group.add_argument(
+ "--output_attentions",
+ action="store_true",
+ help=("Whether or not for the traced model to return the attentions tensors of all attention layers."),
+ )
class NeuronxExportCommand(BaseOptimumCLICommand):
diff --git a/optimum/exporters/neuron/__main__.py b/optimum/exporters/neuron/__main__.py
index 4c3788d47..8e70ee4d7 100644
--- a/optimum/exporters/neuron/__main__.py
+++ b/optimum/exporters/neuron/__main__.py
@@ -22,14 +22,16 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from requests.exceptions import ConnectionError as RequestsConnectionError
-from transformers import AutoConfig
+from transformers import AutoConfig, PretrainedConfig
from ...neuron.utils import (
+ DECODER_NAME,
DIFFUSION_MODEL_TEXT_ENCODER_2_NAME,
DIFFUSION_MODEL_TEXT_ENCODER_NAME,
DIFFUSION_MODEL_UNET_NAME,
DIFFUSION_MODEL_VAE_DECODER_NAME,
DIFFUSION_MODEL_VAE_ENCODER_NAME,
+ ENCODER_NAME,
NEURON_FILE_NAME,
is_neuron_available,
is_neuronx_available,
@@ -43,6 +45,7 @@
from .model_configs import * # noqa: F403
from .utils import (
build_stable_diffusion_components_mandatory_shapes,
+ get_encoder_decoder_models_for_export,
get_stable_diffusion_models_for_export,
replace_stable_diffusion_submodels,
)
@@ -64,8 +67,10 @@
if TYPE_CHECKING:
+ from transformers import PreTrainedModel
+
if is_diffusers_available():
- from diffusers import StableDiffusionPipeline
+ from diffusers import DiffusionPipeline, StableDiffusionPipeline
logger = logging.get_logger()
@@ -103,7 +108,11 @@ def infer_task(task: str, model_name_or_path: str) -> str:
def normalize_input_shapes(task: str, args: argparse.Namespace) -> Dict[str, int]:
config = AutoConfig.from_pretrained(args.model)
+
model_type = config.model_type.replace("_", "-")
+ if config.is_encoder_decoder:
+ model_type = model_type + "-encoder"
+
neuron_config_constructor = TasksManager.get_exporter_config_constructor(
model_type=model_type, exporter="neuron", task=task
)
@@ -112,6 +121,18 @@ def normalize_input_shapes(task: str, args: argparse.Namespace) -> Dict[str, int
return input_shapes
+def customize_optional_outputs(args: argparse.Namespace) -> Dict[str, bool]:
+ """
+ Customize optional outputs of the traced model, eg. if `output_attentions=True`, the attentions tensors will be traced.
+ """
+ possible_outputs = ["output_attentions", "output_hidden_states"]
+
+ customized_outputs = {}
+ for name in possible_outputs:
+ customized_outputs[name] = getattr(args, name, False)
+ return customized_outputs
+
+
def normalize_stable_diffusion_input_shapes(
args: argparse.Namespace,
) -> Dict[str, Dict[str, int]]:
@@ -173,6 +194,135 @@ def infer_stable_diffusion_shapes_from_diffusers(
return input_shapes
+def _get_submodels_and_neuron_configs(
+ model: Union["PreTrainedModel", "DiffusionPipeline"],
+ input_shapes: Dict[str, int],
+ task: str,
+ output: Path,
+ dynamic_batch_size: bool = False,
+ model_name_or_path: Optional[Union[str, Path]] = None,
+ submodels: Optional[Dict[str, Union[Path, str]]] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+):
+ is_stable_diffusion = "stable-diffusion" in task
+ is_encoder_decoder = (
+ getattr(model.config, "is_encoder_decoder", False) if isinstance(model.config, PretrainedConfig) else False
+ )
+
+ if is_stable_diffusion:
+ # TODO: Enable optional outputs for Stable Diffusion
+ if output_attentions or output_hidden_states:
+ raise ValueError(
+ f"`output_attentions` and `output_hidden_states` are not supported by the {task} task yet."
+ )
+ models_and_neuron_configs, output_model_names = _get_submodels_and_neuron_configs_for_stable_diffusion(
+ model, input_shapes, task, output, dynamic_batch_size, submodels
+ )
+ elif is_encoder_decoder:
+ optional_outputs = {"output_attentions": output_attentions, "output_hidden_states": output_hidden_states}
+ models_and_neuron_configs, output_model_names = _get_submodels_and_neuron_configs_for_encoder_decoder(
+ model, input_shapes, task, output, dynamic_batch_size, model_name_or_path, **optional_outputs
+ )
+ else:
+ # TODO: Enable optional outputs for encoders
+ if output_attentions or output_hidden_states:
+ raise ValueError(
+ f"`output_attentions` and `output_hidden_states` are not supported by the {task} task yet."
+ )
+ neuron_config_constructor = TasksManager.get_exporter_config_constructor(
+ model=model, exporter="neuron", task=task
+ )
+ neuron_config = neuron_config_constructor(model.config, dynamic_batch_size=dynamic_batch_size, **input_shapes)
+ model_name = model.name_or_path.split("/")[-1]
+ output_model_names = {model_name: "model.neuron"}
+ models_and_neuron_configs = {model_name: (model, neuron_config)}
+ maybe_save_preprocessors(model_name_or_path, output)
+ return models_and_neuron_configs, output_model_names
+
+
+def _get_submodels_and_neuron_configs_for_stable_diffusion(
+ model: Union["PreTrainedModel", "DiffusionPipeline"],
+ input_shapes: Dict[str, int],
+ task: str,
+ output: Path,
+ dynamic_batch_size: bool = False,
+ submodels: Optional[Dict[str, Union[Path, str]]] = None,
+):
+ check_compiler_compatibility_for_stable_diffusion()
+ model = replace_stable_diffusion_submodels(model, submodels)
+ if is_neuron_available():
+ raise RuntimeError(
+ "Stable diffusion export is not supported by neuron-cc on inf1, please use neuronx-cc on either inf2/trn1 instead."
+ )
+ input_shapes = infer_stable_diffusion_shapes_from_diffusers(input_shapes, model)
+
+ # Saving the model config and preprocessor as this is needed sometimes.
+ model.scheduler.save_pretrained(output.joinpath("scheduler"))
+ if getattr(model, "tokenizer", None) is not None:
+ model.tokenizer.save_pretrained(output.joinpath("tokenizer"))
+ if getattr(model, "tokenizer_2", None) is not None:
+ model.tokenizer_2.save_pretrained(output.joinpath("tokenizer_2"))
+ if getattr(model, "feature_extractor", None) is not None:
+ model.feature_extractor.save_pretrained(output.joinpath("feature_extractor"))
+ model.save_config(output)
+
+ models_and_neuron_configs = get_stable_diffusion_models_for_export(
+ pipeline=model,
+ task=task,
+ dynamic_batch_size=dynamic_batch_size,
+ **input_shapes,
+ )
+ output_model_names = {
+ DIFFUSION_MODEL_UNET_NAME: os.path.join(DIFFUSION_MODEL_UNET_NAME, NEURON_FILE_NAME),
+ DIFFUSION_MODEL_VAE_ENCODER_NAME: os.path.join(DIFFUSION_MODEL_VAE_ENCODER_NAME, NEURON_FILE_NAME),
+ DIFFUSION_MODEL_VAE_DECODER_NAME: os.path.join(DIFFUSION_MODEL_VAE_DECODER_NAME, NEURON_FILE_NAME),
+ }
+ if getattr(model, "text_encoder", None) is not None:
+ output_model_names[DIFFUSION_MODEL_TEXT_ENCODER_NAME] = os.path.join(
+ DIFFUSION_MODEL_TEXT_ENCODER_NAME, NEURON_FILE_NAME
+ )
+ if getattr(model, "text_encoder_2", None) is not None:
+ output_model_names[DIFFUSION_MODEL_TEXT_ENCODER_2_NAME] = os.path.join(
+ DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, NEURON_FILE_NAME
+ )
+ del model
+
+ return models_and_neuron_configs, output_model_names
+
+
+def _get_submodels_and_neuron_configs_for_encoder_decoder(
+ model: "PreTrainedModel",
+ input_shapes: Dict[str, int],
+ task: str,
+ output: Path,
+ dynamic_batch_size: bool = False,
+ model_name_or_path: Optional[Union[str, Path]] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+):
+ if is_neuron_available():
+ raise RuntimeError(
+ "Encoder-decoder models export is not supported by neuron-cc on inf1, please use neuronx-cc on either inf2/trn1 instead."
+ )
+
+ models_and_neuron_configs = get_encoder_decoder_models_for_export(
+ model=model,
+ task=task,
+ dynamic_batch_size=dynamic_batch_size,
+ input_shapes=input_shapes,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+ output_model_names = {
+ ENCODER_NAME: os.path.join(ENCODER_NAME, NEURON_FILE_NAME),
+ DECODER_NAME: os.path.join(DECODER_NAME, NEURON_FILE_NAME),
+ }
+ maybe_save_preprocessors(model_name_or_path, output)
+
+ return models_and_neuron_configs, output_model_names
+
+
def main_export(
model_name_or_path: str,
output: Union[str, Path],
@@ -188,7 +338,9 @@ def main_export(
local_files_only: bool = False,
use_auth_token: Optional[Union[bool, str]] = None,
do_validation: bool = True,
- submodels: Dict[str, Union[Path, str]] = None,
+ submodels: Optional[Dict[str, Union[Path, str]]] = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
**input_shapes,
):
output = Path(output)
@@ -196,6 +348,7 @@ def main_export(
output.parent.mkdir(parents=True)
task = TasksManager.map_from_synonym(task)
+ is_stable_diffusion = "stable-diffusion" in task
model_kwargs = {
"task": task,
@@ -211,58 +364,17 @@ def main_export(
}
model = TasksManager.get_model_from_task(**model_kwargs)
- is_stable_diffusion = "stable-diffusion" in task
- if not is_stable_diffusion:
- neuron_config_constructor = TasksManager.get_exporter_config_constructor(
- model=model, exporter="neuron", task=task
- )
- neuron_config = neuron_config_constructor(model.config, dynamic_batch_size=dynamic_batch_size, **input_shapes)
- if atol is None:
- atol = neuron_config.ATOL_FOR_VALIDATION
- model_name = model.name_or_path.split("/")[-1]
- output_model_names = {model_name: "model.neuron"}
- models_and_neuron_configs = {model_name: (model, neuron_config)}
- maybe_save_preprocessors(model, output.parent)
-
- if is_stable_diffusion:
- model = replace_stable_diffusion_submodels(model, submodels)
- check_compiler_compatibility_for_stable_diffusion()
- if is_neuron_available():
- raise RuntimeError(
- "Stable diffusion export is not supported by neuron-cc on inf1, please use neuronx-cc on either inf2/trn1 instead."
- )
- input_shapes = infer_stable_diffusion_shapes_from_diffusers(input_shapes, model)
-
- # Saving the model config and preprocessor as this is needed sometimes.
- model.scheduler.save_pretrained(output.joinpath("scheduler"))
- if hasattr(model, "tokenizer") and model.tokenizer is not None:
- model.tokenizer.save_pretrained(output.joinpath("tokenizer"))
- if hasattr(model, "tokenizer_2") and model.tokenizer_2 is not None:
- model.tokenizer_2.save_pretrained(output.joinpath("tokenizer_2"))
- if hasattr(model, "feature_extractor") and model.feature_extractor is not None:
- model.feature_extractor.save_pretrained(output.joinpath("feature_extractor"))
- model.save_config(output)
-
- models_and_neuron_configs = get_stable_diffusion_models_for_export(
- pipeline=model,
- task=task,
- dynamic_batch_size=dynamic_batch_size,
- **input_shapes,
- )
- output_model_names = {
- DIFFUSION_MODEL_UNET_NAME: os.path.join(DIFFUSION_MODEL_UNET_NAME, NEURON_FILE_NAME),
- DIFFUSION_MODEL_VAE_ENCODER_NAME: os.path.join(DIFFUSION_MODEL_VAE_ENCODER_NAME, NEURON_FILE_NAME),
- DIFFUSION_MODEL_VAE_DECODER_NAME: os.path.join(DIFFUSION_MODEL_VAE_DECODER_NAME, NEURON_FILE_NAME),
- }
- if hasattr(model, "text_encoder") and model.text_encoder is not None:
- output_model_names[DIFFUSION_MODEL_TEXT_ENCODER_NAME] = os.path.join(
- DIFFUSION_MODEL_TEXT_ENCODER_NAME, NEURON_FILE_NAME
- )
- if hasattr(model, "text_encoder_2") and model.text_encoder_2 is not None:
- output_model_names[DIFFUSION_MODEL_TEXT_ENCODER_2_NAME] = os.path.join(
- DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, NEURON_FILE_NAME
- )
- del model
+ models_and_neuron_configs, output_model_names = _get_submodels_and_neuron_configs(
+ model=model,
+ input_shapes=input_shapes,
+ task=task,
+ output=output,
+ dynamic_batch_size=dynamic_batch_size,
+ model_name_or_path=model_name_or_path,
+ submodels=submodels,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
_, neuron_outputs = export_models(
models_and_neuron_configs=models_and_neuron_configs,
@@ -329,6 +441,8 @@ def main():
input_shapes = normalize_input_shapes(task, args)
submodels = None
+ optional_outputs = customize_optional_outputs(args)
+
main_export(
model_name_or_path=args.model,
output=args.output,
@@ -340,6 +454,7 @@ def main():
trust_remote_code=args.trust_remote_code,
do_validation=not args.disable_validation,
submodels=submodels,
+ **optional_outputs,
**input_shapes,
)
diff --git a/optimum/exporters/neuron/base.py b/optimum/exporters/neuron/base.py
index 9e41f0c17..5f7277b53 100644
--- a/optimum/exporters/neuron/base.py
+++ b/optimum/exporters/neuron/base.py
@@ -119,6 +119,9 @@ def __init__(
audio_sequence_length: Optional[int] = None,
point_batch_size: Optional[int] = None,
nb_points_per_image: Optional[int] = None,
+ num_beams: int = 1,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
# TODO: add custom dtype after optimum 1.13 release
# int_dtype: str = "int64",
# float_dtype: str = "fp32",
@@ -147,6 +150,7 @@ def __init__(
"audio_sequence_length": audio_sequence_length,
"point_batch_size": point_batch_size,
"nb_points_per_image": nb_points_per_image,
+ "num_beams": num_beams,
}
input_shapes = {}
for name, value in axes_values.items():
@@ -154,6 +158,8 @@ def __init__(
input_shapes[name] = value
setattr(self, name, value)
setattr(self, "input_shapes", input_shapes)
+ setattr(self, "output_attentions", output_attentions)
+ setattr(self, "output_hidden_states", output_hidden_states)
setattr(self, "compiler_type", compiler_type)
setattr(self, "compiler_version", compiler_version)
@@ -290,7 +296,7 @@ def flatten_inputs(cls, inputs: Dict[str, Any]) -> Dict[str, Any]:
flatten[name] = value
return flatten
- def check_model_inputs_order(
+ def patch_model_for_export(
self,
model: "PreTrainedModel",
dummy_inputs: Optional[Dict[str, torch.Tensor]] = None,
diff --git a/optimum/exporters/neuron/config.py b/optimum/exporters/neuron/config.py
index 0e3d61bc8..01a3ae86a 100644
--- a/optimum/exporters/neuron/config.py
+++ b/optimum/exporters/neuron/config.py
@@ -16,9 +16,13 @@
Common Neuron configuration classes that handle most of the features for building model specific
configurations.
"""
+from typing import List
from ...utils import (
DummyBboxInputGenerator,
+ DummyInputGenerator,
+ DummySeq2SeqDecoderTextInputGenerator,
+ DummySeq2SeqPastKeyValuesGenerator,
DummyTextInputGenerator,
DummyVisionInputGenerator,
logging,
@@ -61,3 +65,93 @@ class TextNeuronDecoderConfig(NeuronDecoderConfig):
"""
pass
+
+
+class TextSeq2SeqNeuronConfig(NeuronConfig):
+ """
+ Handles encoder-decoder-based text architectures.
+ """
+
+ DUMMY_INPUT_GENERATOR_CLASSES = (
+ DummyTextInputGenerator,
+ DummySeq2SeqDecoderTextInputGenerator,
+ DummySeq2SeqPastKeyValuesGenerator,
+ )
+
+ @property
+ def inputs(self) -> List[str]:
+ common_inputs = []
+ # encoder + decoder without past
+ if "encoder" in self.MODEL_TYPE:
+ common_inputs = ["input_ids", "attention_mask"]
+ # decoder with past
+ if "decoder" in self.MODEL_TYPE:
+ common_inputs = [
+ "decoder_input_ids",
+ "decoder_attention_mask",
+ "encoder_hidden_states",
+ "attention_mask", # TODO: replace with `encoder_attention_mask` after optimum 1.14 release
+ ]
+
+ return common_inputs
+
+ @property
+ def outputs(self) -> List[str]:
+ common_outputs = []
+ # encoder + decoder without past
+ if "encoder" in self.MODEL_TYPE:
+ common_outputs = (
+ [f"present.{idx}.self.key" for idx in range(self._config.num_decoder_layers)]
+ + [f"present.{idx}.self.value" for idx in range(self._config.num_decoder_layers)]
+ + [f"present.{idx}.cross.key" for idx in range(self._config.num_decoder_layers)]
+ + [f"present.{idx}.cross.value" for idx in range(self._config.num_decoder_layers)]
+ )
+ # decoder with past
+ if "decoder" in self.MODEL_TYPE:
+ beam_outputs = (
+ ["next_token_scores", "next_tokens", "next_indices"] if self.num_beams > 1 else ["next_tokens"]
+ )
+ common_outputs = (
+ beam_outputs
+ + [f"past.{idx}.self.key" for idx in range(self._config.num_decoder_layers)]
+ + [f"past.{idx}.self.value" for idx in range(self._config.num_decoder_layers)]
+ + [f"past.{idx}.cross.key" for idx in range(self._config.num_decoder_layers)]
+ + [f"past.{idx}.cross.value" for idx in range(self._config.num_decoder_layers)]
+ )
+
+ if self.output_hidden_states:
+ # Flatten hidden states of all layers
+ common_outputs += [
+ f"decoder_hidden_state.{idx}" for idx in range(self._config.num_decoder_layers + 1)
+ ] # +1 for the embedding layer
+
+ if self.output_attentions:
+ # Flatten attentions tensors of all attention layers
+ common_outputs += [f"decoder_attention.{idx}" for idx in range(self._config.num_decoder_layers)]
+ if getattr(self._config, "is_encoder_decoder", False) is True:
+ common_outputs += [f"cross_attention.{idx}" for idx in range(self._config.num_decoder_layers)]
+
+ return common_outputs
+
+ def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]:
+ dummy_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[0](
+ self.task, self._normalized_config, **kwargs
+ )
+ dummy_decoder_text_input_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[1](
+ self.task,
+ self._normalized_config,
+ **kwargs,
+ )
+ dummy_seq2seq_past_key_values_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[2](
+ self.task,
+ self._normalized_config,
+ encoder_sequence_length=dummy_text_input_generator.sequence_length,
+ **kwargs,
+ )
+ dummy_inputs_generators = [
+ dummy_text_input_generator,
+ dummy_decoder_text_input_generator,
+ dummy_seq2seq_past_key_values_generator,
+ ]
+
+ return dummy_inputs_generators
diff --git a/optimum/exporters/neuron/convert.py b/optimum/exporters/neuron/convert.py
index 912049524..d5b826ee6 100644
--- a/optimum/exporters/neuron/convert.py
+++ b/optimum/exporters/neuron/convert.py
@@ -169,8 +169,12 @@ def validate_model_outputs(
with torch.no_grad():
reference_model.eval()
ref_inputs = config.generate_dummy_inputs(return_tuple=False, **input_shapes)
- if hasattr(config._config, "_class_name") and "AutoencoderKL" in config._config._class_name:
- # VAE components for stable diffusion
+ if getattr(reference_model.config, "is_encoder_decoder", False):
+ reference_model = config.patch_model_for_export(reference_model, device="cpu", **input_shapes)
+ if "AutoencoderKL" in getattr(config._config, "_class_name", "") or getattr(
+ reference_model.config, "is_encoder_decoder", False
+ ):
+ # VAE components for stable diffusion or Encoder-Decoder models
ref_inputs = tuple(ref_inputs.values())
ref_outputs = reference_model(*ref_inputs)
neuron_inputs = ref_inputs
@@ -217,9 +221,9 @@ def validate_model_outputs(
# Check the shape and values match
shape_failures = []
value_failures = []
- for name, output in zip(neuron_output_names_list, neuron_outputs):
+ for i, (name, output) in enumerate(zip(neuron_output_names_list, neuron_outputs)):
if isinstance(output, torch.Tensor):
- ref_output = ref_outputs[name].numpy()
+ ref_output = ref_outputs[name].numpy() if isinstance(ref_outputs, dict) else ref_outputs[i].numpy()
output = output.numpy()
elif isinstance(output, tuple): # eg. `hidden_states` of `AutoencoderKL` is a tuple of tensors.
ref_output = torch.stack(ref_outputs[name]).numpy()
@@ -336,6 +340,8 @@ def export_models(
compiler_version=NEURON_COMPILER_VERSION,
model_type=getattr(sub_neuron_config, "MODEL_TYPE", None),
task=getattr(sub_neuron_config, "task", None),
+ output_attentions=getattr(sub_neuron_config, "output_attentions", False),
+ output_hidden_states=getattr(sub_neuron_config, "output_hidden_states", False),
)
if isinstance(model_config, PretrainedConfig):
model_config = DiffusersPretrainedConfig.from_dict(model_config.__dict__)
@@ -424,7 +430,14 @@ def export_neuronx(
dummy_inputs = config.generate_dummy_inputs(**input_shapes)
dummy_inputs = config.flatten_inputs(dummy_inputs)
dummy_inputs_tuple = tuple(dummy_inputs.values())
- checked_model = config.check_model_inputs_order(model, dummy_inputs)
+
+ aliases = {}
+ if getattr(model.config, "is_encoder_decoder", False):
+ checked_model = config.patch_model_for_export(model, **input_shapes)
+ if getattr(config, "is_decoder", False):
+ aliases = config.generate_io_aliases(checked_model)
+ else:
+ checked_model = config.patch_model_for_export(model, dummy_inputs)
if auto_cast is not None:
logger.info(f"Using Neuron: --auto-cast {auto_cast}")
@@ -440,7 +453,12 @@ def export_neuronx(
# diffusers specific
compiler_args = add_stable_diffusion_compiler_args(config, compiler_args)
- neuron_model = neuronx.trace(checked_model, dummy_inputs_tuple, compiler_args=compiler_args)
+ neuron_model = neuronx.trace(
+ checked_model,
+ dummy_inputs_tuple,
+ compiler_args=compiler_args,
+ input_output_aliases=aliases,
+ )
if config.dynamic_batch_size is True:
neuron_model = neuronx.dynamic_batch(neuron_model)
@@ -538,7 +556,7 @@ def export_neuron(
dummy_inputs = config.generate_dummy_inputs(**input_shapes)
dummy_inputs_tuple = tuple(dummy_inputs.values())
- checked_model = config.check_model_inputs_order(model, dummy_inputs)
+ checked_model = config.patch_model_for_export(model, dummy_inputs)
compiler_args = convert_neuronx_compiler_args_to_neuron(auto_cast, auto_cast_type, disable_fast_relayout)
neuron_model = neuron.trace(
diff --git a/optimum/exporters/neuron/model_configs.py b/optimum/exporters/neuron/model_configs.py
index 8c3891d13..aa7d05fa8 100644
--- a/optimum/exporters/neuron/model_configs.py
+++ b/optimum/exporters/neuron/model_configs.py
@@ -19,22 +19,32 @@
import torch
+from ...neuron.utils import DummyBeamValuesGenerator
from ...utils import (
+ DummyInputGenerator,
DummySeq2SeqDecoderTextInputGenerator,
DummyTimestepInputGenerator,
DummyVisionInputGenerator,
NormalizedConfig,
NormalizedConfigManager,
+ NormalizedSeq2SeqConfig,
NormalizedTextAndVisionConfig,
is_diffusers_available,
)
+from ...utils.normalized_config import T5LikeNormalizedTextConfig
from ..tasks import TasksManager
from .config import (
TextAndVisionNeuronConfig,
TextEncoderNeuronConfig,
TextNeuronDecoderConfig,
+ TextSeq2SeqNeuronConfig,
VisionNeuronConfig,
)
+from .model_wrappers import (
+ T5DecoderWrapper,
+ T5EncoderWrapper,
+ UnetNeuronWrapper,
+)
if TYPE_CHECKING:
@@ -224,7 +234,7 @@ class UNetNeuronConfig(VisionNeuronConfig):
ATOL_FOR_VALIDATION = 1e-3
MANDATORY_AXES = ("batch_size", "sequence_length", "num_channels", "width", "height")
MODEL_TYPE = "unet"
-
+ CUSTOM_MODEL_WRAPPER = UnetNeuronWrapper
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
image_size="sample_size",
num_channels="in_channels",
@@ -281,40 +291,8 @@ def generate_dummy_inputs(self, return_tuple: bool = False, **kwargs):
else:
return dummy_inputs
- class ModelWrapper(torch.nn.Module):
- def __init__(self, model, input_names: List[str]):
- super().__init__()
- self.model = model
- self.input_names = input_names
-
- def forward(self, *inputs):
- if len(inputs) != len(self.input_names):
- raise ValueError(
- f"The model needs {len(self.input_names)} inputs: {self.input_names}."
- f" But only {len(input)} inputs are passed."
- )
-
- ordered_inputs = dict(zip(self.input_names, inputs))
-
- added_cond_kwargs = {
- "text_embeds": ordered_inputs.pop("text_embeds", None),
- "time_ids": ordered_inputs.pop("time_ids", None),
- }
- sample = ordered_inputs.pop("sample", None)
- timestep = ordered_inputs.pop("timestep").float().expand((sample.shape[0],))
-
- out_tuple = self.model(
- sample=sample,
- timestep=timestep,
- added_cond_kwargs=added_cond_kwargs,
- return_dict=False,
- **ordered_inputs,
- )
-
- return out_tuple
-
- def check_model_inputs_order(self, model, dummy_inputs):
- return self.ModelWrapper(model, list(dummy_inputs.keys()))
+ def patch_model_for_export(self, model, dummy_inputs):
+ return self.CUSTOM_MODEL_WRAPPER(model, list(dummy_inputs.keys()))
@property
def is_sdxl(self) -> bool:
@@ -379,13 +357,13 @@ def inputs(self) -> List[str]:
def outputs(self) -> List[str]:
return ["sample"]
- def check_model_inputs_order(
+ def patch_model_for_export(
self,
model: "VaeDecoder",
dummy_inputs: Dict[str, torch.Tensor],
**kwargs,
):
- return super().check_model_inputs_order(model=model, dummy_inputs=dummy_inputs, forward_with_tuple=True)
+ return super().patch_model_for_export(model=model, dummy_inputs=dummy_inputs, forward_with_tuple=True)
@register_in_tasks_manager("gpt2", "text-generation")
@@ -398,6 +376,30 @@ class LLamaNeuronConfig(TextNeuronDecoderConfig):
NEURONX_CLASS = "llama.model.LlamaForSampling"
+@register_in_tasks_manager("t5-encoder", "text2text-generation")
+class T5EncoderNeuronConfig(TextSeq2SeqNeuronConfig):
+ ATOL_FOR_VALIDATION = 1e-3
+ MANDATORY_AXES = ("batch_size", "sequence_length", "num_beams")
+ MODEL_TYPE = "t5-encoder"
+ CUSTOM_MODEL_WRAPPER = T5EncoderWrapper
+ NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args(
+ hidden_size="d_model",
+ num_attention_heads="num_heads",
+ encoder_num_layers="num_layers",
+ decoder_num_layers="num_decoder_layers",
+ key_value_dim="d_kv",
+ allow_new=True,
+ )
+
+ @property
+ def is_decoder(self) -> bool:
+ return False
+
+ def patch_model_for_export(self, model, device="xla", **kwargs):
+ num_beams = kwargs.pop("num_beams", 1)
+ return self.CUSTOM_MODEL_WRAPPER(model, num_beams=num_beams, device=device)
+
+
@register_in_tasks_manager("opt", "text-generation")
class OPTNeuronConfig(TextNeuronDecoderConfig):
NEURONX_CLASS = "opt.model.OPTForSampling"
@@ -406,3 +408,66 @@ class OPTNeuronConfig(TextNeuronDecoderConfig):
@register_in_tasks_manager("bloom", "text-generation")
class BloomNeuronConfig(TextNeuronDecoderConfig):
NEURONX_CLASS = "bloom.model.BloomForSampling"
+
+
+@register_in_tasks_manager("t5-decoder", "text2text-generation")
+class T5DecoderNeuronConfig(TextSeq2SeqNeuronConfig):
+ ATOL_FOR_VALIDATION = 1e-3
+ DUMMY_INPUT_GENERATOR_CLASSES = TextSeq2SeqNeuronConfig.DUMMY_INPUT_GENERATOR_CLASSES + (DummyBeamValuesGenerator,)
+ MANDATORY_AXES = ("batch_size", "sequence_length", "num_beams")
+ MODEL_TYPE = "t5-decoder"
+ CUSTOM_MODEL_WRAPPER = T5DecoderWrapper
+ NORMALIZED_CONFIG_CLASS = T5LikeNormalizedTextConfig
+
+ @property
+ def is_decoder(self) -> bool:
+ return True
+
+ @property
+ def inputs(self) -> List[str]:
+ common_inputs = super().inputs + ["beam_idx", "beam_scores"]
+ return common_inputs
+
+ def generate_dummy_inputs(self, **kwargs):
+ batch_size = kwargs.pop("batch_size") * kwargs.get("num_beams")
+ dummy_inputs = super().generate_dummy_inputs(batch_size=batch_size, **kwargs)
+ dummy_inputs["decoder_input_ids"] = dummy_inputs["decoder_input_ids"][:, :1] # sequence_length = 1
+ dummy_inputs["encoder_hidden_states"] = dummy_inputs["encoder_hidden_states"][0]
+
+ return dummy_inputs
+
+ def _create_dummy_input_generator_classes(self, **kwargs) -> List["DummyInputGenerator"]:
+ dummy_inputs_generators = super()._create_dummy_input_generator_classes(**kwargs)
+ dummy_beam_values_generator = self.DUMMY_INPUT_GENERATOR_CLASSES[-1](
+ self.task,
+ self._normalized_config,
+ num_beams=kwargs.pop("num_beams", 1),
+ **kwargs,
+ )
+ dummy_inputs_generators.append(dummy_beam_values_generator)
+ return dummy_inputs_generators
+
+ def patch_model_for_export(self, model, device="xla", **kwargs):
+ batch_size = kwargs.pop("batch_size", 1)
+ sequence_length = kwargs.pop("sequence_length", 1)
+ num_beams = kwargs.pop("num_beams", 1)
+
+ return self.CUSTOM_MODEL_WRAPPER(
+ model,
+ batch_size=batch_size,
+ sequence_length=sequence_length,
+ num_beams=num_beams,
+ output_hidden_states=self.output_hidden_states,
+ output_attentions=self.output_attentions,
+ device=device,
+ )
+
+ def generate_io_aliases(self, model):
+ num_outputs_from_trace = 3 if model.num_beams > 1 else 1
+ aliases = {}
+ for i in range(len(model.past_key_values_sa)):
+ aliases[model.past_key_values_sa[i]] = i + num_outputs_from_trace
+ for i in range(len(model.past_key_values_ca)):
+ aliases[model.past_key_values_ca[i]] = len(model.past_key_values_sa) + i + num_outputs_from_trace
+
+ return aliases
diff --git a/optimum/exporters/neuron/model_wrappers.py b/optimum/exporters/neuron/model_wrappers.py
new file mode 100644
index 000000000..0b1ae4504
--- /dev/null
+++ b/optimum/exporters/neuron/model_wrappers.py
@@ -0,0 +1,341 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Model wrappers for Neuron export."""
+
+from typing import TYPE_CHECKING, List, Optional
+
+import torch
+from transformers.models.t5.modeling_t5 import T5LayerCrossAttention
+
+
+if TYPE_CHECKING:
+ from transformers.modeling_utils import PreTrainedModel
+
+
+class UnetNeuronWrapper(torch.nn.Module):
+ def __init__(self, model, input_names: List[str]):
+ super().__init__()
+ self.model = model
+ self.input_names = input_names
+
+ def forward(self, *inputs):
+ if len(inputs) != len(self.input_names):
+ raise ValueError(
+ f"The model needs {len(self.input_names)} inputs: {self.input_names}."
+ f" But only {len(input)} inputs are passed."
+ )
+
+ ordered_inputs = dict(zip(self.input_names, inputs))
+
+ added_cond_kwargs = {
+ "text_embeds": ordered_inputs.pop("text_embeds", None),
+ "time_ids": ordered_inputs.pop("time_ids", None),
+ }
+ sample = ordered_inputs.pop("sample", None)
+ timestep = ordered_inputs.pop("timestep").float().expand((sample.shape[0],))
+
+ out_tuple = self.model(
+ sample=sample,
+ timestep=timestep,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ **ordered_inputs,
+ )
+
+ return out_tuple
+
+
+# Adapted from https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/pytorch/torch-neuronx/t5-inference-tutorial.html
+class T5EncoderWrapper(torch.nn.Module):
+ """Wrapper to trace the encoder and the kv cache initialization in the decoder."""
+
+ def __init__(
+ self,
+ model: "PreTrainedModel",
+ num_beams: int = 1,
+ device: str = "xla",
+ tp_degree: Optional[int] = None,
+ ):
+ super().__init__()
+ self.model = model
+ self.config = model.config
+ self.num_beams = num_beams
+ self.device = device
+ self.tp_degree = tp_degree
+
+ def forward(self, input_ids, attention_mask):
+ # Infer shapes
+ batch_size = input_ids.shape[0]
+ sequence_length = input_ids.shape[1]
+
+ encoder_output = self.model.encoder(
+ input_ids=input_ids, attention_mask=attention_mask, output_attentions=False, output_hidden_states=False
+ )
+
+ last_hidden_state = encoder_output["last_hidden_state"]
+ encoder_hidden_states = torch.concat(
+ [tensor.unsqueeze(0).repeat(self.num_beams, 1, 1) for tensor in last_hidden_state]
+ )
+
+ decoder_blocks = self.model.decoder.block
+ present_key_value_states_sa = []
+ present_key_value_states_ca = []
+
+ for block in decoder_blocks:
+ # Cross attention has to be initialized with the encoder hidden state
+ cross_attention: T5LayerCrossAttention = block.layer[1]
+ attention = cross_attention.EncDecAttention
+
+ def shape(states):
+ """projection"""
+ return states.view(
+ self.num_beams * batch_size, -1, self.config.num_heads, attention.key_value_proj_dim
+ ).transpose(1, 2)
+
+ key_states = shape(attention.k(encoder_hidden_states))
+ value_states = shape(attention.v(encoder_hidden_states))
+
+ # cross_attn_kv_state
+ present_key_value_states_ca.append(key_states)
+ present_key_value_states_ca.append(value_states)
+
+ # Self attention kv states are initialized to zeros. This is done to keep the size of the kv cache tensor constant.
+ # The kv cache is padded here to keep a fixed shape.
+ # [key states]
+ present_key_value_states_sa.append(
+ torch.zeros(
+ (self.num_beams * batch_size, self.config.num_heads, sequence_length - 1, self.config.d_kv),
+ dtype=torch.float32,
+ device=self.device,
+ )
+ )
+ # [value states]
+ present_key_value_states_sa.append(
+ torch.zeros(
+ (self.num_beams * batch_size, self.config.num_heads, sequence_length - 1, self.config.d_kv),
+ dtype=torch.float32,
+ device=self.device,
+ )
+ )
+
+ return present_key_value_states_sa + present_key_value_states_ca
+
+
+# Adapted from https://awsdocs-neuron.readthedocs-hosted.com/en/latest/src/examples/pytorch/torch-neuronx/t5-inference-tutorial.html
+class T5DecoderWrapper(torch.nn.Module):
+ """Wrapper to trace the decoder with past keys values with a language head."""
+
+ def __init__(
+ self,
+ model: "PreTrainedModel",
+ batch_size: int,
+ sequence_length: int,
+ num_beams: int = 1,
+ output_hidden_states: bool = False,
+ output_attentions: bool = False,
+ device: str = "xla",
+ tp_degree: Optional[int] = None,
+ ):
+ super().__init__()
+ self.model = model
+ self.config = model.config
+ self.batch_size = batch_size
+ self.sequence_length = sequence_length
+ self.num_beams = num_beams
+ self.output_hidden_states = output_hidden_states
+ self.output_attentions = output_attentions
+ self.device = device
+ self.tp_degree = tp_degree
+
+ # Initialize KV cache (num_beams, n_heads, seq_length, dim_per_head)
+ if device == "cpu":
+ self.past_key_values_sa = [
+ torch.ones(
+ (num_beams, self.config.num_heads, self.sequence_length - 1, self.config.d_kv), dtype=torch.float32
+ )
+ for _ in range(self.config.num_decoder_layers * 2)
+ ]
+ self.past_key_values_ca = [
+ torch.ones(
+ (num_beams, self.config.num_heads, self.sequence_length, self.config.d_kv), dtype=torch.float32
+ )
+ for _ in range(self.config.num_decoder_layers * 2)
+ ]
+ elif device == "xla":
+ self.past_key_values_sa = torch.nn.ParameterList(
+ [
+ torch.nn.Parameter(
+ torch.ones(
+ (
+ self.batch_size * self.num_beams,
+ self.config.num_heads,
+ sequence_length - 1,
+ self.config.d_kv,
+ ),
+ dtype=torch.float32,
+ ),
+ requires_grad=False,
+ )
+ for _ in range(self.config.num_decoder_layers * 2)
+ ]
+ )
+ self.past_key_values_ca = torch.nn.ParameterList(
+ [
+ torch.nn.Parameter(
+ torch.ones(
+ (
+ self.batch_size * self.num_beams,
+ self.config.num_heads,
+ sequence_length,
+ self.config.d_kv,
+ ),
+ dtype=torch.float32,
+ ),
+ requires_grad=False,
+ )
+ for _ in range(self.config.num_decoder_layers * 2)
+ ]
+ )
+
+ def update_past(self, past_key_values):
+ new_past_sa = []
+ new_past_ca = []
+ for past_layer in past_key_values:
+ new_past_layer = list(past_layer)
+ for i in range(len(new_past_layer[:2])):
+ new_past_layer[i] = past_layer[i][:, :, 1:]
+ new_past_sa += [
+ new_past_layer[:2],
+ ]
+ new_past_ca += [
+ new_past_layer[2:],
+ ]
+ return new_past_sa, new_past_ca
+
+ def reorder_cache(self, past_key_values, beam_idx):
+ for i in range(len(past_key_values)):
+ gather_index = beam_idx.view([beam_idx.shape[0], 1, 1, 1]).expand_as(past_key_values[i])
+ past_key_values[i] = torch.gather(past_key_values[i], dim=0, index=gather_index)
+ return past_key_values
+
+ def forward(
+ self,
+ input_ids,
+ decoder_attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ beam_idx,
+ beam_scores,
+ **kwargs,
+ ):
+ if self.num_beams > 1:
+ # We reorder the cache based on the beams selected in each iteration. Required step for beam search.
+ past_key_values_sa = self.reorder_cache(self.past_key_values_sa, beam_idx)
+ past_key_values_ca = self.reorder_cache(self.past_key_values_ca, beam_idx)
+ else:
+ # We do not need to reorder for greedy sampling
+ past_key_values_sa = self.past_key_values_sa
+ past_key_values_ca = self.past_key_values_ca
+
+ # The cache is stored in a flatten form. We order the cache per layer before passing it to the decoder.
+ # Each layer has 4 tensors, so we group by 4.
+ past_key_values = [
+ [*past_key_values_sa[i * 2 : i * 2 + 2], *past_key_values_ca[i * 2 : i * 2 + 2]]
+ for i in range(0, int(len(past_key_values_ca) / 2))
+ ]
+
+ decoder_output = self.model.decoder(
+ input_ids=input_ids,
+ attention_mask=decoder_attention_mask,
+ past_key_values=past_key_values,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=True,
+ output_attentions=self.output_attentions,
+ output_hidden_states=self.output_hidden_states,
+ )
+
+ last_hidden_state = decoder_output["last_hidden_state"]
+ past_key_values = decoder_output["past_key_values"]
+ if self.output_hidden_states:
+ decoder_hidden_states = list(
+ decoder_output["hidden_states"]
+ ) # flatten `hidden_states` which is a tuple of tensors
+
+ if self.output_attentions:
+ decoder_attentions = list(
+ decoder_output["attentions"]
+ ) # flatten `decoder_attentions` which is a tuple of tensors
+ cross_attentions = list(
+ decoder_output["cross_attentions"]
+ ) # flatten `cross_attentions` which is a tuple of tensors
+
+ if self.config.tie_word_embeddings:
+ # Rescale output before projecting on vocab
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
+ last_hidden_state = last_hidden_state * (self.model.config.d_model**-0.5)
+
+ lm_logits = self.model.lm_head(last_hidden_state)
+
+ past_key_values_sa, past_key_values_ca = self.update_past(past_key_values)
+
+ # We flatten the cache to a single array. This is required for the input output aliasing to work
+ past_key_values_sa = [vec for kv_per_layer in past_key_values_sa for vec in kv_per_layer]
+ past_key_values_ca = [vec for kv_per_layer in past_key_values_ca for vec in kv_per_layer]
+
+ if self.device == "cpu":
+ self.past_key_values_sa = past_key_values_sa
+ self.past_key_values_ca = past_key_values_ca
+
+ # We calculate topk inside the wrapper
+ next_token_logits = lm_logits[:, -1, :]
+
+ if self.num_beams > 1:
+ # This section of beam search is run outside the decoder in the huggingface t5 implementation.
+ # To maximize the computation within the neuron device, we move this within the wrapper
+ logit_max, _ = torch.max(next_token_logits, dim=-1, keepdim=True)
+ logsumexp = torch.log(torch.exp(next_token_logits - logit_max).sum(dim=-1, keepdim=True))
+ next_token_scores = next_token_logits - logit_max - logsumexp
+ next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
+
+ # reshape for beam search
+ vocab_size = next_token_scores.shape[-1]
+ next_token_scores = next_token_scores.view(self.batch_size, self.num_beams * vocab_size)
+ next_token_scores = next_token_scores * 1
+
+ # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
+ next_token_scores, next_tokens = torch.topk(
+ next_token_scores, 2 * self.num_beams, dim=1, largest=True, sorted=True
+ )
+
+ next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
+ next_tokens = next_tokens % vocab_size
+
+ neuron_outputs = [next_token_scores, next_tokens, next_indices] + past_key_values_sa + past_key_values_ca
+
+ else:
+ # Greedy
+ next_tokens = torch.argmax(next_token_logits, dim=-1)
+
+ neuron_outputs = [next_tokens] + past_key_values_sa + past_key_values_ca
+
+ if self.output_hidden_states:
+ neuron_outputs += decoder_hidden_states
+
+ if self.output_attentions:
+ neuron_outputs += decoder_attentions
+ neuron_outputs += cross_attentions
+
+ return neuron_outputs
diff --git a/optimum/exporters/neuron/utils.py b/optimum/exporters/neuron/utils.py
index 46b66920c..b49817f40 100644
--- a/optimum/exporters/neuron/utils.py
+++ b/optimum/exporters/neuron/utils.py
@@ -23,11 +23,13 @@
from transformers import PretrainedConfig
from ...neuron.utils import (
+ DECODER_NAME,
DIFFUSION_MODEL_TEXT_ENCODER_2_NAME,
DIFFUSION_MODEL_TEXT_ENCODER_NAME,
DIFFUSION_MODEL_UNET_NAME,
DIFFUSION_MODEL_VAE_DECODER_NAME,
DIFFUSION_MODEL_VAE_ENCODER_NAME,
+ ENCODER_NAME,
get_attention_scores_sd,
get_attention_scores_sdxl,
)
@@ -157,7 +159,7 @@ def get_stable_diffusion_models_for_export(
Whether the Neuron compiled model supports dynamic batch size.
Returns:
- `Dict[str, Tuple[Union[`PreTrainedModel`, `ModelMixin`], `NeuronConfig`]: A Dict containing the model and
+ `Dict[str, Tuple[Union[`PreTrainedModel`, `ModelMixin`], `NeuronConfig`]`: A Dict containing the model and
Neuron configs for the different components of the model.
"""
models_for_export = _get_submodels_for_export_stable_diffusion(pipeline=pipeline, task=task)
@@ -326,6 +328,15 @@ def override_diffusers_2_0_attn_processors(model):
return model
+def check_mandatory_input_shapes(neuron_config_constructor, task, input_shapes):
+ mandatory_shapes = neuron_config_constructor.func.get_mandatory_axes_for_task(task)
+ for name in mandatory_shapes:
+ if input_shapes.get(name, None) is None:
+ raise AttributeError(
+ f"Cannot find the value of `{name}` which is mandatory for exporting the model to the neuron format, please set the value explicitly."
+ )
+
+
def replace_stable_diffusion_submodels(pipeline, submodels):
if submodels is not None:
unet_id = submodels.pop("unet", None)
@@ -334,3 +345,68 @@ def replace_stable_diffusion_submodels(pipeline, submodels):
pipeline.unet = unet
return pipeline
+
+
+def get_encoder_decoder_models_for_export(
+ model: "PreTrainedModel",
+ task: str,
+ input_shapes: Dict[str, int],
+ dynamic_batch_size: Optional[bool] = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+) -> Dict[str, Tuple["PreTrainedModel", "NeuronConfig"]]:
+ """
+ Returns the components of an encoder-decoder model and their subsequent neuron configs.
+ The encoder includes the compute of encoder hidden states and the initialization of KV
+ cache. The decoder the autoprogressive process of generating tokens, which takes past
+ key values as inputs to save the compute.
+
+ Args:
+ model ("PreTrainedModel"):
+ The model to export.
+ input_shapes (`Dict[str, int]`):
+ Static shapes used for compiling the encoder and the decoder.
+ dynamic_batch_size (`bool`, defaults to `False`):
+ Whether the Neuron compiled model supports dynamic batch size.
+ output_attentions (`bool`, defaults to `False`):
+ Whether or not for the traced model to return the attentions tensors of all attention layers.
+ output_hidden_states (`bool`, defaults to `False`):
+ Whether or not for the traced model to return the hidden states of all layers.
+
+ Returns:
+ `Dict[str, Tuple["PreTrainedModel", "NeuronConfig"]]`: A Dict containing the model and
+ Neuron configs for the different components of the model.
+ """
+ models_for_export = {}
+
+ # Encoder
+ model_type = getattr(model.config, "model_type") + "-encoder"
+ encoder_config_constructor = TasksManager.get_exporter_config_constructor(
+ exporter="neuron", model_type=model_type, task=task
+ )
+ check_mandatory_input_shapes(encoder_config_constructor, task, input_shapes)
+ encoder_neuron_config = encoder_config_constructor(
+ config=model.config,
+ task=task,
+ dynamic_batch_size=dynamic_batch_size,
+ **input_shapes,
+ )
+ models_for_export[ENCODER_NAME] = (model, encoder_neuron_config)
+
+ # Decoder
+ model_type = getattr(model.config, "model_type") + "-decoder"
+ decoder_config_constructor = TasksManager.get_exporter_config_constructor(
+ exporter="neuron", model_type=model_type, task=task
+ )
+ check_mandatory_input_shapes(encoder_config_constructor, task, input_shapes)
+ decoder_neuron_config = decoder_config_constructor(
+ config=model.config,
+ task=task,
+ dynamic_batch_size=dynamic_batch_size,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ **input_shapes,
+ )
+ models_for_export[DECODER_NAME] = (model, decoder_neuron_config)
+
+ return models_for_export
diff --git a/optimum/neuron/__init__.py b/optimum/neuron/__init__.py
index 276365daa..f9ceb961d 100644
--- a/optimum/neuron/__init__.py
+++ b/optimum/neuron/__init__.py
@@ -42,6 +42,7 @@
"NeuronStableDiffusionXLInpaintPipeline",
],
"modeling_decoder": ["NeuronDecoderModel"],
+ "modeling_seq2seq": ["NeuronModelForSeq2SeqLM"],
"accelerate": [
"NeuronAccelerator",
"NeuronAcceleratorState",
@@ -73,6 +74,7 @@
NeuronStableDiffusionXLInpaintPipeline,
NeuronStableDiffusionXLPipeline,
)
+ from .modeling_seq2seq import NeuronModelForSeq2SeqLM
from .pipelines import pipeline
from .trainers import NeuronTrainer, Seq2SeqNeuronTrainer
from .training_args import NeuronTrainingArguments, Seq2SeqNeuronTrainingArguments
diff --git a/optimum/neuron/generation/utils.py b/optimum/neuron/generation/utils.py
index ce6f93e8b..51027af4d 100644
--- a/optimum/neuron/generation/utils.py
+++ b/optimum/neuron/generation/utils.py
@@ -17,7 +17,7 @@
import copy
import inspect
import warnings
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
@@ -51,10 +51,6 @@
from transformers.utils import ModelOutput, logging
-if TYPE_CHECKING:
- from transformers.generation.streamers import BaseStreamer
- from transformers.modeling_utils import PreTrainedModel
-
logger = logging.get_logger(__name__)
@@ -82,6 +78,91 @@ class NeuronGenerationMixin(GenerationMixin):
learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
"""
+ @staticmethod
+ def _initialize_attention(
+ model_kwargs,
+ num_padding_values,
+ batch_size,
+ device,
+ is_encoder_decoder,
+ ):
+ """Initializes the appropriate attention mask -- encoder-decoder models use `decoder_attention_mask`"""
+ if is_encoder_decoder:
+ # One 1 for decoder_start_token_id, 0s for the currently-unfilled locations in the past_key_values tensor,
+ # 1s for the actual input_ids
+ decoder_attention_mask = torch.cat(
+ [
+ torch.zeros((batch_size, num_padding_values), dtype=torch.int32),
+ torch.ones((batch_size, 2), dtype=torch.int32),
+ ],
+ axis=1,
+ ).to(device)
+ mask = {"decoder_attention_mask": decoder_attention_mask}
+ else:
+ attention_mask = model_kwargs.pop("attention_mask")
+ # 0s for the currently-unfilled locations in the past_key_values tensor, 1s for the actual input_ids
+ attention_mask = torch.cat(
+ [
+ torch.zeros(
+ (batch_size, num_padding_values), dtype=attention_mask.dtype, device=attention_mask.device
+ ),
+ attention_mask,
+ torch.ones((batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device),
+ ],
+ axis=1,
+ )
+ mask = {"attention_mask": attention_mask}
+
+ return mask
+
+ @staticmethod
+ def _update_attention(model_kwargs, batch_size, is_encoder_decoder):
+ """Updates the appropriate attention mask -- encoder-decoder models use `decoder_attention_mask`"""
+
+ attention_mask_name = "decoder_attention_mask" if is_encoder_decoder else "attention_mask"
+ attention_mask = model_kwargs.pop(attention_mask_name)
+ attention_mask_update_slice = torch.ones(
+ (batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device
+ )
+ attention_mask = torch.cat([attention_mask[:, 1:], attention_mask_update_slice], dim=-1)
+ mask = {attention_mask_name: attention_mask}
+ return mask
+
+ @staticmethod
+ def _initialize_past(past_key_values, num_padding_values):
+ """Initialize past_key_values with zeros -- the structure depends on `batch_axis`"""
+
+ new_past = ()
+ for past_layer in past_key_values:
+ new_past_layer = list(past_layer)
+ for i in range(len(new_past_layer[:2])):
+ b, n_heads, _, head_dim = past_layer[i].shape
+ new_past_layer[i] = torch.cat(
+ [
+ torch.zeros(
+ (b, n_heads, num_padding_values, head_dim),
+ dtype=past_layer[i].dtype,
+ device=past_layer[i].device,
+ ),
+ past_layer[i],
+ ],
+ dim=2,
+ )
+ new_past += (tuple(new_past_layer),)
+
+ return new_past
+
+ @staticmethod
+ def _update_past(past_key_values):
+ new_past = ()
+ for past_layer in past_key_values:
+ new_past_layer = list(past_layer)
+ for i, _ in enumerate(new_past_layer[:2]):
+ new_past_layer[i] = past_layer[i][:, :, 1:]
+ new_past += (tuple(new_past_layer),)
+
+ return new_past
+
def _update_model_kwargs_for_xla_generation(
self,
outputs: ModelOutput,
@@ -93,81 +174,6 @@ def _update_model_kwargs_for_xla_generation(
seq_length: Optional[int] = None,
use_cache: bool = True,
) -> Dict[str, Any]:
- def _initialize_attention(model_kwargs, num_padding_values, is_encoder_decoder):
- """Initializes the appropriate attention mask -- encoder-decoder models use `decoder_attention_mask`"""
- if is_encoder_decoder:
- # One 1 for decoder_start_token_id, 0s for the currently-unfilled locations in the past_key_values tensor,
- # 1s for the actual input_ids
- decoder_attention_mask = torch.cat(
- [
- torch.zeros((batch_size, num_padding_values), dtype=torch.int32),
- torch.ones((batch_size, 2), dtype=torch.int32),
- ],
- axis=1,
- ).to(outputs.logits.device)
- mask = {"decoder_attention_mask": decoder_attention_mask}
- else:
- attention_mask = model_kwargs.pop("attention_mask")
- # 0s for the currently-unfilled locations in the past_key_values tensor, 1s for the actual input_ids
- attention_mask = torch.cat(
- [
- torch.zeros(
- (batch_size, num_padding_values), dtype=attention_mask.dtype, device=attention_mask.device
- ),
- attention_mask,
- torch.ones((batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device),
- ],
- axis=1,
- )
- mask = {"attention_mask": attention_mask}
-
- return mask
-
- def _update_attention(model_kwargs, is_encoder_decoder):
- """Updates the appropriate attention mask -- encoder-decoder models use `decoder_attention_mask`"""
-
- attention_mask_name = "decoder_attention_mask" if is_encoder_decoder else "attention_mask"
- attention_mask = model_kwargs.pop(attention_mask_name)
- attention_mask_update_slice = torch.ones(
- (batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device
- )
- attention_mask = torch.cat([attention_mask[:, 1:], attention_mask_update_slice], dim=-1)
- mask = {attention_mask_name: attention_mask}
- return mask
-
- def _initialize_past(past_key_values, num_padding_values):
- """Initialize past_key_values with zeros -- the structure depends on `batch_axis`"""
-
- new_past = ()
- for past_layer in past_key_values:
- new_past_layer = list(past_layer)
- for i in range(len(new_past_layer[:2])):
- b, n_heads, _, head_dim = past_layer[i].shape
- new_past_layer[i] = torch.cat(
- [
- torch.zeros(
- (b, n_heads, num_padding_values, head_dim),
- dtype=past_layer[i].dtype,
- device=past_layer[i].device,
- ),
- past_layer[i],
- ],
- dim=2,
- )
- new_past += (tuple(new_past_layer),)
-
- return new_past
-
- def _update_past(past_key_values):
- new_past = ()
- for past_layer in past_key_values:
- new_past_layer = list(past_layer)
- for i, _ in enumerate(new_past_layer[:2]):
- new_past_layer[i] = past_layer[i][:, :, 1:]
- new_past += (tuple(new_past_layer),)
-
- return new_past
-
if use_cache:
past_key_values = self._extract_past_from_model_output(outputs)
if past_key_values is None:
@@ -182,11 +188,13 @@ def _update_past(past_key_values):
# previous autoregressive generation steps (step 0 has no past_key_values, step 1 has 1 past_key_values value, ..., the last step
# has `max_length - 1` past_key_values values).
num_padding_values = max_length - seq_length
- mask = _initialize_attention(model_kwargs, num_padding_values, is_encoder_decoder)
- new_past = _initialize_past(past_key_values, num_padding_values)
+ mask = self._initialize_attention(
+ model_kwargs, num_padding_values, batch_size, outputs.logits.device, is_encoder_decoder
+ )
+ new_past = self._initialize_past(past_key_values, num_padding_values)
else:
- mask = _update_attention(model_kwargs, is_encoder_decoder)
- new_past = _update_past(past_key_values)
+ mask = self._update_attention(model_kwargs, batch_size, is_encoder_decoder)
+ new_past = self._update_past(past_key_values)
# sets the updated variables (mask and past_key_values)
model_kwargs.update(mask)
@@ -253,425 +261,12 @@ def _expand_dict_for_generation(dict_to_expand):
model_kwargs = _expand_dict_for_generation(model_kwargs)
- if is_encoder_decoder:
- if model_kwargs.get("encoder_outputs") is None:
- raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
- model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
-
- return input_ids, model_kwargs
-
- def beam_search(
- self,
- input_ids: torch.LongTensor,
- beam_scorer: BeamScorer,
- logits_processor: Optional[LogitsProcessorList] = None,
- stopping_criteria: Optional[StoppingCriteriaList] = None,
- max_length: Optional[int] = None,
- pad_token_id: Optional[int] = None,
- eos_token_id: Optional[Union[int, List[int]]] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- output_scores: Optional[bool] = None,
- return_dict_in_generate: Optional[bool] = None,
- synced_gpus: Optional[bool] = False,
- seq_length: Optional[int] = None,
- **model_kwargs,
- ) -> Union[BeamSearchOutput, torch.LongTensor]:
- r"""
- Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
- can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
-
-
-
- In most cases, you do not need to call [`~generation.GenerationMixin.beam_search`] directly. Use generate()
- instead. For an overview of generation strategies and code examples, check the [following
- guide](../generation_strategies).
-
-
-
- Parameters:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- The sequence used as a prompt for the generation.
- beam_scorer (`BeamScorer`):
- An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
- sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
- logits_processor (`LogitsProcessorList`, *optional*):
- An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
- used to modify the prediction scores of the language modeling head applied at each generation step.
- stopping_criteria (`StoppingCriteriaList`, *optional*):
- An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
- used to tell if the generation loop should stop.
- max_length (`int`, *optional*, defaults to 20):
- **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
- tokens. The maximum length of the sequence to be generated.
- pad_token_id (`int`, *optional*):
- The id of the *padding* token.
- eos_token_id (`Union[int, List[int]]`, *optional*):
- The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
- output_attentions (`bool`, *optional*, defaults to `False`):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more details.
- output_hidden_states (`bool`, *optional*, defaults to `False`):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
- for more details.
- output_scores (`bool`, *optional*, defaults to `False`):
- Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
- return_dict_in_generate (`bool`, *optional*, defaults to `False`):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- synced_gpus (`bool`, *optional*, defaults to `False`):
- Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
- seq_length:
- Length of current input_ids sequence
- model_kwargs:
- Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
- an encoder-decoder model the kwargs should include `encoder_outputs`.
-
- Return:
- [`generation.BeamSearchDecoderOnlyOutput`], [`~generation.BeamSearchEncoderDecoderOutput`] or
- `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
- [`~generation.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
- `return_dict_in_generate=True` or a [`~generation.BeamSearchEncoderDecoderOutput`] if
- `model.config.is_encoder_decoder=True`.
-
-
- Examples:
-
- ```python
- >>> from transformers import (
- ... AutoTokenizer,
- ... AutoModelForSeq2SeqLM,
- ... LogitsProcessorList,
- ... MinLengthLogitsProcessor,
- ... BeamSearchScorer,
- ... )
- >>> import torch
-
- >>> tokenizer = AutoTokenizer.from_pretrained("t5-base")
- >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
-
- >>> encoder_input_str = "translate English to German: How old are you?"
- >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
-
-
- >>> # lets run beam search using 3 beams
- >>> num_beams = 3
- >>> # define decoder start token ids
- >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
- >>> input_ids = input_ids * model.config.decoder_start_token_id
-
- >>> # add encoder_outputs to model keyword arguments
- >>> model_kwargs = {
- ... "encoder_outputs": model.get_encoder()(
- ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
- ... )
- ... }
-
- >>> # instantiate beam scorer
- >>> beam_scorer = BeamSearchScorer(
- ... batch_size=1,
- ... num_beams=num_beams,
- ... device=model.device,
- ... )
-
- >>> # instantiate logits processors
- >>> logits_processor = LogitsProcessorList(
- ... [
- ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
- ... ]
- ... )
-
- >>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)
-
- >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
- ['Wie alt bist du?']
- ```"""
- # init values
- logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
- stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
- if max_length is not None:
- warnings.warn(
- "`max_length` is deprecated in this function, use"
- " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
- UserWarning,
- )
- stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
- if len(stopping_criteria) == 0:
- warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
- pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
- eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
- if isinstance(eos_token_id, int):
- eos_token_id = [eos_token_id]
- output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
- output_attentions = (
- output_attentions if output_attentions is not None else self.generation_config.output_attentions
- )
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
- )
- return_dict_in_generate = (
- return_dict_in_generate
- if return_dict_in_generate is not None
- else self.generation_config.return_dict_in_generate
- )
-
- batch_size = len(beam_scorer._beam_hyps)
- num_beams = beam_scorer.num_beams
-
- batch_beam_size, cur_len = input_ids.shape
-
- # Overwrite cur_len
- cur_len = seq_length
-
- if num_beams * batch_size != batch_beam_size:
- raise ValueError(
- f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
- )
-
- # init attention / hidden states / scores tuples
- scores = () if (return_dict_in_generate and output_scores) else None
- beam_indices = (
- tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
- )
- decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
- cross_attentions = () if (return_dict_in_generate and output_attentions) else None
- decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
-
- # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
- if return_dict_in_generate and self.config.is_encoder_decoder:
- encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
- encoder_hidden_states = (
- model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
- )
-
- # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
- # of the first beam are considered to avoid sampling the exact same tokens across all beams.
- # beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
- beam_scores_device = "cpu"
- beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=beam_scores_device)
- beam_scores[:, 1:] = -1e9
- beam_scores = beam_scores.view((batch_size * num_beams,))
-
- this_peer_finished = False # used by synced_gpus only
- while True:
- if synced_gpus:
- # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
- # The following logic allows an early break if all peers finished generating their sequence
- this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
- # send 0.0 if we finished, 1.0 otherwise
- dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
- # did all peers finish? the reduced sum will be 0.0 then
- if this_peer_finished_flag.item() == 0.0:
- break
-
- # prepare model inputs
- if model_kwargs["use_cache"]:
- # From max_length-sized input_ids, select first
- # cur_len - 1 values.
- update_indices = torch.stack(
- [torch.arange(input_ids.size(0)), torch.tensor(cur_len - 1).repeat(input_ids.size(0))], dim=-1
- )
- input_ids_ = input_ids[update_indices[:, 0], update_indices[:, 1], None]
- model_inputs = self.prepare_inputs_for_generation(input_ids_, **model_kwargs)
- else:
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
-
- outputs = self(
- **model_inputs,
- return_dict=True,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- )
-
- if synced_gpus and this_peer_finished:
- cur_len = cur_len + 1
- continue # don't waste resources running the code we don't need
-
- if not model_kwargs["use_cache"]:
- one_hot = (
- torch.cat(
- [
- torch.tensor([0]).repeat(1, cur_len - 1),
- torch.tensor([1]).repeat(1, 1),
- torch.tensor([0]).repeat(1, input_ids.size(1) - cur_len),
- ],
- dim=1,
- )
- .to(device=outputs.logits.device)
- .float()
- )
- next_token_logits = torch.matmul(one_hot, outputs.logits)
- next_token_logits = next_token_logits.squeeze(1)
- else:
- next_token_logits = outputs.logits[:, -1, :]
-
- # Manually compute log softmax
- # log_softmax(vi) = vi - max(vi) - log(sum(exp(vi - max(vi))))
- logit_max, _ = torch.max(next_token_logits, dim=-1, keepdim=True)
- logsumexp = torch.log(torch.exp(next_token_logits - logit_max).sum(dim=-1, keepdim=True))
- next_token_scores = next_token_logits - logit_max - logsumexp
- # (batch_size * num_beams, vocab_size)
-
- xm.mark_step()
-
- # We don't want to change every single logit processor, so
- # we peform this processing on CPU.
- input_ids_ = input_ids.to("cpu")[:, :cur_len]
- next_token_scores_ = next_token_scores.to("cpu")
- next_token_scores_processed = logits_processor(input_ids_, next_token_scores_)
-
- next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
-
- # Store scores, attentions and hidden_states when required
- if return_dict_in_generate:
- if output_scores:
- scores += (next_token_scores_processed,)
- if output_attentions:
- decoder_attentions += (
- (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
- )
- if self.config.is_encoder_decoder:
- cross_attentions += (outputs.cross_attentions,)
-
- if output_hidden_states:
- decoder_hidden_states += (
- (outputs.decoder_hidden_states,)
- if self.config.is_encoder_decoder
- else (outputs.hidden_states,)
- )
-
- # reshape for beam search
- vocab_size = next_token_scores.shape[-1]
- next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
- next_token_scores = next_token_scores * 1
-
- # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
- next_token_scores, next_tokens = torch.topk(
- next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
- )
-
- next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
- next_tokens = next_tokens % vocab_size
-
- # stateless
- beam_outputs = beam_scorer.process(
- input_ids.to("cpu")[:, :cur_len],
- next_token_scores.to("cpu"),
- next_tokens.to("cpu"),
- next_indices.to("cpu"),
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- beam_indices=beam_indices,
- )
-
- beam_scores = beam_outputs["next_beam_scores"]
- beam_next_tokens = beam_outputs["next_beam_tokens"]
- beam_idx = beam_outputs["next_beam_indices"]
-
- update_indices = torch.stack(
- [torch.arange(batch_beam_size), torch.tensor(cur_len - 1).repeat(batch_beam_size)], dim=-1
- )
- update_indices_2 = torch.stack(
- [torch.arange(batch_beam_size), torch.tensor(cur_len).repeat(batch_beam_size)], dim=-1
- )
- # First select beam_indices
- device = input_ids.device
- beam_idx_device = beam_idx.to(device=input_ids.device)
- input_ids[:, :] = input_ids[beam_idx_device.long(), :]
-
- # Then append new tokens
- input_ids[update_indices_2[:, 0], update_indices_2[:, 1], None] = beam_next_tokens.unsqueeze(-1).to(device)
- input_ids = input_ids * 1 # Hack to materialize tensor
-
- # update generated ids, model inputs, and length for next step
- model_kwargs = self._update_model_kwargs_for_xla_generation(
- outputs,
- model_kwargs,
- batch_size=batch_beam_size,
- is_encoder_decoder=self.config.is_encoder_decoder,
- max_length=stopping_criteria.max_length,
- seq_length=cur_len,
- use_cache=model_kwargs["use_cache"],
- )
- if model_kwargs["past_key_values"] is not None:
- model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
-
- if return_dict_in_generate and output_scores:
- beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
-
- # increase cur_len
- cur_len = cur_len + 1
-
- # stop when each sentence is finished, or if we exceed the maximum length
- stop_criterion_1 = beam_scorer.is_done
- if isinstance(stopping_criteria, list):
- if len(stopping_criteria) == 1:
- stopping_criteria = stopping_criteria[0]
-
- # Cases that can be handled in XLA without requiring
- # non-padded input_ids
- if isinstance(stopping_criteria, MaxLengthCriteria):
- stop_criterion_2 = cur_len >= stopping_criteria.max_length
- elif isinstance(stopping_criteria, MaxTimeCriteria):
- stop_criterion_2 = stopping_criteria(input_ids, scores)
- else:
- # Other cases will be handled on CPU
- batch_size, _ = input_ids.shape
- input_ids_cpu = input_ids.to("cpu")
- mask = torch.cat(
- [torch.ones(batch_size, cur_len), torch.zeros(batch_size, input_ids.shape[1] - cur_len)], dim=1
- ).bool()
- input_ids_cpu = torch.masked_select(input_ids_cpu, mask).reshape((batch_size, cur_len))
- scores_cpu = scores.to("cpu") if torch.is_tensor(scores) else scores
- stop_criterion_2 = stopping_criteria(input_ids_cpu, scores_cpu)
-
- if stop_criterion_1 or stop_criterion_2:
- if not synced_gpus:
- break
- else:
- this_peer_finished = True
-
- sequence_outputs = beam_scorer.finalize(
- input_ids.to("cpu"),
- beam_scores.to("cpu"),
- next_tokens.to("cpu"),
- next_indices.to("cpu"),
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- max_length=stopping_criteria.max_length,
- beam_indices=beam_indices,
- )
-
- for k, v in sequence_outputs.items():
- if type(v) == torch.Tensor:
- sequence_outputs[k] = sequence_outputs[k].to(input_ids.device)
-
- if return_dict_in_generate:
- if not output_scores:
- sequence_outputs["sequence_scores"] = None
-
- if self.config.is_encoder_decoder:
- return BeamSearchEncoderDecoderOutput(
- sequences=sequence_outputs["sequences"],
- sequences_scores=sequence_outputs["sequence_scores"],
- scores=scores,
- beam_indices=sequence_outputs["beam_indices"],
- encoder_attentions=encoder_attentions,
- encoder_hidden_states=encoder_hidden_states,
- decoder_attentions=decoder_attentions,
- cross_attentions=cross_attentions,
- decoder_hidden_states=decoder_hidden_states,
- )
- else:
- return BeamSearchDecoderOnlyOutput(
- sequences=sequence_outputs["sequences"],
- sequences_scores=sequence_outputs["sequence_scores"],
- scores=scores,
- beam_indices=sequence_outputs["beam_indices"],
- attentions=decoder_attentions,
- hidden_states=decoder_hidden_states,
- )
- else:
- return sequence_outputs["sequences"]
+ if is_encoder_decoder:
+ if model_kwargs.get("encoder_outputs") is None:
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
+
+ return input_ids, model_kwargs
@torch.no_grad()
def generate(
@@ -682,8 +277,7 @@ def generate(
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
- assistant_model: Optional["PreTrainedModel"] = None,
- streamer: Optional["BaseStreamer"] = None,
+ is_traced_inference: bool = False,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
r"""
@@ -702,23 +296,23 @@ def generate(
Parameters:
- inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
+ inputs (`Optional[torch.Tensor]`, defaults to `None`):
The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
`input_ids`, `input_values`, `input_features`, or `pixel_values`.
- generation_config (`~generation.GenerationConfig`, *optional*):
+ generation_config (`Optional[GenerationConfig]`, defaults to `None`):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which had the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
- logits_processor (`LogitsProcessorList`, *optional*):
+ logits_processor (`Optional[LogitsProcessorList]`, defaults to `None`):
Custom logits processors that complement the default logits processors built from arguments and
generation config. If a logit processor is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
- stopping_criteria (`StoppingCriteriaList`, *optional*):
+ stopping_criteria (`Optional[StoppingCriteriaList]`, defaults to `None`):
Custom stopping criteria that complement the default stopping criteria built from arguments and a
generation config. If a stopping criteria is passed that is already created with the arguments or a
generation config an error is thrown. This feature is intended for advanced users.
@@ -729,18 +323,13 @@ def generate(
on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
Retrieval](https://arxiv.org/abs/2010.00904).
- synced_gpus (`bool`, *optional*):
+ synced_gpus (`Optional[bool]`, defaults to `None`):
Whether to continue running the while loop until max_length. Unless overridden this flag will be set to
`True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished
generating before other GPUs. Otherwise it'll be set to `False`.
- assistant_model (`PreTrainedModel`, *optional*):
- An assistant model that can be used to accelerate generation. The assistant model must have the exact
- same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model
- is much faster than running generation with the model you're calling generate from. As such, the
- assistant model should be much smaller.
- streamer (`BaseStreamer`, *optional*):
- Streamer object that will be used to stream the generated sequences. Generated tokens are passed
- through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
+ is_traced_inference (`bool`, defaults to `False`):
+ Whether the decoder is traced or using XLA lazy tensor. If the decoder is traced, next tokens and the beam scores
+ are computed inside the decoder.
kwargs:
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
@@ -820,12 +409,14 @@ def generate(
# 4. Define other model kwargs
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
- if generation_config.use_cache:
+ if generation_config.use_cache and not is_traced_inference:
warnings.warn("use_cache is not supported for generation on Neuron devices, switching to use_cache=False.")
- model_kwargs["use_cache"] = False
+ model_kwargs["use_cache"] = False
+ else:
+ model_kwargs["use_cache"] = generation_config.use_cache
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
- requires_attention_mask = "encoder_outputs" not in model_kwargs
+ requires_attention_mask = "encoder_outputs" not in model_kwargs and not is_traced_inference
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
@@ -843,7 +434,7 @@ def generate(
"generation results, please set `padding_side='left'` when initializing the tokenizer."
)
- if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
+ if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs and not is_traced_inference:
# if model is encoder decoder encoder_outputs are created
# and added to `model_kwargs`
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
@@ -863,9 +454,6 @@ def generate(
else:
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
- if streamer is not None:
- streamer.put(input_ids.cpu())
-
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
@@ -962,12 +550,7 @@ def generate(
if generation_config.num_beam_groups > generation_config.num_beams:
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
- if streamer is not None and (generation_config.num_beams > 1):
- raise ValueError(
- "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
- )
-
- if self.device.type != input_ids.device.type:
+ if hasattr(self, "device") and self.device.type != input_ids.device.type:
warnings.warn(
"You are calling .generate() with the `input_ids` being on a device type different"
f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
@@ -1010,7 +593,7 @@ def generate(
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
seq_length=input_ids_seq_length,
- streamer=streamer,
+ is_traced_inference=is_traced_inference,
**model_kwargs,
)
elif is_beam_gen_mode:
@@ -1049,15 +632,340 @@ def generate(
return_dict_in_generate=generation_config.return_dict_in_generate,
synced_gpus=synced_gpus,
seq_length=input_ids_seq_length,
+ is_traced_inference=is_traced_inference,
**model_kwargs,
)
else:
- raise ValueError("Only greedy search and beam search are supported on Neuron.")
+ raise ValueError("Only greedy search and beam search are supported on Neuron.")
+
+ def greedy_search(
+ self,
+ input_ids: torch.LongTensor,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ max_length: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ eos_token_id: Optional[Union[int, List[int]]] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_scores: Optional[bool] = None,
+ return_dict_in_generate: Optional[bool] = None,
+ synced_gpus: bool = False,
+ seq_length: Optional[int] = None,
+ is_traced_inference: bool = False,
+ **model_kwargs,
+ ) -> Union[GreedySearchOutput, torch.LongTensor]:
+ r"""
+ Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be
+ used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
+
+
+
+ In most cases, you do not need to call [`~generation.GenerationMixin.greedy_search`] directly. Use generate()
+ instead. For an overview of generation strategies and code examples, check the [following
+ guide](../generation_strategies).
+
+
+
+
+ Parameters:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ The sequence used as a prompt for the generation.
+ logits_processor (`LogitsProcessorList`, *optional*):
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
+ used to modify the prediction scores of the language modeling head applied at each generation step.
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
+ used to tell if the generation loop should stop.
+
+ max_length (`int`, *optional*, defaults to 20):
+ **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
+ tokens. The maximum length of the sequence to be generated.
+ pad_token_id (`int`, *optional*):
+ The id of the *padding* token.
+ eos_token_id (`Union[int, List[int]]`, *optional*):
+ The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
+ output_attentions (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more details.
+ output_hidden_states (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more details.
+ output_scores (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
+ return_dict_in_generate (`bool`, *optional*, defaults to `False`):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ synced_gpus (`bool`, *optional*, defaults to `False`):
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
+ seq_length (`Optional[int]`, defaults to `False`):
+ Length of current input_ids sequence
+ is_traced_inference (`bool`, defaults to `False`):
+ Whether the decoder is traced or using XLA lazy tensor. If the decoder is traced, next tokens and the beam scores
+ are computed inside the decoder.
+ model_kwargs:
+ Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
+ If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
+
+ Return:
+ [`~generation.GreedySearchDecoderOnlyOutput`], [`~generation.GreedySearchEncoderDecoderOutput`] or
+ `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
+ [`~generation.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
+ `return_dict_in_generate=True` or a [`~generation.GreedySearchEncoderDecoderOutput`] if
+ `model.config.is_encoder_decoder=True`.
+
+ Examples:
+
+ ```python
+ >>> from transformers import AutoTokenizer
+ >>> from optimum.neuron import NeuronModelForSeq2SeqLM
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("t5-small")
+ >>> input_shapes = {"batch_size": 1, "sequence_length": 128, "num_beams": 1}
+ >>> model = NeuronModelForSeq2SeqLM.from_pretrained("t5-small", export=True, dynamic_batch_size=False, **input_shapes)
+
+ >>> input_prompt = "translate English to German: Lets eat good food."
+ >>> inputs = tokenizer(input_prompt, return_tensors="pt")
+
+ >>> outputs = model.greedy_search(input_ids)
+
+ >>> results = [tokenizer.decode(t, skip_special_tokens=True) for t in outputs]
+ ```
+ """
+ # init values
+ if logits_processor is not None and is_traced_inference:
+ logger.warning(
+ "`logits_processor` will not be neglected because in `optimum-neuron`, `next_tokens` is computed inside the compiled decoder. If you want us to support custom logits_processor during the compilation, please file an issue to https://github.com/huggingface/optimum-neuron."
+ )
+ elif logits_processor is None:
+ logits_processor = LogitsProcessorList()
+ use_cache = model_kwargs.pop("use_cache", False)
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
+ if max_length is not None:
+ warnings.warn(
+ "`max_length` is deprecated in this function, use"
+ " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
+ UserWarning,
+ )
+ stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
+ eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
+ if isinstance(eos_token_id, int):
+ eos_token_id = [eos_token_id]
+ eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
+ output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
+ output_attentions = (
+ output_attentions if output_attentions is not None else self.generation_config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
+ )
+ return_dict_in_generate = (
+ return_dict_in_generate
+ if return_dict_in_generate is not None
+ else self.generation_config.return_dict_in_generate
+ )
+
+ # init attention / hidden states / scores tuples
+ scores = None
+ if return_dict_in_generate and output_scores:
+ if is_traced_inference:
+ logger.warning(
+ "`output_scores` will be neglected because currently we do not trace `next_token_scores` for greedy search (we do only in beam search). If you want us to support the option during the compilation, please file an issue to https://github.com/huggingface/optimum-neuron."
+ )
+ else:
+ scores = ()
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
+
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
+ if return_dict_in_generate and self.config.is_encoder_decoder:
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
+ encoder_hidden_states = (
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
+ )
+
+ # keep track of which sequences are already finished
+ unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
+
+ this_peer_finished = False # used by synced_gpus only
+ while True:
+ if synced_gpus:
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
+ # The following logic allows an early break if all peers finished generating their sequence
+ this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
+ # send 0.0 if we finished, 1.0 otherwise
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
+ # did all peers finish? the reduced sum will be 0.0 then
+ if this_peer_finished_flag.item() == 0.0:
+ break
+
+ # prepare model inputs
+ if use_cache:
+ # From max_length-sized input_ids, select first
+ # seq_length - 1 values.
+
+ if model_kwargs.get("past_key_values") is None:
+ input_ids_ = input_ids[:, :seq_length]
+ else:
+ update_indices = torch.stack(
+ [torch.arange(input_ids.size(0)), torch.tensor(seq_length - 1).repeat(input_ids.size(0))],
+ dim=-1,
+ )
+ input_ids_ = input_ids[update_indices[:, 0], update_indices[:, 1], None]
+
+ model_inputs = self.prepare_inputs_for_generation(input_ids_, **model_kwargs)
+ else:
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
+
+ # forward pass to get next token
+ outputs = self(
+ **model_inputs,
+ return_dict=True,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ if synced_gpus and this_peer_finished:
+ continue # don't waste resources running the code we don't need
+
+ if not is_traced_inference:
+ if not use_cache:
+ one_hot = (
+ torch.cat(
+ [
+ torch.tensor([0]).repeat(1, seq_length - 1),
+ torch.tensor([1]).repeat(1, 1),
+ torch.tensor([0]).repeat(1, input_ids.size(1) - seq_length),
+ ],
+ dim=1,
+ )
+ .to(device=outputs.logits.device)
+ .float()
+ )
+ next_token_logits = torch.matmul(one_hot, outputs.logits)
+ next_token_logits = next_token_logits.squeeze(1)
+ else:
+ next_token_logits = outputs.logits[:, -1, :]
+
+ # pre-process distribution
+ # Move to cpu to handle arbitrary logits_processor
+ next_tokens_scores = logits_processor(input_ids.to("cpu")[:, :seq_length], next_token_logits.to("cpu"))
+ next_tokens_scores = next_tokens_scores.to(input_ids.device)
+
+ # argmax
+ next_tokens = torch.argmax(next_tokens_scores, dim=-1)
+
+ if return_dict_in_generate and output_scores:
+ scores += (next_tokens_scores,)
+ else:
+ next_tokens = outputs[0]
+
+ # Store scores, attentions and hidden_states when required
+ if return_dict_in_generate:
+ if output_attentions:
+ decoder_attentions += (
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
+ )
+ if self.config.is_encoder_decoder:
+ cross_attentions += (outputs.cross_attentions,)
+
+ if output_hidden_states:
+ decoder_hidden_states += (
+ (outputs.decoder_hidden_states,)
+ if self.config.is_encoder_decoder
+ else (outputs.hidden_states,)
+ )
+
+ # finished sentences should have their next token be a padding token
+ if eos_token_id is not None:
+ if pad_token_id is None:
+ raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
+
+ # update generated ids, model inputs, and length for next step
+ batch_size, _ = input_ids.shape
+ update_indices = torch.stack(
+ [torch.arange(batch_size), torch.tensor(seq_length).repeat(batch_size)], dim=-1
+ )
+ input_ids[update_indices[:, 0], update_indices[:, 1]] = next_tokens[:]
+ model_kwargs = self._update_model_kwargs_for_xla_generation(
+ outputs=outputs,
+ model_kwargs=model_kwargs,
+ batch_size=batch_size,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ max_length=stopping_criteria.max_length,
+ seq_length=seq_length,
+ use_cache=use_cache,
+ )
+
+ seq_length += 1
+
+ # if eos_token was found in one sentence, set sentence to finished
+ if eos_token_id_tensor is not None:
+ unfinished_sequences = unfinished_sequences.mul(
+ next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
+ )
+
+ if not is_traced_inference:
+ xm.mark_step()
+
+ # stop when each sentence is finished, or if we exceed the maximum length
+ stop_criterion_1 = unfinished_sequences.max() == 0
+
+ if isinstance(stopping_criteria, list):
+ if len(stopping_criteria) == 1:
+ stopping_criteria = stopping_criteria[0]
+
+ # Cases that can be handled in XLA without requiring
+ # non-padded input_ids
+ if isinstance(stopping_criteria, MaxLengthCriteria):
+ stop_criterion_2 = seq_length >= stopping_criteria.max_length
+ elif isinstance(stopping_criteria, MaxTimeCriteria):
+ stop_criterion_2 = stopping_criteria(input_ids, scores)
+ else:
+ # Other cases will be handled on CPU
+ batch_size, _ = input_ids.shape
+ mask = torch.cat(
+ [torch.ones(batch_size, seq_length), torch.zeros(batch_size, input_ids.shape[1] - seq_length)],
+ dim=1,
+ ).bool()
+ input_ids_cpu = torch.masked_select(input_ids, mask).reshape((batch_size, seq_length)).to("cpu")
+ scores_cpu = scores.to("cpu") if torch.is_tensor(scores) else scores
+ stop_criterion_2 = stopping_criteria(input_ids_cpu, scores_cpu)
+
+ if stop_criterion_1 or stop_criterion_2:
+ this_peer_finished = True
+
+ if this_peer_finished and not synced_gpus:
+ break
+
+ if return_dict_in_generate:
+ if self.config.is_encoder_decoder:
+ return GreedySearchEncoderDecoderOutput(
+ sequences=input_ids,
+ scores=scores,
+ encoder_attentions=encoder_attentions,
+ encoder_hidden_states=encoder_hidden_states,
+ decoder_attentions=decoder_attentions,
+ cross_attentions=cross_attentions,
+ decoder_hidden_states=decoder_hidden_states,
+ )
+ else:
+ return GreedySearchDecoderOnlyOutput(
+ sequences=input_ids,
+ scores=scores,
+ attentions=decoder_attentions,
+ hidden_states=decoder_hidden_states,
+ )
+ else:
+ return input_ids
- def greedy_search(
+ def beam_search(
self,
input_ids: torch.LongTensor,
+ beam_scorer: BeamScorer,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
max_length: Optional[int] = None,
@@ -1067,34 +975,35 @@ def greedy_search(
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
- synced_gpus: bool = False,
+ synced_gpus: Optional[bool] = False,
seq_length: Optional[int] = None,
- streamer: Optional["BaseStreamer"] = None,
+ is_traced_inference: bool = False,
**model_kwargs,
- ) -> Union[GreedySearchOutput, torch.LongTensor]:
+ ) -> Union[BeamSearchOutput, torch.LongTensor]:
r"""
- Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be
- used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
+ Generates sequences of token ids for models with a language modeling head using **beam search decoding** and
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
- In most cases, you do not need to call [`~generation.GenerationMixin.greedy_search`] directly. Use generate()
+ In most cases, you do not need to call [`~generation.GenerationMixin.beam_search`] directly. Use generate()
instead. For an overview of generation strategies and code examples, check the [following
guide](../generation_strategies).
-
Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
+ beam_scorer (`BeamScorer`):
+ An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and
+ sorted during generation. For more information, the documentation of [`BeamScorer`] should be read.
logits_processor (`LogitsProcessorList`, *optional*):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
stopping_criteria (`StoppingCriteriaList`, *optional*):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
-
max_length (`int`, *optional*, defaults to 20):
**DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
tokens. The maximum length of the sequence to be generated.
@@ -1114,75 +1023,74 @@ def greedy_search(
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
synced_gpus (`bool`, *optional*, defaults to `False`):
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
- seq_length:
+ seq_length (`Optional[int]`, defaults to `False`):
Length of current input_ids sequence
- streamer (`BaseStreamer`, *optional*):
- Streamer object that will be used to stream the generated sequences. Generated tokens are passed
- through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
- Unsupported for XLA devices
+ is_traced_inference (`bool`, defaults to `False`):
+ Whether the decoder is traced or using XLA lazy tensor. If the decoder is traced, next tokens and the beam scores
+ are computed inside the decoder.
model_kwargs:
- Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
- If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
Return:
- [`~generation.GreedySearchDecoderOnlyOutput`], [`~generation.GreedySearchEncoderDecoderOutput`] or
+ [`generation.BeamSearchDecoderOnlyOutput`], [`~generation.BeamSearchEncoderDecoderOutput`] or
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
- [`~generation.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
- `return_dict_in_generate=True` or a [`~generation.GreedySearchEncoderDecoderOutput`] if
+ [`~generation.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
+ `return_dict_in_generate=True` or a [`~generation.BeamSearchEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.
+
Examples:
```python
- >>> from transformers import (
- ... AutoTokenizer,
- ... AutoModelForCausalLM,
- ... LogitsProcessorList,
- ... MinLengthLogitsProcessor,
- ... StoppingCriteriaList,
- ... MaxLengthCriteria,
- ... )
-
- >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
- >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
+ >>> from transformers import AutoTokenizer
+ >>> from optimum.neuron import NeuronModelForSeq2SeqLM
- >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token
- >>> model.generation_config.pad_token_id = model.generation_config.eos_token_id
+ >>> tokenizer = AutoTokenizer.from_pretrained("t5-small")
+ >>> input_shapes = {"batch_size": 1, "sequence_length": 128, "num_beams": 4}
+ >>> model = NeuronModelForSeq2SeqLM.from_pretrained("t5-small", export=True, dynamic_batch_size=False, **input_shapes)
- >>> input_prompt = "It might be possible to"
- >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
-
- >>> # instantiate logits processors
- >>> logits_processor = LogitsProcessorList(
- ... [
- ... MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id),
- ... ]
- ... )
- >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
+ >>> input_prompt = "translate English to German: Lets eat good food."
+ >>> inputs = tokenizer(input_prompt, return_tensors="pt")
- >>> outputs = model.greedy_search(
- ... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria
+ >>> # add encoder_outputs to model keyword arguments
+ >>> model_kwargs = {
+ ... "encoder_outputs": model.get_encoder()(
+ ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True
+ ... )
+ ... }
+ >>> # instantiate beam scorer
+ >>> beam_scorer = BeamSearchScorer(
+ ... batch_size=1,
+ ... num_beams=num_beams,
+ ... device=model.device,
... )
+ >>> outputs = model.beam_search(input_ids, beam_scorer)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
- ["It might be possible to get a better understanding of the nature of the problem, but it's not"]
- ```"""
+ ```
+ """
# init values
- logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
- use_cache = model_kwargs["use_cache"] if "use_cache" in model_kwargs else False
+ if logits_processor is not None and is_traced_inference:
+ logger.warning(
+ "`logits_processor` will be neglected because in `optimum-neuron`, `next_tokens` is computed inside the compiled decoder. If you want us to support custom logits_processor during the compilation, please file an issue to https://github.com/huggingface/optimum-neuron."
+ )
+ elif logits_processor is None:
+ logits_processor = LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
warnings.warn(
"`max_length` is deprecated in this function, use"
- " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
+ " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
+ if len(stopping_criteria) == 0:
+ warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning)
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
- eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
@@ -1196,8 +1104,24 @@ def greedy_search(
else self.generation_config.return_dict_in_generate
)
+ batch_size = len(beam_scorer._beam_hyps)
+ num_beams = beam_scorer.num_beams
+
+ batch_beam_size, cur_len = input_ids.shape
+
+ # Overwrite cur_len
+ cur_len = seq_length
+
+ if num_beams * batch_size != batch_beam_size:
+ raise ValueError(
+ f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
+ )
+
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
+ beam_indices = (
+ tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
+ )
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
@@ -1209,8 +1133,12 @@ def greedy_search(
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
- # keep track of which sequences are already finished
- unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
+ # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
+ # of the first beam are considered to avoid sampling the exact same tokens across all beams.
+ beam_scores_device = "cpu"
+ beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=beam_scores_device)
+ beam_scores[:, 1:] = -1e9
+ beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only
while True:
@@ -1225,61 +1153,96 @@ def greedy_search(
break
# prepare model inputs
- if use_cache:
+ if model_kwargs["use_cache"]:
# From max_length-sized input_ids, select first
- # seq_length - 1 values.
-
- if model_kwargs.get("past_key_values") is None:
- input_ids_ = input_ids[:, :seq_length]
- else:
- update_indices = torch.stack(
- [torch.arange(input_ids.size(0)), torch.tensor(seq_length - 1).repeat(input_ids.size(0))],
- dim=-1,
- )
- input_ids_ = input_ids[update_indices[:, 0], update_indices[:, 1], None]
-
+ # cur_len - 1 values.
+ update_indices = torch.stack(
+ [torch.arange(input_ids.size(0)), torch.tensor(cur_len - 1).repeat(input_ids.size(0))], dim=-1
+ )
+ input_ids_ = input_ids[update_indices[:, 0], update_indices[:, 1], None]
model_inputs = self.prepare_inputs_for_generation(input_ids_, **model_kwargs)
else:
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
- # forward pass to get next token
- outputs = self(
- **model_inputs,
- return_dict=True,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- )
+ if is_traced_inference:
+ outputs = self(
+ **model_inputs,
+ beam_scores=beam_scores,
+ return_dict=True,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+ next_token_scores = outputs.next_token_scores
+ next_tokens = outputs.next_tokens
+ next_indices = outputs.next_indices
- if synced_gpus and this_peer_finished:
- continue # don't waste resources running the code we don't need
+ if return_dict_in_generate and output_scores:
+ scores += (next_token_scores,)
+ else:
+ outputs = self(
+ **model_inputs,
+ return_dict=True,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
- if not use_cache:
- one_hot = (
- torch.cat(
- [
- torch.tensor([0]).repeat(1, seq_length - 1),
- torch.tensor([1]).repeat(1, 1),
- torch.tensor([0]).repeat(1, input_ids.size(1) - seq_length),
- ],
- dim=1,
+ if synced_gpus and this_peer_finished:
+ cur_len = cur_len + 1
+ continue # don't waste resources running the code we don't need
+
+ if not model_kwargs["use_cache"]:
+ one_hot = (
+ torch.cat(
+ [
+ torch.tensor([0]).repeat(1, cur_len - 1),
+ torch.tensor([1]).repeat(1, 1),
+ torch.tensor([0]).repeat(1, input_ids.size(1) - cur_len),
+ ],
+ dim=1,
+ )
+ .to(device=outputs.logits.device)
+ .float()
)
- .to(device=outputs.logits.device)
- .float()
+ next_token_logits = torch.matmul(one_hot, outputs.logits)
+ next_token_logits = next_token_logits.squeeze(1)
+ else:
+ next_token_logits = outputs.logits[:, -1, :]
+
+ # Manually compute log softmax
+ # log_softmax(vi) = vi - max(vi) - log(sum(exp(vi - max(vi))))
+ logit_max, _ = torch.max(next_token_logits, dim=-1, keepdim=True)
+ logsumexp = torch.log(torch.exp(next_token_logits - logit_max).sum(dim=-1, keepdim=True))
+ next_token_scores = next_token_logits - logit_max - logsumexp
+ # (batch_size * num_beams, vocab_size)
+
+ xm.mark_step()
+
+ # We don't want to change every single logit processor, so
+ # we perform this processing on CPU.
+ input_ids_ = input_ids.to("cpu")[:, :cur_len]
+ next_token_scores_ = next_token_scores.to("cpu")
+ next_token_scores_processed = logits_processor(input_ids_, next_token_scores_)
+
+ next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
+
+ # reshape for beam search
+ vocab_size = next_token_scores.shape[-1]
+ next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
+ next_token_scores = next_token_scores * 1
+
+ # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
+ next_token_scores, next_tokens = torch.topk(
+ next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
)
- next_token_logits = torch.matmul(one_hot, outputs.logits)
- next_token_logits = next_token_logits.squeeze(1)
- else:
- next_token_logits = outputs.logits[:, -1, :]
- # pre-process distribution
- # Move to cpu to handle arbitrary logits_processor
- next_tokens_scores = logits_processor(input_ids.to("cpu")[:, :seq_length], next_token_logits.to("cpu"))
- next_tokens_scores = next_tokens_scores.to(input_ids.device)
+ next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
+ next_tokens = next_tokens % vocab_size
+
+ if return_dict_in_generate and output_scores:
+ scores += (next_token_scores_processed,)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
- if output_scores:
- scores += (next_tokens_scores,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
@@ -1293,45 +1256,67 @@ def greedy_search(
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
+ # stateless
+ beam_outputs = beam_scorer.process(
+ input_ids.to("cpu")[:, :cur_len],
+ next_token_scores.to("cpu"),
+ next_tokens.to("cpu"),
+ next_indices.to("cpu"),
+ pad_token_id=pad_token_id,
+ eos_token_id=eos_token_id,
+ beam_indices=beam_indices,
+ )
- # argmax
- next_tokens = torch.argmax(next_tokens_scores, dim=-1)
-
- # finished sentences should have their next token be a padding token
- if eos_token_id is not None:
- if pad_token_id is None:
- raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
- next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
+ beam_scores = beam_outputs["next_beam_scores"]
+ beam_next_tokens = beam_outputs["next_beam_tokens"]
+ beam_idx = beam_outputs["next_beam_indices"]
- # update generated ids, model inputs, and length for next step
- batch_size, _ = input_ids.shape
update_indices = torch.stack(
- [torch.arange(batch_size), torch.tensor(seq_length).repeat(batch_size)], dim=-1
+ [torch.arange(batch_beam_size), torch.tensor(cur_len - 1).repeat(batch_beam_size)], dim=-1
)
- input_ids[update_indices[:, 0], update_indices[:, 1]] = next_tokens[:]
+ update_indices_2 = torch.stack(
+ [torch.arange(batch_beam_size), torch.tensor(cur_len).repeat(batch_beam_size)], dim=-1
+ )
+ # First select beam_indices
+ device = input_ids.device
+ beam_idx_device = beam_idx.to(device=input_ids.device)
+ input_ids[:, :] = input_ids[beam_idx_device.long(), :]
+
+ # Then append new tokens
+ if is_traced_inference:
+ # int64 is not natively supported by inf2 and has been cast down to int32
+ input_ids[update_indices_2[:, 0], update_indices_2[:, 1], None] = (
+ beam_next_tokens.unsqueeze(-1).to(device).to(torch.long)
+ )
+ else:
+ input_ids[update_indices_2[:, 0], update_indices_2[:, 1], None] = beam_next_tokens.unsqueeze(-1).to(
+ device
+ )
+ input_ids = input_ids * 1 # Hack to materialize tensor
+
+ # update generated ids, model inputs, and length for next step
model_kwargs = self._update_model_kwargs_for_xla_generation(
- outputs,
- model_kwargs,
- batch_size=batch_size,
+ outputs=outputs,
+ model_kwargs=model_kwargs,
+ batch_size=batch_beam_size,
is_encoder_decoder=self.config.is_encoder_decoder,
max_length=stopping_criteria.max_length,
- seq_length=seq_length,
- use_cache=use_cache,
+ seq_length=cur_len,
+ use_cache=model_kwargs["use_cache"],
)
+ if is_traced_inference:
+ self._reorder_cache(beam_idx.to(torch.int64))
+ elif model_kwargs["past_key_values"] is not None:
+ model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
- seq_length += 1
-
- # if eos_token was found in one sentence, set sentence to finished
- if eos_token_id_tensor is not None:
- unfinished_sequences = unfinished_sequences.mul(
- next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
- )
+ if return_dict_in_generate and output_scores:
+ beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
- xm.mark_step()
+ # increase cur_len
+ cur_len = cur_len + 1
# stop when each sentence is finished, or if we exceed the maximum length
- stop_criterion_1 = unfinished_sequences.max() == 0
-
+ stop_criterion_1 = beam_scorer.is_done
if isinstance(stopping_criteria, list):
if len(stopping_criteria) == 1:
stopping_criteria = stopping_criteria[0]
@@ -1339,34 +1324,51 @@ def greedy_search(
# Cases that can be handled in XLA without requiring
# non-padded input_ids
if isinstance(stopping_criteria, MaxLengthCriteria):
- stop_criterion_2 = seq_length >= stopping_criteria.max_length
+ stop_criterion_2 = cur_len >= stopping_criteria.max_length
elif isinstance(stopping_criteria, MaxTimeCriteria):
stop_criterion_2 = stopping_criteria(input_ids, scores)
else:
# Other cases will be handled on CPU
batch_size, _ = input_ids.shape
+ input_ids_cpu = input_ids.to("cpu")
mask = torch.cat(
- [torch.ones(batch_size, seq_length), torch.zeros(batch_size, input_ids.shape[1] - seq_length)],
- dim=1,
+ [torch.ones(batch_size, cur_len), torch.zeros(batch_size, input_ids.shape[1] - cur_len)], dim=1
).bool()
- input_ids_cpu = torch.masked_select(input_ids, mask).reshape((batch_size, seq_length)).to("cpu")
+ input_ids_cpu = torch.masked_select(input_ids_cpu, mask).reshape((batch_size, cur_len))
scores_cpu = scores.to("cpu") if torch.is_tensor(scores) else scores
stop_criterion_2 = stopping_criteria(input_ids_cpu, scores_cpu)
if stop_criterion_1 or stop_criterion_2:
- this_peer_finished = True
+ if not synced_gpus:
+ break
+ else:
+ this_peer_finished = True
- if this_peer_finished and not synced_gpus:
- break
+ sequence_outputs = beam_scorer.finalize(
+ input_ids.to("cpu"),
+ beam_scores.to("cpu"),
+ next_tokens.to("cpu"),
+ next_indices.to("cpu"),
+ pad_token_id=pad_token_id,
+ eos_token_id=eos_token_id,
+ max_length=stopping_criteria.max_length,
+ beam_indices=beam_indices,
+ )
- if streamer is not None:
- streamer.end()
+ for k, v in sequence_outputs.items():
+ if type(v) == torch.Tensor:
+ sequence_outputs[k] = sequence_outputs[k].to(input_ids.device)
if return_dict_in_generate:
+ if not output_scores:
+ sequence_outputs["sequence_scores"] = None
+
if self.config.is_encoder_decoder:
- return GreedySearchEncoderDecoderOutput(
- sequences=input_ids,
+ return BeamSearchEncoderDecoderOutput(
+ sequences=sequence_outputs["sequences"],
+ sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
+ beam_indices=sequence_outputs["beam_indices"],
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
@@ -1374,11 +1376,13 @@ def greedy_search(
decoder_hidden_states=decoder_hidden_states,
)
else:
- return GreedySearchDecoderOnlyOutput(
- sequences=input_ids,
+ return BeamSearchDecoderOnlyOutput(
+ sequences=sequence_outputs["sequences"],
+ sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
+ beam_indices=sequence_outputs["beam_indices"],
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
else:
- return input_ids
+ return sequence_outputs["sequences"]
diff --git a/optimum/neuron/modeling_base.py b/optimum/neuron/modeling_base.py
index be6a4950f..6cc1cd95c 100644
--- a/optimum/neuron/modeling_base.py
+++ b/optimum/neuron/modeling_base.py
@@ -20,7 +20,7 @@
from contextlib import contextmanager
from pathlib import Path
from tempfile import TemporaryDirectory
-from typing import TYPE_CHECKING, Dict, List, Optional, Union
+from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
import torch
from huggingface_hub import HfApi, HfFolder, hf_hub_download
@@ -297,14 +297,15 @@ def _from_transformers(
)
store_compilation_config(
- config,
- input_shapes,
- compiler_kwargs,
- input_names,
- output_names,
- dynamic_batch_size,
- compiler_type,
- compiler_version,
+ config=config,
+ input_shapes=input_shapes,
+ compiler_kwargs=compiler_kwargs,
+ input_names=input_names,
+ output_names=output_names,
+ dynamic_batch_size=dynamic_batch_size,
+ compiler_type=compiler_type,
+ compiler_version=compiler_version,
+ task=task,
)
config.save_pretrained(save_dir_path)
@@ -375,9 +376,6 @@ def _attributes_init(
self.preprocessors = preprocessors if preprocessors is not None else []
- self.input_names = getattr(self.config, "input_names", [])
- self.output_names = getattr(self.config, "output_names", [])
-
# Registers the NeuronModelForXXX classes into the transformers AutoModel classes to avoid warnings when creating
# a pipeline https://github.com/huggingface/transformers/blob/3d3204c025b6b5de013e07dd364208e28b4d9589/src/transformers/pipelines/base.py#L940
AutoConfig.register(self.model_type, AutoConfig)
@@ -395,10 +393,10 @@ def _neuron_config_init(cls, config: "PretrainedConfig") -> "NeuronConfig":
)
return
- neuron_configs = config.neuron
+ neuron_config = config.neuron
# Fetch compiler information
- compiler_type = neuron_configs.get("compiler_type")
- compiler_version = neuron_configs.get("compiler_version")
+ compiler_type = neuron_config.get("compiler_type")
+ compiler_version = neuron_config.get("compiler_version")
# Fetch mandatory shapes from config
compile_shapes = {
@@ -410,13 +408,14 @@ def _neuron_config_init(cls, config: "PretrainedConfig") -> "NeuronConfig":
# Neuron config constructuor
task = getattr(config, "task") or TasksManager.infer_task_from_model(cls.auto_model_class)
task = TasksManager.map_from_synonym(task)
+ model_type = neuron_config.get("model_type", None) or config.model_type
neuron_config_constructor = TasksManager.get_exporter_config_constructor(
- model_type=config.model_type, exporter="neuron", task=task
+ model_type=model_type, exporter="neuron", task=task
)
return neuron_config_constructor(
config,
- dynamic_batch_size=neuron_configs.get("dynamic_batch_size", False),
+ dynamic_batch_size=neuron_config.get("dynamic_batch_size", False),
compiler_type=compiler_type,
compiler_version=compiler_version,
**compile_shapes,
@@ -453,10 +452,19 @@ def _raise_if_invalid_padding(self, input_name, input_tensor, target_shapes, to_
f" than the static shapes used for compilation: {target_shapes}{extra}."
)
- def _pad_to_compiled_shape(self, inputs: Dict[str, "torch.Tensor"]):
+ def _pad_to_compiled_shape(
+ self, inputs: Dict[str, "torch.Tensor"], padding_side: Literal["right", "left"] = "right"
+ ):
"""
Pads input tensors if they are not in valid shape.
+
+ Args:
+ inputs (`Dict[str, "torch.Tensor"]`):
+ Dictionary of input torch tensors.
+ padding_side (`Literal["right", "left"]`, defaults to "right"):
+ The side on which to apply the padding.
"""
+ logger.info(f"Padding input tensors, the padding side is: {padding_side}.")
for input_name, input_tensor in inputs.items():
target_shapes = self.input_static_shapes[input_name]
padding = ()
@@ -468,7 +476,7 @@ def _pad_to_compiled_shape(self, inputs: Dict[str, "torch.Tensor"]):
to_pad = target_shapes[i] - input_tensor.size(i)
self._raise_if_invalid_padding(input_name, input_tensor, target_shapes, to_pad, i)
- padding += (0, to_pad)
+ padding += (0, to_pad) if padding_side == "right" else (to_pad, 0)
if (
self.preprocessors is not None
@@ -484,16 +492,21 @@ def _pad_to_compiled_shape(self, inputs: Dict[str, "torch.Tensor"]):
# Pad to batch size: dimension 0 (pad_token_id can't be 0)
padding = (0,) * len(padding)
- if self.neuron_config.dynamic_batch_size is True and input_tensor.size(0) % target_shapes[0] == 0:
+ is_encoder_decoder = getattr(self.config, "is_encoder_decoder", False)
+ if (
+ not is_encoder_decoder
+ and self.neuron_config.dynamic_batch_size is True
+ and input_tensor.size(0) % target_shapes[0] == 0
+ ):
inputs[input_name] = input_tensor
continue
- elif self.neuron_config.dynamic_batch_size is True:
+ elif not is_encoder_decoder and self.neuron_config.dynamic_batch_size is True:
target_shape = (input_tensor.size(0) // target_shapes[0] + 1) * target_shapes[0]
to_pad = target_shape - input_tensor.size(0)
else:
to_pad = target_shapes[0] - input_tensor.size(0)
self._raise_if_invalid_padding(input_name, input_tensor, target_shapes, to_pad, 0)
- padding += (0, to_pad)
+ padding += (0, to_pad) if padding_side == "right" else (to_pad, 0)
pad_id = 1
inputs[input_name] = torch.nn.functional.pad(input_tensor, padding, mode="constant", value=pad_id)
@@ -505,7 +518,13 @@ def neuron_padding_manager(self, inputs: Dict[str, "torch.Tensor"]):
inputs = tuple(self._pad_to_compiled_shape(inputs).values())
yield inputs
- def remove_padding(self, outputs: List[torch.Tensor], dims: List[int], indices: List[int]) -> List[torch.Tensor]:
+ @staticmethod
+ def remove_padding(
+ outputs: List[torch.Tensor],
+ dims: List[int],
+ indices: List[int],
+ padding_side: Literal["right", "left"] = "right",
+ ) -> List[torch.Tensor]:
"""
Removes padding from output tensors.
@@ -516,12 +535,26 @@ def remove_padding(self, outputs: List[torch.Tensor], dims: List[int], indices:
List of dimensions in which we slice a tensor.
indices (`List[int]`):
List of indices in which we slice a tensor along an axis.
+ padding_side (`Literal["right", "left"]`, defaults to "right"):
+ The side on which the padding has been applied.
"""
if len(dims) != len(indices):
raise ValueError(f"The size of `dims`({len(dims)}) and indices`({len(indices)}) must be equal.")
+
for dim, indice in zip(dims, indices):
- outputs = [
- torch.index_select(output_tensor, dim, torch.LongTensor(range(indice))) for output_tensor in outputs
- ]
+ if padding_side == "right":
+ outputs = [
+ torch.index_select(output_tensor, dim, torch.LongTensor(range(indice)))
+ for output_tensor in outputs
+ ]
+ elif padding_side == "left":
+ outputs = [
+ torch.index_select(
+ output_tensor,
+ dim,
+ torch.LongTensor(range(output_tensor.shape[dim] - indice, output_tensor.shape[dim])),
+ )
+ for output_tensor in outputs
+ ]
return outputs
diff --git a/optimum/neuron/modeling_diffusion.py b/optimum/neuron/modeling_diffusion.py
index 80fe64b09..356fb1603 100644
--- a/optimum/neuron/modeling_diffusion.py
+++ b/optimum/neuron/modeling_diffusion.py
@@ -333,6 +333,12 @@ def _save_pretrained(
"""
Saves the model to the serialized format optimized for Neuron devices.
"""
+ if self.model_and_config_save_paths is None:
+ logger.warning(
+ "`model_save_paths` is None which means that no path of Neuron model is defined. Nothing will be saved."
+ )
+ return
+
save_directory = Path(save_directory)
if not self.model_and_config_save_paths.get(DIFFUSION_MODEL_VAE_ENCODER_NAME)[0].is_file():
self.model_and_config_save_paths.pop(DIFFUSION_MODEL_VAE_ENCODER_NAME)
@@ -343,13 +349,7 @@ def _save_pretrained(
if not self.model_and_config_save_paths.get(DIFFUSION_MODEL_TEXT_ENCODER_2_NAME)[0].is_file():
self.model_and_config_save_paths.pop(DIFFUSION_MODEL_TEXT_ENCODER_2_NAME)
- if self.model_and_config_save_paths is None:
- logger.warning(
- "`model_save_paths` is None which means that no path of Neuron model is defined. Nothing will be saved."
- )
- return
- else:
- logger.info(f"Saving the {tuple(self.model_and_config_save_paths.keys())}...")
+ logger.info(f"Saving the {tuple(self.model_and_config_save_paths.keys())}...")
dst_paths = {
DIFFUSION_MODEL_TEXT_ENCODER_NAME: save_directory
@@ -399,6 +399,7 @@ def _from_pretrained(
config: Dict[str, Any],
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
+ force_download: bool = False,
cache_dir: Optional[str] = None,
text_encoder_file_name: Optional[str] = NEURON_FILE_NAME,
text_encoder_2_file_name: Optional[str] = NEURON_FILE_NAME,
@@ -439,6 +440,7 @@ def _from_pretrained(
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
+ force_download=force_download,
allow_patterns=allow_patterns,
ignore_patterns=["*.msgpack", "*.safetensors", "*.bin"],
)
diff --git a/optimum/neuron/modeling_seq2seq.py b/optimum/neuron/modeling_seq2seq.py
new file mode 100644
index 000000000..3d42a7129
--- /dev/null
+++ b/optimum/neuron/modeling_seq2seq.py
@@ -0,0 +1,601 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""NeuroModelForXXX classes for seq2seq models' inference on Neuron devices."""
+
+import copy
+import logging
+import os
+import shutil
+from abc import ABC, abstractmethod
+from pathlib import Path
+from tempfile import TemporaryDirectory
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from huggingface_hub import snapshot_download
+from transformers import AutoConfig, AutoModelForSeq2SeqLM, GenerationConfig
+from transformers.generation.logits_process import LogitsProcessorList
+from transformers.generation.stopping_criteria import StoppingCriteriaList
+from transformers.utils import ModelOutput
+
+from ..exporters.neuron import (
+ NeuronConfig,
+ main_export,
+)
+from ..exporters.tasks import TasksManager
+from ..utils.save_utils import maybe_load_preprocessors
+from .generation import NeuronGenerationMixin
+from .modeling_base import NeuronBaseModel
+from .utils import (
+ DECODER_NAME,
+ ENCODER_NAME,
+ NEURON_FILE_NAME,
+ is_neuronx_available,
+)
+
+
+if TYPE_CHECKING:
+ from transformers import PretrainedConfig, PreTrainedModel
+
+if is_neuronx_available():
+ import torch_neuronx
+
+logger = logging.getLogger(__name__)
+
+
+class NeuronModelForConditionalGeneration(NeuronBaseModel, ABC):
+ base_model_prefix = "neuron_model"
+ config_name = "config.json"
+
+ def __init__(
+ self,
+ encoder: torch.jit._script.ScriptModule,
+ decoder: torch.jit._script.ScriptModule,
+ config: "PretrainedConfig",
+ model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
+ encoder_file_name: Optional[str] = NEURON_FILE_NAME,
+ decoder_file_name: Optional[str] = NEURON_FILE_NAME,
+ preprocessors: Optional[List] = None,
+ neuron_configs: Optional[Dict[str, "NeuronConfig"]] = None,
+ configs: Optional[Dict[str, "PretrainedConfig"]] = None,
+ generation_config: Optional[GenerationConfig] = None,
+ model_and_config_save_paths: Optional[Dict[str, Tuple[str, Path]]] = None,
+ **kwargs,
+ ):
+ self.config = config
+ self.configs = configs
+ self.neuron_configs = neuron_configs
+ self.input_static_shapes = NeuronModelForConditionalGeneration.get_input_static_shapes(
+ self.neuron_configs[ENCODER_NAME]
+ ) # only for the encoder
+ self._attributes_init(model_save_dir, preprocessors, **kwargs)
+ self.model_and_config_save_paths = model_and_config_save_paths if model_and_config_save_paths else None
+ self.encoder = NeuronEncoder(
+ encoder,
+ self,
+ self.configs[ENCODER_NAME],
+ self.neuron_configs[ENCODER_NAME],
+ )
+ self.decoder = NeuronDecoder(
+ decoder,
+ self,
+ self.configs[DECODER_NAME],
+ self.neuron_configs[DECODER_NAME],
+ )
+ self.dynamic_batch_size = all(
+ neuron_config._config.neuron["dynamic_batch_size"] for neuron_config in self.neuron_configs.values()
+ )
+ self.encoder_file_name = encoder_file_name
+ self.decoder_file_name = decoder_file_name
+
+ if generation_config is None:
+ generation_config = GenerationConfig.from_model_config(self.configs[DECODER_NAME])
+ self.generation_config = generation_config
+
+ def _save_pretrained(
+ self,
+ save_directory: Union[str, Path],
+ encoder_file_name: str = NEURON_FILE_NAME,
+ decoder_file_name: str = NEURON_FILE_NAME,
+ ):
+ """
+ Saves the model encoder and decoder as well as their configuration files to a
+ directory, so that it can be re-loaded using the
+ [`~optimum.neuron.modeling_seq2seq.NeuronModelForSeq2SeqLM.from_pretrained`] class method.
+
+ Args:
+ save_directory (`Union[str, Path`]):
+ The directory where to save the model files.
+ encoder_file_name (`str`, defaults to `NEURON_FILE_NAME`]):
+ The file name to save the encoder.
+ decoder_file_name (`str`, defaults to `NEURON_FILE_NAME`]):
+ The file name to save the decoder.
+ """
+ if self.model_and_config_save_paths is None:
+ logger.warning(
+ "`model_save_paths` is None which means that no path of Neuron model is defined. Nothing will be saved."
+ )
+ return
+
+ save_directory = Path(save_directory)
+ if not self.model_and_config_save_paths.get(ENCODER_NAME)[0].is_file():
+ self.model_and_config_save_paths.pop(ENCODER_NAME)
+
+ if not self.model_and_config_save_paths.get(DECODER_NAME)[0].is_file():
+ self.model_and_config_save_paths.pop(DECODER_NAME)
+
+ dst_paths = [
+ save_directory / ENCODER_NAME / encoder_file_name,
+ save_directory / DECODER_NAME / decoder_file_name,
+ ]
+ src_paths = [
+ Path(self.model_and_config_save_paths[ENCODER_NAME][0]),
+ Path(self.model_and_config_save_paths[DECODER_NAME][0]),
+ ]
+
+ for src_path, dst_path in zip(src_paths, dst_paths):
+ dst_path.parent.mkdir(parents=True, exist_ok=True)
+ if src_path.is_file():
+ shutil.copyfile(src_path, dst_path)
+
+ self.generation_config.save_pretrained(save_directory)
+
+ @classmethod
+ def _from_pretrained(
+ cls,
+ model_id: Union[str, Path],
+ config: "PretrainedConfig",
+ use_auth_token: Optional[Union[bool, str]] = None,
+ revision: Optional[str] = None,
+ force_download: bool = False,
+ cache_dir: Optional[str] = None,
+ encoder_file_name: Optional[str] = NEURON_FILE_NAME,
+ decoder_file_name: Optional[str] = NEURON_FILE_NAME,
+ subfolder: str = "",
+ local_files_only: bool = False,
+ model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
+ **kwargs,
+ ):
+ model_id = str(model_id)
+
+ if not os.path.isdir(model_id):
+ # Downloads all repo's files matching the allowed patterns
+ model_id = snapshot_download(
+ model_id,
+ cache_dir=cache_dir,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ force_download=force_download,
+ ignore_patterns=["*.msgpack", "*.safetensors", "*.bin"], # only download *.neuron artifacts
+ )
+
+ preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder)
+
+ new_model_save_dir = Path(model_id)
+
+ model_and_config_save_paths = {
+ "encoder": (
+ new_model_save_dir / ENCODER_NAME / encoder_file_name,
+ new_model_save_dir / ENCODER_NAME / cls.config_name,
+ ),
+ "decoder": (
+ new_model_save_dir / DECODER_NAME / decoder_file_name,
+ new_model_save_dir / DECODER_NAME / cls.config_name,
+ ),
+ }
+
+ # Re-build pretrained configs and neuron configs
+ configs, neuron_configs = {}, {}
+ for name, file_paths in model_and_config_save_paths.items():
+ if file_paths[1].is_file():
+ model_config = AutoConfig.from_pretrained(file_paths[1])
+ configs[name] = model_config
+ neuron_configs[name] = cls._neuron_config_init(model_config)
+
+ # Initialize Neuron Runtime before loading models
+ runtime = torch.classes.neuron.Runtime()
+ runtime.initialize()
+ runtime.set_default_neuron_cores(0, 1)
+
+ encoder = cls.load_model(model_and_config_save_paths[ENCODER_NAME][0])
+ decoder = cls.load_model(model_and_config_save_paths[DECODER_NAME][0])
+ torch_neuronx.move_trace_to_device(decoder, 0)
+
+ if model_save_dir is None:
+ model_save_dir = new_model_save_dir
+
+ generation_config = None
+ try:
+ generation_config = GenerationConfig.from_pretrained(
+ model_id,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=os.path.join(subfolder, DECODER_NAME),
+ )
+ except OSError:
+ logger.info("Generation config file not found, using a generation config created from the model config.")
+
+ return cls(
+ encoder=encoder,
+ decoder=decoder,
+ config=config,
+ model_save_dir=model_save_dir,
+ encoder_file_name=encoder_file_name,
+ decoder_file_name=decoder_file_name,
+ preprocessors=preprocessors,
+ neuron_configs=neuron_configs,
+ configs=configs,
+ generation_config=generation_config,
+ model_and_config_save_paths=model_and_config_save_paths,
+ )
+
+ @classmethod
+ def _from_transformers(
+ cls,
+ model_id: str,
+ config: "PretrainedConfig",
+ use_auth_token: Optional[Union[bool, str]] = None,
+ revision: str = "main",
+ force_download: bool = True,
+ cache_dir: Optional[str] = None,
+ subfolder: str = "",
+ local_files_only: bool = False,
+ trust_remote_code: bool = False,
+ task: Optional[str] = None,
+ auto_cast: Optional[str] = "matmul",
+ auto_cast_type: Optional[str] = "bf16",
+ disable_fast_relayout: Optional[bool] = False,
+ disable_fallback: bool = False,
+ dynamic_batch_size: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ **kwargs_shapes,
+ ) -> "NeuronModelForConditionalGeneration":
+ if dynamic_batch_size is True:
+ logger.warning(
+ "Sequence-to-sequence models don't support dynamic batch size yet, `dynamic_batch_size` will be set to False."
+ )
+
+ if task is None:
+ task = TasksManager.infer_task_from_model(cls.auto_model_class)
+
+ # Get compilation arguments
+ auto_cast_type = None if auto_cast is None else auto_cast_type
+ compiler_kwargs = {
+ "auto_cast": auto_cast,
+ "auto_cast_type": auto_cast_type,
+ "disable_fast_relayout": disable_fast_relayout,
+ "disable_fallback": disable_fallback,
+ }
+
+ save_dir = TemporaryDirectory()
+ save_dir_path = Path(save_dir.name)
+
+ main_export(
+ model_name_or_path=model_id,
+ output=save_dir_path,
+ compiler_kwargs=compiler_kwargs,
+ task=task,
+ dynamic_batch_size=dynamic_batch_size,
+ cache_dir=cache_dir,
+ trust_remote_code=trust_remote_code,
+ subfolder=subfolder,
+ revision=revision,
+ force_download=force_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ do_validation=False,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ **kwargs_shapes,
+ )
+
+ return cls._from_pretrained(
+ model_id=save_dir_path,
+ config=config,
+ model_save_dir=save_dir,
+ )
+
+ def _save_config(self, save_directory):
+ save_directory = Path(save_directory)
+ self.configs[ENCODER_NAME].save_pretrained(save_directory / ENCODER_NAME)
+ self.configs[DECODER_NAME].save_pretrained(save_directory / DECODER_NAME)
+ combined_config = self._combine_encoder_decoder_config(
+ encoder_config=self.configs[ENCODER_NAME],
+ decoder_config=self.configs[DECODER_NAME],
+ )
+ combined_config.save_pretrained(save_directory)
+
+ def _combine_encoder_decoder_config(self, encoder_config: "PretrainedConfig", decoder_config: "PretrainedConfig"):
+ encoder_neuron_config = encoder_config.neuron
+ decoder_neuron_config = decoder_config.neuron
+ combined_config = copy.deepcopy(encoder_config)
+
+ encoder_neuron_config["encoder_input_names"] = encoder_neuron_config.pop("input_names")
+ encoder_neuron_config["encoder_output_names"] = encoder_neuron_config.pop("output_names")
+ decoder_neuron_config["decoder_input_names"] = decoder_neuron_config.pop("input_names")
+ decoder_neuron_config["decoder_output_names"] = decoder_neuron_config.pop("output_names")
+
+ encoder_neuron_config.update(decoder_neuron_config)
+ encoder_neuron_config.pop("model_type")
+ combined_config.__setattr__("neuron", encoder_neuron_config)
+
+ return combined_config
+
+
+class NeuronModelForSeq2SeqLM(NeuronModelForConditionalGeneration, NeuronGenerationMixin):
+ auto_model_class = AutoModelForSeq2SeqLM
+ main_input_name = "input_ids"
+
+ def forward(
+ self,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
+ encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ beam_scores: Optional[torch.FloatTensor] = None,
+ return_dict: bool = False,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
+ ) -> Union[Tuple[torch.FloatTensor], ModelOutput]:
+ hidden_states = encoder_outputs["last_hidden_state"]
+
+ if not hasattr(self, "beam_idx"):
+ # Infering the number of beams from the attention mask
+ num_beams = attention_mask.shape[0]
+ self.beam_idx = torch.arange(0, num_beams, dtype=torch.int64)
+
+ outputs = self.decoder(
+ decoder_input_ids, decoder_attention_mask, hidden_states, attention_mask, self.beam_idx, beam_scores
+ )
+
+ # Fetch optional outputs
+ cur_idx = 0
+ cross_attentions = None
+ decoder_attentions = None
+ decoder_hidden_states = None
+
+ # Skip pkv which can't be copied from memory to buffer
+ if output_attentions and self.config.neuron.get("output_attentions"):
+ if self.config.is_encoder_decoder:
+ cross_attentions = outputs[-self.config.num_decoder_layers :]
+ cur_idx += self.config.num_decoder_layers
+ decoder_attentions = outputs[-(self.config.num_decoder_layers + cur_idx) : -cur_idx]
+ cur_idx += self.config.num_decoder_layers
+
+ if output_hidden_states and self.config.neuron.get("output_hidden_states"):
+ decoder_hidden_states = outputs[-(self.config.num_decoder_layers + 1 + cur_idx) : -cur_idx]
+
+ decoder_outputs = ModelOutput(
+ next_token_scores=outputs[0],
+ next_tokens=outputs[1],
+ next_indices=outputs[2],
+ cross_attentions=cross_attentions,
+ decoder_attentions=decoder_attentions,
+ decoder_hidden_states=decoder_hidden_states,
+ )
+
+ if return_dict:
+ return decoder_outputs
+ else:
+ return decoder_outputs.to_tuple()
+
+ def generate(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
+ assistant_model: Optional["PreTrainedModel"] = None,
+ num_return_sequences: int = 1,
+ **kwargs,
+ ):
+ max_length = self.neuron_configs[ENCODER_NAME].sequence_length
+ num_beams = self.neuron_configs[ENCODER_NAME].num_beams
+ batch_size = self.neuron_configs[ENCODER_NAME].batch_size
+
+ inputs = {"input_ids": input_ids}
+ if attention_mask is not None:
+ inputs["attention_mask"] = attention_mask
+ inputs = self._pad_to_compiled_shape(inputs)
+
+ past_key_values = self.encoder(**inputs)
+
+ decoder_attention_mask = torch.cat(
+ [
+ torch.zeros((batch_size, max_length - 1), dtype=torch.int64),
+ torch.ones((batch_size, 1), dtype=torch.int64),
+ ],
+ axis=1,
+ )
+
+ # copy the new cache state to the decoder
+ for state, tensor in zip(self.decoder.model.parameters(), past_key_values):
+ state.copy_(tensor)
+
+ output = super().generate(
+ **inputs,
+ generation_config=generation_config,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
+ assistant_model=assistant_model,
+ num_return_sequences=num_return_sequences,
+ max_length=kwargs.pop("max_length", None) or max_length,
+ max_new_tokens=kwargs.pop("max_new_tokens", None),
+ output_attentions=kwargs.pop("output_attentions", False),
+ output_hidden_states=kwargs.pop("output_hidden_states", False),
+ output_scores=kwargs.pop("output_scores", False),
+ return_dict_in_generate=kwargs.pop("return_dict_in_generate", False),
+ num_beams=num_beams,
+ do_sample=kwargs.pop("do_sample", False),
+ use_cache=True, # pkv is cached by default in
+ decoder_attention_mask=decoder_attention_mask,
+ # Pass fake encoder_outputs so the transfomers code will not invoke the encoder
+ encoder_outputs={"last_hidden_state": torch.ones((batch_size, max_length, 1))},
+ is_traced_inference=True,
+ )
+ return output
+
+ def _reorder_cache(self, beam_idx):
+ """
+ The cache was reordered during the tracing of the decoder so we can skip it here. This is needed for beam search and not greedy sampling.
+ """
+ self.beam_idx = beam_idx
+
+ def get_encoder(self) -> "NeuronEncoder":
+ return self.encoder
+
+ def _update_model_kwargs_for_xla_generation(
+ self,
+ model_kwargs: Dict[str, Any],
+ batch_size: int,
+ is_encoder_decoder: bool = False,
+ # Leave following kwargs for compatibility, will not have any effect.
+ outputs: "ModelOutput" = None,
+ standardize_cache_format: bool = False,
+ max_length: Optional[int] = None,
+ seq_length: Optional[int] = None,
+ use_cache: bool = True,
+ ) -> Dict[str, Any]:
+ mask = self._update_attention(model_kwargs, batch_size, is_encoder_decoder)
+ # sets the updated variables (mask and past_key_values)
+ model_kwargs.update(mask)
+
+ return model_kwargs
+
+ # Override to cut the input_ids to just last token
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ attention_mask=None,
+ decoder_attention_mask=None,
+ encoder_outputs=None,
+ **kwargs,
+ ):
+ # cut decoder_input_ids as past is cached
+ input_ids = input_ids[:, -1:]
+
+ return {
+ "decoder_input_ids": input_ids,
+ "encoder_outputs": encoder_outputs,
+ "attention_mask": attention_mask,
+ "decoder_attention_mask": decoder_attention_mask,
+ }
+
+ def _validate_static_shape(self, input_shapes: List[int], target_shapes: List[int]) -> bool:
+ """
+ Checks if a input needs to be padded.
+ """
+ return input_shapes == target_shapes
+
+ def can_generate(self):
+ """Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
+ return True
+
+
+class _NeuronSeq2SeqModelPart:
+ """
+ For Seq2Seq architecture, we usually compile it to multiple neuron models. Each represents a part of the model.
+ """
+
+ def __init__(
+ self,
+ model: torch.jit._script.ScriptModule,
+ parent_model: NeuronBaseModel,
+ config: Optional["PretrainedConfig"] = None,
+ neuron_config: Optional["NeuronConfig"] = None,
+ model_type: str = "encoder",
+ device: Optional[int] = None,
+ ):
+ self.model = model
+ self.parent_model = parent_model
+ self.config = config
+ self.neuron_config = neuron_config
+ self.model_type = model_type
+ self.device = device
+
+ @abstractmethod
+ def forward(self, *args, **kwargs):
+ pass
+
+ def __call__(self, *args, **kwargs):
+ return self.forward(*args, **kwargs)
+
+
+class NeuronEncoder(_NeuronSeq2SeqModelPart):
+ """
+ Encoder part of the encoder-decoder model for Neuron inference. (Actually it's a monolith of encoder + decoder without past_key_values to workaround the control flow in the decoder).
+ """
+
+ main_input_name = "input_ids"
+
+ def __init__(
+ self,
+ model: torch.jit._script.ScriptModule,
+ parent_model: NeuronBaseModel,
+ config: Optional["PretrainedConfig"] = None,
+ neuron_config: Optional[Dict[str, str]] = None,
+ ):
+ super().__init__(model, parent_model, config, neuron_config, "encoder")
+
+ def forward(self, input_ids: torch.LongTensor, attention_mask: torch.FloatTensor):
+ inputs = (
+ input_ids,
+ attention_mask,
+ )
+ outputs = self.model(*inputs)
+ return outputs
+
+
+class NeuronDecoder(_NeuronSeq2SeqModelPart):
+ """
+ Decoder part of the encoder-decoder model for Neuron inference. (Actually it's decoder with past_key_values).
+ """
+
+ def __init__(
+ self,
+ model: torch.jit._script.ScriptModule,
+ parent_model: NeuronBaseModel,
+ config: Optional["PretrainedConfig"] = None,
+ neuron_config: Optional[Dict[str, str]] = None,
+ ):
+ super().__init__(model, parent_model, config, neuron_config, "decoder")
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor,
+ decoder_attention_mask: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor,
+ encoder_attention_mask: torch.FloatTensor,
+ beam_idx: torch.LongTensor,
+ beam_scores: torch.FloatTensor,
+ ):
+ inputs = (
+ input_ids,
+ decoder_attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ beam_idx,
+ beam_scores,
+ )
+ outputs = self.model(*inputs)
+ return outputs
diff --git a/optimum/neuron/utils/__init__.py b/optimum/neuron/utils/__init__.py
index 559f501c3..c859ba71b 100644
--- a/optimum/neuron/utils/__init__.py
+++ b/optimum/neuron/utils/__init__.py
@@ -15,11 +15,13 @@
from .argument_utils import convert_neuronx_compiler_args_to_neuron, store_compilation_config
from .constant import (
+ DECODER_NAME,
DIFFUSION_MODEL_TEXT_ENCODER_2_NAME,
DIFFUSION_MODEL_TEXT_ENCODER_NAME,
DIFFUSION_MODEL_UNET_NAME,
DIFFUSION_MODEL_VAE_DECODER_NAME,
DIFFUSION_MODEL_VAE_ENCODER_NAME,
+ ENCODER_NAME,
NEURON_FILE_NAME,
)
from .import_utils import (
@@ -31,6 +33,7 @@
is_torch_xla_available,
is_transformers_neuronx_available,
)
+from .input_generators import DummyBeamValuesGenerator
from .optimization_utils import get_attention_scores_sd, get_attention_scores_sdxl
from .patching import DynamicPatch, ModelPatcher, Patcher, patch_everywhere, patch_within_function
from .training_utils import (
diff --git a/optimum/neuron/utils/argument_utils.py b/optimum/neuron/utils/argument_utils.py
index 68c79b684..d910cd074 100644
--- a/optimum/neuron/utils/argument_utils.py
+++ b/optimum/neuron/utils/argument_utils.py
@@ -147,6 +147,8 @@ def store_compilation_config(
compiler_version: str,
model_type: Optional[str] = None,
task: str = None,
+ output_attentions: bool = False,
+ output_hidden_states: bool = False,
**kwargs,
):
if isinstance(config, OrderedDict):
@@ -173,6 +175,21 @@ def store_compilation_config(
config_args["input_names"] = input_names
config_args["output_names"] = output_names
+ original_model_type = getattr(config, "model_type", None)
+ neuron_model_type = str(model_type).replace("_", "-") if model_type is not None else model_type
+ if original_model_type is None:
+ update_func(
+ "model_type", neuron_model_type
+ ) # Add model_type to the config if it doesn't exist before, eg. submodel of Stable Diffusion.
+ else:
+ config_args["model_type"] = (
+ neuron_model_type or original_model_type
+ ) # Prioritize Neuron custom model_type, eg. `t5-encoder`.
+
+ # Add args of optional outputs
+ config_args["output_attentions"] = output_attentions
+ config_args["output_hidden_states"] = output_hidden_states
+
update_func("neuron", config_args)
if hasattr(config, "_diffusers_version"):
@@ -180,9 +197,6 @@ def store_compilation_config(
update_func("_diffusers_version", diffusers.__version__)
- model_type = getattr(config, "model_type", None) or model_type
- model_type = str(model_type).replace("_", "-")
- update_func("model_type", model_type)
update_func("task", task)
return config
diff --git a/optimum/neuron/utils/constant.py b/optimum/neuron/utils/constant.py
index 7719ce8a2..edc6eebb8 100644
--- a/optimum/neuron/utils/constant.py
+++ b/optimum/neuron/utils/constant.py
@@ -15,6 +15,8 @@
"""Constants used as default values."""
NEURON_FILE_NAME = "model.neuron"
+ENCODER_NAME = "encoder"
+DECODER_NAME = "decoder"
DIFFUSION_MODEL_TEXT_ENCODER_NAME = "text_encoder"
DIFFUSION_MODEL_TEXT_ENCODER_2_NAME = "text_encoder_2"
DIFFUSION_MODEL_UNET_NAME = "unet"
diff --git a/optimum/neuron/utils/input_generators.py b/optimum/neuron/utils/input_generators.py
new file mode 100644
index 000000000..91a1657d9
--- /dev/null
+++ b/optimum/neuron/utils/input_generators.py
@@ -0,0 +1,46 @@
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Dummy input generation classes."""
+
+import torch
+
+from ...utils import DTYPE_MAPPER, DummyInputGenerator, NormalizedTextConfig
+
+
+class DummyBeamValuesGenerator(DummyInputGenerator):
+ """
+ Generates dummy beam search inputs.
+ """
+
+ SUPPORTED_INPUT_NAMES = (
+ "beam_idx",
+ "beam_scores",
+ )
+
+ def __init__(
+ self,
+ task: str,
+ normalized_config: NormalizedTextConfig,
+ num_beams: int = 1,
+ **kwargs,
+ ):
+ self.task = task
+ self.num_beams = num_beams
+
+ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
+ if input_name == "beam_idx":
+ return torch.arange(0, self.num_beams, dtype=DTYPE_MAPPER.pt(int_dtype))
+ elif input_name == "beam_scores":
+ return torch.zeros((self.num_beams,), dtype=DTYPE_MAPPER.pt(float_dtype))
diff --git a/tests/cli/test_export_cli.py b/tests/cli/test_export_cli.py
index 09c8c063b..a8abf30f4 100644
--- a/tests/cli/test_export_cli.py
+++ b/tests/cli/test_export_cli.py
@@ -248,3 +248,63 @@ def test_replace_unet(self):
shell=False,
check=True,
)
+
+ @requires_neuronx
+ def test_encoder_decoder(self):
+ model_id = "hf-internal-testing/tiny-random-t5"
+ with tempfile.TemporaryDirectory() as tempdir:
+ subprocess.run(
+ [
+ "optimum-cli",
+ "export",
+ "neuron",
+ "--model",
+ model_id,
+ "--task",
+ "text2text-generation",
+ "--batch_size",
+ "1",
+ "--sequence_length",
+ "18",
+ "--num_beams",
+ "4",
+ "--auto_cast",
+ "matmul",
+ "--auto_cast_type",
+ "bf16",
+ tempdir,
+ ],
+ shell=False,
+ check=True,
+ )
+
+ @requires_neuronx
+ def test_encoder_decoder_optional_outputs(self):
+ model_id = "hf-internal-testing/tiny-random-t5"
+ with tempfile.TemporaryDirectory() as tempdir:
+ subprocess.run(
+ [
+ "optimum-cli",
+ "export",
+ "neuron",
+ "--model",
+ model_id,
+ "--task",
+ "text2text-generation",
+ "--batch_size",
+ "1",
+ "--sequence_length",
+ "18",
+ "--num_beams",
+ "4",
+ "--auto_cast",
+ "matmul",
+ "--auto_cast_type",
+ "bf16",
+ "--output_hidden_states",
+ "--output_attentions",
+ tempdir,
+ ],
+ shell=False,
+ check=True,
+ )
diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py
index 96f090384..0867b93bd 100644
--- a/tests/exporters/exporters_utils.py
+++ b/tests/exporters/exporters_utils.py
@@ -32,6 +32,10 @@
"xlm-roberta": "hf-internal-testing/tiny-xlm-roberta",
}
+ENCODER_DECODER_MODELS_TINY = {
+ "t5": "hf-internal-testing/tiny-random-t5",
+}
+
STABLE_DIFFUSION_MODELS_TINY = {
"stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch",
"stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl",
diff --git a/tests/exporters/test_export.py b/tests/exporters/test_export.py
index fcd18b98e..41507453a 100644
--- a/tests/exporters/test_export.py
+++ b/tests/exporters/test_export.py
@@ -14,7 +14,6 @@
# limitations under the License.
import copy
-import os
import random
import unittest
from pathlib import Path
@@ -22,7 +21,7 @@
from typing import Dict, Optional
from parameterized import parameterized
-from transformers import AutoConfig, set_seed
+from transformers import AutoConfig, AutoModelForSeq2SeqLM, set_seed
from transformers.testing_utils import require_vision
from optimum.exporters.neuron import (
@@ -30,25 +29,17 @@
build_stable_diffusion_components_mandatory_shapes,
export,
export_models,
- get_stable_diffusion_models_for_export,
validate_model_outputs,
validate_models_outputs,
)
+from optimum.exporters.neuron.__main__ import _get_submodels_and_neuron_configs
from optimum.exporters.neuron.model_configs import * # noqa: F403
from optimum.exporters.tasks import TasksManager
-from optimum.neuron.utils import (
- DIFFUSION_MODEL_TEXT_ENCODER_2_NAME,
- DIFFUSION_MODEL_TEXT_ENCODER_NAME,
- DIFFUSION_MODEL_UNET_NAME,
- DIFFUSION_MODEL_VAE_DECODER_NAME,
- DIFFUSION_MODEL_VAE_ENCODER_NAME,
- NEURON_FILE_NAME,
-)
from optimum.neuron.utils.testing_utils import is_inferentia_test, requires_neuronx
from optimum.utils import DEFAULT_DUMMY_SHAPES, is_diffusers_available, logging
from optimum.utils.testing_utils import require_diffusers
-from .exporters_utils import EXPORT_MODELS_TINY, STABLE_DIFFUSION_MODELS_TINY
+from .exporters_utils import ENCODER_DECODER_MODELS_TINY, EXPORT_MODELS_TINY, STABLE_DIFFUSION_MODELS_TINY
if is_diffusers_available():
@@ -164,31 +155,25 @@ class NeuronStableDiffusionExportTestCase(unittest.TestCase):
"""
@parameterized.expand(
- [STABLE_DIFFUSION_MODELS_TINY["stable-diffusion"], STABLE_DIFFUSION_MODELS_TINY["stable-diffusion"]]
+ [STABLE_DIFFUSION_MODELS_TINY["stable-diffusion"], STABLE_DIFFUSION_MODELS_TINY["latent-consistency"]]
)
- def test_export_for_stable_diffusion_models(self, model_name):
+ def test_export_for_stable_diffusion_models(self, model_id):
set_seed(SEED)
# prepare neuron config / models
- pipe = StableDiffusionPipeline.from_pretrained(model_name)
+ model = StableDiffusionPipeline.from_pretrained(model_id)
input_shapes = build_stable_diffusion_components_mandatory_shapes(
- **{"batch_size": 1, "height": 64, "width": 64}
- )
- models_and_neuron_configs = get_stable_diffusion_models_for_export(
- pipeline=pipe,
- task="stable-diffusion",
- dynamic_batch_size=False,
- **input_shapes,
+ **{"batch_size": 1, "height": 64, "width": 64, "num_images_per_prompt": 4}
)
- output_model_names = {
- DIFFUSION_MODEL_TEXT_ENCODER_NAME: os.path.join(DIFFUSION_MODEL_TEXT_ENCODER_NAME, NEURON_FILE_NAME),
- DIFFUSION_MODEL_UNET_NAME: os.path.join(DIFFUSION_MODEL_UNET_NAME, NEURON_FILE_NAME),
- DIFFUSION_MODEL_VAE_ENCODER_NAME: os.path.join(DIFFUSION_MODEL_VAE_ENCODER_NAME, NEURON_FILE_NAME),
- DIFFUSION_MODEL_VAE_DECODER_NAME: os.path.join(DIFFUSION_MODEL_VAE_DECODER_NAME, NEURON_FILE_NAME),
- }
-
with TemporaryDirectory() as tmpdirname:
+ models_and_neuron_configs, output_model_names = _get_submodels_and_neuron_configs(
+ model=model,
+ input_shapes=input_shapes,
+ task="stable-diffusion",
+ output=Path(tmpdirname),
+ model_name_or_path=model_id,
+ )
_, neuron_outputs = export_models(
models_and_neuron_configs=models_and_neuron_configs,
output_dir=Path(tmpdirname),
@@ -202,30 +187,61 @@ def test_export_for_stable_diffusion_models(self, model_name):
)
@parameterized.expand([STABLE_DIFFUSION_MODELS_TINY["stable-diffusion-xl"]])
- def test_export_for_stable_diffusion_xl_models(self, model_name):
+ def test_export_for_stable_diffusion_xl_models(self, model_id):
set_seed(SEED)
# prepare neuron config / models
- pipe = StableDiffusionXLPipeline.from_pretrained(model_name)
+ model = StableDiffusionXLPipeline.from_pretrained(model_id)
input_shapes = build_stable_diffusion_components_mandatory_shapes(
- **{"batch_size": 1, "height": 64, "width": 64}
- )
- models_and_neuron_configs = get_stable_diffusion_models_for_export(
- pipeline=pipe,
- task="stable-diffusion-xl",
- dynamic_batch_size=False,
- **input_shapes,
+ **{"batch_size": 1, "height": 64, "width": 64, "num_images_per_prompt": 4}
)
- output_model_names = {
- DIFFUSION_MODEL_TEXT_ENCODER_NAME: os.path.join(DIFFUSION_MODEL_TEXT_ENCODER_NAME, NEURON_FILE_NAME),
- DIFFUSION_MODEL_TEXT_ENCODER_2_NAME: os.path.join(DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, NEURON_FILE_NAME),
- DIFFUSION_MODEL_UNET_NAME: os.path.join(DIFFUSION_MODEL_UNET_NAME, NEURON_FILE_NAME),
- DIFFUSION_MODEL_VAE_ENCODER_NAME: os.path.join(DIFFUSION_MODEL_VAE_ENCODER_NAME, NEURON_FILE_NAME),
- DIFFUSION_MODEL_VAE_DECODER_NAME: os.path.join(DIFFUSION_MODEL_VAE_DECODER_NAME, NEURON_FILE_NAME),
- }
+ with TemporaryDirectory() as tmpdirname:
+ models_and_neuron_configs, output_model_names = _get_submodels_and_neuron_configs(
+ model=model,
+ input_shapes=input_shapes,
+ task="stable-diffusion-xl",
+ output=Path(tmpdirname),
+ model_name_or_path=model_id,
+ )
+ _, neuron_outputs = export_models(
+ models_and_neuron_configs=models_and_neuron_configs,
+ output_dir=Path(tmpdirname),
+ output_file_names=output_model_names,
+ )
+ validate_models_outputs(
+ models_and_neuron_configs=models_and_neuron_configs,
+ neuron_named_outputs=neuron_outputs,
+ output_dir=Path(tmpdirname),
+ neuron_files_subpaths=output_model_names,
+ )
+
+
+@is_inferentia_test
+@requires_neuronx
+class NeuronEncoderDecoderExportTestCase(unittest.TestCase):
+ """
+ Integration tests ensuring encoder-decoder models are correctly exported.
+ """
+
+ @parameterized.expand(ENCODER_DECODER_MODELS_TINY.items())
+ def test_export_encoder_decoder_models(self, model_name, model_id):
+ set_seed(SEED)
+
+ # prepare neuron config / models
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
+ input_shapes = {"batch_size": 1, "sequence_length": 18, "num_beams": 4}
with TemporaryDirectory() as tmpdirname:
+ models_and_neuron_configs, output_model_names = _get_submodels_and_neuron_configs(
+ model=model,
+ input_shapes=input_shapes,
+ task="text2text-generation",
+ output=Path(tmpdirname),
+ model_name_or_path=model_id,
+ output_attentions=True,
+ output_hidden_states=True,
+ )
_, neuron_outputs = export_models(
models_and_neuron_configs=models_and_neuron_configs,
output_dir=Path(tmpdirname),
diff --git a/tests/generation/conftest.py b/tests/generation/conftest.py
index 3997bc9a6..c39a03b38 100644
--- a/tests/generation/conftest.py
+++ b/tests/generation/conftest.py
@@ -17,7 +17,7 @@
import pytest
from transformers import AutoTokenizer
-from optimum.neuron import NeuronModelForCausalLM
+from optimum.neuron import NeuronModelForCausalLM, NeuronModelForSeq2SeqLM
from optimum.neuron.utils.testing_utils import requires_neuronx
from optimum.utils.testing_utils import USER
@@ -29,24 +29,32 @@
"llama": "dacorvo/tiny-random-llama",
"opt": "hf-internal-testing/tiny-random-OPTForCausalLM",
}
+SEQ2SEQ_MODEL_NAMES = {
+ "t5": "hf-internal-testing/tiny-random-t5",
+}
@pytest.fixture(scope="module", params=[DECODER_MODEL_NAMES[model_arch] for model_arch in DECODER_MODEL_ARCHITECTURES])
-def export_model_id(request):
+def export_decoder_id(request):
+ return request.param
+
+
+@pytest.fixture(scope="module", params=[SEQ2SEQ_MODEL_NAMES[model_arch] for model_arch in SEQ2SEQ_MODEL_NAMES])
+def export_seq2seq_id(request):
return request.param
@pytest.fixture(scope="module")
@requires_neuronx
-def neuron_model_path(export_model_id):
+def neuron_decoder_path(export_decoder_id):
model = NeuronModelForCausalLM.from_pretrained(
- export_model_id, export=True, batch_size=1, sequence_length=100, num_cores=2
+ export_decoder_id, export=True, batch_size=1, sequence_length=100, num_cores=2
)
model_dir = TemporaryDirectory()
model_path = model_dir.name
model.save_pretrained(model_path)
del model
- tokenizer = AutoTokenizer.from_pretrained(export_model_id)
+ tokenizer = AutoTokenizer.from_pretrained(export_decoder_id)
tokenizer.save_pretrained(model_path)
del tokenizer
# Yield instead of returning to keep a reference to the temporary directory.
@@ -56,8 +64,91 @@ def neuron_model_path(export_model_id):
@pytest.fixture(scope="module")
-def neuron_push_id(export_model_id):
- model_name = export_model_id.split("/")[-1]
+@requires_neuronx
+def neuron_seq2seq_beam_path(export_seq2seq_id):
+ model = NeuronModelForSeq2SeqLM.from_pretrained(
+ export_seq2seq_id, export=True, batch_size=1, sequence_length=64, num_beams=4
+ )
+ model_dir = TemporaryDirectory()
+ model_path = model_dir.name
+ model.save_pretrained(model_path)
+ del model
+ # Yield instead of returning to keep a reference to the temporary directory.
+ # It will go out of scope and be released only once all tests needing the fixture
+ # have been completed.
+ yield model_path
+
+
+@pytest.fixture(scope="module")
+@requires_neuronx
+def neuron_seq2seq_beam_path_with_optional_outputs(export_seq2seq_id):
+ model = NeuronModelForSeq2SeqLM.from_pretrained(
+ export_seq2seq_id,
+ export=True,
+ batch_size=1,
+ sequence_length=64,
+ num_beams=4,
+ output_attentions=True,
+ output_hidden_states=True,
+ )
+ model_dir = TemporaryDirectory()
+ model_path = model_dir.name
+ model.save_pretrained(model_path)
+ del model
+ # Yield instead of returning to keep a reference to the temporary directory.
+ # It will go out of scope and be released only once all tests needing the fixture
+ # have been completed.
+ yield model_path
+
+
+@pytest.fixture(scope="module")
+@requires_neuronx
+def neuron_seq2seq_greedy_path(export_seq2seq_id):
+ model = NeuronModelForSeq2SeqLM.from_pretrained(
+ export_seq2seq_id, export=True, batch_size=1, sequence_length=64, num_beams=1
+ )
+ model_dir = TemporaryDirectory()
+ model_path = model_dir.name
+ model.save_pretrained(model_path)
+ del model
+ # Yield instead of returning to keep a reference to the temporary directory.
+ # It will go out of scope and be released only once all tests needing the fixture
+ # have been completed.
+ yield model_path
+
+
+@pytest.fixture(scope="module")
+@requires_neuronx
+def neuron_seq2seq_greedy_path_with_optional_outputs(export_seq2seq_id):
+ model = NeuronModelForSeq2SeqLM.from_pretrained(
+ export_seq2seq_id,
+ export=True,
+ batch_size=1,
+ sequence_length=64,
+ num_beams=1,
+ output_attentions=True,
+ output_hidden_states=True,
+ )
+ model_dir = TemporaryDirectory()
+ model_path = model_dir.name
+ model.save_pretrained(model_path)
+ del model
+ # Yield instead of returning to keep a reference to the temporary directory.
+ # It will go out of scope and be released only once all tests needing the fixture
+ # have been completed.
+ yield model_path
+
+
+@pytest.fixture(scope="module")
+def neuron_push_decoder_id(export_decoder_id):
+ model_name = export_decoder_id.split("/")[-1]
+ repo_id = f"{USER}/{model_name}-neuronx"
+ return repo_id
+
+
+@pytest.fixture(scope="module")
+def neuron_push_seq2seq_id(export_seq2seq_id):
+ model_name = export_seq2seq_id.split("/")[-1]
repo_id = f"{USER}/{model_name}-neuronx"
return repo_id
diff --git a/tests/generation/test_export.py b/tests/generation/test_export.py
index e4eaef935..fb69f2a88 100644
--- a/tests/generation/test_export.py
+++ b/tests/generation/test_export.py
@@ -16,7 +16,7 @@
import pytest
from generation_utils import check_neuron_model
-from optimum.neuron import NeuronModelForCausalLM
+from optimum.neuron import NeuronModelForCausalLM, NeuronModelForSeq2SeqLM
from optimum.neuron.utils.testing_utils import is_inferentia_test, requires_neuronx
@@ -30,9 +30,9 @@
)
@is_inferentia_test
@requires_neuronx
-def test_model_export(export_model_id, batch_size, sequence_length, num_cores, auto_cast_type):
+def test_decoder_export(export_decoder_id, batch_size, sequence_length, num_cores, auto_cast_type):
model = NeuronModelForCausalLM.from_pretrained(
- export_model_id,
+ export_decoder_id,
export=True,
batch_size=batch_size,
sequence_length=sequence_length,
@@ -44,6 +44,33 @@ def test_model_export(export_model_id, batch_size, sequence_length, num_cores, a
@is_inferentia_test
@requires_neuronx
-def test_model_from_path(neuron_model_path):
- model = NeuronModelForCausalLM.from_pretrained(neuron_model_path)
+def test_model_from_path(neuron_decoder_path):
+ model = NeuronModelForCausalLM.from_pretrained(neuron_decoder_path)
check_neuron_model(model)
+
+
+@pytest.mark.parametrize(
+ "batch_size, sequence_length, num_beams",
+ [
+ [1, 64, 1],
+ [1, 64, 4],
+ ],
+)
+@is_inferentia_test
+@requires_neuronx
+def test_seq2seq_export(export_seq2seq_id, batch_size, sequence_length, num_beams):
+ model = NeuronModelForSeq2SeqLM.from_pretrained(
+ export_seq2seq_id,
+ export=True,
+ batch_size=batch_size,
+ sequence_length=sequence_length,
+ num_beams=num_beams,
+ )
+ return model
+
+
+@is_inferentia_test
+@requires_neuronx
+def test_seq2seq_model_from_path(neuron_seq2seq_greedy_path):
+ model = NeuronModelForSeq2SeqLM.from_pretrained(neuron_seq2seq_greedy_path)
+ return model
diff --git a/tests/generation/test_generate.py b/tests/generation/test_generate.py
index 47eecb8a7..1f7630b4d 100644
--- a/tests/generation/test_generate.py
+++ b/tests/generation/test_generate.py
@@ -17,7 +17,7 @@
import torch
from transformers import AutoTokenizer
-from optimum.neuron import NeuronModelForCausalLM
+from optimum.neuron import NeuronModelForCausalLM, NeuronModelForSeq2SeqLM
from optimum.neuron.utils.testing_utils import is_inferentia_test, requires_neuronx
@@ -40,17 +40,17 @@ def _test_model_generation(model, tokenizer, batch_size, input_length, **gen_kwa
)
@is_inferentia_test
@requires_neuronx
-def test_model_generation(neuron_model_path, gen_kwargs):
- model = NeuronModelForCausalLM.from_pretrained(neuron_model_path)
- tokenizer = AutoTokenizer.from_pretrained(neuron_model_path)
+def test_decoder_generation(neuron_decoder_path, gen_kwargs):
+ model = NeuronModelForCausalLM.from_pretrained(neuron_decoder_path)
+ tokenizer = AutoTokenizer.from_pretrained(neuron_decoder_path)
_test_model_generation(model, tokenizer, model.batch_size, 10, **gen_kwargs)
@is_inferentia_test
@requires_neuronx
-def test_model_generation_input_dimensions(neuron_model_path):
- model = NeuronModelForCausalLM.from_pretrained(neuron_model_path)
- tokenizer = AutoTokenizer.from_pretrained(neuron_model_path)
+def test_model_generation_input_dimensions(neuron_decoder_path):
+ model = NeuronModelForCausalLM.from_pretrained(neuron_decoder_path)
+ tokenizer = AutoTokenizer.from_pretrained(neuron_decoder_path)
# Using valid input dimensions
_test_model_generation(model, tokenizer, model.batch_size, model.max_length // 2)
# Using an incompatible batch_size
@@ -59,3 +59,85 @@ def test_model_generation_input_dimensions(neuron_model_path):
# Using an incompatible input length
with pytest.raises(ValueError, match="The input sequence length"):
_test_model_generation(model, tokenizer, model.batch_size, input_length=model.max_length * 2)
+
+
+@is_inferentia_test
+@requires_neuronx
+def test_seq2seq_generation_beam(neuron_seq2seq_beam_path):
+ model = NeuronModelForSeq2SeqLM.from_pretrained(neuron_seq2seq_beam_path)
+ tokenizer = AutoTokenizer.from_pretrained(neuron_seq2seq_beam_path)
+ inputs = tokenizer("translate English to German: Lets eat good food.", return_tensors="pt")
+
+ # 1. max length
+ output = model.generate(**inputs, num_return_sequences=2, max_length=5)
+ assert len(output[0]) <= 5
+
+ # 2. min length
+ output = model.generate(**inputs, num_return_sequences=2, min_length=10)
+ assert len(output[0]) >= 10
+
+ # 3. max new tokens
+ output = model.generate(**inputs, num_return_sequences=2, max_new_tokens=5)
+ assert len(output[0].unique()) <= 5 + 1 # +1 for `decoder_start_token_id`
+
+
+@is_inferentia_test
+@requires_neuronx
+def test_seq2seq_generation_beam_with_optional_outputs(neuron_seq2seq_beam_path_with_optional_outputs):
+ model = NeuronModelForSeq2SeqLM.from_pretrained(neuron_seq2seq_beam_path_with_optional_outputs)
+ tokenizer = AutoTokenizer.from_pretrained(neuron_seq2seq_beam_path_with_optional_outputs)
+ inputs = tokenizer("translate English to German: Lets eat good food.", return_tensors="pt")
+
+ output = model.generate(
+ **inputs,
+ num_return_sequences=1,
+ max_length=20,
+ output_scores=True,
+ output_attentions=True,
+ output_hidden_states=True,
+ return_dict_in_generate=True,
+ )
+ assert "scores" in output
+ assert "decoder_attentions" in output
+ assert "cross_attentions" in output
+ assert "decoder_hidden_states" in output
+
+
+@is_inferentia_test
+@requires_neuronx
+def test_seq2seq_generation_greedy(neuron_seq2seq_greedy_path):
+ model = NeuronModelForSeq2SeqLM.from_pretrained(neuron_seq2seq_greedy_path)
+ tokenizer = AutoTokenizer.from_pretrained(neuron_seq2seq_greedy_path)
+ inputs = tokenizer("translate English to German: Lets eat good food.", return_tensors="pt")
+
+ # 1. max length
+ output = model.generate(**inputs, num_return_sequences=1, max_length=5)
+ assert len(output[0]) <= 5
+
+ # 2. min length
+ output = model.generate(**inputs, num_return_sequences=1, min_length=10)
+ assert len(output[0]) >= 10
+
+ # 3. max new tokens
+ output = model.generate(**inputs, num_return_sequences=1, max_new_tokens=5)
+ assert len(output[0]) <= 5 + 1 # +1 for `decoder_start_token_id`
+
+
+@is_inferentia_test
+@requires_neuronx
+def test_seq2seq_generation_greedy_with_optional_outputs(neuron_seq2seq_greedy_path_with_optional_outputs):
+ model = NeuronModelForSeq2SeqLM.from_pretrained(neuron_seq2seq_greedy_path_with_optional_outputs)
+ tokenizer = AutoTokenizer.from_pretrained(neuron_seq2seq_greedy_path_with_optional_outputs)
+ inputs = tokenizer("translate English to German: Lets eat good food.", return_tensors="pt")
+
+ output = model.generate(
+ **inputs,
+ num_return_sequences=1,
+ max_length=20,
+ output_attentions=True,
+ output_hidden_states=True,
+ return_dict_in_generate=True,
+ )
+ assert "decoder_attentions" in output
+ assert "cross_attentions" in output
+ assert "decoder_hidden_states" in output
diff --git a/tests/generation/test_hub.py b/tests/generation/test_hub.py
index 2966e0199..7e1faa196 100644
--- a/tests/generation/test_hub.py
+++ b/tests/generation/test_hub.py
@@ -18,7 +18,7 @@
from huggingface_hub import HfApi
from transformers.testing_utils import ENDPOINT_STAGING
-from optimum.neuron import NeuronModelForCausalLM
+from optimum.neuron import NeuronModelForCausalLM, NeuronModelForSeq2SeqLM
from optimum.neuron.utils.testing_utils import is_inferentia_test, requires_neuronx
from optimum.utils.testing_utils import TOKEN
@@ -34,17 +34,46 @@ def test_model_from_hub():
@is_inferentia_test
@requires_neuronx
-def test_push_to_hub(neuron_model_path, neuron_push_id):
- model = NeuronModelForCausalLM.from_pretrained(neuron_model_path)
- model.push_to_hub(neuron_model_path, neuron_push_id, use_auth_token=TOKEN, endpoint=ENDPOINT_STAGING)
+def test_push_to_hub(neuron_decoder_path, neuron_push_decoder_id):
+ model = NeuronModelForCausalLM.from_pretrained(neuron_decoder_path)
+ model.push_to_hub(neuron_decoder_path, neuron_push_decoder_id, use_auth_token=TOKEN, endpoint=ENDPOINT_STAGING)
api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN)
try:
- hub_files_info = api.list_files_info(neuron_push_id)
+ hub_files_info = api.list_files_info(neuron_push_decoder_id)
hub_files_path = [info.rfilename for info in hub_files_info]
- for path, _, files in os.walk(neuron_model_path):
+ for path, _, files in os.walk(neuron_decoder_path):
for name in files:
local_file_path = os.path.join(path, name)
- hub_file_path = os.path.relpath(local_file_path, neuron_model_path)
+ hub_file_path = os.path.relpath(local_file_path, neuron_decoder_path)
assert hub_file_path in hub_files_path
finally:
- api.delete_repo(neuron_push_id)
+ api.delete_repo(neuron_push_decoder_id)
+
+
+@is_inferentia_test
+@requires_neuronx
+def test_seq2seq_model_from_hub():
+ model = NeuronModelForSeq2SeqLM.from_pretrained(
+ "Jingya/tiny-random-t5-neuronx", revision="ce617676ce12a19df7c6bd523c69b83447fa036b"
+ )
+ return model
+
+
+@is_inferentia_test
+@requires_neuronx
+def test_push_seq2seq_to_hub(neuron_seq2seq_greedy_path, neuron_push_seq2seq_id):
+ model = NeuronModelForSeq2SeqLM.from_pretrained(neuron_seq2seq_greedy_path)
+ model.push_to_hub(
+ neuron_seq2seq_greedy_path, neuron_push_seq2seq_id, use_auth_token=TOKEN, endpoint=ENDPOINT_STAGING
+ )
+ api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN)
+ try:
+ hub_files_info = api.list_files_info(neuron_push_seq2seq_id)
+ hub_files_path = [info.rfilename for info in hub_files_info]
+ for path, _, files in os.walk(neuron_seq2seq_greedy_path):
+ for name in files:
+ local_file_path = os.path.join(path, name)
+ hub_file_path = os.path.relpath(local_file_path, neuron_seq2seq_greedy_path)
+ assert hub_file_path in hub_files_path
+ finally:
+ api.delete_repo(neuron_push_seq2seq_id)