简介
基于lmms-engine中的训练时对数据packing操作以及use_rmpad消除了所有padding计算的逻辑
Packing
总体逻辑基于packing_length 将不同的数据填充到一个sequence中,具体来说
在Datsset中, 如下代码所示,将不同的数据append到buffer列表中
if self.config.packing:
# Reset index at the start of each iteration pass
self.cur_idx = 0
buffer = []
buffer_length = 0
packing_length = self.config.packing_length
# Iterate through the dataset once per epoch
while self.cur_idx < len(curr_data_list):
try:
data_dict = self.get_one_sample(self.cur_idx, curr_data_folder[self.cur_idx], curr_data_list)
except Exception as e:
traceback.print_exc()
logger.error(f"Error getting one sample: {e}, skip this sample")
self.cur_idx += 1
continue
input_ids = data_dict["input_ids"]
data_length = input_ids.shape[0]
self.cur_idx += 1
# Drop overlong sample if filtering is enabled
if data_length > packing_length and self.config.filter_overlong:
continue
# If current sample cannot fit into current buffer, yield the buffer first
if buffer_length > 0 and buffer_length + data_length > packing_length:
yield buffer
buffer = []
buffer_length = 0
# If the sample is still longer than packing_length (and not filtered),
# yield it as its own batch to avoid stalling
if data_length > packing_length:
yield [data_dict]
continue
# Append to buffer
buffer.append(data_dict)
buffer_length += data_length
# Flush remaining buffer
if len(buffer) > 0:
yield buffer
在 Collator 组合成batch的形式传入到模型的输入, 这里还是将数据padding
@ dataclass
class VisionCollator:
processor: Processable
def pad_sequence(self, input_ids, batch_first, padding_value):
if self.processor.tokenizer.padding_side == "left":
input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value)
if self.processor.tokenizer.padding_side == "left":
input_ids = torch.flip(input_ids, [1])
return input_ids
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
if isinstance(instances[0], list):
instances = [inst for instance in instances for inst in instance]
inputs = collections.defaultdict(list)
for instance in instances:
for key, values in instance.items():
inputs[key].append(values)
batched_inputs = {}
if "input_ids" in inputs.keys():
input_ids = inputs.pop("input_ids")
input_ids = self.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.processor.tokenizer.pad_token_id,
)
batched_inputs["input_ids"] = input_ids
if "labels" in inputs.keys():
labels = inputs.pop("labels")
labels = self.pad_sequence(
labels,
batch_first=True,
padding_value=-100,
)
batched_inputs["labels"] = labels
if "attention_mask" in inputs.keys():
inputs.pop("attention_mask")
attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id).long()
batched_inputs["attention_mask"] = attention_mask
# for the other keys
for key, values in inputs.items():
# Handle scalar/boolean values ( use_audio_in_video)
if isinstance(values[0], bool) or (
isinstance(values[0], (int, float)) and not isinstance(values[0], torch.Tensor)
):
batched_inputs[key] = values[0]
else:
batched_inputs[key] = torch.concatenate(values, dim=0)
return batched_inputs
@property
def image_token_id(self):
return self.processor.tokenizer.convert_tokens_to_ids(self.processor.image_token)
rmpad
项目中,是以 monkey patch的形式(也就是打热补丁) 替换rmpad操作的,如下代码所示,主要就是替换模型中的forward操作
if use_rmpad:
from .qwen3_vl_ops import attn_forward as qwen3_ops_attn_forward
from .qwen3_vl_ops import (
decoder_layer_forward as qwen3_ops_decoder_layer_forward,
)
from .qwen3_vl_ops import model_forward as qwen3_ops_model_forward
from .qwen3_vl_ops import text_model_forward as qwen3_ops_text_model_forward
modeling_qwen3_vl.Qwen3VLModel.forward = qwen3_ops_model_forward
modeling_qwen3_vl.Qwen3VLTextModel.forward = qwen3_ops_text_model_forward
modeling_qwen3_vl.Qwen3VLTextDecoderLayer.forward = qwen3_ops_decoder_layer_forward
modeling_qwen3_vl.Qwen3VLTextAttention.forward = qwen3_ops_attn_forward
Qwen3VLModel.forward
显式调用了 _unpad_input。它计算了非 padding 元素的索引 (indices) 和累积序列长度 (cu_seq_lens), 去掉了input中的padding token
if input_ids is not None:
original_input_ids = input_ids
input_ids, indices, cu_seq_lens, _ = _unpad_input(input_ids, attention_mask=attention_mask)
batch_size, seq_length = original_input_ids.shape
elif inputs_embeds is not None:
original_inputs_embeds = inputs_embeds
inputs_embeds, indices, cu_seq_lens, _ = _unpad_input(inputs_embeds, attention_mask=attention_mask)
batch_size, seq_length, _ = original_inputs_embeds.shape
def _unpad_input(input_ids, attention_mask):
valid_mask = attention_mask.squeeze(1).squeeze(1).eq(1)
seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(valid_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
input_ids = rearrange(input_ids, "b s ... -> (b s) ...")[indices]
unpad_seq_len = input_ids.shape[0]
return input_ids, indices, cu_seqlens, max_seqlen_in_batch
由于输入形状变了,位置编码的施加方式也必须改变。
计算出的 position_ids 需要根据 indices 进行重排和筛选,以匹配扁平化后的 input_ids。
# 将 (batch, seq) 的 pos_ids 展平,并只取非 padding 部分
position_ids = (
index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices).transpose(0, 1).unsqueeze(1)
)
最终传递给 Language Model 的参数最后调用底层 LLM 时,参数列表也不同
outputs = self.language_model(
...,
indices=indices, # 关键:用于恢复原始形状或索引
cu_seq_lens=cu_seq_lens, # 关键:FlashAttention Varlen 需要知道每个句子的边界
...
)
Qwen3VLTextAttention.forward
这里其实就是显示的使用 flash_attn_varlen_func来做attention计算, 从而达到packing的最终目的。
attn_output = flash_attn_varlen_func(
q=query_states,
k=key_states,
v=value_states,
cu_seqlens_q=cu_seq_lens,
cu_seqlens_k=cu_seq_lens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
causal=True,
window_size=window_size,
softmax_scale=self.head_dim**-0.5,
dropout_p=0.0,
)
简单来说,flash_attn_varlen_func 的核心任务是:在一堆已经混在一起、分不清谁是谁的 Token 长条中,利用“索引导航”精准地还原出原本的句子结构,并进行 Attention 计算。
下面我结合你提供的代码,通过一个具体的例子来详细拆解这个过程:
- 数据的变形:从“方阵”到“长条”
cu_seqlens(Cumulative Sequence Lengths)
Loss计算
基于cu_seqlens (这里是seq_lens)分别算每一个样本的shifted的输入和标签
if use_rmpad:
# We need to shift the tokens according to seq lens
# Otherwise, the first labels of the next seq will be the last labels of the current seq
shift_hidden_states = []
shift_labels = []
for i in range(len(seq_lens) - 1):
cur_hidden_states = hidden_states[seq_lens[i] : seq_lens[i + 1], :]
cur_shift_hidden_states = cur_hidden_states[:-1, :].contiguous()
cur_labels = labels[seq_lens[i] : seq_lens[i + 1]]
cur_shift_labels = cur_labels[1:].contiguous()
shift_hidden_states.append(cur_shift_hidden_states)
shift_labels.append(cur_shift_labels)
shift_hidden_states = torch.cat(shift_hidden_states, dim=0)
shift_labels = torch.cat(shift_labels, dim=0)