位置插值(position Interpolation, PI)通过将超出训练长度的位置索引等比例缩小,映射到模型已经学习的位置范围内,实现长度外推。
好处是不用重新训练,直接在推理时加入。
论文提出 Extending Context Window of Large Language Models via Positional Interpolation
llama采用Rope位置编码,因此其实现都是针对rope编码的位置插值。
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def forward(self, x, position_ids):
# difference to the original RoPE: a scaling factor is aplied to the position ids
position_ids = position_ids.float() / self.scaling_factor
cos, sin = super().forward(x, position_ids)
return cos, sin
位置插值原理介绍: https://kaiokendev.github.io/til#extending-context-to-8k
class ScaledRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
max_position_embeddings = 8192
# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(
self.max_seq_len_cached,
device=self.inv_freq.device,
dtype=self.inv_freq.dtype,
)
# These two lines:
self.scale = 1 / 4
t *= self.scale
参考:
1.https://zhuanlan.zhihu.com/p/679147878
2.https://blog.csdn.net/v_JULY_v/article/details/135072211
3.https://kaiokendev.github.io/til#extending-context-to-8k
百川13B的位置编码是Alibi。因此是针对Alibi的长度外推。
有测试表明外推最大长度大约是训练的8倍时可以达到最佳性能:评论区
实现代码和步骤:
https://github.com/seanzhang-zhichen/baichuan-Dynamic-NTK-ALiBi
参考:
1.https://zhuanlan.zhihu.com/p/657161287
2.https://zhuanlan.zhihu.com/p/647628295