Skip to content

Commit 1a4c59d

Browse files
author
Yusuke Sugomori
committed
scala
1 parent 1a540b2 commit 1a4c59d

File tree

2 files changed

+216
-1
lines changed

2 files changed

+216
-1
lines changed

java/RBM/src/RBM.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ public void contrastive_divergence(int[] input, double lr, int k) {
9191

9292
for(int i=0; i<n_hidden; i++) {
9393
for(int j=0; j<n_visible; j++) {
94-
W[i][j] += lr *(ph_sample[i] * input[j] - nh_means[i] * nv_samples[j]) / N;
94+
W[i][j] += lr * (ph_sample[i] * input[j] - nh_means[i] * nv_samples[j]) / N;
9595
}
9696
hbias[i] += lr * (ph_sample[i] - nh_means[i]) / N;
9797
}

scala/RBM.scala

+215
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
import scala.util.Random
2+
import scala.math
3+
4+
class RBM(val N: Int, val n_visible: Int, val n_hidden: Int,
5+
_W: Array[Array[Double]]=null, _hbias: Array[Double]=null, _vbias: Array[Double]=null,
6+
var rng: Random=null) {
7+
8+
var W: Array[Array[Double]] = Array.ofDim[Double](n_hidden, n_visible)
9+
var hbias: Array[Double] = new Array[Double](n_hidden)
10+
var vbias: Array[Double] = new Array[Double](n_visible)
11+
12+
13+
if(rng == null) rng = new Random(1234)
14+
15+
if(_W == null) {
16+
var i: Int = 0
17+
var j: Int = 0
18+
19+
val a: Double = 1 / n_visible
20+
for(i <- 0 until n_hidden)
21+
for(j <- 0 until n_visible)
22+
W(i)(j) = uniform(-a, a)
23+
24+
} else {
25+
W = _W
26+
}
27+
28+
if(_hbias == null) {
29+
var i: Int = 0
30+
for(i <- 0 until n_hidden) hbias(i) = 0
31+
} else {
32+
hbias = _hbias
33+
}
34+
35+
if(_vbias == null) {
36+
var i: Int = 0
37+
for(i <- 0 until n_visible) vbias(i) = 0
38+
} else {
39+
vbias = _vbias
40+
}
41+
42+
43+
def uniform(min: Double, max: Double): Double = rng.nextDouble() * (max - min) + min
44+
def binomial(n: Int, p: Double): Int = {
45+
if(p < 0 || p > 1) return 0
46+
47+
var c: Int = 0
48+
var r: Double = 0
49+
50+
var i: Int = 0
51+
for(i <- 0 until n) {
52+
r = rng.nextDouble()
53+
if(r < p) c += 1
54+
}
55+
56+
c
57+
}
58+
59+
def sigmoid(x: Double): Double = 1.0 / (1.0 + math.pow(math.E, -x))
60+
61+
62+
def contrastive_divergence(input: Array[Int], lr: Double, k: Int) {
63+
val ph_mean: Array[Double] = new Array[Double](n_hidden)
64+
val ph_sample: Array[Int] = new Array[Int](n_hidden)
65+
val nv_means: Array[Double] = new Array[Double](n_visible)
66+
val nv_samples: Array[Int] = new Array[Int](n_visible)
67+
val nh_means: Array[Double] = new Array[Double](n_hidden)
68+
val nh_samples: Array[Int] = new Array[Int](n_hidden)
69+
70+
/* CD-k */
71+
sample_h_given_v(input, ph_mean, ph_sample)
72+
73+
var step: Int = 0
74+
for(step <- 0 until k) {
75+
if(step == 0) {
76+
gibbs_hvh(ph_sample, nv_means, nv_samples, nh_means, nh_samples)
77+
} else {
78+
gibbs_hvh(nh_samples, nv_means, nv_samples, nh_means, nh_samples)
79+
}
80+
}
81+
82+
var i: Int = 0
83+
var j: Int = 0
84+
for(i <- 0 until n_hidden) {
85+
for(j <- 0 until n_visible) {
86+
W(i)(j) += lr * (ph_sample(i) * input(j) - nh_means(i) * nv_samples(j)) / N
87+
}
88+
hbias(i) += lr * (ph_sample(i) - nh_means(i)) / N
89+
}
90+
91+
for(i <- 0 until n_visible) {
92+
vbias(i) += lr * (input(i) - nv_samples(i)) / N
93+
}
94+
}
95+
96+
97+
def sample_h_given_v(v0_sample: Array[Int], mean: Array[Double], sample: Array[Int]) {
98+
var i: Int = 0
99+
for(i <- 0 until n_hidden) {
100+
mean(i) = propup(v0_sample, W(i), hbias(i))
101+
sample(i) = binomial(1, mean(i))
102+
}
103+
}
104+
105+
def sample_v_given_h(h0_sample: Array[Int], mean: Array[Double], sample: Array[Int]) {
106+
var i: Int = 0
107+
for(i <- 0 until n_visible) {
108+
mean(i) = propdown(h0_sample, i, vbias(i))
109+
sample(i) = binomial(1, mean(i))
110+
}
111+
}
112+
113+
def propup(v: Array[Int], w: Array[Double], b: Double): Double = {
114+
var pre_sigmoid_activation: Double = 0
115+
var j: Int = 0
116+
for(j <- 0 until n_visible) {
117+
pre_sigmoid_activation += w(j) * v(j)
118+
}
119+
pre_sigmoid_activation += b
120+
sigmoid(pre_sigmoid_activation)
121+
}
122+
123+
def propdown(h: Array[Int], i: Int, b: Double): Double = {
124+
var pre_sigmoid_activation: Double = 0
125+
var j: Int = 0
126+
for(j <- 0 until n_hidden) {
127+
pre_sigmoid_activation += W(j)(i) * h(j)
128+
}
129+
pre_sigmoid_activation += b
130+
sigmoid(pre_sigmoid_activation)
131+
}
132+
133+
def gibbs_hvh(h0_sample: Array[Int], nv_means: Array[Double], nv_samples: Array[Int], nh_means: Array[Double], nh_samples: Array[Int]) {
134+
sample_v_given_h(h0_sample, nv_means, nv_samples)
135+
sample_h_given_v(nv_samples, nh_means, nh_samples)
136+
}
137+
138+
139+
def reconstruct(v: Array[Int], reconstructed_v: Array[Double]) {
140+
val h: Array[Double] = new Array[Double](n_hidden)
141+
var pre_sigmoid_activation: Double = 0
142+
143+
var i: Int = 0
144+
var j: Int = 0
145+
146+
for(i <- 0 until n_hidden) {
147+
h(i) = propup(v, W(i), hbias(i))
148+
}
149+
150+
for(i <- 0 until n_visible) {
151+
pre_sigmoid_activation = 0
152+
for(j <- 0 until n_hidden) {
153+
pre_sigmoid_activation += W(j)(i) * h(j)
154+
}
155+
pre_sigmoid_activation += vbias(i)
156+
reconstructed_v(i) = sigmoid(pre_sigmoid_activation)
157+
}
158+
}
159+
}
160+
161+
162+
def test_rbm() {
163+
val rng: Random = new Random(123)
164+
165+
var learning_rate: Double = 0.1
166+
val training_epochs: Int = 1000
167+
val k: Int = 1
168+
169+
val train_N: Int = 6;
170+
val test_N: Int = 2
171+
val n_visible: Int = 6
172+
val n_hidden: Int = 3
173+
174+
val train_X: Array[Array[Int]] = Array(
175+
Array(1, 1, 1, 0, 0, 0),
176+
Array(1, 0, 1, 0, 0, 0),
177+
Array(1, 1, 1, 0, 0, 0),
178+
Array(0, 0, 1, 1, 1, 0),
179+
Array(0, 0, 1, 0, 1, 0),
180+
Array(0, 0, 1, 1, 1, 0)
181+
)
182+
183+
184+
val rbm: RBM = new RBM(train_N, n_visible, n_hidden, rng=rng)
185+
186+
var i: Int = 0
187+
var j: Int = 0
188+
189+
// train
190+
var epoch: Int = 0
191+
for(epoch <- 0 until training_epochs) {
192+
for(i <- 0 until train_N) {
193+
rbm.contrastive_divergence(train_X(i), learning_rate, k)
194+
}
195+
}
196+
197+
// test data
198+
val test_X: Array[Array[Int]] = Array(
199+
Array(1, 1, 0, 0, 0, 0),
200+
Array(0, 0, 0, 1, 1, 0)
201+
)
202+
203+
val reconstructed_X: Array[Array[Double]] = Array.ofDim[Double](test_N, n_visible)
204+
for(i <- 0 until test_N) {
205+
rbm.reconstruct(test_X(i), reconstructed_X(i))
206+
for(j <- 0 until n_visible) {
207+
printf("%.5f", reconstructed_X(i)(j))
208+
}
209+
println()
210+
}
211+
212+
}
213+
214+
test_rbm()
215+

0 commit comments

Comments
 (0)