Transformers: more than meets the eye
♫Attention heads use their power to infer the structural relationships♫
It’s fair to say there are two landmark scientific papers in modern AI which stand apart from all others; two which, apart from their own impressive merits, are widely seen to have ushered in entire new eras. The first, in 2012, ignited the “deep learning” wave. The second, “Attention Is All You Need,” published in 2017 by Google Brain, propelled us into the Age of the Transformer. Six years later, that is where we live. Transformers are the T in GPT-3, ChatGPT, and BERT; they are integral to DALL-E and Stable Diffusion, and at the heart of DeepMind’s AlphaFold; they lurk behind almost every recent AI headline.
Last week I attended a talk in San Francisco by the founders of Generally Intelligent, where among the many interesting things said was: “I think it's important for people to understand how these models work at a deep level … they're not as complicated as you might think.” In that spirit, I thought I’d write a post which tries to explain what the Transformer architecture is, and how they work, to a general non-technical audience — now that I finally think I basically understand them myself.
Like my previous post about the basics of neural networks, this will remain a math-free zone, except for some brief mocking of the complexity of the math.
To recur or to attend?
AI models, like computers in general, are fundamentally machines that transform a series of numbers into a different series of numbers. The magic and meaning is in what those outputs represent — written language, visual images, protein structures, etc.
But to get good outputs, you need good inputs, which can be quite lengthy. Before 2017, text inputs were generally fed into a model one word at a time. Obviously, to make sense of a full sentence, the model had to somehow remember its previous words. This was managed by feeding the output for each word back into the inputs for the next. Such models were called recurrent neural networks.
It may surprise you that modern AI doesn’t have “memory” like traditional computers. Everything an AI model knows is “baked in” during its training phase; when models perform what’s known as inference, i.e. when they’re actually used, they don’t have any data storage, they just accept inputs and generate outputs. (This may change; there’s a lot of active research re adding some kind of storage to large language models.) Recurrence is in a sense a hack which gives models a rough equivalent of memory.
But like many/most hacks, recurrence doesn’t work well at scale. Transformers get far better results by doing something much simpler — feeding the entire input to the model at once, and training it to pay attention to which aspects actually matter. An imperfect analogy: a recurrent model is taught how to read an essay and then answer questions about it, whereas an attention model is taught how to read the questions first, then go through the essay and identifies which bits are most relevant.
Latent space, the final frontier
The above is all pretty abstract. Let’s talk about what transformer models actually do.
Suppose our input is an English phrase, and we want our model to continue where it leaves off, or translate it into another language. (GPT-3 can do both.) The first step is to take our phrase and tokenize it — turn it into a series of numbers. In general, individual words turn into individual numbers, but complex or unusual words may be subdivided into two or more tokens. For instances, OpenAI’s GPT-3 tokenizer turns
To be or not to be, that is the question; whether 'tis nobler in the mind to suffer the slings and arrows of outrageous fortune, or to take arms against a sea of troubles, and by opposing end them.
which is 39 words, into 47 tokens:
[2514, 307, 393, 407, 284, 307, 11, 326, 318, 262, 1808, 26, 1771, 705, 48010, 31801, 1754, 287, 262, 2000, 284, 8659, 262, 1017, 654, 290, 20507, 286, 23077, 15807, 11, 393, 284, 1011, 5101, 1028, 257, 5417, 286, 14979, 11, 290, 416, 12330, 886, 606, 13]
Next those numbers are turned into embeddings in latent space. What the heck is an embedding? Wikipedia provides an excellent answer: “a real-valued vector that encodes the meaning of the word in such a way that words that are closer in the vector space are expected to be similar in meaning.”
Now, a vector is another series of numbers. Imagine a neural network which performs embeddings — an “encoder” — such that each of the 47 tokens above is embedded as a series of 3 numbers. Now think of those triplets as locations on a three-dimensional graph
where, say, P at (5, 0, 6), represents the word “suffer,” and Q at (0, 4, 5) represents the word “against.” We can expect that, when embedded on this graph, the vector for “troubles” is closer to P, since that word is more like “suffer” than it is like “against,” and that the vector for “opposing” is closer to Q, since it’s more like “against.” In this example, the entire three-dimensional space we’re talking about is our latent space.
Good news: that’s fairly comprehensible, right? Bad news: in practice, latent space is N-dimensional, where N is typically 100 or more, and I’m pretty sure you can’t really envision what 100-dimensional space looks like. Sorry. But you get the idea, I hope.
Oh, and what we use to get the word embeddings? We use a neural network. Do we train them to do that? …Sort of? A remarkable thing is that if you train a neural network to process language, it will very likely generate a word-embedding system internally in order to do so. Various pre-trained embedding models are available, and OpenAI makes theirs available as a paid API.
What it means to pay attention
OK, you’ve rendered a bunch of input text into its word embeddings. Next: attention! To quote Attention Is All You Need, “An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.” See? That was easy! Moving on…
…OK, fine, let’s unpack that a little. In English: “An attention function takes each word in the input, and is trained to learn how closely it’s connected to every other word. Then each of those words is treated with significance corresponding to that closeness.” Take, for example, from the Google Brain announcement post, two sentences which are to be translated from English to French:
The animal didn’t cross the street because it was too tired.
The animal didn’t cross the street because it was too wide.
Previous attempts at translation via machine learning had immense trouble with such sentences, which, in fairness, are genuinely ambiguous; two sentences, identical but for the last noun, in which the word “it” means two completely separate things! Ah, English. But transformer models learn how words relate to one another - that the first “it” is connected to “animal” and the second “it” is connected to “street.”
It turns out that if you layer such attention functions together with simple feedforward neural networks, and spend vast quantities of computing power training them on gargantuan piles of data, then this learned understanding of the relative significance and connections of words in a given text gives the resulting model an entirely unreasonable and unexpected grasp of language. It’s honestly shocking how effectively large language models (aka LLMs) powered by transformer attention functions can wield language and predict the continuations of text.
Eight heads are better than one
Now, granted, a full-fledged transformer is not quite as simple as I describe above. They don’t just run the inputs (and the generated output, as it accumulates) through one trained attention function; they run them through several, simultaneously, in parallel. (Eight, in the original paper.) Each is called an “attention head,” although really “eye” is a better metaphor. If the input is a scroll, and a transformer is a mythical beast, then this mythical beast has eight eyes —
—and each of those eyes learns to perceive something different. “Not only do individual attention heads clearly learn to perform different tasks, many appear to exhibit behavior related to the syntactic and semantic structure of the sentences,” to once again quote the seminal paper.
To be clear, we don’t try to make each “attention head” see something different in the inputs we feed them. Like embeddings, that’s something that just happens — automagically — as part of the training process. Why? …We’re not entirely sure. How? …We’re again not entirely sure. Essentially all current AI research is empirical, not theoretical. It’s like we’ve discovered a whole new parallel dimension of epistemology, and we’re still at the “exploring with torches and pith helmets” stage. This is one reason people so excited about / unnerved by modern AI.
The above graph is not just an illustration of a point, but a personal note; it’s an embedded Metaculus forecast, and I just started a new engineering job there. At Metaculus we provide a platform for collective-wisdom forecasts such as the above. (And we have quite a few AI and LLM questions/forecasts.) As you can see, our forecasting community seems to think that there’s a pretty decent chance of LLMs attracting deleterious legislative attention in the near future.
I would understand why! (Without agreeing.) Modern AI models are unnerving. They seem more like djinn than any other new technology we’ve seen in quite some time. But love them or fear them, they aren’t about to dive back into their lamp. (It’s only a matter of time before someone releases an open-source, or at least open-weight, LLM at least as good as ChatGPT.) And what’s most exciting and unnerving of all is the distinct possibility that they remain only the tip of the AI iceberg…