Skip to content

Commit 7d4b17e

Browse files
committed
refactor
1 parent 5803311 commit 7d4b17e

File tree

4 files changed

+54
-32
lines changed

4 files changed

+54
-32
lines changed

neural_ner/config.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
import os
22
import torch
3+
import numpy as np
4+
5+
np.random.seed(123)
6+
torch.manual_seed(123)
7+
if torch.cuda.is_available():
8+
torch.cuda.manual_seed_all(123)
39

410
class Config(object):
511
pass
@@ -18,7 +24,7 @@ class Config(object):
1824
config.dropout_ratio = 0.5
1925

2026
config.max_grad_norm = 5.0
21-
config.batch_size = 32
27+
config.batch_size = 1
2228
config.num_epochs = 100
2329

2430
config.print_every = 100
@@ -45,4 +51,9 @@ class Config(object):
4551

4652
config.is_l2_loss = False
4753

48-
config.is_cuda = config.is_cuda and torch.cuda.is_available()
54+
config.model_name = 'model.NER_SOFTMAX_CHAR'
55+
config.optimizer = 'sgd'
56+
57+
# config postprocess
58+
config.is_cuda = config.is_cuda and torch.cuda.is_available()
59+

neural_ner/model.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,6 @@
1515

1616
logging.basicConfig(level=logging.INFO)
1717

18-
np.random.seed(123)
19-
torch.manual_seed(123)
20-
if torch.cuda.is_available():
21-
torch.cuda.manual_seed_all(123)
22-
23-
def get_model(vocab, config, model_file_path, is_eval=False):
24-
#model = NER_SOFTMAX_CHAR_CRF(vocab, config)
25-
model = NER_SOFTMAX_CHAR(vocab, config)
26-
27-
if is_eval:
28-
model = model.eval()
29-
if config.is_cuda:
30-
model = model.cuda()
31-
32-
if model_file_path is not None:
33-
state = torch.load(model_file_path, map_location=lambda storage, location: storage)
34-
model.load_state_dict(state['model'], strict=False)
35-
36-
return model
37-
3818
class NER_SOFTMAX_CHAR(nn.Module):
3919
def __init__(self, vocab, config):
4020
super(NER_SOFTMAX_CHAR, self).__init__()

neural_ner/model_utils.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,44 @@
1-
from __future__ import unicode_literals, print_function, division
2-
31
import torch
42
import numpy as np
3+
import importlib
4+
import logging
5+
from torch.optim import Adam, SGD
6+
7+
logging.basicConfig(level=logging.INFO)
8+
9+
def get_model(vocab, config, model_file_path, is_eval=False):
10+
class_path = config.model_name.split('.')
11+
class_name = class_path[-1]
12+
class_module = '.'.join(class_path[:-1])
13+
14+
ModelClass = getattr(importlib.import_module(class_module), class_name)
15+
16+
model = ModelClass(vocab, config)
17+
18+
if is_eval:
19+
model = model.eval()
20+
if config.is_cuda:
21+
model = model.cuda()
22+
23+
if model_file_path is not None:
24+
state = torch.load(model_file_path, map_location=lambda storage, location: storage)
25+
model.load_state_dict(state['model'], strict=False)
26+
27+
return model
28+
29+
def get_optimizer(model, config):
30+
params = list(filter(lambda p: p.requires_grad, model.parameters()))
31+
32+
optimizer = None
33+
if config.optimizer == 'adam':
34+
optimizer = Adam(params, amsgrad=True)
35+
elif config.optimizer == 'sgd':
36+
optimizer = SGD(params, lr=0.01, momentum=0.9)
37+
38+
num_params = sum(p.numel() for p in params)
39+
logging.info("Number of params: %d" % num_params)
40+
41+
return optimizer, params
542

643
def get_mask(lengths, config):
744
seq_lens = lengths.view(-1, 1)

neural_ner/process_training.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55

66
import torch
77
from torch.nn.utils import clip_grad_norm_
8-
from torch.optim import Adam, SGD
98

109
from data_utils.batcher import DatasetConll2003
1110
from data_utils.vocab import Vocab
12-
from model import get_model
11+
from model_utils import get_model, get_optimizer
1312
from train_utils import setup_train_dir, save_model, write_summary, \
1413
get_param_norm, get_grad_norm, Evaluter
1514

@@ -49,12 +48,7 @@ def train_one_batch(self, batch, optimizer, params):
4948

5049
def train(self):
5150
train_dir, summary_writer = setup_train_dir(self.config)
52-
53-
params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
54-
optimizer = Adam(params, amsgrad=True)
55-
56-
num_params = sum(p.numel() for p in params)
57-
logging.info("Number of params: %d" % num_params)
51+
optimizer, params = get_optimizer(self.model, self.config)
5852

5953
exp_loss, best_dev_f1 = None, None
6054

0 commit comments

Comments
 (0)