|
21 | 21 |
|
22 | 22 |
|
23 | 23 | class Actor(object): |
24 | | - def __init__(self, n_features, action_range, lr=0.0001): |
| 24 | + def __init__(self, sess, n_features, action_range, lr=0.0001): |
| 25 | + self.sess = sess |
25 | 26 | with tf.name_scope('inputs'): |
26 | 27 | self.state = tf.placeholder(tf.float32, [n_features, ], "state") |
27 | 28 | state = tf.expand_dims(self.state, axis=0) |
@@ -79,7 +80,8 @@ def choose_action(self, s): |
79 | 80 |
|
80 | 81 |
|
81 | 82 | class Critic(object): |
82 | | - def __init__(self, n_features, lr=0.01): |
| 83 | + def __init__(self, sess, n_features, lr=0.01): |
| 84 | + self.sess = sess |
83 | 85 | with tf.name_scope('inputs'): |
84 | 86 | self.state = tf.placeholder(tf.float32, [n_features, ], "state") |
85 | 87 | state = tf.expand_dims(self.state, axis=0) |
@@ -118,53 +120,56 @@ def evaluate(self, s): |
118 | 120 |
|
119 | 121 |
|
120 | 122 | OUTPUT_GRAPH = False |
| 123 | +MAX_EPISODE = 3000 |
121 | 124 | EPISODE_TIME_THRESHOLD = 300 |
122 | 125 | DISPLAY_REWARD_THRESHOLD = -550 # renders environment if total episode reward is greater then this threshold |
123 | 126 | RENDER = False # rendering wastes time |
124 | 127 | GAMMA = 0.9 |
| 128 | +LR_A = 0.001 # learning rate for actor |
| 129 | +LR_C = 0.01 # learning rate for critic |
125 | 130 |
|
126 | 131 | env = gym.make('Pendulum-v0') |
127 | | -# env.seed(1) # reproducible |
128 | | - |
129 | | -actor = Actor(n_features=env.observation_space.shape[0], action_range=[env.action_space.low[0], env.action_space.high[0]], lr=0.001) |
130 | | -critic = Critic(n_features=env.observation_space.shape[0], lr=0.002) |
131 | | - |
132 | | -with tf.Session() as sess: |
133 | | - if OUTPUT_GRAPH: |
134 | | - tf.summary.FileWriter("logs/", sess.graph) |
135 | | - |
136 | | - actor.sess, critic.sess = sess, sess # define the tf session |
137 | | - tf.global_variables_initializer().run() |
138 | | - |
139 | | - for i_episode in range(3000): |
140 | | - observation = env.reset() |
141 | | - t = 0 |
142 | | - ep_rs = [] |
143 | | - while True: |
144 | | - # if RENDER: |
145 | | - env.render() |
146 | | - action, mu, sigma = actor.choose_action(observation) |
147 | | - |
148 | | - observation_, reward, done, info = env.step(action) |
149 | | - reward /= 10 |
150 | | - TD_target = reward + GAMMA * critic.evaluate(observation_) # r + gamma * V_next |
151 | | - TD_eval = critic.evaluate(observation) # V_now |
152 | | - TD_error = TD_target - TD_eval |
153 | | - |
154 | | - actor.update(s=observation, a=action, adv=TD_error) |
155 | | - critic.update(s=observation, target=TD_target) |
156 | | - |
157 | | - observation = observation_ |
158 | | - t += 1 |
159 | | - # print(reward) |
160 | | - ep_rs.append(reward) |
161 | | - if t > EPISODE_TIME_THRESHOLD: |
162 | | - ep_rs_sum = sum(ep_rs) |
163 | | - if 'running_reward' not in globals(): |
164 | | - running_reward = ep_rs_sum |
165 | | - else: |
166 | | - running_reward = running_reward * 0.9 + ep_rs_sum * 0.1 |
167 | | - if running_reward > DISPLAY_REWARD_THRESHOLD: RENDER = True # rendering |
168 | | - print("episode:", i_episode, " reward:", int(running_reward)) |
169 | | - break |
| 132 | +env.seed(1) # reproducible |
| 133 | + |
| 134 | +sess = tf.Session() |
| 135 | + |
| 136 | +actor = Actor(sess, n_features=env.observation_space.shape[0], action_range=[env.action_space.low[0], env.action_space.high[0]], lr=LR_A) |
| 137 | +critic = Critic(sess, n_features=env.observation_space.shape[0], lr=LR_C) |
| 138 | + |
| 139 | +sess.run(tf.global_variables_initializer()) |
| 140 | + |
| 141 | +if OUTPUT_GRAPH: |
| 142 | + tf.summary.FileWriter("logs/", sess.graph) |
| 143 | + |
| 144 | +for i_episode in range(MAX_EPISODE): |
| 145 | + s = env.reset() |
| 146 | + t = 0 |
| 147 | + ep_rs = [] |
| 148 | + while True: |
| 149 | + # if RENDER: |
| 150 | + env.render() |
| 151 | + a, mu, sigma = actor.choose_action(s) |
| 152 | + |
| 153 | + s_, r, done, info = env.step(a) |
| 154 | + r /= 10 |
| 155 | + TD_target = r + GAMMA * critic.evaluate(s_) # r + gamma * V_next |
| 156 | + TD_eval = critic.evaluate(s) # V_now |
| 157 | + TD_error = TD_target - TD_eval |
| 158 | + |
| 159 | + actor.update(s=s, a=a, adv=TD_error) |
| 160 | + critic.update(s=s, target=TD_target) |
| 161 | + |
| 162 | + s = s_ |
| 163 | + t += 1 |
| 164 | + # print(reward) |
| 165 | + ep_rs.append(r) |
| 166 | + if t > EPISODE_TIME_THRESHOLD: |
| 167 | + ep_rs_sum = sum(ep_rs) |
| 168 | + if 'running_reward' not in globals(): |
| 169 | + running_reward = ep_rs_sum |
| 170 | + else: |
| 171 | + running_reward = running_reward * 0.9 + ep_rs_sum * 0.1 |
| 172 | + if running_reward > DISPLAY_REWARD_THRESHOLD: RENDER = True # rendering |
| 173 | + print("episode:", i_episode, " reward:", int(running_reward)) |
| 174 | + break |
170 | 175 |
|
0 commit comments