Skip to content

Commit d204a62

Browse files
committed
wip: FrozenLakeEnv
1 parent 3b9d879 commit d204a62

File tree

1 file changed

+112
-8
lines changed

1 file changed

+112
-8
lines changed

noderl/environments/FrozenLakeEnv.ts

Lines changed: 112 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,28 @@
1+
import { full_matrix } from "../utils/lists"
12
import { choice_NxN } from "../utils/random"
23

4+
enum Action {
5+
LEFT = 0,
6+
DOWN = 1,
7+
RIGHT = 2,
8+
UP = 3,
9+
}
10+
11+
type State = Coordinates
12+
type Tile = 'S' | 'F' | 'H' | 'G'
13+
type LakeMap = Matrix<Tile>
14+
type LakeMapDict = { [size: string]: LakeMap }
15+
316
interface Transition {
417
prob: number,
5-
next_state: number,
18+
next_state: State,
619
reward: number,
720
done: boolean,
821
}
922

1023
export default class FrozenLakeEnv {
1124

12-
static MAPS = {
25+
static MAPS: LakeMapDict = {
1326
'4x4': [
1427
['S', 'F', 'F', 'F'],
1528
['F', 'H', 'F', 'H'],
@@ -28,24 +41,115 @@ export default class FrozenLakeEnv {
2841
]
2942
}
3043

31-
static generate_lake(size: Size, prob_frozen: number): Matrix<string> {
44+
static generate_lake(size: Size, prob_frozen: number): LakeMap {
3245
// TODO: implement a random generator
3346
return FrozenLakeEnv.MAPS['4x4']
3447
}
3548

36-
private lake: Matrix<string> = []
49+
private size: Size
50+
private lake: LakeMap = []
51+
private current_state: Coordinates = [0, 0]
3752
private n_actions: number = 4
38-
private n_states: number
39-
private MDP: Matrix<Transition[]> = []
40-
private current_state: Coordinates
53+
private n_states: number = 0
54+
55+
public MDP: Matrix<Transition[]>
4156

4257
constructor(size: Size = [4, 4], is_slippery = true, prob_frozen = 0.8) {
58+
this.size = size
4359
this.lake = FrozenLakeEnv.generate_lake(size, prob_frozen)
4460
this.n_states = size[0] * size[1]
4561

62+
this.compute_MDP(is_slippery)
4663
}
4764

48-
step(action: number) {
65+
private compute_MDP(is_slippery: boolean): void {
66+
const { size, n_actions } = this
67+
const MDP = this.MDP = full_matrix(size, [] as Transition[])
68+
69+
for(let row = 0; row < size[0]; row++){
70+
for(let col = 0; col < size[1]; col++){
71+
for(let action = 0; action < n_actions; action++) {
72+
const tile = this.get_tile([row, col])
73+
74+
if (['H', 'G'].includes(tile)) {
75+
76+
MDP[row][col].push({
77+
prob: 1.0,
78+
next_state: [row, col],
79+
reward: 0,
80+
done: true,
81+
})
82+
83+
} else if (is_slippery) {
84+
85+
const all_actions = [
86+
(action - 1) % n_actions,
87+
action,
88+
(action + 1) % n_actions,
89+
]
90+
91+
for(let a of all_actions) {
92+
const transition = this.get_transition([row, col], a)
93+
transition.prob = 1/all_actions.length // uniform probs
94+
MDP[row][col].push(transition)
95+
}
96+
97+
} else {
98+
const transition = this.get_transition([row, col], action)
99+
MDP[row][col].push(transition)
100+
}
101+
}
102+
}
103+
}
104+
}
105+
106+
private calc_next_state(state: State, action: Action): State {
107+
let [ row, col ] = state
108+
109+
switch(action) {
110+
case Action.LEFT:
111+
col = Math.max(col - 1, 0)
112+
break
113+
case Action.DOWN:
114+
row = Math.min(row + 1, this.size[0] - 1)
115+
break;
116+
case Action.RIGHT:
117+
col = Math.min(col + 1, this.size[1] - 1)
118+
break
119+
case Action.UP:
120+
row = Math.max(row - 1, 0)
121+
}
122+
123+
return [row, col] as State
124+
}
125+
126+
private get_tile(state: State): Tile {
127+
const [ row, col ] = state
128+
return this.lake[row][col]
129+
}
130+
131+
private get_transition(state: State, action: Action): Transition {
132+
const next_state = this.calc_next_state(state, action)
133+
const tile = this.get_tile(next_state)
134+
const done = ['H', 'G'].includes(tile)
135+
const reward = Number(tile === 'G')
136+
137+
return {
138+
prob: 1,
139+
next_state,
140+
reward,
141+
done,
142+
}
143+
}
144+
145+
reset(): State {
146+
this.current_state = [0, 0]
147+
return [ ...this.current_state ] // it's safer to clone
148+
}
49149

150+
step(action: Action): Transition {
151+
const transition = this.get_transition(this.current_state, action)
152+
this.current_state = transition.next_state
153+
return transition
50154
}
51155
}

0 commit comments

Comments
 (0)