比起两年前,NLG任务已经得到了非常有效的发展,transformers模块的使用广泛程度也达到前所未有的程度。在模型推理预测时,一个核心的语句就是model.generate(),本文就来详细介绍一下generate方法是如何运作的。在生成的过程中,包含了诸多生成策略,本文将以最常用的beam search为例,尽可能详细地展开介绍。
随着各种LLM的出现,transformers中与generate相关的代码发生了一些变化,主要区别在于:
generate的源码位置发生了改变; generate方法中,采用一个generation_config参数来管理生成相关的各种配置,并优化了逻辑,使得逻辑更加清晰。
1. generate的代码位置
在之前版本的transformers中(transformers~=4.9),generate方法位于transformers.generation_utils.py,这个方法是GenerationMixin类的一个方法。
而在新版本的transformers中(transformers~=4.42),generate方法被转移到了transformers.generation.utils.py,仍然是GenerationMixin的一个类方法。
而对于一个hf形式的预训练模型,都是继承了PreTrainedModel类的,而顺着这个PreTrainedModel类,可以看到更上一级的继承逻辑,GenerationMixin就在其中:
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):
这就是为什么通过AutoModel.from_pretrained()实例化的一个model为什么可以直接调用generate方法去做推理。
2. GenerationMixin概览
这一部分作为一个速查表写在这里,不建议直接阅读,而是在读后面代码的过程中,返回来查看这部分内容。
GenerationMixin类所有方法概览如下:
3. generate签名
在介绍流程之前先看一下generate方法的签名,在4.42.4版本中,其签名简化如下:
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
相比之前的版本,这样写的直接优点就是,与原版的超长签名相比,减少了传入的参数,将诸如top_k, top_p, num_beams等参数全部都整合到了generation_config中,使得函数看起来更加简化,并且该参数可以直接从模型路径下的generation_config.json文件中读取,一定程度上为用户提供了便捷。
相应的缺点就是很多参数没有显性地暴露出来,在查看注释和自定义生成配置的时候就不是很方便了。 需要在GenerationConfig中查看可选的参数:
from transformers.generation.configuration_utils import GenerationConfig
help(GenerationConfig)
generate方法的参数含义与作用介绍如下:
在这些输入中,logits_processor和stopping_criteria,将是用户手动干预生成过程的主要手段。
4. generate过程
在4.42版本的transformers代码中,generate过程的注释写的比较条理清晰,所以本文也沿用代码注释中的序号进行划分。
4.1 读取并更新generation config
这一部分的大概逻辑就是处理generation config为None的情况,以及检查是否存在与生成策略不一致的错误参数。
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
self._validate_model_kwargs(model_kwargs.copy())
self._validate_assistant(assistant_model)
其中_validate_model_class,_validate_model_kwargs,_validate_assistant方法都不是重点,这里不展开介绍。
4.2 补充没有传入的参数
这部分需要补充的参数包括logits_processor, stopping_criteria, 以及generation_config中的pad_token_id。前两项是设置为默认的空list;查看self.forward以及model_args有没有attention_mask的传入
# 2. Set generation parameters if not already defined
if synced_gpus is None:
if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1:
synced_gpus = True
else:
synced_gpus = False
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
4.3 定义模型输入
# 3. Define model inputs
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
device = inputs_tensor.device
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
# decoder-only models must use left-padding for batched generation.
if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
# Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
if (
generation_config.pad_token_id is not None
and batch_size > 1
and len(inputs_tensor.shape) == 2
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
):
logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
"generation results, please set `padding_side='left'` when initializing the tokenizer."
)
这里主要需要关注_prepare_model_inputs这个方法,这个方法的核心,一句话概括就是模型输入的序列input_ids,必须非空,如果空的话,就用bos_token去初始化。其余部分都是用来应对个别模型的特殊情况。并检查decoder-only的模型输入,检查input_ids中任何序列中的最后一个 id 是否为“pad_token_id”
def _prepare_model_inputs(
self,
inputs: Optional[torch.Tensor] = None,
bos_token_id: Optional[torch.Tensor] = None,
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
"""
This function extracts the model-specific `inputs` for generation.
"""
# 这一步似乎是起到一个校准的作用,防止某些encoder-decoder模型的主模型和encoder的输入名称不一致
# 1. retrieve all kwargs that are non-None or non-model input related.
# some encoder-decoder models have different names for model and encoder
if (
self.config.is_encoder_decoder
and hasattr(self, "encoder")
and self.encoder.main_input_name != self.main_input_name
):
input_name = self.encoder.main_input_name
else:
input_name = self.main_input_name
model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}
# 确保inputs没有重复传入
# 2. check whether model_input_name is passed as kwarg
# if yes and `inputs` is None use kwarg inputs
inputs_kwarg = model_kwargs.pop(input_name, None)
if inputs_kwarg is not None and inputs is not None:
raise ValueError(
f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed. "
f"Make sure to either pass {inputs} or {input_name}=..."
)
elif inputs_kwarg is not None:
inputs = inputs_kwarg
# 如果 input_name 是 input_ids 且 model_kwargs 中存在 inputs_embeds这一输入参数:
# 如果是decoder-only模型,如果支持 inputs_embeds,则将 input_ids 转移到 model_kwargs 中,
# 这样后续的一些自动化步骤(如创建 attention_mask)可以依赖实际的模型输入。需要把'input_ids'这一参数放在inputs_kwarg中传入
# 如果是encoder-decoder模型,input_ids与inputs_embeds只能传入其一
# 3. In the presence of `inputs_embeds` for text models:
# - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model
# doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with
# input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`)
# - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and
# pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states.
if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
if not self.config.is_encoder_decoder:
has_inputs_embeds_forwarding = "inputs_embeds" in set(
inspect.signature(self.prepare_inputs_for_generation).parameters.keys()
)
if not has_inputs_embeds_forwarding:
raise ValueError(
f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} "
"doesn't have its forwarding implemented. See the GPT2 implementation for an example "
"(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!"
)
# In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
# the attention mask) can rely on the actual model input.
model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
inputs, bos_token_id, model_kwargs=model_kwargs
)
else:
if inputs is not None:
raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
# 如果最后还是没有input_ids, 采用bos创建input_ids,可以简化理解为:
# torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
return inputs, input_name, model_kwargs
这里稍微解释下为什么会存在这样的差异检查:
**decoder-only**模型与**encoder-decoder**模型的不同
4.4 定义模型的其他参数
这一部分没有需要特别注意的地方,主要就是一些config设置,补齐模型的其他参数,如创建attention_mask,确保encoder-decoder模型能够返回’ModelOutput’类等等。
# 4. Define other model kwargs
# decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
# generating the first new token or not, and we only want to use the embeddings for the first new token)
if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
model_kwargs["use_cache"] = True
else:
model_kwargs["use_cache"] = generation_config.use_cache
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
)
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
# if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
inputs_tensor, model_kwargs, model_input_name, generation_config
)
4.5 对自回归模型准备input_ids
这一步与4.3的主要区别在于,针对AR模型额外进行了处理。如果是encoder-decoder模型,确保模型的解码器输入是正确格式的、如果是decoder-only的模型则直接采用4.3创建的input_tensor作为input_ids。
# 5. Prepare `input_ids` which will be used for auto-regressive generation
if self.config.is_encoder_decoder:
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
batch_size=batch_size,
model_input_name=model_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=generation_config.decoder_start_token_id,
device=inputs_tensor.device,
)
else:
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
if generation_config.token_healing:
input_ids = self.heal_tokens(input_ids, tokenizer)
if streamer is not None:
streamer.put(input_ids.cpu())
另外, 这里还有个函数heal_tokens, 主要功能为模型生成新的序列, 扩展或者替换序列中的尾部token,以便更好地匹配可能的扩展或替代
- 遍历每个批次中的尾部 token ID,并尝试找到这个 token 的可能扩展(即可能的替代 token)
- 如果找到了扩展 token,为每个扩展 token 应用一个生成偏置值 (
sequence_bias),使得这些 token 在生成时更有可能被选择。 - 在生成时,会轻微地偏向原始 token 以避免过于激进的修复(例如 'http' -> 'https' 这样的替代可能是不期望的)。
- 使用
self.generate方法重新生成新的序列,替换掉原始序列中的尾部 token。 - 最终返回处理后的
input_ids,这些序列可能在尾部 token 被替换或扩展后得到了改进。假设你有一个输入序列为
input_ids = torch.tensor([[203, 204, 205]]),这些 IDs 对应的 tokens 是'hello world'。如果205是一个可以扩展的 token(例如,它可能表示'world'或者'worldwide'),那么这个函数会尝试查找和替换它。如果找到了更好的扩展(例如'worldwide'),函数将会生成一个新序列替换掉原来的序列。 最终生成的新序列可能是torch.tensor([[203, 204, 305]]),其中305是新的 token ID,对应'worldwide'。这样,原始输入序列就被“修复”成了更合适的输出。
**4.6 准备最大长度 **
这一部分就是根据config中的相关配置,判断input_id的长度有没有超长, 被封装进了_prepare_generated_length和_validate_generated_length函数中, 后面都是cache相关的参数设置, 暂时不去考虑。
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
generation_config = self._prepare_generated_length(
generation_config=generation_config,
has_default_max_length=has_default_max_length,
has_default_min_length=has_default_min_length,
model_input_name=model_input_name,
inputs_tensor=inputs_tensor,
input_ids_length=input_ids_length,
)
use_dynamic_cache_by_default = False
if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None:
raise ValueError(
"Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a "
"Cache object) is unsupported. Please use only one of the two."
)
elif generation_config.cache_implementation is not None:
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if generation_config.cache_implementation == "static" and not self._supports_static_cache:
raise ValueError(
"This model does not support `cache_implementation='static'`. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981"
)
model_kwargs["past_key_values"] = self._get_cache(
generation_config.cache_implementation,
getattr(generation_config, "num_beams", 1) * batch_size,
generation_config.max_length,
)
elif generation_config.cache_implementation == "quantized":
if not self._supports_quantized_cache:
raise ValueError(
"This model does not support the quantized cache. If you want your model to support quantized "
"cache, please open an issue."
)
cache_config = (
generation_config.cache_config
if generation_config.cache_config is not None
else QuantizedCacheConfig()
)
cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend]
if cache_config.backend == "quanto" and not is_quanto_available():
raise ImportError(
"You need to install `quanto` in order to use KV cache quantization with quanto backend. "
"Please install it via with `pip install quanto`"
)
elif cache_config.backend == "HQQ" and not is_hqq_available():
raise ImportError(
"You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
"Please install it via with `pip install hqq`"
)
model_kwargs["past_key_values"] = cache_class(cache_config)
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
# keeps copying the cache thus using much more memory
elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
past = model_kwargs.get("past_key_values", None)
if past is None:
model_kwargs["past_key_values"] = DynamicCache()
use_dynamic_cache_by_default = True
elif isinstance(past, tuple):
model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(past)
use_dynamic_cache_by_default = True
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
4.7 确认生成模式
这里直接选择beam search分支了,其他模式不做展开介绍,下同。
beam search分为两种,beam_search以及进阶款的后者对应后续的生成方法为beam_sample。
如果do_sample为True, 会选择beam_sample
二者的区别主要在于,进阶款的beam_sample_gen_mode可以设置temperature、top_k、top_p等参数进一步控制生成,设置的方法在4.11节:logits warper中介绍。对于基础款的beam_search,就没有创建logits warper这一环节。
# 7. determine generation mode
generation_mode = generation_config.get_generation_mode(assistant_model)
if streamer is not None and (generation_config.num_beams > 1):
raise ValueError(
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
)
4.8 创建logits processor
# 8. prepare distribution pre_processing samplers
prepared_logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
device=inputs_tensor.device,
model_kwargs=model_kwargs,
negative_prompt_ids=negative_prompt_ids,
negative_prompt_attention_mask=negative_prompt_attention_mask,
)
这一个环节比较重要,因为涉及到了logits processor。这些processor是在生成的过程中,在每一个step,对计算出来的得分进行修正处理的。在transformers中,预设了若干processor,用户也可以定义自己的processor(需要继承抽象类transformers.generation.logit_process.LogitsProcessor),自己设计逻辑,来对生成的过程进行人工干预。
在beam search中,logits process的使用方法是:
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
其中,input_ids是当前step传给模型的序列token id对应Tensor(batch_size, sequence_length),next_token_scores是经过模型计算之后的分数(即在vocab上的概率分布)取log_softmax。
在这里简单介绍一下在transformers中预设的processor。限于篇幅,不贴出全部源码,只对其功能进行总结。
4.9 创建停止规则
stopping_criteria与logits_processor是用户对生成过程进行干预的主要手段,相比logits_processor强行改变概率空间,stopping_criteria则是直接设定了终止生成的策略,理解起来也会相对容易一些。
# 9. prepare stopping criteria
prepared_stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
)
预设的criteria总结如下:
如果是自定义criteria,应当继承抽象类transformers.generation.stopping_criteria.StoppingCriteria。
4.10 进入相应的分支
这里直接选择进入beam search的分支。如前文所述,如果要控制temperature等超参数,则应该进入is_beam_sample_gen_mode这个分支。
这些是生成模型中不同的生成方法,每种方法在文本生成的过程中有不同的策略:
- CONTRASTIVE_SEARCH: 是一种平衡探索和利用的生成方法,适用于需要生成高质量且有一定多样性的文本任务。
- GREEDY_SEARCH: 每一步都选择最有可能的词,生成单一且通常连贯的文本,但缺乏多样性。
- SAMPLE: 根据概率分布随机选择下一个词,生成更具多样性的文本,但可能影响连贯性。
- ASSISTED_GENERATION: 结合人类输入或外部辅助信息来生成文本,通常用于对生成内容有特定需求的任务。
- BEAM_SEARCH: 在生成过程中维护多个候选序列(通常称为“束”),最终选择概率最高的完整序列。这种方法在质量和多样性之间取得平衡。
- BEAM_SAMPLE: 结合了束搜索和采样的策略,在保持候选序列的同时对每一步进行采样,以提高生成文本的多样性。
- CONSTRAINED_BEAM_SEARCH: 在束搜索的基础上增加了某些约束条件,确保生成的文本满足特定要求,如包含某些关键字或遵循特定的语法结构。
- GROUP_BEAM_SEARCH: 将候选序列分组进行束搜索,通常用于减少生成文本的重复性,并提高生成结果的多样性。这在需要生成多个不重复答案的场景中非常有用。 后面主要介绍Beam_search
4.11 创建logits warper
elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
# 11. prepare logits warper
prepared_logits_warper = (
self._get_logits_warper(generation_config, device=input_ids.device)
if generation_config.do_sample
else None
)
logits warper的使用方法与logits processor一样,都是用来修改概率的输出。关于他们的区别,暂时没有找到很好的解释,可以理解为warper控制着temperature、topk等与生成策略相关的参数。并且是在logits processor处理之后再进行处理的。
普通的beam search不会涉及这一部分,只有选择sample模式的beam search时,才会使用到logits warper。
需要记住的是,它的输入与processor一样,都是当前的序列(token_ids)与之前计算出的得分(scores),返回的结果是处理之后的得分,形状是(batch_size, config.vocab_size)。
预设的warper包括:
4.12 beam search
这一部分是beam search的核心流程,其具体的执行生成过程将在第5节中进行详细的介绍。
在这一部分中,首先创建了用于打分的BeamSearchScorer(具体作用将在第5节中进行介绍),然后根据num_beams对input_ids进行了扩展,最后执行beam search的核心方法beam_search,或beam sample对应的beam_sample方法。
# 12. prepare beam search scorer
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=generation_config.num_beams,
device=inputs_tensor.device,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length,
)
# 13. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 14. run beam sample
result = self._beam_search(
input_ids,
beam_scorer,
logits_processor=prepared_logits_processor,
logits_warper=prepared_logits_warper,
stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,
**model_kwargs,
)
5. Beam Search
简单介绍一下在文本生成任务中常用的解码策略Beam Search(集束搜索)。
生成式任务相比普通的分类、tagging等NLP任务会复杂不少。在生成的时候,模型的输出是一个时间步一个时间步依次获得的,而且前面时间步的结果还会影响后面时间步的结果。也就是说,每一个时间步,模型给出的都是基于历史生成结果的条件概率。为了生成完整的句子,需要一个称为解码(decode)的额外动作来融合模型多个时间步的输出,而且使得最终得到的序列的每一步条件概率连乘起来最大。
在文本生成任务中,每一个时间步可能的输出种类称为字典大小(vocabulary size,我们用V表示),进行T步随机的生成可能获得的结果总共有$V^T$种。拿中文文本生成来说,V 的值大约是5000-6000,即常用汉字的个数。在如此大的基数下,遍历整个生成空间是不现实的。