Skip to content
This repository has been archived by the owner on Mar 17, 2019. It is now read-only.

Commit

Permalink
Merge pull request #140 from mkhansen-intel/fix_env_reset
Browse files Browse the repository at this point in the history
Fix monitor env.reset() issue #135
  • Loading branch information
ahcorde authored Apr 25, 2018
2 parents 1fbedc1 + 886c043 commit bdd3fe5
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions examples/turtlebot/circuit2_turtlebot_lidar_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def clear_monitor_files(training_dir):
#Each time we run through the entire dataset, it's called an epoch.
#PARAMETER LIST
epochs = 1000
steps = 10000
steps = 1000
updateTargetNetwork = 10000
explorationRate = 1
minibatch_size = 64
Expand Down Expand Up @@ -318,6 +318,7 @@ def clear_monitor_files(training_dir):
clear_monitor_files(outdir)
copy_tree(monitor_path,outdir)

env._max_episode_steps = steps # env returns done after _max_episode_steps
env = gym.wrappers.Monitor(env, outdir,force=not continue_execution, resume=continue_execution)

last100Scores = [0] * 100
Expand All @@ -329,13 +330,14 @@ def clear_monitor_files(training_dir):
start_time = time.time()

#start iterating from 'current epoch'.

for epoch in xrange(current_epoch+1, epochs+1, 1):
observation = env.reset()
cumulated_reward = 0
done = False
episode_step = 0

# number of timesteps
for t in xrange(steps):
# run until env returns done
while not done:
# env.render()
qValues = deepQ.getQValues(observation)

Expand All @@ -357,23 +359,18 @@ def clear_monitor_files(training_dir):

observation = newObservation

if (t >= 1000):
print ("reached the end! :D")
done = True

env._flush(force=True)
if done:
last100Scores[last100ScoresIndex] = t
last100Scores[last100ScoresIndex] = episode_step
last100ScoresIndex += 1
if last100ScoresIndex >= 100:
last100Filled = True
last100ScoresIndex = 0
if not last100Filled:
print ("EP "+str(epoch)+" - {} timesteps".format(t+1)+" Exploration="+str(round(explorationRate, 2)))
print ("EP " + str(epoch) + " - " + format(episode_step + 1) + "/" + str(steps) + " Episode steps Exploration=" + str(round(explorationRate, 2)))
else :
m, s = divmod(int(time.time() - start_time), 60)
h, m = divmod(m, 60)
print ("EP "+str(epoch)+" - {} timesteps".format(t+1)+" - last100 Steps : "+str((sum(last100Scores)/len(last100Scores)))+" - Cumulated R: "+str(cumulated_reward)+" Eps="+str(round(explorationRate, 2))+" Time: %d:%02d:%02d" % (h, m, s))
print ("EP " + str(epoch) + " - " + format(episode_step + 1) + "/" + str(steps) + " Episode steps - last100 Steps : " + str((sum(last100Scores) / len(last100Scores))) + " - Cumulated R: " + str(cumulated_reward) + " Eps=" + str(round(explorationRate, 2)) + " Time: %d:%02d:%02d" % (h, m, s))
if (epoch)%100==0:
#save model weights and monitoring data every 100 epochs.
deepQ.saveModel('/tmp/turtle_c2_dqn_ep'+str(epoch)+'.h5')
Expand All @@ -385,13 +382,14 @@ def clear_monitor_files(training_dir):
parameter_dictionary = dict(zip(parameter_keys, parameter_values))
with open('/tmp/turtle_c2_dqn_ep'+str(epoch)+'.json', 'w') as outfile:
json.dump(parameter_dictionary, outfile)
break

stepCounter += 1
if stepCounter % updateTargetNetwork == 0:
deepQ.updateTargetNetwork()
print ("updating target network")

episode_step += 1

explorationRate *= 0.995 #epsilon decay
# explorationRate -= (2.0/epochs)
explorationRate = max (0.05, explorationRate)
Expand Down

0 comments on commit bdd3fe5

Please sign in to comment.