I went down a very deep rabbit hole, so this is part 2 on how to make models learn. I will discuss how we will be able to remember things more dynamically beyond just the good old RAG. Please read Part 1 if you wanna learn about these models being stateless being a problem and how RL can augment models to learn over time even if we don’t have complete information on the context or human.
A quick recap on why we need a more dynamic memory system… LLMs are stateless.
Each interaction is treated independently, with no inherent memory of prior interactions. They essentially cannot learn continuously from real-time experiences. Making it impossible to develop personalised long-term understanding of individuals. It won’t be able to adapt dynamically based on human feedback and shifting contexts - making it feel unintelligent.
So how can it go about remembering the right things in the right moment?
Retrieval aware transformer
I am going to assume most of you who got this far in this article most likely are aware of memory RAG so I won’t spend time on it. A quick recap on it: it is two step external memory process - it retrieves relevant snippets from external store (e.g. vector database), injects them into the current prompt context and pass the combined prompt to LLM for generation. It makes the LLM appears to remember things from the past. When it is really just explicitly feeding relevant snippets back into each new session’s context window.
Retrieval aware transforms are different. instead of injecting snippets externally at prompt level, they involve the retrieval mechanism directly within the transformer attention layers. Meaning the model can reference external memory during its internal generation process, not just before it.
The model learns explicitly to decide when and what to retrieve internally and dynamically.
This “in-model retrieval” isn’t just a prepend but woven into the attention mechanism. RAG cannot achieve this level of flexibility and finer grain because it is constrained to the potential compression with context window limits. This is possible because the model can pause internally during each token generation to consider “Should I fetch external context relevant to wheat I am current generating?” If yes, it will perform the operation before moving on to generate the next token.
The Multi-Head latent attention from DeepSeek is a great example.
The part that I find particular fascinating is that I always assumed LLMs aren’t inherently great at explicitly measuring their own uncertainty. With the contextual architectures, the internal attention mechanism is trained to sense uncertainty - emerges implicitly from certain internal metrics during training
Attention distribution pattern: when its more confident it tends to be clear well defined tokens. when its uncertain, it tends to be more diffuse or spread out.
Token probability distributions: models generate tokens by assigning probabilities to all possible tokens. If uncertain, the probability distribution becomes flatter - harder to predict what’s next.
So the model basically is trained to notice when the attention mechanism learns to correlate diffused attention or flat token distribution with a retrieval trigger. It learns a latent representation that external retrieval can help reduce uncertainty.
This means if you ask your system
“What to do for lunch?”
Traditional RAG: broadly fetch general meal preferences, dietary info or recent food orders upfront - many won’t be used in final response and can slow down interaction
Contextual transformer: Initially generates quickly without unnecessary retrieval. Only when mid-generation does it sense a need then it will grab that specific granular information - less overhead on redundant retrieval.
Why should you care?
Us understanding the material and what is possible is so critical to make the right design and experience tradeoffs.
We need to be able to balance stable vs open-endedness where RAG may be more useful for customer support but when it comes to more personalised and goal oriented experiences a dynamic retrieval aware transformer architecture may be necessary.