24
24
from contextlib import suppress
25
25
from datetime import datetime
26
26
from functools import partial
27
+ import typing as t
27
28
28
29
import torch
29
30
import torch .nn as nn
395
396
help = 'log training and validation metrics to wandb' )
396
397
397
398
398
- def _parse_args (config_path : str | None = None ):
399
+ def _parse_args (config : dict [ str , t . Any ] ):
399
400
# Do we have a config file to parse?
400
401
args_config , remaining = config_parser .parse_known_args ()
401
- if args_config .config or config_path :
402
- config_path = config_path or args_config .config
403
- with open (config_path , 'r' ) as f :
402
+ if args_config .config :
403
+ with open (args_config .config , 'r' ) as f :
404
404
cfg = yaml .safe_load (f )
405
405
parser .set_defaults (** cfg )
406
406
407
407
# The main arg parser parses the rest of the args, the usual
408
408
# defaults will have been overridden if config file specified.
409
409
args = parser .parse_args (remaining )
410
410
411
+ # override args passed through config
412
+ args_dict = vars (args )
413
+ for k , v in config .items ():
414
+ assert k in args_dict , f"Parameter { k } not recognized"
415
+ args .__setattr__ (k , v )
416
+
411
417
# Cache the args as a text string to save them in the output dir later
412
418
args_text = yaml .safe_dump (args .__dict__ , default_flow_style = False )
413
419
return args , args_text
@@ -420,10 +426,9 @@ def _log_params(args):
420
426
})
421
427
422
428
423
-
424
- def train (config_path : str | None = None ):
429
+ def train (config : dict [str , t .Any ]):
425
430
utils .setup_default_logging ()
426
- args , args_text = _parse_args (config_path )
431
+ args , args_text = _parse_args (config )
427
432
428
433
if args .device_modules :
429
434
for module in args .device_modules :
@@ -1199,4 +1204,4 @@ def validate(
1199
1204
1200
1205
1201
1206
if __name__ == '__main__' :
1202
- train ()
1207
+ train (dict () )
0 commit comments