Class StrategyExtended
Aliases:
- Class
tf.contrib.distribute.DistributionStrategyExtended
- Class
tf.distribute.StrategyExtended
Defined in tensorflow/python/distribute/distribute_lib.py
.
Additional APIs for algorithms that need to be distribution-aware.
The intent is that you can write an algorithm in a stylized way and
it will be usable with a variety of different
tf.distribute.Strategy
implementations. Each descendant will implement a different strategy
for distributing the algorithm across multiple devices/machines.
Furthermore, these changes can be hidden inside the specific layers
and other library classes that need special treatment to run in a
distributed setting, so that most users' model definition code can
run unchanged. The tf.distribute.Strategy
API works the same way
with eager and graph execution.
First let's introduce a few high-level concepts:
- Data parallelism is where we run multiple copies of the model on different slices of the input data. This is in contrast to model parallelism where we divide up a single copy of a model across multiple devices. Note: we only support data parallelism for now, but hope to add support for model parallelism in the future.
- A replica is one copy of the model, running on one slice of the input data.
- Synchronous, or more commonly sync, training is where the updates from each replica are aggregated together before updating the model variables. This is in contrast to asynchronous, or async training, where each replica updates the model variables independently.
- Furthermore you might run your computation on multiple devices on one machine (or "host"), or on multiple machines/hosts. If you are running on multiple machines, you might have a single master host that drives computation across all of them, or you might have multiple clients driving the computation asynchronously.
To distribute an algorithm, we might use some of these ingredients:
- Parameter servers: These are hosts that hold a single copy of parameters/variables. All replicas that want to operate on a variable retrieve it at the beginning of a step and send an update to be applied at the end of the step. Can support either sync or async training.
- Mirrored variables: These are variables that are copied to multiple devices, where we keep the copies in sync by applying the same updates to every copy. Normally would only be used with sync training.
- Reductions and Allreduce: A reduction is some method of aggregating multiple values into one value, like "sum" or "mean". If doing sync training, we will perform a reduction on the gradients to a parameter from all replicas before applying the update. Allreduce is an algorithm for performing a reduction on values from multiple devices and making the result available on all of those devices.
- In the future we will have support for TensorFlow's partitioned variables, where a single variable is split across multiple devices.
We have then a few approaches we want to support:
- Code written (as if) with no knowledge of class
tf.distribute.Strategy
. This code should work as before, even if some of the layers, etc. used by that code are written to be distribution-aware. This is done by having a defaulttf.distribute.Strategy
that gives ordinary behavior, and by default being in a single replica context. - Ordinary model code that you want to run using a specific
tf.distribute.Strategy
. This can be as simple as:
with my_strategy.scope(): iterator = my_strategy.make_dataset_iterator(dataset) session.run(iterator.initialize()) replica_train_ops = my_strategy.extended.call_for_each_replica( replica_fn, args=(iterator.get_next(),)) train_op = my_strategy.group(replica_train_ops)
This takes an ordinary dataset
and replica_fn
and runs it
distributed using a particular tf.distribute.Strategy
in
my_strategy
. Any variables created in replica_fn
are created
using my_strategy
's policy, and library functions called by
replica_fn
can use the get_replica_context()
API to get enhanced
behavior in this case.
- If you want to write a distributed algorithm, you may use any of
the
tf.distribute.Strategy
APIs inside awith my_strategy.scope():
block of code.
Lower-level concepts:
- Wrapped values: In order to represent values parallel across devices (either replicas or the devices associated with a particular value), we wrap them in a "PerReplica" or "Mirrored" object that contains a map from device to values. "PerReplica" is used when the value may be different across replicas, and "Mirrored" when the value are the same.
- Unwrapping and merging: Consider calling a function
fn
on multiple replicas, likeextended.call_for_each_replica(fn, args=[w])
with an argumentw
that is a wrapped value. This meansw
will have a map taking replica deviced0
tow0
, replica deviced1
tow1
, etc.extended.call_for_each_replica()
unwrapsw
before callingfn
, so it callsfn(w0)
ond0
,fn(w1)
ond1
, etc. It then merges the return values fromfn()
, which can possibly result in wrapped values. For example, let's sayfn()
returns a tuple with three components:(x, a, v0)
from replica 0,(x, b, v1)
on replica 1, etc. If the first component is the same objectx
from every replica, then the first component of the merged result will also bex
. If the second component is different (a
,b
, ...) from each replica, then the merged value will have a wrapped map from replica device to the different values. If the third component is the members of a mirrored variable (v
mapsd0
tov0
,d1
tov1
, etc.), then the merged result will be that mirrored variable (v
). - Replica context vs. Cross-replica context: replica context is when we
are in some function that is being called once for each replica.
Otherwise we are in cross-replica context, which is useful for
calling
tf.distribute.Strategy
methods which operate across the replicas (likereduce_to()
). By default you start in a replica context (the default "single replica context") and then some methods can switch you back and forth, as described below. - Worker devices vs. parameter devices: Most replica computations will happen on worker devices. Since we don't yet support model parallelism, there will be one worker device per replica. When using parameter servers (see above), the set of devices holding variables may be different, otherwise the parameter devices might match the worker devices.
- Non-slot devices are some subset of the parameter devices where we
put all the non-slot variables. We need to ensure that all
non-slot variables are allocated on the same device, or mirrored
across the same set of devices. If you have some variable you want
to colocate all the non-slot variables with, you can use
colocate_vars_with()
to get the remaining non-slot variables on the same device. Otherwise you can usenon_slot_devices()
to pick a consistent set of devices to pass to bothcolocate_vars_with()
andupdate_non_slot()
.
When using a tf.distribute.Strategy
, we have a new type dimension
called locality that says what values are compatible with which
APIs:
- T: different value for each replica (e.g. a PerReplica-wrapped value).
- M: value is "mirrored" across replicas, i.e. there are copies with the same value on each replica (e.g. a Mirrored-wrapped value).
- V(
v
): value is "mirrored" across all the devices which have a copy of variablev
(also a Mirrored-wrapped value, but over parameter devices instead of worker devices). - N: value is "mirrored" across all the "non-slot" devices
Rules for methods with respect to locality and single-replica vs. cross-replica context:
with d.scope()
: default single-replica context -> cross-replica context ford
with d.extended.colocate_vars_with(v)
: in replica/cross-replica context, variables will be created with locality V(v
). That is, if we writewith d.extended.colocate_vars_with(v1): v2 = tf.get_variable(...)
, thenv2
will have locality V(v1
), i.e. locality V(v2
) will equal V(v1
).with d.extended.colocate_vars_with(d.extended.non_slot_devices(...))
: in replica/cross-replica context, variables will be created with locality Nv = tf.get_variable(...)
: in replica/cross-replica context, creates a variable (which by definition will have locality V(v
), though will match another locality if inside acolocate_vars_with
scope).d.make_dataset_iterator(dataset)
(or the deprecatedd.distribute_dataset(dataset).make_one_shot_iterator()
): in cross-replica context, produces an iterator with locality Td.extended.broadcast_to(t)
: in cross-replica context, produces a value with locality Md.extended.broadcast_to(t, v)
: in cross-replica context, produces a value with locality V(v
)d.extended.call_for_each_replica(fn, ...)
: in cross-replica context, runsfn()
in a replica context (and so may callget_replica_context()
and use its API, includingmerge_call()
to get back to cross-replica context), once for each replica. May use values with locality T or M, and any variable.d.extended.reduce_to(m, t, t)
: in cross-replica context, accepts t with locality T and produces a value with locality M.d.extended.reduce_to(m, t, v)
: in cross-replica context, accepts t with locality T and produces a value with locality V(v
).d.extended.batch_reduce_to(m, [(t, v)]): see
d.extended.reduce_to()`d.extended.update(v, fn, ...)
: in cross-replica context, runsfn()
once for each devicev
is copied to, all inputs should have locality V(v
), output will have locality V(v
) as well.d.extended.update_non_slot(d.extended.non_slot_devices(), fn)
: in cross-replica context, liked.extended.update()
except with locality N.d.extended.read_var(v)
: Gets the (read-only) value of the variablev
(on the device determined by the current device scope), aggregating across replicas for replica-local variables. Frequently, this will be done automatically when usingv
in an expression or fetching it in a cross-replica context, but this function can be used to force that conversion happens at a particular point in time (for example, to add the result of the conversion to a graph collection).
The standard pattern for updating variables is to:
- Create an input iterator with
d.make_dataset_iterator()
. - Define each replica
d.extended.call_for_each_replica()
up to the point of getting a list of gradient, variable pairs. - Call
d.extended.reduce_to(VariableAggregation.SUM, t, v)
ord.extended.batch_reduce_to()
to sum the gradients (with locality T) into values with locality V(v
). - Call
d.extended.update(v)
for each variable to update its value.
Steps 3 and 4 are done automatically by class Optimizer
if you call
its apply_gradients
method in a replica context. Otherwise you can
manually call its _distributed_apply
method in a cross-replica context.
Another thing you might want to do in the middle of your replica function is
an all-reduce of some intermediate value, using d.extended.reduce_to()
or
d.extended.batch_reduce_to()
. You simply provide the same tensor as the
input and destination.
Layers should expect to be called in a replica context, and can use
the tf.distribute.get_replica_context
function to get a
tf.distribute.ReplicaContext
object. The
ReplicaContext
object has a merge_call()
method for entering
cross-replica context where you can use reduce_to()
(or
batch_reduce_to()
) and then optionally update()
to update state.
You may use this API whether or not a tf.distribute.Strategy
is
being used, since there is a default implementation of
ReplicaContext
and tf.distribute.Strategy
.
NOTE for new tf.distribute.Strategy
implementations: Please put all logic
in a subclass of tf.distribute.StrategyExtended
. The only code needed for
the tf.distribute.Strategy
subclass is for instantiating your subclass of
tf.distribute.StrategyExtended
in the __init__
method.
__init__
__init__(container_strategy)
Initialize self. See help(type(self)) for accurate signature.
Properties
experimental_between_graph
Whether the strategy uses between-graph replication or not.
This is expected to return a constant value that will not be changed throughout its life cycle.
experimental_require_static_shapes
experimental_should_init
Whether initialization is needed.
parameter_devices
Returns the tuple of all devices used to place variables.
should_checkpoint
Whether checkpointing is needed.
should_save_summary
Whether saving summaries is needed.
worker_devices
Returns the tuple of all devices used to for compute replica execution.
Methods
tf.distribute.StrategyExtended.batch_reduce_to
batch_reduce_to(
reduce_op,
value_destination_pairs
)
Combine multiple reduce_to
calls into one for faster execution.
Args:
reduce_op
: Reduction type, an instance oftf.distribute.ReduceOp
enum. DEPRECATED but still accepted values:tf.VariableAggregation.SUM
,tf.VariableAggregation.MEAN
,value_destination_pairs
: A sequence of (value, destinations) pairs. Seereduce_to()
for a description.
Returns:
A list of mirrored values, one per pair in value_destination_pairs
.
tf.distribute.StrategyExtended.broadcast_to
broadcast_to(
tensor,
destinations
)
Mirror a tensor on one device to all worker devices.
Args:
tensor
: A Tensor value to broadcast.destinations
: A mirrored variable or device string specifying the destination devices to copytensor
to.
Returns:
A value mirrored to destinations
devices.
tf.distribute.StrategyExtended.call_for_each_replica
call_for_each_replica(
fn,
args=(),
kwargs=None
)
Run fn
once per replica.
fn
may call tf.get_replica_context()
to access methods such as
replica_id_in_sync_group
and merge_call()
.
merge_call()
is used to communicate between the replicas and
re-enter the cross-replica context. All replicas pause their execution
having encountered a merge_call()
call. After that the
merge_fn
-function is executed. Its results are then unwrapped and
given back to each replica call. After that execution resumes until
fn
is complete or encounters another merge_call()
. Example:
# Called once in "cross-replica" context.
def merge_fn(distribution, three_plus_replica_id):
# sum the values across replicas
return sum(distribution.unwrap(three_plus_replica_id))
# Called once per replica in `distribution`, in a "replica" context.
def fn(three):
replica_ctx = tf.get_replica_context()
v = three + replica_ctx.replica_id_in_sync_group
# Computes the sum of the `v` values across all replicas.
s = replica_ctx.merge_call(merge_fn, args=(v,))
return s + v
with distribution.scope():
# in "cross-replica" context
...
merged_results = distribution.call_for_each_replica(fn, args=[3])
# merged_results has the values from every replica execution of `fn`.
print(distribution.unwrap(merged_results)) # Prints a list
Args:
fn
: function to run (will be run once per replica).args
: Tuple or list with positional arguments forfn
.kwargs
: Dict with keyword arguments forfn
.
Returns:
Merged return value of fn
across all replicas.
tf.distribute.StrategyExtended.colocate_vars_with
colocate_vars_with(colocate_with_variable)
Scope that controls which devices variables will be created on.
No operations should be added to the graph inside this scope, it should only be used when creating variables (some implementations work by changing variable creation, others work by using a tf.colocate_with() scope).
This may only be used inside self.scope()
.
Example usage:
with strategy.scope():
var1 = tf.get_variable(...)
with strategy.extended.colocate_vars_with(v1):
# var2 and var3 will be created on the same device(s) as var1
var2 = tf.get_variable(...)
var3 = tf.get_variable(...)
def fn(v1, v2, v3):
# operates on v1 from var1, v2 from var2, and v3 from var3
# `fn` runs on every device `v1` is on, `v2` and `v3` will be there too.
strategy.extended.update(v1, fn, args=(v2, v3))
Args:
colocate_with_variable
: A created inself.scope()
. Variables created while in the returned context manager will be on the same set of devices ascolocate_with_variable
.
Returns:
A context manager.
tf.distribute.StrategyExtended.experimental_run_steps_on_iterator
experimental_run_steps_on_iterator(
fn,
iterator,
iterations=1,
initial_loop_values=None
)
Run fn
with input from iterator
for iterations
times.
This method can be used to run a step function for training a number of times using input from a dataset.
Args:
fn
: function to run using this distribution strategy. The function must have the following signature:def fn(context, inputs)
.context
is an instance ofMultiStepContext
that will be passed whenfn
is run.context
can be used to specify the outputs to be returned fromfn
by callingcontext.set_last_step_output
. It can also be used to capture non tensor outputs bycontext.set_non_tensor_output
. SeeMultiStepContext
documentation for more information.inputs
will have same type/structure asiterator.get_next()
. Typically,fn
will usecall_for_each_replica
method of the strategy to distribute the computation over multiple replicas.iterator
: Iterator of a dataset that represents the input forfn
. The caller is responsible for initializing the iterator as needed.iterations
: (Optional) Number of iterations thatfn
should be run. Defaults to 1.initial_loop_values
: (Optional) Initial values to be passed into the loop that runsfn
. Defaults toNone
. # TODO(priyag): Remove initial_loop_values argument when we have a mechanism to infer the outputs offn
.
Returns:
Returns the MultiStepContext
object which has the following properties,
among other things:
- run_op: An op that runs fn
iterations
times.
- last_step_outputs: A dictionary containing tensors set using
context.set_last_step_output
. Evaluating this returns the value of
the tensors after the last iteration.
- non_tensor_outputs: A dictionatry containing anything that was set by
fn
by calling context.set_non_tensor_output
.
tf.distribute.StrategyExtended.non_slot_devices
non_slot_devices(var_list)
Device(s) for non-slot variables.
Create variables on these devices in a
with colocate_vars_with(non_slot_devices(...)):
block.
Update those using update_non_slot()
.
Args:
var_list
: The list of variables being optimized, needed with the defaulttf.distribute.Strategy
.
tf.distribute.StrategyExtended.read_var
read_var(v)
Reads the value of a variable.
Returns the aggregate value of a replica-local variable, or the (read-only) value of any other variable.
Args:
v
: A variable allocated within the scope of thistf.distribute.Strategy
.
Returns:
A tensor representing the value of v
, aggregated across replicas if
necessary.
tf.distribute.StrategyExtended.reduce_to
reduce_to(
reduce_op,
value,
destinations
)
Combine (via e.g. sum or mean) values across replicas.
Args:
reduce_op
: Reduction type, an instance oftf.distribute.ReduceOp
enum. DEPRECATED but still accepted values:tf.VariableAggregation.SUM
,tf.VariableAggregation.MEAN
,value
: A per-replica value with one value per replica.destinations
: A mirrored variable, a per-replica tensor, or a device string. The return value will be copied to all destination devices (or all the devices where thedestinations
value resides). To perform an all-reduction, passvalue
todestinations
.
Returns:
A value mirrored to destinations
.
tf.distribute.StrategyExtended.update
update(
var,
fn,
args=(),
kwargs=None,
group=True
)
Run fn
to update var
using inputs mirrored to the same devices.
If var
is mirrored across multiple devices, then this implements
logic like:
results = {}
for device, v in var:
with tf.device(device):
# args and kwargs will be unwrapped if they are mirrored.
results[device] = fn(v, *args, **kwargs)
return merged(results)
Otherwise this returns fn(var, *args, **kwargs)
colocated with var
.
Neither args
nor kwargs
may contain per-replica values.
If they contain mirrored values, they will be unwrapped before
calling fn
.
Args:
var
: Variable, possibly mirrored to multiple devices, to operate on.fn
: Function to call. Should take the variable as the first argument.args
: Tuple or list. Additional positional arguments to pass tofn()
.kwargs
: Dict with keyword arguments to pass tofn()
.group
: Boolean. Defaults to True. If False, the return value will be unwrapped.
Returns:
By default, the merged return value of fn
across all replicas. The
merged result has dependencies to make sure that if it is evaluated at
all, the side effects (updates) will happen on every replica. If instead
"group=False" is specified, this function will return a nest of lists
where each list has an element per replica, and the caller is responsible
for ensuring all elements are executed.
tf.distribute.StrategyExtended.update_non_slot
update_non_slot(
colocate_with,
fn,
args=(),
kwargs=None,
group=True
)
Runs fn(*args, **kwargs)
on colocate_with
devices.
Args:
colocate_with
: The return value ofnon_slot_devices()
.fn
: Function to execute.args
: Tuple or list. Positional arguments to pass tofn()
.kwargs
: Dict with keyword arguments to pass tofn()
.group
: Boolean. Defaults to True. If False, the return value will be unwrapped.
Returns:
Return value of fn
, possibly merged across devices.
tf.distribute.StrategyExtended.value_container
value_container(value)
Returns the container that this per-replica value
belongs to.
Args:
value
: A value returned bycall_for_each_replica()
or a variable created inscope()
.
Returns:
A container that value
belongs to.
If value does not belong to any container (including the case of
container having been destroyed), returns the value itself.
value in unwrap(value_container(value))
will always be true.