Skip to content

Commit eafd9bd

Browse files
author
Yusuke Sugomori
committed
for scalac RBM.scala
1 parent 1a4c59d commit eafd9bd

File tree

4 files changed

+63
-43
lines changed

4 files changed

+63
-43
lines changed

java/DBN/src/DBN.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ public void predict(int[] x, double[] y) {
150150
log_layer.softmax(y);
151151
}
152152

153-
public static void main(String[] arg) {
153+
private static void test_dbn() {
154154
Random rng = new Random(123);
155155

156156
double pretrain_lr = 0.1;
@@ -215,4 +215,8 @@ public static void main(String[] arg) {
215215
System.out.println();
216216
}
217217
}
218+
219+
public static void main(String[] args) {
220+
test_dbn();
221+
}
218222
}

java/LogisticRegression/src/LogisticRegression.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ public void predict(int[] x, double[] y) {
6565
softmax(y);
6666
}
6767

68-
public static void main(String[] arg) {
68+
private static void test_lr() {
6969
double learning_rate = 0.1;
7070
double n_epochs = 500;
7171

@@ -121,4 +121,8 @@ public static void main(String[] arg) {
121121
System.out.println();
122122
}
123123
}
124+
125+
public static void main(String[] args) {
126+
test_lr();
127+
}
124128
}

java/RBM/src/RBM.java

+6-1
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ public void reconstruct(int[] v, double[] reconstructed_v) {
163163

164164

165165

166-
public static void main(String[] arg) {
166+
private static void test_rbm() {
167167
Random rng = new Random(123);
168168

169169
double learning_rate = 0.1;
@@ -212,4 +212,9 @@ public static void main(String[] arg) {
212212
System.out.println();
213213
}
214214
}
215+
216+
public static void main(String[] args) {
217+
test_rbm();
218+
}
219+
215220
}

scala/RBM.scala

+47-40
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
// $ scalac RBM.scala
2+
// $ scala RBM
3+
14
import scala.util.Random
25
import scala.math
36

@@ -159,57 +162,61 @@ class RBM(val N: Int, val n_visible: Int, val n_hidden: Int,
159162
}
160163

161164

162-
def test_rbm() {
163-
val rng: Random = new Random(123)
165+
object RBM {
166+
def test_rbm() {
167+
val rng: Random = new Random(123)
164168

165-
var learning_rate: Double = 0.1
166-
val training_epochs: Int = 1000
167-
val k: Int = 1
169+
var learning_rate: Double = 0.1
170+
val training_epochs: Int = 1000
171+
val k: Int = 1
168172

169-
val train_N: Int = 6;
170-
val test_N: Int = 2
171-
val n_visible: Int = 6
172-
val n_hidden: Int = 3
173+
val train_N: Int = 6;
174+
val test_N: Int = 2
175+
val n_visible: Int = 6
176+
val n_hidden: Int = 3
173177

174-
val train_X: Array[Array[Int]] = Array(
175-
Array(1, 1, 1, 0, 0, 0),
176-
Array(1, 0, 1, 0, 0, 0),
177-
Array(1, 1, 1, 0, 0, 0),
178-
Array(0, 0, 1, 1, 1, 0),
179-
Array(0, 0, 1, 0, 1, 0),
180-
Array(0, 0, 1, 1, 1, 0)
181-
)
178+
val train_X: Array[Array[Int]] = Array(
179+
Array(1, 1, 1, 0, 0, 0),
180+
Array(1, 0, 1, 0, 0, 0),
181+
Array(1, 1, 1, 0, 0, 0),
182+
Array(0, 0, 1, 1, 1, 0),
183+
Array(0, 0, 1, 0, 1, 0),
184+
Array(0, 0, 1, 1, 1, 0)
185+
)
182186

183187

184-
val rbm: RBM = new RBM(train_N, n_visible, n_hidden, rng=rng)
188+
val rbm: RBM = new RBM(train_N, n_visible, n_hidden, rng=rng)
185189

186-
var i: Int = 0
187-
var j: Int = 0
190+
var i: Int = 0
191+
var j: Int = 0
188192

189-
// train
190-
var epoch: Int = 0
191-
for(epoch <- 0 until training_epochs) {
192-
for(i <- 0 until train_N) {
193-
rbm.contrastive_divergence(train_X(i), learning_rate, k)
193+
// train
194+
var epoch: Int = 0
195+
for(epoch <- 0 until training_epochs) {
196+
for(i <- 0 until train_N) {
197+
rbm.contrastive_divergence(train_X(i), learning_rate, k)
198+
}
194199
}
195-
}
196200

197-
// test data
198-
val test_X: Array[Array[Int]] = Array(
199-
Array(1, 1, 0, 0, 0, 0),
200-
Array(0, 0, 0, 1, 1, 0)
201-
)
201+
// test data
202+
val test_X: Array[Array[Int]] = Array(
203+
Array(1, 1, 0, 0, 0, 0),
204+
Array(0, 0, 0, 1, 1, 0)
205+
)
202206

203-
val reconstructed_X: Array[Array[Double]] = Array.ofDim[Double](test_N, n_visible)
204-
for(i <- 0 until test_N) {
205-
rbm.reconstruct(test_X(i), reconstructed_X(i))
206-
for(j <- 0 until n_visible) {
207-
printf("%.5f", reconstructed_X(i)(j))
207+
val reconstructed_X: Array[Array[Double]] = Array.ofDim[Double](test_N, n_visible)
208+
for(i <- 0 until test_N) {
209+
rbm.reconstruct(test_X(i), reconstructed_X(i))
210+
for(j <- 0 until n_visible) {
211+
printf("%.5f ", reconstructed_X(i)(j))
212+
}
213+
println()
208214
}
209-
println()
210-
}
211215

212-
}
216+
}
213217

214-
test_rbm()
218+
def main(args: Array[String]) {
219+
test_rbm()
220+
}
215221

222+
}

0 commit comments

Comments
 (0)