Skip to content

Commit 6d6b412

Browse files
author
Yusuke Sugomori
committed
dA.java
1 parent eafd9bd commit 6d6b412

File tree

3 files changed

+214
-7
lines changed

3 files changed

+214
-7
lines changed

java/DBN/src/RBM.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ public class RBM {
44
public int N;
55
public int n_visible;
66
public int n_hidden;
7-
double[][] W;
8-
double[] hbias;
9-
double[] vbias;
7+
public double[][] W;
8+
public double[] hbias;
9+
public double[] vbias;
1010
public Random rng;
1111

1212
public double uniform(double min, double max) {

java/RBM/src/RBM.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ public class RBM {
44
public int N;
55
public int n_visible;
66
public int n_hidden;
7-
double[][] W;
8-
double[] hbias;
9-
double[] vbias;
7+
public double[][] W;
8+
public double[] hbias;
9+
public double[] vbias;
1010
public Random rng;
1111

1212
public double uniform(double min, double max) {
@@ -207,7 +207,7 @@ private static void test_rbm() {
207207
for(int i=0; i<test_N; i++) {
208208
rbm.reconstruct(test_X[i], reconstructed_X[i]);
209209
for(int j=0; j<n_visible; j++) {
210-
System.out.printf("%.5f", reconstructed_X[i][j]);
210+
System.out.printf("%.5f ", reconstructed_X[i][j]);
211211
}
212212
System.out.println();
213213
}

java/dA/src/dA.java

+207
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
import java.util.Random;
2+
3+
public class dA {
4+
public int N;
5+
public int n_visible;
6+
public int n_hidden;
7+
public double[][] W;
8+
public double[] hbias;
9+
public double[] vbias;
10+
public Random rng;
11+
12+
13+
public double uniform(double min, double max) {
14+
return rng.nextDouble() * (max - min) + min;
15+
}
16+
17+
public int binomial(int n, double p) {
18+
if(p < 0 || p > 1) return 0;
19+
20+
int c = 0;
21+
double r;
22+
23+
for(int i=0; i<n; i++) {
24+
r = rng.nextDouble();
25+
if (r < p) c++;
26+
}
27+
28+
return c;
29+
}
30+
31+
public static double sigmoid(double x) {
32+
return 1.0 / (1.0 + Math.pow(Math.E, -x));
33+
}
34+
35+
public dA(int N, int n_visible, int n_hidden,
36+
double[][] W, double[] hbias, double[] vbias, Random rng) {
37+
this.N = N;
38+
this.n_visible = n_visible;
39+
this.n_hidden = n_hidden;
40+
41+
if(rng == null) this.rng = new Random(1234);
42+
else this.rng = rng;
43+
44+
if(W == null) {
45+
this.W = new double[this.n_hidden][this.n_visible];
46+
double a = 1.0 / this.n_visible;
47+
48+
for(int i=0; i<this.n_hidden; i++) {
49+
for(int j=0; j<this.n_visible; j++) {
50+
this.W[i][j] = uniform(-a, a);
51+
}
52+
}
53+
} else {
54+
this.W = W;
55+
}
56+
57+
if(hbias == null) {
58+
this.hbias = new double[this.n_hidden];
59+
for(int i=0; i<this.n_hidden; i++) this.hbias[i] = 0;
60+
} else {
61+
this.hbias = hbias;
62+
}
63+
64+
if(vbias == null) {
65+
this.vbias = new double[this.n_visible];
66+
for(int i=0; i<this.n_visible; i++) this.vbias[i] = 0;
67+
} else {
68+
this.vbias = vbias;
69+
}
70+
}
71+
72+
public void get_corrupted_input(int[] x, int[] tilde_x, double p) {
73+
for(int i=0; i<n_visible; i++) {
74+
if(x[i] == 0) {
75+
tilde_x[i] = 0;
76+
} else {
77+
tilde_x[i] = binomial(1, p);
78+
}
79+
}
80+
}
81+
82+
// Encode
83+
public void get_hidden_values(int[] x, double[] y) {
84+
for(int i=0; i<n_hidden; i++) {
85+
y[i] = 0;
86+
for(int j=0; j<n_visible; j++) {
87+
y[i] += W[i][j] * x[j];
88+
}
89+
y[i] += hbias[i];
90+
y[i] = sigmoid(y[i]);
91+
}
92+
}
93+
94+
// Decode
95+
public void get_reconstructed_input(double[] y, double[] z) {
96+
for(int i=0; i<n_visible; i++) {
97+
z[i] = 0;
98+
for(int j=0; j<n_hidden; j++) {
99+
z[i] += W[j][i] * y[j];
100+
}
101+
z[i] += vbias[i];
102+
z[i] = sigmoid(z[i]);
103+
}
104+
}
105+
106+
public void train(int[] x, double lr, double corruption_level) {
107+
int[] tilde_x = new int[n_visible];
108+
double[] y = new double[n_hidden];
109+
double[] z = new double[n_visible];
110+
111+
double[] L_vbias = new double[n_visible];
112+
double[] L_hbias = new double[n_hidden];
113+
114+
double p = 1 - corruption_level;
115+
116+
get_corrupted_input(x, tilde_x, p);
117+
get_hidden_values(tilde_x, y);
118+
get_reconstructed_input(y, z);
119+
120+
// vbias
121+
for(int i=0; i<n_visible; i++) {
122+
L_vbias[i] = x[i] - z[i];
123+
vbias[i] += lr * L_vbias[i] / N;
124+
}
125+
126+
// hbias
127+
for(int i=0; i<n_hidden; i++) {
128+
L_hbias[i] = 0;
129+
for(int j=0; j<n_visible; j++) {
130+
L_hbias[i] += W[i][j] * L_vbias[j];
131+
}
132+
L_hbias[i] *= y[i] * (1 - y[i]);
133+
hbias[i] += lr * L_hbias[i] / N;
134+
}
135+
136+
// W
137+
for(int i=0; i<n_hidden; i++) {
138+
for(int j=0; j<n_visible; j++) {
139+
W[i][j] += lr * (L_hbias[i] * tilde_x[j] + L_vbias[j] * y[i]) / N;
140+
}
141+
}
142+
}
143+
144+
public void reconstruct(int[] x, double[] z) {
145+
double[] y = new double[n_hidden];
146+
147+
get_hidden_values(x, y);
148+
get_reconstructed_input(y, z);
149+
}
150+
151+
152+
private static void test_dA() {
153+
Random rng = new Random(123);
154+
155+
double learning_rate = 0.1;
156+
double corruption_level = 0.3;
157+
int training_epochs = 100;
158+
159+
int train_N = 10;
160+
int test_N = 2;
161+
int n_visible = 20;
162+
int n_hidden = 5;
163+
164+
int[][] train_X = {
165+
{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
166+
{1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
167+
{1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
168+
{1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
169+
{0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
170+
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
171+
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1},
172+
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1},
173+
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1},
174+
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0}
175+
};
176+
177+
dA da = new dA(train_N, n_visible, n_hidden, null, null, null, rng);
178+
179+
// train
180+
for(int epoch=0; epoch<training_epochs; epoch++) {
181+
for(int i=0; i<train_N; i++) {
182+
da.train(train_X[i], learning_rate, corruption_level);
183+
}
184+
}
185+
186+
// test data
187+
int[][] test_X = {
188+
{1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
189+
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0}
190+
};
191+
192+
double[][] reconstructed_X = new double[test_N][n_visible];
193+
194+
// test
195+
for(int i=0; i<test_N; i++) {
196+
da.reconstruct(test_X[i], reconstructed_X[i]);
197+
for(int j=0; j<n_visible; j++) {
198+
System.out.printf("%.5f ", reconstructed_X[i][j]);
199+
}
200+
System.out.println();
201+
}
202+
}
203+
204+
public static void main(String[] args) {
205+
test_dA();
206+
}
207+
}

0 commit comments

Comments
 (0)