|
19 | 19 | from __future__ import print_function
|
20 | 20 |
|
21 | 21 | from tensorflow.contrib.optimizer_v2 import optimizer_v2
|
22 |
| -from tensorflow.python.framework import ops |
23 | 22 | from tensorflow.python.ops import array_ops
|
24 | 23 | from tensorflow.python.ops import gen_array_ops
|
25 | 24 | from tensorflow.python.ops import init_ops
|
@@ -65,17 +64,18 @@ def __init__(self, learning_rate, initial_accumulator_value=0.1,
|
65 | 64 |
|
66 | 65 | def _create_vars(self, var_list, state):
|
67 | 66 | for v in var_list:
|
68 |
| - with ops.colocate_with(v): |
69 |
| - dtype = v.dtype.base_dtype |
70 |
| - if v.get_shape().is_fully_defined(): |
71 |
| - init = init_ops.constant_initializer(self._initial_accumulator_value, |
72 |
| - dtype=dtype) |
73 |
| - else: |
74 |
| - # Use a Tensor instead of initializer if variable does not have static |
75 |
| - # shape. |
76 |
| - init_constant = gen_array_ops.fill( |
77 |
| - array_ops.shape(v), self._initial_accumulator_value) |
78 |
| - init = math_ops.cast(init_constant, dtype) |
| 67 | + # TODO(isaprykin): Delete colocate_with(v) from other optimizers and |
| 68 | + # confirm that colocation will happen anyway. |
| 69 | + dtype = v.dtype.base_dtype |
| 70 | + if v.get_shape().is_fully_defined(): |
| 71 | + init = init_ops.constant_initializer(self._initial_accumulator_value, |
| 72 | + dtype=dtype) |
| 73 | + else: |
| 74 | + # Use a Tensor instead of initializer if variable does not have static |
| 75 | + # shape. |
| 76 | + init_constant = gen_array_ops.fill( |
| 77 | + array_ops.shape(v), self._initial_accumulator_value) |
| 78 | + init = math_ops.cast(init_constant, dtype) |
79 | 79 | state.create_slot_with_initializer(v, init, v.get_shape(), dtype,
|
80 | 80 | "accumulator")
|
81 | 81 |
|
|
0 commit comments