|
| 1 | +"""Main DQN agent.""" |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import tensorflow as tf |
| 5 | +from PIL import Image |
| 6 | +import random |
| 7 | +from huberLoss import mean_huber_loss, weighted_huber_loss |
| 8 | + |
| 9 | +EPSILON_BEGIN = 1.0 |
| 10 | +EPSILON_END = 0.1 |
| 11 | +BETA_BEGIN = 0.5 |
| 12 | +BETA_END = 1.0 |
| 13 | + |
| 14 | +class DQNAgent(): |
| 15 | + def __init__(self, |
| 16 | + eval_model, |
| 17 | + target_model, |
| 18 | + memory, |
| 19 | + num_actions, |
| 20 | + gamma, |
| 21 | + update_freq, |
| 22 | + target_update_freq, |
| 23 | + update_target_params_ops, |
| 24 | + batch_size, |
| 25 | + is_double_dqn, |
| 26 | + is_per, |
| 27 | + is_distributional, |
| 28 | + num_step, |
| 29 | + is_noisy, |
| 30 | + learning_rate, |
| 31 | + rmsp_decay, |
| 32 | + rmsp_momentum, |
| 33 | + rmsp_epsilon): |
| 34 | + |
| 35 | + self._eval_model = eval_model |
| 36 | + self._target_model = target_model |
| 37 | + self._memory = memory |
| 38 | + self._num_actions = num_actions |
| 39 | + self._gamma = gamma |
| 40 | + self._update_freq = update_freq |
| 41 | + self._target_update_freq = target_update_freq |
| 42 | + self._update_target_params_ops = update_target_params_ops |
| 43 | + self._batch_size = batch_size |
| 44 | + self._is_double_dqn = is_double_dqn |
| 45 | + self._is_per = is_per |
| 46 | + self._is_distributional = is_distributional |
| 47 | + self._num_step = num_step |
| 48 | + self._is_noisy = is_noisy |
| 49 | + self._learning_rate = learning_rate |
| 50 | + self._rmsp_decay = rmsp_decay |
| 51 | + self._rmsp_momentum = rmsp_momentum |
| 52 | + self._rmsp_epsilon = rmsp_epsilon |
| 53 | + self._update_times = 0 |
| 54 | + self._beta = EPSILON_BEGIN |
| 55 | + self._beta_increment = (EPSILON_END-BETA_BEGIN)/2000000.0 |
| 56 | + self._epsilon = EPSILON_BEGIN if is_noisy else 0. |
| 57 | + self._epsilon_increment = (EPSILON_END - EPSILON_BEGIN)/2000000.0 if is_noisy==0 else 0. |
| 58 | + self._action_ph = tf.placeholder(tf.int32,[None,2],'action_ph') |
| 59 | + self._reward_ph = tf.placeholder(tf.float32,name='reward_ph') |
| 60 | + self._is_terminal_ph = tf.placeholder(tf.float32,name='is_terminal_ph') |
| 61 | + self._action_chosen_by_eval_ph = tf.placeholder(tf.int32,[None,2],'action_chosen_by_eval_ph') |
| 62 | + self._loss_weight_ph = tf.placeholder(tf.float32,name='loss_weight_ph') |
| 63 | + self._error_op,self._train_op = self._get_error_and_train_op(self._reward_ph,self._is_terminal_ph, |
| 64 | + self._action_ph,self._action_chosen_by_eval_ph, |
| 65 | + self._loss_weight_ph) |
| 66 | + |
| 67 | + |
| 68 | + def _get_error_and_train_op(self,reward_ph, |
| 69 | + is_terminal_ph, |
| 70 | + action_ph, |
| 71 | + action_chosen_by_eval_ph, |
| 72 | + loss_weight_ph): |
| 73 | + |
| 74 | + if self._is_distributional == 0: |
| 75 | + q_values_target = self._target_model['q_values'] |
| 76 | + q_values_eval = self._eval_model['q_values'] |
| 77 | + |
| 78 | + if self._is_double_dqn: |
| 79 | + max_q = tf.gather_nd(q_values_target,action_chosen_by_eval_ph) # 如果是double-dqn,动作由eval-net选出,q值由target-net得到 |
| 80 | + else: |
| 81 | + max_q = tf.reduce_max(q_values_target,axis=1) |
| 82 | + |
| 83 | + target = reward_ph + (1.0 - is_terminal_ph) * (self._gamma ** self._num_step) * max_q # 这里是多步的dqn |
| 84 | + gathered_outputs = tf.gather_nd(q_values_eval,action_ph,name='gathered_outputs') |
| 85 | + |
| 86 | + if self._is_per == 1: |
| 87 | + loss = weighted_huber_loss(target,gathered_outputs,loss_weight_ph) |
| 88 | + else: |
| 89 | + loss = mean_huber_loss(target,gathered_outputs) |
| 90 | + train_op = tf.train.RMSPropOptimizer(self._learning_rate,decay=self._rmsp_decay, |
| 91 | + momentum=self._rmsp_momentum,epsilon=self._rmsp_epsilon).minimize(loss) |
| 92 | + |
| 93 | + error_op = tf.abs(gathered_outputs - target,name='abs_error') |
| 94 | + return train_op,error_op |
| 95 | + |
| 96 | + else: |
| 97 | + N_atoms = 51 |
| 98 | + V_Max = 20.0 |
| 99 | + V_Min = 0.0 |
| 100 | + Delta_z = (V_Max - V_Min) / (N_atoms - 1) |
| 101 | + z_list = tf.constant([V_Min + i * Delta_z for i in range(N_atoms)], dtype=tf.float32) |
| 102 | + |
| 103 | + q_distributional_values_target = self._target_model['q_distributional_network'] # batch_size * num_actions * N_atoms |
| 104 | + tmp_batch_size = tf.shape(q_distributional_values_target)[0] # batch_size |
| 105 | + |
| 106 | + if self._is_double_dqn: |
| 107 | + q_distributional_chosen_by_action_target = tf.gather_nd(q_distributional_values_target,action_chosen_by_eval_ph) |
| 108 | + else: |
| 109 | + action_chosen_by_target_q = tf.cast(tf.argmax(self._target_model['q_values'], axis=1), tf.int32) |
| 110 | + q_distributional_chosen_by_action_target = tf.gather_nd(q_distributional_values_target, |
| 111 | + tf.concat([tf.reshape(tf.range(tmp_batch_size),[-1,1]), |
| 112 | + tf.reshape(action_chosen_by_target_q,[-1,1])],axis=1)) |
| 113 | + |
| 114 | + |
| 115 | + target = tf.tile(tf.reshape(reward_ph,[-1,1]),[1,N_atoms]) + \ |
| 116 | + (self._gamma * self._num_step) * \ |
| 117 | + tf.multiply(tf.reshape(z_list,[1,N_atoms]),(1.0 - tf.tile(tf.reshape(is_terminal_ph,[-1,1]),[1,N_atoms]))) |
| 118 | + |
| 119 | + target = tf.clip_by_value(target,V_Min,V_Max) |
| 120 | + |
| 121 | + b = (target - V_Min) / Delta_z |
| 122 | + |
| 123 | + u,l = tf.ceil(b),tf.floor(b) |
| 124 | + |
| 125 | + u_id,l_id = tf.cast(u,tf.int32),tf.cast(l,tf.int32) |
| 126 | + |
| 127 | + u_minus_b,b_minus_l = u - b,b - l |
| 128 | + q_distributional_values_eval = self._eval_model['q_distributional_network'] |
| 129 | + |
| 130 | + q_distributional_chosen_by_action_eval = tf.gather_nd(q_distributional_values_eval,action_ph) |
| 131 | + |
| 132 | + index_help = tf.tile(tf.reshape(tf.range(tmp_batch_size),[-1,1]),[1,N_atoms]) |
| 133 | + |
| 134 | + index_help = tf.expand_dims(index_help,-1) # batch * N_atoms * 1 |
| 135 | + u_id = tf.concat([index_help,tf.expand_dims(u_id,-1)],axis=2) |
| 136 | + l_id = tf.concat([index_help,tf.expand_dims(l_id,-1)],axis=2) |
| 137 | + |
| 138 | + error = q_distributional_chosen_by_action_target * u_minus_b * \ |
| 139 | + tf.log(tf.gather_nd(q_distributional_chosen_by_action_eval, l_id)) \ |
| 140 | + + q_distributional_chosen_by_action_target * b_minus_l * \ |
| 141 | + tf.log(tf.gather_nd(q_distributional_chosen_by_action_eval, u_id)) |
| 142 | + error = tf.reduce_sum(error, axis=1) |
| 143 | + |
| 144 | + if self._is_per == 1: |
| 145 | + loss = tf.negative(error * loss_weight_ph) |
| 146 | + else: |
| 147 | + loss = tf.negative(error) |
| 148 | + |
| 149 | + train_op = tf.train.RMSPropOptimizer(self._learning_rate, |
| 150 | + decay=self._rmsp_decay, momentum=self._rmsp_momentum, |
| 151 | + epsilon=self._rmsp_epsilon).minimize(loss) |
| 152 | + error_op = tf.abs(error, name='abs_error') |
| 153 | + return error_op, train_op |
| 154 | + |
| 155 | + def select_action(self,sess,state,epsilon,model): |
| 156 | + batch_size = len(state) |
| 157 | + if np.random.rand() < epsilon: |
| 158 | + action = np.random.randint(0,self._num_actions,size=(batch_size,)) |
| 159 | + else: |
| 160 | + state = state.astype(np.float32) / 255.0 |
| 161 | + feed_dict = {model['input_frames'] :state} |
| 162 | + action = sess.run(model['action'],feed_dict=feed_dict) |
| 163 | + return action |
| 164 | + |
| 165 | + def get_multi_step_sample(self,env,sess,num_step,epsilon): |
| 166 | + old_state,action,reward,new_state,is_terminal = env.get_state() |
| 167 | + total_reward = np.sign(reward) |
| 168 | + total_is_terminal = is_terminal |
| 169 | + |
| 170 | + next_action = self.select_action(sess,new_state,epsilon,self._eval_model) |
| 171 | + env.take_action(next_action) |
| 172 | + |
| 173 | + for i in range(1,num_step): |
| 174 | + _,_,reward,new_state,is_terminal = env.get_state() |
| 175 | + total_reward += self._gamma ** i * np.sign(reward) |
| 176 | + total_is_terminal += is_terminal |
| 177 | + next_action = self.select_action(sess,new_state,epsilon,self._eval_model) |
| 178 | + env.take_action(next_action) |
| 179 | + |
| 180 | + return old_state,action,total_reward,new_state,np.sign(total_is_terminal) |
| 181 | + |
| 182 | + def fit(self,sess,env,num_iterations,do_train=True): |
| 183 | + |
| 184 | + num_environment = env.num_process |
| 185 | + env.reset() |
| 186 | + |
| 187 | + for t in range(0,num_iterations,num_environment): |
| 188 | + # 准备数据 |
| 189 | + old_state,action,reward,new_state,is_terminal = self.get_multi_step_sample(env,sess,self._num_step,self._epsilon) |
| 190 | + self._memory.append(old_state,action,reward,new_state,is_terminal) # 插入数据 |
| 191 | + if self._epsilon > EPSILON_END: |
| 192 | + self._epsilon += num_environment * self._epsilon_increment |
| 193 | + if do_train: |
| 194 | + num_update = sum([1 if i % self._update_freq == 0 else 0 for i in range(t, t + num_environment)]) |
| 195 | + # 抽取数据 |
| 196 | + for _ in range(num_update): |
| 197 | + if self._is_per == 1: |
| 198 | + (old_state_list, action_list, reward_list, new_state_list, is_terminal_list), \ |
| 199 | + idx_list, p_list, sum_p, count = self._memory.sample(self._batch_size) |
| 200 | + else: |
| 201 | + old_state_list, action_list, reward_list, new_state_list, is_terminal_list \ |
| 202 | + = self._memory.sample(self._batch_size) |
| 203 | + |
| 204 | + feed_dict = {self._target_model['input_frames']: new_state_list.astype(np.float32) / 255.0, |
| 205 | + self._eval_model['input_frames']: old_state_list.astype(np.float32) / 255.0, |
| 206 | + self._action_ph: list(enumerate(action_list)), |
| 207 | + self._reward_ph: np.array(reward_list).astype(np.float32), |
| 208 | + self._is_terminal_ph: np.array(is_terminal_list).astype(np.float32), |
| 209 | + } |
| 210 | + |
| 211 | + if self._is_double_dqn: |
| 212 | + action_chosen_by_online = sess.run(self._eval_model['action'], feed_dict={ |
| 213 | + self._eval_model['input_frames']: new_state_list.astype(np.float32)/255.0}) |
| 214 | + feed_dict[self._action_chosen_by_eval_ph] = list(enumerate(action_chosen_by_online)) |
| 215 | + |
| 216 | + if self._is_per == 1: |
| 217 | + # Annealing weight beta |
| 218 | + feed_dict[self._loss_weight_ph] = (np.array(p_list)*count/sum_p)**(-self._beta) |
| 219 | + error, _ = sess.run([self._error_op, self._train_op], feed_dict=feed_dict) |
| 220 | + self._memory.update(idx_list, error) |
| 221 | + else: |
| 222 | + sess.run(self._train_op, feed_dict=feed_dict) |
| 223 | + |
| 224 | + self._update_times += 1 |
| 225 | + if self._beta < BETA_END: |
| 226 | + self._beta += self._beta_increment |
| 227 | + |
| 228 | + if self._update_times%self._target_update_freq == 0: |
| 229 | + sess.run(self._update_target_params_ops) |
| 230 | + |
| 231 | + |
| 232 | + def _get_error(self, sess, old_state, action, reward, new_state, is_terminal): |
| 233 | + ''' |
| 234 | + Get TD error for Prioritized Experience Replay |
| 235 | + ''' |
| 236 | + feed_dict = {self._target_model['input_frames']: new_state.astype(np.float32)/255.0, |
| 237 | + self._eval_model['input_frames']: old_state.astype(np.float32)/255.0, |
| 238 | + self._action_ph: list(enumerate(action)), |
| 239 | + self._reward_ph: np.array(reward).astype(np.float32), |
| 240 | + self._is_terminal_ph: np.array(is_terminal).astype(np.float32), |
| 241 | + } |
| 242 | + |
| 243 | + if self._is_double_dqn: |
| 244 | + action_chosen_by_online = sess.run(self._eval_model['action'], feed_dict={ |
| 245 | + self._eval_model['input_frames']: new_state.astype(np.float32)/255.0}) |
| 246 | + feed_dict[self._action_chosen_by_eval_ph] = list(enumerate(action_chosen_by_online)) |
| 247 | + |
| 248 | + error = sess.run(self._error_op, feed_dict=feed_dict) |
| 249 | + return error |
| 250 | + |
| 251 | + def get_mean_max_Q(self, sess, samples): |
| 252 | + mean_max = [] |
| 253 | + INCREMENT = 1000 |
| 254 | + for i in range(0, len(samples), INCREMENT): |
| 255 | + feed_dict = {self._eval_model['input_frames']: |
| 256 | + samples[i: i + INCREMENT].astype(np.float32)/255.0} |
| 257 | + mean_max.append(sess.run(self._eval_model['mean_max_Q'], |
| 258 | + feed_dict = feed_dict)) |
| 259 | + return np.mean(mean_max) |
| 260 | + |
| 261 | + |
| 262 | + def evaluate(self, sess, env, num_episode): |
| 263 | + """Evaluate num_episode games by online model. |
| 264 | + Parameters |
| 265 | + ---------- |
| 266 | + sess: tf.Session |
| 267 | + env: batchEnv.BatchEnvironment |
| 268 | + This is your paralleled Atari environment. |
| 269 | + num_episode: int |
| 270 | + This is the number of episode of games to evaluate |
| 271 | + Returns |
| 272 | + ------- |
| 273 | + reward list for each episode |
| 274 | + """ |
| 275 | + num_environment = env.num_process |
| 276 | + env.reset() |
| 277 | + reward_of_each_environment = np.zeros(num_environment) |
| 278 | + rewards_list = [] |
| 279 | + |
| 280 | + num_finished_episode = 0 |
| 281 | + |
| 282 | + while num_finished_episode < num_episode: |
| 283 | + old_state, action, reward, new_state, is_terminal = env.get_state() |
| 284 | + action = self.select_action(sess, new_state, 0, self._eval_model) |
| 285 | + env.take_action(action) |
| 286 | + for i, r, is_t in zip(range(num_environment), reward, is_terminal): |
| 287 | + if not is_t: |
| 288 | + reward_of_each_environment[i] += r |
| 289 | + else: |
| 290 | + rewards_list.append(reward_of_each_environment[i]) |
| 291 | + reward_of_each_environment[i] = 0 |
| 292 | + num_finished_episode += 1 |
| 293 | + return np.mean(rewards_list), np.std(rewards_list) |
| 294 | + |
0 commit comments