Class InfoGAN<T>
- Namespace
- AiDotNet.NeuralNetworks
- Assembly
- AiDotNet.dll
Represents an Information Maximizing Generative Adversarial Network (InfoGAN), which learns disentangled representations in an unsupervised manner by maximizing mutual information between latent codes and generated observations.
public class InfoGAN<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable
Type Parameters
TThe numeric type used for calculations, typically float or double.
- Inheritance
-
InfoGAN<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
InfoGAN extends the GAN framework by: - Decomposing the input noise into incompressible noise (z) and latent codes (c) - Maximizing the mutual information I(c; G(z,c)) between codes and generated images - Learning interpretable and disentangled representations automatically - Using an auxiliary network Q to approximate the posterior P(c|x) - Enabling control over semantic features without labeled data
For Beginners: InfoGAN learns to separate different features automatically.
Key concept:
- Splits random input into two parts:
- Random noise (z): provides variety
- Latent codes (c): control specific features
- Learns what each code controls WITHOUT labels
- Example: For faces, might learn codes for:
- Code 1: controls rotation
- Code 2: controls width
- Code 3: controls lighting
How it works:
- Generator uses both z and c to create images
- Auxiliary network Q tries to predict c from the generated image
- If Q can predict c accurately, the codes are meaningful
- This forces codes to represent interpretable features
Use cases:
- Discover semantic features in datasets
- Disentangled representation learning
- Controllable image generation
- Feature manipulation (change one aspect, keep others)
Reference: Chen et al., "InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets" (2016)
Constructors
InfoGAN(NeuralNetworkArchitecture<T>, NeuralNetworkArchitecture<T>, NeuralNetworkArchitecture<T>, int, InputType, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, ILossFunction<T>?, double)
Initializes a new instance of the InfoGAN<T> class with the specified architecture and training parameters.
public InfoGAN(NeuralNetworkArchitecture<T> generatorArchitecture, NeuralNetworkArchitecture<T> discriminatorArchitecture, NeuralNetworkArchitecture<T> qNetworkArchitecture, int latentCodeSize, InputType inputType, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? generatorOptimizer = null, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? discriminatorOptimizer = null, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? qNetworkOptimizer = null, ILossFunction<T>? lossFunction = null, double mutualInfoCoefficient = 1)
Parameters
generatorArchitectureNeuralNetworkArchitecture<T>The architecture for the generator network.
discriminatorArchitectureNeuralNetworkArchitecture<T>The architecture for the discriminator network.
qNetworkArchitectureNeuralNetworkArchitecture<T>The architecture for the Q network (should output latentCodeSize values).
latentCodeSizeintThe size of the latent code (number of controllable features).
inputTypeInputTypeThe type of input data (e.g., ThreeDimensional for images).
generatorOptimizerIGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>Optional optimizer for the generator. If null, an Adam optimizer with default settings is created.
discriminatorOptimizerIGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>Optional optimizer for the discriminator. If null, an Adam optimizer with default settings is created.
qNetworkOptimizerIGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>Optional optimizer for the Q network. If null, an Adam optimizer with default settings is created.
lossFunctionILossFunction<T>Optional loss function. If null, the default loss function for generative tasks is used.
mutualInfoCoefficientdoubleThe coefficient for mutual information loss. Higher values prioritize feature learning. Default is 1.0.
Remarks
This constructor creates an InfoGAN with three networks: - Generator: Creates images from noise and latent codes - Discriminator: Determines if images are real or fake - Q Network: Predicts latent codes from generated images
The mutual information loss encourages the generator to use the latent codes in meaningful ways that can be recovered by the Q network.
For Beginners: InfoGAN learns controllable features automatically: - The generator creates images using random noise + controllable codes - The Q network tries to guess which codes were used - This forces the codes to represent real, interpretable features - After training, you can manipulate specific features by changing the codes
Exceptions
- ArgumentNullException
Thrown when any of the architecture parameters is null.
- ArgumentOutOfRangeException
Thrown when latentCodeSize is not positive or mutualInfoCoefficient is negative.
Properties
Discriminator
Gets the discriminator network.
public ConvolutionalNeuralNetwork<T> Discriminator { get; }
Property Value
Generator
Gets the generator network.
public ConvolutionalNeuralNetwork<T> Generator { get; }
Property Value
ParameterCount
Gets the total number of trainable parameters in the InfoGAN.
public override int ParameterCount { get; }
Property Value
QNetwork
Gets the auxiliary Q network that predicts latent codes from images.
public ConvolutionalNeuralNetwork<T> QNetwork { get; }
Property Value
Remarks
The Q network shares most of its parameters with the discriminator (up to the last layer). It outputs the predicted latent code distribution given an image. This network is key to maximizing mutual information.
For Beginners: The Q network is the "feature detector".
- Takes an image as input
- Outputs: "I think these codes were used to make this"
- Training makes Q better at guessing codes
- This forces generator to use codes meaningfully
Methods
CreateNewInstance()
Creates a new instance of the InfoGAN with the same configuration.
protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
A new InfoGAN instance with the same architecture and hyperparameters.
Remarks
This method creates a fresh InfoGAN instance with the same network architectures and hyperparameters. The new instance has freshly initialized optimizers.
For Beginners: This method creates a copy of the InfoGAN structure but with new, untrained networks and fresh optimizers.
DeserializeNetworkSpecificData(BinaryReader)
Deserializes InfoGAN-specific data from a binary reader.
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
Parameters
readerBinaryReaderThe binary reader to read from.
Remarks
This method deserializes the InfoGAN-specific configuration and all three networks. After deserialization, the optimizers are reset to their initial state.
For Beginners: This method loads the InfoGAN's settings and all three networks (generator, discriminator, and Q network) from a file.
Generate(Tensor<T>, Tensor<T>)
Generates images with specific latent codes.
public Tensor<T> Generate(Tensor<T> noise, Tensor<T> latentCodes)
Parameters
noiseTensor<T>Random noise.
latentCodesTensor<T>Latent codes to control generation.
Returns
- Tensor<T>
Generated images.
GenerateRandomLatentCodes(int)
Generates random latent codes (continuous, uniform in [-1, 1]).
public Tensor<T> GenerateRandomLatentCodes(int batchSize)
Parameters
batchSizeint
Returns
- Tensor<T>
GenerateRandomNoiseTensor(int, int)
Generates random noise tensor using vectorized Gaussian noise generation with CPU/GPU acceleration.
public Tensor<T> GenerateRandomNoiseTensor(int batchSize, int noiseSize)
Parameters
batchSizeintThe number of noise samples in the batch.
noiseSizeintThe size of each noise sample.
Returns
- Tensor<T>
A tensor of shape [batchSize, noiseSize] filled with Gaussian noise.
GetModelMetadata()
Gets the metadata for this neural network model.
public override ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
A ModelMetaData object containing information about the model.
InitializeLayers()
Initializes the layers of the neural network based on the architecture.
protected override void InitializeLayers()
Remarks
For Beginners: This method sets up all the layers in your neural network according to the architecture you've defined. It's like assembling the parts of your network before you can use it.
Predict(Tensor<T>)
Makes a prediction using the neural network.
public override Tensor<T> Predict(Tensor<T> input)
Parameters
inputTensor<T>The input data to process.
Returns
- Tensor<T>
The network's prediction.
Remarks
For Beginners: This is the main method you'll use to get results from your trained neural network. You provide some input data (like an image or text), and the network processes it through all its layers to produce an output (like a classification or prediction).
ResetOptimizerState()
Resets the state of all optimizers to their initial values.
public void ResetOptimizerState()
Remarks
This method resets all three optimizers (generator, discriminator, and Q network) to their initial state. This is useful when restarting training or when you want to clear accumulated momentum and adaptive learning rate information.
For Beginners: Call this method when you want to start fresh with training, as if the model had never been trained before. The network weights remain unchanged, but the optimizer's memory of past gradients is cleared.
SerializeNetworkSpecificData(BinaryWriter)
Serializes InfoGAN-specific data to a binary writer.
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
Parameters
writerBinaryWriterThe binary writer to write to.
Remarks
This method serializes the InfoGAN-specific configuration and all three networks. Optimizer state is managed by the optimizer implementations themselves.
For Beginners: This method saves the InfoGAN's settings and all three networks (generator, discriminator, and Q network) to a file.
Train(Tensor<T>, Tensor<T>)
Trains the neural network on a single input-output pair.
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)
Parameters
inputTensor<T>The input data.
expectedOutputTensor<T>The expected output for the given input.
Remarks
This method performs one training step on the neural network using the provided input and expected output. It updates the network's parameters to reduce the error between the network's prediction and the expected output.
For Beginners: This is how your neural network learns. You provide: - An input (what the network should process) - The expected output (what the correct answer should be)
The network then:
- Makes a prediction based on the input
- Compares its prediction to the expected output
- Calculates how wrong it was (the loss)
- Adjusts its internal values to do better next time
After training, you can get the loss value using the GetLastLoss() method to see how well the network is learning.
TrainStep(Tensor<T>, Tensor<T>, Tensor<T>)
Performs one training step for InfoGAN.
public (T discriminatorLoss, T generatorLoss, T mutualInfoLoss) TrainStep(Tensor<T> realImages, Tensor<T> noise, Tensor<T> latentCodes)
Parameters
realImagesTensor<T>Real images.
noiseTensor<T>Random noise (z).
latentCodesTensor<T>Latent codes (c) to condition generation.
Returns
Remarks
InfoGAN training: 1. Train discriminator (standard GAN objective) 2. Train generator with GAN loss + mutual information loss 3. Train Q network to predict latent codes from generated images
For Beginners: One round of InfoGAN training.
Steps:
- Generate images using noise + latent codes
- Train discriminator to spot fakes (standard GAN)
- Train generator to fool discriminator
- Train Q network to guess the codes from images
- Make generator use codes that Q can predict
UpdateParameters(Vector<T>)
Updates the parameters of all networks in the InfoGAN.
public override void UpdateParameters(Vector<T> parameters)
Parameters
parametersVector<T>The new parameters vector containing parameters for all networks.