Class BigGAN<T>
- Namespace
- AiDotNet.NeuralNetworks
- Assembly
- AiDotNet.dll
BigGAN implementation for large-scale high-fidelity image generation.
For Beginners: BigGAN is a state-of-the-art GAN architecture that generates extremely high-quality images by scaling up training in several ways:
- Using very large batch sizes (256-2048 images at once)
- Increasing model capacity (more parameters and feature maps)
- Using class information to generate specific types of images
Think of it like training an artist:
- Small batch = showing the artist 1-2 examples at a time
- BigGAN batch = showing 256+ examples at once for better learning
- Class conditioning = telling the artist exactly what to draw ("draw a cat" vs "draw something")
Key innovations:
- Large Batch Training: Uses batch sizes of 256-2048 (vs typical 32-128)
- Spectral Normalization: Stabilizes training for both G and D
- Self-Attention: Helps model long-range dependencies in images
- Class Conditioning: Uses class embeddings for controlled generation
- Truncation Trick: Trade diversity for quality at generation time
- Orthogonal Initialization: Better weight initialization
- Skip Connections: Direct paths in generator architecture
Based on "Large Scale GAN Training for High Fidelity Natural Image Synthesis" by Brock et al. (2019)
public class BigGAN<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable
Type Parameters
TThe numeric type for computations (e.g., double, float)
- Inheritance
-
BigGAN<T>
- Implements
- Inherited Members
- Extension Methods
Constructors
BigGAN(NeuralNetworkArchitecture<T>, NeuralNetworkArchitecture<T>, int, int, int, int, int, int, int, int, InputType, ILossFunction<T>?, double)
Initializes a new instance of BigGAN.
public BigGAN(NeuralNetworkArchitecture<T> generatorArchitecture, NeuralNetworkArchitecture<T> discriminatorArchitecture, int latentSize = 120, int numClasses = 1000, int classEmbeddingDim = 128, int imageChannels = 3, int imageHeight = 128, int imageWidth = 128, int generatorChannels = 96, int discriminatorChannels = 96, InputType inputType = InputType.TwoDimensional, ILossFunction<T>? lossFunction = null, double initialLearningRate = 0.0001)
Parameters
generatorArchitectureNeuralNetworkArchitecture<T>Architecture for the generator network.
discriminatorArchitectureNeuralNetworkArchitecture<T>Architecture for the discriminator network.
latentSizeintSize of the latent noise vector (default 120)
numClassesintNumber of classes for conditional generation
classEmbeddingDimintDimension of class embeddings (default 128)
imageChannelsintNumber of image channels (1 for grayscale, 3 for RGB)
imageHeightintHeight of generated images
imageWidthintWidth of generated images
generatorChannelsintBase number of channels in generator (default 96)
discriminatorChannelsintBase number of channels in discriminator (default 96)
inputTypeInputTypeThe type of input.
lossFunctionILossFunction<T>Loss function for training (defaults to hinge loss)
initialLearningRatedoubleInitial learning rate (default 0.0001)
Properties
ClassEmbeddingDim
Gets the dimension of class embeddings. These learned embeddings represent each class.
public int ClassEmbeddingDim { get; }
Property Value
Discriminator
Gets the discriminator network that evaluates images and predicts their class. Uses projection discriminator for efficient class conditioning.
public ConvolutionalNeuralNetwork<T> Discriminator { get; }
Property Value
Generator
Gets the generator network that produces images from noise and class labels.
public ConvolutionalNeuralNetwork<T> Generator { get; }
Property Value
LatentSize
Gets the size of the latent noise vector. BigGAN typically uses 120-dimensional latent codes.
public int LatentSize { get; }
Property Value
NumClasses
Gets the number of classes for conditional generation. For example, ImageNet has 1000 classes.
public int NumClasses { get; }
Property Value
ParameterCount
Gets the total number of trainable parameters in the BigGAN.
public override int ParameterCount { get; }
Property Value
Remarks
This includes all parameters from both the Generator and Discriminator networks.
TruncationThreshold
Gets or sets the truncation threshold for the truncation trick. Values in range [0, 2], where lower values trade diversity for quality. Typical value: 0.5 for high quality, 1.0 for balanced, 2.0 for high diversity.
public double TruncationThreshold { get; set; }
Property Value
UseSelfAttention
Gets or sets whether to use self-attention layers.
public bool UseSelfAttention { get; set; }
Property Value
UseSpectralNormalization
Gets or sets whether to use spectral normalization in both generator and discriminator.
public bool UseSpectralNormalization { get; set; }
Property Value
UseTruncation
Gets or sets whether to use the truncation trick during generation. When enabled, samples are resampled if they fall outside the truncation threshold.
public bool UseTruncation { get; set; }
Property Value
Methods
CreateNewInstance()
Creates a new instance of the same type as this neural network.
protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
A new instance of the same neural network type.
Remarks
For Beginners: This creates a blank version of the same type of neural network.
It's used internally by methods like DeepCopy and Clone to create the right type of network before copying the data into it.
DeserializeNetworkSpecificData(BinaryReader)
Deserializes network-specific data that was not covered by the general deserialization process.
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
Parameters
readerBinaryReaderThe BinaryReader to read the data from.
Remarks
This method is called at the end of the general deserialization process to allow derived classes to read any additional data specific to their implementation.
For Beginners: Continuing the suitcase analogy, this is like unpacking that special compartment. After the main deserialization method has unpacked the common items (layers, parameters), this method allows each specific type of neural network to unpack its own unique items that were stored during serialization.
Generate(Tensor<T>, int[])
Generates images from latent codes and class labels.
public Tensor<T> Generate(Tensor<T> latentCodes, int[] classIndices)
Parameters
latentCodesTensor<T>Latent noise vectors
classIndicesint[]Class indices for each sample (must be in range [0, NumClasses))
Returns
- Tensor<T>
Generated images
Exceptions
- ArgumentException
Thrown when class indices don't match batch size or are out of range.
Generate(int)
Generates random images with random class labels.
public Tensor<T> Generate(int numImages)
Parameters
numImagesintNumber of images to generate
Returns
- Tensor<T>
Generated images
Exceptions
- ArgumentOutOfRangeException
Thrown when numImages is not positive.
GetModelMetadata()
Gets the metadata for this neural network model.
public override ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
A ModelMetaData object containing information about the model.
GetParameters()
Gets all trainable parameters of the network as a single vector.
public override Vector<T> GetParameters()
Returns
- Vector<T>
A vector containing all parameters of the network.
Remarks
For Beginners: Neural networks learn by adjusting their "parameters" (also called weights and biases). This method collects all those adjustable values into a single list so they can be updated during training.
InitializeLayers()
Initializes the layers of the neural network based on the architecture.
protected override void InitializeLayers()
Remarks
For Beginners: This method sets up all the layers in your neural network according to the architecture you've defined. It's like assembling the parts of your network before you can use it.
Predict(Tensor<T>)
Makes a prediction using the neural network.
public override Tensor<T> Predict(Tensor<T> input)
Parameters
inputTensor<T>The input data to process.
Returns
- Tensor<T>
The network's prediction.
Remarks
For Beginners: This is the main method you'll use to get results from your trained neural network. You provide some input data (like an image or text), and the network processes it through all its layers to produce an output (like a classification or prediction).
SerializeNetworkSpecificData(BinaryWriter)
Serializes network-specific data that is not covered by the general serialization process.
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
Parameters
writerBinaryWriterThe BinaryWriter to write the data to.
Remarks
This method is called at the end of the general serialization process to allow derived classes to write any additional data specific to their implementation.
For Beginners: Think of this as packing a special compartment in your suitcase. While the main serialization method packs the common items (layers, parameters), this method allows each specific type of neural network to pack its own unique items that other networks might not have.
SetTrainingMode(bool)
Sets the neural network to either training or inference mode.
public override void SetTrainingMode(bool isTraining)
Parameters
isTrainingboolTrue to enable training mode; false to enable inference mode.
Remarks
For Beginners: Neural networks behave differently during training versus when making predictions.
When in training mode (isTraining = true): - The network keeps track of intermediate calculations needed for learning - Certain layers like Dropout and BatchNormalization behave differently - The network uses more memory but can learn from its mistakes
When in inference/prediction mode (isTraining = false): - The network only performs forward calculations - It uses less memory and runs faster - It cannot learn or update its parameters
Think of it like the difference between taking a practice test (training mode) where you can check your answers and learn from mistakes, versus taking the actual exam (inference mode) where you just give your best answers based on what you've already learned.
Train(Tensor<T>, Tensor<T>)
Trains the neural network on a single input-output pair.
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)
Parameters
inputTensor<T>The input data.
expectedOutputTensor<T>The expected output for the given input.
Remarks
This method performs one training step on the neural network using the provided input and expected output. It updates the network's parameters to reduce the error between the network's prediction and the expected output.
For Beginners: This is how your neural network learns. You provide: - An input (what the network should process) - The expected output (what the correct answer should be)
The network then:
- Makes a prediction based on the input
- Compares its prediction to the expected output
- Calculates how wrong it was (the loss)
- Adjusts its internal values to do better next time
After training, you can get the loss value using the GetLastLoss() method to see how well the network is learning.
TrainStep(Tensor<T>, int[], int)
Performs a single training step on a batch of real images with labels. Uses hinge loss by default for improved stability.
public (T discriminatorLoss, T generatorLoss) TrainStep(Tensor<T> realImages, int[] realLabels, int batchSize)
Parameters
realImagesTensor<T>Batch of real images
realLabelsint[]Class labels for real images
batchSizeintNumber of images in the batch
Returns
Exceptions
- ArgumentNullException
Thrown when realImages or realLabels is null.
- ArgumentOutOfRangeException
Thrown when batchSize is not positive.
- ArgumentException
Thrown when array lengths don't match or labels are out of range.
UpdateParameters(Vector<T>)
Updates the network's parameters with new values.
public override void UpdateParameters(Vector<T> parameters)
Parameters
parametersVector<T>The new parameter values to set.
Remarks
For Beginners: During training, a neural network's internal values (parameters) get adjusted to improve its performance. This method allows you to update all those values at once by providing a complete set of new parameters.
This is typically used by optimization algorithms that calculate better parameter values based on training data.