Skip to content

Commit c0a380d

Browse files
authored
Update the wide_deep and transformer code for latest benchmark config. (tensorflow#4246)
* Update the wide_deep code for latest benchmark config. * Also update the transformer benchmark code.
1 parent b9ca525 commit c0a380d

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

official/transformer/transformer_main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ def run_transformer(flags_obj):
433433
tensors_to_log=TENSORS_TO_LOG, # used for logging hooks
434434
batch_size=params.batch_size # for ExamplesPerSecondHook
435435
)
436-
benchmark_logger = logger.config_benchmark_logger(flags_obj.benchmark_log_dir)
436+
benchmark_logger = logger.config_benchmark_logger(flags_obj)
437437
benchmark_logger.log_run_info(
438438
model_name="transformer",
439439
dataset_name="wmt_translate_ende",

official/wide_deep/wide_deep.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def eval_input_fn():
245245
'model_type': flags_obj.model_type,
246246
}
247247

248-
benchmark_logger = logger.config_benchmark_logger(flags_obj.benchmark_log_dir)
248+
benchmark_logger = logger.config_benchmark_logger(flags_obj)
249249
benchmark_logger.log_run_info('wide_deep', 'Census Income', run_params)
250250

251251
loss_prefix = LOSS_PREFIX.get(flags_obj.model_type, '')

0 commit comments

Comments
 (0)