Better Performance With TF - Function - TensorFlow Core
Better Performance With TF - Function - TensorFlow Core
Run in
Google (https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/funct
Colab
Don't rely on Python side effects like object mutation or list appends.
Setup
import tensorflow as tf
https://www.tensorflow.org/guide/function 1/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
Define a helper function to demonstrate the kinds of errors you might encounter:
import traceback
import contextlib
# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
try:
yield
except error_class as e:
print('Caught expected exception \n {}:'.format(error_class))
traceback.print_exc(limit=2)
except Exception as e:
raise e
else:
raise Exception('Expected {} to be raised but no error was raised!'.format(
error_class))
Basics
Usage
A tf.function (https://www.tensorflow.org/api_docs/python/tf/function) that you define (for
example by applying the @tf.function (https://www.tensorflow.org/api_docs/python/tf/function)
decorator) is just like a core TensorFlow operation: You can execute it eagerly; you can
compute gradients; and so on.
https://www.tensorflow.org/guide/function 2/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
result = add(v, 1.0)
tape.gradient(result, v)
@tf.function
def dense_layer(x, w, b):
return add(tf.matmul(x, w), b)
https://www.tensorflow.org/guide/function 3/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)
@tf.function
def conv_fn(image):
return conv_layer(image)
https://www.tensorflow.org/guide/function 4/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
Tracing
This section exposes how tf.function (https://www.tensorflow.org/api_docs/python/tf/function)
works under the hood, including implementation details which may change in the future.
However, once you understand why and when tracing happens, it's much easier to use
tf.function (https://www.tensorflow.org/api_docs/python/tf/function) effectively!
What is "tracing"?
https://www.tensorflow.org/guide/function 5/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
(https://www.tensorflow.org/api_docs/python/tf/Graph). Later you will see how you can run only the
tracing stage with get_concrete_function (#obtaining_concrete_functions).
@tf.function
def double(a):
print("Tracing with", a)
return a + a
print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
https://www.tensorflow.org/guide/function 6/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
print(double.pretty_printed_concrete_signatures())
Input Parameters:
a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.int32, name=None)
Output Type:
TensorSpec(shape=(), dtype=tf.int32, name=None)
Captures:
None
Input Parameters:
a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.float32, name=None)
Output Type:
TensorSpec(shape=(), dtype=tf.float32, name=None)
Captures:
None
https://www.tensorflow.org/guide/function 7/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
Rules of tracing
If multiple matches are found, the most specific signature is chosen. Matching is done by
subtyping (https://en.wikipedia.org/wiki/Subtyping), much like normal function calls in C++ or Java,
for instance. For example, TensorShape([1, 2]) is a subtype of TensorShape([None,
None]) and so a call to the tf.function with TensorShape([1, 2]) can be dispatched to the
ConcreteFunction produced with TensorShape([None, None]) but if a ConcreteFunction
with TensorShape([1, None]) also exists then it will be prioritized since it is more specific.
For Tensor, the type is parameterized by the Tensor's dtype and shape; ranked shapes
are a subtype of unranked shapes; fixed dimensions are a subtype of unknown
dimensions
For Variable, the type is similar to Tensor, but also includes a unique resource ID of the
variable, necessary to correctly wire control dependencies
For Python primitive values, the type corresponds to the value itself. For example, the
TraceType of the value 3 is LiteralTraceType<3>, not int.
For Python ordered containers such as list and tuple, etc., the type is parameterized by
the types of their elements; for example, the type of [1, 2] is
ListTraceType<LiteralTraceType<1>, LiteralTraceType<2>> and the type for [2,
https://www.tensorflow.org/guide/function 8/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
For Python mappings such as dict, the type is also a mapping from the same keys but to
the types of values instead of the actual values. For example, the type of {1: 2, 3: 4},
is MappingTraceType<<KeyValue<1, LiteralTraceType<2>>>, <KeyValue<3,
LiteralTraceType<4>>>>. However, unlike ordered containers, {1: 2, 3: 4} and {3:
4, 1: 2} have equivalent types.
For Python objects which implement the __tf_tracing_type__ method, the type is
whatever that method returns.
For any other Python objects, the type is a generic TraceType, and the matching
precedure is:
First it checks if the object is the same object used in the previous trace (using
Python id() or is). Note that this will still match if the object has changed, so if
you use Python objects as tf.function
(https://www.tensorflow.org/api_docs/python/tf/function) arguments it's best to use
immutable ones.
Next it checks if the object is equal to the object used in the previous trace (using
Python ==).
Controlling retracing
Retracing, which is when your tf.function
(https://www.tensorflow.org/api_docs/python/tf/function) creates more than one trace, helps ensure
that TensorFlow generates correct graphs for each set of inputs. However, tracing is an
expensive operation! If your tf.function (https://www.tensorflow.org/api_docs/python/tf/function)
https://www.tensorflow.org/guide/function 9/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
retraces a new graph for every call, you'll find that your code executes more slowly than if you
didn't use tf.function (https://www.tensorflow.org/api_docs/python/tf/function).
To control the tracing behavior, you can use the following techniques:
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
print("Tracing with", x)
return tf.where(x % 2 == 0, x // 2, 3 * x + 1)
print(next_collatz(tf.constant([1, 2])))
# You specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(TypeError):
next_collatz(tf.constant([[1, 2], [3, 4]]))
# You specified an int32 dtype in the input signature, so this should fail.
with assert_raises(TypeError):
next_collatz(tf.constant([1.0, 2.0]))
https://www.tensorflow.org/guide/function 10/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
TypeError: Binding inputs to tf.function failed due to `Can not cast TensorSpec
array([[1 2]
Since TensorFlow matches tensors based on their shape, using a None dimension as a
wildcard will allow tf.function (https://www.tensorflow.org/api_docs/python/tf/function)s to reuse
traces for variably-sized input. Variably-sized input can occur if you have sequences of
different length, or images of different sizes for each batch. You can check out the
Transformer (https://www.tensorflow.org/text/tutorials/transformer) and Deep Dream
(https://www.tensorflow.org/tutorials/generative/deepdream) tutorials for examples.
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
print('Tracing with', x)
return x
# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
@tf.function(reduce_retracing=True)
def g(x):
https://www.tensorflow.org/guide/function 11/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
print('Tracing with', x)
return x
# Traces once.
print(g(tf.constant([1, 2, 3])))
# No more tracing!
print(g(tf.constant([1, 2, 3, 4, 5, 6, 7])))
print(g(tf.constant([1, 2, 3, 4, 5, 6, 7, 8, 9])))
Often, Python arguments are used to control hyperparameters and graph constructions - for
example, num_layers=10 or training=True or nonlinearity='relu'. So, if the Python
argument changes, it makes sense that you'd have to retrace the graph.
However, it's possible that a Python argument is not being used to control graph construction.
In these cases, a change in the Python value can trigger needless retracing. Take, for example,
this training loop, which AutoGraph will dynamically unroll. Despite the multiple traces, the
generated graph is actually identical, so retracing is unnecessary.
def train_one_step():
pass
@tf.function
def train(num_steps):
print("Tracing with num_steps = ", num_steps)
tf.print("Executing with num_steps = ", num_steps)
https://www.tensorflow.org/guide/function 12/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
for _ in tf.range(num_steps):
train_one_step()
print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
def f():
print('Tracing!')
tf.print('Executing')
tf.function(f)()
tf.function(f)()
Tracing!
Executing
https://www.tensorflow.org/guide/function 13/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
Tracing!
Executing
Where possible, you should prefer converting the Python type into a
tf.experimental.ExtensionType
(https://www.tensorflow.org/api_docs/python/tf/experimental/ExtensionType) instead. Moreover, the
TraceType of an ExtensionType is the tf.TypeSpec
(https://www.tensorflow.org/api_docs/python/tf/TypeSpec) associated with it. Therefore, if needed,
you can simply override the default tf.TypeSpec
(https://www.tensorflow.org/api_docs/python/tf/TypeSpec) to take control of an ExtensionType's
Tracing Protocol. Refer to the Customizing the ExtensionType's TypeSpec section in the
Extension types (/guide/extension_type) guide for details.
@tf.function
def get_mixed_flavor(fruit_a, fruit_b):
return fruit_a.flavor + fruit_b.flavor
class Fruit:
flavor = tf.constant([0, 0])
class Apple(Fruit):
flavor = tf.constant([1, 2])
class Mango(Fruit):
flavor = tf.constant([3, 4])
# As described in the above rules, a generic TraceType for `Apple` and `Mango`
# is generated (and a corresponding ConcreteFunction is traced) but it fails to
# match the second function call since the first pair of Apple() and Mango()
# have gone out out of scope by then and deleted.
get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function
get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function again
# However, each subclass of the `Fruit` class has a fixed flavor, and you
https://www.tensorflow.org/guide/function 14/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
class FruitTraceType(tf.types.experimental.TraceType):
def __init__(self, fruit):
self.fruit_type = type(fruit)
self.fruit_value = fruit
def __hash__(self):
return hash(self.fruit_type)
class FruitWithTraceType:
class AppleWithTraceType(FruitWithTraceType):
flavor = tf.constant([1, 2])
class MangoWithTraceType(FruitWithTraceType):
flavor = tf.constant([3, 4])
https://www.tensorflow.org/guide/function 15/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
Printing a ConcreteFunction displays a summary of its input arguments (with types) and its
output type.
print(double_strings)
https://www.tensorflow.org/guide/function 16/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
print(double_strings.function_type)
with assert_raises(tf.errors.InvalidArgumentError):
double_strings(tf.constant(1))
The above exception was the direct cause of the following exception:
https://www.tensorflow.org/guide/function 17/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
You may notice that Python arguments are given special treatment in a concrete function's
input signature. Prior to TensorFlow 2.3, Python arguments were simply removed from the
concrete function's signature. Starting with TensorFlow 2.3, Python arguments remain in the
signature, but are constrained to take the value set during tracing.
@tf.function
def pow(a, b):
return a ** b
with assert_raises(TypeError):
square(tf.constant(10.0), b=3)
The above exception was the direct cause of the following exception:
https://www.tensorflow.org/guide/function 18/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
Obtaining graphs
Although retrieving the actual tf.Graph (https://www.tensorflow.org/api_docs/python/tf/Graph)
object is not something you'll normally need to do, you can obtain it easily from any concrete
function.
graph = double_strings.graph
for node in graph.as_graph_def().node:
print(f'{node.input} -> {node.name}')
[] -> a
['a', 'a'] -> add
['add'] -> Identity
atomic_fn = double_strings.inference_fn
atomic_fn(tf.constant("a"))
https://www.tensorflow.org/guide/function 19/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
This has the advantage of having lower Python overhead for high-performance scenarios. But
it should only be used for forward inference (no gradient support), and captured tensor values
(if any) would need to be explicitly supplied.
Debugging
In general, debugging code is easier in eager mode than inside tf.function
(https://www.tensorflow.org/api_docs/python/tf/function). You should ensure that your code
executes error-free in eager mode before decorating with tf.function
(https://www.tensorflow.org/api_docs/python/tf/function). To assist in the debugging process, you
can call tf.config.run_functions_eagerly(True)
(https://www.tensorflow.org/api_docs/python/tf/config/run_functions_eagerly) to globally disable and
reenable tf.function (https://www.tensorflow.org/api_docs/python/tf/function).
Plain old Python print calls only execute during tracing, helping you track down when
your function gets (re)traced.
tf.debugging.enable_check_numerics
(https://www.tensorflow.org/api_docs/python/tf/debugging/enable_check_numerics) is an easy
way to track down where NaNs and Inf are created.
AutoGraph transformations
AutoGraph is a library that is on by default in tf.function
(https://www.tensorflow.org/api_docs/python/tf/function), and transforms a subset of Python eager
code into graph-compatible TensorFlow ops. This includes control flow like if, for, while.
https://www.tensorflow.org/guide/function 20/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
control flow is often easier to write and understand when written in Python.
# A simple loop
@tf.function
def f(x):
while tf.reduce_sum(x) > 1:
tf.print(x)
x = tf.tanh(x)
return x
f(tf.random.uniform([5]))
print(tf.autograph.to_code(f.python_function))
def tf__f(x):
with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=Tru
do_return = False
retval_ = ag__.UndefinedReturnValue()
def get_state():
https://www.tensorflow.org/guide/function 21/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
return (x,)
def set_state(vars_):
nonlocal x
(x,) = vars_
def loop_body():
Conditionals
AutoGraph will convert some if <condition> statements into the equivalent tf.cond calls.
This substitution is made if <condition> is a Tensor. Otherwise, the if statement is executed
as a Python conditional.
A Python conditional executes during tracing, so exactly one branch of the conditional will be
added to the graph. Without AutoGraph, this traced graph would be unable to take the alternate
branch if there is data-dependent control flow.
@tf.function
def fizzbuzz(n):
for i in tf.range(1, n + 1):
print('Tracing for loop')
if i % 15 == 0:
print('Tracing fizzbuzz branch')
tf.print('fizzbuzz')
elif i % 3 == 0:
print('Tracing fizz branch')
tf.print('fizz')
elif i % 5 == 0:
print('Tracing buzz branch')
tf.print('buzz')
else:
print('Tracing default branch')
tf.print(i)
https://www.tensorflow.org/guide/function 22/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Loops
AutoGraph will convert some for and while statements into the equivalent TensorFlow
looping ops, like tf.while_loop (https://www.tensorflow.org/api_docs/python/tf/while_loop). If not
converted, the for or while loop is executed as a Python loop.
https://www.tensorflow.org/guide/function 23/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
A Python loop executes during tracing, adding additional ops to the tf.Graph
(https://www.tensorflow.org/api_docs/python/tf/Graph) for every iteration of the loop.
A TensorFlow loop traces the body of the loop, and dynamically selects how many iterations to
run at execution time. The loop body only appears once in the generated tf.Graph
(https://www.tensorflow.org/api_docs/python/tf/Graph).
@tf.function
def train(dataset):
loss = tf.constant(0)
for x, y in dataset:
loss += tf.abs(y - x) # Some dummy computation.
return loss
https://www.tensorflow.org/guide/function 24/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
measure_graph_size(train, tf.data.Dataset.from_generator(
lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (
train(<_FlatMapDataset element_spec=(TensorSpec(shape=<unknown>, dtype=tf.int32,
train(<_FlatMapDataset element_spec=(TensorSpec(shape=<unknown>, dtype=tf.int32,
Reading data from files via TFRecordDataset, CsvDataset, etc. is the most effective way to
consume data, as then TensorFlow itself can manage the asynchronous loading and
prefetching of data, without having to involve Python. To learn more, see the tf.data: Build
TensorFlow input pipelines (/guide/data) guide.
batch_size = 2
seq_len = 3
feature_size = 4
https://www.tensorflow.org/guide/function 25/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
# [batch, time, features] -> [time, batch, features]
input_data = tf.transpose(input_data, [1, 0, 2])
max_seq_len = input_data.shape[0]
dynamic_rnn(rnn_step,
tf.random.uniform([batch_size, seq_len, feature_size]),
tf.zeros([batch_size, feature_size]))
Limitations
tf.function (https://www.tensorflow.org/api_docs/python/tf/function) has a few limitations by
design that you should be aware of when converting a Python function to a tf.function
(https://www.tensorflow.org/api_docs/python/tf/function).
https://www.tensorflow.org/guide/function 26/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
Side effects, like printing, appending to lists, and mutating globals, can behave unexpectedly
inside a tf.function (https://www.tensorflow.org/api_docs/python/tf/function), sometimes
executing twice or not all. They only happen the first time you call a tf.function
(https://www.tensorflow.org/api_docs/python/tf/function) with a set of inputs. Afterwards, the traced
tf.Graph (https://www.tensorflow.org/api_docs/python/tf/Graph) is reexecuted, without executing
the Python code.
The general rule of thumb is to avoid relying on Python side effects in your logic and only use
them to debug your traces. Otherwise, TensorFlow APIs like tf.data
(https://www.tensorflow.org/api_docs/python/tf/data), tf.print
(https://www.tensorflow.org/api_docs/python/tf/print), tf.summary
(https://www.tensorflow.org/api_docs/python/tf/summary), tf.Variable.assign
(https://www.tensorflow.org/api_docs/python/tf/Variable#assign), and tf.TensorArray
(https://www.tensorflow.org/api_docs/python/tf/TensorArray) are the best way to ensure your code
will be executed by the TensorFlow runtime with each call.
@tf.function
def f(x):
print("Traced with", x)
tf.print("Executed with", x)
f(1)
f(1)
f(2)
Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2
If you would like to execute Python code during each invocation of a tf.function
(https://www.tensorflow.org/api_docs/python/tf/function), tf. py_function is an exit hatch. The
drawbacks of tf.py_function (https://www.tensorflow.org/api_docs/python/tf/py_function) are that
it's not portable or particularly performant, cannot be saved with SavedModel, and does not
work well in distributed (multi-GPU, TPU) setups. Also, since tf.py_function
https://www.tensorflow.org/guide/function 27/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
@tf.py_function(Tout=tf.float32)
def py_plus(x, y):
print('Executing eagerly.')
return x + y
@tf.function
def tf_wrapper(x, y):
print('Tracing.')
return py_plus(x, y)
tf_wrapper(tf.constant(1.0), tf.constant(2.0)).numpy()
Tracing.
Executing eagerly.
3.0
tf_wrapper(tf.constant(1.0), tf.constant(2.0)).numpy()
Executing eagerly.
3.0
https://www.tensorflow.org/guide/function 28/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
external_list = []
@tf.function
def side_effect(x):
print('Python side effect')
external_list.append(x)
side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1
Sometimes unexpected behaviors are very hard to notice. In the example below, the counter is
intended to safeguard the increment of a variable. However because it is a python integer and
not a TensorFlow object, it's value is captured during the first trace. When the tf.function
(https://www.tensorflow.org/api_docs/python/tf/function) is used, the assign_add will be recorded
unconditionally in the underlying graph. Therefore v will increase by 1, every time the
tf.function (https://www.tensorflow.org/api_docs/python/tf/function) is called. This issue is
common among users that try to migrate their Graph-mode Tensorflow code to Tensorflow 2
using tf.function (https://www.tensorflow.org/api_docs/python/tf/function) decorators, when
python side-effects (the counter in the example) are used to determine what ops to run
(assign_add in the example). Usually, users realize this only after seeing suspicious numerical
results, or significantly lower performance than expected (e.g. if the guarded operation is very
costly).
class Model(tf.Module):
def __init__(self):
self.v = tf.Variable(0)
self.counter = 0
https://www.tensorflow.org/guide/function 29/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
@tf.function
def __call__(self):
if self.counter == 0:
# A python side-effect
self.counter += 1
self.v.assign_add(1)
return self.v
m = Model()
for n in range(3):
print(m().numpy()) # prints 1, 2, 3
1
2
3
class Model(tf.Module):
def __init__(self):
self.v = tf.Variable(0)
self.counter = 0
@tf.function
def __call__(self):
if self.counter == 0:
# Lifts ops out of function-building graphs
with tf.init_scope():
self.counter += 1
self.v.assign_add(1)
return self.v
https://www.tensorflow.org/guide/function 30/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
m = Model()
for n in range(3):
print(m().numpy()) # prints 1, 1, 1
1
1
1
In summary, as a rule of thumb, you should avoid mutating python objects such as integers or
containers like lists that live outside the tf.function
(https://www.tensorflow.org/api_docs/python/tf/function). Instead, use arguments and TF objects.
For example, the section "Accumulating values in a loop" (#accumulating_values_in_a_loop) has
one example of how list-like operations can be implemented.
Many Python features, such as generators and iterators, rely on the Python runtime to keep
track of state. In general, while these constructs work as expected in eager mode, they are
examples of Python side effects and therefore only happen during tracing.
@tf.function
def buggy_consume_next(iterator):
tf.print("Value:", next(iterator))
https://www.tensorflow.org/guide/function 31/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
Value: 1
Value: 1
Value: 1
@tf.function
def good_consume_next(iterator):
# This is ok, iterator is a tf.data.Iterator
tf.print("Value:", next(iterator))
ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)
Value: 1
Value: 2
Value: 3
For example, the function below "leaks" the tensor a through the Python global x:
https://www.tensorflow.org/guide/function 32/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
x = None
@tf.function
def leaky_function(a):
global x
x = a + 1 # Bad - leaks local tensor
return a + 2
correct_a = leaky_function(tf.constant(1))
3
'SymbolicTensor' object has no attribute 'numpy'
@tf.function
def leaky_function(a):
global x
x = a + 1 # Bad - leaks local tensor
return x # Good - uses local tensor
correct_a = leaky_function(tf.constant(1))
@tf.function
def captures_leaked_tensor(b):
b += x # Bad - `x` is leaked from `leaky_function`
https://www.tensorflow.org/guide/function 33/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
return b
with assert_raises(TypeError):
captures_leaked_tensor(tf.constant(2))
2
'SymbolicTensor' object has no attribute 'numpy'
Caught expected exception
<class 'TypeError'>:
Traceback (most recent call last):
File "/tmpfs/tmp/ipykernel_167534/3551158538.py", line 8, in assert_raises
yield
File "/tmpfs/tmp/ipykernel_167534/566849597.py", line 21, in <module>
captures_leaked_tensor(tf.constant(2))
TypeError: <tf.Tensor 'add:0' shape=() dtype=int32> is out of scope and cannot
Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunctio
Usually, leaks such as these occur when you use Python statements or data structures. In
addition to leaking inaccessible tensors, such statements are also likely wrong because they
count as Python side effects, and are not guaranteed to execute at every function call.
Common ways to leak local tensors also include mutating an external Python collection, or an
object:
class MyClass:
def __init__(self):
self.field = None
external_list = []
external_object = MyClass()
def leaky_function():
a = tf.constant(1)
external_list.append(a) # Bad - leaks tensor
external_object.field = a # Bad - leaks tensor
https://www.tensorflow.org/guide/function 34/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
@tf.function
def recursive_fn(n):
if n > 0:
return recursive_fn(n - 1)
else:
return 1
with assert_raises(Exception):
recursive_fn(tf.constant(5)) # Bad - maximum recursion error.
@tf.function
def recursive_fn(n):
if n > 0:
print('tracing')
return recursive_fn(n - 1)
https://www.tensorflow.org/guide/function 35/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
else:
return 1
tracing
tracing
tracing
tracing
tracing
<tf.Tensor: shape=(), dtype=int32, numpy=1>
Known Issues
If your tf.function (https://www.tensorflow.org/api_docs/python/tf/function) is not evaluating
correctly, the error may be explained by these known issues which are planned to be fixed in
the future.
For that reason, you should follow a functional programming style that uses arguments instead
of closing over outer names.
@tf.function
def buggy_add():
return 1 + foo
https://www.tensorflow.org/guide/function 36/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
@tf.function
def recommended_add(foo):
return 1 + foo
foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))
@tf.function
def variable_add():
return 1 + foo
foo = tf.Variable(1)
print("Variable:", variable_add())
https://www.tensorflow.org/guide/function 37/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
For maximum feature coverage, consider transforming the objects into Extension types
(/guide/extension_type) before passing them to tf.function
(https://www.tensorflow.org/api_docs/python/tf/function). You can also use Python primitives and
tf.nest (https://www.tensorflow.org/api_docs/python/tf/nest)-compatible structures.
However, as covered in the rules of tracing (#rules_of_tracing), when a custom TraceType is not
provided by the custom Python class, tf.function
(https://www.tensorflow.org/api_docs/python/tf/function) is forced to use instance-based equality
which means it will not create a new trace when you pass the same object with modified
attributes.
class SimpleModel(tf.Module):
def __init__(self):
# These values are *not* tf.Variables.
self.bias = 0.
self.weight = 2.
@tf.function
def evaluate(model, x):
return model.weight * x + model.bias
simple_model = SimpleModel()
https://www.tensorflow.org/guide/function 38/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
x = tf.constant(10.)
print(evaluate(simple_model, x))
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x)) # Didn't change :(
Adding bias!
tf.Tensor(20.0, shape=(), dtype=float32)
new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`. `tf.function` already captured its state during tra
print(evaluate_no_bias(x))
https://www.tensorflow.org/guide/function 39/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
print("Adding bias!")
new_model.bias += 5.0
# Create new `tf.function` and `ConcreteFunction` since you modified `new_model`
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)
class BetterModel:
def __init__(self):
self.bias = tf.Variable(0.)
self.weight = tf.Variable(2.)
@tf.function
def evaluate(model, x):
return model.weight * x + model.bias
better_model = BetterModel()
print(evaluate(better_model, x))
https://www.tensorflow.org/guide/function 40/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
print("Adding bias!")
better_model.bias.assign_add(5.0) # Note: instead of better_model.bias += 5
print(evaluate(better_model, x)) # This works!
Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)
Creating tf.Variables
tf.function (https://www.tensorflow.org/api_docs/python/tf/function) only supports singleton
tf.Variable (https://www.tensorflow.org/api_docs/python/tf/Variable)s created once on the first
call, and reused across subsequent function calls. The code snippet below would create a new
tf.Variable (https://www.tensorflow.org/api_docs/python/tf/Variable) in every function call, which
results in a ValueError exception.
Example:
@tf.function
def f(x):
v = tf.Variable(1.0)
return v
with assert_raises(ValueError):
f(1.0)
https://www.tensorflow.org/guide/function 41/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
A common pattern used to work around this limitation is to start with a Python None value,
then conditionally create the tf.Variable (https://www.tensorflow.org/api_docs/python/tf/Variable)
if the value is None:
class Count(tf.Module):
def __init__(self):
self.count = None
@tf.function
def __call__(self):
if self.count is None:
self.count = tf.Variable(0)
return self.count.assign_add(1)
c = Count()
print(c())
print(c())
@tf.function
https://www.tensorflow.org/guide/function 42/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])
train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
train_step(w, x, y, opt2)
If you need to change a stateful object between calls, it's simplest to define a tf.Module
(https://www.tensorflow.org/api_docs/python/tf/Module) subclass, and create instances to hold
those objects:
class TrainStep(tf.Module):
def __init__(self, optimizer):
self.optimizer = optimizer
@tf.function
def __call__(self, w, x, y):
with tf.GradientTape() as tape:
https://www.tensorflow.org/guide/function 43/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
L = tf.reduce_sum(tf.square(w*x - y))
gradients = tape.gradient(L, [w])
self.optimizer.apply_gradients(zip(gradients, [w]))
train_o1 = TrainStep(opt1)
train_o2 = TrainStep(opt2)
train_o1(w, x, y)
train_o2(w, x, y)
You could also do this manually by creating multiple instances of the @tf.function
(https://www.tensorflow.org/api_docs/python/tf/function) wrapper, one for each optimizer:
# Not a tf.function.
def train_step(w, x, y, optimizer):
with tf.GradientTape() as tape:
L = tf.reduce_sum(tf.square(w*x - y))
gradients = tape.gradient(L, [w])
optimizer.apply_gradients(zip(gradients, [w]))
w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])
https://www.tensorflow.org/guide/function 44/45
11/11/24, 9:27 PM Better performance with tf.function | TensorFlow Core
This error occurs because Keras models (which do not have their input shape defined
(https://www.tensorflow.org/guide/keras/custom_layers_and_models#best_practice_deferring_weight_crea
tion_until_the_shape_of_the_inputs_is_known)
) and Keras layers create tf.Variable (https://www.tensorflow.org/api_docs/python/tf/Variable)s
when they are first called. You may be attempting to initialize those variables inside a
tf.function (https://www.tensorflow.org/api_docs/python/tf/function), which has already been
called. To avoid this error, try calling model.build(input_shape) to initialize all the weights
before training the model.
Further reading
To learn about how to export and load a tf.function
(https://www.tensorflow.org/api_docs/python/tf/function), see the SavedModel guide
(https://www.tensorflow.org/guide/saved_model). To learn more about graph optimizations that are
performed after tracing, see the Grappler guide
(https://www.tensorflow.org/guide/graph_optimization). To learn how to optimize your data pipeline
and profile your model, see the Profiler guide (https://www.tensorflow.org/guide/profiler).
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License
(https://creativecommons.org/licenses/by/4.0/), and code samples are licensed under the Apache 2.0 License
(https://www.apache.org/licenses/LICENSE-2.0). For details, see the Google Developers Site Policies
(https://developers.google.com/site-policies). Java is a registered trademark of Oracle and/or its affiliates.
https://www.tensorflow.org/guide/function 45/45