-
Install pytorch in conda
conda install pytorch torchvision cuda80 -c soumith -
Clone this repo
-
Download data in ''data'' directory. The structure looks like below
./data/frames ./data/labels -
Training
python main.py --batch-size 128 --arch resnet18 --workers 4 --pretrained --model_name resnet18 python main.py --batch-size 128 --arch resnet50 --workers 4 --pretrained --model_name resnet50
-
Dataloader 首先,用 pytorch 定義的 Dataset Class ,寫一個HandCam_Dataset。裏面會取出對應的圖片以及標註。 接著對圖片做處理: 將圖片normalize,縮小到224*224大小,再隨機flip圖片做到Data Augmentation
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) self.transform = transforms.Compose([ transforms.Scale(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]) -
Model 使用 pytorch 提供的 resnet18 model 和 resnet50 model。並且直接使用已經已經 pretrained 在 imagenet 的參數,這樣開始 train 會比較快。 pretrain model 再最後一層是 Fully-Connected (512 -> 1000) 和 Fully-Connected (2048 -> 1000),因為 imagenet 有 1000 class,而現在的 dataset 只有 24 class 所以要改成 Fully-Connected (512 -> 24) 和 Fully-Connected (2048 -> 24)
-
Training Procedure Learning rate : 0.01 (learning rate decay every 5 epoch) Batch size : 128 Epoch : 30
Resnet18 : 70.28%
Resnet50 : 72.3%