tf.function

View source on GitHub

Compiles a function into a callable TensorFlow graph.

tf.function(
    func=None, input_signature=None, autograph=True, experimental_implements=None,
    experimental_autograph_options=None, experimental_relax_shapes=False,
    experimental_compile=None
)

tf.function constructs a callable that executes a TensorFlow graph (tf.Graph) created by trace-compiling the TensorFlow operations in func, effectively executing func as a TensorFlow graph.

Example usage:

>>> @tf.function
... def f(x, y):
...   return x ** 2 + y
>>> x = tf.constant([2, 3])
>>> y = tf.constant([3, -2])
>>> f(x, y)
<tf.Tensor: ... numpy=array([7, 7], ...)>

Features

func may use data-dependent control flow, including if, for, while break, continue and return statements:

>>> @tf.function
... def f(x):
...   if tf.reduce_sum(x) > 0:
...     return x * x
...   else:
...     return -x // 2
>>> f(tf.constant(-2))
<tf.Tensor: ... numpy=1>

func's closure may include tf.Tensor and tf.Variable objects:

>>> @tf.function
... def f():
...   return x ** 2 + y
>>> x = tf.constant([-2, -3])
>>> y = tf.Variable([3, -2])
>>> f()
<tf.Tensor: ... numpy=array([7, 7], ...)>

func may also use ops with side effects, such as tf.print, tf.Variable and others:

>>> v = tf.Variable(1)
>>> @tf.function
... def f(x):
...   for i in tf.range(x):
...     v.assign_add(i)
>>> f(3)
>>> v
<tf.Variable ... numpy=4>

Important: Any Python side-effects (appending to a list, printing with print, etc) will only happen once, when func is traced. To have side-effects executed into your tf.function they need to be written as TF ops:

>>> l = []
>>> @tf.function
... def f(x):
...   for i in x:
...     l.append(i + 1)    # Caution! Will only happen once when tracing
>>> f(tf.constant([1, 2, 3]))
>>> l
[<tf.Tensor ...>]

Instead, use TensorFlow collections like tf.TensorArray:

>>> @tf.function
... def f(x):
...   ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
...   for i in range(len(x)):
...     ta = ta.write(i, x[i] + 1)
...   return ta.stack()
>>> f(tf.constant([1, 2, 3]))
<tf.Tensor: ..., numpy=array([2, 3, 4], ...)>

tf.function is polymorphic

Internally, tf.function can build more than one graph, to support arguments with different data types or shapes, since TensorFlow can build more efficient graphs that are specialized on shapes and dtypes. tf.function also treats any pure Python value as opaque objects, and builds a separate graph for each set of Python arguments that it encounters.

To obtain an individual graph, use the get_concrete_function method of the callable created by tf.function. It can be called with the same arguments as func and returns a special tf.Graph object:

>>> @tf.function
... def f(x):
...   return x + 1
>>> isinstance(f.get_concrete_function(1).graph, tf.Graph)
True

Caution: Passing python scalars or lists as arguments to tf.function will always build a new graph. To avoid this, pass numeric arguments as Tensors whenever possible:

>>> @tf.function
... def f(x):
...   return tf.abs(x)
>>> f1 = f.get_concrete_function(1)
>>> f2 = f.get_concrete_function(2)  # Slow - builds new graph
>>> f1 is f2
False
>>> f1 = f.get_concrete_function(tf.constant(1))
>>> f2 = f.get_concrete_function(tf.constant(2))  # Fast - reuses f1
>>> f1 is f2
True

Python numerical arguments should only be used when they take few distinct values, such as hyperparameters like the number of layers in a neural network.

Input signatures

For Tensor arguments, tf.function instantiates a separate graph for every unique set of input shapes and datatypes. The example below creates two separate graphs, each specialized to a different shape:

>>> @tf.function
... def f(x):
...   return x + 1
>>> vector = tf.constant([1.0, 1.0])
>>> matrix = tf.constant([[3.0]])
>>> f.get_concrete_function(vector) is f.get_concrete_function(matrix)
False

An "input signature" can be optionally provided to tf.function to control the graphs traced. The input signature specifies the shape and type of each Tensor argument to the function using a tf.TensorSpec object. More general shapes can be used. This is useful to avoid creating multiple graphs when Tensors have dynamic shapes. It also restricts the dhape and datatype of Tensors that can be used:

>>> @tf.function(
...     input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
... def f(x):
...   return x + 1
>>> vector = tf.constant([1.0, 1.0])
>>> matrix = tf.constant([[3.0]])
>>> f.get_concrete_function(vector) is f.get_concrete_function(matrix)
True

Variables may only be created once

tf.function only allows creating new tf.Variable objects when it is called for the first time:

>>> class MyModule(tf.Module):
...   def __init__(self):
...     self.v = None
...
...   @tf.function
...   def call(self, x):
...     if self.v is None:
...       self.v = tf.Variable(tf.ones_like(x))
...     return self.v * x

In general, it is recommended to create stateful objects like tf.Variable outside of tf.function and passing them as arguments.

Args:

Returns:

If func is not None, returns a callable that will execute the compiled function (and return zero or more tf.Tensor objects). If func is None, returns a decorator that, when invoked with a single func argument, returns a callable equivalent to the case above.