diff --git a/models/atme_model.py b/models/atme_model.py index 9d50334..f524577 100644 --- a/models/atme_model.py +++ b/models/atme_model.py @@ -1,198 +1,214 @@ -import os -import torch -from .base_model import BaseModel -from . import networks -from util.image_pool import DiscPool -import util.util as util -from itertools import chain -# from data import create_dataset - - -class AtmeModel(BaseModel): - """ This class implements the ATME model, for learning a mapping from input images to output images given paired data. - - The model training requires '--dataset_mode aligned' dataset. - By default, it uses a '--netG unet_256_attn' U-Net generator, - a '--netD basic' discriminator (PatchGAN), - and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper). - - atme paper: https://arxiv.org/pdf/x.pdf - """ - - @staticmethod - def modify_commandline_options(parser, is_train=True): - """Add new dataset-specific options, and rewrite default values for existing options. - - Parameters: - parser -- original option parser - is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. - - Returns: - the modified parser. - - For atme, we do not use image buffer - The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1 - By default, we use vanilla GAN loss, UNet with instance norm, and aligned datasets. - """ - # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/) - parser.add_argument('--n_save_noisy', type=int, default=0, help='number of D_t and W_t to keep track of') - parser.add_argument('--mask_size', type=int, default=512) - parser.add_argument('--dim', type=int, default=64, help='dim for the ddm UNet') - parser.add_argument('--dim_mults', type=tuple, default=(1, 2, 4, 8), help='dim_mults for the ddm UNet') - parser.add_argument('--groups', type=int, default=8, help='number of groups for GroupNorm within ResnetBlocks') - parser.add_argument('--init_dim', type=int, default=64, help='output channels after initial conv2d of x_t') - parser.add_argument('--learned_sinusoidal_cond', type=bool, default=False, - help='learn fourier features for positional embedding?') - parser.add_argument('--random_fourier_features', type=bool, default=False, - help='random fourier features for positional embedding?') - parser.add_argument('--learned_sinusoidal_dim', type=int, default=16, - help='twice the number of fourier frequencies to learn') - parser.add_argument('--time_dim_mult', type=int, default=4, - help='dim * time_dim_mult amounts to output channels after time-MLP') - - if is_train: - parser.set_defaults(pool_size=0, gan_mode='vanilla') - parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') - - return parser - - - def __init__(self, opt, dataset): - """Initialize the pix2pix class. - - Parameters: - opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions - """ - BaseModel.__init__(self, opt) - self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] - if self.isTrain: - self.visual_names = ['real_A', 'fake_B', 'real_B'] - self.model_names = ['G', 'D', 'W'] - opt.pool_size = 0 - opt.gan_mode = 'vanilla' - else: - self.visual_names = ['real_A', 'fake_B'] - self.model_names = ['G', 'W'] - - self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, - not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, - **{'dim': opt.dim, - 'dim_mults': opt.dim_mults, - 'init_dim': opt.init_dim, - 'resnet_block_groups': opt.groups, - 'learned_sinusoidal_cond': opt.learned_sinusoidal_cond, - 'learned_sinusoidal_dim': opt.learned_sinusoidal_dim, - 'random_fourier_features': opt.random_fourier_features, - 'time_dim_mult': opt.time_dim_mult}) - - self.netW = networks.define_W(opt.init_type, opt.init_gain, self.gpu_ids, output_size=opt.vol_cube_dim) - self.disc_pool = DiscPool(opt, self.device, dataset, isTrain=self.isTrain) - - if self.isTrain: - self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, - opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) - - if self.isTrain: - self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) - self.criterionL1 = torch.nn.L1Loss() - self.optimizer_G = torch.optim.Adam(chain(self.netW.parameters(), self.netG.parameters()), lr=opt.lr, - betas=(opt.beta1, 0.999)) - self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) - self.optimizers.append(self.optimizer_G) - self.optimizers.append(self.optimizer_D) - - self.save_noisy = True if opt.n_save_noisy > 0 else False - if self.save_noisy: - self.save_DW_idx = torch.randint(len(dataset), (opt.n_save_noisy,)) - self.img_DW_dir = os.path.join(opt.checkpoints_dir, opt.name, 'images_noisy') - util.mkdir(self.img_DW_dir) - - def setup(self, opt): - """Load and print networks; create schedulers - - Parameters: - opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions - """ - if self.isTrain: - self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] - self.print_networks(opt.verbose) - if not self.isTrain or opt.continue_train: - load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch - self.load_networks(load_suffix, pre_train_G_path=opt.pre_train_G_path, pre_train_W_path=opt.pre_train_W_path) - - def _save_DW(self, visuals): - to_save = (self.batch_indices.view(1, -1) == self.save_DW_idx.view(-1, 1)).any(dim=0) - if any(to_save) > 0: - idx_to_save = torch.nonzero(to_save)[0] - for label, images in visuals.items(): - for idx, image in zip(idx_to_save, images[to_save]): - img_idx = self.batch_indices[idx].item() - image_numpy = util.tensor2im(image[None]) - img_path = os.path.join(self.img_DW_dir, f'epoch_{self.epoch:03d}_{label}_{img_idx}.png') - util.save_image(image_numpy, img_path) - - def set_input(self, input, epoch=None): - """Unpack input data from the dataloader and perform necessary pre-processing steps. - - Parameters: - input (dict): include the data itself and its metadata information. - - The option 'direction' can be used to swap images in domain A and domain B. - """ - self.epoch = epoch - AtoB = self.opt.direction == 'AtoB' - self.real_A = input['A' if AtoB else 'B'].to(self.device, dtype=torch.float) - if self.isTrain: - self.real_B = input['B' if AtoB else 'A'].to(self.device, dtype=torch.float) - self.image_paths = input['A_paths' if AtoB else 'B_paths'] - self.batch_indices = input['batch_indices'] - self.disc_B = self.disc_pool.query(self.batch_indices) - - def forward(self): - """Run forward pass; called by both functions and .""" - self.Disc_B = self.netW(self.disc_B) - self.noisy_A = self.real_A * (1 + self.Disc_B) - self.fake_B = self.netG(self.noisy_A, self.Disc_B) - - def backward_D(self): - """Calculate GAN loss for the discriminator""" - # Fake - fake_AB = torch.cat((self.real_A, self.fake_B), 1) - pred_fake = self.netD(fake_AB.detach()) - self.loss_D_fake = self.criterionGAN(pred_fake, False) - # Real - real_AB = torch.cat((self.real_A, self.real_B), 1) - pred_real = self.netD(real_AB) - self.loss_D_real = self.criterionGAN(pred_real, True) - # combine loss and calculate gradients - self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 - self.loss_D.backward() - - def backward_G(self): - """Calculate GAN and L1 loss for the generator""" - # First, G(A) should fake the discriminator - fake_AB = torch.cat((self.real_A, self.fake_B), 1) - self.disc_B = self.netD(fake_AB) - self.loss_G_GAN = self.criterionGAN(self.disc_B, True) - # Second, G(A) = B - self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 - # combine loss and calculate gradients - self.loss_G = self.loss_G_GAN + self.loss_G_L1 - self.loss_G.backward() - - def optimize_parameters(self): - self.forward() - # update D - self.set_requires_grad(self.netD, True) - self.optimizer_D.zero_grad() - self.backward_D() - self.optimizer_D.step() - # update G - self.set_requires_grad(self.netD, False) - self.optimizer_G.zero_grad() - self.backward_G() - self.optimizer_G.step() - # Save discriminator output - self.disc_pool.insert(self.disc_B.detach(), self.batch_indices) - if self.save_noisy: # Save images corresponding to disc_B and Disc_B - self._save_DW({'D': torch.sigmoid(self.disc_B), 'W': self.Disc_B}) \ No newline at end of file +import os +import torch +from torch.cuda.amp import GradScaler, autocast +from .base_model import BaseModel +from . import networks +from util.image_pool import DiscPool +import util.util as util +from itertools import chain +# from data import create_dataset + + +class AtmeModel(BaseModel): + """ This class implements the ATME model, for learning a mapping from input images to output images given paired data. + + The model training requires '--dataset_mode aligned' dataset. + By default, it uses a '--netG unet_256_attn' U-Net generator, + a '--netD basic' discriminator (PatchGAN), + and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper). + + atme paper: https://arxiv.org/pdf/x.pdf + """ + + @staticmethod + def modify_commandline_options(parser, is_train=True): + """Add new dataset-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + + For atme, we do not use image buffer + The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1 + By default, we use vanilla GAN loss, UNet with instance norm, and aligned datasets. + """ + # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/) + parser.add_argument('--n_save_noisy', type=int, default=0, help='number of D_t and W_t to keep track of') + parser.add_argument('--mask_size', type=int, default=512) + parser.add_argument('--dim', type=int, default=64, help='dim for the ddm UNet') + parser.add_argument('--dim_mults', type=tuple, default=(1, 2, 4, 8), help='dim_mults for the ddm UNet') + parser.add_argument('--groups', type=int, default=8, help='number of groups for GroupNorm within ResnetBlocks') + parser.add_argument('--init_dim', type=int, default=64, help='output channels after initial conv2d of x_t') + parser.add_argument('--learned_sinusoidal_cond', type=bool, default=False, + help='learn fourier features for positional embedding?') + parser.add_argument('--random_fourier_features', type=bool, default=False, + help='random fourier features for positional embedding?') + parser.add_argument('--learned_sinusoidal_dim', type=int, default=16, + help='twice the number of fourier frequencies to learn') + parser.add_argument('--time_dim_mult', type=int, default=4, + help='dim * time_dim_mult amounts to output channels after time-MLP') + + if is_train: + parser.set_defaults(pool_size=0, gan_mode='vanilla') + parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') + + return parser + + + def __init__(self, opt, dataset): + """Initialize the pix2pix class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + BaseModel.__init__(self, opt) + self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] + if self.isTrain: + self.visual_names = ['real_A', 'fake_B', 'real_B'] + self.model_names = ['G', 'D', 'W'] + opt.pool_size = 0 + opt.gan_mode = 'vanilla' + else: + self.visual_names = ['real_A', 'fake_B'] + self.model_names = ['G', 'W'] + + self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, + not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, + **{'dim': opt.dim, + 'dim_mults': opt.dim_mults, + 'init_dim': opt.init_dim, + 'resnet_block_groups': opt.groups, + 'learned_sinusoidal_cond': opt.learned_sinusoidal_cond, + 'learned_sinusoidal_dim': opt.learned_sinusoidal_dim, + 'random_fourier_features': opt.random_fourier_features, + 'time_dim_mult': opt.time_dim_mult}) + + self.netW = networks.define_W(opt.init_type, opt.init_gain, self.gpu_ids, output_size=opt.vol_cube_dim) + self.disc_pool = DiscPool(opt, self.device, dataset, isTrain=self.isTrain) + + if self.isTrain: + self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, + opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) + + if self.isTrain: + self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) + self.criterionL1 = torch.nn.L1Loss() + self.optimizer_G = torch.optim.Adam(chain(self.netW.parameters(), self.netG.parameters()), lr=opt.lr, + betas=(opt.beta1, 0.999)) + self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizers.append(self.optimizer_G) + self.optimizers.append(self.optimizer_D) + + self.use_amp = bool(getattr(opt, 'use_amp', False)) and self.isTrain and torch.cuda.is_available() + self.scaler = GradScaler(enabled=self.use_amp) + + self.save_noisy = True if opt.n_save_noisy > 0 else False + if self.save_noisy: + self.save_DW_idx = torch.randint(len(dataset), (opt.n_save_noisy,)) + self.img_DW_dir = os.path.join(opt.checkpoints_dir, opt.name, 'images_noisy') + util.mkdir(self.img_DW_dir) + + def setup(self, opt): + """Load and print networks; create schedulers + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + if self.isTrain: + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + self.print_networks(opt.verbose) + if not self.isTrain or opt.continue_train: + load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch + self.load_networks(load_suffix, pre_train_G_path=opt.pre_train_G_path, pre_train_W_path=opt.pre_train_W_path) + + def _save_DW(self, visuals): + to_save = (self.batch_indices.view(1, -1) == self.save_DW_idx.view(-1, 1)).any(dim=0) + if any(to_save) > 0: + idx_to_save = torch.nonzero(to_save)[0] + for label, images in visuals.items(): + for idx, image in zip(idx_to_save, images[to_save]): + img_idx = self.batch_indices[idx].item() + image_numpy = util.tensor2im(image[None]) + img_path = os.path.join(self.img_DW_dir, f'epoch_{self.epoch:03d}_{label}_{img_idx}.png') + util.save_image(image_numpy, img_path) + + def set_input(self, input, epoch=None): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input (dict): include the data itself and its metadata information. + + The option 'direction' can be used to swap images in domain A and domain B. + """ + self.epoch = epoch + AtoB = self.opt.direction == 'AtoB' + self.real_A = input['A' if AtoB else 'B'].to(self.device, dtype=torch.float) + if self.isTrain: + self.real_B = input['B' if AtoB else 'A'].to(self.device, dtype=torch.float) + self.image_paths = input['A_paths' if AtoB else 'B_paths'] + self.batch_indices = input['batch_indices'] + self.disc_B = self.disc_pool.query(self.batch_indices) + + def forward(self): + """Run forward pass; called by both functions and .""" + self.Disc_B = self.netW(self.disc_B) + self.noisy_A = self.real_A * (1 + self.Disc_B) + self.fake_B = self.netG(self.noisy_A, self.Disc_B) + + def backward_D(self): + """Calculate GAN loss for the discriminator""" + # Fake + fake_AB = torch.cat((self.real_A, self.fake_B), 1) + pred_fake = self.netD(fake_AB.detach()) + self.loss_D_fake = self.criterionGAN(pred_fake, False) + # Real + real_AB = torch.cat((self.real_A, self.real_B), 1) + pred_real = self.netD(real_AB) + self.loss_D_real = self.criterionGAN(pred_real, True) + # combine loss and calculate gradients + self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 + + def backward_G(self): + """Calculate GAN and L1 loss for the generator""" + # First, G(A) should fake the discriminator + fake_AB = torch.cat((self.real_A, self.fake_B), 1) + self.disc_B = self.netD(fake_AB) + self.loss_G_GAN = self.criterionGAN(self.disc_B, True) + # Second, G(A) = B + self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 + # combine loss and calculate gradients + self.loss_G = self.loss_G_GAN + self.loss_G_L1 + + def optimize_parameters(self): + with autocast(enabled=self.use_amp): + self.forward() + # update D + self.set_requires_grad(self.netD, True) + self.optimizer_D.zero_grad() + with autocast(enabled=self.use_amp): + self.backward_D() + if self.use_amp: + self.scaler.scale(self.loss_D).backward() + self.scaler.step(self.optimizer_D) + else: + self.loss_D.backward() + self.optimizer_D.step() + # update G + self.set_requires_grad(self.netD, False) + self.optimizer_G.zero_grad() + with autocast(enabled=self.use_amp): + self.backward_G() + if self.use_amp: + self.scaler.scale(self.loss_G).backward() + self.scaler.step(self.optimizer_G) + self.scaler.update() + else: + self.loss_G.backward() + self.optimizer_G.step() + # Save discriminator output + self.disc_pool.insert(self.disc_B.detach(), self.batch_indices) + if self.save_noisy: # Save images corresponding to disc_B and Disc_B + self._save_DW({'D': torch.sigmoid(self.disc_B), 'W': self.Disc_B}) diff --git a/options/atme_options.py b/options/atme_options.py index 1d87408..b5e3840 100644 --- a/options/atme_options.py +++ b/options/atme_options.py @@ -1,65 +1,67 @@ -from .base_options import BaseOptions -import argparse - - -class AtmeOptions(BaseOptions): - """This class includes training options. - - It also includes shared options defined in BaseOptions. - """ - - def initialize(self, parser): - parser = BaseOptions.initialize(self, parser) - # visdom and HTML visualization parameters - parser.add_argument('--plane', type=str, required=True, default='coronal', help='define the plane the atme is trained on') - parser.add_argument('--model_root', type=str, default='atme_coronal_output', help='path to atme coronal images (should have subfolders trainA, trainB, valA, valB, etc)') - parser.add_argument('--TestAfterTrain', default=True, action=argparse.BooleanOptionalAction, help='specify if to test immediatly after train') - parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen') - parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.') - parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') - parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') - parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') - parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') - parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') - parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') - parser.add_argument('--no_html', type=bool, default=False, help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') - # model parameters - parser.add_argument('--model', type=str, default='atme', help='chooses which model to use. [atme, simple]') - parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA') - parser.add_argument('--input_nc', type=int, default=1, help='# of input image channels: 3 for RGB and 1 for grayscale') - parser.add_argument('--output_nc', type=int, default=1, help='# of output image channels: 3 for RGB and 1 for grayscale') - parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') - parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') - parser.add_argument('--netD', type=str, default='n_layers', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') - parser.add_argument('--netG', type=str, default='unet_256_ddm', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]') - parser.add_argument('--n_layers_D', type=int, default=4, help='only used if netD==n_layers') - parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]') - parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') - parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') - parser.add_argument('--no_dropout', default=True, action=argparse.BooleanOptionalAction, help='no dropout for the generator') - # dataset parameters - parser.add_argument('--batch_size', type=int, default=1, help='input batch size') - parser.add_argument('--load_size', type=int, default=286, help='scale images to this size') - parser.add_argument('--crop_val', type=int, default=7, help='value for cropping the volume') - parser.add_argument('--stride', type=int, default=7, help='value for cropping the volume') - parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') - parser.add_argument('--dataset_mode', type=str, default='aligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]') - # network saving and loading parameters - parser.add_argument('--pre_train_W_path', type=str, default='', help='load path for pre-trained W model') - parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') - parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') - parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') - parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') - parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') - parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') - # training parameters - parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate') - parser.add_argument('--n_epochs_decay', type=int, default=500, help='number of epochs to linearly decay learning rate to zero') - parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') - parser.add_argument('--lr', type=float, default=0.00001, help='initial learning rate for adam') - parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') - parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') - parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') - parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') - - return parser \ No newline at end of file +from .base_options import BaseOptions +import argparse + + +class AtmeOptions(BaseOptions): + """This class includes training options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) + # visdom and HTML visualization parameters + parser.add_argument('--plane', type=str, required=True, default='coronal', help='define the plane the atme is trained on') + parser.add_argument('--model_root', type=str, default='atme_coronal_output', help='path to atme coronal images (should have subfolders trainA, trainB, valA, valB, etc)') + parser.add_argument('--TestAfterTrain', default=True, action=argparse.BooleanOptionalAction, help='specify if to test immediatly after train') + parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen') + parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.') + parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') + parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') + parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') + parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') + parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') + parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') + parser.add_argument('--no_html', type=bool, default=False, help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') + # model parameters + parser.add_argument('--model', type=str, default='atme', help='chooses which model to use. [atme, simple]') + parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA') + parser.add_argument('--input_nc', type=int, default=1, help='# of input image channels: 3 for RGB and 1 for grayscale') + parser.add_argument('--output_nc', type=int, default=1, help='# of output image channels: 3 for RGB and 1 for grayscale') + parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') + parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') + parser.add_argument('--netD', type=str, default='n_layers', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') + parser.add_argument('--netG', type=str, default='unet_256_ddm', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]') + parser.add_argument('--n_layers_D', type=int, default=4, help='only used if netD==n_layers') + parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]') + parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') + parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') + parser.add_argument('--no_dropout', default=True, action=argparse.BooleanOptionalAction, help='no dropout for the generator') + # dataset parameters + parser.add_argument('--batch_size', type=int, default=1, help='input batch size') + parser.add_argument('--load_size', type=int, default=286, help='scale images to this size') + parser.add_argument('--crop_val', type=int, default=7, help='value for cropping the volume') + parser.add_argument('--stride', type=int, default=7, help='value for cropping the volume') + parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') + parser.add_argument('--dataset_mode', type=str, default='aligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]') + # network saving and loading parameters + parser.add_argument('--pre_train_W_path', type=str, default='', help='load path for pre-trained W model') + parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') + parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') + parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') + parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') + parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') + parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') + # training parameters + parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate') + parser.add_argument('--n_epochs_decay', type=int, default=500, help='number of epochs to linearly decay learning rate to zero') + parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') + parser.add_argument('--lr', type=float, default=0.00001, help='initial learning rate for adam') + parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') + parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') + parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') + parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') + parser.add_argument('--use_amp', default=True, action=argparse.BooleanOptionalAction, + help='use automatic mixed precision for faster training when CUDA is available') + + return parser