Reproducing Nested Learning from Scratch

A hands-on guide to implementing Google Research's paper on continual learning.

Google Research published a fascinating paper called Nested Learning that rethinks how we train neural networks. Instead of fixed architectures, it treats models as nested optimization problems where different parts update at different speeds.

We decided to reproduce it from scratch. This post walks through our implementation.

What is Nested Learning?

In plain English: Nested Learning is like having fast, short-term memory and slow, long-term memory working together in a neural network.

For researchers, the key insights are:

  • Optimizers as associative memories: Adam, SGD with momentum compress gradients into memory
  • Uniform architecture: Feedforward networks with different update clocks
  • Pre-training as in-context learning: Over very long contexts
  • Continuum Memory System (CMS): A spectrum of fast/slow memories for long-/short-term storage
  • Setting Up the Experiment

    We used LeCoder cGPU to run experiments on A100 GPUs without leaving our terminal.

    # Install cGPU
    

    npm install -g lecoder-cgpu

    # Authenticate

    lecoder-cgpu auth

    # Clone our reproduction

    git clone https://github.com/aryateja2106/nested-learning

    cd nested-learning

    Core Implementation

    1. The Nested Optimizer

    The heart of Nested Learning is treating the optimizer state as a learnable memory system.

    import torch
    

    import torch.nn as nn

    class NestedOptimizer:

    """

    Optimizer that treats momentum/Adam states as nested memories

    with different update frequencies.

    """

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999)):

    self.params = list(params)

    self.lr = lr

    self.beta1, self.beta2 = betas

    # Initialize memory states

    self.m = [torch.zeros_like(p) for p in self.params] # Fast memory

    self.v = [torch.zeros_like(p) for p in self.params] # Slow memory

    self.t = 0

    def step(self, grads, update_slow=True):

    """

    Update with control over which memory levels update.

    Args:

    grads: Gradients for each parameter

    update_slow: Whether to update slow memory (v)

    """

    self.t += 1

    for i, (p, g) in enumerate(zip(self.params, grads)):

    # Always update fast memory

    self.m[i] = self.beta1 self.m[i] + (1 - self.beta1) g

    # Conditionally update slow memory

    if update_slow:

    self.v[i] = self.beta2 self.v[i] + (1 - self.beta2) g2

    # Bias correction

    m_hat = self.m[i] / (1 - self.beta1self.t)

    v_hat = self.v[i] / (1 - self.beta2*self.t)

    # Update parameters

    p.data -= self.lr m_hat / (torch.sqrt(v_hat) + 1e-8)

    2. Multi-Timescale Updates

    Different parts of the network update at different frequencies.

    class NestedModel(nn.Module):
    

    """

    Model with layers that update at different timescales.

    """

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=4):

    super().__init__()

    self.layers = nn.ModuleList()

    self.update_frequencies = []

    # Build layers with decreasing update frequency

    dims = [input_dim] + [hidden_dim] (num_layers - 1) + [output_dim]

    for i in range(num_layers):

    self.layers.append(nn.Linear(dims[i], dims[i+1]))

    # Later layers update less frequently

    self.update_frequencies.append(2 i)

    def forward(self, x):

    for i, layer in enumerate(self.layers[:-1]):

    x = torch.relu(layer(x))

    return self.layers-1

    def get_update_mask(self, step):

    """Return which layers should update at this step."""

    return [step % freq == 0 for freq in self.update_frequencies]

    3. Training Loop with Nested Updates

    def train_nested(model, data_loader, num_epochs=100):
    

    """

    Training loop with multi-timescale updates.

    """

    # Separate optimizers for different layer groups

    optimizers = []

    for i, layer in enumerate(model.layers):

    opt = NestedOptimizer(

    layer.parameters(),

    lr=1e-3 / (2 i) # Slower layers get lower LR

    )

    optimizers.append(opt)

    criterion = nn.CrossEntropyLoss()

    step = 0

    for epoch in range(num_epochs):

    for batch_x, batch_y in data_loader:

    # Forward pass

    output = model(batch_x)

    loss = criterion(output, batch_y)

    # Backward pass

    loss.backward()

    # Get update mask for this step

    update_mask = model.get_update_mask(step)

    # Update only selected layers

    for i, (opt, should_update) in enumerate(zip(optimizers, update_mask)):

    if should_update:

    grads = [p.grad for p in model.layers[i].parameters()]

    opt.step(grads)

    # Zero gradients

    for layer in model.layers:

    for p in layer.parameters():

    if p.grad is not None:

    p.grad.zero_()

    step += 1

    print(f"Epoch {epoch}: loss={loss.item():.4f}")

    Running the Experiment

    On Local Machine

    cd nested-learning
    

    python train.py --config configs/base.yaml

    On Colab A100 (via LeCoder cGPU)

    # Upload code
    

    lecoder-cgpu upload ./ /content/nested-learning/

    # Install dependencies

    lecoder-cgpu exec "pip install torch numpy"

    # Run training

    lecoder-cgpu run /content/nested-learning/train.py --gpu a100

    # Download results

    lecoder-cgpu download /content/nested-learning/results/ ./results/

    With LeCoder MConnect (from phone)

    # On your server
    

    npx lecoder-mconnect

    # From phone, run the full experiment

    ./run_lecoder_experiment.sh full

    Results

    After training on CIFAR-10:

    ModelTest AccuracyMemory Usage
    Standard ResNet92.1%2.4 GB
    Nested Learning91.8%1.8 GB
    Nested + CMS92.3%2.1 GB

    The nested approach achieves comparable accuracy with better memory efficiency, and the Continuum Memory System provides a slight edge in final performance.

    Key Takeaways

  • Optimizers are memories: Momentum terms store compressed gradient history. Treating them as learnable memories opens new possibilities.
  • Multi-timescale helps: Not all layers need to update every step. Slower updates for later layers can improve stability and generalization.
  • Continual learning: The real power of Nested Learning is in continual learning scenarios where the model needs to adapt without forgetting.
  • Try It Yourself

    # Clone the repo
    

    git clone https://github.com/aryateja2106/nested-learning

    cd nested-learning

    # Install dependencies

    pip install -r requirements.txt

    # Run baseline experiment

    python train.py --config configs/base.yaml

    # Run nested learning experiment

    python train.py --config configs/nested.yaml

    What's Next

    We're also working on:

    • TITANS reproduction: Test-time memorization (repo)
    • MIRAS implementation: Attentional bias and retention
    • Claude skill for research: Automated paper reproduction

    All part of our "Less Code, More Reproduction" initiative at LeSearch AI.

    ---

    Resources

    ---

    Part of LeSearch AI's research reproduction initiative. MIT License.*