Introduction
Large Language Models (LLMs) are trained with a predetermined context length, limiting their applicability in situations that demand lengthy inputs. To overcome this, LLMs typically involve fine-tuning with the desired length, incurring significant training costs. To decouple the link between training length and target length for a more efficient extension of the context window, we are using Positional Skip-wisE (PoSE) training, which cleverly emulates long inputs using a fixed context window.
Drawbacks of existing methods
Position Interpolation is one of the first techniques that was designed to extend the context window. It proposes a strategy of down-scaling the position indices to align with the original window size, resulting in enhanced outcomes for extending the context. There are multiple approaches to implementing Position Interpolation (Linear, YaRN, NTK, etc). However, all these approaches depend on Full-length fine-tuning, which means fine-tuning with a context of the target length, a process characterized by high memory and time demands due to the computational complexity that escalates quadratically with input length. These methods have high computational costs. It’s impractical to extend the context window to extreme lengths.
PoSE’s key advantage lies in being the first cost-effective technique to extend the context window to extreme lengths, like 128K. It substantially reduces memory and time overhead compared to Full-length fine-tuning, while minimally affecting performance.
Why PoSE?
The key idea of PoSE is to simulate long inputs by manipulating position indices within a fixed context window. Figure 1 shows how we divide the original context window into chunks. We then adjust the position indices of each chunk by adding a distinct skipping bias term. Bias terms and chunk length change for each training example. This helps the model adapt to all positions (including both absolute and relative) in the target context window through fine-tuning. Meanwhile, PoSE resembles pre-training by maintaining position indices within chunks. As a result, The model retains its pre-trained capacity for language modeling and comprehension.
**The above image is from the paper: https://arxiv.org/abs/2309.10400
The advantages of PoSE are threefold:
1) Memory and Time Efficiency: During the fine-tuning stage, PoSE avoids the quadratic increase in computational complexity linked to the target length. It only needs the original context size, reducing memory and time overhead.
2) Potential for Extremely-Long Context: Allows to increase the context window of Mistral by up to four times (8k→32k) while maintaining a reasonable level of language modeling and understanding.
3) Compatible with all RoPE-based LLMs and PI strategies: PoSE’s efficacy is empirically validated across several representative RoPE-based LLMs, including LLaMA, LLaMA2, GPT-J, Baichuan (Baichuan, 2023), and Mistral.
Compatibility of PoSE with Fundamental Building Blocks:
Positional Encoding(PE):
- Vanilla Positional Encoding: Vanilla positional encoding in the Transformer model uses sine and cosine functions to embed information about the position of tokens in a sequence, allowing the model to capture sequential relationships and understand the order of input tokens.
- Rotary Position Embedding (RoPE): RoPE is popularly used in LLMs such as LLaMA, GPT-J, Mistral-7B, and others. It encodes token position information using a rotation matrix that includes explicit relative position dependency.
PoSE is compatible with all RoPE-based LLMs.
Position Interpolation (PI):
- Linear Interpolation: Linear interpolation involves a proportional down-scaling of the position index m to m/α. Consequently, the attention score between a query q at position m and a key k at position n becomes g(q, k, θ,(m−n)/α). Theoretical analysis of the same has shown that the interpolated attention score is significantly more stable than the extrapolated counterpart.
- Neural Tangent Kernel (NTK) Interpolation: NTK interpolation modifies the base of RoPE, as opposed to linear interpolation, thereby changing the rotational “speed” of each RoPE dimension.
- YaRN Interpolation: YaRN uses a ramp function to combine Linear and NTK interpolation at different proportions across different dimensions, in contrast to Linear and NTK interpolation, which treats each dimension of RoPE equally. To counteract the attention matrix’s distribution shift brought on by lengthy inputs, it also introduces a temperature factor simultaneously.
PoSE is compatible with all three methods mentioned above – Linear, NTK, and YaRN interpolation.
Deep dive into PoSE
To tackle the challenge of out-of-distribution position indices in extending the context window length of Large Language Models (LLMs) without incurring the impractical computational costs of fine-tuning on extremely long sequences, PoSE proposes an approach that manipulates position indices within the original context window. This method simulates longer inputs and adheres to two key design principles:
- Comprehensive coverage of relative position distances within the range of {1, …, Lt – 1}.
- Maintain the model’s original capabilities by closely mirroring the original position index structure.
In the Positional Skip-wisE (PoSE) method, the original context window Lc is split into N segments. For each segment, we introduce a skipping bias term, randomly sampled from a discrete uniform distribution. This bias term alters the original position indices of its respective segment, broadening the range of relative positions the model is exposed to.
To ensure extensive coverage of the target context window, the length and skipping bias term of each chunk are re-sampled for every training example. This strategy ensures that the continuity of position indices within each chunk remains similar to the pre-training structure. Thereby, it preserves the model’s pre-trained capabilities when fine-tuned on these new position indices for language modeling tasks.
Experiments and Results:
In this experiment, we are using Positional Skip-wisE (PoSE) to increase the context window of Mistral7B from 8K to 32K. Our method demonstrates impressive results for language modeling as well as passkey retrieval.
For each setting in the main experiments, we train Mistral7B with the next token prediction objective. This training process comprises 1,000 steps, employing a global batch size of 64 on 8 A6000 GPUs using Deepspeed ZeRO stage 3.
PassKey retrieval:
Table 1: The evaluation focuses on their effectiveness in passkey retrieval, highlighting the impact of varying context lengths on the models’ ability to extract crucial information. Our model excels in information extraction, capable of handling context lengths up to 32k, surpassing the limitations of the original Mistral7B model which could pass the test cases only if the context window was under 8k.
Standard Benchmarking:
Table 2: Our model achieves an extension to 32k while only experiencing a marginal impact on the standard benchmark accuracy. This demonstrates a commendable ability to handle longer contexts without significantly compromising overall performance.
Conclusion:
In conclusion, we successfully used Positional Skip-wisE (PoSE) training to extend the context window of Mistral 7B. PoSE simulates long inputs using position indices, requiring only the original context window for fine-tuning and successfully decoupling train and target length. Experiments have shown that, when compared to full-length fine-tuning, PoSE significantly reduces memory and time overhead. Using this, we were able to scale the Mistral 7B model to 32k on 8 A6000 GPUs with only minor performance degradation on standard benchmarks. We also empirically demonstrated that PoSE works with all RoPE-based LLMs and position interpolation strategies.
The Model is published here on Hugging Face : mistral-7B-PoSE-32k