Skip to content

Commit 0a7f6f3

Browse files
committed
Built BanditEnv
1 parent 08f16f4 commit 0a7f6f3

File tree

9 files changed

+786
-1
lines changed

9 files changed

+786
-1
lines changed

.editorconfig

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# EditorConfig is awesome: https://EditorConfig.org
2+
3+
# top-most EditorConfig file
4+
root = true
5+
6+
# Unix-style newlines with a newline ending every file
7+
[*]
8+
end_of_line = lf
9+
insert_final_newline = true
10+
11+
# Matches multiple files with brace expansion notation
12+
# Set default charset
13+
[*.{ts,py}]
14+
charset = utf-8
15+
16+
# 4 space indentation
17+
[*.py]
18+
indent_style = space
19+
indent_size = 4
20+
21+
# Indentation override for all JS under lib directory
22+
[*.ts]
23+
indent_style = space
24+
indent_size = 2
25+
26+
# Matches the exact files either package.json or .travis.yml
27+
[{package.json,.travis.yml}]
28+
indent_style = space
29+
indent_size = 2

noderl/environments/BanditEnv.test.ts

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import test from 'tape'
2+
import BanditEnv from './BanditEnv'
3+
4+
test('new BanditEnv', t => {
5+
t.throws(() => new BanditEnv([], []), 'Empty probability distribution')
6+
t.throws(() => new BanditEnv([1], []), 'Empty reward distribution')
7+
t.throws(() => new BanditEnv([1, 0], [1]), 'Probability and Reward distribution must be the same length')
8+
t.throws(() => new BanditEnv([-1, 1], [1, 1]), 'All probabilities must be greater or equal to 0')
9+
t.throws(() => new BanditEnv([0, 2], [1, 1]), 'All probabilities must be less or equal to 1')
10+
t.doesNotThrow(() => new BanditEnv([0, 1], [1, 1]), 'Bandit environment correctly initialized')
11+
t.end()
12+
})
13+
14+
test('BanditEnv#pull', t => {
15+
const probs = [0.2, 0.5, 0.75]
16+
const rewards = [8, 5, 2.5]
17+
const bandit = new BanditEnv(probs, rewards)
18+
19+
t.throws(() => bandit.pull(-1), 'Wrong accion passed in')
20+
21+
rewards.forEach((reward, arm) => {
22+
let attempts = 0
23+
while(true) {
24+
attempts++
25+
if(bandit.pull(arm) === reward) {
26+
t.pass(`Arm ${arm} eventually paid off after ${attempts} attempts`)
27+
break
28+
}
29+
}
30+
})
31+
32+
t.end()
33+
})

noderl/environments/BanditEnv.ts

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import { min, max } from '../utils/lists'
2+
import { uniform } from '../utils/random'
3+
import { assert } from '../utils/assertion'
4+
5+
export default class BanditEnv {
6+
7+
private n_arms: number
8+
private p_dist: number[]
9+
private r_dist: number[]
10+
11+
constructor(p_dist: number[], r_dist: number[]) {
12+
13+
assert(p_dist.length !== 0, 'Empty probability distribution')
14+
assert(r_dist.length !== 0, 'Empty reward distribution')
15+
assert(p_dist.length === r_dist.length, 'Probability and Reward distribution must be the same length')
16+
assert(min(p_dist) >= 0 && max(p_dist) <= 1, 'All probabilities must be between 0 and 1')
17+
18+
this.n_arms = p_dist.length
19+
this.p_dist = p_dist
20+
this.r_dist = r_dist
21+
}
22+
23+
pull(action: number) {
24+
assert(action >= 0 && action < this.n_arms, `Wrong accion passed in: "${action}"`)
25+
26+
let reward = 0
27+
28+
if (uniform() < this.p_dist[action]) {
29+
reward = this.r_dist[action]
30+
}
31+
32+
return reward
33+
}
34+
}

noderl/utils/assertion.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
export class AssertionError extends Error {
2+
name: string = 'AssertionError'
3+
constructor(message?: string) {
4+
super(message)
5+
}
6+
}
7+
8+
export function assert(condition: boolean, message: string) {
9+
if (!condition) {
10+
throw new AssertionError(message)
11+
}
12+
}

noderl/utils/lists.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
export function min(list: number[]): number {
2+
return Math.min(...list)
3+
}
4+
5+
export function max(list: number[]): number {
6+
return Math.max(...list)
7+
}

noderl/utils/random.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
export function uniform() {
2+
return Math.random()
3+
}

0 commit comments

Comments
 (0)