transformers中generate方法

Aug 08, 2024
2 views
Large Model

比起两年前,NLG任务已经得到了非常有效的发展,transformers模块的使用广泛程度也达到前所未有的程度。在模型推理预测时,一个核心的语句就是model.generate(),本文就来详细介绍一下generate方法是如何运作的。在生成的过程中,包含了诸多生成策略,本文将以最常用的beam search为例,尽可能详细地展开介绍。

随着各种LLM的出现,transformers中与generate相关的代码发生了一些变化,主要区别在于:

generate的源码位置发生了改变; generate方法中,采用一个generation_config参数来管理生成相关的各种配置,并优化了逻辑,使得逻辑更加清晰。

1. generate的代码位置

在之前版本的transformers中(transformers~=4.9),generate方法位于transformers.generation_utils.py,这个方法是GenerationMixin类的一个方法。

而在新版本的transformers中(transformers~=4.42),generate方法被转移到了transformers.generation.utils.py,仍然是GenerationMixin的一个类方法。

而对于一个hf形式的预训练模型,都是继承了PreTrainedModel类的,而顺着这个PreTrainedModel类,可以看到更上一级的继承逻辑,GenerationMixin就在其中:

class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):

这就是为什么通过AutoModel.from_pretrained()实例化的一个model为什么可以直接调用generate方法去做推理。

2. GenerationMixin概览

这一部分作为一个速查表写在这里,不建议直接阅读,而是在读后面代码的过程中,返回来查看这部分内容。

GenerationMixin类所有方法概览如下:

3. generate签名

在介绍流程之前先看一下generate方法的签名,在4.42.4版本中,其签名简化如下:

@torch.no_grad()
def generate(
    self,
    inputs: Optional[torch.Tensor] = None,
    generation_config: Optional[GenerationConfig] = None,
    logits_processor: Optional[LogitsProcessorList] = None,
    stopping_criteria: Optional[StoppingCriteriaList] = None,
    prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
    synced_gpus: Optional[bool] = None,
    assistant_model: Optional["PreTrainedModel"] = None,
    streamer: Optional["BaseStreamer"] = None,
    negative_prompt_ids: Optional[torch.Tensor] = None,
    negative_prompt_attention_mask: Optional[torch.Tensor] = None,
    **kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:

相比之前的版本,这样写的直接优点就是,与原版的超长签名相比,减少了传入的参数,将诸如top_k, top_p, num_beams等参数全部都整合到了generation_config中,使得函数看起来更加简化,并且该参数可以直接从模型路径下的generation_config.json文件中读取,一定程度上为用户提供了便捷

相应的缺点就是很多参数没有显性地暴露出来,在查看注释和自定义生成配置的时候就不是很方便了。 需要在GenerationConfig中查看可选的参数:

from transformers.generation.configuration_utils import GenerationConfig

help(GenerationConfig)

generate方法的参数含义与作用介绍如下:

在这些输入中,logits_processorstopping_criteria,将是用户手动干预生成过程的主要手段。

4. generate过程

在4.42版本的transformers代码中,generate过程的注释写的比较条理清晰,所以本文也沿用代码注释中的序号进行划分。

4.1 读取并更新generation config

这一部分的大概逻辑就是处理generation config为None的情况,以及检查是否存在与生成策略不一致的错误参数。

# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
tokenizer = kwargs.pop("tokenizer", None)  # Pull this out first, we only use it for stopping criteria
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
self._validate_model_kwargs(model_kwargs.copy())
self._validate_assistant(assistant_model)

其中_validate_model_class_validate_model_kwargs_validate_assistant方法都不是重点,这里不展开介绍。

4.2 补充没有传入的参数

这部分需要补充的参数包括logits_processor, stopping_criteria, 以及generation_config中的pad_token_id。前两项是设置为默认的空list;查看self.forward以及model_args有没有attention_mask的传入

# 2. Set generation parameters if not already defined
if synced_gpus is None:
    if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1:
        synced_gpus = True
    else:
        synced_gpus = False

logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()

accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None

4.3 定义模型输入

# 3. Define model inputs
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
    inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]

device = inputs_tensor.device
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)

# decoder-only models must use left-padding for batched generation.
if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
    # If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
    # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
    if (
        generation_config.pad_token_id is not None
        and batch_size > 1
        and len(inputs_tensor.shape) == 2
        and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
    ):
        logger.warning(
            "A decoder-only architecture is being used, but right-padding was detected! For correct "
            "generation results, please set `padding_side='left'` when initializing the tokenizer."
        )

这里主要需要关注_prepare_model_inputs这个方法,这个方法的核心,一句话概括就是模型输入的序列input_ids,必须非空,如果空的话,就用bos_token去初始化。其余部分都是用来应对个别模型的特殊情况。并检查decoder-only的模型输入,检查input_ids中任何序列中的最后一个 id 是否为“pad_token_id”

def _prepare_model_inputs(
    self,
    inputs: Optional[torch.Tensor] = None,
    bos_token_id: Optional[torch.Tensor] = None,
    model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
    """
    This function extracts the model-specific `inputs` for generation.
    """
    # 这一步似乎是起到一个校准的作用,防止某些encoder-decoder模型的主模型和encoder的输入名称不一致
    # 1. retrieve all kwargs that are non-None or non-model input related.
    # some encoder-decoder models have different names for model and encoder
    if (
        self.config.is_encoder_decoder
        and hasattr(self, "encoder")
        and self.encoder.main_input_name != self.main_input_name
    ):
        input_name = self.encoder.main_input_name
    else:
        input_name = self.main_input_name

    model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}

        # 确保inputs没有重复传入
    # 2. check whether model_input_name is passed as kwarg
    # if yes and `inputs` is None use kwarg inputs
    inputs_kwarg = model_kwargs.pop(input_name, None)
    if inputs_kwarg is not None and inputs is not None:
        raise ValueError(
            f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed. "
            f"Make sure to either pass {inputs} or {input_name}=..."
        )
    elif inputs_kwarg is not None:
        inputs = inputs_kwarg

        # 如果 input_name 是 input_ids 且 model_kwargs 中存在 inputs_embeds这一输入参数:
        # 如果是decoder-only模型,如果支持 inputs_embeds,则将 input_ids 转移到 model_kwargs 中,
        # 这样后续的一些自动化步骤(如创建 attention_mask)可以依赖实际的模型输入。需要把'input_ids'这一参数放在inputs_kwarg中传入
        # 如果是encoder-decoder模型,input_ids与inputs_embeds只能传入其一
    # 3. In the presence of `inputs_embeds` for text models:
    # - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model
    # doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with
    # input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`)
    # - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and
    # pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states.
    if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
        if not self.config.is_encoder_decoder:
            has_inputs_embeds_forwarding = "inputs_embeds" in set(
                inspect.signature(self.prepare_inputs_for_generation).parameters.keys()
            )
            if not has_inputs_embeds_forwarding:
                raise ValueError(
                    f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} "
                    "doesn't have its forwarding implemented. See the GPT2 implementation for an example "
                    "(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!"
                )
            # In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
            # the attention mask) can rely on the actual model input.
            model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
                inputs, bos_token_id, model_kwargs=model_kwargs
            )
        else:
            if inputs is not None:
                raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")
        inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"

        # 如果最后还是没有input_ids, 采用bos创建input_ids,可以简化理解为:
    # torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
    # 4. if `inputs` is still None, try to create `input_ids` from BOS token
    inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
    return inputs, input_name, model_kwargs

这里稍微解释下为什么会存在这样的差异检查: **decoder-only** 模型与 **encoder-decoder** 模型的不同

4.4 定义模型的其他参数

这一部分没有需要特别注意的地方,主要就是一些config设置,补齐模型的其他参数,如创建attention_mask,确保encoder-decoder模型能够返回’ModelOutput’类等等。

# 4. Define other model kwargs
# decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
# generating the first new token or not, and we only want to use the embeddings for the first new token)
if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
    model_kwargs["use_cache"] = True
else:
    model_kwargs["use_cache"] = generation_config.use_cache

if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
    model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
        inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
    )

if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
    # if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
    model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
        inputs_tensor, model_kwargs, model_input_name, generation_config
    )

4.5 对自回归模型准备input_ids

这一步与4.3的主要区别在于,针对AR模型额外进行了处理。如果是encoder-decoder模型,确保模型的解码器输入是正确格式的、如果是decoder-only的模型则直接采用4.3创建的input_tensor作为input_ids。

# 5. Prepare `input_ids` which will be used for auto-regressive generation
if self.config.is_encoder_decoder:
    input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
        batch_size=batch_size,
        model_input_name=model_input_name,
        model_kwargs=model_kwargs,
        decoder_start_token_id=generation_config.decoder_start_token_id,
        device=inputs_tensor.device,
    )
else:
    input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")

if generation_config.token_healing:
    input_ids = self.heal_tokens(input_ids, tokenizer)

if streamer is not None:
    streamer.put(input_ids.cpu())

另外, 这里还有个函数heal_tokens, 主要功能为模型生成新的序列, 扩展或者替换序列中的尾部token,以便更好地匹配可能的扩展或替代

  • 遍历每个批次中的尾部 token ID,并尝试找到这个 token 的可能扩展(即可能的替代 token)
  • 如果找到了扩展 token,为每个扩展 token 应用一个生成偏置值 (sequence_bias),使得这些 token 在生成时更有可能被选择。
  • 在生成时,会轻微地偏向原始 token 以避免过于激进的修复(例如 'http' -> 'https' 这样的替代可能是不期望的)。
  • 使用 self.generate 方法重新生成新的序列,替换掉原始序列中的尾部 token。
  • 最终返回处理后的 input_ids,这些序列可能在尾部 token 被替换或扩展后得到了改进。

    假设你有一个输入序列为 input_ids = torch.tensor([[203, 204, 205]]),这些 IDs 对应的 tokens 是 'hello world'。如果 205 是一个可以扩展的 token(例如,它可能表示 'world' 或者 'worldwide'),那么这个函数会尝试查找和替换它。如果找到了更好的扩展(例如 'worldwide'),函数将会生成一个新序列替换掉原来的序列。 最终生成的新序列可能是 torch.tensor([[203, 204, 305]]),其中 305 是新的 token ID,对应 'worldwide'。这样,原始输入序列就被“修复”成了更合适的输出。

**4.6 准备最大长度 **

这一部分就是根据config中的相关配置,判断input_id的长度有没有超长, 被封装进了_prepare_generated_length_validate_generated_length函数中, 后面都是cache相关的参数设置, 暂时不去考虑。

# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
generation_config = self._prepare_generated_length(
    generation_config=generation_config,
    has_default_max_length=has_default_max_length,
    has_default_min_length=has_default_min_length,
    model_input_name=model_input_name,
    inputs_tensor=inputs_tensor,
    input_ids_length=input_ids_length,
)

use_dynamic_cache_by_default = False
if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None:
    raise ValueError(
        "Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a "
        "Cache object) is unsupported. Please use only one of the two."
    )
elif generation_config.cache_implementation is not None:
    if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
        if generation_config.cache_implementation == "static" and not self._supports_static_cache:
            raise ValueError(
                "This model does not support `cache_implementation='static'`. Please check the following "
                "issue: https://github.com/huggingface/transformers/issues/28981"
            )
        model_kwargs["past_key_values"] = self._get_cache(
            generation_config.cache_implementation,
            getattr(generation_config, "num_beams", 1) * batch_size,
            generation_config.max_length,
        )
    elif generation_config.cache_implementation == "quantized":
        if not self._supports_quantized_cache:
            raise ValueError(
                "This model does not support the quantized cache. If you want your model to support quantized "
                "cache, please open an issue."
            )

        cache_config = (
            generation_config.cache_config
            if generation_config.cache_config is not None
            else QuantizedCacheConfig()
        )
        cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend]

        if cache_config.backend == "quanto" and not is_quanto_available():
            raise ImportError(
                "You need to install `quanto` in order to use KV cache quantization with quanto backend. "
                "Please install it via  with `pip install quanto`"
            )
        elif cache_config.backend == "HQQ" and not is_hqq_available():
            raise ImportError(
                "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
                "Please install it via  with `pip install hqq`"
            )

        model_kwargs["past_key_values"] = cache_class(cache_config)
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
# keeps copying the cache thus using much more memory
elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
    past = model_kwargs.get("past_key_values", None)
    if past is None:
        model_kwargs["past_key_values"] = DynamicCache()
        use_dynamic_cache_by_default = True
    elif isinstance(past, tuple):
        model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(past)
        use_dynamic_cache_by_default = True

self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)

4.7 确认生成模式

这里直接选择beam search分支了,其他模式不做展开介绍,下同。

beam search分为两种,beam_search以及进阶款的后者对应后续的生成方法为beam_sample。

如果do_sample为True, 会选择beam_sample

二者的区别主要在于,进阶款的beam_sample_gen_mode可以设置temperature、top_k、top_p等参数进一步控制生成,设置的方法在4.11节:logits warper中介绍。对于基础款的beam_search,就没有创建logits warper这一环节。

# 7. determine generation mode
generation_mode = generation_config.get_generation_mode(assistant_model)

if streamer is not None and (generation_config.num_beams > 1):
    raise ValueError(
        "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
    )

4.8 创建logits processor

# 8. prepare distribution pre_processing samplers
prepared_logits_processor = self._get_logits_processor(
    generation_config=generation_config,
    input_ids_seq_length=input_ids_length,
    encoder_input_ids=inputs_tensor,
    prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
    logits_processor=logits_processor,
    device=inputs_tensor.device,
    model_kwargs=model_kwargs,
    negative_prompt_ids=negative_prompt_ids,
    negative_prompt_attention_mask=negative_prompt_attention_mask,
)

这一个环节比较重要,因为涉及到了logits processor。这些processor是在生成的过程中,在每一个step,对计算出来的得分进行修正处理的。在transformers中,预设了若干processor,用户也可以定义自己的processor(需要继承抽象类transformers.generation.logit_process.LogitsProcessor),自己设计逻辑,来对生成的过程进行人工干预。

在beam search中,logits process的使用方法是:

next_token_scores_processed = logits_processor(input_ids, next_token_scores)

其中,input_ids是当前step传给模型的序列token id对应Tensor(batch_size, sequence_length),next_token_scores是经过模型计算之后的分数(即在vocab上的概率分布)取log_softmax。

在这里简单介绍一下在transformers中预设的processor。限于篇幅,不贴出全部源码,只对其功能进行总结。

4.9 创建停止规则

stopping_criterialogits_processor是用户对生成过程进行干预的主要手段,相比logits_processor强行改变概率空间,stopping_criteria则是直接设定了终止生成的策略,理解起来也会相对容易一些。

# 9. prepare stopping criteria
prepared_stopping_criteria = self._get_stopping_criteria(
    generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
)

预设的criteria总结如下:

如果是自定义criteria,应当继承抽象类transformers.generation.stopping_criteria.StoppingCriteria

4.10 进入相应的分支

这里直接选择进入beam search的分支。如前文所述,如果要控制temperature等超参数,则应该进入is_beam_sample_gen_mode这个分支。

这些是生成模型中不同的生成方法,每种方法在文本生成的过程中有不同的策略:

  1. CONTRASTIVE_SEARCH: 是一种平衡探索和利用的生成方法,适用于需要生成高质量且有一定多样性的文本任务。
  2. GREEDY_SEARCH: 每一步都选择最有可能的词,生成单一且通常连贯的文本,但缺乏多样性。
  3. SAMPLE: 根据概率分布随机选择下一个词,生成更具多样性的文本,但可能影响连贯性。
  4. ASSISTED_GENERATION: 结合人类输入或外部辅助信息来生成文本,通常用于对生成内容有特定需求的任务。
  5. BEAM_SEARCH: 在生成过程中维护多个候选序列(通常称为“束”),最终选择概率最高的完整序列。这种方法在质量和多样性之间取得平衡。
  6. BEAM_SAMPLE: 结合了束搜索和采样的策略,在保持候选序列的同时对每一步进行采样,以提高生成文本的多样性。
  7. CONSTRAINED_BEAM_SEARCH: 在束搜索的基础上增加了某些约束条件,确保生成的文本满足特定要求,如包含某些关键字或遵循特定的语法结构。
  8. GROUP_BEAM_SEARCH: 将候选序列分组进行束搜索,通常用于减少生成文本的重复性,并提高生成结果的多样性。这在需要生成多个不重复答案的场景中非常有用。 后面主要介绍Beam_search

4.11 创建logits warper

elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
    # 11. prepare logits warper
    prepared_logits_warper = (
        self._get_logits_warper(generation_config, device=input_ids.device)
        if generation_config.do_sample
        else None
    )

logits warper的使用方法与logits processor一样,都是用来修改概率的输出。关于他们的区别,暂时没有找到很好的解释,可以理解为warper控制着temperature、topk等与生成策略相关的参数。并且是在logits processor处理之后再进行处理的。

普通的beam search不会涉及这一部分,只有选择sample模式的beam search时,才会使用到logits warper。

需要记住的是,它的输入与processor一样,都是当前的序列(token_ids)与之前计算出的得分(scores),返回的结果是处理之后的得分,形状是(batch_size, config.vocab_size)。

预设的warper包括:

这一部分是beam search的核心流程,其具体的执行生成过程将在第5节中进行详细的介绍。

在这一部分中,首先创建了用于打分的BeamSearchScorer(具体作用将在第5节中进行介绍),然后根据num_beams对input_ids进行了扩展,最后执行beam search的核心方法beam_search,或beam sample对应的beam_sample方法。

# 12. prepare beam search scorer
beam_scorer = BeamSearchScorer(
    batch_size=batch_size,
    num_beams=generation_config.num_beams,
    device=inputs_tensor.device,
    length_penalty=generation_config.length_penalty,
    do_early_stopping=generation_config.early_stopping,
    num_beam_hyps_to_keep=generation_config.num_return_sequences,
    max_length=generation_config.max_length,
)
 # 13. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
    input_ids=input_ids,
    expand_size=generation_config.num_beams,
    is_encoder_decoder=self.config.is_encoder_decoder,
    **model_kwargs,
)

# 14. run beam sample
result = self._beam_search(
    input_ids,
    beam_scorer,
    logits_processor=prepared_logits_processor,
    logits_warper=prepared_logits_warper,
    stopping_criteria=prepared_stopping_criteria,
    generation_config=generation_config,
    synced_gpus=synced_gpus,
    **model_kwargs,
)

5. Beam Search

简单介绍一下在文本生成任务中常用的解码策略Beam Search(集束搜索)。

生成式任务相比普通的分类、tagging等NLP任务会复杂不少。在生成的时候,模型的输出是一个时间步一个时间步依次获得的,而且前面时间步的结果还会影响后面时间步的结果。也就是说,每一个时间步,模型给出的都是基于历史生成结果的条件概率。为了生成完整的句子,需要一个称为解码(decode)的额外动作来融合模型多个时间步的输出,而且使得最终得到的序列的每一步条件概率连乘起来最大

在文本生成任务中,每一个时间步可能的输出种类称为字典大小(vocabulary size,我们用V表示),进行T步随机的生成可能获得的结果总共有$V^T$种。拿中文文本生成来说,V 的值大约是5000-6000,即常用汉字的个数。在如此大的基数下,遍历整个生成空间是不现实的。