tf.compat.v1.tpu.batch_parallel

View source on GitHub

Shards computation along the batch dimension for parallel execution.

tf.compat.v1.tpu.batch_parallel(
    computation, inputs=None, num_shards=1, infeed_queue=None,
    device_assignment=None, name=None
)

Convenience wrapper around shard().

inputs must be a list of Tensors or None (equivalent to an empty list). Each input is split into num_shards pieces along the 0-th dimension, and computation is applied to each shard in parallel.

Tensors are broadcast to all shards if they are lexically captured by computation. e.g.,

x = tf.constant(7) def computation(): return x + 3 ... = shard(computation, ...)

The outputs from all shards are concatenated back together along their 0-th dimension.

Inputs and outputs of the computation must be at least rank-1 Tensors.

Args:

Returns:

A list of output tensors.

Raises: