Class InputContext
Defined in tensorflow/python/distribute/distribute_lib.py
.
A class wrapping information needed by an input function.
This is a context class that is passed to the user's input fn and contains information about the compute replicas and input pipelines. The number of compute replicas (in sync training) helps compute per input pipeline batch size from the desired global batch size. Input pipeline information can be used to return a different subset of the input in each input pipeline (for e.g. shard the input pipeline, use a different input source etc).
__init__
__init__(
num_input_pipelines=1,
input_pipeline_id=0,
num_replicas_in_sync=1
)
Initializes an InputContext object.
Args:
num_input_pipelines
: the number of input pipelines in a cluster.input_pipeline_id
: the current input pipeline id, should be an int in [0,num_input_pipelines
).num_replicas_in_sync
: the number of replicas that are in sync.
Properties
input_pipeline_id
Returns the input pipeline ID.
num_input_pipelines
Returns the number of input pipelines.
num_replicas_in_sync
Returns the number of compute replicas in sync.
Methods
tf.distribute.InputContext.get_per_replica_batch_size
get_per_replica_batch_size(global_batch_size)
Returns the per-replica batch size.
Args:
global_batch_size
: the global batch size which should be divisible bynum_replicas_in_sync
.
Returns:
the per-replica batch size.
Raises:
ValueError
: ifglobal_batch_size
not divisible bynum_replicas_in_sync
.