Implementation of the NeurIPS 2025 paper "Nested Learning" by Behrouz et al. (Google Research).
Mathematical verification status:
- ✅ Paper-exact: Delta rule, L2RegressionAttention, LinearAttention, CMS update scheduling
- ✅ Paper-exact (optional): DMGD with
internal_loss_mode='l2_regression'(Eq 21-23) - Default: DMGD uses practical surrogate loss (stable, conceptually aligned)
This version includes paper-exact modes, mathematical correctness tests, and inference-time adaptation.
| Component | Status | Description |
|---|---|---|
| DeepMomentumGD | ✅ Complete | Memory modules trained via internal loss L^(2) |
| SelfModifyingLinear | ✅ Complete | Paper-exact (normalized=False) and stable modes |
| L2RegressionAttention | ✅ New | Paper's L2 regression variant (Eq 27-29) |
| ContinuumMemorySystem | ✅ Complete | Paper-exact nesting (use_residual=False) available |
| LinearAttention | ✅ Fixed | Per-sequence memory (not batch-averaged) |
| HOPE Model | ✅ Complete | Full integration with all components |
| Math Correctness Tests | ✅ New | Finite difference gradient verification |
| Benchmarks | ✅ Complete | WikiText-103 and LAMBADA evaluation scripts |
See IMPLEMENTATION_STATUS.md for detailed component documentation.
# Clone the repository
git clone https://github.com/yourusername/nested-learning.git
cd nested-learning
# Create virtual environment
python -m venv .venv
source .venv/bin/activate
# Install dependencies
pip install -r requirements.txt
# Install in development mode
pip install -e .# Run all tests (27+ tests including math correctness)
python -m pytest tests/ -v
# Run the demo
python examples/nested_learning_demo.py# Synthetic benchmarks (fast, no data download)
python experiments/benchmark_pattern_learning.py # DMGD vs SGD
python experiments/benchmark_continual_learning.py # HOPE vs Vanilla
# WikiText-103 benchmark (Paper Table 1)
python experiments/benchmark_wikitext.py --test
# LAMBADA zero-shot evaluation
python experiments/benchmark_lambada.py --testnested-learning/
├── src/nested_learning/
│ ├── optimizers/ # DeepMomentumGD, NestedDeepMomentumGD
│ ├── memory/ # ContinuumMemorySystem, AssociativeMemory
│ ├── models/ # HOPE, SelfModifyingAttention
│ ├── training/ # NestedLearningTrainer
│ └── utils/ # AMP utilities, helpers
├── experiments/ # Benchmark scripts
├── examples/ # Demo scripts
├── tests/ # Comprehensive test suite
└── docs/ # Documentation
from nested_learning.optimizers import DeepMomentumGD
# Memory modules are trained via internal loss every step
optimizer = DeepMomentumGD(
params=model.parameters(),
lr=1e-3,
memory_lr=1e-4, # Learning rate for memory modules
use_shared_memory=True, # Efficient memory pooling
gradient_checkpointing=True, # Memory efficient (v0.3.0)
use_factorized_memory=True, # Parameter efficient (v0.3.0)
)
for x, y in dataloader:
optimizer.zero_grad()
loss = criterion(model(x), y)
loss.backward()
optimizer.step() # Trains both model AND memory modulesfrom nested_learning.optimizers import NestedDeepMomentumGD
optimizer = NestedDeepMomentumGD(
params=model.parameters(),
lr=0.01,
memory_lr=0.001,
meta_learning=True,
)
# Inner loop: training steps
for _ in range(inner_steps):
loss = model(train_batch)
optimizer.zero_grad()
loss.backward()
optimizer.step(create_graph=True) # Preserve graph for meta-learning
# Outer loop: meta-update based on validation
val_loss = model(val_batch)
optimizer.meta_step(val_loss) # Updates memory modulesfrom nested_learning.models import HOPE
model = HOPE(
dim=512,
n_layers=12,
n_heads=8,
vocab_size=50257,
use_self_modification=True, # Enable delta-rule attention
)
# Weights update during forward pass (works in training AND inference)
logits = model(input_ids, enable_self_modification=True)
# Apply pending weight updates after backward pass
model.apply_pending_updates()# adaptation_scope allows temporary self-modification that reverts on exit
# Useful for inference-time adaptation without permanent weight drift
with model.adaptation_scope():
# Self-modification is active within this scope
output = model.generate(prompt, max_new_tokens=100)
# Weights adapt to the prompt context
# Weights are automatically restored to pre-scope values
# Safe to reuse model for unrelated promptsfrom nested_learning.models.titans import SelfModifyingLinear, L2RegressionAttention
from nested_learning.memory import ContinuumMemorySystem
# Paper-exact self-modification (Eq 28-29)
layer = SelfModifyingLinear(512, 512, normalized=False)
# Paper-exact CMS nesting (Eq 30)
cms = ContinuumMemorySystem(dim=512, num_levels=3)
output = cms(x, use_residual=False) # True nesting: MLP_k(MLP_{k-1}(...))
# Paper-exact L2 regression attention (Eq 27-29)
attn = L2RegressionAttention(dim=512, num_heads=8, normalized=False)Paper-exactness quick reference:
| Component | Paper-Exact | Stable Default |
|---|---|---|
| SelfModifyingLinear | normalized=False |
normalized=True |
| CMS | use_residual=False |
use_residual=True |
| L2RegressionAttention | normalized=False |
normalized=True |
| DMGD internal loss | internal_loss_mode='l2_regression' |
'surrogate' |
See IMPLEMENTATION_STATUS.md for detailed configuration guide.
from nested_learning.utils.amp import NestedAMPWrapper
amp = NestedAMPWrapper(enabled=True, dtype=torch.bfloat16)
with amp.model_autocast():
loss = model(batch)
amp.backward(loss)
amp.unscale_and_clip(optimizer, max_norm=1.0)
amp.step(optimizer)
amp.update()from nested_learning.training import NestedLearningTrainer
trainer = NestedLearningTrainer(
model=model,
optimizer=optimizer,
train_dataloader=train_loader,
val_dataloader=val_loader,
)
# Trains with CMS multi-frequency updates and self-modification
trainer.train(num_epochs=10)The core innovation: memory modules learn to compress gradients through a self-supervised internal loss.
Note: The current DMGD internal loss is a practical surrogate (not the literal L² regression on K-V matrices from the paper). It uses:
- Reconstruction Loss: Memory output should capture gradient direction (cosine similarity)
- Magnitude Preservation: Output magnitude proportional to input gradient
- Temporal Smoothness: Smooth changes over consecutive steps
This is conceptually aligned with the paper's internal objective but not mathematically identical. For paper-exact L² regression, see L2RegressionAttention.
Different MLP levels update at different rates:
- Level 0: Every step (working memory)
- Level 1: Every 10 steps (short-term patterns)
- Level 2: Every 100 steps (long-term knowledge)
Weights change during forward pass via delta rule (Equations 28-29):
# Normalized mode (default, more stable):
W -= lr * (W @ x @ x^T) / (x^T @ x)
# Paper-exact mode (normalized=False):
W -= lr * (W @ x @ x^T)
Updates are deferred until after backward pass to preserve gradient computation. New in v0.3.1: Works during both training AND inference for online adaptation.
Trade compute for memory - useful for large models:
optimizer = DeepMomentumGD(..., gradient_checkpointing=True)Low-rank factorization for large parameter tensors (4x parameter reduction):
optimizer = DeepMomentumGD(..., use_factorized_memory=True, factorized_rank=16)BF16/FP16 training with gradient scaling:
from nested_learning.utils.amp import AMPTrainer, AMPConfig
trainer = AMPTrainer(model, optimizer, amp_config=AMPConfig(dtype=torch.bfloat16))# Run all tests
python -m pytest tests/ -v
# Run specific test suites
python -m pytest tests/test_math_correctness.py -v # Mathematical verification
python -m pytest tests/test_meta_learning.py -v # Meta-learning validation
python -m pytest tests/test_scalability.py -v # Scalability features
python -m pytest tests/test_optimizers.py -v # Optimizer testsTest Results: 27+ passed (including 12 math correctness tests)
- Distributed training (multi-GPU with DDP)
- Continual learning benchmark
- Long-context benchmark (100K+ tokens)
- Hyperparameter search for exact paper reproduction
If you use this implementation, please cite the original paper:
@inproceedings{behrouz2025nested,
title={Nested Learning: The Illusion of Deep Learning Architectures},
author={Behrouz, Ali and Razaviyayn, Meisam and Zhong, Peiling and Mirrokni, Vahab},
booktitle={Advances in Neural Information Processing Systems},
year={2025}
}- IMPLEMENTATION_STATUS.md - Detailed component documentation
- CONTRIBUTING.md - Contribution guidelines
- docs/CONCEPTS.md - Paper concepts explained
MIT License - See LICENSE file for details
This implementation is based on the NeurIPS 2025 paper by Google Research. All credit for the theoretical contributions goes to the original authors.