Class ACGAN<T>
- Namespace
- AiDotNet.NeuralNetworks
- Assembly
- AiDotNet.dll
Represents an Auxiliary Classifier Generative Adversarial Network (AC-GAN), which extends conditional GANs by having the discriminator also predict the class label of the input.
public class ACGAN<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable
Type Parameters
TThe numeric type used for calculations, typically float or double.
- Inheritance
-
ACGAN<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
AC-GAN improves upon conditional GANs by: - Making the discriminator predict both authenticity AND class label - Providing stronger gradient signals for class-conditional generation - Improving image quality and class separability - Enabling better control over generated samples - Training more stable than basic conditional GANs
For Beginners: AC-GAN generates specific types of images with better quality.
Key improvements over cGAN:
- Discriminator has two tasks: "Is it real?" AND "What class is it?"
- This dual task helps the discriminator learn better features
- Generator must create images that fool both checks
- Results in higher quality and more class-consistent images
Example use case:
- Generate digit "7" that looks very realistic
- Discriminator checks: 1) Is it real? 2) Is it a "7"?
- This forces the generator to make better "7"s
Reference: Odena et al., "Conditional Image Synthesis with Auxiliary Classifier GANs" (2017)
Constructors
ACGAN(NeuralNetworkArchitecture<T>, NeuralNetworkArchitecture<T>, int, InputType, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, ILossFunction<T>?)
Initializes a new instance of the ACGAN<T> class.
public ACGAN(NeuralNetworkArchitecture<T> generatorArchitecture, NeuralNetworkArchitecture<T> discriminatorArchitecture, int numClasses, InputType inputType, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? generatorOptimizer = null, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? discriminatorOptimizer = null, ILossFunction<T>? lossFunction = null)
Parameters
generatorArchitectureNeuralNetworkArchitecture<T>The neural network architecture for the generator.
discriminatorArchitectureNeuralNetworkArchitecture<T>The neural network architecture for the discriminator. Note: Output size should be 1 + numClasses (authenticity probability + class probabilities). All outputs must be in range (0, 1) - use sigmoid/softmax activations in the final layer.
numClassesintThe number of classes.
inputTypeInputTypeThe type of input.
generatorOptimizerIGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>Optional optimizer for the generator. If null, Adam optimizer is used.
discriminatorOptimizerIGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>Optional optimizer for the discriminator. If null, Adam optimizer is used.
lossFunctionILossFunction<T>Optional loss function.
Properties
Discriminator
Gets the discriminator network that predicts both authenticity and class.
public ConvolutionalNeuralNetwork<T> Discriminator { get; }
Property Value
Remarks
Unlike standard GANs, the AC-GAN discriminator has two outputs: 1. Authenticity score (real vs fake) - 1 output 2. Class probability distribution - numClasses outputs
For Beginners: The discriminator is a multi-task network.
Two outputs:
- "Is this real or fake?" (1 number: 0-1)
- "What class is this?" (probability for each class)
This dual purpose makes it a better feature learner.
Generator
Gets the generator network that creates class-conditional synthetic data.
public ConvolutionalNeuralNetwork<T> Generator { get; }
Property Value
ParameterCount
Gets the total number of trainable parameters in the ACGAN.
public override int ParameterCount { get; }
Property Value
Methods
CreateNewInstance()
Creates a new instance of the same type as this neural network.
protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
A new instance of the same neural network type.
Remarks
For Beginners: This creates a blank version of the same type of neural network.
It's used internally by methods like DeepCopy and Clone to create the right type of network before copying the data into it.
CreateOneHotLabels(int, int)
Creates one-hot encoded class labels.
public Tensor<T> CreateOneHotLabels(int batchSize, int classIndex)
Parameters
Returns
- Tensor<T>
DeserializeNetworkSpecificData(BinaryReader)
Deserializes AC-GAN-specific data including networks and optimizer states.
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
Parameters
readerBinaryReaderThe binary reader to deserialize data from.
Remarks
This method restores all components needed to continue AC-GAN training from a saved state:
- Number of classes for classification
- Loss histories for training progress visualization
- Generator and Discriminator networks with all learned weights
- Optimizer states (momentum vectors, adaptive learning rates, timesteps)
For Beginners: When you load a saved AC-GAN, this method restores everything needed to continue training exactly where you left off:
- The networks remember everything they learned
- The optimizers remember their momentum and learning rate adjustments
- Training can resume smoothly without any "warm-up" period
This is especially important for Adam optimizer which maintains momentum vectors (m and v) and a timestep counter - losing these would cause training instability after loading.
GenerateConditional(Tensor<T>, Tensor<T>)
Generates class-conditional images.
public Tensor<T> GenerateConditional(Tensor<T> noise, Tensor<T> classLabels)
Parameters
noiseTensor<T>classLabelsTensor<T>
Returns
- Tensor<T>
GenerateRandomNoiseTensor(int, int)
Generates random noise tensor using vectorized Gaussian noise generation.
public Tensor<T> GenerateRandomNoiseTensor(int batchSize, int noiseSize)
Parameters
Returns
- Tensor<T>
GetModelMetadata()
Gets the metadata for this neural network model.
public override ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
A ModelMetaData object containing information about the model.
InitializeLayers()
Initializes the layers of the neural network based on the architecture.
protected override void InitializeLayers()
Remarks
For Beginners: This method sets up all the layers in your neural network according to the architecture you've defined. It's like assembling the parts of your network before you can use it.
Predict(Tensor<T>)
Makes a prediction using the neural network.
public override Tensor<T> Predict(Tensor<T> input)
Parameters
inputTensor<T>The input data to process.
Returns
- Tensor<T>
The network's prediction.
Remarks
For Beginners: This is the main method you'll use to get results from your trained neural network. You provide some input data (like an image or text), and the network processes it through all its layers to produce an output (like a classification or prediction).
ResetOptimizerState()
Resets both optimizer states for a fresh training run.
public void ResetOptimizerState()
SerializeNetworkSpecificData(BinaryWriter)
Serializes AC-GAN-specific data including networks and optimizer states.
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
Parameters
writerBinaryWriterThe binary writer to serialize data to.
Remarks
This method serializes all components needed to fully restore an AC-GAN's training state:
- Number of classes
- Loss histories for monitoring training progress
- Generator and Discriminator networks with all learned weights
- Optimizer states (momentum, adaptive learning rates, timesteps)
For Beginners: When you save an AC-GAN during training, this method ensures that everything needed to resume training is saved:
- The networks' learned knowledge (weights and biases)
- The optimizers' "memory" (like Adam's momentum vectors)
- Training history (loss values for monitoring)
Without saving optimizer states, resuming training would be like starting with a new optimizer that has forgotten all the momentum and adaptive learning rates it built up, which can cause unstable training after loading.
Train(Tensor<T>, Tensor<T>)
Performs a single training iteration using the standard neural network interface.
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)
Parameters
inputTensor<T>The noise tensor used as input to the generator network. Shape should be [batchSize, noiseSize] where noiseSize matches the generator's expected input.
expectedOutputTensor<T>The real images tensor used for discriminator training. Shape should be [batchSize, height, width, channels] or equivalent flattened form.
Remarks
This method adapts the AC-GAN's specialized training to the standard Train(Tensor<T>, Tensor<T>) interface by automatically generating random class labels for both real and fake samples.
The AC-GAN training process differs from standard neural networks because it requires:
- Real images with their class labels
- Noise vectors for generating fake images
- Target class labels for the generated images
When using this simplified interface, random class labels are generated using AiDotNet.Tensors.Helpers.RandomHelper.ThreadSafeRandom for thread-safe, cryptographically-seeded random number generation. For more control over class labels, use the TrainStep(Tensor<T>, Tensor<T>, Tensor<T>, Tensor<T>) method directly.
For Beginners: This method lets you train an AC-GAN using the same interface as other neural networks. Just provide:
input: Random noise vectors (like random seeds for image generation)expectedOutput: Real images to learn from
The method automatically assigns random class labels (like "digit 3", "digit 7", etc.) to both the real images and the images to generate. While this is convenient, for best results you should use TrainStep(Tensor<T>, Tensor<T>, Tensor<T>, Tensor<T>) with actual class labels from your dataset.
Exceptions
- ArgumentNullException
Thrown when
inputorexpectedOutputis null.
- See Also
TrainStep(Tensor<T>, Tensor<T>, Tensor<T>, Tensor<T>)
Performs one training step for the AC-GAN.
public (T discriminatorLoss, T generatorLoss) TrainStep(Tensor<T> realImages, Tensor<T> realLabels, Tensor<T> noise, Tensor<T> fakeLabels)
Parameters
realImagesTensor<T>Real images tensor.
realLabelsTensor<T>Real image class labels (one-hot encoded).
noiseTensor<T>Random noise for generator.
fakeLabelsTensor<T>Class labels for images to generate (one-hot encoded).
Returns
UpdateParameters(Vector<T>)
Updates the network's parameters with new values.
public override void UpdateParameters(Vector<T> parameters)
Parameters
parametersVector<T>The new parameter values to set.
Remarks
For Beginners: During training, a neural network's internal values (parameters) get adjusted to improve its performance. This method allows you to update all those values at once by providing a complete set of new parameters.
This is typically used by optimization algorithms that calculate better parameter values based on training data.