Class WeightedCrossEntropyLoss<T>
- Namespace
- AiDotNet.LossFunctions
- Assembly
- AiDotNet.dll
Implements the Weighted Cross Entropy loss function for classification problems with uneven class importance.
public class WeightedCrossEntropyLoss<T> : LossFunctionBase<T>, ILossFunction<T>
Type Parameters
TThe numeric type used for calculations (e.g., float, double).
- Inheritance
-
WeightedCrossEntropyLoss<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
For Beginners: Weighted Cross Entropy is a variation of the standard cross-entropy loss that applies different weights to different samples or classes.
The regular cross-entropy penalizes all misclassifications equally, but in some cases:
- Some classes might be more important to classify correctly
- Some classes might be rare in the training data but important in practice
- Some samples might be more reliable or representative than others
Weighted Cross Entropy lets you control the importance of different samples by applying weights to them. Higher weights mean the model will focus more on getting those specific samples right.
This loss function is particularly useful for:
- Imbalanced datasets where some classes are underrepresented
- Problems where misclassifying certain classes is more costly than others
- Situations where you have varying confidence in your training data
Constructors
WeightedCrossEntropyLoss(Vector<T>?)
Initializes a new instance of the WeightedCrossEntropyLoss class.
public WeightedCrossEntropyLoss(Vector<T>? weights = null)
Parameters
weightsVector<T>The weights vector for each sample. If null, all samples will have weight 1.
Methods
CalculateDerivative(Vector<T>, Vector<T>)
Calculates the derivative of the Weighted Cross Entropy loss function.
public override Vector<T> CalculateDerivative(Vector<T> predicted, Vector<T> actual)
Parameters
predictedVector<T>The predicted values (probabilities between 0 and 1).
actualVector<T>The actual (target) values (typically 0 or 1).
Returns
- Vector<T>
A vector containing the derivatives of the weighted cross entropy loss with respect to each prediction.
CalculateLoss(Vector<T>, Vector<T>)
Calculates the Weighted Cross Entropy loss between predicted and actual values.
public override T CalculateLoss(Vector<T> predicted, Vector<T> actual)
Parameters
predictedVector<T>The predicted values (probabilities between 0 and 1).
actualVector<T>The actual (target) values (typically 0 or 1).
Returns
- T
The weighted cross entropy loss value.