Table of Contents

Class WGANGP<T>

Namespace
AiDotNet.NeuralNetworks
Assembly
AiDotNet.dll

Represents a Wasserstein GAN with Gradient Penalty (WGAN-GP), an improved version of WGAN that uses gradient penalty instead of weight clipping to enforce the Lipschitz constraint.

public class WGANGP<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable

Type Parameters

T

The numeric type used for calculations, typically float or double.

Inheritance
WGANGP<T>
Implements
IFullModel<T, Tensor<T>, Tensor<T>>
IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>
IParameterizable<T, Tensor<T>, Tensor<T>>
ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>
IGradientComputable<T, Tensor<T>, Tensor<T>>
Inherited Members
Extension Methods

Remarks

WGAN-GP improves upon WGAN by: - Replacing weight clipping with a gradient penalty term - Providing smoother and more stable training - Avoiding pathological behavior caused by weight clipping - Achieving better performance and convergence - Eliminating the need to tune the clipping threshold

For Beginners: WGAN-GP is an enhanced version of WGAN with better training stability.

Key improvements over WGAN:

  • Uses a "gradient penalty" instead of hard weight limits
  • This penalty gently guides the critic to behave correctly
  • More stable and reliable training
  • Produces higher quality results
  • Easier to use (fewer hyperparameters to tune)

The gradient penalty ensures the critic learns smoothly without the problems that weight clipping can cause.

Reference: Gulrajani et al., "Improved Training of Wasserstein GANs" (2017)

Constructors

WGANGP(NeuralNetworkArchitecture<T>, NeuralNetworkArchitecture<T>, InputType, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, ILossFunction<T>?, double, int)

Initializes a new instance of the WGANGP<T> class.

public WGANGP(NeuralNetworkArchitecture<T> generatorArchitecture, NeuralNetworkArchitecture<T> criticArchitecture, InputType inputType, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? generatorOptimizer = null, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? criticOptimizer = null, ILossFunction<T>? lossFunction = null, double gradientPenaltyCoefficient = 10, int criticIterations = 5)

Parameters

generatorArchitecture NeuralNetworkArchitecture<T>

The neural network architecture for the generator.

criticArchitecture NeuralNetworkArchitecture<T>

The neural network architecture for the critic.

inputType InputType

The type of input the WGAN-GP will process.

generatorOptimizer IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>

Optional optimizer for the generator. If null, Adam optimizer is used.

criticOptimizer IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>

Optional optimizer for the critic. If null, Adam optimizer is used.

lossFunction ILossFunction<T>

Optional loss function.

gradientPenaltyCoefficient double

The gradient penalty coefficient (lambda). Default is 10.0.

criticIterations int

Number of critic iterations per generator iteration. Default is 5.

Remarks

The WGAN-GP constructor initializes both the generator and critic networks along with their respective optimizers. The gradient penalty coefficient controls the strength of the Lipschitz constraint enforcement.

For Beginners: This sets up the WGAN-GP with sensible defaults.

Key parameters:

  • Generator/critic architectures define the network structures
  • Optimizers control how the networks learn
  • Gradient penalty coefficient (10.0) controls constraint strength
  • Critic iterations (5) means the critic trains 5 times per generator update

Properties

Critic

Gets the critic network that evaluates data quality.

public ConvolutionalNeuralNetwork<T> Critic { get; }

Property Value

ConvolutionalNeuralNetwork<T>

Generator

Gets the generator network that creates synthetic data.

public ConvolutionalNeuralNetwork<T> Generator { get; }

Property Value

ConvolutionalNeuralNetwork<T>

ParameterCount

Gets the total number of trainable parameters in the WGAN-GP.

public override int ParameterCount { get; }

Property Value

int

Methods

CreateNewInstance()

Creates a new instance of the same type as this neural network.

protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()

Returns

IFullModel<T, Tensor<T>, Tensor<T>>

A new instance of the same neural network type.

Remarks

For Beginners: This creates a blank version of the same type of neural network.

It's used internally by methods like DeepCopy and Clone to create the right type of network before copying the data into it.

DeserializeNetworkSpecificData(BinaryReader)

Deserializes network-specific data that was not covered by the general deserialization process.

protected override void DeserializeNetworkSpecificData(BinaryReader reader)

Parameters

reader BinaryReader

The BinaryReader to read the data from.

Remarks

This method is called at the end of the general deserialization process to allow derived classes to read any additional data specific to their implementation.

For Beginners: Continuing the suitcase analogy, this is like unpacking that special compartment. After the main deserialization method has unpacked the common items (layers, parameters), this method allows each specific type of neural network to unpack its own unique items that were stored during serialization.

EvaluateModel(int)

Evaluates the WGAN-GP by generating images and calculating metrics.

public Dictionary<string, double> EvaluateModel(int sampleSize = 100)

Parameters

sampleSize int

The number of samples to generate for evaluation.

Returns

Dictionary<string, double>

A dictionary containing evaluation metrics.

GenerateImages(Tensor<T>)

Generates synthetic images using the generator.

public Tensor<T> GenerateImages(Tensor<T> noise)

Parameters

noise Tensor<T>

The noise tensor to generate images from.

Returns

Tensor<T>

A tensor containing the generated images.

GenerateRandomNoiseTensor(int, int)

Generates a tensor of random noise for the generator.

public Tensor<T> GenerateRandomNoiseTensor(int batchSize, int noiseSize)

Parameters

batchSize int

The number of noise vectors to generate.

noiseSize int

The dimensionality of each noise vector.

Returns

Tensor<T>

A tensor of random noise values.

Remarks

This method uses vectorized Gaussian noise generation for optimal performance. The generated noise has mean 0 and standard deviation 1, following the standard normal distribution recommended for GAN training.

GetModelMetadata()

Gets the metadata for this neural network model.

public override ModelMetadata<T> GetModelMetadata()

Returns

ModelMetadata<T>

A ModelMetaData object containing information about the model.

InitializeLayers()

Initializes the layers of the neural network based on the architecture.

protected override void InitializeLayers()

Remarks

For Beginners: This method sets up all the layers in your neural network according to the architecture you've defined. It's like assembling the parts of your network before you can use it.

Predict(Tensor<T>)

Makes a prediction using the neural network.

public override Tensor<T> Predict(Tensor<T> input)

Parameters

input Tensor<T>

The input data to process.

Returns

Tensor<T>

The network's prediction.

Remarks

For Beginners: This is the main method you'll use to get results from your trained neural network. You provide some input data (like an image or text), and the network processes it through all its layers to produce an output (like a classification or prediction).

ResetOptimizerState()

Resets both optimizer states for a fresh training run.

public void ResetOptimizerState()

SerializeNetworkSpecificData(BinaryWriter)

Serializes network-specific data that is not covered by the general serialization process.

protected override void SerializeNetworkSpecificData(BinaryWriter writer)

Parameters

writer BinaryWriter

The BinaryWriter to write the data to.

Remarks

This method is called at the end of the general serialization process to allow derived classes to write any additional data specific to their implementation.

For Beginners: Think of this as packing a special compartment in your suitcase. While the main serialization method packs the common items (layers, parameters), this method allows each specific type of neural network to pack its own unique items that other networks might not have.

Train(Tensor<T>, Tensor<T>)

Trains the neural network on a single input-output pair.

public override void Train(Tensor<T> input, Tensor<T> expectedOutput)

Parameters

input Tensor<T>

The input data.

expectedOutput Tensor<T>

The expected output for the given input.

Remarks

This method performs one training step on the neural network using the provided input and expected output. It updates the network's parameters to reduce the error between the network's prediction and the expected output.

For Beginners: This is how your neural network learns. You provide: - An input (what the network should process) - The expected output (what the correct answer should be)

The network then:

  1. Makes a prediction based on the input
  2. Compares its prediction to the expected output
  3. Calculates how wrong it was (the loss)
  4. Adjusts its internal values to do better next time

After training, you can get the loss value using the GetLastLoss() method to see how well the network is learning.

TrainStep(Tensor<T>, Tensor<T>)

Performs one training step for the WGAN-GP using tensor batches.

public (T criticLoss, T generatorLoss) TrainStep(Tensor<T> realImages, Tensor<T> noise)

Parameters

realImages Tensor<T>

A tensor containing real images.

noise Tensor<T>

A tensor containing random noise for the generator.

Returns

(T Accuracy, T Loss)

A tuple containing the critic loss (including gradient penalty) and generator loss.

Remarks

This method implements the WGAN-GP training algorithm: 1. Train the critic multiple times with gradient penalty 2. For each critic update, compute the gradient penalty on interpolated samples 3. Train the generator once to maximize the critic's score on fake images

For Beginners: One training round for WGAN-GP.

The training process:

  • Trains the critic several times with gradient penalty
  • The gradient penalty keeps the critic well-behaved
  • Trains the generator once to improve
  • Returns loss values for monitoring progress

UpdateParameters(Vector<T>)

Updates the parameters of both the generator and critic networks.

public override void UpdateParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

A vector containing all parameters for both networks.