Skip to content

Commit 209ec99

Browse files
committed
[src/rl] Add experiment and proper results handling
1 parent aa68e0e commit 209ec99

File tree

6 files changed

+106
-3
lines changed

6 files changed

+106
-3
lines changed
+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#!/usr/bin/env python
2+
3+
import gym
4+
5+
from evsim.experiments import setup_logger
6+
from evsim.rl import DDQN
7+
8+
name = "DDDQN-100-100"
9+
episodes = 2
10+
episode_steps = 65334
11+
12+
setup_logger("sim-{}".format(name), write=True)
13+
env = gym.make("evsim-v0")
14+
env.imbalance_costs(5000)
15+
env.prediction_accuracy((100, 100))
16+
17+
dqqn = DDQN(env, name, memory_limit=episode_steps, nb_eps=episode_steps, nb_warmup=1000)
18+
dqqn.run(episodes * episode_steps)
19+
20+
dqqn.test()

src/evsim/experiments/DDDQN-80-95.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#!/usr/bin/env python
2+
3+
import gym
4+
5+
from evsim.experiments import setup_logger
6+
from evsim.rl import DDQN
7+
8+
name = "DDDQN-100-100"
9+
episodes = 2
10+
episode_steps = 65334
11+
12+
setup_logger("sim-{}".format(name), write=True)
13+
env = gym.make("evsim-v0")
14+
env.imbalance_costs(5000)
15+
env.prediction_accuracy((80, 95))
16+
17+
dqqn = DDQN(
18+
env,
19+
name,
20+
memory_limit=(episodes * episode_steps),
21+
nb_eps=(1.5 * episode_steps),
22+
nb_warmup=1000,
23+
)
24+
dqqn.run(episodes * episode_steps)
25+
26+
dqqn.test()

src/evsim/experiments/DDDQN-90-99.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#!/usr/bin/env python
2+
3+
import gym
4+
5+
from evsim.experiments import setup_logger
6+
from evsim.rl import DDQN
7+
8+
name = "DDDQN-90-99"
9+
episodes = 2
10+
episode_steps = 65334
11+
12+
setup_logger("sim-{}".format(name), write=True)
13+
env = gym.make("evsim-v0")
14+
env.imbalance_costs(5000)
15+
env.prediction_accuracy((90, 99))
16+
17+
dqqn = DDQN(
18+
env,
19+
name,
20+
memory_limit=(episodes * episode_steps),
21+
nb_eps=(1.5 * episode_steps),
22+
nb_warmup=1000,
23+
)
24+
dqqn.run(episodes * episode_steps)
25+
26+
dqqn.test()

src/evsim/experiments/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# flake8: noqa
2+
from .logger import setup_logger

src/evsim/experiments/logger.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#!/usr/bin/env python
2+
3+
import logging
4+
import os
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
def setup_logger(name, write=True):
10+
f = logging.Formatter("%(levelname)-7s %(message)s")
11+
12+
sh = logging.StreamHandler()
13+
sh.setFormatter(f)
14+
sh.setLevel(logging.ERROR)
15+
handlers = [sh]
16+
17+
if write:
18+
os.makedirs("./logs", exist_ok=True)
19+
fh = logging.FileHandler("./logs/%s.log" % name, mode="w")
20+
fh.setFormatter(f)
21+
fh.setLevel(logging.DEBUG)
22+
handlers = [sh, fh]
23+
24+
logging.basicConfig(
25+
level=logging.DEBUG, datefmt="%d.%m. %H:%M:%S", handlers=handlers
26+
)

src/evsim/rl/ddqn.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class DDQN:
1515
def __init__(
1616
self,
1717
env,
18+
name,
1819
memory_limit=10000,
1920
nb_eps=10000,
2021
nb_warmup=100,
@@ -27,8 +28,10 @@ def __init__(
2728
np.random.seed(123)
2829
random.seed(123)
2930

30-
self.log_filename = "./logs/dqn_{}_log.json".format(self.env.spec.id)
31-
self.weights_filename = "./results/dqn_{}_weights.h5f".format(self.env.spec.id)
31+
self.name = name
32+
self.log_filename = "./logs/{}_log.json".format(self.name)
33+
self.weights_filename = "./results/{}_weights.h5f".format(self.name)
34+
self.result_filename = "./results/{}_result.csv".format(self.name)
3235

3336
# Extract the number of actions form the environment
3437
nb_action = self.env.action_space.spaces[0].n
@@ -94,4 +97,4 @@ def run(self, steps):
9497
def test(self):
9598
self.dqn.load_weights(self.weights_filename)
9699
self.dqn.test(self.env, nb_episodes=1, visualize=False)
97-
self.env.save_results("./results/sim_result_ep_{}.csv".format("test"))
100+
self.env.save_results(self.result_filename)

0 commit comments

Comments
 (0)