Skip to content

Commit 50cf8eb

Browse files
authored
Merge branch 'master' into serving_blogpost
2 parents ec51c8c + 5be3727 commit 50cf8eb

File tree

577 files changed

+18453
-93815
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

577 files changed

+18453
-93815
lines changed

.gitignore

-2
Original file line numberDiff line numberDiff line change
@@ -90,5 +90,3 @@ ENV/
9090

9191
# PyCharm
9292
.idea/
93-
94-
samples/outreach/blogs/serving_blogpost/data/

CODEOWNERS

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
/research/brain_coder/ @danabo
1111
/research/cognitive_mapping_and_planning/ @s-gupta
1212
/research/compression/ @nmjohn
13+
/research/deep_contextual_bandits/ @rikel
1314
/research/deeplab/ @aquariusjay @yknzhu @gpapan
1415
/research/delf/ @andrefaraujo
1516
/research/differential_privacy/ @ilyamironov @ananthr
@@ -19,6 +20,7 @@
1920
/research/global_objectives/ @mackeya-google
2021
/research/im2txt/ @cshallue
2122
/research/inception/ @shlens @vincentvanhoucke
23+
/research/keypointnet/ @mnorouzi
2224
/research/learned_optimizer/ @olganw @nirum
2325
/research/learning_to_remember_rare_events/ @lukaszkaiser @ofirnachum
2426
/research/learning_unsupervised_learning/ @lukemetz @nirum

official/datasets/movielens.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,12 @@ def _progress(count, block_size, total_size):
133133
_regularize_20m_dataset(temp_dir)
134134

135135
for fname in tf.gfile.ListDirectory(temp_dir):
136-
tf.gfile.Copy(os.path.join(temp_dir, fname),
137-
os.path.join(data_subdir, fname))
136+
if not tf.gfile.Exists(os.path.join(data_subdir, fname)):
137+
tf.gfile.Copy(os.path.join(temp_dir, fname),
138+
os.path.join(data_subdir, fname))
139+
else:
140+
tf.logging.info("Skipping copy of {}, as it already exists in the "
141+
"destination folder.".format(fname))
138142

139143
finally:
140144
tf.gfile.DeleteRecursively(temp_dir)

official/keras_application_models/dataset.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from __future__ import print_function
1919

2020
import tensorflow as tf
21-
2221
from official.utils.misc import model_helpers # pylint: disable=g-bad-import-order
2322

2423
# Default values for dataset.
@@ -29,7 +28,7 @@
2928
def _get_default_image_size(model):
3029
"""Provide default image size for each model."""
3130
image_size = (224, 224)
32-
if model in ["inception", "xception", "inceptionresnet"]:
31+
if model in ["inceptionv3", "xception", "inceptionresnetv2"]:
3332
image_size = (299, 299)
3433
elif model in ["nasnetlarge"]:
3534
image_size = (331, 331)
@@ -42,8 +41,8 @@ def generate_synthetic_input_dataset(model, batch_size):
4241
image_shape = (batch_size,) + image_size + (_NUM_CHANNELS,)
4342
label_shape = (batch_size, _NUM_CLASSES)
4443

45-
return model_helpers.generate_synthetic_data(
44+
dataset = model_helpers.generate_synthetic_data(
4645
input_shape=tf.TensorShape(image_shape),
47-
input_dtype=tf.float32,
4846
label_shape=tf.TensorShape(label_shape),
49-
label_dtype=tf.float32)
47+
)
48+
return dataset

official/mnist/mnist_eager.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ def train(model, optimizer, dataset, step_counter, log_interval=None):
8383

8484
def test(model, dataset):
8585
"""Perform an evaluation of `model` on the examples from `dataset`."""
86-
avg_loss = tfe.metrics.Mean('loss')
87-
accuracy = tfe.metrics.Accuracy('accuracy')
86+
avg_loss = tfe.metrics.Mean('loss', dtype=tf.float32)
87+
accuracy = tfe.metrics.Accuracy('accuracy', dtype=tf.float32)
8888

8989
for (images, labels) in dataset:
9090
logits = model(images, training=False)

official/recommendation/README.md

+3-2
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,16 @@ In both datasets, the timestamp is represented in seconds since midnight Coordin
4343
### Download and preprocess dataset
4444
To download the dataset, please install Pandas package first. Then issue the following command:
4545
```
46-
python movielens_dataset.py
46+
python ../datasets/movielens.py
4747
```
4848
Arguments:
4949
* `--data_dir`: Directory where to download and save the preprocessed data. By default, it is `/tmp/movielens-data/`.
5050
* `--dataset`: The dataset name to be downloaded and preprocessed. By default, it is `ml-1m`.
5151

5252
Use the `--help` or `-h` flag to get a full list of possible arguments.
5353

54-
Note the ml-20m dataset is large (the rating file is ~500 MB), and it may take several minutes (~10 mins) for data preprocessing.
54+
Note the ml-20m dataset is large (the rating file is ~500 MB), and it may take several minutes (~2 mins) for data preprocessing.
55+
Both the ml-1m and ml-20m datasets will be coerced into a common format when downloaded.
5556

5657
### Train and evaluate model
5758
To train and evaluate the model, issue the following command:

official/recommendation/constants.py

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Central location for NCF specific values."""
16+
17+
import os
18+
import time
19+
20+
21+
# ==============================================================================
22+
# == Main Thread Data Processing ===============================================
23+
# ==============================================================================
24+
class Paths(object):
25+
"""Container for various path information used while training NCF."""
26+
27+
def __init__(self, data_dir, cache_id=None):
28+
self.cache_id = cache_id or int(time.time())
29+
self.data_dir = data_dir
30+
self.cache_root = os.path.join(
31+
self.data_dir, "{}_ncf_recommendation_cache".format(self.cache_id))
32+
self.train_shard_subdir = os.path.join(self.cache_root,
33+
"raw_training_shards")
34+
self.train_shard_template = os.path.join(self.train_shard_subdir,
35+
"positive_shard_{}.pickle")
36+
self.train_epoch_dir = os.path.join(self.cache_root, "training_epochs")
37+
self.eval_data_subdir = os.path.join(self.cache_root, "eval_data")
38+
self.eval_raw_file = os.path.join(self.eval_data_subdir, "raw.pickle")
39+
self.eval_record_template_temp = os.path.join(self.eval_data_subdir,
40+
"eval_records.temp")
41+
self.eval_record_template = os.path.join(
42+
self.eval_data_subdir, "padded_eval_batch_size_{}.tfrecords")
43+
self.subproc_alive = os.path.join(self.cache_root, "subproc.alive")
44+
45+
46+
APPROX_PTS_PER_TRAIN_SHARD = 128000
47+
48+
# In both datasets, each user has at least 20 ratings.
49+
MIN_NUM_RATINGS = 20
50+
51+
# The number of negative examples attached with a positive example
52+
# when performing evaluation.
53+
NUM_EVAL_NEGATIVES = 999
54+
55+
# ==============================================================================
56+
# == Subprocess Data Generation ================================================
57+
# ==============================================================================
58+
CYCLES_TO_BUFFER = 3 # The number of train cycles worth of data to "run ahead"
59+
# of the main training loop.
60+
61+
READY_FILE = "ready.json"
62+
TRAIN_RECORD_TEMPLATE = "train_{}.tfrecords"
63+
64+
TIMEOUT_SECONDS = 3600 * 2 # If the train loop goes more than two hours without
65+
# consuming an epoch of data, this is a good
66+
# indicator that the main thread is dead and the
67+
# subprocess is orphaned.

0 commit comments

Comments
 (0)