Skip to content

This is the second assignment in Reinforcement Learning course at MIMUW

Notifications You must be signed in to change notification settings

kubawini/MetricLearning

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Metric Learning in Reinforcement Learning

Overview

This project implements and compares supervised, contrastive, and stitching approaches for metric learning in a maze-solving reinforcement learning environment. The models estimate distances from each state to the goal and guide search accordingly. This is the assignment for University of Warsaw Reinforcement Learning course. The general concept was created by the teacher, while implementation of the core functionalities (~90% of the whole project) was done by myself.

Setup

All training methods used consistent hyperparameters:

  • train_steps: 1,000,000
  • n_train_trajectories: 4000
  • learning_rate: 0.0001
  • batch_size: 64

Model Architectures (BRO Net)

Method Input Dim Hidden Dim Output Dim
Supervised 800 128 64
Contrastive 400 128 64
Stitching 400 128 64

Learning Methods

Supervised Learning

Trains the model to predict the exact number of steps to the goal (as classification). Performs well when predictions are accurate but is prone to large mistakes and lacks generalization to unseen transitions.

Contrastive Learning

Learns a latent space where the L2 distance between state embeddings reflects proximity in time. More robust and consistent; better at preserving relative order between states.

Stitching

Uses the contrastive setup but is trained to connect parts of different trajectories. Slightly more biased, but generalizes better to new compositions of paths.

Metrics

Correlation

Plots show predicted vs. actual distance in 10 trajectories.

  • Supervised learning produces discrete but sometimes inaccurate estimates.
  • Contrastive and stitching yield smoother, proportional predictions with some bias.

Supervised

Supervised Correlation

Contrastive

Contrastive Correlation

Stitching

Stitching Correlation

Heatmaps

Visual representation of predicted distances in a maze. First pictures for each method: Yellow = goal, dark blue = walls. Second pictures: The walls are dark blue. The estimated distance scale is shown on the right hand side.

Supervised

Supervised Maze Supervised Heatmap

Contrastive

Contrastive Maze Contrastive Heatmap

Stitching

Stitching Maze Stitching Heatmap

Solved Rate

Measured as whether the model could reach the goal using its predictions to guide search. Also reports the average number of expanded nodes during search.

Method With Search (Rate / Nodes) Without Search (Rate / Nodes)
Supervised 0.995 / 43.8 0.285 / 11.6
Contrastive 0.994 / 27.5 0.561 / 16.7
Stitching 0.997 / 33.6 0.391 / 14.8

Contrastive learning provides the best efficiency and generalization, even with some scale bias.

Training Curves

Supervised

Supervised Loss

Contrastive

Contrastive Loss

Summary

  • Supervised learning performs well with dense supervision but fails to generalize when discrete predictions are off.
  • Contrastive learning builds a latent space where relative proximity is preserved, leading to better planning and fewer expanded nodes.
  • Stitching extends contrastive learning to combine known trajectory parts and shows promise in generalizing to new sequences.

Insights

  • Supervised learning predicts distances directly but lacks robustness.
  • Contrastive learning is better for guiding search and planning due to consistent relative distances.
  • Stitching enables reusing trajectory fragments for generalization, critical in complex RL tasks.

Lessons Learned

  • Avoided float16 due to instability; float64 worked well and didn't exceed memory limits.
  • Latent space learned via contrastive loss gives better structure for planning than direct regression or classification.

About

This is the second assignment in Reinforcement Learning course at MIMUW

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published