tf.compat.v1.tpu.shard

View source on GitHub

Shards computation for parallel execution.

tf.compat.v1.tpu.shard(
    computation, inputs=None, num_shards=1, input_shard_axes=None,
    outputs_from_all_shards=True, output_shard_axes=None, infeed_queue=None,
    device_assignment=None, name=None
)

inputs must be a list of Tensors or None (equivalent to an empty list), each of which has a corresponding split axis (from input_shard_axes). Each input is split into num_shards pieces along the corresponding axis, and computation is applied to each shard in parallel.

Tensors are broadcast to all shards if they are lexically captured by computation. e.g.,

x = tf.constant(7) def computation(): return x + 3 ... = shard(computation, ...)

TODO(phawkins): consider adding support for broadcasting Tensors passed as inputs.

If outputs_from_all_shards is true, the outputs from all shards of computation are concatenated back together along their output_shard_axes. Otherwise, each output is taken from an arbitrary shard.

Inputs and outputs of the computation must be at least rank-1 Tensors.

Args:

Returns:

A list of output tensors.

Raises: