This project implements and compares supervised, contrastive, and stitching approaches for metric learning in a maze-solving reinforcement learning environment. The models estimate distances from each state to the goal and guide search accordingly. This is the assignment for University of Warsaw Reinforcement Learning course. The general concept was created by the teacher, while implementation of the core functionalities (~90% of the whole project) was done by myself.
All training methods used consistent hyperparameters:
train_steps: 1,000,000n_train_trajectories: 4000learning_rate: 0.0001batch_size: 64
| Method | Input Dim | Hidden Dim | Output Dim |
|---|---|---|---|
| Supervised | 800 | 128 | 64 |
| Contrastive | 400 | 128 | 64 |
| Stitching | 400 | 128 | 64 |
Trains the model to predict the exact number of steps to the goal (as classification). Performs well when predictions are accurate but is prone to large mistakes and lacks generalization to unseen transitions.
Learns a latent space where the L2 distance between state embeddings reflects proximity in time. More robust and consistent; better at preserving relative order between states.
Uses the contrastive setup but is trained to connect parts of different trajectories. Slightly more biased, but generalizes better to new compositions of paths.
Plots show predicted vs. actual distance in 10 trajectories.
- Supervised learning produces discrete but sometimes inaccurate estimates.
- Contrastive and stitching yield smoother, proportional predictions with some bias.
Visual representation of predicted distances in a maze. First pictures for each method: Yellow = goal, dark blue = walls. Second pictures: The walls are dark blue. The estimated distance scale is shown on the right hand side.
Measured as whether the model could reach the goal using its predictions to guide search. Also reports the average number of expanded nodes during search.
| Method | With Search (Rate / Nodes) | Without Search (Rate / Nodes) |
|---|---|---|
| Supervised | 0.995 / 43.8 | 0.285 / 11.6 |
| Contrastive | 0.994 / 27.5 | 0.561 / 16.7 |
| Stitching | 0.997 / 33.6 | 0.391 / 14.8 |
Contrastive learning provides the best efficiency and generalization, even with some scale bias.
- Supervised learning performs well with dense supervision but fails to generalize when discrete predictions are off.
- Contrastive learning builds a latent space where relative proximity is preserved, leading to better planning and fewer expanded nodes.
- Stitching extends contrastive learning to combine known trajectory parts and shows promise in generalizing to new sequences.
- Supervised learning predicts distances directly but lacks robustness.
- Contrastive learning is better for guiding search and planning due to consistent relative distances.
- Stitching enables reusing trajectory fragments for generalization, critical in complex RL tasks.
- Avoided
float16due to instability;float64worked well and didn't exceed memory limits. - Latent space learned via contrastive loss gives better structure for planning than direct regression or classification.










