diff --git a/.jenkins/check/config/whitelizard.txt b/.jenkins/check/config/whitelizard.txt index 6eec0a8d46e1e13d7b966316be5424c9e7c5fb2d..48d4f951d8df1b26739bc6261a82360045509f79 100644 --- a/.jenkins/check/config/whitelizard.txt +++ b/.jenkins/check/config/whitelizard.txt @@ -11,7 +11,6 @@ mindformers/mindformers/models/llama/llama.py:construct mindformers/mindformers/trainer/base_trainer.py:training_process mindformers/mindformers/trainer/base_trainer.py:predict_process mindformers/mindformers/core/callback/callback.py:print_output_info -mindformers/mindformers/generation/text_generator.py:_beam_search mindformers/mindformers/generation/text_generator.py:generate mindformers/mindformers/generation/text_generator.py:postprocess mindformers/mindformers/models/modeling_utils.py:from_pretrained_experimental_mode diff --git a/docs/api/api_python/generation/mindformers.generation.GenerationMixin.rst b/docs/api/api_python/generation/mindformers.generation.GenerationMixin.rst index fa43a11c5e82b974780ce1c39ede68c1f7175ca7..255b04f26314dc504faa6d47174d098afcf622d8 100644 --- a/docs/api/api_python/generation/mindformers.generation.GenerationMixin.rst +++ b/docs/api/api_python/generation/mindformers.generation.GenerationMixin.rst @@ -76,7 +76,7 @@ mindformers.generation.GenerationMixin - **eos_token_id** (int) - 句子结束的词元索引。如果设置为None,则遵循模型配置中的设置。 - **pad_token_id** (int) - 填充的词元索引。如果设置为None,则遵循模型配置中的设置。 - **repetition_penalty** (float) - 生成单词频率的惩罚因子。如果将其设置为1,则不启用 `repeat_penalty` 。如果将其设置为 ``None`` ,则遵循模型配置中的设置。默认值: ``None`` 。 - - **num_beams** (int) - 用于束搜寻的束的数量。1表示不使用束搜寻。如果大于1,则 `do_sample` 将被设置为 ``False`` 。 + - **num_beams** (int) - 用于束搜寻的束的数量。1表示不使用束搜寻。当前仅支持设置为1。此参数将在未来删除。 返回: 生成的一个词元索引列表。 diff --git a/mindformers/generation/beam_search.py b/mindformers/generation/beam_search.py deleted file mode 100644 index fa308fc191881dac35285cde218bdacc91579361..0000000000000000000000000000000000000000 --- a/mindformers/generation/beam_search.py +++ /dev/null @@ -1,411 +0,0 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# Copyright 2020 The HuggingFace Inc. team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Beam search for text generation.""" -from abc import ABC, abstractmethod -from collections import UserDict -from typing import Optional, Union, List - -import numpy as np - - -class BeamScorer(ABC): - """Abstract base class for all beam scorers""" - - @abstractmethod - def process(self, input_ids, next_scores, next_tokens, next_indices, pad_token_id, eos_token_id, beam_indices, - group_index): - r""" - Args: - input_ids: - Indices of input sequence tokens in the vocabulary. - next_scores: - Current scores of the top `2 * num_beams` non-finished beam hypotheses. - next_tokens: - `input_ids` of the tokens corresponding to the top `2 * num_beams` non-finished beam hypotheses. - next_indices: - Beam indices indicating to which beam hypothesis the `next_tokens` correspond. - pad_token_id: - The id of the *padding* token. - eos_token_id: - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. - beam_indices: - Beam indices indicating to which beam hypothesis each token correspond. - group_index: - The index of the group of beams. - """ - raise NotImplementedError("This is an abstract method.") - - def finalize(self, input_ids, final_beam_scores, max_length, pad_token_id, eos_token_id, beam_indices): - r""" - Args: - input_ids: - Indices of input sequence tokens in the vocabulary. - final_beam_scores: - The final scores of all non-finished beams. - max_length: - The max_length of output ids. - pad_token_id: - The id of the *padding* token. - eos_token_id: - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. - beam_indices: - Beam indices indicating to which beam hypothesis each token correspond. - """ - raise NotImplementedError("This is an abstract method.") - - -class BeamSearchScorer(BeamScorer): - r""" - [`BeamScorer`] implementing standard beam search decoding. - - Args: - batch_size (`int`): - Batch Size of `input_ids` for which standard beam search decoding is run in parallel. - num_beams (`int`): - Number of beams for beam search. - length_penalty (`float`, *optional*, defaults to 1.0): - Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to - the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log - likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while - `length_penalty` < 0.0 encourages shorter sequences. - do_early_stopping (`bool` or `str`, *optional*, defaults to `False`): - Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values: - `True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an - heuristic is applied and the generation stops when is it very unlikely to find better candidates; - `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical - beam search algorithm). - num_beam_hyps_to_keep (`int`, *optional*, defaults to 1): - The number of beam hypotheses that shall be returned upon calling - [`~transformer.BeamSearchScorer.finalize`]. - num_beam_groups (`int`): - Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. - See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. - max_length (`int`, *optional*): - The maximum length of the sequence to be generated. - """ - - def __init__(self, - batch_size: int, - num_beams: int, - length_penalty: Optional[float] = 1.0, - do_early_stopping: Optional[Union[bool, str]] = False, - num_beam_hyps_to_keep: Optional[int] = 1, - num_beam_groups: Optional[int] = 1, - max_length: Optional[int] = None): - self.num_beams = num_beams - self.length_penalty = length_penalty - self.do_early_stopping = do_early_stopping - self.num_beam_hyps_to_keep = num_beam_hyps_to_keep - self.num_beam_groups = num_beam_groups - self.group_size = self.num_beams // self.num_beam_groups - - self._is_init = False - # self._beam_hyps[i*self.num_beam_groups+j] is the beam_hyps of the j-th group in the i-th mini-batch. - # If group_beam_search is not used, the list consists of `batch_size` beam_hyps. - self._beam_hyps = [ - BeamHypotheses( - num_beams=self.group_size, - length_penalty=self.length_penalty, - early_stopping=self.do_early_stopping, - max_length=max_length, - ) - for _ in range(batch_size * self.num_beam_groups) - ] - # self._done[i*self.num_beam_groups+j] indicates whether the generation of the beam_hyps of the j-th group - # in the i-th mini-batch is complete. - self._done = np.array([False for _ in range(batch_size * self.num_beam_groups)]) - - if not isinstance(num_beams, int) or num_beams <= 1: - raise ValueError( - f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1," - " one should make use of `greedy_search` instead." - ) - - if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0): - raise ValueError( - "`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be" - f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}." - ) - - @property - def is_done(self) -> bool: - return self._done.all() - - def process(self, - input_ids, - next_scores, - next_tokens, - next_indices, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, - beam_indices=None, - group_index: Optional[int] = 0): - batch_size = len(self._beam_hyps) // self.num_beam_groups - - if not batch_size == (input_ids.shape[0] // self.group_size): - if self.num_beam_groups > 1: - raise ValueError( - f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam " - f"size of {self.group_size} is expected by the beam scorer." - ) - raise ValueError( - f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of " - f"{self.group_size} is expected by the beam scorer." - ) - - next_beam_scores = np.zeros((batch_size, self.group_size), dtype=next_scores.dtype) - next_beam_tokens = np.zeros((batch_size, self.group_size), dtype=next_tokens.dtype) - next_beam_indices = np.zeros((batch_size, self.group_size), dtype=next_indices.dtype) - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - - for batch_idx in range(batch_size): - batch_group_idx = batch_idx * self.num_beam_groups + group_index - if self._done[batch_group_idx]: - if self.num_beams < len(self._beam_hyps[batch_group_idx]): - raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated") - if eos_token_id is None or pad_token_id is None: - raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined") - # pad the batch - next_beam_scores[batch_idx, :] = 0 - next_beam_tokens[batch_idx, :] = pad_token_id - next_beam_indices[batch_idx, :] = 0 - continue - - # next tokens for this sentence - beam_idx = 0 - cur_len = np.min(np.where(input_ids[beam_idx] == pad_token_id)) - for beam_token_rank, (next_token, next_score, next_index) in enumerate( - zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx]) - ): - batch_beam_idx = batch_idx * self.group_size + next_index - # add to generated hypotheses if end of sentence - if (eos_token_id is not None) and (next_token in eos_token_id): - # if beam_token does not belong to top num_beams tokens, it should not be added - is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size - if is_beam_token_worse_than_top_num_beams: - continue - if beam_indices is not None: - beam_index = beam_indices[batch_beam_idx] - beam_index = beam_index + (batch_beam_idx,) - else: - beam_index = None - - self._beam_hyps[batch_group_idx].add( - input_ids[batch_beam_idx].copy(), - next_score, - beam_indices=beam_index, - ) - else: - # add next predicted token since it is not eos_token - next_beam_scores[batch_idx, beam_idx] = next_score - next_beam_tokens[batch_idx, beam_idx] = next_token - next_beam_indices[batch_idx, beam_idx] = batch_beam_idx - beam_idx += 1 - - # once the beam for next step is full, don't add more tokens to it. - if beam_idx == self.group_size: - break - - if beam_idx < self.group_size: - raise ValueError( - f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:" - f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected." - ) - - # Check if we are done so that we can save a pad step if all(done) - self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done( - next_scores[batch_idx].max(), cur_len - ) - - return UserDict( - { - "next_beam_scores": next_beam_scores.flatten(), - "next_beam_tokens": next_beam_tokens.flatten(), - "next_beam_indices": next_beam_indices.flatten(), - } - ) - - def finalize(self, - input_ids, - final_beam_scores, - max_length: int, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, - beam_indices=None): - batch_size = len(self._beam_hyps) // self.num_beam_groups - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - - # finalize all open beam hypotheses and add to generated hypotheses - for batch_group_idx, beam_hyp in enumerate(self._beam_hyps): - if self._done[batch_group_idx]: - continue - - # all open beam hypotheses are added to the beam hypothesis - # beam hypothesis class automatically keeps the best beams - for index_per_group in range(self.group_size): - batch_beam_idx = batch_group_idx * self.group_size + index_per_group - final_score = final_beam_scores[batch_beam_idx].item() - final_tokens = input_ids[batch_beam_idx] - beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None - beam_hyp.add(final_tokens, final_score, beam_indices=beam_index) - - # select the best hypotheses - sent_lengths = np.zeros(batch_size * self.num_beam_hyps_to_keep, dtype=np.int32) - best = [] - best_indices = [] - best_scores = np.zeros(batch_size * self.num_beam_hyps_to_keep, dtype=np.float32) - - # retrieve best hypotheses - for i in range(batch_size): - beam_hyps_in_batch = self._beam_hyps[i * self.num_beam_groups: (i + 1) * self.num_beam_groups] - candidate_beams = [ - beam - for beam_hyp in beam_hyps_in_batch - for beam in beam_hyp.beams - ] - sorted_hyps = sorted(candidate_beams, key=lambda x: x[0]) - for j in range(self.num_beam_hyps_to_keep): - best_hyp_tuple = sorted_hyps.pop() - best_score = best_hyp_tuple[0] - best_hyp = best_hyp_tuple[1] - best_index = best_hyp_tuple[2] - sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) - - # append hyp to lists - best.append(best_hyp) - - # append indices to list - best_indices.append(best_index) - - best_scores[i * self.num_beam_hyps_to_keep + j] = best_score - - # prepare for adding eos - sent_lengths_max = max(sent_lengths) + 1 - sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max - decoded = np.zeros((batch_size * self.num_beam_hyps_to_keep, sent_max_len), dtype=np.int32) - - if best_indices and best_indices[0] is not None: - indices = np.zeros((batch_size * self.num_beam_hyps_to_keep, sent_max_len), dtype=np.int32) - else: - indices = None - - # shorter batches are padded if needed - if min(sent_lengths) != max(sent_lengths): - if pad_token_id is None: - raise ValueError("`pad_token_id` has to be defined") - decoded.fill(pad_token_id) - - if indices is not None: - indices.fill(-1) - - # fill with hypotheses and eos_token_id if the latter fits in - for i, (hypo, best_idx) in enumerate(zip(best, best_indices)): - sent_length = min(decoded.shape[-1], sent_lengths[i]) - decoded[i, : sent_length] = hypo[:sent_length] - - if indices is not None: - indices[i, : len(best_idx)] = best_idx - - if sent_lengths[i] < sent_max_len: - # inserting only the first eos_token_id - decoded[i, sent_lengths[i]] = eos_token_id[0] - - return UserDict( - { - "sequences": decoded, - "sequence_scores": best_scores, - "beam_indices": indices, - } - ) - - -class BeamHypotheses: - """ - Beam hypotheses maintaining n-best list - """ - - def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool, max_length: Optional[int] = None): - """ - Initialize n-best list of hypotheses. - """ - self.length_penalty = length_penalty - self.early_stopping = early_stopping - self.max_length = max_length - self.num_beams = num_beams - self.beams = [] - self.worst_score = 1e9 - - if not isinstance(self.early_stopping, bool) and self.max_length is None: - raise ValueError( - "When `do_early_stopping` is set to a string, `max_length` must be defined. Ensure it is passed to the" - " BeamScorer class instance at initialization time." - ) - - def __len__(self): - """ - Number of hypotheses in the list. - """ - return len(self.beams) - - def add(self, hyp, sum_logprobs: float, beam_indices=None): - """ - Add a new hypothesis to the list. - """ - score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty) - if len(self) < self.num_beams or score > self.worst_score: - self.beams.append((score, hyp, beam_indices)) - if len(self) > self.num_beams: - sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)]) - del self.beams[sorted_next_scores[0][1]] - self.worst_score = sorted_next_scores[1][0] - else: - self.worst_score = min(score, self.worst_score) - - def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: - """ - If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst - one in the heap, then we are done with this sentence. - """ - - if len(self) < self.num_beams: - return False - - # `True`: stop as soon as at least `num_beams` hypotheses are finished - if self.early_stopping is True: - return True - - # `False`: heuristic compute the best possible score from `cur_len`, even though it is not entirely accurate - # when `length_penalty` is positive. - if self.early_stopping is False: - highest_attainable_score = best_sum_logprobs / cur_len ** self.length_penalty - ret = self.worst_score >= highest_attainable_score - return ret - - # `"never"`: compute the best possible score, depending on the signal of `length_penalty` - # `length_penalty` > 0.0 -> max denominator is obtained from `max_length`, not from `cur_len` -> min - # abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain - # its max this way - if self.length_penalty > 0.0: - highest_attainable_score = best_sum_logprobs / self.max_length ** self.length_penalty - # the opposite logic applies here (max `highest_attainable_score` from `cur_len`) - else: - highest_attainable_score = best_sum_logprobs / cur_len ** self.length_penalty - ret = self.worst_score >= highest_attainable_score - return ret diff --git a/mindformers/generation/generation_config.py b/mindformers/generation/generation_config.py index d8a02285cc6ba7633b52c5874f80ef91239df3a6..76b06514c364a4caad878ed1714887f5db3b198c 100644 --- a/mindformers/generation/generation_config.py +++ b/mindformers/generation/generation_config.py @@ -151,6 +151,9 @@ class GenerationConfig: # number of beams self.num_beams = kwargs.pop("num_beams", 1) + if self.num_beams > 1: + self.num_beams = 1 + logger.warning("Beam search is no longer supported, will set num_beams to 1.") # do sample or not self.do_sample = kwargs.pop("do_sample", False) # incremental infer diff --git a/mindformers/generation/text_generator.py b/mindformers/generation/text_generator.py index 8c7d57c975da4cd2a1772505d8788a4fd29c4121..90da2a99aad6507d0700ac4a865659e99779434c 100644 --- a/mindformers/generation/text_generator.py +++ b/mindformers/generation/text_generator.py @@ -29,7 +29,6 @@ from mindspore.ops import operations as P import mindspore.common.dtype as mstype from mindspore.common.tensor import Tensor -from mindformers.generation.beam_search import BeamSearchScorer from mindformers.generation.generation_config import GenerationConfig from mindformers.generation.logits_process import (LogitNormalization, LogitsProcessorList, RepetitionPenaltyLogitsProcessor, @@ -77,11 +76,8 @@ class GenerationMode: Possible generation modes. """ - # Non-beam methods GREEDY_SEARCH = "greedy_search" SAMPLE = "sample" - # Beam methods - BEAM_SEARCH = "beam_search" class GenerationMixin: @@ -405,11 +401,9 @@ class GenerationMixin: @staticmethod def _get_generation_mode(generation_config: GenerationConfig): """determine the generation mode by config""" - if generation_config.num_beams == 1: - if generation_config.do_sample: - return GenerationMode.SAMPLE - return GenerationMode.GREEDY_SEARCH - return GenerationMode.BEAM_SEARCH + if generation_config.do_sample: + return GenerationMode.SAMPLE + return GenerationMode.GREEDY_SEARCH def _prepare_model_inputs_for_decoder(self, input_ids, input_mask): """generate the inputs for the decoder""" @@ -542,230 +536,6 @@ class GenerationMixin: res = self.gather(res, mint.cumsum(q_seq_lens, dim=0) - 1, 0) return res - def _beam_search(self, - origin_inputs, - beam_scorer: BeamSearchScorer, - generation_config: GenerationConfig, - logits_processor: Optional[LogitsProcessorList] = None, - streamer: BaseStreamer = None, - **model_kwargs): - r""" - Generates sequences of token ids for models with a language modeling head using **beam search decoding** and - can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. - - Parameters: - origin_inputs (`List(str), List(List(str))`): - The sequence used as a prompt for the generation. - beam_scorer (`BeamScorer`): - An derived instance of [`BeamScorer`] that defines how beam hypotheses are constructed, stored and - sorted during generation. For more information, the documentation of [`BeamScorer`] should be read. - generation_config (`GenerationConfig`, *optional*): - The generation configuration to be used as base parametrization for the generation - call. `**kwargs` passed to generate matching the attributes of `generation_config` - will override them. If `generation_config` is not provided, the default config - from the model configuration will be used. Please note that unspecified parameters - will inherit [`GenerationConfig`]'s default values, whose documentation should be - checked to parameterize generation. - logits_processor (`LogitsProcessorList`, *optional*): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - streamer (`TextStreamer, *optional*`): - The streamer that generator uses. - model_kwargs: - Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is - an encoder-decoder model the kwargs should include `encoder_outputs`. - - Return: - A list of the generated token ids - """ - if streamer is not None: - raise ValueError("Streamer does not support in beam search method yet!") - if generation_config.use_past: - raise ValueError("Beam search does not support incremental inference yet! Please set use_past to False.") - if self.config.is_sample_acceleration: - raise ValueError("Beam search does not support sample acceleration yet! " - "Please set is_sample_acceleration to False.") - - total_time = time.time() - prepare_time = time.time() - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - - batch_size = len(beam_scorer._beam_hyps) # pylint: disable=W0212 - num_beams = beam_scorer.num_beams - batch_beam_size = origin_inputs.shape[0] - logger.debug("The input shape is: %s", origin_inputs.shape) - if num_beams * batch_size != batch_beam_size: - raise ValueError( - f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." - ) - - valid_length_each_example, _ = \ - get_valid_length_each_example(origin_inputs, generation_config.pad_token_id) - - target_length = ( - self.config.seq_length - if generation_config.max_length > self.config.seq_length - else generation_config.max_length - ) - logger.debug("max target_length is: %s", target_length) - input_ids = self._pad_inputs_using_max_length( - origin_inputs=origin_inputs, pad_token_id=generation_config.pad_token_id - ) - - logger.debug( - "pad the origin inputs from %s into shape: %s", - origin_inputs.shape, - input_ids.shape, - ) - - beam_scores = np.zeros((batch_size, num_beams), dtype=np.float64) - beam_scores[:, 1:] = -1e9 - beam_scores = beam_scores.reshape((batch_size * num_beams,)) - - input_mask = np.zeros_like(input_ids) - for i in range(valid_length_each_example.shape[0]): - input_mask[i, :valid_length_each_example[i]] = 1 - encoder_output = None - encoder_mask = None - if self.config.is_encoder_decoder: - target_length = min(target_length, self.config.max_decode_length) - logger.debug("target_length is: %s", target_length) - - # When do encoder and decoder prediction, the encoder can be cached - # to speed up the inference - ( - encoder_output, - encoder_mask, - input_ids, - target_mask, - ) = self._prepare_model_inputs_for_decoder(input_ids, input_mask) - valid_length_each_example = np.ones((batch_beam_size, 1)).astype(np.int32) - - # update model kwargs once, before go into generate loop. - self.update_model_kwargs_before_generate(input_ids, model_kwargs) - - need_gather_logits = True - - is_first_token = True - - origin_len = np.sum(valid_length_each_example) / num_beams - prepare_time = time.time() - prepare_time - logger.debug("forward prepare time: %s s", prepare_time) - - while True: - forward_time = time.time() - seq_length = input_ids.shape[1] - current_index = [ - valid_length_each_example[i] - 1 + i * seq_length - for i in range(batch_beam_size) - ] - logger.debug("validate length: %s", valid_length_each_example) - if self.config.is_encoder_decoder: - inputs = Tensor(input_ids, mstype.int32) - # pylint: disable=E1102 - res = self( - input_ids=None, - attention_mask=encoder_mask, - encoder_outputs=encoder_output, - decoder_input_ids=inputs, - decoder_attention_mask=Tensor(target_mask, mstype.float32), - ) - else: - model_kwargs["current_index"] = current_index - # model prepare input dict - model_inputs = self.prepare_inputs_for_generation( # pylint: disable=E1111 - input_ids, **model_kwargs - ) - # incremental generate - if generation_config.use_past: - logger.warning("Beam search currently not support incremental, " - "auto-aggressive generate will be performed.") - # auto-aggressive generate - res = self(**model_inputs) # pylint: disable=E1102 - forward_time = time.time() - forward_time - - search_time = time.time() - # post process logits - # convert to numpy for post process - logits = res[0] if isinstance(res, tuple) else res - if isinstance(logits, Tensor): - logits = logits.asnumpy().astype(np.float32) - logits = np.reshape(logits, (-1, logits.shape[-1])) # (batch_size * num_beams * seq_length, vocab_size) - # need gather last seq logits using current_index - # compare length to determine if need gather; if not, gather should be done in model construct - if need_gather_logits and logits.shape[0] > len(current_index): - logits = logits[current_index] # (total_batch_size, vocab_size) - logits_processor.append(LogitNormalization()) - - # post process logits, without changing logits shape and order - next_token_scores = logits_processor(input_ids, logits) # (batch_size * num_beams, vocab_size) - - # reshape for beam search - vocab_size = next_token_scores.shape[-1] - next_token_scores = np.reshape(next_token_scores, (batch_size, -1)) # (batch_size, num_beams * vocab_size) - - if is_first_token: - next_token_scores = next_token_scores[:, :vocab_size] - is_first_token = False - - # sample 2 next tokens for each beam, so we have at least 1 non eos token per beam - next_token_scores, next_tokens = topk( - next_token_scores, 2 * num_beams, axis=1, largest=True, sort=True - ) - - next_indices = np.floor_divide(next_tokens, vocab_size) - next_tokens = next_tokens % vocab_size - - beam_outputs = beam_scorer.process( - input_ids, # (batch_size * num_beams, seq_length) - next_token_scores, - next_tokens, - next_indices, - pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id - ) - beam_scores = beam_outputs["next_beam_scores"] - beam_next_tokens = beam_outputs["next_beam_tokens"] - beam_idx = beam_outputs["next_beam_indices"] - search_time = time.time() - search_time - - update_time = time.time() - # reorder model inputs - old_input_ids = input_ids.copy() - for i in range(batch_beam_size): - input_ids[i] = old_input_ids[beam_idx[i], :] - - # add new tokens to input_ids - for i in range(batch_beam_size): - input_ids[i, valid_length_each_example[i]] = beam_next_tokens[i] - if self.config.is_encoder_decoder: - target_mask[i][valid_length_each_example[i]] = int(1) - - input_mask[i][valid_length_each_example[i]] = 1 - valid_length_each_example[i] += int(1) - - update_time = time.time() - update_time - logger.debug("forward time: %s s; beam search time: %s s; update time: %s s; total count: %s s", - forward_time, search_time, update_time, forward_time + search_time + update_time) - - if beam_scorer.is_done or np.min(valid_length_each_example) >= generation_config.max_length: - break - - sequence_outputs = beam_scorer.finalize( - input_ids, - beam_scores, - pad_token_id=generation_config.pad_token_id, - eos_token_id=generation_config.eos_token_id, - max_length=generation_config.max_length - ) - - generate_len = np.sum(valid_length_each_example) / num_beams - origin_len - total_time = time.time() - total_time - logger.info("total time: %s s; generated tokens: %s tokens; generate speed: %s tokens/s", - total_time, generate_len, generate_len / total_time) - - return sequence_outputs["sequences"] - def generate(self, input_ids: Optional[Union[List[int], List[List[int]]]], generation_config: Optional[GenerationConfig] = None, @@ -827,8 +597,8 @@ class GenerationMixin: - repetition_penalty (float): The penalty factor of the frequency that generated words. The If set 1, the repetition_penalty will not be enabled. If set None, it follows the setting in the configureation in the model. Default: ``None``. - - num_beams (int): Number of beams for beam search. 1 means no beam search. If larger than 1, do_sample - will be set to false. + - num_beams (int): Number of beams for beam search. 1 means no beam search. Only 1 is supported now. + This parameter will be deleted in the future. Returns: A list of the generated token ids. @@ -873,10 +643,6 @@ class GenerationMixin: **kwargs ) # All unused kwargs must be model kwargs - if generation_config.num_beams > 1: - logger.warning("When num_beams is set to a value greater than 1, do_sample will be set to False, " - "due to the current beam search does not support sampling.") - generation_config.do_sample = False logger.info("Generation Config is: %s", generation_config) if generation_config.pad_token_id is None: @@ -936,10 +702,6 @@ class GenerationMixin: # determine generation mode generation_config.generation_mode = self._get_generation_mode(generation_config) logger.info(f"The generation mode will be **{generation_config.generation_mode.upper()}**.") - if streamer is not None and (generation_config.num_beams > 1): - raise ValueError( - "`streamer` cannot be used with beam search yet. Make sure that `num_beams` is set to 1." - ) if not use_legacy and not hasattr(self, "is_train_model"): self._set_block_mgr(batch_size, self.config.seq_length) @@ -960,176 +722,154 @@ class GenerationMixin: scores = () if generation_config.return_dict_in_generate and generation_config.output_scores else None raw_logits = () if generation_config.return_dict_in_generate and generation_config.output_logits else None - # beam search - if generation_config.generation_mode == GenerationMode.BEAM_SEARCH: - # prepare beam search scorer - beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=generation_config.num_beams, - max_length=generation_config.max_length - ) - # interleave input_ids with `num_beams` additional sequences per batch - input_ids = np.repeat(input_ids, generation_config.num_beams, 0) - - # run beam search - output_ids = self._beam_search( - origin_inputs=input_ids, - beam_scorer=beam_scorer, - generation_config=generation_config, - logits_processor=logits_processor, - streamer=streamer, - **model_kwargs - ) - # greedy search or sample - else: - total_time = time.time() - prepare_time = time.time() + total_time = time.time() + prepare_time = time.time() - origin_inputs = input_ids - logits_warper = self.get_logits_warper(generation_config) \ - if generation_config.generation_mode == GenerationMode.SAMPLE else None + origin_inputs = input_ids + logits_warper = self.get_logits_warper(generation_config) \ + if generation_config.generation_mode == GenerationMode.SAMPLE else None - if streamer is not None: - streamer.put(origin_inputs) + if streamer is not None: + streamer.put(origin_inputs) - batch_size = origin_inputs.shape[0] - logger.debug("The input shape is: %s", origin_inputs.shape) + batch_size = origin_inputs.shape[0] + logger.debug("The input shape is: %s", origin_inputs.shape) - valid_length_each_example, _ = \ - get_valid_length_each_example(origin_inputs, generation_config.pad_token_id) + valid_length_each_example, _ = \ + get_valid_length_each_example(origin_inputs, generation_config.pad_token_id) - input_ids = self._pad_inputs_using_max_length( - origin_inputs=origin_inputs, pad_token_id=generation_config.pad_token_id - ) + input_ids = self._pad_inputs_using_max_length( + origin_inputs=origin_inputs, pad_token_id=generation_config.pad_token_id + ) - logger.debug( - "pad the origin inputs from %s into shape: %s", - origin_inputs.shape, - input_ids.shape, - ) + logger.debug( + "pad the origin inputs from %s into shape: %s", + origin_inputs.shape, + input_ids.shape, + ) - input_mask = np.zeros_like(input_ids) - for i in range(valid_length_each_example.shape[0]): - input_mask[i, :valid_length_each_example[i]] = 1 - encoder_output = None - encoder_mask = None - target_mask = None - if self.config.is_encoder_decoder: - generation_config.max_length = min(generation_config.max_length, self.config.max_decode_length) - logger.debug("max decode length is: %s", generation_config.max_length) - - # When do encoder and decoder prediction, the encoder can be cached - # to speed up the inference - ( - encoder_output, - encoder_mask, - input_ids, - target_mask, - ) = self._prepare_model_inputs_for_decoder(input_ids, input_mask) - valid_length_each_example = np.array([1 for _ in range(batch_size)]) - # A single loop generates one token, loop until reaching target - # model_origin_max_length or generating eod token - is_finished = [False] * batch_size - - # update model kwargs once, before go into generate loop. - self.update_model_kwargs_before_generate(input_ids, model_kwargs) - - origin_len = np.sum(valid_length_each_example) - prepare_time = time.time() - prepare_time - logger.debug("forward prepare time: %s s", prepare_time) - - prefill = True - model_kwargs["origin_inputs"] = origin_inputs - - if (hasattr(self.config, 'pet_config') and self.config.pet_config is not None - and self.config.pet_config.pet_type == "slora"): - adapter_id = kwargs.pop("adapter_id", None) - if adapter_id is not None and len(adapter_id) > 1: - if len(adapter_id) != batch_size: - raise ValueError("adapter_ids has different length with inputs.") - model_kwargs["adapter_ids"] = adapter_id - else: - model_kwargs["adapter_ids"] = adapter_id * batch_size if adapter_id is not None else None - - while np.sum(is_finished) != batch_size: - self.detailed_latency.start_preprocess_timer() - block_tables = None - slot_mapping = None - if (not use_legacy or generation_config.use_past) and not hasattr(self, "is_train_model"): - if prefill: - if (use_legacy and self.is_pynative and self.config.is_dynamic): - max_input_length = len(origin_inputs[0]) - else: - max_input_length = self.config.seq_length - block_tables, slot_mapping = self.block_mgr.assemble_pa_full_inputs(max_input_length, - valid_length_each_example, - is_finished) + input_mask = np.zeros_like(input_ids) + for i in range(valid_length_each_example.shape[0]): + input_mask[i, :valid_length_each_example[i]] = 1 + encoder_output = None + encoder_mask = None + target_mask = None + if self.config.is_encoder_decoder: + generation_config.max_length = min(generation_config.max_length, self.config.max_decode_length) + logger.debug("max decode length is: %s", generation_config.max_length) + + # When do encoder and decoder prediction, the encoder can be cached + # to speed up the inference + ( + encoder_output, + encoder_mask, + input_ids, + target_mask, + ) = self._prepare_model_inputs_for_decoder(input_ids, input_mask) + valid_length_each_example = np.array([1 for _ in range(batch_size)]) + # A single loop generates one token, loop until reaching target + # model_origin_max_length or generating eod token + is_finished = [False] * batch_size + + # update model kwargs once, before go into generate loop. + self.update_model_kwargs_before_generate(input_ids, model_kwargs) + + origin_len = np.sum(valid_length_each_example) + prepare_time = time.time() - prepare_time + logger.debug("forward prepare time: %s s", prepare_time) + + prefill = True + model_kwargs["origin_inputs"] = origin_inputs + + if (hasattr(self.config, 'pet_config') and self.config.pet_config is not None + and self.config.pet_config.pet_type == "slora"): + adapter_id = kwargs.pop("adapter_id", None) + if adapter_id is not None and len(adapter_id) > 1: + if len(adapter_id) != batch_size: + raise ValueError("adapter_ids has different length with inputs.") + model_kwargs["adapter_ids"] = adapter_id + else: + model_kwargs["adapter_ids"] = adapter_id * batch_size if adapter_id is not None else None + + while np.sum(is_finished) != batch_size: + self.detailed_latency.start_preprocess_timer() + block_tables = None + slot_mapping = None + if (not use_legacy or generation_config.use_past) and not hasattr(self, "is_train_model"): + if prefill: + if (use_legacy and self.is_pynative and self.config.is_dynamic): + max_input_length = len(origin_inputs[0]) else: - block_tables, slot_mapping = self.block_mgr.assemble_pa_inc_inputs(valid_length_each_example, - is_finished) - self.profile.start_profiling(valid_length_each_example[0] - input_ids_length) - if use_legacy or (hasattr(self, "is_train_model") and self.is_train_model): - infer_output, is_finished = self.infer(input_ids=input_ids, - valid_length_each_example=valid_length_each_example, - generation_config=generation_config, - logits_processor=logits_processor, - logits_warper=logits_warper, - block_tables=block_tables, - slot_mapping=slot_mapping, - prefill=prefill, - is_finished=is_finished, - encoder_mask=encoder_mask, - encoder_output=encoder_output, - target_mask=target_mask, - **model_kwargs) + max_input_length = self.config.seq_length + block_tables, slot_mapping = self.block_mgr.assemble_pa_full_inputs(max_input_length, + valid_length_each_example, + is_finished) else: - infer_output, is_finished = self.infer_mcore(input_ids=input_ids, - valid_length_each_example=valid_length_each_example, - generation_config=generation_config, - logits_processor=logits_processor, - logits_warper=logits_warper, - block_tables=block_tables, - slot_mapping=slot_mapping, - prefill=prefill, - is_finished=is_finished, - **model_kwargs) - self.profile.stop_profiling(valid_length_each_example[0] - input_ids_length) - if generation_config.return_dict_in_generate: - target_list = infer_output["target_list"] - if generation_config.output_scores: - scores += (infer_output["probs"],) - if generation_config.output_logits: - raw_logits += (infer_output["logits"],) - else: - target_list = infer_output - if not use_legacy or generation_config.use_past: - if prefill and "origin_inputs" in model_kwargs: - model_kwargs.pop("origin_inputs") - prefill = False + block_tables, slot_mapping = self.block_mgr.assemble_pa_inc_inputs(valid_length_each_example, + is_finished) + self.profile.start_profiling(valid_length_each_example[0] - input_ids_length) + if use_legacy or (hasattr(self, "is_train_model") and self.is_train_model): + infer_output, is_finished = self.infer(input_ids=input_ids, + valid_length_each_example=valid_length_each_example, + generation_config=generation_config, + logits_processor=logits_processor, + logits_warper=logits_warper, + block_tables=block_tables, + slot_mapping=slot_mapping, + prefill=prefill, + is_finished=is_finished, + encoder_mask=encoder_mask, + encoder_output=encoder_output, + target_mask=target_mask, + **model_kwargs) + else: + infer_output, is_finished = self.infer_mcore(input_ids=input_ids, + valid_length_each_example=valid_length_each_example, + generation_config=generation_config, + logits_processor=logits_processor, + logits_warper=logits_warper, + block_tables=block_tables, + slot_mapping=slot_mapping, + prefill=prefill, + is_finished=is_finished, + **model_kwargs) + self.profile.stop_profiling(valid_length_each_example[0] - input_ids_length) + if generation_config.return_dict_in_generate: + target_list = infer_output["target_list"] + if generation_config.output_scores: + scores += (infer_output["probs"],) + if generation_config.output_logits: + raw_logits += (infer_output["logits"],) + else: + target_list = infer_output + if not use_legacy or generation_config.use_past: + if prefill and "origin_inputs" in model_kwargs: + model_kwargs.pop("origin_inputs") + prefill = False - for i in range(batch_size): - if is_finished[i]: - continue - input_ids[i, valid_length_each_example[i]] = target_list[i] + for i in range(batch_size): + if is_finished[i]: + continue + input_ids[i, valid_length_each_example[i]] = target_list[i] - if self.config.is_encoder_decoder: - target_mask[i][valid_length_each_example[i]] = int(1) + if self.config.is_encoder_decoder: + target_mask[i][valid_length_each_example[i]] = int(1) - # Stop judgment - if target_list[i] in generation_config.eos_token_id \ - or valid_length_each_example[i] + 1 == generation_config.max_length \ - or valid_length_each_example[i] + 1 == max_length_each_example[i]: - is_finished[i] = True - else: - valid_length_each_example[i] += 1 - input_mask[i][valid_length_each_example[i] - 1] = 1 + # Stop judgment + if target_list[i] in generation_config.eos_token_id \ + or valid_length_each_example[i] + 1 == generation_config.max_length \ + or valid_length_each_example[i] + 1 == max_length_each_example[i]: + is_finished[i] = True + else: + valid_length_each_example[i] += 1 + input_mask[i][valid_length_each_example[i] - 1] = 1 - if streamer is not None: - if batch_size == 1: - streamer.put(target_list[0]) - else: - streamer.put(target_list) - self.detailed_latency.end_postprocess_timer() + if streamer is not None: + if batch_size == 1: + streamer.put(target_list[0]) + else: + streamer.put(target_list) + self.detailed_latency.end_postprocess_timer() # Return valid outputs out of padded outputs valid_length_each_example += 1 @@ -1745,8 +1485,6 @@ class GenerationMixin: target = p_args[i][target_index] target_list[i] = target - elif generation_config.generation_mode == GenerationMode.BEAM_SEARCH: - raise ValueError("sampler method doesn't support BEAM_SEARCH. ") if not self.is_pynative: # pylint: disable=C0415 from mindspore.common.api import _pynative_executor diff --git a/tests/st/test_ut/base_schema.json b/tests/st/test_ut/base_schema.json index deb984083159d2e0b62eea338e1a89c6c2c12cac..ff369eab14e92682c1059bdc4f9cef9178ed50c5 100644 --- a/tests/st/test_ut/base_schema.json +++ b/tests/st/test_ut/base_schema.json @@ -1304,9 +1304,6 @@ "mindformers.generation.GenerationMixin._incremental_infer_mcore": { "signature": "(self, model_inputs: dict, prefill, gather_decode=True)" }, - "mindformers.generation.GenerationMixin._beam_search": { - "signature": "(self, origin_inputs, beam_scorer: mindformers.generation.beam_search.BeamSearchScorer, generation_config: mindformers.generation.generation_config.GenerationConfig, logits_processor: Optional[mindformers.generation.logits_process.LogitsProcessorList] = None, streamer: mindformers.generation.streamers.BaseStreamer = None, **model_kwargs)" - }, "mindformers.generation.GenerationMixin.generate": { "signature": "(self, input_ids: Union[List[List[int]], List[int], NoneType], generation_config: Optional[mindformers.generation.generation_config.GenerationConfig] = None, logits_processor: Optional[mindformers.generation.logits_process.LogitsProcessorList] = None, streamer: Optional[mindformers.generation.streamers.BaseStreamer] = None, seed: Optional[int] = None, **kwargs)" }, @@ -1595,9 +1592,6 @@ "mindformers.generation.text_generator.GenerationMixin._incremental_infer_mcore": { "signature": "(self, model_inputs: dict, prefill, gather_decode=True)" }, - "mindformers.generation.text_generator.GenerationMixin._beam_search": { - "signature": "(self, origin_inputs, beam_scorer: mindformers.generation.beam_search.BeamSearchScorer, generation_config: mindformers.generation.generation_config.GenerationConfig, logits_processor: Optional[mindformers.generation.logits_process.LogitsProcessorList] = None, streamer: mindformers.generation.streamers.BaseStreamer = None, **model_kwargs)" - }, "mindformers.generation.text_generator.GenerationMixin.generate": { "signature": "(self, input_ids: Union[List[List[int]], List[int], NoneType], generation_config: Optional[mindformers.generation.generation_config.GenerationConfig] = None, logits_processor: Optional[mindformers.generation.logits_process.LogitsProcessorList] = None, streamer: Optional[mindformers.generation.streamers.BaseStreamer] = None, seed: Optional[int] = None, **kwargs)" },