|
| 1 | +import os |
| 2 | +import argparse |
| 3 | +from util.helper import get_optimizer_argparse, preprocess_args, create_exp_directory, BaseConfig, get_logging_config |
| 4 | +from util.data import Dataset |
| 5 | +from util.evaluation import evaluate_model, get_eval, get_model_scores |
| 6 | +from util.cmn import CollaborativeMemoryNetwork |
| 7 | +import numpy as np |
| 8 | +import tensorflow as tf |
| 9 | +from logging.config import dictConfig |
| 10 | +from tqdm import tqdm |
| 11 | + |
| 12 | +parser = argparse.ArgumentParser(parents=[get_optimizer_argparse()], |
| 13 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
| 14 | +parser.add_argument('-g', '--gpu', help='set gpu device number 0-3', type=str, default=0) |
| 15 | +parser.add_argument('--iters', help='Max iters', type=int, default=30) |
| 16 | +parser.add_argument('-b', '--batch_size', help='Batch Size', type=int, default=128) |
| 17 | +parser.add_argument('-e', '--embedding', help='Embedding Size', type=int, default=50) |
| 18 | +parser.add_argument('--dataset', help='path to file', type=str, default='pretrain_data/citeulike-a.npz') |
| 19 | +parser.add_argument('--hops', help='Number of hops/layers', type=int, default=2) |
| 20 | +parser.add_argument('-n', '--neg', help='Negative Samples Count', type=int, default=4) |
| 21 | +parser.add_argument('--l2', help='l2 Regularization', type=float, default=0.1) |
| 22 | +parser.add_argument('-l', '--logdir', help='Set custom name for logdirectory', |
| 23 | + type=str, default=None) |
| 24 | +parser.add_argument('--resume', help='Resume existing from logdir', action="store_true") |
| 25 | +parser.add_argument('--pretrain', help='Load pretrained user/item embeddings', type=str, |
| 26 | + default='pretrain/citeulike-a_e50.npz') |
| 27 | +parser.set_defaults(optimizer='rmsprop', learning_rate=0.001, decay=0.9, momentum=0.9) |
| 28 | +FLAGS = parser.parse_args() |
| 29 | +preprocess_args(FLAGS) |
| 30 | +os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu |
| 31 | + |
| 32 | +# Create results in here unless we specify a logdir |
| 33 | +BASE_DIR = 'result/' |
| 34 | +if FLAGS.logdir is not None and not os.path.exists(FLAGS.logdir): |
| 35 | + os.mkdir(FLAGS.logdir) |
| 36 | + |
| 37 | +class Config(BaseConfig): |
| 38 | + logdir = create_exp_directory(BASE_DIR) if FLAGS.logdir is None else FLAGS.logdir |
| 39 | + filename = FLAGS.dataset |
| 40 | + embed_size = FLAGS.embedding |
| 41 | + batch_size = FLAGS.batch_size |
| 42 | + hops = FLAGS.hops |
| 43 | + l2 = FLAGS.l2 |
| 44 | + user_count = -1 |
| 45 | + item_count = -1 |
| 46 | + optimizer = FLAGS.optimizer |
| 47 | + tol = 1e-5 |
| 48 | + neg_count = FLAGS.neg |
| 49 | + optimizer_params = FLAGS.optimizer_params |
| 50 | + grad_clip = 5.0 |
| 51 | + decay_rate = 0.9 |
| 52 | + learning_rate = FLAGS.learning_rate |
| 53 | + pretrain = FLAGS.pretrain |
| 54 | + max_neighbors = -1 |
| 55 | + |
| 56 | +config = Config() |
| 57 | + |
| 58 | +if FLAGS.resume: |
| 59 | + config.save_directory = config.logdir |
| 60 | + config.load() |
| 61 | + |
| 62 | +dictConfig(get_logging_config(config.logdir)) |
| 63 | +dataset = Dataset(config.filename) |
| 64 | + |
| 65 | +config.item_count = dataset.item_count |
| 66 | +config.user_count = dataset.user_count |
| 67 | +config.save_directory = config.logdir |
| 68 | +config.max_neighbors = dataset._max_user_neighbors |
| 69 | + |
| 70 | +tf.logging.info('\n\n%s\n\n' % config) |
| 71 | + |
| 72 | +if not FLAGS.resume: |
| 73 | + config.save() |
| 74 | + |
| 75 | +model = CollaborativeMemoryNetwork(config) |
| 76 | + |
| 77 | +sv = tf.train.Supervisor(logdir=config.logdir, save_model_secs=60 * 10, |
| 78 | + save_summaries_secs=0) |
| 79 | + |
| 80 | +sess = sv.prepare_or_wait_for_session(config=tf.ConfigProto( |
| 81 | + gpu_options=tf.GPUOptions(allow_growth=True))) |
| 82 | + |
| 83 | +if not FLAGS.resume: |
| 84 | + pretrain = np.load(FLAGS.pretrain) |
| 85 | + sess.graph._unsafe_unfinalize() |
| 86 | + tf.logging.info('Loading Pretrained Embeddings.... from %s' % FLAGS.pretrain) |
| 87 | + sess.run([ |
| 88 | + model.user_memory.embeddings.assign(pretrain['user']*0.5), |
| 89 | + model.item_memory.embeddings.assign(pretrain['item']*0.5)]) |
| 90 | + |
| 91 | +# Train Loop |
| 92 | +for i in range(FLAGS.iters): |
| 93 | + if sv.should_stop(): |
| 94 | + break |
| 95 | + |
| 96 | + progress = tqdm(enumerate(dataset.get_data(FLAGS.batch_size, True, FLAGS.neg)), |
| 97 | + dynamic_ncols=True, total=(dataset.train_size * FLAGS.neg) // FLAGS.batch_size) |
| 98 | + loss = [] |
| 99 | + for k, example in progress: |
| 100 | + ratings, pos_neighborhoods, pos_neighborhood_length, \ |
| 101 | + neg_neighborhoods, neg_neighborhood_length = example |
| 102 | + feed = { |
| 103 | + model.input_users: ratings[:, 0], |
| 104 | + model.input_items: ratings[:, 1], |
| 105 | + model.input_items_negative: ratings[:, 2], |
| 106 | + model.input_neighborhoods: pos_neighborhoods, |
| 107 | + model.input_neighborhood_lengths: pos_neighborhood_length, |
| 108 | + model.input_neighborhoods_negative: neg_neighborhoods, |
| 109 | + model.input_neighborhood_lengths_negative: neg_neighborhood_length |
| 110 | + } |
| 111 | + batch_loss, _ = sess.run([model.loss, model.train], feed) |
| 112 | + loss.append(batch_loss) |
| 113 | + progress.set_description(u"[{}] Loss: {:,.4f} » » » » ".format(i, batch_loss)) |
| 114 | + |
| 115 | + tf.logging.info("Epoch {}: Avg Loss/Batch {:<20,.6f}".format(i, np.mean(loss))) |
| 116 | + evaluate_model(sess, dataset.test_data, dataset.item_users_list, model.input_users, model.input_items, |
| 117 | + model.input_neighborhoods, model.input_neighborhood_lengths, |
| 118 | + model.dropout, model.score, config.max_neighbors) |
| 119 | + |
| 120 | +EVAL_AT = range(1, 11) |
| 121 | +hrs, ndcgs = [], [] |
| 122 | +s = "" |
| 123 | +scores, out = get_model_scores(sess, dataset.test_data, dataset.item_users_list, model.input_users, model.input_items, |
| 124 | + model.input_neighborhoods, model.input_neighborhood_lengths, |
| 125 | + model.dropout, model.score, config.max_neighbors, True) |
| 126 | + |
| 127 | +for k in EVAL_AT: |
| 128 | + hr, ndcg = get_eval(scores, len(scores[0])-1, k) |
| 129 | + hrs.append(hr) |
| 130 | + ndcgs.append(ndcg) |
| 131 | + s += "{:<14} {:<14.6f}{:<14} {:.6f}\n".format('HR@%s' % k, hr, |
| 132 | + 'NDCG@%s' % k, ndcg) |
| 133 | +tf.logging.info(s) |
| 134 | + |
| 135 | +with open("{}/final_results".format(config.logdir), 'w') as fout: |
| 136 | + header = ','.join([str(k) for k in EVAL_AT]) |
| 137 | + fout.write("{},{}\n".format('metric', header)) |
| 138 | + ndcg = ','.join([str(x) for x in ndcgs]) |
| 139 | + hr = ','.join([str(x) for x in hrs]) |
| 140 | + fout.write("ndcg,{}\n".format(ndcg)) |
| 141 | + fout.write("hr,{}".format(hr)) |
| 142 | + |
| 143 | +tf.logging.info("Saving model...") |
| 144 | +# Save before exiting |
| 145 | +sv.saver.save(sess, sv.save_path, |
| 146 | + global_step=tf.contrib.framework.get_global_step()) |
| 147 | +sv.request_stop() |
| 148 | + |
0 commit comments