Skip to content

Commit 5b420c4

Browse files
author
Yusuke Sugomori
committed
buf fix: sampling
1 parent 031e6aa commit 5b420c4

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

DeepBeliefNets.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(self, input=None, label=None,\
5555
if i == 0:
5656
layer_input = self.x
5757
else:
58-
layer_input = self.sigmoid_layers[-1].output()
58+
layer_input = self.sigmoid_layers[-1].sample_h_given_v()
5959

6060
# construct sigmoid_layer
6161
sigmoid_layer = HiddenLayer(input=layer_input,
@@ -76,7 +76,7 @@ def __init__(self, input=None, label=None,\
7676

7777

7878
# layer for output using Logistic Regression
79-
self.log_layer = LogisticRegression(input=self.sigmoid_layers[-1].output(),
79+
self.log_layer = LogisticRegression(input=self.sigmoid_layers[-1].sample_h_given_v(),
8080
label=self.y,
8181
n_in=hidden_layer_sizes[-1],
8282
n_out=n_outs)
@@ -92,9 +92,8 @@ def pretrain(self, lr=0.1, k=1, epochs=100):
9292
if i == 0:
9393
layer_input = self.x
9494
else:
95-
layer_input = self.sigmoid_layers[i-1].output()
95+
layer_input = self.sigmoid_layers[i-1].sample_h_given_v(layer_input)
9696
rbm = self.rbm_layers[i]
97-
9897

9998
for epoch in xrange(epochs):
10099
c = []
@@ -108,7 +107,7 @@ def pretrain(self, lr=0.1, k=1, epochs=100):
108107

109108

110109
def finetune(self, lr=0.1, epochs=100):
111-
layer_input = self.sigmoid_layers[-1].output()
110+
layer_input = self.sigmoid_layers[-1].sample_h_given_v()
112111

113112
# train log_layer
114113
epoch = 0

HiddenLayer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,12 @@ def output(self, input=None):
5656
else self.activation(linear_output))
5757

5858

59+
def sample_h_given_v(self, input=None):
60+
if input is not None:
61+
self.input = input
62+
63+
v_mean = self.output()
64+
h_sample = self.numpy_rng.binomial(size=v_mean.shape,
65+
n=1,
66+
p=v_mean)
67+
return h_sample

0 commit comments

Comments
 (0)