Skip to content

Commit aa7880c

Browse files
author
shixiaowen03
committed
add
2 parents 6d34cb1 + 2d9c2ff commit aa7880c

File tree

14 files changed

+1994
-0
lines changed

14 files changed

+1994
-0
lines changed

nlp/Basic-EEMN-Demo/data_utils.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import os
2+
import re
3+
import numpy as np
4+
5+
def load_task(data_dir, task_id, only_supporting=False):
6+
'''Load the nth task. There are 20 tasks in total.
7+
8+
Returns a tuple containing the training and testing data for the task.
9+
'''
10+
assert task_id > 0 and task_id < 21
11+
12+
files = os.listdir(data_dir)
13+
files = [os.path.join(data_dir, f) for f in files]
14+
s = 'qa{}_'.format(task_id)
15+
train_file = [f for f in files if s in f and 'train' in f][0]
16+
test_file = [f for f in files if s in f and 'test' in f][0]
17+
train_data = get_stories(train_file, only_supporting)
18+
test_data = get_stories(test_file, only_supporting)
19+
return train_data, test_data
20+
21+
def tokenize(sent):
22+
'''Return the tokens of a sentence including punctuation.
23+
>>> tokenize('Bob dropped the apple. Where is the apple?')
24+
['Bob', 'dropped', 'the', 'apple', '.', 'Where', 'is', 'the', 'apple', '?']
25+
'''
26+
return [x.strip() for x in re.split('(\W+)?', sent) if x.strip()]
27+
28+
29+
def parse_stories(lines, only_supporting=False):
30+
'''Parse stories provided in the bAbI tasks format
31+
If only_supporting is true, only the sentences that support the answer are kept.
32+
'''
33+
data = []
34+
story = []
35+
for line in lines:
36+
line = str.lower(line)
37+
nid, line = line.split(' ', 1)
38+
nid = int(nid)
39+
if nid == 1:
40+
story = []
41+
if '\t' in line: # question
42+
q, a, supporting = line.split('\t')
43+
q = tokenize(q)
44+
#a = tokenize(a)
45+
# answer is one vocab word even if it's actually multiple words
46+
a = [a]
47+
substory = None
48+
49+
# remove question marks
50+
if q[-1] == "?":
51+
q = q[:-1]
52+
53+
if only_supporting:
54+
# Only select the related substory
55+
supporting = map(int, supporting.split())
56+
substory = [story[i - 1] for i in supporting]
57+
else:
58+
# Provide all the substories
59+
substory = [x for x in story if x]
60+
61+
data.append((substory, q, a))
62+
story.append('')
63+
else: # regular sentence
64+
# remove periods
65+
sent = tokenize(line)
66+
if sent[-1] == ".":
67+
sent = sent[:-1]
68+
story.append(sent)
69+
return data
70+
71+
72+
def get_stories(f, only_supporting=False):
73+
'''Given a file name, read the file, retrieve the stories, and then convert the sentences into a single story.
74+
If max_length is supplied, any stories longer than max_length tokens will be discarded.
75+
'''
76+
with open(f) as f:
77+
return parse_stories(f.readlines(), only_supporting=only_supporting)
78+
79+
def vectorize_data(data, word_idx, sentence_size, memory_size):
80+
"""
81+
Vectorize stories and queries.
82+
83+
If a sentence length < sentence_size, the sentence will be padded with 0's.
84+
85+
If a story length < memory_size, the story will be padded with empty memories.
86+
Empty memories are 1-D arrays of length sentence_size filled with 0's.
87+
88+
The answer array is returned as a one-hot encoding.
89+
"""
90+
S = []
91+
Q = []
92+
A = []
93+
for story, query, answer in data:
94+
ss = []
95+
for i, sentence in enumerate(story, 1):
96+
ls = max(0, sentence_size - len(sentence))
97+
ss.append([word_idx[w] for w in sentence] + [0] * ls)
98+
99+
# take only the most recent sentences that fit in memory
100+
ss = ss[::-1][:memory_size][::-1]
101+
102+
# Make the last word of each sentence the time 'word' which
103+
# corresponds to vector of lookup table
104+
for i in range(len(ss)):
105+
ss[i][-1] = len(word_idx) - memory_size - i + len(ss)
106+
107+
# pad to memory_size
108+
lm = max(0, memory_size - len(ss))
109+
for _ in range(lm):
110+
ss.append([0] * sentence_size)
111+
112+
lq = max(0, sentence_size - len(query))
113+
q = [word_idx[w] for w in query] + [0] * lq
114+
115+
y = np.zeros(len(word_idx) + 1) # 0 is reserved for nil word
116+
for a in answer:
117+
y[word_idx[a]] = 1
118+
119+
S.append(ss)
120+
Q.append(q)
121+
A.append(y)
122+
return np.array(S), np.array(Q), np.array(A)

nlp/Basic-EEMN-Demo/main.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from data_utils import load_task, vectorize_data
2+
from sklearn import cross_validation, metrics
3+
from memn2n import MemN2N
4+
from itertools import chain
5+
from six.moves import range, reduce
6+
7+
import tensorflow as tf
8+
import numpy as np
9+
10+
tf.flags.DEFINE_float("learning_rate", 0.01, "Learning rate for SGD.")
11+
tf.flags.DEFINE_float("anneal_rate", 25, "Number of epochs between halving the learnign rate.")
12+
tf.flags.DEFINE_float("anneal_stop_epoch", 100, "Epoch number to end annealed lr schedule.")
13+
tf.flags.DEFINE_float("max_grad_norm", 40.0, "Clip gradients to this norm.")
14+
tf.flags.DEFINE_integer("evaluation_interval", 10, "Evaluate and print results every x epochs")
15+
tf.flags.DEFINE_integer("batch_size", 32, "Batch size for training.")
16+
tf.flags.DEFINE_integer("hops", 3, "Number of hops in the Memory Network.")
17+
tf.flags.DEFINE_integer("epochs", 100, "Number of epochs to train for.")
18+
tf.flags.DEFINE_integer("embedding_size", 20, "Embedding size for embedding matrices.")
19+
tf.flags.DEFINE_integer("memory_size", 50, "Maximum size of memory.")
20+
tf.flags.DEFINE_integer("task_id", 1, "bAbI task id, 1 <= id <= 20")
21+
tf.flags.DEFINE_integer("random_state", None, "Random state.")
22+
tf.flags.DEFINE_string("data_dir", "data/tasks_1-20_v1-2/en/", "Directory containing bAbI tasks")
23+
FLAGS = tf.flags.FLAGS
24+
25+
print("Started Task:", FLAGS.task_id)
26+
27+
# task data
28+
train, test = load_task(FLAGS.data_dir, FLAGS.task_id)
29+
data = train + test
30+
31+
vocab = sorted(reduce(lambda x, y: x | y, (set(list(chain.from_iterable(s)) + q + a) for s, q, a in data)))
32+
word_idx = dict((c, i + 1) for i, c in enumerate(vocab))
33+
34+
max_story_size = max(map(len, (s for s, _, _ in data)))
35+
mean_story_size = int(np.mean([ len(s) for s, _, _ in data ]))
36+
sentence_size = max(map(len, chain.from_iterable(s for s, _, _ in data)))
37+
query_size = max(map(len, (q for _, q, _ in data)))
38+
memory_size = min(FLAGS.memory_size, max_story_size)
39+
40+
# Add time words/indexes
41+
for i in range(memory_size):
42+
word_idx['time{}'.format(i+1)] = 'time{}'.format(i+1)
43+
44+
vocab_size = len(word_idx) + 1 # +1 for nil word
45+
sentence_size = max(query_size, sentence_size) # for the position
46+
sentence_size += 1 # +1 for time words
47+
48+
print("Longest sentence length", sentence_size)
49+
print("Longest story length", max_story_size)
50+
print("Average story length", mean_story_size)
51+
52+
# train/validation/test sets
53+
S, Q, A = vectorize_data(train, word_idx, sentence_size, memory_size)
54+
trainS, valS, trainQ, valQ, trainA, valA = cross_validation.train_test_split(S, Q, A, test_size=.1, random_state=FLAGS.random_state)
55+
testS, testQ, testA = vectorize_data(test, word_idx, sentence_size, memory_size)
56+
57+
print(testS[0])
58+
59+
print("Training set shape", trainS.shape)
60+
61+
# params
62+
n_train = trainS.shape[0]
63+
n_test = testS.shape[0]
64+
n_val = valS.shape[0]
65+
66+
print("Training Size", n_train)
67+
print("Validation Size", n_val)
68+
print("Testing Size", n_test)
69+
70+
train_labels = np.argmax(trainA, axis=1)
71+
test_labels = np.argmax(testA, axis=1)
72+
val_labels = np.argmax(valA, axis=1)
73+
74+
tf.set_random_seed(FLAGS.random_state)
75+
batch_size = FLAGS.batch_size
76+
77+
batches = zip(range(0, n_train-batch_size, batch_size), range(batch_size, n_train, batch_size))
78+
batches = [(start, end) for start, end in batches]
79+
80+
with tf.Session() as sess:
81+
model = MemN2N(batch_size, vocab_size, sentence_size, memory_size, FLAGS.embedding_size, session=sess,
82+
hops=FLAGS.hops, max_grad_norm=FLAGS.max_grad_norm)
83+
for t in range(1, FLAGS.epochs+1):
84+
# Stepped learning rate
85+
if t - 1 <= FLAGS.anneal_stop_epoch:
86+
anneal = 2.0 ** ((t - 1) // FLAGS.anneal_rate)
87+
else:
88+
anneal = 2.0 ** (FLAGS.anneal_stop_epoch // FLAGS.anneal_rate)
89+
lr = FLAGS.learning_rate / anneal
90+
91+
np.random.shuffle(batches)
92+
total_cost = 0.0
93+
for start, end in batches:
94+
s = trainS[start:end]
95+
q = trainQ[start:end]
96+
a = trainA[start:end]
97+
cost_t = model.batch_fit(s, q, a, lr)
98+
total_cost += cost_t
99+
100+
if t % FLAGS.evaluation_interval == 0:
101+
train_preds = []
102+
for start in range(0, n_train, batch_size):
103+
end = start + batch_size
104+
s = trainS[start:end]
105+
q = trainQ[start:end]
106+
pred = model.predict(s, q)
107+
train_preds += list(pred)
108+
109+
val_preds = model.predict(valS, valQ)
110+
train_acc = metrics.accuracy_score(np.array(train_preds), train_labels)
111+
val_acc = metrics.accuracy_score(val_preds, val_labels)
112+
113+
print('-----------------------')
114+
print('Epoch', t)
115+
print('Total Cost:', total_cost)
116+
print('Training Accuracy:', train_acc)
117+
print('Validation Accuracy:', val_acc)
118+
print('-----------------------')
119+
120+
test_preds = model.predict(testS, testQ)
121+
test_acc = metrics.accuracy_score(test_preds, test_labels)
122+
print("Testing Accuracy:", test_acc)

0 commit comments

Comments
 (0)