@@ -15,11 +15,12 @@ def __init__(self,env,config):
1515 self .v_min = self .config .v_min
1616 self .atoms = self .config .atoms
1717
18- self .time_step = 0
1918 self .epsilon = self .config .INITIAL_EPSILON
2019 self .state_shape = env .observation_space .shape
2120 self .action_dim = env .action_space .n
2221
22+ self .time_step = 0
23+
2324 target_state_shape = [1 ]
2425 target_state_shape .extend (self .state_shape )
2526
@@ -82,10 +83,16 @@ def build_cate_dqn_net(self):
8283
8384 self .optimizer = tf .train .AdamOptimizer (self .config .LEARNING_RATE ).minimize (self .cross_entropy_loss )
8485
86+ eval_params = tf .get_collection ("eval_net_params" )
87+ target_params = tf .get_collection ('target_net_params' )
88+
89+ self .update_target_net = [tf .assign (t , e ) for t , e in zip (target_params , eval_params )]
90+
8591
8692
8793
8894 def train (self ,s ,r ,action ,s_ ,gamma ):
95+ self .time_step += 1
8996 list_q_ = [self .sess .run (self .q_target ,feed_dict = {self .state_input :[s_ ],self .action_input :[[a ]]}) for a in range (self .action_dim )]
9097 a_ = tf .argmax (list_q_ ).eval ()
9198 m = np .zeros (self .atoms )
@@ -103,6 +110,8 @@ def train(self,s,r,action,s_,gamma):
103110 self .sess .run (self .optimizer ,feed_dict = {self .state_input :[s ] , self .action_input :[action ], self .m_input : m })
104111
105112
113+ if self .time_step % self .config .UPDATE_TARGET_NET == 0 :
114+ self .sess .run (self .update_target_net )
106115
107116 def save_model (self ):
108117 print ("Model saved in : " , self .saver .save (self .sess , self .config .MODEL_PATH ))
0 commit comments