From 9676a6543bc2fc29b9ec4a6f616955925d8d60f3 Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Tue, 28 Nov 2023 18:35:29 +0000 Subject: [PATCH] reuse neuron gen mix --- optimum/neuron/generation/utils.py | 1155 ++++++++++++++-------------- optimum/neuron/modeling_seq2seq.py | 298 +------ 2 files changed, 576 insertions(+), 877 deletions(-) diff --git a/optimum/neuron/generation/utils.py b/optimum/neuron/generation/utils.py index 9ab87e914..11f64d88e 100644 --- a/optimum/neuron/generation/utils.py +++ b/optimum/neuron/generation/utils.py @@ -52,8 +52,7 @@ if TYPE_CHECKING: - from transformers.generation.streamers import BaseStreamer - from transformers.modeling_utils import PreTrainedModel + pass logger = logging.get_logger(__name__) @@ -272,419 +271,6 @@ def _expand_dict_for_generation(dict_to_expand): return input_ids, model_kwargs - def beam_search( - self, - input_ids: torch.LongTensor, - beam_scorer: BeamScorer, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, - seq_length: Optional[int] = None, - **model_kwargs, - ) -> Union[BeamSearchOutput, torch.LongTensor]: - r""" - Generates sequences of token ids for models with a language modeling head using **beam search decoding** and - can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - - - In most cases, you do not need to call [`~generation.GenerationMixin.beam_search`] directly. Use generate() - instead. For an overview of generation strategies and code examples, check the [following - guide](../generation_strategies). - - - - Parameters: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - The sequence used as a prompt for the generation. - beam_scorer (`BeamScorer`): - An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and - sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. - logits_processor (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - stopping_criteria (`StoppingCriteriaList`, *optional*): - An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] - used to tell if the generation loop should stop. - max_length (`int`, *optional*, defaults to 20): - **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated - tokens. The maximum length of the sequence to be generated. - pad_token_id (`int`, *optional*): - The id of the *padding* token. - eos_token_id (`Union[int, List[int]]`, *optional*): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. - output_attentions (`bool`, *optional*, defaults to `False`): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more details. - output_hidden_states (`bool`, *optional*, defaults to `False`): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more details. - output_scores (`bool`, *optional*, defaults to `False`): - Whether or not to return the prediction scores. See `scores` under returned tensors for more details. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - seq_length: - Length of current input_ids sequence - model_kwargs: - Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is - an encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - [`generation.BeamSearchDecoderOnlyOutput`], [`~generation.BeamSearchEncoderDecoderOutput`] or - `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`~generation.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation.BeamSearchEncoderDecoderOutput`] if - `model.config.is_encoder_decoder=True`. - - - Examples: - - ```python - >>> from transformers import ( - ... AutoTokenizer, - ... AutoModelForSeq2SeqLM, - ... LogitsProcessorList, - ... MinLengthLogitsProcessor, - ... BeamSearchScorer, - ... ) - >>> import torch - - >>> tokenizer = AutoTokenizer.from_pretrained("t5-base") - >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") - - >>> encoder_input_str = "translate English to German: How old are you?" - >>> encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids - - - >>> # lets run beam search using 3 beams - >>> num_beams = 3 - >>> # define decoder start token ids - >>> input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long) - >>> input_ids = input_ids * model.config.decoder_start_token_id - - >>> # add encoder_outputs to model keyword arguments - >>> model_kwargs = { - ... "encoder_outputs": model.get_encoder()( - ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True - ... ) - ... } - - >>> # instantiate beam scorer - >>> beam_scorer = BeamSearchScorer( - ... batch_size=1, - ... num_beams=num_beams, - ... device=model.device, - ... ) - - >>> # instantiate logits processors - >>> logits_processor = LogitsProcessorList( - ... [ - ... MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id), - ... ] - ... ) - - >>> outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs) - - >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ['Wie alt bist du?'] - ```""" - # init values - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - if max_length is not None: - warnings.warn( - "`max_length` is deprecated in this function, use" - " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", - UserWarning, - ) - stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - if len(stopping_criteria) == 0: - warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions - ) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate - if return_dict_in_generate is not None - else self.generation_config.return_dict_in_generate - ) - - batch_size = len(beam_scorer._beam_hyps) - num_beams = beam_scorer.num_beams - - batch_beam_size, cur_len = input_ids.shape - - # Overwrite cur_len - cur_len = seq_length - - if num_beams * batch_size != batch_beam_size: - raise ValueError( - f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." - ) - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - beam_indices = ( - tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None - ) - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None - - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None - ) - - # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens - # of the first beam are considered to avoid sampling the exact same tokens across all beams. - # beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) - beam_scores_device = "cpu" - beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=beam_scores_device) - beam_scores[:, 1:] = -1e9 - beam_scores = beam_scores.view((batch_size * num_beams,)) - - this_peer_finished = False # used by synced_gpus only - while True: - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: - break - - # prepare model inputs - if model_kwargs["use_cache"]: - # From max_length-sized input_ids, select first - # cur_len - 1 values. - update_indices = torch.stack( - [torch.arange(input_ids.size(0)), torch.tensor(cur_len - 1).repeat(input_ids.size(0))], dim=-1 - ) - input_ids_ = input_ids[update_indices[:, 0], update_indices[:, 1], None] - model_inputs = self.prepare_inputs_for_generation(input_ids_, **model_kwargs) - else: - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - if synced_gpus and this_peer_finished: - cur_len = cur_len + 1 - continue # don't waste resources running the code we don't need - - if not model_kwargs["use_cache"]: - one_hot = ( - torch.cat( - [ - torch.tensor([0]).repeat(1, cur_len - 1), - torch.tensor([1]).repeat(1, 1), - torch.tensor([0]).repeat(1, input_ids.size(1) - cur_len), - ], - dim=1, - ) - .to(device=outputs.logits.device) - .float() - ) - next_token_logits = torch.matmul(one_hot, outputs.logits) - next_token_logits = next_token_logits.squeeze(1) - else: - next_token_logits = outputs.logits[:, -1, :] - - # Manually compute log softmax - # log_softmax(vi) = vi - max(vi) - log(sum(exp(vi - max(vi)))) - logit_max, _ = torch.max(next_token_logits, dim=-1, keepdim=True) - logsumexp = torch.log(torch.exp(next_token_logits - logit_max).sum(dim=-1, keepdim=True)) - next_token_scores = next_token_logits - logit_max - logsumexp - # (batch_size * num_beams, vocab_size) - - xm.mark_step() - - # We don't want to change every single logit processor, so - # we peform this processing on CPU. - input_ids_ = input_ids.to("cpu")[:, :cur_len] - next_token_scores_ = next_token_scores.to("cpu") - next_token_scores_processed = logits_processor(input_ids_, next_token_scores_) - - next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - scores += (next_token_scores_processed,) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - # reshape for beam search - vocab_size = next_token_scores.shape[-1] - next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) - next_token_scores = next_token_scores * 1 - - # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search) - next_token_scores, next_tokens = torch.topk( - next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True - ) - - next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") - next_tokens = next_tokens % vocab_size - - # stateless - beam_outputs = beam_scorer.process( - input_ids.to("cpu")[:, :cur_len], - next_token_scores.to("cpu"), - next_tokens.to("cpu"), - next_indices.to("cpu"), - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - beam_indices=beam_indices, - ) - - beam_scores = beam_outputs["next_beam_scores"] - beam_next_tokens = beam_outputs["next_beam_tokens"] - beam_idx = beam_outputs["next_beam_indices"] - - update_indices = torch.stack( - [torch.arange(batch_beam_size), torch.tensor(cur_len - 1).repeat(batch_beam_size)], dim=-1 - ) - update_indices_2 = torch.stack( - [torch.arange(batch_beam_size), torch.tensor(cur_len).repeat(batch_beam_size)], dim=-1 - ) - # First select beam_indices - device = input_ids.device - beam_idx_device = beam_idx.to(device=input_ids.device) - input_ids[:, :] = input_ids[beam_idx_device.long(), :] - - # Then append new tokens - input_ids[update_indices_2[:, 0], update_indices_2[:, 1], None] = beam_next_tokens.unsqueeze(-1).to(device) - input_ids = input_ids * 1 # Hack to materialize tensor - - # update generated ids, model inputs, and length for next step - model_kwargs = self._update_model_kwargs_for_xla_generation( - outputs, - model_kwargs, - batch_size=batch_beam_size, - is_encoder_decoder=self.config.is_encoder_decoder, - max_length=stopping_criteria.max_length, - seq_length=cur_len, - use_cache=model_kwargs["use_cache"], - ) - if model_kwargs["past_key_values"] is not None: - model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) - - if return_dict_in_generate and output_scores: - beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) - - # increase cur_len - cur_len = cur_len + 1 - - # stop when each sentence is finished, or if we exceed the maximum length - stop_criterion_1 = beam_scorer.is_done - if isinstance(stopping_criteria, list): - if len(stopping_criteria) == 1: - stopping_criteria = stopping_criteria[0] - - # Cases that can be handled in XLA without requiring - # non-padded input_ids - if isinstance(stopping_criteria, MaxLengthCriteria): - stop_criterion_2 = cur_len >= stopping_criteria.max_length - elif isinstance(stopping_criteria, MaxTimeCriteria): - stop_criterion_2 = stopping_criteria(input_ids, scores) - else: - # Other cases will be handled on CPU - batch_size, _ = input_ids.shape - input_ids_cpu = input_ids.to("cpu") - mask = torch.cat( - [torch.ones(batch_size, cur_len), torch.zeros(batch_size, input_ids.shape[1] - cur_len)], dim=1 - ).bool() - input_ids_cpu = torch.masked_select(input_ids_cpu, mask).reshape((batch_size, cur_len)) - scores_cpu = scores.to("cpu") if torch.is_tensor(scores) else scores - stop_criterion_2 = stopping_criteria(input_ids_cpu, scores_cpu) - - if stop_criterion_1 or stop_criterion_2: - if not synced_gpus: - break - else: - this_peer_finished = True - - sequence_outputs = beam_scorer.finalize( - input_ids.to("cpu"), - beam_scores.to("cpu"), - next_tokens.to("cpu"), - next_indices.to("cpu"), - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - max_length=stopping_criteria.max_length, - beam_indices=beam_indices, - ) - - for k, v in sequence_outputs.items(): - if type(v) == torch.Tensor: - sequence_outputs[k] = sequence_outputs[k].to(input_ids.device) - - if return_dict_in_generate: - if not output_scores: - sequence_outputs["sequence_scores"] = None - - if self.config.is_encoder_decoder: - return BeamSearchEncoderDecoderOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - beam_indices=sequence_outputs["beam_indices"], - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - ) - else: - return BeamSearchDecoderOnlyOutput( - sequences=sequence_outputs["sequences"], - sequences_scores=sequence_outputs["sequence_scores"], - scores=scores, - beam_indices=sequence_outputs["beam_indices"], - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - ) - else: - return sequence_outputs["sequences"] - @torch.no_grad() def generate( self, @@ -694,8 +280,7 @@ def generate( stopping_criteria: Optional[StoppingCriteriaList] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, synced_gpus: Optional[bool] = None, - assistant_model: Optional["PreTrainedModel"] = None, - streamer: Optional["BaseStreamer"] = None, + is_traced_inference: bool = False, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: r""" @@ -714,23 +299,23 @@ def generate( Parameters: - inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): + inputs (`Optional[torch.Tensor]`, defaults to `None`): The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of `input_ids`, `input_values`, `input_features`, or `pixel_values`. - generation_config (`~generation.GenerationConfig`, *optional*): + generation_config (`Optional[GenerationConfig]`, defaults to `None`): The generation configuration to be used as base parametrization for the generation call. `**kwargs` passed to generate matching the attributes of `generation_config` will override them. If `generation_config` is not provided, the default will be used, which had the following loading priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s default values, whose documentation should be checked to parameterize generation. - logits_processor (`LogitsProcessorList`, *optional*): + logits_processor (`Optional[LogitsProcessorList]`, defaults to `None`): Custom logits processors that complement the default logits processors built from arguments and generation config. If a logit processor is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users. - stopping_criteria (`StoppingCriteriaList`, *optional*): + stopping_criteria (`Optional[StoppingCriteriaList]`, defaults to `None`): Custom stopping criteria that complement the default stopping criteria built from arguments and a generation config. If a stopping criteria is passed that is already created with the arguments or a generation config an error is thrown. This feature is intended for advanced users. @@ -741,18 +326,13 @@ def generate( on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful for constrained generation conditioned on the prefix, as described in [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904). - synced_gpus (`bool`, *optional*): + synced_gpus (`Optional[bool]`, defaults to `None`): Whether to continue running the while loop until max_length. Unless overridden this flag will be set to `True` under DeepSpeed ZeRO Stage 3 multiple GPUs environment to avoid hanging if one GPU finished generating before other GPUs. Otherwise it'll be set to `False`. - assistant_model (`PreTrainedModel`, *optional*): - An assistant model that can be used to accelerate generation. The assistant model must have the exact - same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model - is much faster than running generation with the model you're calling generate from. As such, the - assistant model should be much smaller. - streamer (`BaseStreamer`, *optional*): - Streamer object that will be used to stream the generated sequences. Generated tokens are passed - through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + is_traced_inference (`bool`, defaults to `False`): + Whether the decoder is traced or using XLA lazy tensor. If the decoder is traced, next tokens and the beam scores + are computed inside the decoder. kwargs: Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder @@ -832,9 +412,11 @@ def generate( # 4. Define other model kwargs model_kwargs["output_attentions"] = generation_config.output_attentions model_kwargs["output_hidden_states"] = generation_config.output_hidden_states - if generation_config.use_cache: + if generation_config.use_cache and not is_traced_inference: warnings.warn("use_cache is not supported for generation on Neuron devices, switching to use_cache=False.") - model_kwargs["use_cache"] = False + model_kwargs["use_cache"] = False + else: + model_kwargs["use_cache"] = generation_config.use_cache accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) requires_attention_mask = "encoder_outputs" not in model_kwargs @@ -875,9 +457,6 @@ def generate( else: input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") - if streamer is not None: - streamer.put(input_ids.cpu()) - # 6. Prepare `max_length` depending on other stopping criteria. input_ids_seq_length = input_ids.shape[-1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None @@ -974,11 +553,6 @@ def generate( if generation_config.num_beam_groups > generation_config.num_beams: raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`") - if streamer is not None and (generation_config.num_beams > 1): - raise ValueError( - "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." - ) - if hasattr(self, "device") and self.device.type != input_ids.device.type: warnings.warn( "You are calling .generate() with the `input_ids` being on a device type different" @@ -1022,7 +596,7 @@ def generate( return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, seq_length=input_ids_seq_length, - streamer=streamer, + is_traced_inference=is_traced_inference, **model_kwargs, ) elif is_beam_gen_mode: @@ -1061,15 +635,332 @@ def generate( return_dict_in_generate=generation_config.return_dict_in_generate, synced_gpus=synced_gpus, seq_length=input_ids_seq_length, + is_traced_inference=is_traced_inference, **model_kwargs, ) else: - raise ValueError("Only greedy search and beam search are supported on Neuron.") + raise ValueError("Only greedy search and beam search are supported on Neuron.") + + def greedy_search( + self, + input_ids: torch.LongTensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: bool = False, + seq_length: Optional[int] = None, + is_traced_inference: bool = False, + **model_kwargs, + ) -> Union[GreedySearchOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be + used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + + + In most cases, you do not need to call [`~generation.GenerationMixin.greedy_search`] directly. Use generate() + instead. For an overview of generation strategies and code examples, check the [following + guide](../generation_strategies). + + + + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`LogitsProcessorList`, *optional*): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`, *optional*): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + + max_length (`int`, *optional*, defaults to 20): + **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated + tokens. The maximum length of the sequence to be generated. + pad_token_id (`int`, *optional*): + The id of the *padding* token. + eos_token_id (`Union[int, List[int]]`, *optional*): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more details. + output_hidden_states (`bool`, *optional*, defaults to `False`): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more details. + output_scores (`bool`, *optional*, defaults to `False`): + Whether or not to return the prediction scores. See `scores` under returned tensors for more details. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + seq_length (`Optional[int]`, defaults to `False`): + Length of current input_ids sequence + is_traced_inference (`bool`, defaults to `False`): + Whether the decoder is traced or using XLA lazy tensor. If the decoder is traced, next tokens and the beam scores + are computed inside the decoder. + model_kwargs: + Additional model specific keyword arguments will be forwarded to the `forward` function of the model. + If model is an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.GreedySearchDecoderOnlyOutput`], [`~generation.GreedySearchEncoderDecoderOutput`] or + `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GreedySearchEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + + Examples: + + ```python + >>> from transformers import AutoTokenizer + >>> from optimum.neuron import NeuronModelForSeq2SeqLM + + >>> tokenizer = AutoTokenizer.from_pretrained("t5-small") + >>> input_shapes = {"batch_size": 1, "sequence_length": 128, "num_beams": 1} + >>> model = NeuronModelForSeq2SeqLM.from_pretrained("t5-small", export=True, dynamic_batch_size=False, **input_shapes) + + >>> input_prompt = "translate English to German: Lets eat good food." + >>> inputs = tokenizer(input_prompt, return_tensors="pt") + + >>> outputs = model.greedy_search(input_ids) + + >>> results = [tokenizer.decode(t, skip_special_tokens=True) for t in outputs] + ``` + """ + # init values + if logits_processor is not None and is_traced_inference: + logger.warning( + "`logits_processor` will not be neglected because in `optimum-neuron`, `next_tokens` is computed inside the compiled decoder. If you want us to support custom logits_processor during the compilation, please file an issue to https://github.com/huggingface/optimum-neuron." + ) + elif logits_processor is None: + logits_processor = LogitsProcessorList() + use_cache = model_kwargs.pop("use_cache", False) + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # keep track of which sequences are already finished + unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) + + this_peer_finished = False # used by synced_gpus only + while True: + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + break + + # prepare model inputs + if use_cache: + # From max_length-sized input_ids, select first + # seq_length - 1 values. + + if model_kwargs.get("past_key_values") is None: + input_ids_ = input_ids[:, :seq_length] + else: + update_indices = torch.stack( + [torch.arange(input_ids.size(0)), torch.tensor(seq_length - 1).repeat(input_ids.size(0))], + dim=-1, + ) + input_ids_ = input_ids[update_indices[:, 0], update_indices[:, 1], None] + + model_inputs = self.prepare_inputs_for_generation(input_ids_, **model_kwargs) + else: + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + if synced_gpus and this_peer_finished: + continue # don't waste resources running the code we don't need + + if not is_traced_inference: + if not use_cache: + one_hot = ( + torch.cat( + [ + torch.tensor([0]).repeat(1, seq_length - 1), + torch.tensor([1]).repeat(1, 1), + torch.tensor([0]).repeat(1, input_ids.size(1) - seq_length), + ], + dim=1, + ) + .to(device=outputs.logits.device) + .float() + ) + next_token_logits = torch.matmul(one_hot, outputs.logits) + next_token_logits = next_token_logits.squeeze(1) + else: + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + # Move to cpu to handle arbitrary logits_processor + next_tokens_scores = logits_processor(input_ids.to("cpu")[:, :seq_length], next_token_logits.to("cpu")) + next_tokens_scores = next_tokens_scores.to(input_ids.device) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_tokens_scores,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # argmax + next_tokens = torch.argmax(next_tokens_scores, dim=-1) + else: + next_tokens = outputs[0] + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + # update generated ids, model inputs, and length for next step + batch_size, _ = input_ids.shape + update_indices = torch.stack( + [torch.arange(batch_size), torch.tensor(seq_length).repeat(batch_size)], dim=-1 + ) + input_ids[update_indices[:, 0], update_indices[:, 1]] = next_tokens[:] + model_kwargs = self._update_model_kwargs_for_xla_generation( + outputs=outputs, + model_kwargs=model_kwargs, + batch_size=batch_size, + is_encoder_decoder=self.config.is_encoder_decoder, + max_length=stopping_criteria.max_length, + seq_length=seq_length, + use_cache=use_cache, + ) + + seq_length += 1 + + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id_tensor is not None: + unfinished_sequences = unfinished_sequences.mul( + next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) + ) + + if not is_traced_inference: + xm.mark_step() + + # stop when each sentence is finished, or if we exceed the maximum length + stop_criterion_1 = unfinished_sequences.max() == 0 + + if isinstance(stopping_criteria, list): + if len(stopping_criteria) == 1: + stopping_criteria = stopping_criteria[0] + + # Cases that can be handled in XLA without requiring + # non-padded input_ids + if isinstance(stopping_criteria, MaxLengthCriteria): + stop_criterion_2 = seq_length >= stopping_criteria.max_length + elif isinstance(stopping_criteria, MaxTimeCriteria): + stop_criterion_2 = stopping_criteria(input_ids, scores) + else: + # Other cases will be handled on CPU + batch_size, _ = input_ids.shape + mask = torch.cat( + [torch.ones(batch_size, seq_length), torch.zeros(batch_size, input_ids.shape[1] - seq_length)], + dim=1, + ).bool() + input_ids_cpu = torch.masked_select(input_ids, mask).reshape((batch_size, seq_length)).to("cpu") + scores_cpu = scores.to("cpu") if torch.is_tensor(scores) else scores + stop_criterion_2 = stopping_criteria(input_ids_cpu, scores_cpu) + + if stop_criterion_1 or stop_criterion_2: + this_peer_finished = True + + if this_peer_finished and not synced_gpus: + break + + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return GreedySearchEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return GreedySearchDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) + else: + return input_ids - def greedy_search( + def beam_search( self, input_ids: torch.LongTensor, + beam_scorer: BeamScorer, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, max_length: Optional[int] = None, @@ -1079,34 +970,35 @@ def greedy_search( output_hidden_states: Optional[bool] = None, output_scores: Optional[bool] = None, return_dict_in_generate: Optional[bool] = None, - synced_gpus: bool = False, + synced_gpus: Optional[bool] = False, seq_length: Optional[int] = None, - streamer: Optional["BaseStreamer"] = None, + is_traced_inference: bool = False, **model_kwargs, - ) -> Union[GreedySearchOutput, torch.LongTensor]: + ) -> Union[BeamSearchOutput, torch.LongTensor]: r""" - Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be - used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + Generates sequences of token ids for models with a language modeling head using **beam search decoding** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - In most cases, you do not need to call [`~generation.GenerationMixin.greedy_search`] directly. Use generate() + In most cases, you do not need to call [`~generation.GenerationMixin.beam_search`] directly. Use generate() instead. For an overview of generation strategies and code examples, check the [following guide](../generation_strategies). - Parameters: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. + beam_scorer (`BeamScorer`): + An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and + sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. logits_processor (`LogitsProcessorList`, *optional*): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. stopping_criteria (`StoppingCriteriaList`, *optional*): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. - max_length (`int`, *optional*, defaults to 20): **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated tokens. The maximum length of the sequence to be generated. @@ -1126,75 +1018,74 @@ def greedy_search( Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. synced_gpus (`bool`, *optional*, defaults to `False`): Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - seq_length: + seq_length (`Optional[int]`, defaults to `False`): Length of current input_ids sequence - streamer (`BaseStreamer`, *optional*): - Streamer object that will be used to stream the generated sequences. Generated tokens are passed - through `streamer.put(token_ids)` and the streamer is responsible for any further processing. - Unsupported for XLA devices + is_traced_inference (`bool`, defaults to `False`): + Whether the decoder is traced or using XLA lazy tensor. If the decoder is traced, next tokens and the beam scores + are computed inside the decoder. model_kwargs: - Additional model specific keyword arguments will be forwarded to the `forward` function of the model. - If model is an encoder-decoder model the kwargs should include `encoder_outputs`. + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. Return: - [`~generation.GreedySearchDecoderOnlyOutput`], [`~generation.GreedySearchEncoderDecoderOutput`] or + [`generation.BeamSearchDecoderOnlyOutput`], [`~generation.BeamSearchEncoderDecoderOutput`] or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a - [`~generation.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and - `return_dict_in_generate=True` or a [`~generation.GreedySearchEncoderDecoderOutput`] if + [`~generation.BeamSearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.BeamSearchEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. + Examples: ```python - >>> from transformers import ( - ... AutoTokenizer, - ... AutoModelForCausalLM, - ... LogitsProcessorList, - ... MinLengthLogitsProcessor, - ... StoppingCriteriaList, - ... MaxLengthCriteria, - ... ) - - >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") - >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + >>> from transformers import AutoTokenizer + >>> from optimum.neuron import NeuronModelForSeq2SeqLM - >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token - >>> model.generation_config.pad_token_id = model.generation_config.eos_token_id + >>> tokenizer = AutoTokenizer.from_pretrained("t5-small") + >>> input_shapes = {"batch_size": 1, "sequence_length": 128, "num_beams": 4} + >>> model = NeuronModelForSeq2SeqLM.from_pretrained("t5-small", export=True, dynamic_batch_size=False, **input_shapes) - >>> input_prompt = "It might be possible to" - >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids - - >>> # instantiate logits processors - >>> logits_processor = LogitsProcessorList( - ... [ - ... MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id), - ... ] - ... ) - >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) + >>> input_prompt = "translate English to German: Lets eat good food." + >>> inputs = tokenizer(input_prompt, return_tensors="pt") - >>> outputs = model.greedy_search( - ... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria + >>> # add encoder_outputs to model keyword arguments + >>> model_kwargs = { + ... "encoder_outputs": model.get_encoder()( + ... encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True + ... ) + ... } + >>> # instantiate beam scorer + >>> beam_scorer = BeamSearchScorer( + ... batch_size=1, + ... num_beams=num_beams, + ... device=model.device, ... ) + >>> outputs = model.beam_search(input_ids, beam_scorer) >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) - ["It might be possible to get a better understanding of the nature of the problem, but it's not"] - ```""" + ``` + """ # init values - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - use_cache = model_kwargs["use_cache"] if "use_cache" in model_kwargs else False + if logits_processor is not None and is_traced_inference: + logger.warning( + "`logits_processor` will not be neglected because in `optimum-neuron`, `next_tokens` is computed inside the compiled decoder. If you want us to support custom logits_processor during the compilation, please file an issue to https://github.com/huggingface/optimum-neuron." + ) + elif logits_processor is None: + logits_processor = LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" - " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", + " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", UserWarning, ) stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + if len(stopping_criteria) == 0: + warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None output_scores = output_scores if output_scores is not None else self.generation_config.output_scores output_attentions = ( output_attentions if output_attentions is not None else self.generation_config.output_attentions @@ -1208,8 +1099,24 @@ def greedy_search( else self.generation_config.return_dict_in_generate ) + batch_size = len(beam_scorer._beam_hyps) + num_beams = beam_scorer.num_beams + + batch_beam_size, cur_len = input_ids.shape + + # Overwrite cur_len + cur_len = seq_length + + if num_beams * batch_size != batch_beam_size: + raise ValueError( + f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ) + # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None + beam_indices = ( + tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None + ) decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None @@ -1221,8 +1128,13 @@ def greedy_search( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) - # keep track of which sequences are already finished - unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) + # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens + # of the first beam are considered to avoid sampling the exact same tokens across all beams. + # beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) + beam_scores_device = "cpu" + beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=beam_scores_device) + beam_scores[:, 1:] = -1e9 + beam_scores = beam_scores.view((batch_size * num_beams,)) this_peer_finished = False # used by synced_gpus only while True: @@ -1237,113 +1149,153 @@ def greedy_search( break # prepare model inputs - if use_cache: - # From max_length-sized input_ids, select first - # seq_length - 1 values. - - if model_kwargs.get("past_key_values") is None: - input_ids_ = input_ids[:, :seq_length] - else: - update_indices = torch.stack( - [torch.arange(input_ids.size(0)), torch.tensor(seq_length - 1).repeat(input_ids.size(0))], - dim=-1, - ) - input_ids_ = input_ids[update_indices[:, 0], update_indices[:, 1], None] + if model_kwargs["use_cache"]: + import pdb + pdb.set_trace() + # From max_length-sized input_ids, select first + # cur_len - 1 values. + update_indices = torch.stack( + [torch.arange(input_ids.size(0)), torch.tensor(cur_len - 1).repeat(input_ids.size(0))], dim=-1 + ) + input_ids_ = input_ids[update_indices[:, 0], update_indices[:, 1], None] model_inputs = self.prepare_inputs_for_generation(input_ids_, **model_kwargs) else: model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - if synced_gpus and this_peer_finished: - continue # don't waste resources running the code we don't need + if is_traced_inference: + next_token_scores, next_tokens, next_indices = self(**model_inputs, beam_scores=beam_scores) + else: + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) - if not use_cache: - one_hot = ( - torch.cat( - [ - torch.tensor([0]).repeat(1, seq_length - 1), - torch.tensor([1]).repeat(1, 1), - torch.tensor([0]).repeat(1, input_ids.size(1) - seq_length), - ], - dim=1, + if synced_gpus and this_peer_finished: + cur_len = cur_len + 1 + continue # don't waste resources running the code we don't need + + if not model_kwargs["use_cache"]: + one_hot = ( + torch.cat( + [ + torch.tensor([0]).repeat(1, cur_len - 1), + torch.tensor([1]).repeat(1, 1), + torch.tensor([0]).repeat(1, input_ids.size(1) - cur_len), + ], + dim=1, + ) + .to(device=outputs.logits.device) + .float() ) - .to(device=outputs.logits.device) - .float() + next_token_logits = torch.matmul(one_hot, outputs.logits) + next_token_logits = next_token_logits.squeeze(1) + else: + next_token_logits = outputs.logits[:, -1, :] + + # Manually compute log softmax + # log_softmax(vi) = vi - max(vi) - log(sum(exp(vi - max(vi)))) + logit_max, _ = torch.max(next_token_logits, dim=-1, keepdim=True) + logsumexp = torch.log(torch.exp(next_token_logits - logit_max).sum(dim=-1, keepdim=True)) + next_token_scores = next_token_logits - logit_max - logsumexp + # (batch_size * num_beams, vocab_size) + + xm.mark_step() + + # We don't want to change every single logit processor, so + # we peform this processing on CPU. + input_ids_ = input_ids.to("cpu")[:, :cur_len] + next_token_scores_ = next_token_scores.to("cpu") + next_token_scores_processed = logits_processor(input_ids_, next_token_scores_) + + next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores_processed,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # reshape for beam search + vocab_size = next_token_scores.shape[-1] + next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) + next_token_scores = next_token_scores * 1 + + # Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search) + next_token_scores, next_tokens = torch.topk( + next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True ) - next_token_logits = torch.matmul(one_hot, outputs.logits) - next_token_logits = next_token_logits.squeeze(1) - else: - next_token_logits = outputs.logits[:, -1, :] - - # pre-process distribution - # Move to cpu to handle arbitrary logits_processor - next_tokens_scores = logits_processor(input_ids.to("cpu")[:, :seq_length], next_token_logits.to("cpu")) - next_tokens_scores = next_tokens_scores.to(input_ids.device) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - scores += (next_tokens_scores,) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - # argmax - next_tokens = torch.argmax(next_tokens_scores, dim=-1) + next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") + next_tokens = next_tokens % vocab_size - # finished sentences should have their next token be a padding token - if eos_token_id is not None: - if pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") - next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + # stateless + beam_outputs = beam_scorer.process( + input_ids.to("cpu")[:, :cur_len], + next_token_scores.to("cpu"), + next_tokens.to("cpu"), + next_indices.to("cpu"), + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + beam_indices=beam_indices, + ) + + beam_scores = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] - # update generated ids, model inputs, and length for next step - batch_size, _ = input_ids.shape update_indices = torch.stack( - [torch.arange(batch_size), torch.tensor(seq_length).repeat(batch_size)], dim=-1 + [torch.arange(batch_beam_size), torch.tensor(cur_len - 1).repeat(batch_beam_size)], dim=-1 ) - input_ids[update_indices[:, 0], update_indices[:, 1]] = next_tokens[:] + update_indices_2 = torch.stack( + [torch.arange(batch_beam_size), torch.tensor(cur_len).repeat(batch_beam_size)], dim=-1 + ) + # First select beam_indices + device = input_ids.device + beam_idx_device = beam_idx.to(device=input_ids.device) + input_ids[:, :] = input_ids[beam_idx_device.long(), :] + + # Then append new tokens + input_ids[update_indices_2[:, 0], update_indices_2[:, 1], None] = beam_next_tokens.unsqueeze(-1).to(device) + input_ids = input_ids * 1 # Hack to materialize tensor + + # update generated ids, model inputs, and length for next step model_kwargs = self._update_model_kwargs_for_xla_generation( outputs, model_kwargs, - batch_size=batch_size, + batch_size=batch_beam_size, is_encoder_decoder=self.config.is_encoder_decoder, max_length=stopping_criteria.max_length, - seq_length=seq_length, - use_cache=use_cache, + seq_length=cur_len, + use_cache=model_kwargs["use_cache"], ) + if is_traced_inference: + self._reorder_cache(beam_idx.to(torch.int64)) + elif model_kwargs["past_key_values"] is not None: + model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx) - seq_length += 1 - - # if eos_token was found in one sentence, set sentence to finished - if eos_token_id_tensor is not None: - unfinished_sequences = unfinished_sequences.mul( - next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) - ) + if return_dict_in_generate and output_scores: + beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) - xm.mark_step() + # increase cur_len + cur_len = cur_len + 1 # stop when each sentence is finished, or if we exceed the maximum length - stop_criterion_1 = unfinished_sequences.max() == 0 - + stop_criterion_1 = beam_scorer.is_done if isinstance(stopping_criteria, list): if len(stopping_criteria) == 1: stopping_criteria = stopping_criteria[0] @@ -1351,34 +1303,51 @@ def greedy_search( # Cases that can be handled in XLA without requiring # non-padded input_ids if isinstance(stopping_criteria, MaxLengthCriteria): - stop_criterion_2 = seq_length >= stopping_criteria.max_length + stop_criterion_2 = cur_len >= stopping_criteria.max_length elif isinstance(stopping_criteria, MaxTimeCriteria): stop_criterion_2 = stopping_criteria(input_ids, scores) else: # Other cases will be handled on CPU batch_size, _ = input_ids.shape + input_ids_cpu = input_ids.to("cpu") mask = torch.cat( - [torch.ones(batch_size, seq_length), torch.zeros(batch_size, input_ids.shape[1] - seq_length)], - dim=1, + [torch.ones(batch_size, cur_len), torch.zeros(batch_size, input_ids.shape[1] - cur_len)], dim=1 ).bool() - input_ids_cpu = torch.masked_select(input_ids, mask).reshape((batch_size, seq_length)).to("cpu") + input_ids_cpu = torch.masked_select(input_ids_cpu, mask).reshape((batch_size, cur_len)) scores_cpu = scores.to("cpu") if torch.is_tensor(scores) else scores stop_criterion_2 = stopping_criteria(input_ids_cpu, scores_cpu) if stop_criterion_1 or stop_criterion_2: - this_peer_finished = True + if not synced_gpus: + break + else: + this_peer_finished = True - if this_peer_finished and not synced_gpus: - break + sequence_outputs = beam_scorer.finalize( + input_ids.to("cpu"), + beam_scores.to("cpu"), + next_tokens.to("cpu"), + next_indices.to("cpu"), + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + beam_indices=beam_indices, + ) - if streamer is not None: - streamer.end() + for k, v in sequence_outputs.items(): + if type(v) == torch.Tensor: + sequence_outputs[k] = sequence_outputs[k].to(input_ids.device) if return_dict_in_generate: + if not output_scores: + sequence_outputs["sequence_scores"] = None + if self.config.is_encoder_decoder: - return GreedySearchEncoderDecoderOutput( - sequences=input_ids, + return BeamSearchEncoderDecoderOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], scores=scores, + beam_indices=sequence_outputs["beam_indices"], encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, @@ -1386,11 +1355,13 @@ def greedy_search( decoder_hidden_states=decoder_hidden_states, ) else: - return GreedySearchDecoderOnlyOutput( - sequences=input_ids, + return BeamSearchDecoderOnlyOutput( + sequences=sequence_outputs["sequences"], + sequences_scores=sequence_outputs["sequence_scores"], scores=scores, + beam_indices=sequence_outputs["beam_indices"], attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) else: - return input_ids + return sequence_outputs["sequences"] diff --git a/optimum/neuron/modeling_seq2seq.py b/optimum/neuron/modeling_seq2seq.py index 23396002e..b52e7e863 100644 --- a/optimum/neuron/modeling_seq2seq.py +++ b/optimum/neuron/modeling_seq2seq.py @@ -26,19 +26,12 @@ import torch from huggingface_hub import snapshot_download from transformers import AutoConfig, AutoModelForSeq2SeqLM, GenerationConfig -from transformers.generation.beam_search import BeamScorer from transformers.generation.logits_process import ( LogitsProcessorList, ) from transformers.generation.stopping_criteria import ( - MaxLengthCriteria, - MaxTimeCriteria, StoppingCriteriaList, ) -from transformers.generation.utils import ( - BeamSearchOutput, - GreedySearchOutput, -) from transformers.modeling_outputs import Seq2SeqLMOutput from ..exporters.neuron import ( @@ -59,6 +52,7 @@ if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedModel + from transformers.utils import ModelOutput if is_neuronx_available(): import torch_neuronx @@ -357,6 +351,10 @@ def forward( decoder_attention_mask: Optional[torch.BoolTensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, beam_scores=None, + # Leave following kwargs for compatibility, will not have any effect. + return_dict: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: hidden_states = encoder_outputs["last_hidden_state"] @@ -418,290 +416,14 @@ def generate( max_length=kwargs.pop("max_length", None) or max_length, num_beams=num_beams, do_sample=kwargs.pop("do_sample", False), - use_cache=kwargs.pop( - "use_cache", False - ), # `use_cache` is supported by default in `optimum-neuron`, set to False to avoid warning + use_cache=True, # pkv is cached by default decoder_attention_mask=decoder_attention_mask, # Pass fake encoder_outputs so the transfomers code will not invoke the encoder encoder_outputs={"last_hidden_state": torch.ones((batch_size, max_length, 1))}, + is_traced_inference=True, ) return output - def beam_search( - self, - input_ids: torch.LongTensor, - beam_scorer: "BeamScorer", - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, - output_scores: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: Optional[bool] = False, - seq_length: Optional[int] = None, - **model_kwargs, - ) -> Union[BeamSearchOutput, torch.LongTensor]: - """ - Overriding beam search to use next_token_scores returned from neuron device instead of logits. - """ - if logits_processor is not None: - logger.warning( - "`logits_processor` will be neglected because in `optimum-neuron`, `next_tokens` is computed inside the compiled decoder. If you want us to support custom logits_processor during the compilation, please file an issue to https://github.com/huggingface/optimum-neuron." - ) - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - - batch_size = len(beam_scorer._beam_hyps) - num_beams = beam_scorer.num_beams - - batch_beam_size, cur_len = input_ids.shape - - # Overwrite cur_len - cur_len = seq_length - - if num_beams * batch_size != batch_beam_size: - raise ValueError( - f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." - ) - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - beam_indices = ( - tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None - ) - - # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens - # of the first beam are considered to avoid sampling the exact same tokens across all beams. - beam_scores_device = "cpu" - beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=beam_scores_device) - beam_scores[:, 1:] = -1e9 - beam_scores = beam_scores.view((batch_size * num_beams,)) - - while True: - # prepare model inputs - # From max_length-sized input_ids, select first - # cur_len - 1 values. - update_indices = torch.stack( - [torch.arange(input_ids.size(0)), torch.tensor(cur_len - 1).repeat(input_ids.size(0))], dim=-1 - ) - input_ids_ = input_ids[update_indices[:, 0], update_indices[:, 1], None] - model_inputs = self.prepare_inputs_for_generation(input_ids_, **model_kwargs) - - next_token_scores, next_tokens, next_indices = self(**model_inputs, beam_scores=beam_scores) - - # stateless - beam_outputs = beam_scorer.process( - input_ids.to("cpu")[:, :cur_len], - next_token_scores.to("cpu"), - next_tokens.to("cpu"), - next_indices.to("cpu"), - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - beam_indices=beam_indices, - ) - - beam_scores = beam_outputs["next_beam_scores"] - beam_next_tokens = beam_outputs["next_beam_tokens"] - beam_idx = beam_outputs["next_beam_indices"] - - update_indices = torch.stack( - [torch.arange(batch_beam_size), torch.tensor(cur_len - 1).repeat(batch_beam_size)], dim=-1 - ) - update_indices_2 = torch.stack( - [torch.arange(batch_beam_size), torch.tensor(cur_len).repeat(batch_beam_size)], dim=-1 - ) - # First select beam_indices - device = input_ids.device - beam_idx_device = beam_idx.to(device=input_ids.device) - input_ids[:, :] = input_ids[beam_idx_device.long(), :] - - # Then append new tokens - input_ids[update_indices_2[:, 0], update_indices_2[:, 1], None] = ( - beam_next_tokens.unsqueeze(-1).to(device).to(torch.long) - ) - input_ids = input_ids * 1 # Hack to materialize tensor - - # update generated ids, model inputs, and length for next step - model_kwargs = self._update_model_kwargs_for_xla_generation( - model_kwargs, - batch_size=batch_beam_size, - is_encoder_decoder=self.config.is_encoder_decoder, - ) - self._reorder_cache(beam_idx.to(torch.int64)) - - if return_dict_in_generate and output_scores: - beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) - - # increase cur_len - cur_len = cur_len + 1 - - # stop when each sentence is finished, or if we exceed the maximum length - stop_criterion_1 = beam_scorer.is_done - if isinstance(stopping_criteria, list): - if len(stopping_criteria) == 1: - stopping_criteria = stopping_criteria[0] - - # Cases that can be handled in XLA without requiring - # non-padded input_ids - if isinstance(stopping_criteria, MaxLengthCriteria): - stop_criterion_2 = cur_len >= stopping_criteria.max_length - elif isinstance(stopping_criteria, MaxTimeCriteria): - stop_criterion_2 = stopping_criteria(input_ids, scores) - else: - # Other cases will be handled on CPU - batch_size, _ = input_ids.shape - input_ids_cpu = input_ids.to("cpu") - mask = torch.cat( - [torch.ones(batch_size, cur_len), torch.zeros(batch_size, input_ids.shape[1] - cur_len)], dim=1 - ).bool() - input_ids_cpu = torch.masked_select(input_ids_cpu, mask).reshape((batch_size, cur_len)) - scores_cpu = scores.to("cpu") if torch.is_tensor(scores) else scores - stop_criterion_2 = stopping_criteria(input_ids_cpu, scores_cpu) - - if stop_criterion_1 or stop_criterion_2: - if not synced_gpus: - break - - sequence_outputs = beam_scorer.finalize( - input_ids.to("cpu"), - beam_scores.to("cpu"), - next_tokens.to("cpu"), - next_indices.to("cpu"), - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - max_length=stopping_criteria.max_length, - beam_indices=beam_indices, - ) - - for k, v in sequence_outputs.items(): - if type(v) == torch.Tensor: - sequence_outputs[k] = sequence_outputs[k].to(input_ids.device) - - return sequence_outputs["sequences"] - - def greedy_search( - self, - input_ids: torch.LongTensor, - logits_processor: Optional["LogitsProcessorList"] = None, - stopping_criteria: Optional["StoppingCriteriaList"] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, - output_scores: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - seq_length: Optional[int] = int, - **model_kwargs, - ) -> Union[GreedySearchOutput, torch.LongTensor]: - """ - Overriding greedy sampling to use next tokens returned from neuron device instead of logits. - """ - # init values - if logits_processor is not None: - logger.warning( - "`logits_processor` will not be neglected because in `optimum-neuron`, `next_tokens` is computed inside the compiled decoder. If you want us to support custom logits_processor during the compilation, please file an issue to https://github.com/huggingface/optimum-neuron." - ) - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - if max_length is not None: - from transformers.generation.stopping_criteria import validate_stopping_criteria - - stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id - eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - - # keep track of which sequences are already finished - unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) - - this_peer_finished = False # used by synced_gpus only - while True: - # prepare model inputs - # From max_length-sized input_ids, select first - # seq_length - 1 values. - - if model_kwargs.get("past_key_values") is None: - input_ids_ = input_ids[:, :seq_length] - else: - update_indices = torch.stack( - [torch.arange(input_ids.size(0)), torch.tensor(seq_length - 1).repeat(input_ids.size(0))], - dim=-1, - ) - input_ids_ = input_ids[update_indices[:, 0], update_indices[:, 1], None] - - model_inputs = self.prepare_inputs_for_generation(input_ids_, **model_kwargs) - - # forward pass to get next token - output = self(**model_inputs) - next_tokens = output[0] - - # finished sentences should have their next token be a padding token - if eos_token_id is not None: - if pad_token_id is None: - raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") - next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) - - # update generated ids, model inputs, and length for next step - - batch_size, _ = input_ids.shape - update_indices = torch.stack( - [torch.arange(batch_size), torch.tensor(seq_length).repeat(batch_size)], dim=-1 - ) - input_ids[update_indices[:, 0], update_indices[:, 1]] = next_tokens[:] - model_kwargs = self._update_model_kwargs_for_xla_generation( - model_kwargs, - batch_size=batch_size, - is_encoder_decoder=self.config.is_encoder_decoder, - ) - - seq_length += 1 - - # if eos_token was found in one sentence, set sentence to finished - if eos_token_id_tensor is not None: - unfinished_sequences = unfinished_sequences.mul( - next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) - ) - - # stop when each sentence is finished, or if we exceed the maximum length - stop_criterion_1 = unfinished_sequences.max() == 0 - - if isinstance(stopping_criteria, list): - if len(stopping_criteria) == 1: - stopping_criteria = stopping_criteria[0] - - # Cases that can be handled in XLA without requiring - # non-padded input_ids - if isinstance(stopping_criteria, MaxLengthCriteria): - stop_criterion_2 = seq_length >= stopping_criteria.max_length - elif isinstance(stopping_criteria, MaxTimeCriteria): - stop_criterion_2 = stopping_criteria(input_ids, scores) - else: - # Other cases will be handled on CPU - batch_size, _ = input_ids.shape - mask = torch.cat( - [torch.ones(batch_size, seq_length), torch.zeros(batch_size, input_ids.shape[1] - seq_length)], - dim=1, - ).bool() - input_ids_cpu = torch.masked_select(input_ids, mask).reshape((batch_size, seq_length)).to("cpu") - scores_cpu = scores.to("cpu") if torch.is_tensor(scores) else scores - stop_criterion_2 = stopping_criteria(input_ids_cpu, scores_cpu) - - if stop_criterion_1 or stop_criterion_2: - this_peer_finished = True - - if this_peer_finished: - break - - return input_ids - def _reorder_cache(self, beam_idx): """ The cache was reordered during the tracing of the decoder so we can skip it here. This is needed for beam search and not greedy sampling. @@ -716,6 +438,12 @@ def _update_model_kwargs_for_xla_generation( model_kwargs: Dict[str, Any], batch_size: int, is_encoder_decoder: bool = False, + # Leave following kwargs for compatibility, will not have any effect. + outputs: "ModelOutput" = None, + standardize_cache_format: bool = False, + max_length: Optional[int] = None, + seq_length: Optional[int] = None, + use_cache: bool = True, ) -> Dict[str, Any]: mask = self._update_attention(model_kwargs, batch_size, is_encoder_decoder) # sets the updated variables (mask and past_key_values)