Skip to content

Commit 99e52a8

Browse files
ebrevdomartinwicke
authored andcommitted
Bugfixes to TensorArray and functional ops:
- TensorArray shape inference now works correctly for scalar elements - TensorArrays each now get a unique name at runtime, per step. This means that they can be used in nested functional ops (e.g. tf.scan(tf.scan(...))) Change: 125110643
1 parent 1bea99a commit 99e52a8

File tree

6 files changed

+53
-18
lines changed

6 files changed

+53
-18
lines changed

tensorflow/core/kernels/tensor_array.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_GPU);
7575

7676
} // namespace tensor_array
7777

78+
std::atomic<int64> TensorArray::tensor_array_counter{0};
79+
7880
Status TensorArray::CopyShapesFrom(TensorArray* rhs) {
7981
mutex_lock l(mu_);
8082
mutex_lock l_rhs(*rhs->mu());

tensorflow/core/kernels/tensor_array.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_GPU);
124124
//
125125
class TensorArray : public ResourceBase {
126126
public:
127+
static std::atomic<int64> tensor_array_counter;
128+
127129
// Construct a TensorArray for holding Tensors of type 'dtype' with
128130
// 'N' elements. While the underlying storage is a std::vector and
129131
// can hold more than MAX_INT entries, in practice we do not expect

tensorflow/core/kernels/tensor_array_ops.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,14 +147,18 @@ class TensorArrayOp : public TensorArrayCreationOp {
147147
const int32 size = tensor_size->scalar<int32>()();
148148

149149
auto handle = tensor_array_output_handle->flat<string>();
150+
string unique_tensor_array_name =
151+
strings::StrCat(tensor_array_name_, "_",
152+
TensorArray::tensor_array_counter.fetch_add(1));
150153
handle(0) = "_tensor_arrays";
151-
handle(1) = tensor_array_name_;
154+
handle(1) = unique_tensor_array_name;
152155

153156
TensorArray* tensor_array = new TensorArray(
154157
dtype_, *tensor_array_output_handle, size, dynamic_size_,
155158
false /* multiple_writes_aggregate */, clear_after_read_);
156159

157-
TF_RETURN_IF_ERROR(rm->Create(handle(0), tensor_array_name_, tensor_array));
160+
TF_RETURN_IF_ERROR(
161+
rm->Create(handle(0), unique_tensor_array_name, tensor_array));
158162

159163
*output_tensor_array = tensor_array;
160164

tensorflow/python/kernel_tests/functional_ops_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,28 @@ def testScan_Scoped(self):
180180
results = np.array([6, 16, 38, 84, 178, 368])
181181
self.assertAllEqual(results, r.eval())
182182

183+
def testScanFoldl_Nested(self):
184+
with self.test_session():
185+
elems = tf.constant([1.0, 2.0, 3.0, 4.0], name="data")
186+
inner_elems = tf.constant([0.5, 0.5], name="data")
187+
188+
def r_inner(a, x):
189+
return tf.foldl(lambda b, y: b * y * x, inner_elems, initializer=a)
190+
191+
r = tf.scan(r_inner, elems)
192+
193+
# t == 0 (returns 1)
194+
# t == 1, a == 1, x == 2 (returns 1)
195+
# t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1
196+
# t_1 == 1, b == 1, y == 0.5, returns b * y * x = 1
197+
# t == 2, a == 1, x == 3 (returns 1.5*1.5 == 2.25)
198+
# t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1.5
199+
# t_1 == 1, b == 1.5, y == 0.5, returns b * y * x = 1.5*1.5
200+
# t == 3, a == 2.25, x == 4 (returns 9)
201+
# t_0 == 0, b == a == 2.25, y == 0.5, returns b * y * x = 4.5
202+
# t_1 == 1, b == 4.5, y == 0.5, returns b * y * x = 9
203+
self.assertAllClose([1., 1., 2.25, 9.], r.eval())
204+
183205
def testScan_Control(self):
184206
with self.test_session() as sess:
185207
s = tf.placeholder(tf.float32, shape=[None])

tensorflow/python/kernel_tests/tensor_array_ops_test.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import numpy as np
2323
import tensorflow as tf
2424

25-
from tensorflow.python.framework import errors
2625
from tensorflow.python.framework import tensor_shape
2726
from tensorflow.python.ops import gen_data_flow_ops
2827
from tensorflow.python.ops import tensor_array_grad
@@ -462,7 +461,7 @@ def _testTensorArrayWriteGradientAddMultipleAdds(self, dtype):
462461
# Assert that if multiple_writes_aggregate is not enabled,
463462
# multiple writes raise an exception.
464463
with self.assertRaisesOpError(
465-
r"TensorArray foo: Could not write to TensorArray index 2 because "
464+
r"TensorArray foo_.*: Could not write to TensorArray index 2 because "
466465
r"it has already been written to."):
467466
w1.flow.eval()
468467

@@ -495,16 +494,22 @@ def testMultiTensorArray(self):
495494
r = r1 + r2
496495
self.assertAllClose(9.0, r.eval())
497496

498-
def testDuplicateTensorArrayFails(self):
497+
def testDuplicateTensorArrayHasDifferentName(self):
499498
with self.test_session(use_gpu=self._use_gpu) as session:
500499
h1 = tensor_array_ops.TensorArray(
501500
size=1, dtype=tf.float32, tensor_array_name="foo")
502501
c1 = h1.write(0, 4.0)
503502
h2 = tensor_array_ops.TensorArray(
504503
size=1, dtype=tf.float32, tensor_array_name="foo")
505504
c2 = h2.write(0, 5.0)
506-
with self.assertRaises(errors.AlreadyExistsError):
507-
session.run([c1.flow, c2.flow])
505+
_, _, c1h, c2h = session.run([c1.flow, c2.flow, c1.handle, c2.handle])
506+
c1h = [x.decode("ascii") for x in c1h]
507+
c2h = [x.decode("ascii") for x in c2h]
508+
self.assertEqual(c1h[0], "_tensor_arrays")
509+
self.assertEqual(c2h[0], "_tensor_arrays")
510+
self.assertTrue(c1h[1].startswith("foo_"))
511+
self.assertTrue(c2h[1].startswith("foo_"))
512+
self.assertNotEqual(c1h[1], c2h[1])
508513

509514
def _testTensorArrayGradientWriteReadType(self, dtype):
510515
with self.test_session(use_gpu=self._use_gpu) as session:
@@ -692,13 +697,6 @@ def testWriteCloseTensorArray(self):
692697
w1 = w0.write(1, [3.0])
693698
w1.close().run() # Expected to run without problems
694699

695-
ta = tensor_array_ops.TensorArray(
696-
dtype=tf.float32, tensor_array_name="foo", size=3)
697-
with self.assertRaisesOpError(
698-
r"TensorArray foo has already been closed."):
699-
with tf.control_dependencies([w1.close()]):
700-
w1.write(2, 3.0).flow.eval()
701-
702700
def _testWhileLoopWritePackGradients(self, dynamic_size, dtype):
703701
np_dtype = dtype.as_numpy_dtype
704702
with self.test_session(use_gpu=self._use_gpu) as session:

tensorflow/python/ops/tensor_array_ops.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,13 @@ def __init__(self, dtype, size=None, dynamic_size=None,
6464
flow=None, infer_shape=True, name=None):
6565
"""Construct a new TensorArray or wrap an existing TensorArray handle.
6666
67+
A note about the parameter `name`:
68+
69+
The name of the `TensorArray` (even if passed in) is uniquified: each time
70+
a new `TensorArray` is created at runtime it is assigned its own name for
71+
the duration of the run. This avoids name collissions if a `TensorArray`
72+
is created within a `while_loop`.
73+
6774
Args:
6875
dtype: (required) data type of the TensorArray.
6976
size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
@@ -235,7 +242,7 @@ def pack(self, name=None):
235242
value = gen_data_flow_ops._tensor_array_pack(
236243
handle=self._handle, flow_in=self._flow, dtype=self._dtype,
237244
name=name)
238-
if self._elem_shape and self._elem_shape[0].dims:
245+
if self._elem_shape and self._elem_shape[0].dims is not None:
239246
value.set_shape([None] + self._elem_shape[0].dims)
240247
return value
241248

@@ -255,7 +262,7 @@ def concat(self, name=None):
255262
value, _ = gen_data_flow_ops._tensor_array_concat(
256263
handle=self._handle, flow_in=self._flow, dtype=self._dtype,
257264
name=name)
258-
if self._elem_shape and self._elem_shape[0].dims:
265+
if self._elem_shape and self._elem_shape[0].dims is not None:
259266
value.set_shape([None] + self._elem_shape[0].dims[1:])
260267
return value
261268

@@ -284,7 +291,7 @@ def unpack(self, value, name=None):
284291
if ta._infer_shape:
285292
val_shape = flow_out.op.inputs[1].get_shape()
286293
elem_shape = tensor_shape.unknown_shape()
287-
if val_shape.dims:
294+
if val_shape.dims is not None:
288295
elem_shape = tensor_shape.TensorShape(val_shape.dims[1:])
289296
if ta._elem_shape:
290297
if not elem_shape == ta._elem_shape[0]:
@@ -326,7 +333,7 @@ def split(self, value, lengths, name=None):
326333
val_shape = flow_out.op.inputs[1].get_shape()
327334
clengths = tensor_util.constant_value(flow_out.op.inputs[2])
328335
elem_shape = tensor_shape.unknown_shape()
329-
if val_shape.dims:
336+
if val_shape.dims is not None:
330337
if clengths is not None and clengths.max() == clengths.min():
331338
elem_shape = tensor_shape.TensorShape(
332339
[clengths[0]] + val_shape.dims[1:])

0 commit comments

Comments
 (0)