Skip to content

Commit 81ecd16

Browse files
committed
wip: dynamic programming
1 parent d204a62 commit 81ecd16

File tree

4 files changed

+109
-60
lines changed

4 files changed

+109
-60
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import FrozenLakeEnv, { Transition } from '../environments/FrozenLakeEnv'
2+
import { zeros } from '../utils/lists'
3+
4+
export default function policy_evaluation(env: FrozenLakeEnv, policy: Matrix<number>, gamma=1, theta=1e-8) {
5+
const Value = zeros(env.n_states)
6+
7+
for(;;) {
8+
let delta = 0
9+
10+
for(let state = 0; state < env.n_states; state++) {
11+
const vs = Value[state]
12+
13+
for (let action = 0; action < env.n_actions; action++ ) {
14+
const action_prob = policy[state][action]
15+
16+
for (let transition of env.MDP[state][action]) {
17+
const { prob, next_state, reward } = transition
18+
Value[state] += action_prob * prob * (reward + (gamma * Value[next_state]))
19+
}
20+
}
21+
22+
delta = Math.max(delta, Math.abs(vs - Value[state]))
23+
}
24+
25+
if (delta < theta) break
26+
}
27+
}

noderl/environments/FrozenLakeEnv.ts

Lines changed: 82 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@ enum Action {
88
UP = 3,
99
}
1010

11-
type State = Coordinates
11+
type Position = [number, number]
12+
type State = number
1213
type Tile = 'S' | 'F' | 'H' | 'G'
1314
type LakeMap = Matrix<Tile>
1415
type LakeMapDict = { [size: string]: LakeMap }
1516

16-
interface Transition {
17+
export interface Transition {
1718
prob: number,
1819
next_state: State,
1920
reward: number,
@@ -48,9 +49,9 @@ export default class FrozenLakeEnv {
4849

4950
private size: Size
5051
private lake: LakeMap = []
51-
private current_state: Coordinates = [0, 0]
52-
private n_actions: number = 4
53-
private n_states: number = 0
52+
private current_state: State = 0
53+
public n_actions: number = 4
54+
public n_states: number = 0
5455

5556
public MDP: Matrix<Transition[]>
5657

@@ -59,52 +60,22 @@ export default class FrozenLakeEnv {
5960
this.lake = FrozenLakeEnv.generate_lake(size, prob_frozen)
6061
this.n_states = size[0] * size[1]
6162

62-
this.compute_MDP(is_slippery)
63+
this.compute_dynamics(is_slippery)
6364
}
6465

65-
private compute_MDP(is_slippery: boolean): void {
66-
const { size, n_actions } = this
67-
const MDP = this.MDP = full_matrix(size, [] as Transition[])
68-
69-
for(let row = 0; row < size[0]; row++){
70-
for(let col = 0; col < size[1]; col++){
71-
for(let action = 0; action < n_actions; action++) {
72-
const tile = this.get_tile([row, col])
73-
74-
if (['H', 'G'].includes(tile)) {
75-
76-
MDP[row][col].push({
77-
prob: 1.0,
78-
next_state: [row, col],
79-
reward: 0,
80-
done: true,
81-
})
82-
83-
} else if (is_slippery) {
84-
85-
const all_actions = [
86-
(action - 1) % n_actions,
87-
action,
88-
(action + 1) % n_actions,
89-
]
90-
91-
for(let a of all_actions) {
92-
const transition = this.get_transition([row, col], a)
93-
transition.prob = 1/all_actions.length // uniform probs
94-
MDP[row][col].push(transition)
95-
}
66+
private to_state(position: Position): State {
67+
const [ row, col ] = position
68+
return row * this.size[1] + col
69+
}
9670

97-
} else {
98-
const transition = this.get_transition([row, col], action)
99-
MDP[row][col].push(transition)
100-
}
101-
}
102-
}
103-
}
71+
private to_position(state: State): Position {
72+
const row = Math.floor(state / this.size[1])
73+
const col = state - row * this.size[1]
74+
return [row, col] as Position
10475
}
10576

106-
private calc_next_state(state: State, action: Action): State {
107-
let [ row, col ] = state
77+
private calc_next_position(position: Position, action: Action): Position {
78+
let [ row, col ] = position
10879

10980
switch(action) {
11081
case Action.LEFT:
@@ -120,35 +91,90 @@ export default class FrozenLakeEnv {
12091
row = Math.max(row - 1, 0)
12192
}
12293

123-
return [row, col] as State
94+
return [row, col] as Position
12495
}
12596

126-
private get_tile(state: State): Tile {
127-
const [ row, col ] = state
97+
private get_tile(position: Position): Tile {
98+
const [ row, col ] = position
12899
return this.lake[row][col]
129100
}
130101

131-
private get_transition(state: State, action: Action): Transition {
132-
const next_state = this.calc_next_state(state, action)
133-
const tile = this.get_tile(next_state)
102+
private get_transition(position: Position, action: Action): Transition {
103+
const next_position = this.calc_next_position(position, action)
104+
const tile = this.get_tile(next_position)
105+
const next_state = this.to_state(next_position)
134106
const done = ['H', 'G'].includes(tile)
135107
const reward = Number(tile === 'G')
136108

137109
return {
138-
prob: 1,
110+
prob: 1.0,
139111
next_state,
140112
reward,
141113
done,
142114
}
143115
}
144116

117+
private compute_dynamics(is_slippery: boolean): void {
118+
const { size, n_states, n_actions } = this
119+
const MDP = this.MDP = full_matrix([n_states, n_actions], [] as Transition[])
120+
121+
for(let row = 0; row < size[0]; row++){
122+
for(let col = 0; col < size[1]; col++){
123+
124+
const position = [row, col] as Position
125+
const state = this.to_state(position)
126+
127+
for(let action = 0; action < n_actions; action++) {
128+
const tile = this.get_tile(position)
129+
130+
if (['H', 'G'].includes(tile)) {
131+
132+
MDP[state][action].push({
133+
prob: 1.0,
134+
next_state: state,
135+
reward: 0,
136+
done: true,
137+
})
138+
139+
} else if (is_slippery) {
140+
141+
// Floor is slippery, this means that two other
142+
// actions could happen if you take a step in a
143+
// direction. E.g.
144+
// Intended direction: up (or down)
145+
// Unintended outcome: left or right
146+
//
147+
// Intended direction: right (or left)
148+
// Unintended outcome: up or down
149+
const potential_actions = [
150+
(action - 1) % n_actions, // unintended
151+
action, // intended
152+
(action + 1) % n_actions, // unintended
153+
]
154+
const n_all = potential_actions.length
155+
156+
for(let a of potential_actions) {
157+
const transition = this.get_transition(position, a)
158+
transition.prob = 1/n_all // uniform probs
159+
MDP[state][action].push(transition)
160+
}
161+
162+
} else {
163+
const transition = this.get_transition(position, action)
164+
MDP[state][action].push(transition)
165+
}
166+
}
167+
}
168+
}
169+
}
170+
145171
reset(): State {
146-
this.current_state = [0, 0]
147-
return [ ...this.current_state ] // it's safer to clone
172+
return this.current_state = 0
148173
}
149174

150175
step(action: Action): Transition {
151-
const transition = this.get_transition(this.current_state, action)
176+
const current_position = this.to_position(this.current_state)
177+
const transition = this.get_transition(current_position, action)
152178
this.current_state = transition.next_state
153179
return transition
154180
}

noderl/types.d.ts

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
type Matrix<T> = T[][]
22
type Size = number[]
3-
type Coordinates = number[]

ts-node

Lines changed: 0 additions & 3 deletions
This file was deleted.

0 commit comments

Comments
 (0)