Skip to content

Commit b23ef62

Browse files
committed
Merge pull request lisa-lab#17 from boulanni/master
Small fixes to RNN-RBM tutorial
2 parents 2eb32d0 + d72c34c commit b23ef62

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

code/rnnrbm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,9 +271,10 @@ def test_rnnrbm(batch_size=100, num_epochs=200):
271271
model = RnnRbm()
272272
model.train(glob.glob('../data/Nottingham/train/*.mid'),
273273
batch_size=batch_size, num_epochs=num_epochs)
274+
return model
274275

275276
if __name__ == '__main__':
276-
test_rnnrbm()
277+
model = test_rnnrbm()
277278
model.generate('sample1.mid')
278279
model.generate('sample2.mid')
279280
pylab.show()

doc/rnnrbm.txt

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Modeling and generating sequences of polyphonic music with the RNN-RBM
1111
and `restricted Boltzmann machines (RBM) <rbm.html>`_.
1212

1313
.. note::
14-
The code for this section is available for download here: `rnnrbm.py <../code/rnnrbm.py>`_.
14+
The code for this section is available for download here: `rnnrbm.py <code/rnnrbm.py>`_.
1515

1616
You will need the modified `Python MIDI package (GPL license) <http://www.iro.umontreal.ca/~lisa/deep/midi.zip>`_ in your ``$PYTHONPATH`` or in the working directory in order to convert MIDI files to and from piano-rolls.
1717
The script also assumes that the content of the `Nottingham Database of folk tunes <http://www.iro.umontreal.ca/~lisa/deep/data/Nottingham.zip>`_ has been extracted in the ``../data`` directory.
@@ -266,7 +266,7 @@ We now have all the necessary ingredients to start training our network on real
266266
n_hidden_recurrent)
267267

268268
gradient = T.grad(cost, params, consider_constant=[v_sample])
269-
updates_train.update(dict((p, p - lr * g) for p, g in zip(params,
269+
updates_train.update(((p, p - lr * g) for p, g in zip(params,
270270
gradient)))
271271
self.train_function = theano.function([v], monitor,
272272
updates=updates_train)
@@ -288,7 +288,10 @@ We now have all the necessary ingredients to start training our network on real
288288

289289
assert len(files) > 0, 'Training set is empty!' \
290290
' (did you download the data files?)'
291-
dataset = [midiread(f, self.r, self.dt).piano_roll for f in files]
291+
dataset = [midiread(f, self.r,
292+
self.dt).piano_roll.astype(theano.config.floatX)
293+
for f in files]
294+
292295
try:
293296
for epoch in xrange(num_epochs):
294297
numpy.random.shuffle(dataset)

0 commit comments

Comments
 (0)