Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions atme.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,15 @@ def train(opt):
dataset_size = len(dataset)
print('The number of training images = %d' % dataset_size)

model = create_model(opt, dataset)
model.setup(opt)
visualizer = Visualizer(opt)
total_iters = 0


for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1):
model = create_model(opt, dataset)
model.setup(opt)
visualizer = Visualizer(opt)
start_epoch = getattr(model, 'start_epoch', opt.epoch_count)
total_iters = getattr(model, 'start_iter', 0)
opt.epoch_count = start_epoch


for epoch in range(start_epoch, opt.n_epochs + opt.n_epochs_decay + 1):
epoch_start_time = time.time()
iter_data_time = time.time()
epoch_iter = 0
Expand Down Expand Up @@ -85,19 +87,21 @@ def train(opt):
# if opt.display_id > 0:
# visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)

if total_iters % opt.save_latest_freq == 0: # cache our latest model every <save_latest_freq> iterations
print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
model.save_networks(save_suffix)
visuals = model.get_current_visuals()
slice_num = i if opt.batch_size == 1 else random.randint(0, opt.batch_size)
save_atme_images(visuals, save_fig_dir, slice_num, iter_num=total_iters, epoch=epoch)
if total_iters % opt.save_latest_freq == 0: # cache our latest model every <save_latest_freq> iterations
print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
model.save_networks(save_suffix)
visuals = model.get_current_visuals()
slice_num = i if opt.batch_size == 1 else random.randint(0, opt.batch_size)
save_atme_images(visuals, save_fig_dir, slice_num, iter_num=total_iters, epoch=epoch)
model.save_training_state(epoch, total_iters)

iter_data_time = time.time()
if epoch % opt.save_epoch_freq == 0: # cache our model every <save_epoch_freq> epochs
print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
model.save_networks('latest')
model.save_networks(epoch)
if epoch % opt.save_epoch_freq == 0: # cache our model every <save_epoch_freq> epochs
print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
model.save_networks('latest')
model.save_networks(epoch)
model.save_training_state(epoch + 1, total_iters)

# Save D_real and D_fake
visualizer.save_D_losses(model.get_current_losses())
Expand Down
87 changes: 69 additions & 18 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ def __init__(self, opt):
torch.backends.cudnn.benchmark = True
self.loss_names = []
self.model_names = []
self.visual_names = []
self.optimizers = []
self.image_paths = []
self.metric = 0 # used for learning rate policy 'plateau'
self.visual_names = []
self.optimizers = []
self.image_paths = []
self.metric = 0 # used for learning rate policy 'plateau'
self.start_epoch = opt.epoch_count
self.start_iter = 0

@abstractmethod
def set_input(self, input):
Expand All @@ -68,12 +70,14 @@ def setup(self, opt):
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
if self.isTrain:
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
self.print_networks(opt.verbose)
if not self.isTrain or opt.continue_train:
load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
self.load_networks(load_suffix, pre_train_G_path=opt.pre_train_G_path)
if self.isTrain:
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
self.print_networks(opt.verbose)
if not self.isTrain or opt.continue_train:
load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
self.load_networks(load_suffix, pre_train_G_path=opt.pre_train_G_path)
if self.isTrain and opt.continue_train:
self.load_training_state()


def eval(self):
Expand Down Expand Up @@ -159,14 +163,61 @@ def save_specific_networks(self, networks_names, epoch):
save_path = os.path.join(self.save_dir, save_filename)
net = getattr(self, 'net' + name)

if len(self.gpu_ids) > 1 and torch.cuda.is_available():
torch.save(net.module.cpu().state_dict(), save_path)
net.cuda(self.gpu_ids[0])
else:
torch.save(net.state_dict(), save_path)

def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
if len(self.gpu_ids) > 1 and torch.cuda.is_available():
torch.save(net.module.cpu().state_dict(), save_path)
net.cuda(self.gpu_ids[0])
else:
torch.save(net.state_dict(), save_path)

def save_training_state(self, epoch, total_iters):
"""Save optimizer and scheduler states for resuming training later."""
if not self.isTrain:
return

state = {
'epoch': int(epoch),
'total_iters': int(total_iters),
'optimizer_states': [optimizer.state_dict() for optimizer in self.optimizers],
}

if hasattr(self, 'schedulers'):
state['scheduler_states'] = [scheduler.state_dict() for scheduler in getattr(self, 'schedulers', [])]

os.makedirs(self.save_dir, exist_ok=True)
save_path = os.path.join(self.save_dir, 'training_state.pth')
torch.save(state, save_path)

def load_training_state(self):
"""Load optimizer and scheduler states to resume training."""
load_path = os.path.join(self.save_dir, 'training_state.pth')
if not os.path.isfile(load_path):
return

print('loading training state from %s' % load_path)
state = torch.load(load_path, map_location='cpu')

optimizer_states = state.get('optimizer_states', [])
if len(optimizer_states) != len(self.optimizers):
print('Warning: number of optimizers does not match when loading training state.')
for optimizer, opt_state in zip(self.optimizers, optimizer_states):
optimizer.load_state_dict(opt_state)
for opt_state_value in optimizer.state.values():
for k, v in opt_state_value.items():
if isinstance(v, torch.Tensor):
opt_state_value[k] = v.to(self.device)

if hasattr(self, 'schedulers'):
scheduler_states = state.get('scheduler_states', [])
if len(scheduler_states) != len(getattr(self, 'schedulers', [])):
print('Warning: number of schedulers does not match when loading training state.')
for scheduler, scheduler_state in zip(getattr(self, 'schedulers', []), scheduler_states):
scheduler.load_state_dict(scheduler_state)

self.start_epoch = state.get('epoch', self.start_epoch)
self.start_iter = state.get('total_iters', self.start_iter)

def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
key = keys[i]
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
if module.__class__.__name__.startswith('InstanceNorm') and \
Expand Down
41 changes: 22 additions & 19 deletions simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,23 @@ def train(opt):
opt.save_dir = os.path.join(opt.main_root, opt.model_root, opt.exp_name)
mkdir(opt.save_dir)

model = create_model(opt)
model.setup(opt)
visualizer = Visualizer(opt)

train_loader = create_simple_train_dataset(opt)
print('prepare data_loader done')

total_iters = 0

figures_path = os.path.join(opt.save_dir, 'figures', 'train')
mkdir(figures_path)

slice_index = int(opt.patch_size / 2)

for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1):
model = create_model(opt)
model.setup(opt)
visualizer = Visualizer(opt)

train_loader = create_simple_train_dataset(opt)
print('prepare data_loader done')

start_epoch = getattr(model, 'start_epoch', opt.epoch_count)
total_iters = getattr(model, 'start_iter', 0)
opt.epoch_count = start_epoch

figures_path = os.path.join(opt.save_dir, 'figures', 'train')
mkdir(figures_path)

slice_index = int(opt.patch_size / 2)

for epoch in range(start_epoch, opt.n_epochs + opt.n_epochs_decay + 1):
epoch_start_time = time.time()
iter_data_time = time.time()
epoch_iter = 0
Expand All @@ -64,10 +66,11 @@ def train(opt):
visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)

iter_data_time = time.time()
if epoch % opt.save_epoch_freq == 0:
print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
model.save_networks('latest')
model.save_networks(epoch)
if epoch % opt.save_epoch_freq == 0:
print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
model.save_networks('latest')
model.save_networks(epoch)
model.save_training_state(epoch + 1, total_iters)

losses = model.get_current_losses()
visualizer.save_to_tensorboard_writer(epoch, losses)
Expand Down