Class CycleGAN<T>
- Namespace
- AiDotNet.NeuralNetworks
- Assembly
- AiDotNet.dll
Represents a CycleGAN for unpaired image-to-image translation.
public class CycleGAN<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable
Type Parameters
TThe numeric type.
- Inheritance
-
CycleGAN<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
CycleGAN enables image-to-image translation without paired training data: - Uses two generators (A→B and B→A) and two discriminators - Enforces cycle consistency: A→B→A should equal A - Works without paired examples (e.g., can learn horses→zebras from separate collections) - Uses adversarial loss + cycle consistency loss + identity loss
For Beginners: CycleGAN translates images without matched pairs.
Key innovation:
- Doesn't need paired training data
- Learns from two separate collections of images
- Example: Photos of horses + Photos of zebras → can convert horses to zebras
How it works:
- Two generators: G (A→B) and F (B→A)
- Two discriminators: D_A and D_B
- Cycle consistency: G(F(B)) ≈ B and F(G(A)) ≈ A
- This prevents mode collapse and maintains content
Applications:
- Style transfer (Monet → Photo, Photo → Monet)
- Season transfer (Summer → Winter)
- Object transfiguration (Horse → Zebra)
- Domain adaptation
Reference: Zhu et al., "Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks" (2017)
Constructors
CycleGAN(NeuralNetworkArchitecture<T>, NeuralNetworkArchitecture<T>, NeuralNetworkArchitecture<T>, NeuralNetworkArchitecture<T>, InputType, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>?, ILossFunction<T>?, double, double)
Initializes a new instance of the CycleGAN<T> class with the specified architecture and training parameters.
public CycleGAN(NeuralNetworkArchitecture<T> generatorAtoB, NeuralNetworkArchitecture<T> generatorBtoA, NeuralNetworkArchitecture<T> discriminatorA, NeuralNetworkArchitecture<T> discriminatorB, InputType inputType, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? generatorAtoBOptimizer = null, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? generatorBtoAOptimizer = null, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? discriminatorAOptimizer = null, IGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>? discriminatorBOptimizer = null, ILossFunction<T>? lossFunction = null, double cycleConsistencyLambda = 10, double identityLambda = 5)
Parameters
generatorAtoBNeuralNetworkArchitecture<T>The architecture for the generator that transforms images from domain A to domain B.
generatorBtoANeuralNetworkArchitecture<T>The architecture for the generator that transforms images from domain B to domain A.
discriminatorANeuralNetworkArchitecture<T>The architecture for the discriminator that evaluates images in domain A.
discriminatorBNeuralNetworkArchitecture<T>The architecture for the discriminator that evaluates images in domain B.
inputTypeInputTypeThe type of input data (e.g., ThreeDimensional for images).
generatorAtoBOptimizerIGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>Optional optimizer for the A→B generator. If null, an Adam optimizer with default GAN settings is created.
generatorBtoAOptimizerIGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>Optional optimizer for the B→A generator. If null, an Adam optimizer with default GAN settings is created.
discriminatorAOptimizerIGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>Optional optimizer for discriminator A. If null, an Adam optimizer with default GAN settings is created.
discriminatorBOptimizerIGradientBasedOptimizer<T, Tensor<T>, Tensor<T>>Optional optimizer for discriminator B. If null, an Adam optimizer with default GAN settings is created.
lossFunctionILossFunction<T>Optional loss function. If null, the default loss function for generative tasks is used.
cycleConsistencyLambdadoubleThe coefficient for cycle consistency loss. Higher values enforce stronger cycle consistency. Default is 10.0.
identityLambdadoubleThe coefficient for identity loss. Helps preserve color composition. Default is 5.0.
Remarks
This constructor creates a CycleGAN with four separate networks and optimizers: - Generator A→B: Transforms images from domain A to domain B - Generator B→A: Transforms images from domain B to domain A - Discriminator A: Evaluates whether images in domain A are real or generated - Discriminator B: Evaluates whether images in domain B are real or generated
For Beginners: CycleGAN needs four networks to work: - Two generators to translate images in both directions - Two discriminators to judge images in each domain
The cycle consistency loss ensures that translating A→B→A gets back to the original, which helps maintain content while only changing style.
Exceptions
- ArgumentNullException
Thrown when any of the architecture parameters is null.
- ArgumentOutOfRangeException
Thrown when cycleConsistencyLambda or identityLambda is negative.
Properties
DiscriminatorA
Discriminator for domain A.
public NeuralNetworkBase<T> DiscriminatorA { get; }
Property Value
DiscriminatorB
Discriminator for domain B.
public NeuralNetworkBase<T> DiscriminatorB { get; }
Property Value
GeneratorAtoB
Generator A→B.
public NeuralNetworkBase<T> GeneratorAtoB { get; }
Property Value
GeneratorBtoA
Generator B→A.
public NeuralNetworkBase<T> GeneratorBtoA { get; }
Property Value
Methods
CreateNewInstance()
Creates a new instance of the CycleGAN with the same configuration.
protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
A new CycleGAN instance with the same architecture and hyperparameters.
Remarks
This method creates a fresh CycleGAN instance with the same network architectures and hyperparameters. The new instance has freshly initialized optimizers.
For Beginners: This method creates a copy of the CycleGAN structure but with new, untrained networks and fresh optimizers.
DeserializeNetworkSpecificData(BinaryReader)
Deserializes CycleGAN-specific data from a binary reader.
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
Parameters
readerBinaryReaderThe binary reader to read from.
Remarks
This method deserializes the CycleGAN-specific configuration and all four networks. After deserialization, the optimizers are reset to their initial state.
For Beginners: This method loads the CycleGAN's settings and all four networks (two generators and two discriminators) from a file.
GetModelMetadata()
Gets the metadata for this neural network model.
public override ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
A ModelMetaData object containing information about the model.
InitializeLayers()
Initializes the layers of the neural network based on the architecture.
protected override void InitializeLayers()
Remarks
For Beginners: This method sets up all the layers in your neural network according to the architecture you've defined. It's like assembling the parts of your network before you can use it.
Predict(Tensor<T>)
Makes a prediction using the neural network.
public override Tensor<T> Predict(Tensor<T> input)
Parameters
inputTensor<T>The input data to process.
Returns
- Tensor<T>
The network's prediction.
Remarks
For Beginners: This is the main method you'll use to get results from your trained neural network. You provide some input data (like an image or text), and the network processes it through all its layers to produce an output (like a classification or prediction).
ResetOptimizerState()
Resets the state of all optimizers to their initial values.
public void ResetOptimizerState()
Remarks
This method resets all four optimizers (both generators and both discriminators) to their initial state. This is useful when restarting training or when you want to clear accumulated momentum and adaptive learning rate information.
For Beginners: Call this method when you want to start fresh with training, as if the model had never been trained before. The network weights remain unchanged, but the optimizer's memory of past gradients is cleared.
SerializeNetworkSpecificData(BinaryWriter)
Serializes CycleGAN-specific data to a binary writer.
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
Parameters
writerBinaryWriterThe binary writer to write to.
Remarks
This method serializes the CycleGAN-specific configuration and all four networks. Optimizer state is managed by the optimizer implementations themselves.
For Beginners: This method saves the CycleGAN's settings and all four networks (two generators and two discriminators) to a file.
Train(Tensor<T>, Tensor<T>)
Trains the neural network on a single input-output pair.
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)
Parameters
inputTensor<T>The input data.
expectedOutputTensor<T>The expected output for the given input.
Remarks
This method performs one training step on the neural network using the provided input and expected output. It updates the network's parameters to reduce the error between the network's prediction and the expected output.
For Beginners: This is how your neural network learns. You provide: - An input (what the network should process) - The expected output (what the correct answer should be)
The network then:
- Makes a prediction based on the input
- Compares its prediction to the expected output
- Calculates how wrong it was (the loss)
- Adjusts its internal values to do better next time
After training, you can get the loss value using the GetLastLoss() method to see how well the network is learning.
TrainStep(Tensor<T>, Tensor<T>)
Performs one training step for CycleGAN.
public (T discLoss, T genLoss, T cycleLoss) TrainStep(Tensor<T> realA, Tensor<T> realB)
Parameters
realATensor<T>Real images from domain A.
realBTensor<T>Real images from domain B.
Returns
- (T Precision, T Recall, T F1Score)
A tuple containing discriminator loss, generator loss, and cycle consistency loss.
Exceptions
- ArgumentNullException
Thrown when realA or realB is null.
- ArgumentException
Thrown when batch dimensions don't match or batch size is zero.
TranslateAtoB(Tensor<T>)
Translates image from domain A to domain B.
public Tensor<T> TranslateAtoB(Tensor<T> imageA)
Parameters
imageATensor<T>
Returns
- Tensor<T>
Remarks
This method temporarily sets the generator to evaluation mode for inference, then restores the original training mode after prediction. This ensures batch normalization and dropout behave correctly during both inference and subsequent training steps.
TranslateBtoA(Tensor<T>)
Translates image from domain B to domain A.
public Tensor<T> TranslateBtoA(Tensor<T> imageB)
Parameters
imageBTensor<T>
Returns
- Tensor<T>
Remarks
This method temporarily sets the generator to evaluation mode for inference, then restores the original training mode after prediction. This ensures batch normalization and dropout behave correctly during both inference and subsequent training steps.
UpdateParameters(Vector<T>)
Updates the parameters of all networks in the CycleGAN.
public override void UpdateParameters(Vector<T> parameters)
Parameters
parametersVector<T>The new parameters vector containing parameters for all networks.