Skip to content

Commit 48c04b3

Browse files
author
shixiaowen03
committed
noisy net
1 parent 5aa1e47 commit 48c04b3

File tree

8 files changed

+700
-154
lines changed

8 files changed

+700
-154
lines changed

.idea/workspace.xml

Lines changed: 142 additions & 154 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

RL/Basic-NoisyNet-Demo/Config.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
class NoisyNetDQNConfig:
2+
# ENV_NAME = "CartPole-v1"
3+
ENV_NAME = 'Breakout-v0' # 0: hold 1: throw the ball 2: move right 3: move left
4+
# ENV_NAME = "Freeway-v0"
5+
GAMMA = 0.99 # discount factor for target Q
6+
START_TRAINING = 1000 # experience replay buffer size
7+
BATCH_SIZE = 64 # size of minibatch
8+
UPDATE_TARGET_NET = 400 # update eval_network params every 200 steps
9+
LEARNING_RATE = 0.01
10+
MODEL_PATH = './model/NoisyNetDQN_model'
11+
12+
INITIAL_EPSILON = 1.0 # starting value of epsilon
13+
FINAL_EPSILON = 0.01 # final value of epsilon
14+
EPSILIN_DECAY = 0.999
15+
16+
replay_buffer_size = 2000
17+
iteration = 5
18+
episode = 300 # 300 games per iteration
19+
20+
noisy_distribution = 'factorised' # independent or factorised
21+
22+
23+
24+
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import tensorflow as tf
2+
import numpy as np
3+
import random
4+
from collections import deque
5+
6+
from utils import conv,noisy_dense
7+
8+
class NoisyNetDQN():
9+
def __init__(self,env,config):
10+
self.sess = tf.InteractiveSession()
11+
self.config = config
12+
13+
self.replay_buffer = deque(maxlen = self.config.replay_buffer_size)
14+
self.time_step = 0
15+
16+
self.state_dim = env.observation_space.shape
17+
self.action_dim = env.action_space.n
18+
19+
print('state_dim:', self.state_dim)
20+
print('action_dim:', self.action_dim)
21+
22+
self.action_batch = tf.placeholder('int32',[None])
23+
self.y_input = tf.placeholder('float',[None,self.action_dim])
24+
25+
batch_shape = [None]
26+
batch_shape.extend(self.state_dim)
27+
28+
self.eval_input = tf.placeholder('float',batch_shape)
29+
self.target_input = tf.placeholder('float',batch_shape)
30+
31+
self.build_noisy_dqn_net()
32+
33+
self.saver = tf.train.Saver()
34+
35+
self.sess.run(tf.global_variables_initializer())
36+
37+
self.save_model()
38+
self.restore_model()
39+
40+
def build_layers(self,state,c_names,units_1,units_2,w_i,b_i,reg=None):
41+
with tf.variable_scope('conv1'):
42+
conv1 = conv(state,[5,5,3,6],[6],[1,2,2,1],w_i,b_i)
43+
with tf.variable_scope('conv2'):
44+
conv2 = conv(conv1,[3,3,6,12],[12],[1,2,2,1],w_i,b_i)
45+
with tf.variable_scope('flatten'):
46+
flatten = tf.contrib.layers.flatten(conv2)
47+
48+
with tf.variable_scope('dense1'):
49+
dense1 = noisy_dense(flatten,units_1,[units_1],c_names,w_i,b_i,noisy_distribution = self.config.noisy_distribution)
50+
51+
with tf.variable_scope('dense2'):
52+
dense2 = noisy_dense(dense1,units_2,[units_2],c_names,w_i,b_i,noisy_distribution = self.config.noisy_distribution)
53+
54+
with tf.variable_scope('dense3'):
55+
dense3 = noisy_dense(dense2,self.action_dim,[self.action_dim],c_names,w_i,b_i,noisy_distribution = self.config.noisy_distribution)
56+
57+
return dense3
58+
59+
def build_noisy_dqn_net(self):
60+
with tf.variable_scope('target_net'):
61+
c_names = ['target_net_arams',tf.GraphKeys.GLOBAL_VARIABLES]
62+
w_i = tf.random_uniform_initializer(-0.1,0.1)
63+
b_i = tf.constant_initializer(0.1)
64+
self.q_target = self.build_layers(self.target_input,c_names,24,24,w_i,b_i)
65+
66+
with tf.variable_scope('eval_net'):
67+
c_names = ['eval_net_params',tf.GraphKeys.GLOBAL_VARIABLES]
68+
w_i = tf.random_uniform_initializer(-0.1,0.1)
69+
b_i = tf.constant_initializer(0.1)
70+
self.q_eval = self.build_layers(self.eval_input,c_names,24,24,w_i,b_i)
71+
72+
self.loss = tf.reduce_mean(tf.squared_difference(self.q_eval,self.y_input))
73+
74+
self.optimizer = tf.train.AdamOptimizer(self.config.LEARNING_RATE).minimize(self.loss)
75+
76+
eval_params = tf.get_collection("eval_net_params")
77+
target_params = tf.get_collection('target_net_params')
78+
79+
self.update_target_net = [tf.assign(t,e) for t,e in zip(target_params,eval_params)]
80+
81+
82+
def save_model(self):
83+
print("Model saved in : ", self.saver.save(self.sess, self.config.MODEL_PATH))
84+
85+
def restore_model(self):
86+
self.saver.restore(self.sess, self.config.MODEL_PATH)
87+
print("Model restored.")
88+
89+
90+
def perceive(self,state,action,reward,next_state,done):
91+
self.replay_buffer.append((state,action,reward,next_state,done))
92+
93+
94+
def train_q_network(self,update=True):
95+
96+
if len(self.replay_buffer) < self.config.START_TRAINING:
97+
return
98+
99+
self.time_step += 1
100+
minibatch = random.sample(self.replay_buffer,self.config.BATCH_SIZE)
101+
102+
np.random.shuffle(minibatch)
103+
104+
state_batch = [data[0] for data in minibatch]
105+
action_batch = [data[1] for data in minibatch]
106+
reward_batch = [data[2] for data in minibatch]
107+
next_state_batch = [data[3] for data in minibatch]
108+
done = [data[4] for data in minibatch]
109+
110+
q_target = self.sess.run(self.q_target,feed_dict={self.target_input:next_state_batch})
111+
q_eval = self.sess.run(self.q_eval,feed_dict={self.eval_input:state_batch})
112+
113+
done = np.array(done) + 0
114+
115+
# DQN的结构 r + max q_target[a]
116+
y_batch = np.zeros((self.config.BATCH_SIZE,self.action_dim))
117+
for i in range(0,self.config.BATCH_SIZE):
118+
temp = q_eval[i]
119+
action = np.argmax(q_target[i])
120+
temp[action_batch[i]] = reward_batch[i] + (1 - done[i]) * self.config.GAMMA * q_target[i][action]
121+
y_batch[i] = temp
122+
123+
124+
self.sess.run(self.optimizer,feed_dict={
125+
self.y_input:y_batch,
126+
self.eval_input:state_batch,
127+
self.action_batch:action_batch
128+
})
129+
130+
if update and self.time_step % self.config.UPDATE_TARGET_NET == 0:
131+
self.sess.run(self.update_target_net)
132+
133+
134+
135+
def noisy_action(self, state):
136+
137+
return np.argmax(self.sess.run(self.q_target,feed_dict={self.target_input: [state]})[0])
138+
139+
140+
141+
142+
143+
144+
145+
146+
147+
148+

RL/Basic-NoisyNet-Demo/main.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import matplotlib.pyplot as plt
2+
import gym
3+
4+
import numpy as np
5+
import tensorflow as tf
6+
7+
import pickle
8+
9+
from Config import NoisyNetDQNConfig
10+
from NoisyNetDQN import NoisyNetDQN
11+
12+
def map_scores(dqfd_scores=None, ddqn_scores=None, xlabel=None, ylabel=None):
13+
if dqfd_scores is not None:
14+
plt.plot(dqfd_scores, 'r')
15+
if ddqn_scores is not None:
16+
plt.plot(ddqn_scores, 'b')
17+
if xlabel is not None:
18+
plt.xlabel(xlabel)
19+
if ylabel is not None:
20+
plt.ylabel(ylabel)
21+
plt.show()
22+
23+
24+
def BreakOut_NoisyNetDQN(index,env):
25+
with tf.variable_scope('DQfD_' + str(index)):
26+
agent = NoisyNetDQN(env,NoisyNetDQNConfig())
27+
scores = []
28+
for e in range(NoisyNetDQNConfig.episode):
29+
done = False
30+
score = 0 # sum of reward in one episode
31+
state = env.reset()
32+
# while done is False:
33+
last_lives = 5
34+
throw = True
35+
items_buffer = []
36+
while not done:
37+
env.render()
38+
action = 1 if throw else agent.noisy_action(state)
39+
next_state, real_reward, done, info = env.step(action)
40+
lives = info['ale.lives']
41+
train_reward = 1 if throw else -1 if lives < last_lives else real_reward
42+
score += real_reward
43+
throw = lives < last_lives
44+
last_lives = lives
45+
# agent.perceive(state, action, train_reward, next_state, done) # miss: -1 break: reward nothing: 0
46+
items_buffer.append([state, action, next_state, done]) # miss: -1 break: reward nothing: 0
47+
state = next_state
48+
if train_reward != 0: # train when miss the ball or score or throw the ball in the beginning
49+
print ('len(items_buffer):', len(items_buffer))
50+
for item in items_buffer:
51+
agent.perceive(item[0], item[1], -1 if throw else train_reward, item[2], item[3])
52+
agent.train_q_network(update=False)
53+
items_buffer = []
54+
scores.append(score)
55+
agent.sess.run(agent.update_target_net)
56+
print("episode:", e, " score:", score, " memory length:", len(agent.replay_buffer))
57+
58+
return scores
59+
60+
61+
if __name__ == '__main__':
62+
env = gym.make('Breakout-v0') # 打砖块游戏
63+
64+
NoisyNetDQN_sum_scores = np.zeros(NoisyNetDQNConfig.episode)
65+
66+
for i in range(NoisyNetDQNConfig.iteration):
67+
scores = BreakOut_NoisyNetDQN(i,env)
68+
dqfd_sum_scores = [a + b for a, b in zip(scores, NoisyNetDQN_sum_scores)]
69+
NoisyNetDQN_mean_scores = NoisyNetDQN_sum_scores / NoisyNetDQNConfig.iteration
70+
with open('/Users/mahailong/DQfD/NoisyNetDQN_mean_scores.p', 'wb') as f:
71+
pickle.dump(NoisyNetDQN_mean_scores, f, protocol=2)

RL/Basic-NoisyNet-Demo/readme

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
论文名称:Noisy Networks for Exploration
2+
论文下载地址:https://arxiv.org/abs/1706.10295v1

RL/Basic-NoisyNet-Demo/utils.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import tensorflow as tf
2+
from tensorflow.python.framework import ops
3+
4+
5+
def conv(inputs, kernel_shape, bias_shape, strides, w_i, b_i=None, activation=tf.nn.relu):
6+
7+
weights = tf.get_variable('weights', shape=kernel_shape, initializer=w_i)
8+
conv = tf.nn.conv2d(inputs, weights, strides=strides, padding='SAME')
9+
if bias_shape is not None:
10+
biases = tf.get_variable('biases', shape=bias_shape, initializer=b_i)
11+
return activation(conv + biases) if activation is not None else conv+biases
12+
return activation(conv) if activation is not None else conv
13+
14+
def noisy_dense(inputs, units, bias_shape, c_names, w_i, b_i=None, activation=tf.nn.relu, noisy_distribution='factorised'):
15+
def f(e_list):
16+
return tf.multiply(tf.sign(e_list), tf.pow(tf.abs(e_list), 0.5))
17+
18+
if not isinstance(inputs, ops.Tensor):
19+
inputs = ops.convert_to_tensor(inputs, dtype='float')
20+
21+
if len(inputs.shape) > 2:
22+
inputs = tf.contrib.layers.flatten(inputs)
23+
flatten_shape = inputs.shape[1]
24+
weights = tf.get_variable('weights', shape=[flatten_shape, units], initializer=w_i)
25+
w_noise = tf.get_variable('w_noise', [flatten_shape, units], initializer=w_i, collections=c_names)
26+
if noisy_distribution == 'independent':
27+
weights += tf.multiply(tf.random_normal(shape=w_noise.shape), w_noise)
28+
elif noisy_distribution == 'factorised':
29+
noise_1 = f(tf.random_normal(tf.TensorShape([flatten_shape, 1]), dtype=tf.float32)) # 注意是列向量形式,方便矩阵乘法
30+
noise_2 = f(tf.random_normal(tf.TensorShape([1, units]), dtype=tf.float32))
31+
weights += tf.multiply(noise_1 * noise_2, w_noise)
32+
dense = tf.matmul(inputs, weights)
33+
if bias_shape is not None:
34+
assert bias_shape[0] == units
35+
biases = tf.get_variable('biases', shape=bias_shape, initializer=b_i)
36+
b_noise = tf.get_variable('b_noise', [1, units], initializer=b_i, collections=c_names)
37+
if noisy_distribution == 'independent':
38+
biases += tf.multiply(tf.random_normal(shape=b_noise.shape), b_noise)
39+
elif noisy_distribution == 'factorised':
40+
biases += tf.multiply(noise_2, b_noise)
41+
return activation(dense + biases) if activation is not None else dense + biases
42+
return activation(dense) if activation is not None else dense
43+
44+
45+
46+
47+
48+
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"cells": [],
3+
"metadata": {},
4+
"nbformat": 4,
5+
"nbformat_minor": 2
6+
}

0 commit comments

Comments
 (0)