tfgan折腾笔记(三):核心函数详述——gan_loss族
WongWai95 人气:0gan_loss族的函数有:
1.gan_loss:
函数原型:
def gan_loss( # GANModel. model, # Loss functions. generator_loss_fn=tuple_losses.wasserstein_generator_loss, discriminator_loss_fn=tuple_losses.wasserstein_discriminator_loss, # Auxiliary losses. gradient_penalty_weight=None, gradient_penalty_epsilon=1e-10, gradient_penalty_target=1.0, gradient_penalty_one_sided=False, mutual_information_penalty_weight=None, aux_cond_generator_weight=None, aux_cond_discriminator_weight=None, tensor_pool_fn=None, # Options. reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS, add_summaries=True)
参数:
model:gan_model族函数的返回值
generator_loss_fn:生成器使用的损失函数,可用函数见其他说明。
discriminator_loss_fn:判别器使用的损失函数,可用函数见其他说明。
gradient_penalty_weight:如果不是None,则必须提供一个非负数或Tensor,意义为梯度惩罚的权值。
gradient_penalty_epsilon:如果提供了上一个参数,那么这个参数应该提供一个用于在梯度罚函数中维持数值稳定性的较小的正值。 请注意,某些应用程序需要增加此值以避免NaN。
gradient_penalty_target:如果上上个参数不是None,那么这个参数就指明了梯度规范的目标值。应该是一个数值类型或Tensor。
gradient_penalty_one_sided:(暂不明白什么意思)。
mutual_information_penalty_weight:交叉信息惩罚权值。如果不是None,必须提供一个非负数或Tensor。
aux_cond_generator_weight:如果不是None,则添加生成器分类损失。
aux_cond_discriminator_weight:如果不是None,则添加判别器分类损失。
tensor_pool_fn:tensor pool函数。此函数传入tuple类型:(generated_data, generator_inputs),函数将它们放在内部pool中,并且返回上一个pool中的值。如,可以传入tfgan.features.tensor_pool。
reduction:传入tf.losses.Reduction类的函数。
add_summaries:是否添加总结到Tensorboard日志。
返回值:
返回“GANLoss 命名元组”。
函数内部实现:
# Create standard losses with optional kwargs, if the loss functions accept # them. def _optional_kwargs(fn, possible_kwargs): """Returns a kwargs dictionary of valid kwargs for a given function.""" if inspect.getargspec(fn).keywords is not None: return possible_kwargs actual_args = inspect.getargspec(fn).args actual_kwargs = {} for k, v in possible_kwargs.items(): if k in actual_args: actual_kwargs[k] = v return actual_kwargs possible_kwargs = {'reduction': reduction, 'add_summaries': add_summaries} gen_loss = generator_loss_fn( model, **_optional_kwargs(generator_loss_fn, possible_kwargs)) dis_loss = discriminator_loss_fn( pooled_model, **_optional_kwargs(discriminator_loss_fn, possible_kwargs))
其他说明:
- tfgan内置损失函数:
__all__ = [ 'acgan_discriminator_loss', 'acgan_generator_loss', 'least_squares_discriminator_loss', 'least_squares_generator_loss', 'modified_discriminator_loss', 'modified_generator_loss', 'minimax_discriminator_loss', 'minimax_generator_loss', 'wasserstein_discriminator_loss', 'wasserstein_hinge_discriminator_loss', 'wasserstein_hinge_generator_loss', 'wasserstein_generator_loss', 'wasserstein_gradient_penalty', 'mutual_information_penalty', 'combine_adversarial_loss', 'cycle_consistency_loss', 'stargan_generator_loss_wrapper', 'stargan_discriminator_loss_wrapper', 'stargan_gradient_penalty_wrapper' ]
2.cyclegan_loss:
函数原型:
def cyclegan_loss( model, # Loss functions. generator_loss_fn=tuple_losses.least_squares_generator_loss, discriminator_loss_fn=tuple_losses.least_squares_discriminator_loss, # Auxiliary losses. cycle_consistency_loss_fn=tuple_losses.cycle_consistency_loss, cycle_consistency_loss_weight=10.0, # Options **kwargs)
参数:
model:gan_model族函数的返回值
generator_loss_fn:生成器使用的损失函数。
discriminator_loss_fn:判别器使用的损失函数。
cycle_consistency_loss_fn:循环一致性损失函数。
cycle_consistency_loss_weight:循环一致性损失的权值。
**kwargs:这里的参数将直接传递给cyclegan_loss函数内部调用的gan_loss函数。
返回值:
返回“CycleGANLoss 命名元组”。
函数内部实现:
循环一致性损失函数与权值的定义:
# Defines cycle consistency loss. cycle_consistency_loss = cycle_consistency_loss_fn( model, add_summaries=kwargs.get('add_summaries', True)) cycle_consistency_loss_weight = _validate_aux_loss_weight( cycle_consistency_loss_weight, 'cycle_consistency_loss_weight') aux_loss = cycle_consistency_loss_weight * cycle_consistency_loss
**kwargs的实现:
# Defines losses for each partial model. def _partial_loss(partial_model): partial_loss = gan_loss( partial_model, generator_loss_fn=generator_loss_fn, discriminator_loss_fn=discriminator_loss_fn, **kwargs) return partial_loss._replace(generator_loss=partial_loss.generator_loss + aux_loss) with tf.compat.v1.name_scope('cyclegan_loss_x2y'): loss_x2y = _partial_loss(model.model_x2y) with tf.compat.v1.name_scope('cyclegan_loss_y2x'): loss_y2x = _partial_loss(model.model_y2x)
其他说明:
- cycle-gan实际上是由两个普通gan组合而成的,其loss是普通gan的loss加上循环一致性损失。
- 循环一致性损失权值越大,则X->Y->X循环的相似性方面学习的越快。
3.stargan_loss:
函数原型:
def stargan_loss( model, generator_loss_fn=tuple_losses.stargan_generator_loss_wrapper( losses_wargs.wasserstein_generator_loss), discriminator_loss_fn=tuple_losses.stargan_discriminator_loss_wrapper( losses_wargs.wasserstein_discriminator_loss), gradient_penalty_weight=10.0, gradient_penalty_epsilon=1e-10, gradient_penalty_target=1.0, gradient_penalty_one_sided=False, reconstruction_loss_fn=tf.compat.v1.losses.absolute_difference, reconstruction_loss_weight=10.0, classification_loss_fn=tf.compat.v1.losses.softmax_cross_entropy, classification_loss_weight=1.0, classification_one_hot=True, add_summaries=True)
参数:
model:gan_model族函数的返回值
generator_loss_fn:生成器使用的损失函数。
discriminator_loss_fn:判别器使用的损失函数。
gradient_penalty_weight:如果不是None,则必须提供一个非负数或Tensor,意义为梯度惩罚的权值。
gradient_penalty_epsilon:如果提供了上一个参数,那么这个参数应该提供一个用于在梯度罚函数中维持数值稳定性的较小的正值。 请注意,某些应用程序需要增加此值以避免NaN。
gradient_penalty_target:如果上上个参数不是None,那么这个参数就指明了梯度规范的目标值。应该是一个数值类型或Tensor。
gradient_penalty_one_sided:(暂不明白什么意思)。
reconstruction_loss_fn:重建损失函数。
reconstruction_loss_weight:重建损失的权重。
classification_loss_fn:分类损失函数。
classification_loss_weight:分类损失的权重。
classification_one_hot:分类的one_hot_label。
add_summaries:是否向tensorboard添加总结。
返回值:
返回“StarGANLoss 命名元组”。
函数内部实现:
梯度惩罚函数与权值的定义:
# Gradient Penalty. if _use_aux_loss(gradient_penalty_weight): gradient_penalty_fn = tuple_losses.stargan_gradient_penalty_wrapper( losses_wargs.wasserstein_gradient_penalty) discriminator_loss += gradient_penalty_fn( model, epsilon=gradient_penalty_epsilon, target=gradient_penalty_target, one_sided=gradient_penalty_one_sided, add_summaries=add_summaries) * gradient_penalty_weight
重建损失函数与权值的定义:
# Reconstruction Loss. reconstruction_loss = reconstruction_loss_fn(model.input_data, model.reconstructed_data) generator_loss += reconstruction_loss * reconstruction_loss_weight if add_summaries: tf.compat.v1.summary.scalar('reconstruction_loss', reconstruction_loss)
分类损失函数与权值定义:
# Classification Loss. generator_loss += _classification_loss_helper( true_labels=model.generated_data_domain_target, predict_logits=model.discriminator_generated_data_domain_predication, scope_name='generator_classification_loss') * classification_loss_weight discriminator_loss += _classification_loss_helper( true_labels=model.input_data_domain_label, predict_logits=model.discriminator_input_data_domain_predication, scope_name='discriminator_classification_loss' ) * classification_loss_weight
其他说明:
无
加载全部内容