|
23 | 23 | import datetime
|
24 | 24 | import gc
|
25 | 25 | import functools
|
26 |
| -import logging |
27 | 26 | import multiprocessing
|
28 | 27 | import json
|
29 | 28 | import os
|
|
40 | 39 | import tensorflow as tf
|
41 | 40 |
|
42 | 41 | from absl import app as absl_app
|
43 |
| -from absl import logging as absl_logging |
44 | 42 | from absl import flags
|
45 | 43 |
|
46 | 44 | from official.datasets import movielens
|
47 | 45 | from official.recommendation import constants as rconst
|
48 | 46 | from official.recommendation import stat_utils
|
49 | 47 |
|
50 | 48 |
|
| 49 | +_log_file = None |
| 50 | + |
| 51 | + |
51 | 52 | def log_msg(msg):
|
52 | 53 | """Include timestamp info when logging messages to a file."""
|
53 | 54 | if flags.FLAGS.redirect_logs:
|
54 | 55 | timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
|
55 |
| - absl_logging.info("[{}] {}".format(timestamp, msg)) |
| 56 | + print("[{}] {}".format(timestamp, msg), file=_log_file) |
56 | 57 | else:
|
57 |
| - absl_logging.info(msg) |
58 |
| - sys.stdout.flush() |
59 |
| - sys.stderr.flush() |
| 58 | + print(msg, file=_log_file) |
| 59 | + if _log_file: |
| 60 | + _log_file.flush() |
60 | 61 |
|
61 | 62 |
|
62 | 63 | def get_cycle_folder_name(i):
|
@@ -395,61 +396,54 @@ def _generation_loop(
|
395 | 396 |
|
396 | 397 |
|
397 | 398 | def main(_):
|
| 399 | + global _log_file |
398 | 400 | redirect_logs = flags.FLAGS.redirect_logs
|
399 | 401 | cache_paths = rconst.Paths(
|
400 | 402 | data_dir=flags.FLAGS.data_dir, cache_id=flags.FLAGS.cache_id)
|
401 | 403 |
|
402 | 404 |
|
403 | 405 | log_file_name = "data_gen_proc_{}.log".format(cache_paths.cache_id)
|
404 |
| - log_file = os.path.join(cache_paths.data_dir, log_file_name) |
405 |
| - if log_file.startswith("gs://") and redirect_logs: |
| 406 | + log_path = os.path.join(cache_paths.data_dir, log_file_name) |
| 407 | + if log_path.startswith("gs://") and redirect_logs: |
406 | 408 | fallback_log_file = os.path.join(tempfile.gettempdir(), log_file_name)
|
407 | 409 | print("Unable to log to {}. Falling back to {}"
|
408 |
| - .format(log_file, fallback_log_file)) |
409 |
| - log_file = fallback_log_file |
| 410 | + .format(log_path, fallback_log_file)) |
| 411 | + log_path = fallback_log_file |
410 | 412 |
|
411 | 413 | # This server is generally run in a subprocess.
|
412 | 414 | if redirect_logs:
|
413 |
| - print("Redirecting stdout and stderr to {}".format(log_file)) |
414 |
| - log_stream = open(log_file, "wt") # Note: not tf.gfile.Open(). |
415 |
| - stdout = log_stream |
416 |
| - stderr = log_stream |
| 415 | + print("Redirecting output of data_async_generation.py process to {}" |
| 416 | + .format(log_path)) |
| 417 | + _log_file = open(log_path, "wt") # Note: not tf.gfile.Open(). |
417 | 418 | try:
|
418 |
| - if redirect_logs: |
419 |
| - absl_logging.get_absl_logger().addHandler( |
420 |
| - hdlr=logging.StreamHandler(stream=stdout)) |
421 |
| - sys.stdout = stdout |
422 |
| - sys.stderr = stderr |
423 |
| - print("Logs redirected.") |
424 |
| - try: |
425 |
| - log_msg("sys.argv: {}".format(" ".join(sys.argv))) |
426 |
| - |
427 |
| - if flags.FLAGS.seed is not None: |
428 |
| - np.random.seed(flags.FLAGS.seed) |
429 |
| - |
430 |
| - _generation_loop( |
431 |
| - num_workers=flags.FLAGS.num_workers, |
432 |
| - cache_paths=cache_paths, |
433 |
| - num_readers=flags.FLAGS.num_readers, |
434 |
| - num_neg=flags.FLAGS.num_neg, |
435 |
| - num_train_positives=flags.FLAGS.num_train_positives, |
436 |
| - num_items=flags.FLAGS.num_items, |
437 |
| - spillover=flags.FLAGS.spillover, |
438 |
| - epochs_per_cycle=flags.FLAGS.epochs_per_cycle, |
439 |
| - train_batch_size=flags.FLAGS.train_batch_size, |
440 |
| - eval_batch_size=flags.FLAGS.eval_batch_size, |
441 |
| - ) |
442 |
| - except KeyboardInterrupt: |
443 |
| - log_msg("KeyboardInterrupt registered.") |
444 |
| - except: |
445 |
| - traceback.print_exc() |
446 |
| - raise |
| 419 | + log_msg("sys.argv: {}".format(" ".join(sys.argv))) |
| 420 | + |
| 421 | + if flags.FLAGS.seed is not None: |
| 422 | + np.random.seed(flags.FLAGS.seed) |
| 423 | + |
| 424 | + _generation_loop( |
| 425 | + num_workers=flags.FLAGS.num_workers, |
| 426 | + cache_paths=cache_paths, |
| 427 | + num_readers=flags.FLAGS.num_readers, |
| 428 | + num_neg=flags.FLAGS.num_neg, |
| 429 | + num_train_positives=flags.FLAGS.num_train_positives, |
| 430 | + num_items=flags.FLAGS.num_items, |
| 431 | + spillover=flags.FLAGS.spillover, |
| 432 | + epochs_per_cycle=flags.FLAGS.epochs_per_cycle, |
| 433 | + train_batch_size=flags.FLAGS.train_batch_size, |
| 434 | + eval_batch_size=flags.FLAGS.eval_batch_size, |
| 435 | + ) |
| 436 | + except KeyboardInterrupt: |
| 437 | + log_msg("KeyboardInterrupt registered.") |
| 438 | + except: |
| 439 | + traceback.print_exc(file=_log_file) |
| 440 | + raise |
447 | 441 | finally:
|
448 | 442 | log_msg("Shutting down generation subprocess.")
|
449 | 443 | sys.stdout.flush()
|
450 | 444 | sys.stderr.flush()
|
451 | 445 | if redirect_logs:
|
452 |
| - log_stream.close() |
| 446 | + _log_file.close() |
453 | 447 |
|
454 | 448 |
|
455 | 449 | def define_flags():
|
|
0 commit comments