Class BeamSearchDecoder
Inherits From: Decoder
Defined in tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py.
BeamSearch sampling decoder.
NOTE If you are using the BeamSearchDecoder with a cell wrapped in
AttentionWrapper, then you must ensure that:
- The encoder output has been tiled to
beam_widthviatf.contrib.seq2seq.tile_batch(NOTtf.tile). - The
batch_sizeargument passed to thezero_statemethod of this wrapper is equal totrue_batch_size * beam_width. - The initial state created with
zero_stateabove contains acell_statevalue containing properly tiled final state from the encoder.
An example:
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
encoder_outputs, multiplier=beam_width)
tiled_encoder_final_state = tf.contrib.seq2seq.tile_batch(
encoder_final_state, multiplier=beam_width)
tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
sequence_length, multiplier=beam_width)
attention_mechanism = MyFavoriteAttentionMechanism(
num_units=attention_depth,
memory=tiled_inputs,
memory_sequence_length=tiled_sequence_length)
attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
decoder_initial_state = attention_cell.zero_state(
dtype, batch_size=true_batch_size * beam_width)
decoder_initial_state = decoder_initial_state.clone(
cell_state=tiled_encoder_final_state)
Meanwhile, with AttentionWrapper, coverage penalty is suggested to use
when computing scores(https://arxiv.org/pdf/1609.08144.pdf). It encourages
the translation to cover all inputs.
__init__
__init__(
cell,
embedding,
start_tokens,
end_token,
initial_state,
beam_width,
output_layer=None,
length_penalty_weight=0.0,
coverage_penalty_weight=0.0,
reorder_tensor_arrays=True
)
Initialize the BeamSearchDecoder.
Args:
cell: AnRNNCellinstance.embedding: A callable that takes a vector tensor ofids(argmax ids), or theparamsargument forembedding_lookup.start_tokens:int32vector shaped[batch_size], the start tokens.end_token:int32scalar, the token that marks end of decoding.initial_state: A (possibly nested tuple of...) tensors and TensorArrays.beam_width: Python integer, the number of beams.output_layer: (Optional) An instance oftf.layers.Layer, i.e.,tf.layers.Dense. Optional layer to apply to the RNN output prior to storing the result or sampling.length_penalty_weight: Float weight to penalize length. Disabled with 0.0.coverage_penalty_weight: Float weight to penalize the coverage of source sentence. Disabled with 0.0.reorder_tensor_arrays: IfTrue,TensorArrays' elements within the cell state will be reordered according to the beam search path. If theTensorArraycan be reordered, the stacked form will be returned. Otherwise, theTensorArraywill be returned as is. Set this flag toFalseif the cell state containsTensorArrays that are not amenable to reordering.
Raises:
TypeError: ifcellis not an instance ofRNNCell, oroutput_layeris not an instance oftf.layers.Layer.ValueError: Ifstart_tokensis not a vector orend_tokenis not a scalar.
Properties
batch_size
The batch size of input values.
output_dtype
A (possibly nested tuple of...) dtype[s].
output_size
A (possibly nested tuple of...) integer[s] or TensorShape object[s].
tracks_own_finished
The BeamSearchDecoder shuffles its beams and their finished state.
For this reason, it conflicts with the dynamic_decode function's
tracking of finished states. Setting this property to true avoids
early stopping of decoding due to mismanagement of the finished state
in dynamic_decode.
Returns:
True.
Methods
tf.contrib.seq2seq.BeamSearchDecoder.finalize
finalize(
outputs,
final_state,
sequence_lengths
)
Finalize and return the predicted_ids.
Args:
outputs: An instance of BeamSearchDecoderOutput.final_state: An instance of BeamSearchDecoderState. Passed through to the output.sequence_lengths: Anint64tensor shaped[batch_size, beam_width]. The sequence lengths determined for each beam during decode. NOTE These are ignored; the updated sequence lengths are stored infinal_state.lengths.
Returns:
outputs: An instance ofFinalBeamSearchDecoderOutputwhere the predicted_ids are the result of calling _gather_tree.final_state: The same input instance ofBeamSearchDecoderState.
tf.contrib.seq2seq.BeamSearchDecoder.initialize
initialize(name=None)
Initialize the decoder.
Args:
name: Name scope for any created operations.
Returns:
(finished, start_inputs, initial_state).
tf.contrib.seq2seq.BeamSearchDecoder.step
step(
time,
inputs,
state,
name=None
)
Perform a decoding step.
Args:
time: scalarint32tensor.inputs: A (structure of) input tensors.state: A (structure of) state tensors and TensorArrays.name: Name scope for any created operations.
Returns:
(outputs, next_state, next_inputs, finished).