Skip to content

Commit 9cc6687

Browse files
committed
wip: running bandits experiments
1 parent 6e9df77 commit 9cc6687

File tree

5 files changed

+28
-22
lines changed

5 files changed

+28
-22
lines changed

noderl/multiarm_bandits/BaseAgent.ts

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,33 @@
11
import { argmax, full, zeros, range } from '../utils/lists'
22
import { choice } from '../utils/random'
33

4-
export default class GreedyAgent {
5-
protected arms_count: number[]
4+
export default class BaseAgent {
5+
public name: string
6+
protected arm_counts: number[]
67
protected q_values: number[]
78
protected actions: number[]
8-
protected explored: number = 0
9-
protected exploited: number = 0
109

1110
get estimates(): number[] {
1211
return this.q_values
1312
}
1413

1514
get counts(): number[] {
16-
return this.arms_count
15+
return this.arm_counts
1716
}
1817

1918
constructor(n_arms: number, init_value: number = 0) {
20-
this.arms_count = zeros(n_arms)
19+
this.name = `${this.constructor.name}_arms-${n_arms}_initvals-${init_value}`
20+
this.arm_counts = zeros(n_arms)
2121
this.q_values = full(n_arms, init_value)
2222
this.actions = range(n_arms)
2323
}
2424

2525
pull(arm: number) {
26-
this.arms_count[arm]++
26+
this.arm_counts[arm]++
2727
return arm
2828
}
2929

3030
random_action(): number {
31-
this.explored++
3231
const action = choice(this.actions)
3332
return this.pull(action)
3433
}
@@ -39,14 +38,9 @@ export default class GreedyAgent {
3938

4039
optimize(action: number, reward: number) {
4140
const old_stimate = this.q_values[action]
42-
const step_size = 1/this.arms_count[action]
41+
const step_size = 1/this.arm_counts[action]
4342

4443
// Algorithm in section 2.4 of Sutton & Barto book
4544
this.q_values[action] = old_stimate + step_size * (reward - old_stimate)
4645
}
47-
48-
explore_exploit_rate(): number[] {
49-
const total = this.explored + this.exploited
50-
return [this.explored/total, this.exploited/total]
51-
}
5246
}

noderl/multiarm_bandits/EpsilonGreegyAgent.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,27 @@ import GreedyAgent from './GreedyAgent'
44

55
export default class EpsilonGreedyAgent extends GreedyAgent {
66
private eps: number
7+
protected explored: number = 0
8+
protected exploited: number = 0
79

810
constructor(n_arms: number, eps: number, init_value: number = 0) {
911
super(n_arms, init_value)
12+
this.name += `_eps-${eps}`
1013
this.eps = eps
1114
}
1215

1316
act(): number {
1417
if (uniform() < this.eps) {
18+
this.explored++
1519
return this.random_action()
1620
} else {
21+
this.exploited++
1722
return this.greedy_action()
1823
}
1924
}
25+
26+
explore_exploit_rate(): number[] {
27+
const total = this.explored + this.exploited
28+
return [this.explored/total, this.exploited/total]
29+
}
2030
}

noderl/multiarm_bandits/GreedyAgent.ts

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ export default class GreedyAgent extends BaseAgent {
55

66
greedy_action(): number {
77
// Greedy action
8-
this.exploited++
98
const action = argmax(this.q_values)
109
return this.pull(action)
1110
}

noderl/multiarm_bandits/UCB1Agent.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ export default class UCB1Agent extends BaseAgent {
88
ucb_action() {
99
const plays = ++this.plays
1010
const values = this.q_values.map((estimate, i) => {
11-
const arm_count = this.arms_count[i]
11+
const arm_count = this.arm_counts[i]
1212
return estimate + Math.sqrt(2 * Math.log(plays)/ arm_count)
1313
})
1414

noderl/multiarm_bandits/training.ts

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@ import GreedyAgent from './GreedyAgent'
44
import EpsilonGreegyAgent from './EpsilonGreegyAgent'
55
import UCB1Agent from './UCB1Agent'
66
import { LOG_PATH } from '../config'
7+
import { full } from '../utils/lists'
78

8-
const PROB_DIST = [0.2, 0.5, 0.75]
9-
const REWARD_DIST = [1, 1, 1]
10-
const INIT_VALUE = 10
9+
const PROB_DIST = [0.2, 0.5, 0.75, 0.15, 0.01, 0.92, 0.88, 0.36, 0.79, 0.9]
10+
const N = PROB_DIST.length
11+
const REWARD_DIST = full(N, 1)
12+
const INIT_VALUE = 0
1113
const EPSILON = 0.01
1214
const EPISODES = 100000
1315

@@ -18,16 +20,17 @@ const env = new BanditEnv(PROB_DIST, REWARD_DIST)
1820
// const agent = new EpsilonGreegyAgent(PROB_DIST.length, EPSILON, INIT_VALUE)
1921
const agent = new UCB1Agent(PROB_DIST.length, INIT_VALUE)
2022

21-
const writer = tf.node.summaryFileWriter(`${LOG_PATH}/bandits`)
23+
const writer = tf.node.summaryFileWriter(`${LOG_PATH}/${agent.name}`)
2224

2325
let total_reward = 0
2426

2527
if (agent instanceof UCB1Agent) {
26-
PROB_DIST.forEach((_, i) => {
28+
// Initialization. We try once all the arms
29+
for (let i = 0; i < N; i++) {
2730
const action = agent.pull(i)
2831
const reward = env.step(action)
2932
agent.optimize(action, reward)
30-
})
33+
}
3134
}
3235

3336
for (let e = 0; e < EPISODES; e++) {

0 commit comments

Comments
 (0)