Skip to content

Commit 4927055

Browse files
author
Yusuke Sugomori
committed
bug fix
1 parent 0d32eaa commit 4927055

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

cpp/RBM.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,13 @@ void RBM::contrastive_divergence(int *input, double lr, int k) {
7474

7575
for(int i=0; i<n_hidden; i++) {
7676
for(int j=0; j<n_visible; j++) {
77-
W[i][j] += lr * (ph_sample[i] * input[j] - nh_means[i] * nv_samples[j]);
77+
W[i][j] += lr * (ph_sample[i] * input[j] - nh_means[i] * nv_samples[j]) / N;
7878
}
79-
hbias[i] += lr * (ph_sample[i] - nh_means[i]);
79+
hbias[i] += lr * (ph_sample[i] - nh_means[i]) / N;
8080
}
8181

8282
for(int i=0; i<n_visible; i++) {
83-
vbias[i] += lr * (input[i] - nv_samples[i]);
83+
vbias[i] += lr * (input[i] - nv_samples[i]) / N;
8484
}
8585

8686
delete[] ph_mean;
@@ -155,7 +155,7 @@ void test_rbm() {
155155
srand(0);
156156

157157
double learning_rate = 0.1;
158-
int training_epochs = 100;
158+
int training_epochs = 1000;
159159
int k = 1;
160160

161161
int train_N = 6;

0 commit comments

Comments
 (0)