|
| 1 | +import tensorflow as tf |
| 2 | +import numpy as np |
| 3 | +import random |
| 4 | +from collections import deque |
| 5 | + |
| 6 | +from utils import conv,noisy_dense |
| 7 | + |
| 8 | +class NoisyNetDQN(): |
| 9 | + def __init__(self,env,config): |
| 10 | + self.sess = tf.InteractiveSession() |
| 11 | + self.config = config |
| 12 | + |
| 13 | + self.replay_buffer = deque(maxlen = self.config.replay_buffer_size) |
| 14 | + self.time_step = 0 |
| 15 | + |
| 16 | + self.state_dim = env.observation_space.shape |
| 17 | + self.action_dim = env.action_space.n |
| 18 | + |
| 19 | + print('state_dim:', self.state_dim) |
| 20 | + print('action_dim:', self.action_dim) |
| 21 | + |
| 22 | + self.action_batch = tf.placeholder('int32',[None]) |
| 23 | + self.y_input = tf.placeholder('float',[None,self.action_dim]) |
| 24 | + |
| 25 | + batch_shape = [None] |
| 26 | + batch_shape.extend(self.state_dim) |
| 27 | + |
| 28 | + self.eval_input = tf.placeholder('float',batch_shape) |
| 29 | + self.target_input = tf.placeholder('float',batch_shape) |
| 30 | + |
| 31 | + self.build_noisy_dqn_net() |
| 32 | + |
| 33 | + self.saver = tf.train.Saver() |
| 34 | + |
| 35 | + self.sess.run(tf.global_variables_initializer()) |
| 36 | + |
| 37 | + self.save_model() |
| 38 | + self.restore_model() |
| 39 | + |
| 40 | + def build_layers(self,state,c_names,units_1,units_2,w_i,b_i,reg=None): |
| 41 | + with tf.variable_scope('conv1'): |
| 42 | + conv1 = conv(state,[5,5,3,6],[6],[1,2,2,1],w_i,b_i) |
| 43 | + with tf.variable_scope('conv2'): |
| 44 | + conv2 = conv(conv1,[3,3,6,12],[12],[1,2,2,1],w_i,b_i) |
| 45 | + with tf.variable_scope('flatten'): |
| 46 | + flatten = tf.contrib.layers.flatten(conv2) |
| 47 | + |
| 48 | + with tf.variable_scope('dense1'): |
| 49 | + dense1 = noisy_dense(flatten,units_1,[units_1],c_names,w_i,b_i,noisy_distribution = self.config.noisy_distribution) |
| 50 | + |
| 51 | + with tf.variable_scope('dense2'): |
| 52 | + dense2 = noisy_dense(dense1,units_2,[units_2],c_names,w_i,b_i,noisy_distribution = self.config.noisy_distribution) |
| 53 | + |
| 54 | + with tf.variable_scope('dense3'): |
| 55 | + dense3 = noisy_dense(dense2,self.action_dim,[self.action_dim],c_names,w_i,b_i,noisy_distribution = self.config.noisy_distribution) |
| 56 | + |
| 57 | + return dense3 |
| 58 | + |
| 59 | + def build_noisy_dqn_net(self): |
| 60 | + with tf.variable_scope('target_net'): |
| 61 | + c_names = ['target_net_arams',tf.GraphKeys.GLOBAL_VARIABLES] |
| 62 | + w_i = tf.random_uniform_initializer(-0.1,0.1) |
| 63 | + b_i = tf.constant_initializer(0.1) |
| 64 | + self.q_target = self.build_layers(self.target_input,c_names,24,24,w_i,b_i) |
| 65 | + |
| 66 | + with tf.variable_scope('eval_net'): |
| 67 | + c_names = ['eval_net_params',tf.GraphKeys.GLOBAL_VARIABLES] |
| 68 | + w_i = tf.random_uniform_initializer(-0.1,0.1) |
| 69 | + b_i = tf.constant_initializer(0.1) |
| 70 | + self.q_eval = self.build_layers(self.eval_input,c_names,24,24,w_i,b_i) |
| 71 | + |
| 72 | + self.loss = tf.reduce_mean(tf.squared_difference(self.q_eval,self.y_input)) |
| 73 | + |
| 74 | + self.optimizer = tf.train.AdamOptimizer(self.config.LEARNING_RATE).minimize(self.loss) |
| 75 | + |
| 76 | + eval_params = tf.get_collection("eval_net_params") |
| 77 | + target_params = tf.get_collection('target_net_params') |
| 78 | + |
| 79 | + self.update_target_net = [tf.assign(t,e) for t,e in zip(target_params,eval_params)] |
| 80 | + |
| 81 | + |
| 82 | + def save_model(self): |
| 83 | + print("Model saved in : ", self.saver.save(self.sess, self.config.MODEL_PATH)) |
| 84 | + |
| 85 | + def restore_model(self): |
| 86 | + self.saver.restore(self.sess, self.config.MODEL_PATH) |
| 87 | + print("Model restored.") |
| 88 | + |
| 89 | + |
| 90 | + def perceive(self,state,action,reward,next_state,done): |
| 91 | + self.replay_buffer.append((state,action,reward,next_state,done)) |
| 92 | + |
| 93 | + |
| 94 | + def train_q_network(self,update=True): |
| 95 | + |
| 96 | + if len(self.replay_buffer) < self.config.START_TRAINING: |
| 97 | + return |
| 98 | + |
| 99 | + self.time_step += 1 |
| 100 | + minibatch = random.sample(self.replay_buffer,self.config.BATCH_SIZE) |
| 101 | + |
| 102 | + np.random.shuffle(minibatch) |
| 103 | + |
| 104 | + state_batch = [data[0] for data in minibatch] |
| 105 | + action_batch = [data[1] for data in minibatch] |
| 106 | + reward_batch = [data[2] for data in minibatch] |
| 107 | + next_state_batch = [data[3] for data in minibatch] |
| 108 | + done = [data[4] for data in minibatch] |
| 109 | + |
| 110 | + q_target = self.sess.run(self.q_target,feed_dict={self.target_input:next_state_batch}) |
| 111 | + q_eval = self.sess.run(self.q_eval,feed_dict={self.eval_input:state_batch}) |
| 112 | + |
| 113 | + done = np.array(done) + 0 |
| 114 | + |
| 115 | + # DQN的结构 r + max q_target[a] |
| 116 | + y_batch = np.zeros((self.config.BATCH_SIZE,self.action_dim)) |
| 117 | + for i in range(0,self.config.BATCH_SIZE): |
| 118 | + temp = q_eval[i] |
| 119 | + action = np.argmax(q_target[i]) |
| 120 | + temp[action_batch[i]] = reward_batch[i] + (1 - done[i]) * self.config.GAMMA * q_target[i][action] |
| 121 | + y_batch[i] = temp |
| 122 | + |
| 123 | + |
| 124 | + self.sess.run(self.optimizer,feed_dict={ |
| 125 | + self.y_input:y_batch, |
| 126 | + self.eval_input:state_batch, |
| 127 | + self.action_batch:action_batch |
| 128 | + }) |
| 129 | + |
| 130 | + if update and self.time_step % self.config.UPDATE_TARGET_NET == 0: |
| 131 | + self.sess.run(self.update_target_net) |
| 132 | + |
| 133 | + |
| 134 | + |
| 135 | + def noisy_action(self, state): |
| 136 | + |
| 137 | + return np.argmax(self.sess.run(self.q_target,feed_dict={self.target_input: [state]})[0]) |
| 138 | + |
| 139 | + |
| 140 | + |
| 141 | + |
| 142 | + |
| 143 | + |
| 144 | + |
| 145 | + |
| 146 | + |
| 147 | + |
| 148 | + |
0 commit comments