tf.contrib.gan.stargan_model(
generator_fn,
discriminator_fn,
input_data,
input_data_domain_label,
generator_scope='Generator',
discriminator_scope='Discriminator'
)
Defined in tensorflow/contrib/gan/python/train.py
.
Returns a StarGAN model outputs and variables.
See https://arxiv.org/abs/1711.09020 for more details.
Args:
generator_fn
: A python lambda that takesinputs
andtargets
as inputs and returns 'generated_data' as the transformed version ofinput
based on thetarget
.input
has shape (n, h, w, c),targets
has shape (n, num_domains), andgenerated_data
has the same shape asinput
.discriminator_fn
: A python lambda that takesinputs
andnum_domains
as inputs and returns a tuple (source_prediction
,domain_prediction
).source_prediction
represents the source(real/generated) prediction by the discriminator, anddomain_prediction
represents the domain prediction/classification by the discriminator.source_prediction
has shape (n) anddomain_prediction
has shape (n, num_domains).input_data
: Tensor or a list of tensor of shape (n, h, w, c) representing the real input images.input_data_domain_label
: Tensor or a list of tensor of shape (batch_size, num_domains) representing the domain label associated with the real images.generator_scope
: Optional generator variable scope. Useful if you want to reuse a subgraph that has already been created.discriminator_scope
: Optional discriminator variable scope. Useful if you want to reuse a subgraph that has already been created.
Returns:
StarGANModel nametuple return the tensor that are needed to compute the loss.
Raises:
ValueError
: If the shape ofinput_data_domain_label
is not rank 2 or fully defined in every dimensions.