This repository contains the pytorch code for multiple CNN architectures and improve methods based on the following papers, hope the implementation and results will helpful for your research!!
- Architecure
- (lenet) LeNet-5, convolutional neural networks
- (alexnet) ImageNet Classification with Deep Convolutional Neural Networks
- (vgg) Very Deep Convolutional Networks for Large-Scale Image Recognition
- (resnet) Deep Residual Learning for Image Recognition
- (preresnet) Identity Mappings in Deep Residual Networks
- (resnext) Aggregated Residual Transformations for Deep Neural Networks
- (densenet) Densely Connected Convolutional Networks
- (senet) Squeeze-and-Excitation Networks
- (bam) BAM: Bottleneck Attention Module
- (cbam) CBAM: Convolutional Block Attention Module
- (genet) Gather-Excite: Exploiting Feature Context in Convolutional Neural Networks
- (sknet) SKNet: Selective Kernel Networks
- Regularization
- Learning Rate Scheduler
- Python (>=3.6)
- PyTorch (>=1.1.0)
- Tensorboard(>=1.4.0) (for visualization)
- Other dependencies (pyyaml, easydict)
PS: for TensorboardX version, check tag pt1.0
pip install -r requirements.txtsimply run the cmd for the training:
## 1 GPU for lenet
CUDA_VISIBLE_DEVICES=0 python -u train.py --work-path ./experiments/cifar10/lenet
## resume from ckpt
CUDA_VISIBLE_DEVICES=0 python -u train.py --work-path ./experiments/cifar10/lenet --resume
## 2 GPUs for resnet1202
CUDA_VISIBLE_DEVICES=0,1 python -u train.py --work-path ./experiments/cifar10/preresnet1202
## 4 GPUs for densenet190bc
CUDA_VISIBLE_DEVICES=0,1,2,3 python -u train.py --work-path ./experiments/cifar10/densenet190bcWe use yaml file config.yaml to save the parameters, check any files in ./experimets for more details.
You can see the training curve via tensorboard, tensorboard --logdir path-to-event --port your-port.
The training log will be dumped via logging, check log.txt in your work path.
| architecture | params | batch size | epoch | C10 test acc (%) | C100 test acc (%) |
|---|---|---|---|---|---|
| Lecun | 62K | 128 | 250 | 67.46 | 34.10 |
| alexnet | 2.4M | 128 | 250 | 75.56 | 38.67 |
| vgg19 | 20M | 128 | 250 | 93.00 | 72.07 |
| preresnet20 | 0.27M | 128 | 250 | 91.88 | 67.03 |
| preresnet110 | 1.7M | 128 | 250 | 94.24 | 72.96 |
| preresnet1202 | 19.4M | 128 | 250 | 94.74 | 75.28 |
| densenet100bc | 0.76M | 64 | 300 | 95.08 | 77.55 |
| densenet190bc | 25.6M | 64 | 300 | 96.11 | 82.59 |
| resnext29_16x64d | 68.1M | 128 | 300 | 95.94 | 83.18 |
| se_resnext29_16x64d | 68.6M | 128 | 300 | 96.15 | 83.65 |
| cbam_resnext29_16x64d | 68.7M | 128 | 300 | 96.27 | 83.62 |
| ge_resnext29_16x64d | 70.0M | 128 | 300 | 96.21 | 83.57 |
PS: the default data augmentation methods are RandomCrop + RandomHorizontalFlip + Normalize,
and the √ means which additional method be used. 🍰
| architecture | epoch | cutout | mixup | C10 test acc (%) |
|---|---|---|---|---|
| preresnet20 | 250 | 91.88 | ||
| preresnet20 | 250 | √ | 92.57 | |
| preresnet20 | 250 | √ | 92.71 | |
| preresnet20 | 250 | √ | √ | 92.66 |
| preresnet110 | 250 | 94.24 | ||
| preresnet110 | 250 | √ | 94.67 | |
| preresnet110 | 250 | √ | 94.94 | |
| preresnet110 | 250 | √ | √ | 95.66 |
| se_resnext29_16x64d | 300 | 96.15 | ||
| se_resnext29_16x64d | 300 | √ | 96.60 | |
| se_resnext29_16x64d | 300 | √ | 96.86 | |
| se_resnext29_16x64d | 300 | √ | √ | 97.03 |
| cbam_resnext29_16x64d | 300 | √ | √ | 97.16 |
| ge_resnext29_16x64d | 300 | √ | √ | 97.19 |
| -- | -- | -- | -- | -- |
| shake_resnet26_2x64d | 1800 | 96.94 | ||
| shake_resnet26_2x64d | 1800 | √ | 97.20 | |
| shake_resnet26_2x64d | 1800 | √ | 97.42 | |
| shake_resnet26_2x64d | 1800 | √ | √ | 97.71 |
PS: shake_resnet26_2x64d achieved 97.71% test accuracy with cutout and mixup!!
It's cool, right?
| architecture | epoch | step decay | cosine | htd(-6,3) | cutout | mixup | C10 test acc (%) |
|---|---|---|---|---|---|---|---|
| preresnet20 | 250 | √ | 91.88 | ||||
| preresnet20 | 250 | √ | 92.13 | ||||
| preresnet20 | 250 | √ | 92.44 | ||||
| preresnet20 | 250 | √ | √ | √ | 93.30 | ||
| preresnet110 | 250 | √ | 94.24 | ||||
| preresnet110 | 250 | √ | 94.48 | ||||
| preresnet110 | 250 | √ | 94.82 | ||||
| preresnet110 | 250 | √ | √ | √ | 95.88 |
Provided codes were adapted from
- kuangliu/pytorch-cifar
- bearpaw/pytorch-classification
- timgaripov/swa
- xgastaldi/shake-shake
- uoguelph-mlrg/Cutout
- facebookresearch/mixup-cifar10
- BIGBALLON/cifar-10-cnn
- BayesWatch/pytorch-GENet
- Jongchan/attention-module
- pppLang/SKNet
Feel free to contact me if you have any suggestions or questions, issues are welcome,
create a PR if you find any bugs or you want to contribute. 😊
@misc{bigballon2019cifarzoo,
author = {Wei Li},
title = {CIFAR-ZOO: PyTorch implementation of CNNs for CIFAR dataset},
howpublished = {\url{https://github.com/BIGBALLON/CIFAR-ZOO}},
year = {2019}
}
