RetNet (The Transformer Killer) Demystified
TLDR;
In summary, Retentive Network (RetNet) exhibits similar performance to the Transformer of the same size, can be trained in parallel, and supports a recurrence mode that enables O(1) inference complexity per token.
For a comprehensive yet unofficial implementation, please refer to the repository provided below:
Transformers have emerged as the predominant framework for Language Model Models (LLMs) due to their effective resolution of the sequential training challenges faced by recurrent neural networks (RNNs). Nevertheless, transformers are not without their flaws, as they only address two aspects of the “impossible triangle”. Enter Microsoft’s RetNet, which claims to occupy the exact center of this impossible triangle, surpassing all previous attempts that fell short of achieving this remarkable feat. To underscore the significance of this breakthrough:
- RetNet exhibits SUPERIOR language modeling performance.
- RetNet accomplishes this with a 3.4x reduction in memory usage.
- RetNet achieves an 8.4x increase in throughput.
- RetNet boasts a 15.6x decrease in latency.
These improvements are orders of magnitude faster than the current state-of-the-art, while also delivering enhanced performance. If other teams are able to reproduce these results and if it becomes available in the open-source community, this will undoubtedly have a monumental impact. For now, Microsoft holds a highly valuable asset in their possession.
The primary inquiry at hand is the underlying factor that contributes to its greatness. In this blog post, we will explore and elucidate the responses to this query. We will meticulously deconstruct each equation to gain profound insights and vividly comprehend the processes involved. By employing a detailed example, we will examine the RetNet’s functionality, surpassing that of Transformers, and demonstrating its immense potential to reign as the dominant force. Prepare yourself for an intellectually stimulating experience, as we delve into a multitude of captivating concepts!
Quick History
Various methods have been developed to address the costly inference of generative transformers. Prominent contributions in this area include Linear Transformers, Attention-Free Transformers (AFT) by Apple, and RWKV by BlinkDL, which builds upon the AFT framework.
Motivation
The concept of the “impossible triangle” illustrates the limitations of current sequence models/LLMs in simultaneously achieving three crucial training objectives: parallelism, cost-effective inference, and high performance. The arms of the triangle represent the two dimensions that existing methodologies successfully address, while failing to fulfill the desired property of the third vertex. However, the RetNet architecture successfully accomplishes all of these objectives within a unified framework. To gain a deeper understanding of this pivotal motivation behind the development of RetNet, let us delve into the details.
Parallelism
RNNs, as their name implies, process sequences in a sequential manner, one after the other. The computation at each time step relies on the hidden state from the previous time step, making it impossible to perform parallel computation until all preceding steps have been processed. Consequently, this significantly slows down the training process.
On the other hand, Transformers utilize the self-attention mechanism, which is highly parallelizable. This allows for parallel processing of outputs at each time step using the Q, K, and V matrices. However, this advantage of parallelization becomes a disadvantage during inference, as we will discuss later.
RetNet combines the strengths of both approaches by incorporating three processing paradigms: parallel training and recurrent/chunk-wise inference. It incorporates the parallelizable self-attention mechanism from Transformers, but with clever techniques to overcome its limitations (more details on this will be provided later).
Complexity (Memory + Inference)
Inference cost per time step refers to the utilization of GPU memory, throughput, and latency, while memory complexity pertains to the scaling principles governing the memory requirements in relation to the length of a sequence. RNNs employ elementary and cost-effective operations, such as matrix multiplications, resulting in a constant inference cost that does not increase with sequence length (i.e. O(1)). Conversely, the memory complexity of RNNs scales linearly with sequence length.
On the other hand, transformers utilize self-attention blocks, necessitating the maintenance of an “NxN” matrix during inference. This leads to a linear increase in inference cost (O(N)) and a quadratic increase in memory complexity (O(N²)).
However, RetNet, while using Transformer’s self-attention blocks to parallelize training and achieve cutting-edge performance, does not encounter the aforementioned issues with inference cost and memory complexity. This is attributable to its modified self-attention module, which is replaced by a retention module along with the adoption of the recurrent inference paradigm, enabling RetNet to emulate self-attention during inference.
Performance
The primary benefit of Transformers in comparison to RNNs lies in their capability to handle longer sequences without experiencing catastrophic forgetting, thanks to the inclusion of self-attention heads. RetNet, on the other hand, achieves comparable or superior performance to Transformers, and we will soon delve into the details to further substantiate this claim.
RetNet vs Transformers
RetNet presents a solution that combines the strengths of two different approaches, demonstrating how we can effectively implement it. By employing the parallelizable training approach of Transformers, RetNet avoids the drawbacks of the inefficient and slow auto-regressive step found in RNNs. In a moment, we will examine the subtle yet significant modifications it applies to the original self-attention computation, replacing it with a retention mechanism during training. However, during inference, RetNet seamlessly transitions to the more memory and compute efficient recurrent paradigm of RNNs, thanks to the incorporation of the retention mechanism instead of self-attention.
Step 1: Parallelization
RetNet utilizes parallel representation learning from original transformers during training to move away from the constrained auto-regressive sequence processing of RNNs. However, it introduces a few modifications to the overall process. The summary diagrams above can help us identify these changes.
One noticeable change is that RetNet replaces the softmax operation with a Hadamard product involving a newly introduced D-matrix, followed by a GroupNorm operation. This deviation from the traditional approach may seem peculiar. The softmax operation has been the foundation of self-attention, which has contributed to the state-of-the-art performance of Transformers. It assigns relative attention weights to each token in the input sequence, facilitating the model’s ability to learn and retain long-term dependencies. However, it is worth recalling that the computation of softmax(Q. T) is also the reason behind the suboptimal inference time performance of Transformers. This is because it requires storing the NxN matrix, which grows quadratically with the sequence length. Ironically, the very feature that grants Transformers an advantage during training and downstream performance becomes their greatest obstacle during inference.
Prior attempts to address this issue have focused on approximating the softmax operation, but such approaches have typically resulted in decreased performance. However, RetNet introduces the remarkable combination of the D-matrix and GroupNorm, which achieves similar or even superior performance compared to Transformers. Moreover, RetNet exhibits significantly faster inference speed and improved memory efficiency. Additionally, RetNet can efficiently train in parallel during the training process.
How RetNet Works
RetNet is a plug-and-play modification in the Transformer architecture that substitutes “attention” with “retention”. I will review them in a systematic manner, starting from the top.
RetNet block
RetNet is composed of multiple stacks of identical blocks at its highest level. Each of these blocks includes MultiScaleRetention (MSR) and FeedForwardNetwork (FFN), along with layer-norm and skip-connections, similar to Transformers. The FFN in RetNet is also very similar to Transformers, as it consists of a 2-layer MLP with a hidden dimension size equal to 2 times the embedding size, and utilizes gelu activation.
If we replace MSR with MultiHeadAttention, RetNet essentially becomes a Transformer. Therefore, the distinctions between RetNet and Transformers lie primarily in the MSR component.
Gated Multiscale Retention
Multi-Scale is similar to Multi-Head. In the given equation, γ represents a hyperparameter utilized for retention, and it is defined independently for each head. Prior to the group norm, this corresponds to conventional multi-head attention, but with retention.
Gated MSR incorporates group norm, swish gate, and out projection on the output, which can be regarded as an additional design choice. (Group-norm enables scaled dot-product, but its significance is not crucial at this moment. However, the most significant differentiation, the retention module, is yet to be discussed.
Retention
This block comprises of parallel, recurrent, and chunk-wise recurrent modules.
Parallel Retention
The focus should be on the final line. Once again, without considering D, this is dot-product attention without the softmax function. Therefore, the crucial information lies in D and Theta.
- Theta (and bar(Theta), the complex conjugate) represents the complex form of the “xpos encoding. This is an extension of rotary embedding, allowing the model to better predict sequence length. In non-complex space, there exists an identical representation known as xpos, which is based on RoPE.
- D refers to the causal masking and decay matrix.
- The exponent corresponds to the number of times the previous token representation has been decayed. The significance of this will become clearer when we examine the recurrent representation.
In the original Transformers framework, the output is generated by applying affine transformations to the input embedding X using the WQ, WK, and WV matrices. This is followed by taking the softmax of the resulting (Q. T) and multiplying it with the V matrix. The result is stored in the output matrix O, which contains the contextualized embeddings of the input matrix X.
The authors of RetNet introduce the “retention” block, which operates in both a recurrent and parallelized manner. They first explain the recurrent setting, where each input element is processed individually. Then, they propose a vectorized version of the recurrent retention block.
The RetNet formulation closely resembles the original Transformer formulation, with a few changes. The softmax operation is replaced with a positional embedding term (pos). The concept of pos and its purpose will be discussed in more detail shortly. It is important to note that RetNet replaces the softmax operation with the pos matrix.
In the equation, pos’ represents the complex conjugate of pos. By further simplifying the equation with γ as a scalar value, the computation can be easily parallelized during training iterations, as shown in the image.
The first step of obtaining Q, K, and V in RetNet is the same as in the original Transformer. However, now the pos/pos’ embedding is multiplied element-wise with the Q and K matrices. The details of pos/pos’ and the D-matrix will be explained shortly. Looking at the final step of the parallel training phase formulation, it closely resembles the original Transformer computation, except for the substitution of softmax with the D-matrix. This formulation allows for complete parallelization, as the D-matrix can be pre-computed since it only consists of a relative positional embedding and a causal mask representation.
In summary, RetNet can be trained in a parallel paradigm with minor changes. However, further exploration is needed to understand the details of pos/pos’ and the D-matrix.
Relative Positional Embeddings
We do not need to delve into excessive detail regarding these positional embeddings, as they derive their intuition and functionality from the original positional embeddings of Transformers/LLMs. However, to gain a better understanding of the equation at hand, let us explore it briefly.
The Θ parameter incorporates “relative positional information” into each vector of the Q and K matrices through vector rotation. This imbues them with positional awareness and is achieved by taking the Hadamard product between the Q and K vectors and their respective position-specific vector rotations, as depicted in the provided image.
The Qn and Km vectors at each position undergo rotation using the red arrows as rotation vectors. By referring to the accompanying vector rotation diagram in Equation 5, it is evident that einθ/ eimθ, when n=m=1, undergo a single rotation. These are the rotation vectors for the Q1 and K1 positions. Similarly, for n=2 and m=2 positions, the vectors undergo a double rotation. The dot product between vectors with the same rotation (i. . positions along the diagonals) will equal 1. Additionally, when n=1 and m=2, the dot product occurs between two differently rotated vectors and represents a specific positional value for the vectors in that position. It is worth noting that as we progress further (e. . m=2, n=1, etc. , the dot product tends toward 0 as the vectors become orthogonal to each other.
Why do we only focus on the lower triangle in the figure above? The reasoning behind this will become clearer in the upcoming sections, but in short, it is because, at each time step, we are solely interested in and wish to “attend to” the information from past time steps.
Now, let’s move on to the part where Θ is element-wise multiplied with each vector in Q and K to confer them with “position awareness. Subsequently, we will explore the purpose of the proposed D-matrix.
The D-matrix functions as both a causal mask and an exponential decay weighting scheme for past positions.
Based on the definition of D in Equation 6, we observe that D in retention computation performs the tasks that masked attention and softmax accomplish in self-attention.
Masked attention — causal mask: For positions where n > m, the vectors of (Q. T) are multiplied by 0 to ensure the causal assumption of sequence processing holds true. This guarantees that information from future time steps does not leak into the current time step.
Softmax-exponential decay: For positions where n <= m, the vectors of (Q. T) are weighted with an exponentially decaying factor of γ. This implies that the further a token is in the past, the less significant it is for the current time step. This accomplishes the task of assigning different weights to the information from previous time steps, similar to what softmax achieves in self-attention. Although this approach is more constrained and less flexible than softmax due to its restrictive assumption, the authors have demonstrated its effectiveness.
By combining the operations illustrated in the figure above using the given Hadamard product, we achieve the final step of parallel operation. Now, it becomes apparent why the upper triangle of the “position-aware” (Q. T) is disregarded, as these values are set to 0 after the operation with D. The entire operation can be executed in a completely parallelizable manner during training.
Working example for the parallel training
Recurrent Retention
Sn in transformers is similar to the KV-cache. However, instead of concatenating them sequentially, RetNet combines them into a single matrix with the recurrent element in the first line. This matrix is then multiplied by the current query.
This process is essentially the same as parallel Retention.
Here is an informal proof sketch:
Let S_0 be equal to 0. By solving the recurrence of S_n,
Recall the last row of the exponent matrix for the D in the parallel representation, which was [3, 2, 1, 0]. Notice that n is equal to 4. When we calculate the retention for the 4th token compared to the 1st token, we decay it 3 times, which is equivalent to n — i = 3 in the equation above. Since the rest of the equation remains the same, the parallel and recurrent representations are identical to each other. For some more reworked examples on Recurrent Retention you can refer to this article.
Chunkwise Retention
Although it may appear intricate, the process involves parallel computation per chunk and the recurrent connection of the chunks. The crucial factor to consider is the number of decays applied.
Why Decay?
The key aspect to note is that it utilizes a concept known as decay, which, when applied the appropriate number of times, enables parallelization. However, it is crucial to comprehend the underlying motivation for employing such decay. The derivation, at a high level, is rather straightforward.
We define the recurrent state (s_n) as kv_cache. The first line in the provided equation represents the recurrence relationship. Additionally, we define the output at time n as Q_n * s_n. The second line in the equation illustrates this and solves the recurrence to unveil the complete dependency. It is worth noting that the A matrix is applied multiple times.
Next, we diagonalize the A matrix as follows.
Consequently, the Λ symbols can be incorporated into other parameters that can be learned (Q_n = X * W_k, hence Λ can be absorbed into W_k!). Consequently, only the middle part remains.
This middle part precisely corresponds to the γ (decay) and theta that were previously observed.
In essence, they function as a “closed-form positional encoding” that also possesses a recurrent-form, enabling the computation of the encoding at time n in advance, thereby facilitating parallelization.
Results from the paper
In our exploration of the inner workings of each component of the RetNet, we have provided working examples and explanations. The purpose of this blog is to equip readers with the necessary knowledge and mathematical understanding to delve deeper into the paper themselves and engage in further discussions and analysis.
Key Takeaways:
- RetNet demonstrates superior performance compared to the Transformer model as it scales up in size. It is uncertain whether this trend will continue in the future.
- RetNet outperforms other linear time transformers in terms of performance.
- The architecture of RetNet allows for fast processing, as expected.
- There are some missing details in the paper that will become clearer once the official code is published.
- The paper inaccurately dismisses the possibility of training parallelization in RWKV.
- It is uncertain whether this work will performs on par with LLaMA models.