Skip to content

Commit 401a99a

Browse files
committed
Add a way to reload pretrained model
1 parent 3d9b1ac commit 401a99a

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

code/lstm.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ def _p(pp, name):
8080

8181

8282
def init_params(options):
83+
"""
84+
Global (not LSTM) parameter. For the embeding and the classifier.
85+
"""
8386
params = OrderedDict()
8487
# embedding
8588
randn = numpy.random.rand(options['n_words'],
@@ -125,6 +128,11 @@ def ortho_weight(ndim):
125128

126129

127130
def param_init_lstm(options, params, prefix='lstm'):
131+
"""
132+
Init the LSTM parameter:
133+
134+
:see: init_params
135+
"""
128136
W = numpy.concatenate([ortho_weight(options['dim_proj']),
129137
ortho_weight(options['dim_proj']),
130138
ortho_weight(options['dim_proj']),
@@ -388,6 +396,7 @@ def test_lstm(
388396
noise_std=0.,
389397
use_dropout=True, # if False slightly faster, but worst test error
390398
# This frequently need a bigger model.
399+
reload_model="", # Path to a saved model we want to start from.
391400
):
392401

393402
# Model options
@@ -407,6 +416,9 @@ def test_lstm(
407416
# Dict name (string) -> numpy ndarray
408417
params = init_params(model_options)
409418

419+
if reload_model:
420+
load_params('lstm_model.npz', params)
421+
410422
# This create Theano Shared Variable from the parameters.
411423
# Dict name (string) -> Theano Tensor Shared Variable
412424
# params and tparams have different copy of the weights.
@@ -561,4 +573,7 @@ def test_lstm(
561573
theano.config.scan.allow_gc = False
562574

563575
# See function train for all possible parameter and there definition.
564-
test_lstm(max_epochs=10)
576+
test_lstm(
577+
#reload_model="lstm_model.npz",
578+
max_epochs=10,
579+
)

0 commit comments

Comments
 (0)