Class ModelAverageCustomGetter
Defined in tensorflow/contrib/opt/python/training/model_average_optimizer.py
.
Custom_getter class is used to do.
- Change trainable variables to local collection and place them at worker device
- Generate global variables Notice that the class should be used with tf.replica_device_setter, so that the global center variables and global step variable can be placed at ps device. Besides, use 'tf.get_variable' instead of 'tf.Variable' to use this custom getter.
For example, ma_custom_getter = ModelAverageCustomGetter(worker_device) with tf.device( tf.train.replica_device_setter( worker_device=worker_device, ps_device="/job:ps/cpu:0", cluster=cluster)), tf.variable_scope('',custom_getter=ma_custom_getter): hid_w = tf.get_variable( initializer=tf.truncated_normal( [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units], stddev=1.0 / IMAGE_PIXELS), name="hid_w") hid_b = tf.get_variable(initializer=tf.zeros([FLAGS.hidden_units]), name="hid_b")
__init__
__init__(worker_device)
Create a new ModelAverageCustomGetter
.
Args:
worker_device
: String. Name of theworker
job.
Methods
tf.contrib.opt.ModelAverageCustomGetter.__call__
__call__(
getter,
name,
trainable,
collections,
*args,
**kwargs
)
Call self as a function.