Skip to content

Commit fd9ec7d

Browse files
author
bart2
committed
Pass train config as dict
1 parent 3ad821f commit fd9ec7d

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

timm/train.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from contextlib import suppress
2525
from datetime import datetime
2626
from functools import partial
27+
import typing as t
2728

2829
import torch
2930
import torch.nn as nn
@@ -395,19 +396,24 @@
395396
help='log training and validation metrics to wandb')
396397

397398

398-
def _parse_args(config_path: str | None = None):
399+
def _parse_args(config: dict[str, t.Any]):
399400
# Do we have a config file to parse?
400401
args_config, remaining = config_parser.parse_known_args()
401-
if args_config.config or config_path:
402-
config_path = config_path or args_config.config
403-
with open(config_path, 'r') as f:
402+
if args_config.config:
403+
with open(args_config.config, 'r') as f:
404404
cfg = yaml.safe_load(f)
405405
parser.set_defaults(**cfg)
406406

407407
# The main arg parser parses the rest of the args, the usual
408408
# defaults will have been overridden if config file specified.
409409
args = parser.parse_args(remaining)
410410

411+
# override args passed through config
412+
args_dict = vars(args)
413+
for k, v in config.items():
414+
assert k in args_dict, f"Parameter {k} not recognized"
415+
args.__setattr__(k, v)
416+
411417
# Cache the args as a text string to save them in the output dir later
412418
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
413419
return args, args_text
@@ -420,10 +426,9 @@ def _log_params(args):
420426
})
421427

422428

423-
424-
def train(config_path: str | None = None):
429+
def train(config: dict[str, t.Any]):
425430
utils.setup_default_logging()
426-
args, args_text = _parse_args(config_path)
431+
args, args_text = _parse_args(config)
427432

428433
if args.device_modules:
429434
for module in args.device_modules:
@@ -1199,4 +1204,4 @@ def validate(
11991204

12001205

12011206
if __name__ == '__main__':
1202-
train()
1207+
train(dict())

0 commit comments

Comments
 (0)