Class GenerativeAdversarialNetwork<T>
- Namespace
- AiDotNet.NeuralNetworks
- Assembly
- AiDotNet.dll
Represents a Generative Adversarial Network (GAN), a deep learning architecture that consists of two neural networks (a generator and a discriminator) competing against each other in a zero-sum game.
public class GenerativeAdversarialNetwork<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable, IAuxiliaryLossLayer<T>, IDiagnosticsProvider
Type Parameters
TThe numeric type used for calculations, typically float or double.
- Inheritance
-
GenerativeAdversarialNetwork<T>
- Implements
- Derived
- Inherited Members
- Extension Methods
Remarks
A Generative Adversarial Network (GAN) is a powerful machine learning architecture that uses two neural networks - a generator and a discriminator - that are trained simultaneously through adversarial training. The generator network learns to create realistic synthetic data samples (like images), while the discriminator network learns to distinguish between real data and the generator's synthetic outputs. As training progresses, the generator becomes better at creating realistic data, and the discriminator becomes better at distinguishing real from fake, pushing each other to improve in a competitive process.
For Beginners: A GAN is like an art forger and an art detective competing against each other.
Think of it this way:
- The generator is like an art forger trying to create fake paintings that look real
- The discriminator is like an art detective trying to tell which paintings are real and which are fake
- As the forger gets better, the detective has to improve too
- As the detective gets better, the forger is forced to create more convincing fakes
- Eventually, the forger becomes so good that their fake paintings are nearly indistinguishable from real ones
This continuous competition drives both networks to improve, resulting in a generator that can create remarkably realistic synthetic data like images, music, or text.
Constructors
GenerativeAdversarialNetwork(NeuralNetworkArchitecture<T>, NeuralNetworkArchitecture<T>, InputType, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, ILossFunction<T>?)
Initializes a new instance of the GenerativeAdversarialNetwork<T> class.
public GenerativeAdversarialNetwork(NeuralNetworkArchitecture<T> generatorArchitecture, NeuralNetworkArchitecture<T> discriminatorArchitecture, InputType inputType, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? generatorOptimizer = null, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? discriminatorOptimizer = null, ILossFunction<T>? lossFunction = null)
Parameters
generatorArchitectureNeuralNetworkArchitecture<T>The neural network architecture for the generator.
discriminatorArchitectureNeuralNetworkArchitecture<T>The neural network architecture for the discriminator.
inputTypeInputTypeThe type of input the GAN will process.
generatorOptimizerIGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>The optimizer for the generator. If null, Adam optimizer is used.
discriminatorOptimizerIGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>The optimizer for the discriminator. If null, Adam optimizer is used.
lossFunctionILossFunction<T>The loss function used to compute loss values during training.
Remarks
This constructor initializes a new Generative Adversarial Network with the specified generator and discriminator architectures. It also sets up the optimization parameters and initializes tracking collections for monitoring training progress. The GAN's architecture is a combination of the generator and discriminator architectures.
For Beginners: This sets up the complete GAN system with both networks.
When creating a new GAN:
- You provide separate architectures for the generator and discriminator
- You can optionally provide custom optimizers for each network
- The inputType specifies what kind of data the GAN will work with
- If you don't specify optimizers, Adam optimizer is used by default
Think of it like establishing the rules and roles for the forger and detective before their competition begins.
Properties
AuxiliaryLossWeight
Gets or sets the weight for auxiliary losses (gradient penalty, feature matching). Default is 10.0 for gradient penalty (standard for WGAN-GP).
public T AuxiliaryLossWeight { get; set; }
Property Value
- T
Discriminator
Gets the discriminator network that distinguishes between real and synthetic data.
public NeuralNetworkBase<T> Discriminator { get; }
Property Value
- NeuralNetworkBase<T>
A neural network that classifies data as real or synthetic.
Remarks
The Discriminator is a neural network that takes data (either real or generated) as input and outputs a probability that the data is real. During training, it learns to better distinguish between real data and the Generator's synthetic data. The network type is chosen based on the input type (FeedForward for 1D, Convolutional for 3D, etc.).
For Beginners: This is the "detective" network that tries to spot fakes.
Think of the Discriminator as:
- An art expert examining paintings to determine if they're authentic
- It analyzes data (like images) and gives a probability that it's real
- Its goal is to correctly identify real data and detect generated fakes
- It improves by learning from its mistakes
For example, in an image generation task, the Discriminator essentially answers the question "Is this a real photograph or a computer-generated image?" and becomes increasingly sophisticated in its ability to tell the difference.
DiscriminatorOptimizer
Gets the optimizer used for updating discriminator parameters.
protected IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>> DiscriminatorOptimizer { get; }
Property Value
- IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>
Remarks
Provides access to the discriminator's optimizer for derived classes that need custom training logic.
FeatureMatchingLayers
Gets or sets the indices of discriminator layers to use for feature matching. If null, uses middle layers by default.
public int[]? FeatureMatchingLayers { get; set; }
Property Value
- int[]
Array of layer indices, or null to use defaults.
Remarks
Specifies which discriminator layers to extract features from for feature matching. Typically, intermediate layers (not too early, not too late) work best. If null, the implementation will automatically select appropriate middle layers.
For Beginners: This chooses which internal layers to compare.
Layer selection matters:
- Early layers capture low-level features (edges, textures)
- Middle layers capture mid-level features (shapes, parts)
- Late layers capture high-level features (object identity)
- If not specified, sensible defaults are used automatically
FeatureMatchingWeight
Gets or sets the weight applied to the feature matching loss.
public double FeatureMatchingWeight { get; set; }
Property Value
- double
The multiplier for the feature matching loss component.
Remarks
This weight balances the feature matching loss against the standard adversarial loss. Typical values range from 0.1 to 1.0. Higher values make the generator focus more on matching feature statistics rather than fooling the discriminator directly.
For Beginners: This controls how much the generator focuses on feature matching.
The weight determines:
- How much to prioritize matching internal patterns vs. fooling the discriminator
- Higher values mean more focus on feature matching
- Lower values mean more focus on the adversarial objective
- Typical values are around 0.1 to 1.0
Generator
Gets the generator network that creates synthetic data.
public NeuralNetworkBase<T> Generator { get; }
Property Value
- NeuralNetworkBase<T>
A convolutional neural network that generates synthetic data.
Remarks
The Generator is a neural network that takes random noise as input and produces synthetic data (such as images) as output. During training, it learns to create increasingly realistic data that can fool the Discriminator. In this implementation, it's specifically a convolutional neural network, which is well-suited for image generation tasks.
For Beginners: This is the "forger" network that creates fake data.
Think of the Generator as:
- An artist creating paintings from random starting points
- It takes random noise (like static) and shapes it into structured data
- Its goal is to create outputs so realistic they fool the Discriminator
- It improves by learning from the feedback of the Discriminator
For example, in an image generation task, the Generator might start by creating blurry, unrealistic images, but gradually learn to create sharp, detailed, and realistic images as training progresses.
GeneratorOptimizer
Gets the optimizer used for updating generator parameters.
protected IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>> GeneratorOptimizer { get; }
Property Value
- IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>
Remarks
Provides access to the generator's optimizer for derived classes that need custom training logic.
UseAuxiliaryLoss
Gets or sets whether to use auxiliary losses (gradient penalty, feature matching) during training. Default is true for improved training stability.
public bool UseAuxiliaryLoss { get; set; }
Property Value
UseFeatureMatching
Gets or sets whether feature matching is enabled for generator training.
public bool UseFeatureMatching { get; set; }
Property Value
- bool
True if feature matching should be used; false otherwise.
Remarks
Feature matching is a technique from Salimans et al. (2016) that helps stabilize GAN training and prevent mode collapse. Instead of training the generator to fool the discriminator directly, it trains the generator to match the statistics of real data features at intermediate layers of the discriminator.
For Beginners: This enables a more stable way of training the generator.
Instead of just trying to fool the discriminator:
- The generator learns to match the internal patterns of real images
- This helps create more diverse and realistic outputs
- It reduces the risk of mode collapse (generating the same image repeatedly)
- Training tends to be more stable with this enabled
Important: When enabled, batches are stored but ComputeFeatureMatchingLoss() must be manually called and added to the generator loss in your training loop. The base Train() method does not automatically integrate this loss. See ComputeFeatureMatchingLoss() documentation for integration examples.
Methods
ComputeAuxiliaryLoss()
Computes the auxiliary loss for the GAN, which includes gradient penalty and feature matching losses.
public T ComputeAuxiliaryLoss()
Returns
- T
The total auxiliary loss value.
Remarks
This method computes auxiliary losses that improve GAN training stability: - Gradient Penalty (WGAN-GP): Penalizes deviations from gradient norm of 1 - Feature Matching: Encourages matching statistics of intermediate activations
For Beginners: This calculates extra losses that make training more stable.
The auxiliary losses:
- Gradient Penalty: Keeps the discriminator's gradients well-behaved
- Feature Matching: Encourages realistic feature distributions
- Combined, they prevent common GAN training problems like mode collapse
- Make training more reliable and convergent
ComputeFeatureMatchingLoss()
Computes the feature matching loss between real and generated data.
public T ComputeFeatureMatchingLoss()
Returns
- T
The feature matching loss value.
Remarks
This method implements feature matching loss from Salimans et al. (2016). Instead of training the generator to maximize discriminator confusion directly, it trains the generator to match the statistics (mean activations) of real data at intermediate layers of the discriminator. This approach helps stabilize training and prevent mode collapse.
The loss is computed as the L2 distance between the mean feature activations of real and generated samples across specified discriminator layers. If no layers are specified via FeatureMatchingLayers, the method automatically selects middle layers of the discriminator.
For Beginners: This measures how well generated images match real images internally.
How it works:
- Pass real images through the discriminator and extract internal features
- Pass fake images through the discriminator and extract the same features
- Compare the average features from real vs. fake images
- Return a score showing how different they are
Why this helps:
- Forces the generator to match internal patterns, not just fool the discriminator
- Helps create more diverse outputs (prevents mode collapse)
- Makes training more stable
- Results in more realistic generated images
The loss should be minimized during generator training, typically weighted and combined with the standard adversarial loss.
ComputeGradientPenalty(Tensor<T>, Tensor<T>, double)
Computes the gradient penalty for WGAN-GP (Wasserstein GAN with Gradient Penalty).
public T ComputeGradientPenalty(Tensor<T> realSamples, Tensor<T> fakeSamples, double lambda = 10)
Parameters
realSamplesTensor<T>Batch of real samples.
fakeSamplesTensor<T>Batch of generated (fake) samples.
lambdadoubleWeight for the gradient penalty term (default: 10.0).
Returns
- T
The gradient penalty loss value.
Remarks
This method implements the gradient penalty from Gulrajani et al. (2017) "Improved Training of Wasserstein GANs". The gradient penalty enforces the Lipschitz constraint by penalizing the discriminator when gradients deviate from unit norm at interpolated points between real and fake samples.
The penalty is computed as: λ * E[(||∇_x D(x)|| - 1)²] where x is sampled uniformly along straight lines between real and fake samples. This replaces weight clipping and leads to more stable training and higher quality results.
This implementation uses symbolic differentiation (autodiff) to compute gradients with respect to the input. This is more accurate and efficient than numerical differentiation and is the standard approach in modern WGAN-GP implementations.
For Beginners: This helps stabilize WGAN training by constraining gradients.
How it works:
- Create interpolated samples between real and fake images (mix them randomly)
- Compute how the discriminator output changes with respect to input (gradient)
- Measure how far the gradient norm is from 1.0
- Penalize the discriminator if gradients are too large or too small
Why this helps:
- Enforces the mathematical constraint needed for Wasserstein distance
- Prevents discriminator gradients from exploding or vanishing
- More stable than weight clipping (older WGAN approach)
- Results in higher quality generated images
Important: This method computes the gradient penalty but does not automatically integrate it into training. To use WGAN-GP, you must:
- Call this method during discriminator training
- Add the returned penalty to the discriminator loss
- Use the combined loss to update discriminator parameters
The base Train() method does not automatically include gradient penalty. Typical lambda values are 10.0 for images.
CreateNewInstance()
Creates a new instance of the GenerativeAdversarialNetwork with the same configuration as the current instance.
protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
A new GenerativeAdversarialNetwork instance with the same architecture as the current instance.
Remarks
This method creates a new instance of the GenerativeAdversarialNetwork with the same generator and discriminator architectures as the current instance. This is useful for model cloning, ensemble methods, or cross-validation scenarios where multiple instances of the same model with identical configurations are needed.
For Beginners: This method creates a fresh copy of the GAN's blueprint.
When you need multiple versions of the same GAN with identical settings:
- This method creates a new, empty GAN with the same configuration
- It copies the architecture of both the generator and discriminator networks
- The new GAN has the same structure but no trained data
- This is useful for techniques that need multiple models, like ensemble methods
For example, when experimenting with different training approaches, you'd want to start with identical model configurations.
DeserializeNetworkSpecificData(BinaryReader)
Deserializes GAN-specific data from a binary reader.
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
Parameters
readerBinaryReaderThe binary reader to read from.
Remarks
This method loads the state of a previously saved GAN from a binary stream. It restores both the generator and discriminator networks, as well as optimizer parameters like momentum and learning rate settings. This allows training to resume from exactly where it left off, maintaining all networks and parameters.
For Beginners: This loads a complete GAN from a saved file.
When loading the GAN:
- Both the generator and discriminator networks are restored
- The optimizer state (momentum, learning rates, etc.) is recovered
- Recent training history is loaded
- All parameters resume their previous values
This lets you continue working with a model exactly where you left off, or use a model that someone else has trained.
DiscriminateImages(Tensor<T>)
Evaluates how real a batch of images appears to the discriminator.
public Tensor<T> DiscriminateImages(Tensor<T> images)
Parameters
imagesTensor<T>The tensor containing images to evaluate.
Returns
- Tensor<T>
A tensor containing discriminator scores for each image.
Remarks
This method evaluates how realistic a batch of images appears to the discriminator, returning a score between 0 and 1 for each image where higher values indicate more realistic images.
For Beginners: This checks how convincing the generated images are.
The discriminator's evaluation:
- Examines each image in the batch
- Scores each image between 0 and 1
- Higher scores mean more convincing/realistic images
- Provides feedback on the generator's performance
EnableFeatureMatching(bool)
Enables feature matching loss to encourage the generator to match statistics of real data.
public void EnableFeatureMatching(bool enable = true)
Parameters
enableboolWhether to enable feature matching.
Remarks
Feature matching encourages the generator to match the statistics of intermediate layer activations of real data, rather than directly maximizing the discriminator output. This can improve training stability.
For Beginners: This helps the generator create more realistic data.
Feature matching:
- Makes the generator match patterns found in real data
- Works at a deeper level than just fooling the discriminator
- Improves diversity and realism of generated samples
- Helps prevent mode collapse
EnableGradientPenalty(bool)
Enables gradient penalty (WGAN-GP) for improved training stability.
public void EnableGradientPenalty(bool enable = true)
Parameters
enableboolWhether to enable gradient penalty.
Remarks
Gradient penalty is a regularization technique used in Wasserstein GANs with Gradient Penalty (WGAN-GP). It enforces the Lipschitz constraint by penalizing the gradient norm deviation from 1, which stabilizes training.
For Beginners: This helps prevent training instability.
Gradient penalty:
- Adds a regularization term that keeps gradients under control
- Prevents mode collapse (when the generator produces limited variety)
- Improves convergence and stability
- Is standard practice in modern GAN training (WGAN-GP)
EvaluateModel(int)
Evaluates the GAN by generating a batch of images and calculating metrics for their quality.
public Dictionary<string, double> EvaluateModel(int sampleSize = 100)
Parameters
sampleSizeintThe number of images to generate for evaluation.
Returns
- Dictionary<string, double>
A dictionary containing evaluation metrics.
Remarks
This tensor-based method evaluates the current performance of the GAN by generating a batch of images and calculating several metrics. It computes statistics on the discriminator scores, checks for diversity in the outputs, and detects potential mode collapse. This efficient implementation processes all images in parallel for better performance.
For Beginners: This tests how well the GAN is performing using batch processing.
The tensor-based evaluation:
- Generates multiple sample images in a single batch operation
- Has the discriminator score all images at once
- Calculates statistics like average score and diversity measures
- Identifies potential issues like mode collapse
This provides comprehensive metrics on GAN performance in an efficient manner.
EvaluateModelWithTensors(int)
Evaluates the GAN using tensor operations.
public Dictionary<string, double> EvaluateModelWithTensors(int sampleSize = 100)
Parameters
sampleSizeintThe number of images to generate for evaluation.
Returns
- Dictionary<string, double>
A dictionary containing evaluation metrics.
Remarks
This method evaluates GAN performance by generating images and calculating metrics using tensor operations throughout. This provides a more efficient evaluation compared to the previous vector-based approach.
For Beginners: This tests how well the GAN is performing.
The tensor-based evaluation:
- Generates multiple images in a single batch operation
- Has the discriminator evaluate all images at once
- Calculates statistics on the quality and diversity of the outputs
- Provides metrics to track training progress
GenerateImages(Tensor<T>)
Generates synthetic images using tensor operations.
public Tensor<T> GenerateImages(Tensor<T> noise)
Parameters
noiseTensor<T>The tensor containing the noise input.
Returns
- Tensor<T>
A tensor containing generated images.
Remarks
This method generates synthetic images by passing noise through the generator network using tensor operations. It supports both single inputs and batches.
For Beginners: This creates fake images from random noise patterns.
The process:
- Takes random noise as input (the creative inspiration)
- Passes it through the generator network
- Produces synthetic images as output
- Works efficiently with batches of inputs
GenerateQualityImages(int, double)
Generates high-quality images by filtering based on discriminator scores.
public Tensor<T> GenerateQualityImages(int count, double minDiscriminatorScore = 0.7)
Parameters
countintThe number of images to generate.
minDiscriminatorScoredoubleThe minimum score threshold for quality images.
Returns
- Tensor<T>
A tensor of images that meet the quality threshold.
Remarks
This method generates multiple images and filters them based on discriminator scores. Only images that exceed the specified quality threshold are returned, ensuring better overall output quality.
For Beginners: This creates multiple high-quality fake images.
The process:
- Generates more images than requested
- Uses the discriminator to evaluate their quality
- Keeps only the most convincing/realistic ones
- Returns images that meet your quality standards
This is useful for applications where you want only the best outputs.
GenerateRandomNoiseTensor(int, int)
Generates a tensor of random noise for the generator.
public Tensor<T> GenerateRandomNoiseTensor(int batchSize, int noiseSize)
Parameters
Returns
- Tensor<T>
A tensor containing random noise from a normal distribution.
Remarks
This method efficiently generates a batch of noise vectors using a normal distribution, which serves as input to the generator for creating synthetic images.
For Beginners: This creates random starting points for image generation.
The noise generation:
- Creates multiple random inputs in a single operation
- Uses a normal distribution (bell curve) for better results
- Each noise pattern will result in a different image
- Efficient batch processing for better performance
GetAuxiliaryLossDiagnostics()
Gets diagnostic information about the auxiliary losses.
public Dictionary<string, string> GetAuxiliaryLossDiagnostics()
Returns
- Dictionary<string, string>
A dictionary containing diagnostic information about GAN training.
Remarks
This method provides insights into GAN training dynamics, including: - Generator and discriminator losses - Gradient penalty values - Feature matching statistics - Wasserstein distance estimates
For Beginners: This gives you information to track GAN training health.
The diagnostics include:
- Generator Loss: How well the generator is fooling the discriminator
- Discriminator Loss: How well the discriminator is distinguishing real from fake
- Gradient Penalty: The regularization term value
- Feature Matching: How well features match between real and fake data
- Wasserstein Distance: An estimate of the distribution distance (for WGAN)
These help you:
- Detect training instabilities early
- Monitor convergence progress
- Tune hyperparameters effectively
- Diagnose issues like mode collapse
GetDiagnostics()
Gets diagnostic information about this component's state and behavior. Provides GAN-specific auxiliary loss diagnostics.
public Dictionary<string, string> GetDiagnostics()
Returns
- Dictionary<string, string>
A dictionary containing diagnostic metrics including auxiliary loss diagnostics from GetAuxiliaryLossDiagnostics().
GetModelMetadata()
Gets metadata about the GAN model, including information about both generator and discriminator components.
public override ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
A ModelMetaData object containing information about the GAN.
Remarks
This method returns comprehensive metadata about the GAN, including its architecture, training state, and key parameters. This information is useful for model management, tracking experiments, and reporting. The metadata includes details about both the generator and discriminator networks, as well as optimization settings like the current learning rate.
For Beginners: This provides detailed information about the GAN's configuration and state.
The metadata includes:
- What this model is and what it does (generate synthetic data)
- The architecture details of both the generator and discriminator
- Current training parameters like learning rate
- The model's creation date and type
This information is useful for keeping track of different models, comparing experimental results, and documenting your work.
InitializeLayers()
Initializes the layers of the Generative Adversarial Network.
protected override void InitializeLayers()
Remarks
This method is overridden from the base class but is empty because a GAN doesn't use layers directly. Instead, the GAN architecture consists of two separate neural networks (Generator and Discriminator) that each have their own layers. These networks are initialized separately in the constructor.
For Beginners: This method is empty because GANs work differently from standard neural networks.
Unlike traditional neural networks:
- GANs don't have a single sequence of layers
- Instead, they consist of two separate networks (Generator and Discriminator)
- Each of these networks has its own layers
- These networks are initialized separately in the constructor
This method is only included because it's required by the base class, but it doesn't need to do anything in a GAN implementation.
Predict(Tensor<T>)
Performs a forward pass through the generator network using a tensor input.
public override Tensor<T> Predict(Tensor<T> input)
Parameters
inputTensor<T>The input tensor containing noise vectors to generate images from.
Returns
- Tensor<T>
A tensor containing the generated images.
Remarks
This method is part of the INeuralNetwork interface implementation. In the context of a GAN, "prediction" means using the generator to create synthetic data from random noise input. The method supports batch processing by handling tensor inputs that may contain multiple noise vectors.
For Beginners: This method creates synthetic images from random noise.
When you call Predict:
- The input tensor contains one or more random noise patterns
- These noise patterns are passed through the generator network
- The generator transforms the noise into synthetic images
- The output tensor contains the resulting synthetic images
This is the same underlying process as GenerateImage(), but works with tensors instead of vectors, allowing for batch processing of multiple inputs at once.
SerializeNetworkSpecificData(BinaryWriter)
Serializes GAN-specific data to a binary writer.
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
Parameters
writerBinaryWriterThe binary writer to write to.
Remarks
This method saves the state of the GAN to a binary stream. It serializes both the generator and discriminator networks, as well as optimizer parameters like momentum and learning rate settings. This allows the GAN to be restored later with its full state intact, including both networks and training parameters.
For Beginners: This saves the complete state of the GAN to a file.
When saving the GAN:
- Both the generator and discriminator networks are saved
- The optimizer state (momentum, learning rates, etc.) is saved
- Recent training history is saved
- All the parameters needed to resume training are preserved
This allows you to save your progress and continue training later, share trained models with others, or deploy them in applications.
Train(Tensor<T>, Tensor<T>)
Trains both the generator and discriminator using tensor-based operations throughout.
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)
Parameters
inputTensor<T>The noise input for the generator (for batch training).
expectedOutputTensor<T>The real images used to train the discriminator.
Remarks
This tensor-native implementation trains both networks efficiently by processing entire batches at once through tensor operations. It eliminates the vector conversion overhead from the previous implementation.
For Beginners: This trains both networks efficiently with multiple examples.
The fully tensor-based training process:
- Processes entire batches of data in parallel
- Trains the discriminator on both real and fake images
- Trains the generator to create more convincing fake images
- Updates the networks using batch operations for better performance
TrainStep(Tensor<T>, Tensor<T>)
Performs one step of training for both the generator and discriminator using tensor batches.
public (T discriminatorLoss, T generatorLoss) TrainStep(Tensor<T> realImages, Tensor<T> noise)
Parameters
realImagesTensor<T>A tensor containing real images for training the discriminator.
noiseTensor<T>A tensor containing random noise for training the generator.
Returns
Remarks
This method performs one complete training iteration for the GAN using tensor-based operations for maximum efficiency. It first trains the Discriminator on batches of both real images and fake images generated by the Generator. Then it trains the Generator to create images that can fool the Discriminator. This adversarial training process is optimized for batch processing.
For Beginners: This is one round of the competition between generator and discriminator.
The tensor-based training step:
- Processes entire batches of images in parallel
- First trains the discriminator on both real and generated images
- Then trains the generator to create more convincing fake images
- Returns the loss values to track progress
- Is much faster than training with individual vectors
This efficient implementation is critical for training GANs in reasonable timeframes.
UpdateParameters(Vector<T>)
Updates the parameters of both the Generator and Discriminator networks.
public override void UpdateParameters(Vector<T> parameters)
Parameters
parametersVector<T>A vector containing the combined parameters for both networks.
Remarks
This method splits the incoming parameter vector between the Generator and Discriminator, updates each network accordingly, and adjusts the learning rate based on the magnitude of parameter changes. It also includes a mechanism to reset the optimizer state if exceptionally large changes are detected.
For Beginners: This method updates both parts of the GAN at once.
The process:
- Splits the incoming parameters between Generator and Discriminator
- Updates each network with its respective parameters
- Adjusts the learning rate based on how big the changes are
- If changes are very large, it resets some internal values to stabilize training
This approach allows for efficient updating of the entire GAN structure.