@@ -11,6 +11,22 @@ def get_logger(name):
11
11
return logging .getLogger (name )
12
12
13
13
14
+ class TqdmHandler (logging .StreamHandler ):
15
+
16
+ def __init__ (self , * args , ** kwargs ):
17
+ super ().__init__ (* args , ** kwargs )
18
+
19
+ def emit (self , record ):
20
+ try :
21
+ msg = self .format (record )
22
+ tqdm .write (msg )
23
+ self .flush ()
24
+ except (KeyboardInterrupt , SystemExit ):
25
+ raise
26
+ except Exception :
27
+ self .handleError (record )
28
+
29
+
14
30
def init_logger (logger ,
15
31
path = None ,
16
32
mode = 'w' ,
@@ -19,9 +35,9 @@ def init_logger(logger,
19
35
verbose = True ):
20
36
level = level or logging .WARNING
21
37
if not handlers :
22
- handlers = [logging . StreamHandler ()]
38
+ handlers = [TqdmHandler ()]
23
39
if path :
24
- os .makedirs (os .path .dirname (path ), exist_ok = True )
40
+ os .makedirs (os .path .dirname (path ) or './' , exist_ok = True )
25
41
handlers .append (logging .FileHandler (path , mode ))
26
42
logging .basicConfig (format = '%(asctime)s %(levelname)s %(message)s' ,
27
43
datefmt = '%Y-%m-%d %H:%M:%S' ,
@@ -33,13 +49,15 @@ def init_logger(logger,
33
49
def progress_bar (iterator ,
34
50
ncols = None ,
35
51
bar_format = '{l_bar}{bar:18}| {n_fmt}/{total_fmt} {elapsed}<{remaining}, {rate_fmt}{postfix}' ,
36
- leave = True ):
52
+ leave = False ,
53
+ ** kwargs ):
37
54
return tqdm (iterator ,
38
55
ncols = ncols ,
39
56
bar_format = bar_format ,
40
57
ascii = True ,
41
58
disable = (not (logger .level == logging .INFO and is_master ())),
42
- leave = leave )
59
+ leave = leave ,
60
+ ** kwargs )
43
61
44
62
45
63
logger = get_logger ('supar' )
0 commit comments