Skip to content

Commit 58037d2

Browse files
reedwmTaylor Robie
authored and
Taylor Robie
committed
Fix bug where data_async_generation.py would freeze. (tensorflow#4989)
The data_async_generation.py process would print to stderr, but the main process would redirect it's stderr to a pipe. The main process never read from the pipe, so when the pipe was full, data_async_generation.py would stall on a write to stderr. This change makes data_async_generation.py not write to stdout/stderr.
1 parent 4acdc50 commit 58037d2

File tree

2 files changed

+39
-47
lines changed

2 files changed

+39
-47
lines changed

official/recommendation/data_async_generation.py

+38-44
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import datetime
2424
import gc
2525
import functools
26-
import logging
2726
import multiprocessing
2827
import json
2928
import os
@@ -40,23 +39,25 @@
4039
import tensorflow as tf
4140

4241
from absl import app as absl_app
43-
from absl import logging as absl_logging
4442
from absl import flags
4543

4644
from official.datasets import movielens
4745
from official.recommendation import constants as rconst
4846
from official.recommendation import stat_utils
4947

5048

49+
_log_file = None
50+
51+
5152
def log_msg(msg):
5253
"""Include timestamp info when logging messages to a file."""
5354
if flags.FLAGS.redirect_logs:
5455
timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
55-
absl_logging.info("[{}] {}".format(timestamp, msg))
56+
print("[{}] {}".format(timestamp, msg), file=_log_file)
5657
else:
57-
absl_logging.info(msg)
58-
sys.stdout.flush()
59-
sys.stderr.flush()
58+
print(msg, file=_log_file)
59+
if _log_file:
60+
_log_file.flush()
6061

6162

6263
def get_cycle_folder_name(i):
@@ -395,61 +396,54 @@ def _generation_loop(
395396

396397

397398
def main(_):
399+
global _log_file
398400
redirect_logs = flags.FLAGS.redirect_logs
399401
cache_paths = rconst.Paths(
400402
data_dir=flags.FLAGS.data_dir, cache_id=flags.FLAGS.cache_id)
401403

402404

403405
log_file_name = "data_gen_proc_{}.log".format(cache_paths.cache_id)
404-
log_file = os.path.join(cache_paths.data_dir, log_file_name)
405-
if log_file.startswith("gs://") and redirect_logs:
406+
log_path = os.path.join(cache_paths.data_dir, log_file_name)
407+
if log_path.startswith("gs://") and redirect_logs:
406408
fallback_log_file = os.path.join(tempfile.gettempdir(), log_file_name)
407409
print("Unable to log to {}. Falling back to {}"
408-
.format(log_file, fallback_log_file))
409-
log_file = fallback_log_file
410+
.format(log_path, fallback_log_file))
411+
log_path = fallback_log_file
410412

411413
# This server is generally run in a subprocess.
412414
if redirect_logs:
413-
print("Redirecting stdout and stderr to {}".format(log_file))
414-
log_stream = open(log_file, "wt") # Note: not tf.gfile.Open().
415-
stdout = log_stream
416-
stderr = log_stream
415+
print("Redirecting output of data_async_generation.py process to {}"
416+
.format(log_path))
417+
_log_file = open(log_path, "wt") # Note: not tf.gfile.Open().
417418
try:
418-
if redirect_logs:
419-
absl_logging.get_absl_logger().addHandler(
420-
hdlr=logging.StreamHandler(stream=stdout))
421-
sys.stdout = stdout
422-
sys.stderr = stderr
423-
print("Logs redirected.")
424-
try:
425-
log_msg("sys.argv: {}".format(" ".join(sys.argv)))
426-
427-
if flags.FLAGS.seed is not None:
428-
np.random.seed(flags.FLAGS.seed)
429-
430-
_generation_loop(
431-
num_workers=flags.FLAGS.num_workers,
432-
cache_paths=cache_paths,
433-
num_readers=flags.FLAGS.num_readers,
434-
num_neg=flags.FLAGS.num_neg,
435-
num_train_positives=flags.FLAGS.num_train_positives,
436-
num_items=flags.FLAGS.num_items,
437-
spillover=flags.FLAGS.spillover,
438-
epochs_per_cycle=flags.FLAGS.epochs_per_cycle,
439-
train_batch_size=flags.FLAGS.train_batch_size,
440-
eval_batch_size=flags.FLAGS.eval_batch_size,
441-
)
442-
except KeyboardInterrupt:
443-
log_msg("KeyboardInterrupt registered.")
444-
except:
445-
traceback.print_exc()
446-
raise
419+
log_msg("sys.argv: {}".format(" ".join(sys.argv)))
420+
421+
if flags.FLAGS.seed is not None:
422+
np.random.seed(flags.FLAGS.seed)
423+
424+
_generation_loop(
425+
num_workers=flags.FLAGS.num_workers,
426+
cache_paths=cache_paths,
427+
num_readers=flags.FLAGS.num_readers,
428+
num_neg=flags.FLAGS.num_neg,
429+
num_train_positives=flags.FLAGS.num_train_positives,
430+
num_items=flags.FLAGS.num_items,
431+
spillover=flags.FLAGS.spillover,
432+
epochs_per_cycle=flags.FLAGS.epochs_per_cycle,
433+
train_batch_size=flags.FLAGS.train_batch_size,
434+
eval_batch_size=flags.FLAGS.eval_batch_size,
435+
)
436+
except KeyboardInterrupt:
437+
log_msg("KeyboardInterrupt registered.")
438+
except:
439+
traceback.print_exc(file=_log_file)
440+
raise
447441
finally:
448442
log_msg("Shutting down generation subprocess.")
449443
sys.stdout.flush()
450444
sys.stderr.flush()
451445
if redirect_logs:
452-
log_stream.close()
446+
_log_file.close()
453447

454448

455449
def define_flags():

official/recommendation/data_preprocessing.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -419,9 +419,7 @@ def instantiate_pipeline(dataset, data_dir, batch_size, eval_batch_size,
419419
tf.logging.info(
420420
"Generation subprocess command: {}".format(" ".join(subproc_args)))
421421

422-
proc = subprocess.Popen(args=subproc_args, stdin=subprocess.PIPE,
423-
stdout=subprocess.PIPE, stderr=subprocess.PIPE,
424-
shell=False, env=subproc_env)
422+
proc = subprocess.Popen(args=subproc_args, shell=False, env=subproc_env)
425423

426424
atexit.register(_shutdown, proc=proc)
427425
atexit.register(tf.gfile.DeleteRecursively,

0 commit comments

Comments
 (0)