Class ReplicaContext
Aliases:
- Class
tf.contrib.distribute.ReplicaContext
- Class
tf.distribute.ReplicaContext
Defined in tensorflow/python/distribute/distribute_lib.py
.
tf.distribute.Strategy
API when in a replica context.
To be used inside your replicated step function, such as in a
tf.distribute.StrategyExtended.call_for_each_replica
call.
__init__
__init__(
strategy,
replica_id_in_sync_group
)
Initialize self. See help(type(self)) for accurate signature.
Properties
devices
The devices this replica is to be executed on, as a tuple of strings.
num_replicas_in_sync
Returns number of replicas over which gradients are aggregated.
replica_id_in_sync_group
Which replica is being defined, from 0 to num_replicas_in_sync - 1
.
strategy
The current tf.distribute.Strategy
object.
Methods
tf.distribute.ReplicaContext.__enter__
__enter__()
tf.distribute.ReplicaContext.__exit__
__exit__(
exception_type,
exception_value,
traceback
)
tf.distribute.ReplicaContext.merge_call
merge_call(
merge_fn,
args=(),
kwargs=None
)
Merge args across replicas and run merge_fn
in a cross-replica context.
This allows communication and coordination when there are multiple calls
to a model function triggered by a call to
strategy.extended.call_for_each_replica(model_fn, ...)
.
See tf.distribute.StrategyExtended.call_for_each_replica
for an
explanation.
If not inside a distributed scope, this is equivalent to:
strategy = tf.distribute.get_strategy()
with cross-replica-context(strategy):
return merge_fn(strategy, *args, **kwargs)
Args:
merge_fn
: function that joins arguments from threads that are given as PerReplica. It acceptstf.distribute.Strategy
object as the first argument.args
: List or tuple with positional per-thread arguments formerge_fn
.kwargs
: Dict with keyword per-thread arguments formerge_fn
.
Returns:
The return value of merge_fn
, except for PerReplica
values which are
unpacked.