Although large language models can improve their performance by generating a chain of thought (CoT) — intermediate text tokens that break down the process of responding to a prompt into a series of steps — much of the CoT text is aimed at maintaining fluency (such as “a”, “of”, “we know that”) rather than reasoning (“a² + b² = c²”). Researchers addressed this inefficiency.
What’s new: Shibo Hao, Sainbayar Sukhbaatar, and colleagues at Meta and University of California San Diego introduced Coconut (Chain of Continuous Thought), a method that trains large language models (LLMs) to process chains of thought as vectors rather than words.
Key insight: A large language model (LLM) can be broken into an embedding layer, transformer, and classification layer. To generate the next text token from input text, the embedding layer embeds the text; given the text, the transformer outputs a hidden vector; and the classification layer maps the vector to text-token probabilities. Based on these probabilities, a decoding algorithm selects the next token to generate, which feeds back into the input text sequence to generate the next vector, and so on. When a model generates a CoT, committing to a specific word at each step limits the information available to the meanings of the words generated so far, while a vector could represent multiple possible words. Using vectors instead of text enables the CoT to encode richer information.
How it works: The authors built three LLMs by fine-tuning a pre-trained GPT-2 on three datasets of prompts, CoTs, and final outputs: GSM8k (grade-school math word problems); ProntoQA (questions and answers about fictional concepts expressed in made-up words, including synthetic CoTs in natural language); and (3) ProsQA, a more challenging question-answering dataset introduced by the authors, inspired by ProntoQA but with longer reasoning steps.
- Fine-tuning began with supervised training. The LLM learned to generate the text in the training set, including the CoT and final answers. As usual, the last-generated text token was fed back as input to produce the next token.
- Fine-tuning then progressed through k stages for each example. At each stage, the authors replaced a sentence in the CoT text with a thought vector (or two) to build a sequence of k replaced sentences. The start and end of the chain of thought vectors were marked by two special tokens. During vector steps, the LLM fed its output vectors back as input without decoding them into text. The LLM learned to generate only the remaining text tokens, not the thought vectors, which encouraged it to optimize its vector-based reasoning indirectly.
- During inference, the LLM generated a special token to mark the start of the chain of vectors. From this point, it fed back its output vectors, bypassing text decoding for six steps. Afterward, the LLM switched back to generating text for final output.
Results: The authors compared their method to a pretrained GPT-2 that was fine-tuned on the same datasets to predict the next word, including reasoning.
- On ProntoQA, Coconut outperformed the fine-tuned GPT-2 while producing far fewer interim vectors (Coconut) or tokens (baseline LLMs). It achieved 99.8 percent accuracy after generating nine vectors (or tokens) on average, while GPT-2 achieved 98.8 percent accuracy using 92.5 text tokens.
- Coconut excelled on ProsQA’s more complex questions. It achieved 97.0 percent accuracy after generating 14.2 vectors (or tokens) on average, while GPT-2 achieved 77.5 percent accuracy after generating 49.4 text tokens on average.
Yes, but: On GSM8k, Coconut achieved 34.1 percent accuracy, while the baseline LLM achieved 42.9 percent. However, it generated significantly fewer vectors and tokens than the CoT generated tokens. Coconut generated 8.2 vectors on average compared to the baseline LLM’s 25 text tokens.
Why it matters: A traditional CoT commits to a single word at each step and thus encodes one reasoning path in a single CoT. Vectors are less interpretable to humans than language, but the model’s output layer can still decode the thought vectors into probabilities over tokens. Further, inspecting the distribution of words stored along all continuous CoT vectors offers a way to understand multiple potential thought paths stored in one continuous CoT.
We’re thinking: LLMs typically learn to reason over text, mainly because text data is widely available to train on. In contrast, neuroscience shows that the part of the human brain responsible for language largely goes quiet during reasoning tasks, which suggests that explicit language is not a key mechanism for reasoning. Coconut takes an intriguing step to enable LLMs to explore representations that don’t encode the limitations of language.