Attention is all you need to understand
Unfortunately, it is also true that attention must be paid.
This AI boom / shockwave / revolution / proto-singularity (in order of increasing hype) is propelled by two technological breakthroughs, plus gargantuan amounts of data and computing power. Those breakthroughs—both, interestingly, simpler than the recurrent / generative adversarial / convolutional technologies they supplanted—are:
“diffusion” models, used to generate new images by beginning with noise and improving the quality across the entire image until you get something which is high- or even superb-quality, and what's more, matches the prompt. Dall-E, Midjourney, Stable Diffusion, etc.
“transformer” models, which generate their output one piece at a time, and use each newly generated fragment as part of the inputs for the next. ChatGPT, Bing, Bard, LLaMa, and above all, GPT-4.
Technically, transformers are in turn only one kind of "autoregression" model ... but thus far, they stand head, shoulders, torso, hips, and knees above all alternatives. They, and the landmark paper “Attention is all you need” in which they were introduced, are the key breakthrough that made today’s large language models possible.
“Attention” means the mathematical mechanism at the heart of what transformers do. If “attention is all you need,” then for non-technical people to have informed opinions / make informed decisions about AI, perhaps attention is all they need to understand? …OK, fine, clearly not all, but I truly believe it’s important for everyone to understand not just what modern AI can do, but how.
I've previously written a two-part post explaining diffusion models, with which I'm still pretty happy. I've also written about transformers ... but there I kinda handwaved the attention mechanism, and relied too heavily on forbiddingly complex diagrams.
So here I'm overcompensating by setting myself the task of explaining transformer “attention” in plain English … with no math, no code ... and no diagrams at all. Just words. This feels a little like announcing one's attention to bust out a few dance moves with a refrigerator strapped to one's back. Wish me luck.
(And, to be clear, this is definitely going to be an eat-your-mental-veggies abstruse-technical-deep-dive post, rather than a fun one about crazy AI weirdness. More of those in the near future, I promise! And I swear I won’t think any less of you for bailing out if you’re not in a brain-health-food mood.)
(…OK, for those of you who stuck around:)
Training and Embedding
To quickly summarize my previous “how AI / neural networks work in general” post; neural networks are composed of interconnected nodes, aka “neurons,” each of which has many inputs. Each node determines its output value by assigning a different weight to each input, multiplying each input by its corresponding weight, and adding up all those results. (When people say things like “GPT-3 is a model with 175 billion weights,” those weights are what they mean.)
Training
We can train such networks by coming up with mathematical ways to express what we want from them. These are called loss functions. (None are shown here. This is a safe, math-free zone.) An elegant mathematical trick called the “chain rule” lets us take a neural network’s output and, for each individual weight, determine whether it needs to be increased or decreased to improve the output. The process of training neural networks — of working out how much every weight needs to change — is called “gradient descent,” and for large networks, consumes an extraordinary amount of computing power.
Embeddings
When you train neural networks on language, on words, it turns out they immediately construct embeddings. What are they? I’ve gone into more detail previously, but, briefly: an embedding is the transformation of input tokens — a token can represent a single word or a few letters — into an array of (e.g.) 128 numbers. Each of those numbers represents something different about the embedded token(s).
An analogy: Consider a single Google Street View picture of a building. Now imagine you wanted to translate all the information implicit in that Street View shot to a row on a spreadsheet. That row would feature quite a large number of characteristics: latitude, longitude, altitude, address, time the picture was taken, building color, size, estimated age, style, whether it’s a residence, how many people live there if so, number of windows, you could go on and on.
Well, embeddings do exactly that for tokens. We don’t know the meanings of the (e.g.) 128 categories — the internal workings of neural networks remain a mystery! — but we know that language is complicated and words are very different from one another. (Consider, say, “the,” an article that connects to one other word, vs. “kitten,” a noun that relates to other words in a completely different way.) To a neural network, a word can have at least as many significant characteristics as a picture of a house can have for us. For simplicity’s sake, I’ll use “word”/“token” interchangeable from here on in.
Keys, Queries, Values, What?
OK. With that background information, let’s get into how “attention” actually works — and how large language models like ChatGPT use it to generate words.
In order to predict what word comes next in a text, as LLMS do, you must understand the words so far. Does this mean a conscious mind envisioning a kitten when it sees the word “kitten”? No; it means understanding all the contextual cues so that when you see the phrase “as playful as,” you know not just that the next word needs to be some kind of noun, but that “kitten” is especially likely, because “as playful as a kitten” is a somewhat common figure of speech. There are countless little contextual subtleties like that in English. How do large language models figure them out?
With math. In particular, for each word/token in a prompt or ongoing generated text, LLMs generate a query for that word; match that query against the keys for every other word in the text; multiply queries and keys together; and then multiply the result by the values, to get the “attention scores.”
So, what the heck is a query, a key, and a value? They’re actually just like embeddings — long series of numbers. The key insights that LLMs learn during training lie in the three weight matrices that generate queries, keys, and values. (I’ll get to attention scores in a bit. They’re actually the most intuitive of these four concepts.)
Queries
A query is, loosely, the kind of connections a given word is looking for. Consider that Google Street View of a building, and the row of values corresponding to it — location, color, height, style, age, etc. Suppose you were looking for a building suitable for the house’s current denizens to move into. You’d be very interested in other buildings’ size, amenities, and so forth, but less interested in color, shape, and exact number of windows. You’d be interested in buildings near the current one, pretty uninterested in buildings across the world. So you’d mark some of the items in that spreadsheet as important, some as less so. That’s exactly what a query does.
Keys
A key, meanwhile, equally loosely, is what a given word is good for. Taking the house example, suppose you wanted to write a real estate listing to sell the building. This might not be the same as the query! As a buyer, you might really want a house with a basement and/or a deck, but as a seller, you know that to most people those probably don’t matter nearly as much as location and number of bedrooms. Similarly, if you’re a word, an LLM probably cares quite a lot that you’re a noun or a verb, but doesn’t much care whether you rhyme with “lemon.”
Values
As for a value — well, even more loosely, this kind of indicates your context. For a house, stretching the analogy, these are maybe the claims that you can’t really quantify, that you’re in a nice-ish neighborhood, or one that’s up-and-coming. For a word, it’s maybe the mode you’re being used in, or your tense.
Language is a Network
Ultimately, what the LLM “wants” to know is: what other words matter to this word? Language isn’t really a series; it’s a network. When I use the word “query” in this sentence, its meaning is changed for you by the other times I’ve used it above. Imagine a faint spaghetti network of interlaced lines on this page you’re reading, connecting almost every word to almost every other, each line’s darkness corresponding to how strong that connection is, so densely interwoven it’s almost an opaque cloud. That’s the network of meaning that LLMs are trained to understand.
Weight Matrices
…After all that, the math is almost absurdly simple. Use the three weight matrices to generate a query, a key, and a value for every word, by simple multiplying each matrix by its embedding. (Take my word for it, in this math-free zone; this is an extremely straightforward operation.)
And where do we get those weight matrices from? We don’t program them. We don’t know how. Instead their contents are something the neural network learns during training, from the enormous computing power (it was recently revealed that GPT-4 cost well over $100 million to train!) poured into teaching them what to do.
Attention!
Then — once those matrices are trained to generate queries, keys, and values — again, it’s almost absurdly simple. For each word, take its query; multiply by the keys of all the other words; and multiply the assembled result by the values. What you get is the “attention scores” which explain the connections between the query word and every other word in the text — all the lines in that imagined word-network described above.
Still with me? I know it’s not exactly intuitive. But at the same time, positing that this is the major technological breakthrough of the last decade … it’s not as opaque and forbidding as you might expect, is it? It doesn’t require a PhD in some abstruse subject to even begin to grapple with it. In fact it seems oddly simple. If you’re a coder, here’s a pretty brief blog post describing how to write it from scratch. Here’s another. And another. And a description from a more mathematical perspective. Granted, I did leave out a few things — positional embeddings, softmaxes, residual connections, feedforward layers — but I promise you, they’re not super significant, they’re more engineering details than theoretical science.
OK, Fine, But Which Word Comes Next?
Having said that it’s very simple, I will say it gets a little more complicated. First, multiple attention layers are generally “stacked” on top of one another, with one’s output as another’s input, and the overall LLM gaining more insight from each. Second, as mentioned in my Transformers post, there are usually multiple attention networks, each known as an “attention head,” all acting on the same inputs at the same time. As the LLM is trained, each tends to specialize on a single kind of understanding, so that e.g. one might focus on syntax and another on grammar.
Ultimately, though, the previous section describes an “attention layer,” and an LLM consists largely of an embedding layer, then a bunch of attention heads working in parallel, each composed of a bunch of attention layers stacked atop one another. Finally, in the same way that the “gate” to the LLM is the embedding layer, the “exit” is a special layer that learns how to map the final outputs of all these attention heads, to our dictionary of embeddings. Again, it’s actually a very simple concept; the last layer — the LLM output — provides, for every possible word/token, the likelihood that it comes next.
Note: likelihood. LLMs don’t actually “predict the next word” per se. Instead they say something like “Probably 40% chance that it’s ‘kitten’, but 30% ‘puppy’, 15% ‘goldfish’, 8% ‘Patrick’, 3% ‘Danny’, 1% ‘me’ …” and so forth. It’s then up to the system which incorporates the LLM whether it always takes the most likely word, if/how often it takes less likely words, etc. (If you’re an OpenAI API user; this is the concept described there by “temperature.” At a temperature of 0, the single most likely word is always chosen. At higher temperatures, it’s much less predictable.)
…Then for the next next word, the LLM starts all over again — but now the previously chosen word, the output word, becomes the last word of the new input. That’s what “autoregressive” means.
Self-Attention vs. Cross-Attention
The above is called “self-attention” because the neural network is given an input text, and for each word in that text, it looks for its connections — calculates the attention scores — for all the other words in that same text. But this isn’t the only option. In fact it wasn’t what transformers were built for! The original “Attention is all you need” paper had translation as its objective, so the outputs were different words in a different language.
This is known as “cross-attention” and is, mathematically, actually not that different from self-attention. Instead of the queries, keys, and values all coming from the input text, the keys and values can come from different data. In fact this can be different kinds of data; not just two different languages, but, for example, text and images. (Cross-attention is used in e.g. Stable Diffusion to train it to recognize and respond well to prompts.)
Attention Must Be Paid (for)
Transformers are extraordinarily powerful in large part because their relative simplicity, compared to recurrent neural networks, makes them much easier / cheaper / faster to train. (This isn’t to say training them is easy / cheap / fast! But at least it’s, you know, doable for tens to hundreds of millions for the state of the art.) The flip side, though, is that they’re relatively expensive to run. The “inference cost” — how much work you need to do to run them once they’re trained — is relative to the square of the number of tokens in the input. (For each token you have a query and a key, and you need to multiply every query by every key.)
This is why the LLMs you can run on your laptop (or even in your browser(!)) are not just kinda dumb but also very slow; why GPT-3.5 Turbo remains a big deal despite the vast superiority of GPT-4; why OpenAI charges by the token; and why a single GPT-4 API call with 32,000 tokens in and out, once they make that promised capability available, will cost you north of six dollars. Yes, that’s right, for a single API call. The future is made of many, many humming GPUs … and it ain’t cheap.
Worth Paying Attention To
The costs and limitations of transformers and the attention mechanism means that researchers are looking for superior alternatives. For instance, a Stanford team has proposed a new approach, whose cost scales by the number of tokens rather than the square of the number, in a paper whimsically titled “Hungry Hungry Hippos: Towards Language Modeling with State Space Models.” Which is interesting! …But it’s a long way from paper to practical implementations that outperform the state of the art. Still, even if less capable overall, if this works in practice, it’ll be at the very least a very useful LLM niche.
In the interim, though, the transformer remains king. If you want to build one — if you’re a coder and you want to learn by doing not reading — I cannot recommend this two-hour YouTube tutorial from OpenAI cofounder Andrej Karpathy highly enough:
…And that’s all for now. Attention without diagrams! Clearly you needed that. Right?