Class ConditionalGAN<T>
- Namespace
- AiDotNet.NeuralNetworks
- Assembly
- AiDotNet.dll
Represents a Conditional Generative Adversarial Network (cGAN), which generates data conditioned on additional information such as class labels, attributes, or other contextual data.
public class ConditionalGAN<T> : GenerativeAdversarialNetwork<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable, IAuxiliaryLossLayer<T>, IDiagnosticsProvider
Type Parameters
TThe numeric type used for calculations, typically float or double.
- Inheritance
-
ConditionalGAN<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
Conditional GANs extend the basic GAN framework by: - Conditioning both the generator and discriminator on additional information - Allowing controlled generation (e.g., "generate a digit 7") - Enabling class-conditional image synthesis - Providing explicit control over the generated output characteristics
For Beginners: cGAN lets you control what kind of image is generated.
Key features:
- You can specify what you want to generate (e.g., "cat" vs. "dog")
- Both the generator and discriminator see the conditioning information
- Generator: "Given this label, create a matching image"
- Discriminator: "Is this image both real AND matching the label?"
Example use cases:
- Generate a specific digit (0-9) in MNIST
- Create images of specific object classes
- Generate faces with specific attributes (smiling, glasses, etc.)
Reference: Mirza and Osindero, "Conditional Generative Adversarial Nets" (2014)
Constructors
ConditionalGAN(NeuralNetworkArchitecture<T>, NeuralNetworkArchitecture<T>, int, InputType, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, ILossFunction<T>?)
Initializes a new instance of the ConditionalGAN<T> class.
public ConditionalGAN(NeuralNetworkArchitecture<T> generatorArchitecture, NeuralNetworkArchitecture<T> discriminatorArchitecture, int numConditionClasses, InputType inputType, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? generatorOptimizer = null, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? discriminatorOptimizer = null, ILossFunction<T>? lossFunction = null)
Parameters
generatorArchitectureNeuralNetworkArchitecture<T>The neural network architecture for the generator.
discriminatorArchitectureNeuralNetworkArchitecture<T>The neural network architecture for the discriminator.
numConditionClassesintThe number of conditioning classes/categories.
inputTypeInputTypeThe type of input the cGAN will process.
generatorOptimizerIGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>Optional optimizer for the generator. If null, Adam optimizer is used.
discriminatorOptimizerIGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>Optional optimizer for the discriminator. If null, Adam optimizer is used.
lossFunctionILossFunction<T>Optional loss function.
Remarks
This constructor creates a conditional GAN where both the generator and discriminator receive conditioning information. The generator takes noise concatenated with a condition vector, and the discriminator takes an image concatenated with the same condition vector.
For Beginners: This sets up a GAN that can generate specific types of images.
Parameters:
- generatorArchitecture: How the generator network is structured
- discriminatorArchitecture: How the discriminator network is structured
- numConditionClasses: How many different types/classes you have
- inputType: What kind of data (usually images)
- generatorOptimizer/discriminatorOptimizer: Custom learning algorithms (optional)
Methods
CreateNewInstance()
Creates a new instance of the GenerativeAdversarialNetwork with the same configuration as the current instance.
protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
A new GenerativeAdversarialNetwork instance with the same architecture as the current instance.
Remarks
This method creates a new instance of the GenerativeAdversarialNetwork with the same generator and discriminator architectures as the current instance. This is useful for model cloning, ensemble methods, or cross-validation scenarios where multiple instances of the same model with identical configurations are needed.
For Beginners: This method creates a fresh copy of the GAN's blueprint.
When you need multiple versions of the same GAN with identical settings:
- This method creates a new, empty GAN with the same configuration
- It copies the architecture of both the generator and discriminator networks
- The new GAN has the same structure but no trained data
- This is useful for techniques that need multiple models, like ensemble methods
For example, when experimenting with different training approaches, you'd want to start with identical model configurations.
CreateOneHotCondition(int, int)
Creates a one-hot encoded condition tensor.
public Tensor<T> CreateOneHotCondition(int batchSize, int classIndex)
Parameters
Returns
- Tensor<T>
DeserializeNetworkSpecificData(BinaryReader)
Deserializes GAN-specific data from a binary reader.
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
Parameters
readerBinaryReaderThe binary reader to read from.
Remarks
This method loads the state of a previously saved GAN from a binary stream. It restores both the generator and discriminator networks, as well as optimizer parameters like momentum and learning rate settings. This allows training to resume from exactly where it left off, maintaining all networks and parameters.
For Beginners: This loads a complete GAN from a saved file.
When loading the GAN:
- Both the generator and discriminator networks are restored
- The optimizer state (momentum, learning rates, etc.) is recovered
- Recent training history is loaded
- All parameters resume their previous values
This lets you continue working with a model exactly where you left off, or use a model that someone else has trained.
GenerateConditional(Tensor<T>, Tensor<T>)
Generates images conditioned on specific labels.
public Tensor<T> GenerateConditional(Tensor<T> noise, Tensor<T> conditions)
Parameters
noiseTensor<T>conditionsTensor<T>
Returns
- Tensor<T>
GenerateRandomNoiseTensor(int, int)
Generates random noise tensor using vectorized Gaussian noise generation.
public Tensor<T> GenerateRandomNoiseTensor(int batchSize, int noiseSize)
Parameters
Returns
- Tensor<T>
GetModelMetadata()
Gets metadata about the GAN model, including information about both generator and discriminator components.
public override ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
A ModelMetaData object containing information about the GAN.
Remarks
This method returns comprehensive metadata about the GAN, including its architecture, training state, and key parameters. This information is useful for model management, tracking experiments, and reporting. The metadata includes details about both the generator and discriminator networks, as well as optimization settings like the current learning rate.
For Beginners: This provides detailed information about the GAN's configuration and state.
The metadata includes:
- What this model is and what it does (generate synthetic data)
- The architecture details of both the generator and discriminator
- Current training parameters like learning rate
- The model's creation date and type
This information is useful for keeping track of different models, comparing experimental results, and documenting your work.
SerializeNetworkSpecificData(BinaryWriter)
Serializes GAN-specific data to a binary writer.
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
Parameters
writerBinaryWriterThe binary writer to write to.
Remarks
This method saves the state of the GAN to a binary stream. It serializes both the generator and discriminator networks, as well as optimizer parameters like momentum and learning rate settings. This allows the GAN to be restored later with its full state intact, including both networks and training parameters.
For Beginners: This saves the complete state of the GAN to a file.
When saving the GAN:
- Both the generator and discriminator networks are saved
- The optimizer state (momentum, learning rates, etc.) is saved
- Recent training history is saved
- All the parameters needed to resume training are preserved
This allows you to save your progress and continue training later, share trained models with others, or deploy them in applications.
Train(Tensor<T>, Tensor<T>)
Trains the conditional GAN on a batch of data.
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)
Parameters
inputTensor<T>The input noise tensor for the generator.
expectedOutputTensor<T>The tensor containing real images.
Remarks
This method implements the standard Train interface by: 1. Generating random conditions for training 2. Using the input as noise for the generator 3. Using expectedOutput as the real images for the discriminator 4. Delegating to TrainStep for the actual training
For Beginners: This is the main training method that follows the base class contract.
How it works:
- The 'input' tensor is used as random noise for the generator
- The 'expectedOutput' tensor contains real images to train the discriminator
- Random class conditions are generated for conditional training
- Both generator and discriminator are updated in each call
TrainStep(Tensor<T>, Tensor<T>, Tensor<T>)
Performs one training step for the conditional GAN.
public (T discriminatorLoss, T generatorLoss) TrainStep(Tensor<T> realImages, Tensor<T> conditions, Tensor<T> noise)
Parameters
realImagesTensor<T>A tensor containing real images.
conditionsTensor<T>A tensor containing conditioning labels (one-hot encoded).
noiseTensor<T>A tensor containing random noise for the generator.
Returns
Remarks
This method trains both the generator and discriminator with conditioning information: 1. Train discriminator on real images with their true labels 2. Train discriminator on fake images with the generator's conditioning labels 3. Train generator to create images that fool the discriminator for the given conditions
For Beginners: One training round for conditional GAN.
The training process:
- Discriminator learns to verify image-label pairs are correct
- Generator learns to create images matching the specified labels
- Both networks use the conditioning information to guide learning