Table of Contents

Class WassersteinLoss<T>

Namespace
AiDotNet.LossFunctions
Assembly
AiDotNet.dll

Implements the Wasserstein loss function used in Wasserstein Generative Adversarial Networks (WGAN).

public class WassersteinLoss<T> : LossFunctionBase<T>, ILossFunction<T>

Type Parameters

T

The numeric type used for calculations (e.g., float, double).

Inheritance
WassersteinLoss<T>
Implements
Inherited Members
Extension Methods

Remarks

The Wasserstein loss (also known as Earth Mover's Distance loss) measures the distance between two probability distributions. In the context of GANs, it provides a meaningful gradient signal even when the discriminator (critic) is well-trained.

Mathematical Formula:

  • Loss = mean(predicted * label)
  • Where label is +1 for real samples, -1 for fake samples
  • The critic aims to maximize E[critic(real)] - E[critic(fake)]

For Beginners: Wasserstein loss is a special way to measure how different two groups of data are.

Why use Wasserstein loss instead of regular binary cross-entropy?

  • More stable training - gradients don't vanish when the critic is confident
  • The loss value correlates with image quality - lower loss means better images
  • No mode collapse - the generator doesn't get stuck producing the same output
  • Can train the critic to convergence without breaking training

How it works:

  • For real images, we want the critic to output high scores (label = +1)
  • For fake images, we want the critic to output low scores (label = -1)
  • The loss is simply the average of (score * label)
  • A well-trained critic gives positive scores to real images and negative scores to fakes

Reference: Arjovsky et al., "Wasserstein GAN" (2017)

Methods

CalculateDerivative(Vector<T>, Vector<T>)

Calculates the derivative of the Wasserstein loss function.

public override Vector<T> CalculateDerivative(Vector<T> predicted, Vector<T> actual)

Parameters

predicted Vector<T>

The critic's output scores for each sample.

actual Vector<T>

The labels: +1 for real samples, -1 for fake samples.

Returns

Vector<T>

A vector containing the derivatives of the Wasserstein loss for each prediction.

Remarks

The derivative of the Wasserstein loss with respect to the predicted scores is simply the negative of the labels (after accounting for the mean).

Mathematical Derivation:

  • Loss = -mean(predicted * actual) = -(1/n) * sum(predicted_i * actual_i)
  • dLoss/d(predicted_i) = -actual_i / n

For Beginners: The derivative tells the network which direction to adjust.

For a real sample (label = +1):

  • Derivative is negative, so increasing the score decreases the loss
  • This pushes the critic to give higher scores to real images

For a fake sample (label = -1):

  • Derivative is positive, so decreasing the score decreases the loss
  • This pushes the critic to give lower scores to fake images

CalculateLoss(Vector<T>, Vector<T>)

Calculates the Wasserstein loss between predicted critic scores and labels.

public override T CalculateLoss(Vector<T> predicted, Vector<T> actual)

Parameters

predicted Vector<T>

The critic's output scores for each sample.

actual Vector<T>

The labels: +1 for real samples, -1 for fake samples.

Returns

T

The mean Wasserstein loss across all samples.

Remarks

The loss is computed as the negative mean of (predicted * actual), which means:

  • For real samples (label=+1): we want high predicted scores
  • For fake samples (label=-1): we want low predicted scores

For Beginners: This computes how well the critic is doing at telling real from fake.

Example:

  • Real image, critic outputs +5, label is +1: contributes +5 (good!)
  • Fake image, critic outputs -3, label is -1: contributes +3 (good!)
  • Real image, critic outputs -2, label is +1: contributes -2 (bad!)

The loss is negated so that minimizing the loss = maximizing the Wasserstein distance.