Skip to content

Commit c659ae9

Browse files
author
shixiaowen03
committed
distributional rl
1 parent 62bca23 commit c659ae9

File tree

8 files changed

+1458
-168
lines changed

8 files changed

+1458
-168
lines changed

.idea/workspace.xml

Lines changed: 177 additions & 168 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import tensorflow as tf
2+
import numpy as np
3+
import random
4+
from collections import deque
5+
from Config import Categorical_DQN_Config
6+
from utils import conv, dense
7+
import math
8+
9+
10+
class Categorical_DQN():
11+
def __init__(self,env,config):
12+
self.sess = tf.InteractiveSession()
13+
self.config = config
14+
self.v_max = self.config.v_max
15+
self.v_min = self.config.v_min
16+
self.atoms = self.config.atoms
17+
18+
self.time_step = 0
19+
self.epsilon = self.config.INITIAL_EPSILON
20+
self.state_shape = env.observation_space.shape
21+
self.action_dim = env.action_space.n
22+
23+
target_state_shape = [1]
24+
target_state_shape.extend(self.state_shape)
25+
26+
self.state_input = tf.placeholder(tf.float32,target_state_shape)
27+
self.action_input = tf.placeholder(tf.int32,[1,1])
28+
29+
self.m_input = tf.placeholder(tf.float32,[self.atoms])
30+
31+
self.delta_z = (self.v_max - self.v_min) / (self.atoms - 1)
32+
self.z = [self.v_min + i * self.delta_z for i in range(self.atoms)]
33+
34+
self.build_cate_dqn_net()
35+
36+
self.saver = tf.train.Saver()
37+
38+
self.sess.run(tf.global_variables_initializer())
39+
40+
self.save_model()
41+
self.restore_model()
42+
43+
44+
45+
46+
def build_layers(self, state, action, c_names, units_1, units_2, w_i, b_i, reg=None):
47+
with tf.variable_scope('conv1'):
48+
conv1 = conv(state, [5, 5, 3, 6], [6], [1, 2, 2, 1], w_i, b_i)
49+
with tf.variable_scope('conv2'):
50+
conv2 = conv(conv1, [3, 3, 6, 12], [12], [1, 2, 2, 1], w_i, b_i)
51+
with tf.variable_scope('flatten'):
52+
flatten = tf.contrib.layers.flatten(conv2)
53+
54+
with tf.variable_scope('dense1'):
55+
dense1 = dense(flatten, units_1, [units_1], w_i, b_i)
56+
with tf.variable_scope('dense2'):
57+
dense2 = dense(dense1, units_2, [units_2], w_i, b_i)
58+
with tf.variable_scope('concat'):
59+
concatenated = tf.concat([dense2, tf.cast(action, tf.float32)], 1)
60+
with tf.variable_scope('dense3'):
61+
dense3 = dense(concatenated, self.atoms, [self.atoms], w_i, b_i) # 返回
62+
return dense3
63+
64+
def build_cate_dqn_net(self):
65+
with tf.variable_scope('target_net'):
66+
c_names = ['target_net_arams',tf.GraphKeys.GLOBAL_VARIABLES]
67+
w_i = tf.random_uniform_initializer(-0.1,0.1)
68+
b_i = tf.constant_initializer(0.1)
69+
self.z_target = self.build_layers(self.state_input,self.action_input,c_names,24,24,w_i,b_i)
70+
71+
with tf.variable_scope('eval_net'):
72+
c_names = ['eval_net_params',tf.GraphKeys.GLOBAL_VARIABLES]
73+
w_i = tf.random_uniform_initializer(-0.1,0.1)
74+
b_i = tf.constant_initializer(0.1)
75+
self.z_eval = self.build_layers(self.state_input,self.action_input,c_names,24,24,w_i,b_i)
76+
77+
78+
self.q_eval = tf.reduce_sum(self.z_eval * self.z)
79+
self.q_target = tf.reduce_sum(self.z_target * self.z)
80+
81+
self.cross_entropy_loss = -tf.reduce_sum(self.m_input * tf.log(self.z_eval))
82+
83+
self.optimizer = tf.train.AdamOptimizer(self.config.LEARNING_RATE).minimize(self.cross_entropy_loss)
84+
85+
86+
87+
88+
def train(self,s,r,action,s_,gamma):
89+
list_q_ = [self.sess.run(self.q_target,feed_dict={self.state_input:[s_],self.action_input:[[a]]}) for a in range(self.action_dim)]
90+
a_ = tf.argmax(list_q_).eval()
91+
m = np.zeros(self.atoms)
92+
p = self.sess.run(self.z_target,feed_dict = {self.state_input:[s_],self.action_input:[[a_]]})[0]
93+
for j in range(self.atoms):
94+
Tz = min(self.v_max,max(self.v_min,r+gamma * self.z[j]))
95+
bj = (Tz - self.v_min) / self.delta_z # 分在第几个块里
96+
l,u = math.floor(bj),math.ceil(bj) # 上下界
97+
98+
pj = p[j]
99+
100+
m[int(l)] += pj * (u - bj)
101+
m[int(u)] += pj * (bj - l)
102+
103+
self.sess.run(self.optimizer,feed_dict={self.state_input:[s] , self.action_input:[action], self.m_input: m })
104+
105+
106+
107+
def save_model(self):
108+
print("Model saved in : ", self.saver.save(self.sess, self.config.MODEL_PATH))
109+
110+
def restore_model(self):
111+
self.saver.restore(self.sess, self.config.MODEL_PATH)
112+
print("Model restored.")
113+
114+
115+
def greedy_action(self,s):
116+
self.epsilon = max(self.config.FINAL_EPSILON, self.epsilon * self.config.EPSILIN_DECAY)
117+
if random.random() <= self.epsilon:
118+
return random.randint(0, self.action_dim - 1)
119+
return np.argmax(
120+
[self.sess.run(self.q_target,feed_dict={self.state_input:[s],self.action_input:[[a]]}) for a in range(self.action_dim)])

RL/Basic-DisRL-Demo/Config.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
class Categorical_DQN_Config():
2+
v_min = 0
3+
v_max = 1000
4+
atoms = 51
5+
6+
# ENV_NAME = "CartPole-v1"
7+
ENV_NAME = 'Breakout-v0' # 0: hold 1: throw the ball 2: move right 3: move left
8+
# ENV_NAME = "Freeway-v0"
9+
GAMMA = 0.99 # discount factor for target Q
10+
START_TRAINING = 1000 # experience replay buffer size
11+
BATCH_SIZE = 64 # size of minibatch
12+
UPDATE_TARGET_NET = 400 # update eval_network params every 200 steps
13+
LEARNING_RATE = 0.01
14+
MODEL_PATH = './model/C51DQN_model'
15+
16+
INITIAL_EPSILON = 0.9 # starting value of epsilon
17+
FINAL_EPSILON = 0.05 # final value of epsilon
18+
EPSILIN_DECAY = 0.9999
19+
20+
replay_buffer_size = 2000
21+
iteration = 5
22+
episode = 300 # 300 games per iteration

RL/Basic-DisRL-Demo/main.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
2+
import matplotlib.pyplot as plt
3+
import tensorflow as tf
4+
import gym
5+
import numpy as np
6+
import pickle
7+
from Config import Categorical_DQN_Config
8+
from Categorical_DQN import Categorical_DQN
9+
10+
11+
def map_scores(dqfd_scores=None, ddqn_scores=None, xlabel=None, ylabel=None):
12+
if dqfd_scores is not None:
13+
plt.plot(dqfd_scores, 'r')
14+
if ddqn_scores is not None:
15+
plt.plot(ddqn_scores, 'b')
16+
if xlabel is not None:
17+
plt.xlabel(xlabel)
18+
if ylabel is not None:
19+
plt.ylabel(ylabel)
20+
plt.show()
21+
22+
def BreakOut_CDQN(index, env):
23+
with tf.variable_scope('DQfD_' + str(index)):
24+
agent = Categorical_DQN(env, Categorical_DQN_Config())
25+
scores = []
26+
for e in range(Categorical_DQN_Config.episode):
27+
done = False
28+
score = 0 # sum of reward in one episode
29+
state = env.reset()
30+
# while done is False:
31+
last_lives = 5
32+
throw = True
33+
items_buffer = []
34+
while not done:
35+
env.render()
36+
action = 1 if throw else agent.greedy_action(state)
37+
next_state, real_reward, done, info = env.step(action)
38+
lives = info['ale.lives']
39+
train_reward = 1 if throw else -1 if lives < last_lives else real_reward
40+
score += real_reward
41+
throw = lives < last_lives
42+
last_lives = lives
43+
# agent.train(state, train_reward, [action], next_state, 0.1)
44+
items_buffer.append([state, [action], next_state, 0.1])
45+
state = next_state
46+
if train_reward != 0: # train when miss the ball or score or throw the ball in the beginning
47+
for item in items_buffer:
48+
agent.train(item[0], -1 if throw else train_reward, item[1], item[2], item[3])
49+
items_buffer = []
50+
scores.append(score)
51+
agent.save_model()
52+
# if np.mean(scores[-min(10, len(scores)):]) > 495:
53+
# break
54+
return scores
55+
56+
57+
if __name__ == '__main__':
58+
env = gym.make("Breakout-v0")
59+
CDQN_sum_scores = np.zeros(Categorical_DQN_Config.episode)
60+
for i in range(Categorical_DQN_Config.iteration):
61+
scores = BreakOut_CDQN(i,env)
62+
c51_sum_scores = [a + b for a, b in zip(scores, CDQN_sum_scores)]
63+
C51DQN_mean_scores = CDQN_sum_scores / Categorical_DQN_Config.iteration
64+
with open('/Users/mahailong/C51DQN/C51DQN_mean_scores.p', 'wb') as f:
65+
pickle.dump(C51DQN_mean_scores, f, protocol=2)

RL/Basic-DisRL-Demo/readme

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
论文:A Distributional Perspective on Reinforcement Learning
2+
地址:https://arxiv.org/abs/1707.06887

RL/Basic-DisRL-Demo/utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import tensorflow as tf
2+
from tensorflow.python.framework import ops
3+
4+
5+
def conv(inputs, kernel_shape, bias_shape, strides, w_i, b_i=None, activation=tf.nn.relu):
6+
# 使用tf.layers
7+
# relu1 = tf.layers.conv2d(input_imgs, filters=24, kernel_size=[5, 5], strides=[2, 2],
8+
# padding='SAME', activation=tf.nn.relu,
9+
# kernel_initializer=w_i, bias_initializer=b_i)
10+
weights = tf.get_variable('weights', shape=kernel_shape, initializer=w_i)
11+
conv = tf.nn.conv2d(inputs, weights, strides=strides, padding='SAME')
12+
if bias_shape is not None:
13+
biases = tf.get_variable('biases', shape=bias_shape, initializer=b_i)
14+
return activation(conv + biases) if activation is not None else conv + biases
15+
return activation(conv) if activation is not None else conv
16+
17+
18+
def dense(inputs, units, bias_shape, w_i, b_i=None, activation=tf.nn.relu):
19+
# 使用tf.layers,注意:先flatten
20+
# dense1 = tf.layers.dense(tf.contrib.layers.flatten(relu5), activation=tf.nn.relu, units=50)
21+
if not isinstance(inputs, ops.Tensor):
22+
inputs = ops.convert_to_tensor(inputs, dtype='float')
23+
# dim_list = inputs.get_shape().as_list()
24+
# flatten_shape = dim_list[1] if len(dim_list) <= 2 else reduce(lambda x, y: x * y, dim_list[1:])
25+
# reshaped = tf.reshape(inputs, [dim_list[0], flatten_shape])
26+
if len(inputs.shape) > 2:
27+
inputs = tf.contrib.layers.flatten(inputs)
28+
flatten_shape = inputs.shape[1]
29+
weights = tf.get_variable('weights', shape=[flatten_shape, units], initializer=w_i)
30+
dense = tf.matmul(inputs, weights)
31+
if bias_shape is not None:
32+
assert bias_shape[0] == units
33+
biases = tf.get_variable('biases', shape=bias_shape, initializer=b_i)
34+
return activation(dense + biases) if activation is not None else dense + biases
35+
return activation(dense) if activation is not None else dense
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"cells": [],
3+
"metadata": {},
4+
"nbformat": 4,
5+
"nbformat_minor": 2
6+
}

0 commit comments

Comments
 (0)