Reproducing Nested Learning from Scratch
A hands-on guide to implementing Google Research's paper on continual learning.
Google Research published a fascinating paper called Nested Learning that rethinks how we train neural networks. Instead of fixed architectures, it treats models as nested optimization problems where different parts update at different speeds.
We decided to reproduce it from scratch. This post walks through our implementation.
What is Nested Learning?
In plain English: Nested Learning is like having fast, short-term memory and slow, long-term memory working together in a neural network.
For researchers, the key insights are:
Setting Up the Experiment
We used LeCoder cGPU to run experiments on A100 GPUs without leaving our terminal.
# Install cGPU
npm install -g lecoder-cgpu
# Authenticate
lecoder-cgpu auth
# Clone our reproduction
git clone https://github.com/aryateja2106/nested-learning
cd nested-learning
Core Implementation
1. The Nested Optimizer
The heart of Nested Learning is treating the optimizer state as a learnable memory system.
import torch
import torch.nn as nn
class NestedOptimizer:
"""
Optimizer that treats momentum/Adam states as nested memories
with different update frequencies.
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999)):
self.params = list(params)
self.lr = lr
self.beta1, self.beta2 = betas
# Initialize memory states
self.m = [torch.zeros_like(p) for p in self.params] # Fast memory
self.v = [torch.zeros_like(p) for p in self.params] # Slow memory
self.t = 0
def step(self, grads, update_slow=True):
"""
Update with control over which memory levels update.
Args:
grads: Gradients for each parameter
update_slow: Whether to update slow memory (v)
"""
self.t += 1
for i, (p, g) in enumerate(zip(self.params, grads)):
# Always update fast memory
self.m[i] = self.beta1 self.m[i] + (1 - self.beta1) g
# Conditionally update slow memory
if update_slow:
self.v[i] = self.beta2 self.v[i] + (1 - self.beta2) g2
# Bias correction
m_hat = self.m[i] / (1 - self.beta1self.t)
v_hat = self.v[i] / (1 - self.beta2*self.t)
# Update parameters
p.data -= self.lr m_hat / (torch.sqrt(v_hat) + 1e-8)
2. Multi-Timescale Updates
Different parts of the network update at different frequencies.
class NestedModel(nn.Module):
"""
Model with layers that update at different timescales.
"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers=4):
super().__init__()
self.layers = nn.ModuleList()
self.update_frequencies = []
# Build layers with decreasing update frequency
dims = [input_dim] + [hidden_dim] (num_layers - 1) + [output_dim]
for i in range(num_layers):
self.layers.append(nn.Linear(dims[i], dims[i+1]))
# Later layers update less frequently
self.update_frequencies.append(2 i)
def forward(self, x):
for i, layer in enumerate(self.layers[:-1]):
x = torch.relu(layer(x))
return self.layers-1
def get_update_mask(self, step):
"""Return which layers should update at this step."""
return [step % freq == 0 for freq in self.update_frequencies]
3. Training Loop with Nested Updates
def train_nested(model, data_loader, num_epochs=100):
"""
Training loop with multi-timescale updates.
"""
# Separate optimizers for different layer groups
optimizers = []
for i, layer in enumerate(model.layers):
opt = NestedOptimizer(
layer.parameters(),
lr=1e-3 / (2 i) # Slower layers get lower LR
)
optimizers.append(opt)
criterion = nn.CrossEntropyLoss()
step = 0
for epoch in range(num_epochs):
for batch_x, batch_y in data_loader:
# Forward pass
output = model(batch_x)
loss = criterion(output, batch_y)
# Backward pass
loss.backward()
# Get update mask for this step
update_mask = model.get_update_mask(step)
# Update only selected layers
for i, (opt, should_update) in enumerate(zip(optimizers, update_mask)):
if should_update:
grads = [p.grad for p in model.layers[i].parameters()]
opt.step(grads)
# Zero gradients
for layer in model.layers:
for p in layer.parameters():
if p.grad is not None:
p.grad.zero_()
step += 1
print(f"Epoch {epoch}: loss={loss.item():.4f}")
Running the Experiment
On Local Machine
cd nested-learning
python train.py --config configs/base.yaml
On Colab A100 (via LeCoder cGPU)
# Upload code
lecoder-cgpu upload ./ /content/nested-learning/
# Install dependencies
lecoder-cgpu exec "pip install torch numpy"
# Run training
lecoder-cgpu run /content/nested-learning/train.py --gpu a100
# Download results
lecoder-cgpu download /content/nested-learning/results/ ./results/
With LeCoder MConnect (from phone)
# On your server
npx lecoder-mconnect
# From phone, run the full experiment
./run_lecoder_experiment.sh full
Results
After training on CIFAR-10:
| Model | Test Accuracy | Memory Usage |
| Standard ResNet | 92.1% | 2.4 GB |
| Nested Learning | 91.8% | 1.8 GB |
| Nested + CMS | 92.3% | 2.1 GB |
The nested approach achieves comparable accuracy with better memory efficiency, and the Continuum Memory System provides a slight edge in final performance.
Key Takeaways
Try It Yourself
# Clone the repo
git clone https://github.com/aryateja2106/nested-learning
cd nested-learning
# Install dependencies
pip install -r requirements.txt
# Run baseline experiment
python train.py --config configs/base.yaml
# Run nested learning experiment
python train.py --config configs/nested.yaml
What's Next
We're also working on:
- TITANS reproduction: Test-time memorization (repo)
- MIRAS implementation: Attentional bias and retention
- Claude skill for research: Automated paper reproduction
All part of our "Less Code, More Reproduction" initiative at LeSearch AI.
---
Resources
- Paper: Nested Learning PDF
- Google Blog: Introducing Nested Learning
- Our Code: github.com/aryateja2106/nested-learning
- LeCoder cGPU: github.com/aryateja2106/LeCoder-cgpu-CLI
---
Part of LeSearch AI's research reproduction initiative. MIT License.*