Class SampleEmbeddingHelper
Inherits From: GreedyEmbeddingHelper
Defined in tensorflow/contrib/seq2seq/python/ops/helper.py
.
A helper for use during inference.
Uses sampling (from a distribution) instead of argmax and passes the result through an embedding layer to get the next input.
__init__
__init__(
embedding,
start_tokens,
end_token,
softmax_temperature=None,
seed=None
)
Initializer.
Args:
embedding
: A callable that takes a vector tensor ofids
(argmax ids), or theparams
argument forembedding_lookup
. The returned tensor will be passed to the decoder input.start_tokens
:int32
vector shaped[batch_size]
, the start tokens.end_token
:int32
scalar, the token that marks end of decoding.softmax_temperature
: (Optional)float32
scalar, value to divide the logits by before computing the softmax. Larger values (above 1.0) result in more random samples, while smaller values push the sampling distribution towards the argmax. Must be strictly greater than 0. Defaults to 1.0.seed
: (Optional) The sampling seed.
Raises:
ValueError
: ifstart_tokens
is not a 1D tensor orend_token
is not a scalar.
Properties
batch_size
Batch size of tensor returned by sample
.
Returns a scalar int32 tensor.
sample_ids_dtype
DType of tensor returned by sample
.
Returns a DType.
sample_ids_shape
Shape of tensor returned by sample
, excluding the batch dimension.
Returns a TensorShape
.
Methods
tf.contrib.seq2seq.SampleEmbeddingHelper.initialize
initialize(name=None)
Returns (initial_finished, initial_inputs)
.
tf.contrib.seq2seq.SampleEmbeddingHelper.next_inputs
next_inputs(
time,
outputs,
state,
sample_ids,
name=None
)
next_inputs_fn for GreedyEmbeddingHelper.
tf.contrib.seq2seq.SampleEmbeddingHelper.sample
sample(
time,
outputs,
state,
name=None
)
sample for SampleEmbeddingHelper.