Skip to content

Commit 0bb7c84

Browse files
Nupur Garggunan
Nupur Garg
authored andcommitted
Fix Python API.
PiperOrigin-RevId: 199171845
1 parent 6eb43fc commit 0bb7c84

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

tensorflow/contrib/lite/python/convert_saved_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,9 @@ def set_tensor_shapes(tensors, shapes):
216216
"""
217217
if shapes:
218218
for tensor in tensors:
219-
shape = shapes.get(tensor.name)
219+
shape = shapes.get(tensor_name(tensor))
220220
if shape is not None:
221-
tensor.set_shape(shapes[tensor.name])
221+
tensor.set_shape(shape)
222222

223223

224224
def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,

tensorflow/contrib/lite/python/convert_saved_model_test.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,15 @@ def testSetTensorShapeValid(self):
7373
tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
7474
self.assertEqual([None, 3, 5], tensor.shape.as_list())
7575

76-
convert_saved_model.set_tensor_shapes([tensor],
77-
{"Placeholder:0": [5, 3, 5]})
76+
convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [5, 3, 5]})
7877
self.assertEqual([5, 3, 5], tensor.shape.as_list())
7978

79+
def testSetTensorShapeNoneValid(self):
80+
tensor = array_ops.placeholder(dtype=dtypes.float32)
81+
82+
convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [1, 3, 5]})
83+
self.assertEqual([1, 3, 5], tensor.shape.as_list())
84+
8085
def testSetTensorShapeInvalid(self):
8186
tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32)
8287
self.assertEqual([None, 3, 5], tensor.shape.as_list())

0 commit comments

Comments
 (0)