Class RealESRGANLoss<T>
- Namespace
- AiDotNet.LossFunctions
- Assembly
- AiDotNet.dll
Combined loss function for Real-ESRGAN super-resolution training.
public class RealESRGANLoss<T> : LossFunctionBase<T>, ILossFunction<T>
Type Parameters
TThe numeric type used for calculations (e.g., float, double).
- Inheritance
-
RealESRGANLoss<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
Real-ESRGAN uses a combination of three loss functions for training: - L1 (pixel-wise) loss: Ensures pixel-level accuracy - Perceptual (VGG) loss: Ensures perceptual quality using deep features - GAN (adversarial) loss: Ensures realistic details and textures
The total loss is computed as:
L_total = λ_L1 * L_L1 + λ_perceptual * L_perceptual + λ_GAN * L_GAN
For Beginners: This loss function guides Real-ESRGAN training by balancing three goals:
L1 Loss (pixel accuracy): Makes sure each pixel is close to the ground truth. Like comparing photos pixel-by-pixel.
Perceptual Loss (looks right): Uses a pre-trained network (VGG) to compare high-level features. Ensures the output "looks right" even if pixels aren't exact.
GAN Loss (realistic details): The discriminator judges if output looks real. This adds fine details and textures that make images look natural.
The weights control how much each goal matters:
- Higher L1 weight = more pixel-accurate but potentially blurry
- Higher perceptual weight = better visual quality
- Higher GAN weight = more realistic textures but potential artifacts
The default weights (1.0, 1.0, 0.1) are from the Real-ESRGAN paper.
Reference: Wang et al., "Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data", ICCV 2021. https://arxiv.org/abs/2107.10833
Constructors
RealESRGANLoss(double, double, double, Func<Tensor<T>, Tensor<T>>?)
Initializes a new instance of the RealESRGANLoss class.
public RealESRGANLoss(double l1Weight = 1, double perceptualWeight = 1, double ganWeight = 0.1, Func<Tensor<T>, Tensor<T>>? featureExtractor = null)
Parameters
l1WeightdoubleWeight for L1 loss. Default: 1.0 (from Real-ESRGAN paper).
perceptualWeightdoubleWeight for perceptual loss. Default: 1.0 (from Real-ESRGAN paper).
ganWeightdoubleWeight for GAN loss. Default: 0.1 (from Real-ESRGAN paper).
featureExtractorFunc<Tensor<T>, Tensor<T>>Optional VGG feature extractor for perceptual loss.
Remarks
For Beginners: Create this loss with default weights from the paper:
var loss = new RealESRGANLoss<double>();
Or customize weights for different trade-offs:
// More pixel-accurate (potentially blurrier)
var loss = new RealESRGANLoss<double>(l1Weight: 2.0, ganWeight: 0.05);
// More realistic textures (potential artifacts)
var loss = new RealESRGANLoss<double>(ganWeight: 0.2);
Properties
GANWeight
Gets the GAN loss weight.
public double GANWeight { get; }
Property Value
L1Weight
Gets the L1 weight.
public double L1Weight { get; }
Property Value
PerceptualWeight
Gets the perceptual loss weight.
public double PerceptualWeight { get; }
Property Value
Methods
CalculateCombinedLoss(Vector<T>, Vector<T>, Vector<T>)
Calculates the full combined loss including GAN component.
public T CalculateCombinedLoss(Vector<T> predicted, Vector<T> actual, Vector<T> discriminatorOutput)
Parameters
predictedVector<T>The predicted (super-resolved) output.
actualVector<T>The ground truth high-resolution target.
discriminatorOutputVector<T>The discriminator's output for the predicted image.
Returns
- T
The combined loss value including GAN loss.
Remarks
Use this method during training when you have access to the discriminator output. The GAN loss encourages the generator to produce outputs that fool the discriminator.
CalculateDerivative(Vector<T>, Vector<T>)
Calculates the derivative of the combined loss for backpropagation.
public override Vector<T> CalculateDerivative(Vector<T> predicted, Vector<T> actual)
Parameters
predictedVector<T>The predicted values.
actualVector<T>The actual (target) values.
Returns
- Vector<T>
The gradient vector.
CalculateDiscriminatorLoss(Vector<T>, Vector<T>)
Calculates the discriminator loss.
public T CalculateDiscriminatorLoss(Vector<T> realOutput, Vector<T> fakeOutput)
Parameters
realOutputVector<T>Discriminator output for real images.
fakeOutputVector<T>Discriminator output for generated images.
Returns
- T
The discriminator loss value.
Remarks
The discriminator wants to output 1 for real images and 0 for fake images. This computes: -E[log(D(x))] - E[log(1 - D(G(x)))]
CalculateGeneratorGANLoss(Vector<T>)
Calculates the generator's GAN loss component.
public T CalculateGeneratorGANLoss(Vector<T> discriminatorOutput)
Parameters
discriminatorOutputVector<T>The discriminator's output for generated images.
Returns
- T
The GAN loss value.
Remarks
The generator wants the discriminator to classify its output as real (1.0). This computes: -E[log(D(G(x)))]
CalculateLoss(Vector<T>, Vector<T>)
Calculates the combined Real-ESRGAN loss.
public override T CalculateLoss(Vector<T> predicted, Vector<T> actual)
Parameters
predictedVector<T>The predicted (super-resolved) output.
actualVector<T>The ground truth high-resolution target.
Returns
- T
The combined loss value.
Remarks
This method computes the L1 and perceptual components of the loss. The GAN loss component should be computed separately during training using the discriminator output.
For Beginners: This calculates how "wrong" the prediction is. Lower values mean the prediction is closer to the ground truth.