|
6 | 6 |
|
7 | 7 | import torch
|
8 | 8 |
|
| 9 | +from config import Config |
| 10 | + |
| 11 | + |
9 | 12 | if __name__ == '__main__':
|
10 | 13 | parser = argparse.ArgumentParser(
|
11 | 14 | description='Create the Biaffine Parser model.'
|
12 | 15 | )
|
13 |
| - subparsers = parser.add_subparsers(title='Commands') |
| 16 | + subparsers = parser.add_subparsers(title='Commands', dest='mode') |
14 | 17 | subcommands = {
|
15 | 18 | 'evaluate': Evaluate(),
|
16 | 19 | 'predict': Predict(),
|
17 | 20 | 'train': Train()
|
18 | 21 | }
|
19 | 22 | for name, subcommand in subcommands.items():
|
20 | 23 | subparser = subcommand.add_subparser(name, subparsers)
|
| 24 | + subparser.add_argument('--conf', '-c', default='config.ini', |
| 25 | + help='path to config file') |
| 26 | + subparser.add_argument('--model', '-m', default='exp/ptb/model.tag', |
| 27 | + help='path to model file') |
| 28 | + subparser.add_argument('--vocab', '-v', default='exp/ptb/vocab.tag', |
| 29 | + help='path to vocab file') |
21 | 30 | subparser.add_argument('--device', '-d', default='-1',
|
22 | 31 | help='ID of GPU to use')
|
23 | 32 | subparser.add_argument('--seed', '-s', default=1, type=int,
|
24 | 33 | help='seed for generating random numbers')
|
25 | 34 | subparser.add_argument('--threads', '-t', default=4, type=int,
|
26 | 35 | help='max num of threads')
|
27 |
| - subparser.add_argument('--file', '-f', default='model.pt', |
28 |
| - help='path to model file') |
29 |
| - subparser.add_argument('--vocab', '-v', default='vocab.pt', |
30 |
| - help='path to vocabulary file') |
31 | 36 | args = parser.parse_args()
|
32 | 37 |
|
33 | 38 | print(f"Set the max num of threads to {args.threads}")
|
|
37 | 42 | torch.manual_seed(args.seed)
|
38 | 43 | os.environ['CUDA_VISIBLE_DEVICES'] = args.device
|
39 | 44 |
|
40 |
| - args.func(args) |
| 45 | + print(f"Override the default configs with parsed arguments") |
| 46 | + config = Config(args.conf) |
| 47 | + config.update(vars(args)) |
| 48 | + print(config) |
| 49 | + |
| 50 | + print(f"Run the subcommand in mode {args.mode}") |
| 51 | + cmd = subcommands[args.mode] |
| 52 | + cmd(config) |
0 commit comments