Skip to content

Commit fdc8cdd

Browse files
author
shixiaowen03
committed
attention is all u need
1 parent fcef0ce commit fdc8cdd

File tree

7 files changed

+903
-266
lines changed

7 files changed

+903
-266
lines changed

.idea/workspace.xml

Lines changed: 216 additions & 266 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from hyperparams import Hyperparams as hp
2+
import tensorflow as tf
3+
import numpy as np
4+
import codecs
5+
import regex
6+
7+
def load_de_vocab():
8+
vocab = [line.split()[0] for line in codecs.open('data/de.vocab.tsv','r','utf-8').read().splitlines()
9+
if int(line.split()[1])>=hp.min_cnt]
10+
word2idx = {word:idx for idx,word in enumerate(vocab)}
11+
idx2word = {idx:word for idx,word in enumerate(vocab)}
12+
13+
return word2idx,idx2word
14+
15+
def load_en_vocab():
16+
vocab = [line.split()[0] for line in codecs.open('data/en.vocab.tsv','r','utf-8').read().splitlines()
17+
if int(line.split()[1])>=hp.min_cnt]
18+
19+
word2idx = {word:idx for idx,word in enumerate(vocab)}
20+
idx2word = {idx:word for idx,word in enumerate(vocab)}
21+
return word2idx,idx2word
22+
23+
24+
25+
def create_data(source_sents,target_sents):
26+
de2idx,idx2de = load_de_vocab()
27+
en2idx,idx2en = load_en_vocab()
28+
29+
x_list ,y_list,Sources,Targets = [],[],[],[]
30+
for source_sent,target_sent in zip(source_sents,target_sents):
31+
x = [de2idx.get(word,1) for word in (source_sent+u" </S>").split()] # 1: OOV, </S>: End of Text
32+
y = [en2idx.get(word,1) for word in (target_sent+u" </S>").split()]
33+
34+
if max(len(x),len(y)) <= hp.maxlen:
35+
x_list.append(np.array(x))
36+
y_list.append(np.array(y))
37+
Sources.append(source_sent)
38+
Targets.append(target_sent)
39+
40+
#Pad
41+
X = np.zeros([len(x_list),hp.maxlen],np.int32)
42+
Y = np.zeros([len(y_list),hp.maxlen],np.int32)
43+
44+
for i,(x,y) in enumerate(zip(x_list,y_list)):
45+
X[i] = np.lib.pad(x,[0,hp.maxlen-len(x)],'constant',constant_values=(0,0))
46+
Y[i] = np.lib.pad(y,[0,hp.maxlen-len(y)],'constant',constant_values=(0,0))
47+
return X,Y,Sources,Targets
48+
49+
50+
51+
def load_train_data():
52+
def _refine(line):
53+
line = regex.sub("[^\s\p{Latin}']", "", line)
54+
return line.strip()
55+
56+
de_sents = [_refine(line) for line in codecs.open(hp.source_train, 'r', 'utf-8').read().split('\n') if
57+
line and line[0] != "<"]
58+
en_sents = [_refine(line) for line in codecs.open(hp.target_train, 'r', 'utf-8').read().split('\n') if
59+
line and line[0] != '<']
60+
61+
X, Y, Sources, Targets = create_data(de_sents, en_sents)
62+
return X, Y
63+
64+
65+
def load_test_data():
66+
def _refine(line):
67+
line = regex.sub("<[^>]+>", "", line)
68+
line = regex.sub("[^\s\p{Latin}']", "", line)
69+
return line.strip()
70+
71+
de_sents = [_refine(line) for line in codecs.open(hp.source_test,'r','utf-8').read().split('\n') if line and line[:4] == "<seg"]
72+
en_sents = [_refine(line) for line in codecs.open(hp.target_test,'r','utf-8').read().split('\n') if line and line[:4] == '<seg']
73+
74+
X,Y,Sources,Targets = create_data(de_sents,en_sents)
75+
return X,Sources,Targets
76+
77+
78+
79+
def get_batch_data():
80+
X,Y = load_train_data()
81+
82+
num_batch = len(X) // hp.batch_size
83+
84+
print(X[:10],Y[:10])
85+
X = tf.convert_to_tensor(X,tf.int32)
86+
Y = tf.convert_to_tensor(Y,tf.int32)
87+
88+
input_queues = tf.train.slice_input_producer([X,Y])
89+
90+
x,y = tf.train.shuffle_batch(input_queues,
91+
num_threads=8,
92+
batch_size=hp.batch_size,
93+
capacity = hp.batch_size*64,
94+
min_after_dequeue=hp.batch_size * 32,
95+
allow_smaller_final_batch=False)
96+
97+
return x,y,num_batch
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
数据地址为:https://pan.baidu.com/s/14XfprCqjmBKde9NmNZeCNg 密码:lfwu
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
2+
import codecs
3+
import os
4+
5+
import tensorflow as tf
6+
import numpy as np
7+
8+
from hyperparams import Hyperparams as hp
9+
from data_load import load_test_data, load_de_vocab, load_en_vocab
10+
from train import Graph
11+
from nltk.translate.bleu_score import corpus_bleu
12+
13+
14+
def eval():
15+
# Load graph
16+
g = Graph(is_training=False)
17+
print("Graph loaded")
18+
19+
# Load data
20+
X, Sources, Targets = load_test_data()
21+
de2idx, idx2de = load_de_vocab()
22+
en2idx, idx2en = load_en_vocab()
23+
24+
# X, Sources, Targets = X[:33], Sources[:33], Targets[:33]
25+
26+
# Start session
27+
with g.graph.as_default():
28+
sv = tf.train.Supervisor()
29+
with sv.managed_session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
30+
## Restore parameters
31+
sv.saver.restore(sess, tf.train.latest_checkpoint(hp.logdir))
32+
print("Restored!")
33+
34+
## Get model name
35+
mname = open(hp.logdir + '/checkpoint', 'r').read().split('"')[1] # model name
36+
37+
## Inference
38+
if not os.path.exists('results'): os.mkdir('results')
39+
with codecs.open("results/" + mname, "w", "utf-8") as fout:
40+
list_of_refs, hypotheses = [], []
41+
for i in range(len(X) // hp.batch_size):
42+
43+
### Get mini-batches
44+
x = X[i * hp.batch_size: (i + 1) * hp.batch_size]
45+
sources = Sources[i * hp.batch_size: (i + 1) * hp.batch_size]
46+
targets = Targets[i * hp.batch_size: (i + 1) * hp.batch_size]
47+
48+
### Autoregressive inference
49+
### 在测试的时候是一个一个预测
50+
preds = np.zeros((hp.batch_size, hp.maxlen), np.int32)
51+
for j in range(hp.maxlen):
52+
_preds = sess.run(g.preds, {g.x: x, g.y: preds})
53+
preds[:, j] = _preds[:, j]
54+
55+
### Write to file
56+
for source, target, pred in zip(sources, targets, preds): # sentence-wise
57+
got = " ".join(idx2en[idx] for idx in pred).split("</S>")[0].strip()
58+
fout.write("- source: " + source + "\n")
59+
fout.write("- expected: " + target + "\n")
60+
fout.write("- got: " + got + "\n\n")
61+
fout.flush()
62+
63+
# bleu score
64+
ref = target.split()
65+
hypothesis = got.split()
66+
if len(ref) > 3 and len(hypothesis) > 3:
67+
list_of_refs.append([ref])
68+
hypotheses.append(hypothesis)
69+
70+
## Calculate bleu score
71+
score = corpus_bleu(list_of_refs, hypotheses)
72+
fout.write("Bleu Score = " + str(100 * score))
73+
74+
75+
if __name__ == '__main__':
76+
eval()
77+
print("Done")
78+
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
class Hyperparams:
2+
'''Hyperparameters'''
3+
# data
4+
source_train = 'data/train.tags.de-en.de'
5+
target_train = 'data/train.tags.de-en.en'
6+
source_test = 'data/IWSLT16.TED.tst2014.de-en.de.xml'
7+
target_test = 'data/IWSLT16.TED.tst2014.de-en.en.xml'
8+
9+
# training
10+
batch_size = 32 # alias = N
11+
lr = 0.0001 # learning rate. In paper, learning rate is adjusted to the global step.
12+
logdir = 'logdir' # log directory
13+
14+
# model
15+
maxlen = 10 # Maximum number of words in a sentence. alias = T.
16+
# Feel free to increase this if you are ambitious.
17+
min_cnt = 20 # words whose occurred less than min_cnt are encoded as <UNK>.
18+
hidden_units = 512 # alias = C
19+
num_blocks = 6 # number of encoder/decoder blocks
20+
num_epochs = 20
21+
num_heads = 8
22+
dropout_rate = 0.1
23+
sinusoid = False # If True, use sinusoid. If false, positional embedding.
24+

0 commit comments

Comments
 (0)