Skip to content

Commit aec1fec

Browse files
author
Taylor Robie
authored
Fix/ncf eval default (tensorflow#5438)
* improve default handling for eval_batch_size * return eval_batch_size default to None * fix syntax error
1 parent 505cad9 commit aec1fec

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

official/recommendation/ncf_main.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,9 @@ def run_ncf(_):
128128
batch_size = distribution_utils.per_device_batch_size(
129129
int(FLAGS.batch_size), num_gpus)
130130

131-
eval_batch_size = int(FLAGS.eval_batch_size or FLAGS.batch_size)
132131
eval_per_user = rconst.NUM_EVAL_NEGATIVES + 1
132+
eval_batch_size = int(FLAGS.eval_batch_size or
133+
max([FLAGS.batch_size, eval_per_user]))
133134
if eval_batch_size % eval_per_user:
134135
eval_batch_size = eval_batch_size // eval_per_user * eval_per_user
135136
tf.logging.warning(
@@ -365,7 +366,8 @@ def define_ncf_flags():
365366
@flags.validator("eval_batch_size", "eval_batch_size must be at least {}"
366367
.format(rconst.NUM_EVAL_NEGATIVES + 1))
367368
def eval_size_check(eval_batch_size):
368-
return int(eval_batch_size) > rconst.NUM_EVAL_NEGATIVES
369+
return (eval_batch_size is None or
370+
int(eval_batch_size) > rconst.NUM_EVAL_NEGATIVES)
369371

370372

371373
if __name__ == "__main__":

0 commit comments

Comments
 (0)