Adaptive Mamba: Elastic compute with dynamic Matryoshka scaling
🤗 Model on HuggingFace | 📂 GitHub
Also known as: ElasticGPT • Accordion-Net • Dynamic Compute Budget Model
Adamba combines three efficiency techniques into a unified pipeline:
| Technique | Implementation | Purpose |
|---|---|---|
| Matryoshka (MRL) | Width: 128 → 4096 per layer | Elastic compute |
| Early Exit | ConfidenceGate | Skip layers when confident |
| Static SSM | Mamba at full dim | Stable memory backbone |
┌─────────────────────────────────────────────────┐
│ PROMPT → LayerDimPredictor → [dim per layer] │
│ │
│ Attention + MLP: Dynamic (sliced) │
│ Mamba: Static (full dim) │
│ │
│ Gate > 0.95 → EXIT EARLY │
│ Gate < 0.50 → EXPAND remaining layers │
└─────────────────────────────────────────────────┘
- 🎯 LayerDimPredictor: Predicts per-layer dims upfront (no graph breaks)
- 🚪 ConfidenceGate: Unified early exit + dim expansion
- 📦 MatryoshkaKVCache: Slice-down cache strategy
- 🧠 Static Mamba: Uses efficient CUDA kernel (no SSM state resizing)
Key insight: Resizing SSM states on the fly is mathematically messy. Resizing Attention heads is trivial. Keep Mamba static, make Attention/MLP dynamic.
tiny_experiment validation:
Early exits: 71.5% (37.5% compute saved!)
Gate loss: 0.28 → 0.03 (self-supervised difficulty learning)
Hard tasks get more dims than easy tasks ✓
The gate trains itself using shadow loss: comparing what loss would be at each layer to teach the gate when it's safe to exit.
nanochat-d32 (1.9B, 32 layers, dim=2048)
↓ Surgery (add 32 Mamba layers)
Stage 1: 6.4B (dim=2048) ← Hybrid, no expansion
↓ Progressive expand
Stage 2: 9.3B (dim=2560)
↓ Progressive expand
Stage 3: 20B (dim=4096)
| Stage | Model Size | Dim | Hours (8×H100) | Est. Cost |
|---|---|---|---|---|
| 1 | 6.4B | 2048 | 40h | $1,000 |
| 2 | 9.3B | 2560 | 50h | $1,200 |
| 3 | 20B | 4096 | 100h | $2,400 |
| Total | 190h | ~$4,600 |
# 1. Download nanochat-d32 base
huggingface-cli download karpathy/nanochat-d32 \
--local-dir ~/.cache/nanochat/chatsft_checkpoints/d32
# 2. Create 6.4B hybrid
python -m scripts.surgery --new-dim=2048
# 3. Train Stage 1
torchrun --nproc_per_node=8 -m scripts.fractal_train \
--checkpoint ~/.cache/nanochat/hybrid_checkpoints/d32_2048/model.pt \
--expanded-dim=2048 --matryoshka --sample-dim
# 4. Expand → Stage 2
python -m scripts.surgery --expand-from=2048 --new-dim=2560| File | Purpose |
|---|---|
nanochat/hybrid_gpt.py |
Adamba model (Mamba+Attention+Gate) |
nanochat/mamba_block.py |
Static Mamba with SSM fallback |
nanochat/matryoshka.py |
Dimension slicing + energy loss |
nanochat/confidence_probe.py |
V2: LayerDimPredictor, ConfidenceGate |
scripts/surgery.py |
Create/expand hybrid checkpoints |
scripts/fractal_train.py |
Matryoshka training |
tiny_experiment/ |
Local validation suite |
| Mode | Dim | Compute | Use Case |
|---|---|---|---|
| Ghost | 128 | ~0.4% | Trivial tasks |
| Whisper | 512 | ~6% | Simple Q&A |
| Normal | 1024 | ~25% | General use |
| Think | 2048+ | 100% | Complex reasoning |
- Matryoshka Embeddings (OpenAI/Harvard): MRL applied to model weights
- FastBERT / DeeBERT: Confidence-based early exit
- Mixture of Depths (Google DeepMind): Dynamic FLOP allocation
Use --phase flag in scripts/fractal_train.py:
| Phase | Command | What Trains | Matryoshka |
|---|---|---|---|
| 1 | --phase=1 |
Mamba only (frozen attn/mlp) | ✗ Off |
| 2 | --phase=2 |
All + Gates | ✓ On |
| 3 | --phase=3 |
Expanded weights | ✓ On |
# Phase 1: Integrate Mamba (freeze attention)
torchrun -m scripts.fractal_train --phase=1 --checkpoint=phase1.pt
# Phase 2: Matryoshka + Gates (unfreeze all)
torchrun -m scripts.fractal_train --phase=2 --checkpoint=phase2.pt
# Phase 3: After expansion surgery
torchrun -m scripts.fractal_train --phase=3 --checkpoint=phase3.ptMamba (Stage 1): Uses zero-init ✓ (correct, nothing to copy)
MLP/Attention Expansion (Stage 2/3): Uses LoRA-style initialization:
# Instead of zeros, new dims = A @ B (low-rank, small init)
expand_weight_lora(weight, target_size, dim, rank=16, std=0.01)Problem Solved: MHA weights are stored as [Head1 | Head2 | ... | Head16].
Solution: expand_attention_interleaved() expands each head's dims separately:
[Head1_128+32 | Head2_128+32 | ... | Head16_128+32] ← CORRECT
Functions in scripts/surgery.py:
expand_weight_lora()- LoRA-style low-rank initializationexpand_attention_interleaved()- Per-head dimension expansion
Based on nanochat by Andrej Karpathy