Skip to content

Commit d0582c2

Browse files
author
Yusuke Sugomori
committed
rbm constructor
1 parent fb8d50c commit d0582c2

File tree

4 files changed

+72
-46
lines changed

4 files changed

+72
-46
lines changed

c/RBM.c

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,24 +35,39 @@ double sigmoid(double x) {
3535
}
3636

3737

38-
void RBM__construct(RBM* this, int N, int n_visible, int n_hidden) {
38+
void RBM__construct(RBM* this, int N, int n_visible, int n_hidden, \
39+
double **W, double *hbias, double *vbias) {
3940
int i, j;
4041
double a = 1.0 / n_visible;
4142

4243
this->N = N;
4344
this->n_visible = n_visible;
4445
this->n_hidden = n_hidden;
4546

46-
this->W = (double **)malloc(sizeof(double*) * n_hidden);
47-
this->W[0] = (double *)malloc(sizeof(double) * n_visible * n_hidden);
48-
for(i=0; i<n_hidden; i++) this->W[i] = this->W[0] + i * n_visible;
49-
this->hbias = (double *)malloc(sizeof(double) * n_hidden);
50-
this->vbias = (double *)malloc(sizeof(double) * n_visible);
47+
if(W == NULL) {
48+
this->W = (double **)malloc(sizeof(double*) * n_hidden);
49+
this->W[0] = (double *)malloc(sizeof(double) * n_visible * n_hidden);
50+
for(i=0; i<n_hidden; i++) this->W[i] = this->W[0] + i * n_visible;
5151

52-
for(i=0; i<n_hidden; i++) {
53-
for(j=0; j<n_visible; j++) {
54-
this->W[i][j] = uniform(-a, a);
52+
for(i=0; i<n_hidden; i++) {
53+
for(j=0; j<n_visible; j++) {
54+
this->W[i][j] = uniform(-a, a);
55+
}
5556
}
57+
} else {
58+
this->W = W;
59+
}
60+
61+
if(hbias == NULL) {
62+
this->hbias = (double *)malloc(sizeof(double) * n_hidden);
63+
} else {
64+
this->hbias = hbias;
65+
}
66+
67+
if(vbias == NULL) {
68+
this->vbias = (double *)malloc(sizeof(double) * n_visible);
69+
} else {
70+
this->vbias = vbias;
5671
}
5772
}
5873

@@ -198,7 +213,7 @@ void test_rbm(void) {
198213

199214
// construct RBM
200215
RBM rbm;
201-
RBM__construct(&rbm, train_N, n_visible, n_hidden);
216+
RBM__construct(&rbm, train_N, n_visible, n_hidden, NULL, NULL, NULL);
202217

203218
// train
204219
for(epoch=0; epoch<training_epochs; epoch++) {

c/RBM.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ typedef struct {
1010
double *vbias;
1111
} RBM;
1212

13-
void RBM__construct(RBM*, int, int, int);
13+
void RBM__construct(RBM*, int, int, int, double**, double*, double*);
1414
void RBM__destruct(RBM*);
1515
void RBM_contrastive_divergence(RBM*, int*, double, int);
1616
void RBM_sample_h_given_v(RBM*, int*, double*, int*);

cpp/RBM.cpp

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,38 +3,11 @@
33
#include "RBM.h"
44
using namespace std;
55

6-
7-
RBM::RBM(int size, int n_v, int n_h) {
8-
N = size;
9-
n_visible = n_v;
10-
n_hidden = n_h;
11-
12-
W = new double*[n_hidden];
13-
for(int i=0; i<n_hidden; i++) W[i] = new double[n_visible];
14-
hbias = new double[n_hidden];
15-
vbias = new double[n_visible];
16-
17-
double a = 1.0 / n_visible;
18-
19-
for(int i=0; i<n_hidden; i++) {
20-
for(int j=0; j<n_visible; j++) {
21-
W[i][j] = uniform(-a, a);
22-
}
23-
}
24-
}
25-
26-
RBM::~RBM() {
27-
for(int i=0; i<n_hidden; i++) delete[] W[i];
28-
delete[] W;
29-
delete[] hbias;
30-
delete[] vbias;
31-
}
32-
33-
double RBM::uniform(double min, double max) {
6+
double uniform(double min, double max) {
347
return rand() / (RAND_MAX + 1.0) * (max - min) + min;
358
}
369

37-
int RBM::binomial(int n, double p) {
10+
int binomial(int n, double p) {
3811
if(p < 0 || p > 1) return 0;
3912

4013
int c = 0;
@@ -48,11 +21,51 @@ int RBM::binomial(int n, double p) {
4821
return c;
4922
}
5023

51-
double RBM::sigmoid(double x) {
24+
double sigmoid(double x) {
5225
return 1.0 / (1.0 + exp(-x));
5326
}
5427

5528

29+
RBM::RBM(int size, int n_v, int n_h, double **w, double *hb, double *vb) {
30+
N = size;
31+
n_visible = n_v;
32+
n_hidden = n_h;
33+
34+
if(w == NULL) {
35+
W = new double*[n_hidden];
36+
for(int i=0; i<n_hidden; i++) W[i] = new double[n_visible];
37+
double a = 1.0 / n_visible;
38+
39+
for(int i=0; i<n_hidden; i++) {
40+
for(int j=0; j<n_visible; j++) {
41+
W[i][j] = uniform(-a, a);
42+
}
43+
}
44+
} else {
45+
W = w;
46+
}
47+
48+
if(hb == NULL) {
49+
hbias = new double[n_hidden];
50+
} else {
51+
hbias = hb;
52+
}
53+
54+
if(vb == NULL) {
55+
vbias = new double[n_visible];
56+
} else {
57+
vbias = vb;
58+
}
59+
}
60+
61+
RBM::~RBM() {
62+
for(int i=0; i<n_hidden; i++) delete[] W[i];
63+
delete[] W;
64+
delete[] hbias;
65+
delete[] vbias;
66+
}
67+
68+
5669
void RBM::contrastive_divergence(int *input, double lr, int k) {
5770
double *ph_mean = new double[n_hidden];
5871
int *ph_sample = new int[n_hidden];
@@ -173,8 +186,9 @@ void test_rbm() {
173186
{0, 0, 1, 1, 1, 0}
174187
};
175188

189+
176190
// construct RBM
177-
RBM rbm(train_N, n_visible, n_hidden);
191+
RBM rbm(train_N, n_visible, n_hidden, NULL, NULL, NULL);
178192

179193
// train
180194
for(int epoch=0; epoch<training_epochs; epoch++) {

cpp/RBM.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,8 @@ class RBM {
77
double **W;
88
double *hbias;
99
double *vbias;
10-
RBM(int, int, int);
10+
RBM(int, int, int, double**, double*, double*);
1111
~RBM();
12-
double uniform(double, double);
13-
int binomial(int, double);
14-
double sigmoid(double);
1512
void contrastive_divergence(int*, double, int);
1613
void sample_h_given_v(int*, double*, int*);
1714
void sample_v_given_h(int*, double*, int*);

0 commit comments

Comments
 (0)