From fa47ca07bc87c41e15bd847be8938fde6a4aaa2b Mon Sep 17 00:00:00 2001 From: "Xiaoming (Jason) Cui" Date: Wed, 21 Mar 2018 10:29:27 -0700 Subject: [PATCH 1/3] removed source_reverse from wmt16_gnmt_8_layer.json, since this hparam has been removed from NMT code, and it caused training and inference failures when use this .json file --- nmt/standard_hparams/wmt16_gnmt_8_layer.json | 1 - 1 file changed, 1 deletion(-) diff --git a/nmt/standard_hparams/wmt16_gnmt_8_layer.json b/nmt/standard_hparams/wmt16_gnmt_8_layer.json index 438ddcf55..da2034ca7 100644 --- a/nmt/standard_hparams/wmt16_gnmt_8_layer.json +++ b/nmt/standard_hparams/wmt16_gnmt_8_layer.json @@ -22,7 +22,6 @@ "share_vocab": false, "subword_option": "bpe", "sos": "", - "source_reverse": false, "src_max_len": 50, "src_max_len_infer": null, "steps_per_external_eval": null, From 8c3a240f3ff3ef707637d026dc52d63a2f4ae744 Mon Sep 17 00:00:00 2001 From: "Xiaoming (Jason) Cui" Date: Tue, 3 Apr 2018 00:36:36 -0700 Subject: [PATCH 2/3] add command line option to control the num_inter_threads and num_intra_threads for inference session --- nmt/inference.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/nmt/inference.py b/nmt/inference.py index 6f589337a..cf7924b5d 100644 --- a/nmt/inference.py +++ b/nmt/inference.py @@ -131,7 +131,10 @@ def single_worker_inference(infer_model, infer_data = load_data(inference_input_file, hparams) with tf.Session( - graph=infer_model.graph, config=utils.get_config_proto()) as sess: + graph=infer_model.graph, config=utils.get_config_proto( + num_intra_threads=hparams.num_intra_threads, + num_inter_threads=hparams.num_inter_threads + )) as sess: loaded_infer_model = model_helper.load_model( infer_model.model, ckpt, sess, "infer") sess.run( @@ -190,7 +193,10 @@ def multi_worker_inference(infer_model, infer_data = infer_data[start_position:end_position] with tf.Session( - graph=infer_model.graph, config=utils.get_config_proto()) as sess: + graph=infer_model.graph, config=utils.get_config_proto( + num_intra_threads=hparams.num_intra_threads, + num_inter_threads=hparams.num_inter_threads + )) as sess: loaded_infer_model = model_helper.load_model( infer_model.model, ckpt, sess, "infer") sess.run(infer_model.iterator.initializer, From 2ee7e32d5285eaee45153423eca8b4c105a1d7cd Mon Sep 17 00:00:00 2001 From: "Xiaoming (Jason) Cui" Date: Thu, 6 Sep 2018 11:40:43 -0700 Subject: [PATCH 3/3] Added support of setting inter_op_parallelism_threads and intra_op_parallelism_threads for inference --- nmt/inference.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/nmt/inference.py b/nmt/inference.py index 2cbef07c2..90fe5d287 100644 --- a/nmt/inference.py +++ b/nmt/inference.py @@ -95,10 +95,13 @@ def get_model_creator(hparams): return model_creator -def start_sess_and_load_model(infer_model, ckpt_path): +def start_sess_and_load_model(infer_model, ckpt_path, hparams): """Start session and load model.""" + print("intra inter is %d %d \n" %(hparams.num_intra_threads , hparams.num_inter_threads)) sess = tf.Session( - graph=infer_model.graph, config=utils.get_config_proto()) + graph=infer_model.graph, config=utils.get_config_proto( + num_intra_threads=hparams.num_intra_threads, + num_inter_threads=hparams.num_inter_threads)) with infer_model.graph.as_default(): loaded_infer_model = model_helper.load_model( infer_model.model, ckpt_path, sess, "infer") @@ -118,7 +121,7 @@ def inference(ckpt_path, model_creator = get_model_creator(hparams) infer_model = model_helper.create_infer_model(model_creator, hparams, scope) - sess, loaded_infer_model = start_sess_and_load_model(infer_model, ckpt_path) + sess, loaded_infer_model = start_sess_and_load_model(infer_model, ckpt_path, hparams) if num_workers == 1: single_worker_inference(