-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathone_hot.py
45 lines (37 loc) · 1.47 KB
/
one_hot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import gym
import numpy as np
from dreamer.envs.wrapper import EnvWrapper
from rlpyt.spaces.int_box import IntBox
from rlpyt.spaces.float_box import FloatBox
class OneHotAction(EnvWrapper):
def __init__(self, env):
print("isinstance(env.action_space, gym.spaces.Discrete)",
isinstance(env.action_space, gym.spaces.Discrete),
env.action_space)
print("isinstance(env.action_space, IntBox)",
isinstance(env.action_space, IntBox),
env.action_space)
assert isinstance(env.action_space, gym.spaces.Discrete) or isinstance(env.action_space, IntBox)
super().__init__(env)
self._dtype = np.float32
@property
def action_space(self):
shape = (self.env.action_space.n,)
space = FloatBox(low=0, high=1, shape=shape, dtype=self._dtype)
space.sample = self._sample_action
return space
def step(self, action):
index = np.argmax(action).astype(int)
reference = np.zeros_like(action)
reference[index] = 1
if not np.allclose(reference, action, atol=1e6):
raise ValueError(f'Invalid one-hot action:\n{action}')
return self.env.step(index)
def reset(self):
return self.env.reset()
def _sample_action(self):
actions = self.env.action_space.n
index = self.random.randint(0, actions)
reference = np.zeros(actions, dtype=self._dtype)
reference[index] = 1.0
return reference