Skip to content

Commit 7d03dd0

Browse files
author
shixiaowen03
committed
distributional rl
1 parent c659ae9 commit 7d03dd0

File tree

3 files changed

+19
-10
lines changed

3 files changed

+19
-10
lines changed

.idea/workspace.xml

Lines changed: 7 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

RL/Basic-DisRL-Demo/Categorical_DQN.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@ def __init__(self,env,config):
1515
self.v_min = self.config.v_min
1616
self.atoms = self.config.atoms
1717

18-
self.time_step = 0
1918
self.epsilon = self.config.INITIAL_EPSILON
2019
self.state_shape = env.observation_space.shape
2120
self.action_dim = env.action_space.n
2221

22+
self.time_step = 0
23+
2324
target_state_shape = [1]
2425
target_state_shape.extend(self.state_shape)
2526

@@ -82,10 +83,16 @@ def build_cate_dqn_net(self):
8283

8384
self.optimizer = tf.train.AdamOptimizer(self.config.LEARNING_RATE).minimize(self.cross_entropy_loss)
8485

86+
eval_params = tf.get_collection("eval_net_params")
87+
target_params = tf.get_collection('target_net_params')
88+
89+
self.update_target_net = [tf.assign(t, e) for t, e in zip(target_params, eval_params)]
90+
8591

8692

8793

8894
def train(self,s,r,action,s_,gamma):
95+
self.time_step += 1
8996
list_q_ = [self.sess.run(self.q_target,feed_dict={self.state_input:[s_],self.action_input:[[a]]}) for a in range(self.action_dim)]
9097
a_ = tf.argmax(list_q_).eval()
9198
m = np.zeros(self.atoms)
@@ -103,6 +110,8 @@ def train(self,s,r,action,s_,gamma):
103110
self.sess.run(self.optimizer,feed_dict={self.state_input:[s] , self.action_input:[action], self.m_input: m })
104111

105112

113+
if self.time_step % self.config.UPDATE_TARGET_NET == 0:
114+
self.sess.run(self.update_target_net)
106115

107116
def save_model(self):
108117
print("Model saved in : ", self.saver.save(self.sess, self.config.MODEL_PATH))

RL/Basic-DisRL-Demo/Config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,5 @@ class Categorical_DQN_Config():
2020
replay_buffer_size = 2000
2121
iteration = 5
2222
episode = 300 # 300 games per iteration
23+
24+

0 commit comments

Comments
 (0)