Skip to content

Commit e50f1f3

Browse files
author
zysite
committed
Read configs from INI file
1 parent df9d88c commit e50f1f3

File tree

2 files changed

+50
-29
lines changed

2 files changed

+50
-29
lines changed

config.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,37 @@
11
# -*- coding: utf-8 -*-
22

3+
from ast import literal_eval
4+
from configparser import ConfigParser
5+
36

47
class Config(object):
58

6-
# [Network]
7-
n_embed = 100
8-
n_tag_embed = 100
9-
embed_dropout = 0.33
10-
n_lstm_hidden = 400
11-
n_lstm_layers = 3
12-
lstm_dropout = 0.33
13-
n_mlp_arc = 500
14-
n_mlp_rel = 100
15-
mlp_dropout = 0.33
16-
17-
# [Optimizer]
18-
lr = 2e-3
19-
beta_1 = 0.9
20-
beta_2 = 0.9
21-
epsilon = 1e-12
22-
decay = .75
23-
decay_steps = 5000
24-
25-
# [Run]
26-
batch_size = 200
27-
epochs = 1000
28-
patience = 100
9+
def __init__(self, fname):
10+
super(Config, self).__init__()
11+
12+
self.config = ConfigParser()
13+
self.config.read(fname)
14+
self.kwargs = dict((option, literal_eval(value))
15+
for section in self.config.sections()
16+
for option, value in self.config.items(section))
17+
18+
def __repr__(self):
19+
info = f"{self.__class__.__name__}:\n"
20+
for i, (option, value) in enumerate(self.kwargs.items()):
21+
info += f"{option:15} {value:<25}" + ('\n' if i % 2 > 0 else '')
22+
if i % 2 == 0:
23+
info += '\n'
24+
25+
return info
26+
27+
def __getattr__(self, attr):
28+
return self.kwargs.get(attr, None)
29+
30+
def __getstate__(self):
31+
return vars(self)
32+
33+
def __setstate__(self, state):
34+
self.__dict__.update(state)
35+
36+
def update(self, kwargs):
37+
self.kwargs.update(kwargs)

run.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,33 @@
66

77
import torch
88

9+
from config import Config
10+
11+
912
if __name__ == '__main__':
1013
parser = argparse.ArgumentParser(
1114
description='Create the Biaffine Parser model.'
1215
)
13-
subparsers = parser.add_subparsers(title='Commands')
16+
subparsers = parser.add_subparsers(title='Commands', dest='mode')
1417
subcommands = {
1518
'evaluate': Evaluate(),
1619
'predict': Predict(),
1720
'train': Train()
1821
}
1922
for name, subcommand in subcommands.items():
2023
subparser = subcommand.add_subparser(name, subparsers)
24+
subparser.add_argument('--conf', '-c', default='config.ini',
25+
help='path to config file')
26+
subparser.add_argument('--model', '-m', default='exp/ptb/model.tag',
27+
help='path to model file')
28+
subparser.add_argument('--vocab', '-v', default='exp/ptb/vocab.tag',
29+
help='path to vocab file')
2130
subparser.add_argument('--device', '-d', default='-1',
2231
help='ID of GPU to use')
2332
subparser.add_argument('--seed', '-s', default=1, type=int,
2433
help='seed for generating random numbers')
2534
subparser.add_argument('--threads', '-t', default=4, type=int,
2635
help='max num of threads')
27-
subparser.add_argument('--file', '-f', default='model.pt',
28-
help='path to model file')
29-
subparser.add_argument('--vocab', '-v', default='vocab.pt',
30-
help='path to vocabulary file')
3136
args = parser.parse_args()
3237

3338
print(f"Set the max num of threads to {args.threads}")
@@ -37,4 +42,11 @@
3742
torch.manual_seed(args.seed)
3843
os.environ['CUDA_VISIBLE_DEVICES'] = args.device
3944

40-
args.func(args)
45+
print(f"Override the default configs with parsed arguments")
46+
config = Config(args.conf)
47+
config.update(vars(args))
48+
print(config)
49+
50+
print(f"Run the subcommand in mode {args.mode}")
51+
cmd = subcommands[args.mode]
52+
cmd(config)

0 commit comments

Comments
 (0)