Packing & rmpad

Nov 25, 2025
9 views
Large Model

简介

基于lmms-engine中的训练时对数据packing操作以及use_rmpad消除了所有padding计算的逻辑

Packing

总体逻辑基于packing_length 将不同的数据填充到一个sequence中,具体来说

在Datsset中, 如下代码所示,将不同的数据append到buffer列表中

if self.config.packing:
    # Reset index at the start of each iteration pass
    self.cur_idx = 0
    buffer = []
    buffer_length = 0
    packing_length = self.config.packing_length

    # Iterate through the dataset once per epoch
    while self.cur_idx < len(curr_data_list):
        try:
            data_dict = self.get_one_sample(self.cur_idx, curr_data_folder[self.cur_idx], curr_data_list)
        except Exception as e:
            traceback.print_exc()
            logger.error(f"Error getting one sample: {e}, skip this sample")
            self.cur_idx += 1
            continue
        input_ids = data_dict["input_ids"]
        data_length = input_ids.shape[0]
        self.cur_idx += 1

        # Drop overlong sample if filtering is enabled
        if data_length > packing_length and self.config.filter_overlong:
            continue

        # If current sample cannot fit into current buffer, yield the buffer first
        if buffer_length > 0 and buffer_length + data_length > packing_length:
            yield buffer
            buffer = []
            buffer_length = 0

        # If the sample is still longer than packing_length (and not filtered),
        # yield it as its own batch to avoid stalling
        if data_length > packing_length:
            yield [data_dict]
            continue

        # Append to buffer
        buffer.append(data_dict)
        buffer_length += data_length

    # Flush remaining buffer
    if len(buffer) > 0:
        yield buffer

在 Collator 组合成batch的形式传入到模型的输入, 这里还是将数据padding

@ dataclass
class VisionCollator:
    processor: Processable

    def pad_sequence(self, input_ids, batch_first, padding_value):
        if self.processor.tokenizer.padding_side == "left":
            input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
        input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value)
        if self.processor.tokenizer.padding_side == "left":
            input_ids = torch.flip(input_ids, [1])
        return input_ids

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        if isinstance(instances[0], list):
            instances = [inst for instance in instances for inst in instance]
        inputs = collections.defaultdict(list)
        for instance in instances:
            for key, values in instance.items():
                inputs[key].append(values)

        batched_inputs = {}
        if "input_ids" in inputs.keys():
            input_ids = inputs.pop("input_ids")
            input_ids = self.pad_sequence(
                input_ids,
                batch_first=True,
                padding_value=self.processor.tokenizer.pad_token_id,
            )
            batched_inputs["input_ids"] = input_ids
        if "labels" in inputs.keys():
            labels = inputs.pop("labels")
            labels = self.pad_sequence(
                labels,
                batch_first=True,
                padding_value=-100,
            )
            batched_inputs["labels"] = labels

        if "attention_mask" in inputs.keys():
            inputs.pop("attention_mask")

        attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id).long()
        batched_inputs["attention_mask"] = attention_mask

        # for the other keys
        for key, values in inputs.items():
            # Handle scalar/boolean values ( use_audio_in_video)
            if isinstance(values[0], bool) or (
                isinstance(values[0], (int, float)) and not isinstance(values[0], torch.Tensor)
            ):
                batched_inputs[key] = values[0]
            else:
                batched_inputs[key] = torch.concatenate(values, dim=0)
        return batched_inputs

    @property
    def image_token_id(self):
        return self.processor.tokenizer.convert_tokens_to_ids(self.processor.image_token)

rmpad

项目中,是以 monkey patch的形式(也就是打热补丁) 替换rmpad操作的,如下代码所示,主要就是替换模型中的forward操作

if use_rmpad:
    from .qwen3_vl_ops import attn_forward as qwen3_ops_attn_forward
    from .qwen3_vl_ops import (
        decoder_layer_forward as qwen3_ops_decoder_layer_forward,
    )
    from .qwen3_vl_ops import model_forward as qwen3_ops_model_forward
    from .qwen3_vl_ops import text_model_forward as qwen3_ops_text_model_forward

    modeling_qwen3_vl.Qwen3VLModel.forward = qwen3_ops_model_forward
    modeling_qwen3_vl.Qwen3VLTextModel.forward = qwen3_ops_text_model_forward
    modeling_qwen3_vl.Qwen3VLTextDecoderLayer.forward = qwen3_ops_decoder_layer_forward
    modeling_qwen3_vl.Qwen3VLTextAttention.forward = qwen3_ops_attn_forward

Qwen3VLModel.forward

显式调用了 _unpad_input。它计算了非 padding 元素的索引 (indices) 和累积序列长度 (cu_seq_lens), 去掉了input中的padding token

if input_ids is not None:
    original_input_ids = input_ids
    input_ids, indices, cu_seq_lens, _ = _unpad_input(input_ids, attention_mask=attention_mask)
    batch_size, seq_length = original_input_ids.shape
elif inputs_embeds is not None:
    original_inputs_embeds = inputs_embeds
    inputs_embeds, indices, cu_seq_lens, _ = _unpad_input(inputs_embeds, attention_mask=attention_mask)
    batch_size, seq_length, _ = original_inputs_embeds.shape
def _unpad_input(input_ids, attention_mask):
    valid_mask = attention_mask.squeeze(1).squeeze(1).eq(1)
    seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32)
    indices = torch.nonzero(valid_mask.flatten(), as_tuple=False).flatten()
    max_seqlen_in_batch = seqlens_in_batch.max().item()
    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
    input_ids = rearrange(input_ids, "b s ... -> (b s) ...")[indices]

    unpad_seq_len = input_ids.shape[0]

    return input_ids, indices, cu_seqlens, max_seqlen_in_batch

由于输入形状变了,位置编码的施加方式也必须改变。

计算出的 position_ids 需要根据 indices 进行重排和筛选,以匹配扁平化后的 input_ids

# 将 (batch, seq) 的 pos_ids 展平,并只取非 padding 部分
position_ids = (
        index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices).transpose(0, 1).unsqueeze(1)
    )

最终传递给 Language Model 的参数最后调用底层 LLM 时,参数列表也不同

outputs = self.language_model(
    ...,
    indices=indices,          # 关键:用于恢复原始形状或索引
    cu_seq_lens=cu_seq_lens,  # 关键:FlashAttention Varlen 需要知道每个句子的边界
    ...
)

Qwen3VLTextAttention.forward

这里其实就是显示的使用 flash_attn_varlen_func来做attention计算, 从而达到packing的最终目的。

attn_output = flash_attn_varlen_func(
    q=query_states,
    k=key_states,
    v=value_states,
    cu_seqlens_q=cu_seq_lens,
    cu_seqlens_k=cu_seq_lens,
    max_seqlen_q=max_seqlen,
    max_seqlen_k=max_seqlen,
    causal=True,
    window_size=window_size,
    softmax_scale=self.head_dim**-0.5,
    dropout_p=0.0,
)

简单来说,flash_attn_varlen_func 的核心任务是:在一堆已经混在一起、分不清谁是谁的 Token 长条中,利用“索引导航”精准地还原出原本的句子结构,并进行 Attention 计算。

下面我结合你提供的代码,通过一个具体的例子来详细拆解这个过程:

  1. 数据的变形:从“方阵”到“长条”
  2. cu_seqlens (Cumulative Sequence Lengths)

Loss计算

基于cu_seqlens (这里是seq_lens)分别算每一个样本的shifted的输入和标签

if use_rmpad:
    # We need to shift the tokens according to seq lens
    # Otherwise, the first labels of the next seq will be the last labels of the current seq
    shift_hidden_states = []
    shift_labels = []
    for i in range(len(seq_lens) - 1):
        cur_hidden_states = hidden_states[seq_lens[i] : seq_lens[i + 1], :]
        cur_shift_hidden_states = cur_hidden_states[:-1, :].contiguous()
        cur_labels = labels[seq_lens[i] : seq_lens[i + 1]]
        cur_shift_labels = cur_labels[1:].contiguous()
        shift_hidden_states.append(cur_shift_hidden_states)
        shift_labels.append(cur_shift_labels)
    shift_hidden_states = torch.cat(shift_hidden_states, dim=0)
    shift_labels = torch.cat(shift_labels, dim=0)