Attention mechanisms and beyond

Diverger
9 min readMay 17, 2024

--

by Aitor Mira

Self-attention

Quick review

Self-attention is the core of transformers and the reason behind the current AI disruption wave of LLMs. Since is a topic long discussed, at this point most of you surely had read loads of articles about it. Then, I’m not going into details but at least remind the basic idea and components.

Core idea

As we (humans) tend to give higher importance to certain words at reading because it provides more meaning, transformers try to do the same by “attending” more certain words. This importance or attention is just a number or score that transformers learn by repetition of a task. This task is to predict the next word.

After some training you’ll get a model able to assign importance relation between words in a text. But remember, these importance’s are not unique between two words, they depend on the context (surrounding words). This is how they look like:

Check some nice visualizations here.

The self in self-attention

As you may have guessed, the self in self-attention means that the word importance calculations are calculated within the same text. To make it clear, transformers don’t read any other words than the text you provide.

The technical components

Self-attention mechanism consists of three main matrices called Query, Key and Value. They work as follows:

  1. Query represents the terms for which you want to get the attention value.
  2. Key holds the information linking query and values.
  3. Whey Key and Query are multiplied (dot product) the result is the Attention Matrix
  4. Value has the actual content or meaning.
  5. Again, multiplying Value and the Attention Matrix, the result is the Attention Output or score.

This explanation is just an allegory of what is actually happening, which is just three random matrices being iteratively multiplied and magically working by the grace of deep learning.

Keep it in memory

If you would remember just one thing, keep in mind that in this process, the memory bottleneck comes in the third step. Why? Because the dot product of Query and Value (Attention Matrix) results in a matrix of NxN elements where N is the number of words in the original input (the text itself). NxN = N² with means that memory needed growth exponentially (quadratically) with the original text length.

All the other multiplications during self-attentions are at some point limited by a constant term. For example, the Attention Output matrix is AxN where A is the expected output dimension (a fixed number decided by the designers of the transformers). Comparing some cases:

  1. Case 1 with small texts:
    a. A = 64
    b. N = 32
    c. Attention matrix memory = 32*32 = 1024
    d. Attention output memory = 64*32 = 2048
  2. Case 2 with longer texts:
    a. A = 64
    b. N = 128
    c. Attention matrix memory = 128*128 = 16384
    d. Attention output memory = 64*128 = 8192

Note: It may seem obvious to you, but a common strategy to understand complex mechanisms is reduce them into toy problems with small numbers. If, like me, you try to do this, you’ll had a hard time.

Sliding Window Attention

If input text is too long, just cut it

Given that the problem comes with the length of the input text, we could just clip it to a maximum length? Okay, just set a window_size parameter and remove the words further than this param from the current word (i.e. window attention). Let’s analyze the following example:

With the sentence “The cat is hunting a grey” and window_size =3, if we try to follow the sentence (predict next word) what the transformer will see is “̶T̶h̶e̶ ̶c̶a̶t̶ ̶i̶s̶ hunting a grey”. Try as a human to complete the sentence in both cases, see the problem now?

Hint. Search for both sentences in google and see how substantially different the next word prediction could be.

Precisely, just cropping text results in information loss. How can we overcome this? Let’s add the slide concept to window attention (SWA)¹.

Core Idea

See the reference below²

Cropping text results in information loss, if we could figure how to communicate the result of different text chunks into the following chunks, we could escape the tradeoff. The idea would be something like “pay attention of a small chunk of text and then pay attention to the next chunk keeping a vague idea of the previous attention”. Theoretically, if we perform this operation sequentially (in steps or layers), we could cover an “infinite” text keeping an idea of all attention you paid. Let’s apply it into our example:

With a window_size=3 (counting the word itself it’s just looking 2 words behind), representing in green the words in the current window and in red the words that are somehow remembered from previous iterations.

This is how Sliding Window Attention handles long texts, saving memory but keeping (trying) the attention on whole text.

Keep it in memory

If we want to compare the efficiency of SWA to full self-attention, we have to understand that the memory consumption during inference time for attention matrix with SWA will be WxN having W = window size and N = text length while full self-attention is NxN.

Important: this savings comes almost exclusively during inference time, not during training. During training there are other more subtle savings but out of this article scope.

Attention Sinks

Let it flow

Self-attention and Sliding Window Attention are great and work well, but there’s always a but. When an LLM working with the previous mechanisms, handles a long multi-turn conversation, at some point tends to “lose fluency”, which basically means that starts to write nonsense stuff. Why? For some reason, the conversation starters (first few words or messages) hold a large amount of the attention score. That what are called Attention Sinks.

Therefore, when they fall outside de window size, the attention mechanism is forced to redistribute the attention score along the current window. Hence, some words with no much importance gains importance and on sequential layers model “loses the track”.

This is not super-intuitive, and it’s mainly caused due to the inner math of attention mechanisms so don’t overthink, some smart people did it for us³.

Core Idea

Since the problems arose when the first words fall outside the window, let’s keep it forever. We can force the SWA mechanism to always keep the firs N words and drop the words in the middle. Easy right?

Here’s a figure on how attention sinks mechanism works³:

Since the first words (tokens) are kept, the high importance is never lost and redistributed. This effectively solves the fluency problem on long multi-turn conversations.

Thoughts

I know this idea sounds weird. Even after long hours of study, it seems neither convincing nor intuitive to me. But think of it as mathematical optimization for a specific case problem.

Beware. This works exclusively on conversation whose goal is to, precisely, keep the conversation flowing. It won’t work effectively for information retrieval over the middle part of the texts, as stated by the paper authors. Finally something intuitive 😅

Grouped Query Attention

Stick similar things together

Remember the Query, Key, Value stuff I told you at self-attention? Well, I hided something to you. Most of transformers take this QKV attention architecture and creates copies which are executed in parallel. This QKV blocks are called Heads, and the mechanism is called Multi-Head Self-Attention (MH SA). Why this work? Easy, “four eyes see better than two”. Each Head will learn to take attention to different text patterns and hence, improve understanding when combining all Heads attention. But wait, if a single QKV block had memory issues this just increases the problem H times. How could we solve that?

Well, we don’t have a full solution for that, but if we could reduce the number of Heads we’ll save precious memory and compute power. That’s where Grouped Query Attention comes. Since the Query acts as an “input representation” if we find similar Queries within different Heads, we could group them.

As show in the figure⁴, by grouping Queries, the number of Heads can be reduced decreasing linearly memory and computations.

So that’s all? Well, yes, but still the attention mechanism now will relay in the quality of the similarity calculation. However, that’s out of the scope.

Ring Attention

Think bigger

Forget everything above (please don’t, is just a saying). We are not #GPUpoor anymore and have multiple devices (e.g. GPU, TPU, CPU…). How could we leverage them effectively?

Back to the roots. Having multi-device means having distributed memory and compute power. The self-attention limitations come mainly from attention matrix memory and QKV computations. So, the solution is? Distributing memory and computations through all devices. Brilliant, but how? Fortunately, some clever people from UC Berkeley came with the Ring Attention⁵.

Core Idea

This one is pretty easy and intuitive, I promise. Let’s say we have N devices, the mechanism goes as follows:

  1. Each devices receives a Nth fraction of the input (input/N). e.g. If the input has 400 words and we have 4 devices, each device will receive 100 words.
  2. Each deviceᵢ performs local self-attention (QKV) on his input/N block. The result is Voutᵢ.
  3. Once finished, each device will feed the following device with K and V matrices.
  4. Self-attention is calculated again on each deviceᵢ but this time with the corresponding Query (Qᵢ) and KV from previous device (Kᵢ-₁Vᵢ-₁).
  5. Then, having Voutᵢ and the new Voutᵢ-₁ we just sum them (wighted sum).

Illustration of ring attention topology as explained by Coconut Mode⁶.

That’s it, simple and effective. By leveraging the Ring structure and the designed algorithm we can effectively:

  • Reduce memory need to a fraction of the input length.
  • Parallelize calculations and aggregating them seamlessly.
  • Perform a pseudo-global attention not leaving any context out.

The benefits are wide and deep. The limitations mainly come from managing hardware:

  • Many of us doesn’t have many devices well suited for a ring topology.
  • Making the hardware work and communicate across multiple devices without latency and problems is a hard skill.

Bonus

Somehow, how ring attention solves the global context awareness reminds me to how SWA does. Does it ring a bell for you too? 😁

One more thing. Ring topologies are not only used on attention mechanism within transformers. Word embedding has a ring like implementation called Rotatory Positional Embedding (RoPE), but that’s a bit out of scope.

References

  1. Beltagy, I., Peters, M. E., & Cohan, A. (2020). Longformer: The Long-Document Transformer
  2. Jiang, A. Q., Sablayrolles, A., Mensch, A., Bamford, C., Chaplot, D. S., de las Casas, D., Bressand, F., Lengyel, G., Lample, G., Saulnier, L., Lavaud, L. R. L., Lachaux, M.-A., Stock, P., Le Scao, T., Lavril, T., Wang, T., Lacroix, T., & El Sayed, W. (2023). Mistral 7B.
  3. Xiao et al., 2023
  4. Ainslie, J., Lee-Thorp, J., de Jong, M., Zemlyanskiy, Y., Lebrón, F., & Sanghai, S. 2023. GQA: Training Generalized Multi-Query Transformer Models from Multi-Head
  5. Liu, H., Zaharia, M., & Abbeel, P. (2023). Ring Attention with Blockwise Transformers for Near-Infinite Context
  6. Coconut mode ring attention

Nice blogposts to check

--

--

Diverger

Inteligencia Artificial Generativa aplicada para los profesionales de la información y para el desarrollo de software.