Skip to content

Commit ea06532

Browse files
committed
dqn work in progress..
1 parent f3868f4 commit ea06532

File tree

1 file changed

+124
-0
lines changed

1 file changed

+124
-0
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
%matplotlib inline
2+
3+
import torch
4+
import torch.nn as nn
5+
import gym
6+
import random
7+
import numpy as np
8+
import torchvision.transforms as transforms
9+
import matplotlib.pyplot as plt
10+
from torch.autograd import Variable
11+
from collections import deque, namedtuple
12+
13+
env = gym.envs.make("CartPole-v0")
14+
15+
class Net(nn.Module):
16+
def __init__(self):
17+
super(Net, self).__init__()
18+
self.fc1 = nn.Linear(4, 128)
19+
self.tanh = nn.Tanh()
20+
self.fc2 = nn.Linear(128, 2)
21+
self.init_weights()
22+
23+
def init_weights(self):
24+
self.fc1.weight.data.uniform_(-0.1, 0.1)
25+
self.fc2.weight.data.uniform_(-0.1, 0.1)
26+
27+
def forward(self, x):
28+
out = self.fc1(x)
29+
out = self.tanh(out)
30+
out = self.fc2(out)
31+
return out
32+
33+
def make_epsilon_greedy_policy(network, epsilon, nA):
34+
def policy(state):
35+
sample = random.random()
36+
if sample < (1-epsilon) + (epsilon/nA):
37+
q_values = network(state.view(1, -1))
38+
action = q_values.data.max(1)[1][0, 0]
39+
else:
40+
action = random.randrange(nA)
41+
return action
42+
return policy
43+
44+
class ReplayMemory(object):
45+
46+
def __init__(self, capacity):
47+
self.memory = deque()
48+
self.capacity = capacity
49+
50+
def push(self, transition):
51+
if len(self.memory) > self.capacity:
52+
self.memory.popleft()
53+
self.memory.append(transition)
54+
55+
def sample(self, batch_size):
56+
return random.sample(self.memory, batch_size)
57+
58+
def __len__(self):
59+
return len(self.memory)
60+
61+
def to_tensor(ndarray, volatile=False):
62+
return Variable(torch.from_numpy(ndarray), volatile=volatile).float()
63+
64+
def deep_q_learning(num_episodes=10, batch_size=100,
65+
discount_factor=0.95, epsilon=0.1, epsilon_decay=0.95):
66+
67+
# Q-Network and memory
68+
net = Net()
69+
memory = ReplayMemory(10000)
70+
71+
# Loss and Optimizer
72+
criterion = nn.MSELoss()
73+
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
74+
75+
for i_episode in range(num_episodes):
76+
77+
# Set policy (TODO: decaying epsilon)
78+
#if (i_episode+1) % 100 == 0:
79+
# epsilon *= 0.9
80+
81+
policy = make_epsilon_greedy_policy(
82+
net, epsilon, env.action_space.n)
83+
84+
# Start an episode
85+
state = env.reset()
86+
87+
for t in range(10000):
88+
89+
# Sample action from epsilon greed policy
90+
action = policy(to_tensor(state))
91+
next_state, reward, done, _ = env.step(action)
92+
93+
94+
# Restore transition in memory
95+
memory.push([state, action, reward, next_state])
96+
97+
98+
if len(memory) >= batch_size:
99+
# Sample mini-batch transitions from memory
100+
batch = memory.sample(batch_size)
101+
state_batch = np.vstack([trans[0] for trans in batch])
102+
action_batch =np.vstack([trans[1] for trans in batch])
103+
reward_batch = np.vstack([trans[2] for trans in batch])
104+
next_state_batch = np.vstack([trans[3] for trans in batch])
105+
106+
# Forward + Backward + Opimize
107+
net.zero_grad()
108+
q_values = net(to_tensor(state_batch))
109+
next_q_values = net(to_tensor(next_state_batch, volatile=True))
110+
next_q_values.volatile = False
111+
112+
td_target = to_tensor(reward_batch) + discount_factor * (next_q_values).max(1)[0]
113+
loss = criterion(q_values.gather(1,
114+
to_tensor(action_batch).long().view(-1, 1)), td_target)
115+
loss.backward()
116+
optimizer.step()
117+
118+
if done:
119+
break
120+
121+
state = next_state
122+
123+
if len(memory) >= batch_size and (i_episode+1) % 10 == 0:
124+
print ('episode: %d, time: %d, loss: %.4f' %(i_episode, t, loss.data[0]))

0 commit comments

Comments
 (0)