Titans Part 2: Optimizing Memory Updates and Adaptive Learning
This is a follow-up to my previous post on implementing Titans. After getting the basic implementation working, I dove deeper into performance optimization with Claude Code as a pair programmer. This post covers the journey from a working but slow implementation to something that trains efficiently on commodity GPUs.
The Problem: Memory Update Bottleneck
After training for ~27k iterations on OpenWebText, I hit a wall. The loss plateaued around 3.8 and profiling revealed why - the memory update mechanism was consuming a disproportionate amount of GPU time:
_fused_weight_update_kernel 450ms 15.84%
triton_poi_fused_sigmoid 510ms 17.92%
Nearly 34% of CUDA time was spent on Titans memory operations. For a 130M parameter model, this overhead was unacceptable. The MFU (Model FLOPs Utilization) was stuck at 3.39%.
The Culprit: Per-Token Gradient Tensors
Looking at the original implementation, the bottleneck was clear. For each forward pass, we were computing per-token gradients that materialized a massive tensor:
# Old path: outer product for each token
dW = torch.einsum('btc,bth->btch', d_pred, h1) # [B, T, C, H]
With batch=2, sequence=512, hidden=768, expansion=2:
- Tensor size: 2 × 512 × 768 × 1536 = 1.2 billion elements
- Memory: ~4.8GB just for gradients
This was the memory killer. Every forward pass allocated gigabytes of temporary memory, causing GPU memory pressure and slow kernel launches.
The Fix: Aggregated Gradients
The insight came from staring at the equation. We don’t need per-token gradients - we need the sum over tokens for the weight update:
# New path: aggregate directly (no per-token storage!)
dW = torch.einsum('btc,bth->bch', d_pred, h1) # [B, C, H]
Same information, but:
- Tensor size: 2 × 768 × 1536 = 2.4 million elements
- Memory: ~9.6MB
500x memory reduction. The momentum update becomes an approximation (mean gradient instead of per-token EMA), but for typical momentum values (0.9), the difference is negligible.
def aggregated_gradient_memory_update(
keys, values, weights, momentum,
lr, mom_coef, decay
):
# Forward pass
h1 = silu(keys @ W0.T)
pred = h1 @ W1.T
# Backward - aggregate over T
d_pred = (2.0 / C) * (pred - values)
dW1 = torch.einsum('btc,bth->bch', d_pred, h1) # Sum over T
dW0 = torch.einsum('bth,btc->bhc', dh1_pre, keys)
# Momentum with mean gradient
scale = 1.0 / T
new_mom = mom_coef * momentum + (1 - mom_coef) * dW * scale
new_weights = (1 - decay) * weights - lr * new_mom
return new_weights, new_mom
Adding Adaptive Memory Parameters
The Titans paper mentions adaptive learning rates - per-token learned parameters for lr, momentum, and decay. My initial implementation used fixed hyperparameters. Time to match the paper.
The idea: let the model learn how to update its memory based on the input:
class NeuralMemory(nn.Module):
def __init__(self, config):
# ... existing code ...
if config.adaptive_memory:
self.to_lr = nn.Linear(n_embd, 1)
self.to_momentum = nn.Linear(n_embd, 1)
self.to_decay = nn.Linear(n_embd, 1)
def update(self, x, state):
if self.adaptive:
# Per-token adaptive parameters via learned projections
adaptive_lr = torch.sigmoid(self.to_lr(x)) * self.lr_max
adaptive_momentum = torch.sigmoid(self.to_momentum(x))
adaptive_decay = torch.sigmoid(self.to_decay(x))
The sigmoid bounds keep parameters in sensible ranges:
- LR: (0, lr_max) - typically 0.01
- Momentum: (0, 1) - higher = more memory persistence
- Decay: (0, 1) - higher = faster forgetting
A subtle bug emerged: the default _init_weights function was resetting my carefully initialized biases to zero. The fix required special-casing the adaptive projection layers:
def _init_weights(self, module):
match module:
case NeuralMemory() if module.adaptive:
# Preserve adaptive bias initialization
nn.init.constant_(module.to_momentum.bias, 2.0) # sigmoid(2)≈0.88
nn.init.constant_(module.to_decay.bias, -4.0) # sigmoid(-4)≈0.02
case nn.Linear():
nn.init.normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
Optimizer Bottleneck: Fused AdamW
With memory updates optimized, a new bottleneck appeared. Profiling showed:
Optimizer.step#AdamW.step 425ms 41.17%
aten::_foreach_mul_ 87ms 8.43%
aten::_foreach_addcdiv_ 81ms 7.84%
aten::_foreach_lerp_ 64ms 6.18%
The optimizer was taking 41% of GPU time! AdamW was using 7+ separate foreach operations instead of a single fused kernel.
The fix was embarrassingly simple:
optimizer = torch.optim.AdamW(
model.parameters(),
lr=1e-4,
fused=True # Single CUDA kernel
)
Results after enabling fused optimizer:
| Metric | Before | After |
|---|---|---|
| Optimizer time | 426ms (41%) | 170ms (17%) |
| Total CUDA time | 1,034ms | 787ms |
| MFU | 6.3% | 10.8% |
A 60% reduction in optimizer overhead, just by adding fused=True.
Profiling Deep Dive
Building a proper profiling setup was essential. I added a kernel category summary to quickly identify bottleneck categories:
Category Time (ms) %
----------------------------------------
other 2053.51 50.6
optimizer 680.86 16.8
matmul/gemm 535.45 13.2
triton/fused 458.69 11.3
elementwise 173.29 4.3
memory_ops 153.29 3.8
attention 0.00 0.0
The high “other” category revealed torch.compile overhead - the CompiledFxGraph wrapper operations. Not much to optimize there, but knowing where time goes prevents chasing phantom bottlenecks. Memory operations are now just 3.8% thanks to our aggregated gradient approach.
Checkpoint Compatibility
A practical challenge: resuming training from old checkpoints with new model code. The adaptive memory projections add 6 new parameters that old checkpoints don’t have.
The solution checks the actual state_dict rather than trusting config:
# Check if checkpoint actually has adaptive memory weights
ckpt_has_adaptive = any(
"to_lr" in k or "to_momentum" in k or "to_decay" in k
for k in state_dict
)
if upgrade_to_adaptive:
# Load with strict=False to allow missing keys
missing, _ = model.load_state_dict(state_dict, strict=False)
print(f"New parameters initialized: {len(missing)} keys")
This lets you upgrade a trained model to use adaptive memory mid-training. The core 128M weights are preserved; only the 6 small projection weights start fresh.
Results
After all optimizations, training a 128M Titans model on L4:
| Metric | Initial | Optimized |
|---|---|---|
| MFU | 3.4% | 10.8% |
| Time/iter | ~2400ms | ~1950ms |
| Memory overhead | 4.8GB | 9.6MB |
| Optimizer overhead | 41% | 17% |
The throughput improvement came from:
- Larger batch sizes (2 → 4) enabled by memory reduction
- Fused optimizer kernels
- Better GPU utilization
Current training on OpenWebText:
- Iteration: 29k+
- Train loss: 3.55
- Validation loss: 3.57
- Training with adaptive memory enabled
- Steady loss descent with no overfitting (train/val gap < 0.02)
Lessons From This Round
-
Profile at the kernel level: High-level timing hides the real story. The optimizer taking 41% was invisible without detailed profiling.
-
Memory bandwidth matters: The 500x gradient memory reduction wasn’t about avoiding OOM - it was about reducing memory pressure that slowed everything else.
-
Check the defaults:
fused=Truefor AdamW exists but isn’t default. Same fortorch.compileon attention. Read the docs for performance flags. -
Aggregation beats iteration: When you can sum instead of iterate, do it. The per-token gradient loop was algorithmically correct but computationally wasteful.
-
Checkpoint evolution: Plan for model architecture changes. Strict loading breaks iteration; flexible loading with smart defaults enables continuous improvement.
What’s Next
- Training to convergence with adaptive memory
- Comparison: fixed vs adaptive memory parameters
- Longer context experiments (2k, 4k, 8k tokens)
- Potential MAG (Memory as Gate) implementation
Code remains at github.com/XiaoConstantine/nanogpt-titans.