|
| 1 | +from data_utils import load_task, vectorize_data |
| 2 | +from sklearn import cross_validation, metrics |
| 3 | +from memn2n import MemN2N |
| 4 | +from itertools import chain |
| 5 | +from six.moves import range, reduce |
| 6 | + |
| 7 | +import tensorflow as tf |
| 8 | +import numpy as np |
| 9 | + |
| 10 | +tf.flags.DEFINE_float("learning_rate", 0.01, "Learning rate for SGD.") |
| 11 | +tf.flags.DEFINE_float("anneal_rate", 25, "Number of epochs between halving the learnign rate.") |
| 12 | +tf.flags.DEFINE_float("anneal_stop_epoch", 100, "Epoch number to end annealed lr schedule.") |
| 13 | +tf.flags.DEFINE_float("max_grad_norm", 40.0, "Clip gradients to this norm.") |
| 14 | +tf.flags.DEFINE_integer("evaluation_interval", 10, "Evaluate and print results every x epochs") |
| 15 | +tf.flags.DEFINE_integer("batch_size", 32, "Batch size for training.") |
| 16 | +tf.flags.DEFINE_integer("hops", 3, "Number of hops in the Memory Network.") |
| 17 | +tf.flags.DEFINE_integer("epochs", 100, "Number of epochs to train for.") |
| 18 | +tf.flags.DEFINE_integer("embedding_size", 20, "Embedding size for embedding matrices.") |
| 19 | +tf.flags.DEFINE_integer("memory_size", 50, "Maximum size of memory.") |
| 20 | +tf.flags.DEFINE_integer("task_id", 1, "bAbI task id, 1 <= id <= 20") |
| 21 | +tf.flags.DEFINE_integer("random_state", None, "Random state.") |
| 22 | +tf.flags.DEFINE_string("data_dir", "data/tasks_1-20_v1-2/en/", "Directory containing bAbI tasks") |
| 23 | +FLAGS = tf.flags.FLAGS |
| 24 | + |
| 25 | +print("Started Task:", FLAGS.task_id) |
| 26 | + |
| 27 | +# task data |
| 28 | +train, test = load_task(FLAGS.data_dir, FLAGS.task_id) |
| 29 | +data = train + test |
| 30 | + |
| 31 | +vocab = sorted(reduce(lambda x, y: x | y, (set(list(chain.from_iterable(s)) + q + a) for s, q, a in data))) |
| 32 | +word_idx = dict((c, i + 1) for i, c in enumerate(vocab)) |
| 33 | + |
| 34 | +max_story_size = max(map(len, (s for s, _, _ in data))) |
| 35 | +mean_story_size = int(np.mean([ len(s) for s, _, _ in data ])) |
| 36 | +sentence_size = max(map(len, chain.from_iterable(s for s, _, _ in data))) |
| 37 | +query_size = max(map(len, (q for _, q, _ in data))) |
| 38 | +memory_size = min(FLAGS.memory_size, max_story_size) |
| 39 | + |
| 40 | +# Add time words/indexes |
| 41 | +for i in range(memory_size): |
| 42 | + word_idx['time{}'.format(i+1)] = 'time{}'.format(i+1) |
| 43 | + |
| 44 | +vocab_size = len(word_idx) + 1 # +1 for nil word |
| 45 | +sentence_size = max(query_size, sentence_size) # for the position |
| 46 | +sentence_size += 1 # +1 for time words |
| 47 | + |
| 48 | +print("Longest sentence length", sentence_size) |
| 49 | +print("Longest story length", max_story_size) |
| 50 | +print("Average story length", mean_story_size) |
| 51 | + |
| 52 | +# train/validation/test sets |
| 53 | +S, Q, A = vectorize_data(train, word_idx, sentence_size, memory_size) |
| 54 | +trainS, valS, trainQ, valQ, trainA, valA = cross_validation.train_test_split(S, Q, A, test_size=.1, random_state=FLAGS.random_state) |
| 55 | +testS, testQ, testA = vectorize_data(test, word_idx, sentence_size, memory_size) |
| 56 | + |
| 57 | +print(testS[0]) |
| 58 | + |
| 59 | +print("Training set shape", trainS.shape) |
| 60 | + |
| 61 | +# params |
| 62 | +n_train = trainS.shape[0] |
| 63 | +n_test = testS.shape[0] |
| 64 | +n_val = valS.shape[0] |
| 65 | + |
| 66 | +print("Training Size", n_train) |
| 67 | +print("Validation Size", n_val) |
| 68 | +print("Testing Size", n_test) |
| 69 | + |
| 70 | +train_labels = np.argmax(trainA, axis=1) |
| 71 | +test_labels = np.argmax(testA, axis=1) |
| 72 | +val_labels = np.argmax(valA, axis=1) |
| 73 | + |
| 74 | +tf.set_random_seed(FLAGS.random_state) |
| 75 | +batch_size = FLAGS.batch_size |
| 76 | + |
| 77 | +batches = zip(range(0, n_train-batch_size, batch_size), range(batch_size, n_train, batch_size)) |
| 78 | +batches = [(start, end) for start, end in batches] |
| 79 | + |
| 80 | +with tf.Session() as sess: |
| 81 | + model = MemN2N(batch_size, vocab_size, sentence_size, memory_size, FLAGS.embedding_size, session=sess, |
| 82 | + hops=FLAGS.hops, max_grad_norm=FLAGS.max_grad_norm) |
| 83 | + for t in range(1, FLAGS.epochs+1): |
| 84 | + # Stepped learning rate |
| 85 | + if t - 1 <= FLAGS.anneal_stop_epoch: |
| 86 | + anneal = 2.0 ** ((t - 1) // FLAGS.anneal_rate) |
| 87 | + else: |
| 88 | + anneal = 2.0 ** (FLAGS.anneal_stop_epoch // FLAGS.anneal_rate) |
| 89 | + lr = FLAGS.learning_rate / anneal |
| 90 | + |
| 91 | + np.random.shuffle(batches) |
| 92 | + total_cost = 0.0 |
| 93 | + for start, end in batches: |
| 94 | + s = trainS[start:end] |
| 95 | + q = trainQ[start:end] |
| 96 | + a = trainA[start:end] |
| 97 | + cost_t = model.batch_fit(s, q, a, lr) |
| 98 | + total_cost += cost_t |
| 99 | + |
| 100 | + if t % FLAGS.evaluation_interval == 0: |
| 101 | + train_preds = [] |
| 102 | + for start in range(0, n_train, batch_size): |
| 103 | + end = start + batch_size |
| 104 | + s = trainS[start:end] |
| 105 | + q = trainQ[start:end] |
| 106 | + pred = model.predict(s, q) |
| 107 | + train_preds += list(pred) |
| 108 | + |
| 109 | + val_preds = model.predict(valS, valQ) |
| 110 | + train_acc = metrics.accuracy_score(np.array(train_preds), train_labels) |
| 111 | + val_acc = metrics.accuracy_score(val_preds, val_labels) |
| 112 | + |
| 113 | + print('-----------------------') |
| 114 | + print('Epoch', t) |
| 115 | + print('Total Cost:', total_cost) |
| 116 | + print('Training Accuracy:', train_acc) |
| 117 | + print('Validation Accuracy:', val_acc) |
| 118 | + print('-----------------------') |
| 119 | + |
| 120 | + test_preds = model.predict(testS, testQ) |
| 121 | + test_acc = metrics.accuracy_score(test_preds, test_labels) |
| 122 | + print("Testing Accuracy:", test_acc) |
0 commit comments