Skip to content

Commit 032bc2f

Browse files
committed
Support loading config from source
1 parent f410ffd commit 032bc2f

File tree

1 file changed

+28
-8
lines changed

1 file changed

+28
-8
lines changed

supar/utils/config.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
# -*- coding: utf-8 -*-
22

3+
import argparse
4+
import os
35
from ast import literal_eval
46
from configparser import ConfigParser
57

8+
import supar
9+
from supar.utils.fn import download
10+
611

712
class Config(object):
813

9-
def __init__(self, conf=None, **kwargs):
14+
def __init__(self, **kwargs):
1015
super(Config, self).__init__()
1116

12-
config = ConfigParser()
13-
config.read(conf or [])
14-
self.update({**dict((name, literal_eval(value))
15-
for section in config.sections()
16-
for name, value in config.items(section)),
17-
**kwargs})
17+
self.update(kwargs)
1818

1919
def __repr__(self):
2020
s = line = "-" * 20 + "-+-" + "-" * 30 + "\n"
@@ -28,6 +28,9 @@ def __repr__(self):
2828
def __getitem__(self, key):
2929
return getattr(self, key)
3030

31+
def __contains__(self, key):
32+
return hasattr(self, key)
33+
3134
def __getstate__(self):
3235
return vars(self)
3336

@@ -46,8 +49,25 @@ def update(self, kwargs):
4649
kwargs.update(kwargs.pop('kwargs', dict()))
4750
for name, value in kwargs.items():
4851
setattr(self, name, value)
49-
5052
return self
5153

54+
def get(self, key, default=None):
55+
return getattr(self, key) if hasattr(self, key) else default
56+
5257
def pop(self, key, val=None):
5358
return self.__dict__.pop(key, val)
59+
60+
@classmethod
61+
def load(cls, conf='', unknown=None, **kwargs):
62+
config = ConfigParser()
63+
config.read(conf if not conf or os.path.exists(conf) else download(supar.CONFIG.get(conf, conf)))
64+
config = dict((name, literal_eval(value))
65+
for section in config.sections()
66+
for name, value in config.items(section))
67+
if unknown is not None:
68+
parser = argparse.ArgumentParser()
69+
for name, value in config.items():
70+
parser.add_argument('--'+name.replace('_', '-'), type=type(value), default=value)
71+
config.update(vars(parser.parse_args(unknown)))
72+
config.update(kwargs)
73+
return cls(**config)

0 commit comments

Comments
 (0)