Table of Contents

Class WGAN<T>

Namespace
AiDotNet.NeuralNetworks
Assembly
AiDotNet.dll

Represents a Wasserstein Generative Adversarial Network (WGAN), which uses the Wasserstein distance (Earth Mover's distance) to measure the difference between the generated and real data distributions.

public class WGAN<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable

Type Parameters

T

The numeric type used for calculations, typically float or double.

Inheritance
WGAN<T>
Implements
IFullModel<T, Tensor<T>, Tensor<T>>
IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>
IParameterizable<T, Tensor<T>, Tensor<T>>
ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>
IGradientComputable<T, Tensor<T>, Tensor<T>>
Inherited Members
Extension Methods

Remarks

WGAN addresses several training instabilities in vanilla GANs by: - Using Wasserstein distance instead of Jensen-Shannon divergence - Replacing the discriminator with a "critic" that doesn't output probabilities - Enforcing a Lipschitz constraint through weight clipping - Providing a loss that correlates with image quality - Enabling more stable training and better convergence

For Beginners: WGAN is an improved GAN that solves many training problems.

Key improvements over vanilla GAN:

  • More stable training (less likely to fail)
  • The loss value actually tells you how well training is going
  • No mode collapse issues (generating only a few types of outputs)
  • Can train the discriminator (critic) many times without problems

The main change is using a different mathematical way to measure the difference between real and fake images, which turns out to be much more stable.

Reference: Arjovsky et al., "Wasserstein GAN" (2017)

Constructors

WGAN(NeuralNetworkArchitecture<T>, NeuralNetworkArchitecture<T>, InputType, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, ILossFunction<T>?, double, int)

Initializes a new instance of the WGAN<T> class.

public WGAN(NeuralNetworkArchitecture<T> generatorArchitecture, NeuralNetworkArchitecture<T> criticArchitecture, InputType inputType, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? generatorOptimizer = null, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? criticOptimizer = null, ILossFunction<T>? lossFunction = null, double weightClipValue = 0.01, int criticIterations = 5)

Parameters

generatorArchitecture NeuralNetworkArchitecture<T>

The neural network architecture for the generator. The generator output size must match the critic input size.

criticArchitecture NeuralNetworkArchitecture<T>

The neural network architecture for the critic. The critic output size must be 1 (single Wasserstein score).

inputType InputType

The type of input the WGAN will process.

generatorOptimizer IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>

Optional optimizer for the generator. If null, RMSprop optimizer is used (recommended for WGAN).

criticOptimizer IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>

Optional optimizer for the critic. If null, RMSprop optimizer is used (recommended for WGAN).

lossFunction ILossFunction<T>

Optional loss function. Defaults to WassersteinLoss<T> which implements the Wasserstein distance formula. WGAN training uses the critic scores directly for gradient computation, but the WassersteinLoss provides a consistent interface for computing loss values and serialization.

weightClipValue double

The weight clipping threshold. Default is 0.01.

criticIterations int

Number of critic iterations per generator iteration. Default is 5.

Remarks

The WGAN constructor initializes both the generator and critic networks along with their respective optimizers. RMSprop is recommended over Adam for WGAN training stability.

Architecture Validation:

  • Generator output size must match critic input size (generator produces images that critic evaluates)
  • Critic output size must be 1 (outputs a Wasserstein score, not a probability)

About the Loss Function: Unlike traditional GANs that use binary cross-entropy loss, WGAN uses the Wasserstein distance (Earth Mover's distance). By default, WGAN uses WassersteinLoss<T> which implements this mathematically-correct loss function. The actual WGAN training optimizes critic outputs:

  • Critic loss: maximize E[critic(real)] - E[critic(fake)]
  • Generator loss: maximize E[critic(fake)]
The WassersteinLoss computes the same formula: -mean(predicted * label), where label is +1 for real samples and -1 for fake samples.

For Beginners: This sets up the WGAN with sensible defaults.

Key parameters:

  • Generator/critic architectures define the network structures
  • Optimizers control how the networks learn (RMSprop is recommended for WGAN)
  • Weight clipping (0.01) enforces the mathematical constraints
  • Critic iterations (5) means the critic trains 5 times per generator update

About the loss function: WGAN uses the "Wasserstein distance" (also called Earth Mover's distance) to measure how different real and fake images are. By default, we use WassersteinLoss which implements this mathematically. The critic's output is a score (higher = more real-looking), not a probability like in regular GANs. You don't need to specify a loss function - the default WassersteinLoss is the correct choice!

Exceptions

ArgumentNullException

Thrown when generatorArchitecture or criticArchitecture is null.

ArgumentException

Thrown when architecture sizes are incompatible.

ArgumentOutOfRangeException

Thrown when weightClipValue or criticIterations is invalid.

Properties

Critic

Gets the critic network (called discriminator in vanilla GAN) that evaluates data.

public NeuralNetworkBase<T> Critic { get; }

Property Value

NeuralNetworkBase<T>

Remarks

In WGAN, this is called a "critic" rather than "discriminator" because it doesn't output a probability. Instead, it outputs a score that estimates the Wasserstein distance.

For Beginners: The critic is like a discriminator but better.

Critic vs. Discriminator:

  • Discriminator outputs probability (0-1): "Is this real?"
  • Critic outputs a score (any number): "How real is this?"
  • The critic's score directly relates to image quality
  • Higher scores mean more realistic images

Generator

Gets the generator network that creates synthetic data.

public NeuralNetworkBase<T> Generator { get; }

Property Value

NeuralNetworkBase<T>

Methods

CreateNewInstance()

Creates a new instance of the same type as this neural network.

protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()

Returns

IFullModel<T, Tensor<T>, Tensor<T>>

A new instance of the same neural network type.

Remarks

For Beginners: This creates a blank version of the same type of neural network.

It's used internally by methods like DeepCopy and Clone to create the right type of network before copying the data into it.

DeserializeNetworkSpecificData(BinaryReader)

Deserializes network-specific data that was not covered by the general deserialization process.

protected override void DeserializeNetworkSpecificData(BinaryReader reader)

Parameters

reader BinaryReader

The BinaryReader to read the data from.

Remarks

This method is called at the end of the general deserialization process to allow derived classes to read any additional data specific to their implementation.

For Beginners: Continuing the suitcase analogy, this is like unpacking that special compartment. After the main deserialization method has unpacked the common items (layers, parameters), this method allows each specific type of neural network to unpack its own unique items that were stored during serialization.

EvaluateModel(int)

Evaluates the WGAN by generating images and calculating metrics.

public Dictionary<string, double> EvaluateModel(int sampleSize = 100)

Parameters

sampleSize int

The number of samples to generate for evaluation.

Returns

Dictionary<string, double>

A dictionary containing evaluation metrics.

GenerateImages(Tensor<T>)

Generates synthetic images using the generator.

public Tensor<T> GenerateImages(Tensor<T> noise)

Parameters

noise Tensor<T>

The noise tensor to generate images from.

Returns

Tensor<T>

A tensor containing the generated images.

GenerateRandomNoiseTensor(int, int)

Generates a tensor of random noise for the generator.

public Tensor<T> GenerateRandomNoiseTensor(int batchSize, int noiseSize)

Parameters

batchSize int

The number of noise vectors to generate.

noiseSize int

The dimensionality of each noise vector.

Returns

Tensor<T>

A tensor of random noise values.

Remarks

This method uses vectorized Gaussian noise generation from the Engine for optimal performance. The noise is sampled from a standard normal distribution (mean=0, stddev=1).

For Beginners: This creates random input values for the generator.

The random noise serves as the "seed" for generating images:

  • Each batch contains multiple noise vectors
  • Each vector has a fixed size determined by the generator architecture
  • The values are randomly sampled from a bell curve (normal distribution)
  • Different random values will produce different generated images

GetModelMetadata()

Gets the metadata for this neural network model.

public override ModelMetadata<T> GetModelMetadata()

Returns

ModelMetadata<T>

A ModelMetaData object containing information about the model.

InitializeLayers()

Initializes the layers of the neural network based on the architecture.

protected override void InitializeLayers()

Remarks

For Beginners: This method sets up all the layers in your neural network according to the architecture you've defined. It's like assembling the parts of your network before you can use it.

Predict(Tensor<T>)

Makes a prediction using the neural network.

public override Tensor<T> Predict(Tensor<T> input)

Parameters

input Tensor<T>

The input data to process.

Returns

Tensor<T>

The network's prediction.

Remarks

For Beginners: This is the main method you'll use to get results from your trained neural network. You provide some input data (like an image or text), and the network processes it through all its layers to produce an output (like a classification or prediction).

ResetOptimizerState()

Resets both optimizer states for a fresh training run.

public void ResetOptimizerState()

SerializeNetworkSpecificData(BinaryWriter)

Serializes network-specific data that is not covered by the general serialization process.

protected override void SerializeNetworkSpecificData(BinaryWriter writer)

Parameters

writer BinaryWriter

The BinaryWriter to write the data to.

Remarks

This method is called at the end of the general serialization process to allow derived classes to write any additional data specific to their implementation.

For Beginners: Think of this as packing a special compartment in your suitcase. While the main serialization method packs the common items (layers, parameters), this method allows each specific type of neural network to pack its own unique items that other networks might not have.

Train(Tensor<T>, Tensor<T>)

Trains the neural network on a single input-output pair.

public override void Train(Tensor<T> input, Tensor<T> expectedOutput)

Parameters

input Tensor<T>

The input data.

expectedOutput Tensor<T>

The expected output for the given input.

Remarks

This method performs one training step on the neural network using the provided input and expected output. It updates the network's parameters to reduce the error between the network's prediction and the expected output.

For Beginners: This is how your neural network learns. You provide: - An input (what the network should process) - The expected output (what the correct answer should be)

The network then:

  1. Makes a prediction based on the input
  2. Compares its prediction to the expected output
  3. Calculates how wrong it was (the loss)
  4. Adjusts its internal values to do better next time

After training, you can get the loss value using the GetLastLoss() method to see how well the network is learning.

TrainStep(Tensor<T>, Tensor<T>)

Performs one training step for the WGAN using tensor batches.

public (T criticLoss, T generatorLoss) TrainStep(Tensor<T> realImages, Tensor<T> noise)

Parameters

realImages Tensor<T>

A tensor containing real images.

noise Tensor<T>

A tensor containing random noise for the generator.

Returns

(T Accuracy, T Loss)

A tuple containing the critic and generator loss values.

Remarks

This method implements the WGAN training algorithm: 1. Train the critic multiple times (typically 5) with weight clipping 2. Train the generator once 3. The critic is trained to maximize the difference between real and fake scores 4. The generator is trained to maximize the critic's score on fake images

For Beginners: One training round for WGAN.

The training process:

  • Trains the critic several times to make it really good at judging quality
  • Clips the critic's weights to keep it well-behaved
  • Trains the generator once to improve its outputs
  • Returns loss values that actually mean something (higher = better)

UpdateParameters(Vector<T>)

Updates the parameters of both the generator and critic networks.

public override void UpdateParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

A vector containing all parameters for both networks.

Remarks

The parameters vector is split between the generator and critic based on their respective parameter counts. Generator parameters come first, followed by critic parameters.