@@ -8,12 +8,13 @@ enum Action {
8
8
UP = 3 ,
9
9
}
10
10
11
- type State = Coordinates
11
+ type Position = [ number , number ]
12
+ type State = number
12
13
type Tile = 'S' | 'F' | 'H' | 'G'
13
14
type LakeMap = Matrix < Tile >
14
15
type LakeMapDict = { [ size : string ] : LakeMap }
15
16
16
- interface Transition {
17
+ export interface Transition {
17
18
prob : number ,
18
19
next_state : State ,
19
20
reward : number ,
@@ -48,9 +49,9 @@ export default class FrozenLakeEnv {
48
49
49
50
private size : Size
50
51
private lake : LakeMap = [ ]
51
- private current_state : Coordinates = [ 0 , 0 ]
52
- private n_actions : number = 4
53
- private n_states : number = 0
52
+ private current_state : State = 0
53
+ public n_actions : number = 4
54
+ public n_states : number = 0
54
55
55
56
public MDP : Matrix < Transition [ ] >
56
57
@@ -59,52 +60,22 @@ export default class FrozenLakeEnv {
59
60
this . lake = FrozenLakeEnv . generate_lake ( size , prob_frozen )
60
61
this . n_states = size [ 0 ] * size [ 1 ]
61
62
62
- this . compute_MDP ( is_slippery )
63
+ this . compute_dynamics ( is_slippery )
63
64
}
64
65
65
- private compute_MDP ( is_slippery : boolean ) : void {
66
- const { size, n_actions } = this
67
- const MDP = this . MDP = full_matrix ( size , [ ] as Transition [ ] )
68
-
69
- for ( let row = 0 ; row < size [ 0 ] ; row ++ ) {
70
- for ( let col = 0 ; col < size [ 1 ] ; col ++ ) {
71
- for ( let action = 0 ; action < n_actions ; action ++ ) {
72
- const tile = this . get_tile ( [ row , col ] )
73
-
74
- if ( [ 'H' , 'G' ] . includes ( tile ) ) {
75
-
76
- MDP [ row ] [ col ] . push ( {
77
- prob : 1.0 ,
78
- next_state : [ row , col ] ,
79
- reward : 0 ,
80
- done : true ,
81
- } )
82
-
83
- } else if ( is_slippery ) {
84
-
85
- const all_actions = [
86
- ( action - 1 ) % n_actions ,
87
- action ,
88
- ( action + 1 ) % n_actions ,
89
- ]
90
-
91
- for ( let a of all_actions ) {
92
- const transition = this . get_transition ( [ row , col ] , a )
93
- transition . prob = 1 / all_actions . length // uniform probs
94
- MDP [ row ] [ col ] . push ( transition )
95
- }
66
+ private to_state ( position : Position ) : State {
67
+ const [ row , col ] = position
68
+ return row * this . size [ 1 ] + col
69
+ }
96
70
97
- } else {
98
- const transition = this . get_transition ( [ row , col ] , action )
99
- MDP [ row ] [ col ] . push ( transition )
100
- }
101
- }
102
- }
103
- }
71
+ private to_position ( state : State ) : Position {
72
+ const row = Math . floor ( state / this . size [ 1 ] )
73
+ const col = state - row * this . size [ 1 ]
74
+ return [ row , col ] as Position
104
75
}
105
76
106
- private calc_next_state ( state : State , action : Action ) : State {
107
- let [ row , col ] = state
77
+ private calc_next_position ( position : Position , action : Action ) : Position {
78
+ let [ row , col ] = position
108
79
109
80
switch ( action ) {
110
81
case Action . LEFT :
@@ -120,35 +91,90 @@ export default class FrozenLakeEnv {
120
91
row = Math . max ( row - 1 , 0 )
121
92
}
122
93
123
- return [ row , col ] as State
94
+ return [ row , col ] as Position
124
95
}
125
96
126
- private get_tile ( state : State ) : Tile {
127
- const [ row , col ] = state
97
+ private get_tile ( position : Position ) : Tile {
98
+ const [ row , col ] = position
128
99
return this . lake [ row ] [ col ]
129
100
}
130
101
131
- private get_transition ( state : State , action : Action ) : Transition {
132
- const next_state = this . calc_next_state ( state , action )
133
- const tile = this . get_tile ( next_state )
102
+ private get_transition ( position : Position , action : Action ) : Transition {
103
+ const next_position = this . calc_next_position ( position , action )
104
+ const tile = this . get_tile ( next_position )
105
+ const next_state = this . to_state ( next_position )
134
106
const done = [ 'H' , 'G' ] . includes ( tile )
135
107
const reward = Number ( tile === 'G' )
136
108
137
109
return {
138
- prob : 1 ,
110
+ prob : 1.0 ,
139
111
next_state,
140
112
reward,
141
113
done,
142
114
}
143
115
}
144
116
117
+ private compute_dynamics ( is_slippery : boolean ) : void {
118
+ const { size, n_states, n_actions } = this
119
+ const MDP = this . MDP = full_matrix ( [ n_states , n_actions ] , [ ] as Transition [ ] )
120
+
121
+ for ( let row = 0 ; row < size [ 0 ] ; row ++ ) {
122
+ for ( let col = 0 ; col < size [ 1 ] ; col ++ ) {
123
+
124
+ const position = [ row , col ] as Position
125
+ const state = this . to_state ( position )
126
+
127
+ for ( let action = 0 ; action < n_actions ; action ++ ) {
128
+ const tile = this . get_tile ( position )
129
+
130
+ if ( [ 'H' , 'G' ] . includes ( tile ) ) {
131
+
132
+ MDP [ state ] [ action ] . push ( {
133
+ prob : 1.0 ,
134
+ next_state : state ,
135
+ reward : 0 ,
136
+ done : true ,
137
+ } )
138
+
139
+ } else if ( is_slippery ) {
140
+
141
+ // Floor is slippery, this means that two other
142
+ // actions could happen if you take a step in a
143
+ // direction. E.g.
144
+ // Intended direction: up (or down)
145
+ // Unintended outcome: left or right
146
+ //
147
+ // Intended direction: right (or left)
148
+ // Unintended outcome: up or down
149
+ const potential_actions = [
150
+ ( action - 1 ) % n_actions , // unintended
151
+ action , // intended
152
+ ( action + 1 ) % n_actions , // unintended
153
+ ]
154
+ const n_all = potential_actions . length
155
+
156
+ for ( let a of potential_actions ) {
157
+ const transition = this . get_transition ( position , a )
158
+ transition . prob = 1 / n_all // uniform probs
159
+ MDP [ state ] [ action ] . push ( transition )
160
+ }
161
+
162
+ } else {
163
+ const transition = this . get_transition ( position , action )
164
+ MDP [ state ] [ action ] . push ( transition )
165
+ }
166
+ }
167
+ }
168
+ }
169
+ }
170
+
145
171
reset ( ) : State {
146
- this . current_state = [ 0 , 0 ]
147
- return [ ...this . current_state ] // it's safer to clone
172
+ return this . current_state = 0
148
173
}
149
174
150
175
step ( action : Action ) : Transition {
151
- const transition = this . get_transition ( this . current_state , action )
176
+ const current_position = this . to_position ( this . current_state )
177
+ const transition = this . get_transition ( current_position , action )
152
178
this . current_state = transition . next_state
153
179
return transition
154
180
}
0 commit comments