File tree 1 file changed +4
-2
lines changed
1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change @@ -128,8 +128,9 @@ def run_ncf(_):
128
128
batch_size = distribution_utils .per_device_batch_size (
129
129
int (FLAGS .batch_size ), num_gpus )
130
130
131
- eval_batch_size = int (FLAGS .eval_batch_size or FLAGS .batch_size )
132
131
eval_per_user = rconst .NUM_EVAL_NEGATIVES + 1
132
+ eval_batch_size = int (FLAGS .eval_batch_size or
133
+ max ([FLAGS .batch_size , eval_per_user ]))
133
134
if eval_batch_size % eval_per_user :
134
135
eval_batch_size = eval_batch_size // eval_per_user * eval_per_user
135
136
tf .logging .warning (
@@ -365,7 +366,8 @@ def define_ncf_flags():
365
366
@flags .validator ("eval_batch_size" , "eval_batch_size must be at least {}"
366
367
.format (rconst .NUM_EVAL_NEGATIVES + 1 ))
367
368
def eval_size_check (eval_batch_size ):
368
- return int (eval_batch_size ) > rconst .NUM_EVAL_NEGATIVES
369
+ return (eval_batch_size is None or
370
+ int (eval_batch_size ) > rconst .NUM_EVAL_NEGATIVES )
369
371
370
372
371
373
if __name__ == "__main__" :
You can’t perform that action at this time.
0 commit comments