Skip to content

Commit c88fcb2

Browse files
committed
Use flagfile to pass flags to data async generation process.
1 parent d4ac494 commit c88fcb2

File tree

2 files changed

+49
-91
lines changed

2 files changed

+49
-91
lines changed

official/recommendation/data_async_generation.py

+13-32
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151

5252
def log_msg(msg):
5353
"""Include timestamp info when logging messages to a file."""
54-
if flags.FLAGS.use_command_file:
54+
if flags.FLAGS.use_tf_logging:
5555
tf.logging.info(msg)
5656
return
5757

@@ -440,44 +440,26 @@ def remove_alive_file():
440440
gc.collect()
441441

442442

443-
def _set_flags_with_command_file():
444-
"""Use arguments from COMMAND_FILE when use_command_file is True."""
445-
command_file = os.path.join(flags.FLAGS.data_dir,
446-
rconst.COMMAND_FILE)
447-
tf.logging.info("Waiting for command file to appear at {}..."
448-
.format(command_file))
449-
while not tf.gfile.Exists(command_file):
443+
def _parse_flagfile():
444+
"""Fill flags with flagfile."""
445+
flagfile = os.path.join(flags.FLAGS.data_dir,
446+
rconst.FLAGFILE)
447+
tf.logging.info("Waiting for flagfile to appear at {}..."
448+
.format(flagfile))
449+
while not tf.gfile.Exists(flagfile):
450450
time.sleep(1)
451-
tf.logging.info("Command file found.")
452-
with tf.gfile.Open(command_file, "r") as f:
453-
command = json.load(f)
454-
flags.FLAGS.num_workers = command["num_workers"]
455-
assert flags.FLAGS.data_dir == command["data_dir"]
456-
flags.FLAGS.cache_id = command["cache_id"]
457-
flags.FLAGS.num_readers = command["num_readers"]
458-
flags.FLAGS.num_neg = command["num_neg"]
459-
flags.FLAGS.num_train_positives = command["num_train_positives"]
460-
flags.FLAGS.num_items = command["num_items"]
461-
flags.FLAGS.epochs_per_cycle = command["epochs_per_cycle"]
462-
flags.FLAGS.train_batch_size = command["train_batch_size"]
463-
flags.FLAGS.eval_batch_size = command["eval_batch_size"]
464-
flags.FLAGS.spillover = command["spillover"]
465-
flags.FLAGS.redirect_logs = command["redirect_logs"]
466-
assert flags.FLAGS.redirect_logs is False
467-
if "seed" in command:
468-
flags.FLAGS.seed = command["seed"]
451+
tf.logging.info("flagfile found.")
452+
flags.FLAGS([__file__, "--flagfile", flagfile])
469453

470454

471455
def main(_):
472456
global _log_file
473-
if flags.FLAGS.use_command_file is not None:
474-
_set_flags_with_command_file()
457+
_parse_flagfile()
475458

476459
redirect_logs = flags.FLAGS.redirect_logs
477460
cache_paths = rconst.Paths(
478461
data_dir=flags.FLAGS.data_dir, cache_id=flags.FLAGS.cache_id)
479462

480-
481463
log_file_name = "data_gen_proc_{}.log".format(cache_paths.cache_id)
482464
log_path = os.path.join(cache_paths.data_dir, log_file_name)
483465
if log_path.startswith("gs://") and redirect_logs:
@@ -559,12 +541,11 @@ def define_flags():
559541
flags.DEFINE_boolean(name="redirect_logs", default=False,
560542
help="Catch logs and write them to a file. "
561543
"(Useful if this is run as a subprocess)")
544+
flags.DEFINE_boolean(name="use_tf_logging", default=False,
545+
help="Use tf.logging instead of log file.")
562546
flags.DEFINE_integer(name="seed", default=None,
563547
help="NumPy random seed to set at startup. If not "
564548
"specified, a seed will not be set.")
565-
flags.DEFINE_boolean(name="use_command_file", default=False,
566-
help="Use command arguments from json at command_path. "
567-
"All arguments other than data_dir will be ignored.")
568549

569550

570551
if __name__ == "__main__":

official/recommendation/data_preprocessing.py

+36-59
Original file line numberDiff line numberDiff line change
@@ -430,77 +430,54 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
430430
# pool underlying the training generation doesn't starve other processes.
431431
num_workers = int(multiprocessing.cpu_count() * 0.75) or 1
432432

433+
flags_ = {
434+
"data_dir": data_dir,
435+
"cache_id": ncf_dataset.cache_paths.cache_id,
436+
"num_neg": num_neg,
437+
"num_train_positives": ncf_dataset.num_train_positives,
438+
"num_items": ncf_dataset.num_items,
439+
"num_readers": ncf_dataset.num_data_readers,
440+
"epochs_per_cycle": epochs_per_cycle,
441+
"train_batch_size": batch_size,
442+
"eval_batch_size": eval_batch_size,
443+
"num_workers": num_workers,
444+
# This allows the training input function to guarantee batch size and
445+
# significantly improves performance. (~5% increase in examples/sec on
446+
# GPU, and needed for TPU XLA.)
447+
"spillover": True,
448+
"redirect_logs": use_subprocess,
449+
"use_tf_logging": not use_subprocess,
450+
}
451+
if ncf_dataset.deterministic:
452+
flags_["seed"] = stat_utils.random_int32()
453+
# We write to a temp file then atomically rename it to the final file,
454+
# because writing directly to the final file can cause the data generation
455+
# async process to read a partially written JSON file.
456+
flagfile_temp = os.path.join(flags.FLAGS.data_dir, rconst.FLAGFILE_TEMP)
457+
tf.logging.info("Preparing flagfile for async data generation in {} ..."
458+
.format(flagfile_temp))
459+
with tf.gfile.Open(flagfile_temp, "w") as f:
460+
for k, v in six.iteritems(flags_):
461+
f.write("--{}={}\n".format(k, v))
462+
flagfile = os.path.join(data_dir, rconst.FLAGFILE)
463+
tf.gfile.Rename(flagfile_temp, flagfile)
464+
tf.logging.info(
465+
"Wrote flagfile for async data generation in {}."
466+
.format(flagfile))
467+
433468
if use_subprocess:
434469
tf.logging.info("Creating training file subprocess.")
435-
436470
subproc_env = os.environ.copy()
437-
438471
# The subprocess uses TensorFlow for tf.gfile, but it does not need GPU
439472
# resources and by default will try to allocate GPU memory. This would cause
440473
# contention with the main training process.
441474
subproc_env["CUDA_VISIBLE_DEVICES"] = ""
442-
443475
subproc_args = popen_helper.INVOCATION + [
444-
"--data_dir", data_dir,
445-
"--cache_id", str(ncf_dataset.cache_paths.cache_id),
446-
"--num_neg", str(num_neg),
447-
"--num_train_positives", str(ncf_dataset.num_train_positives),
448-
"--num_items", str(ncf_dataset.num_items),
449-
"--num_readers", str(ncf_dataset.num_data_readers),
450-
"--epochs_per_cycle", str(epochs_per_cycle),
451-
"--train_batch_size", str(batch_size),
452-
"--eval_batch_size", str(eval_batch_size),
453-
"--num_workers", str(num_workers),
454-
# This allows the training input function to guarantee batch size and
455-
# significantly improves performance. (~5% increase in examples/sec on
456-
# GPU, and needed for TPU XLA.)
457-
"--spillover", "True",
458-
"--redirect_logs", "True"
459-
]
460-
if ncf_dataset.deterministic:
461-
subproc_args.extend(["--seed", str(int(stat_utils.random_int32()))])
462-
476+
"--data_dir", data_dir]
463477
tf.logging.info(
464478
"Generation subprocess command: {}".format(" ".join(subproc_args)))
465-
466479
proc = subprocess.Popen(args=subproc_args, shell=False, env=subproc_env)
467480

468-
else:
469-
# We write to a temp file then atomically rename it to the final file,
470-
# because writing directly to the final file can cause the data generation
471-
# async process to read a partially written JSON file.
472-
command_file_temp = os.path.join(data_dir, rconst.COMMAND_FILE_TEMP)
473-
tf.logging.info("Generation subprocess command at {} ..."
474-
.format(command_file_temp))
475-
with tf.gfile.Open(command_file_temp, "w") as f:
476-
command = {
477-
"data_dir": data_dir,
478-
"cache_id": ncf_dataset.cache_paths.cache_id,
479-
"num_neg": num_neg,
480-
"num_train_positives": ncf_dataset.num_train_positives,
481-
"num_items": ncf_dataset.num_items,
482-
"num_readers": ncf_dataset.num_data_readers,
483-
"epochs_per_cycle": epochs_per_cycle,
484-
"train_batch_size": batch_size,
485-
"eval_batch_size": eval_batch_size,
486-
"num_workers": num_workers,
487-
# This allows the training input function to guarantee batch size and
488-
# significantly improves performance. (~5% increase in examples/sec on
489-
# GPU, and needed for TPU XLA.)
490-
"spillover": True,
491-
"redirect_logs": False
492-
}
493-
if ncf_dataset.deterministic:
494-
command["seed"] = stat_utils.random_int32()
495-
496-
json.dump(command, f)
497-
command_file = os.path.join(data_dir, rconst.COMMAND_FILE)
498-
tf.gfile.Rename(command_file_temp, command_file)
499-
500-
tf.logging.info(
501-
"Generation subprocess command saved to: {}"
502-
.format(command_file))
503-
504481
cleanup_called = {"finished": False}
505482
@atexit.register
506483
def cleanup():

0 commit comments

Comments
 (0)