Class MarginLoss<T>
- Namespace
- AiDotNet.LossFunctions
- Assembly
- AiDotNet.dll
Implements the Margin loss function, specifically designed for Capsule Networks.
public class MarginLoss<T> : LossFunctionBase<T>, ILossFunction<T>
Type Parameters
TThe numeric type used for calculations (e.g., float, double).
- Inheritance
-
MarginLoss<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
For Beginners: Margin loss is a special loss function used in Capsule Networks.
The formula is: T_c * max(0, m+ - ||v_c||)^2 + lambda * (1 - T_c) * max(0, ||v_c|| - m-)^2
Where:
- T_c is 1 if class c is present, 0 otherwise
- ||v_c|| is the length of the output vector of the capsule for class c
- m+ is the upper bound (usually 0.9)
- m- is the lower bound (usually 0.1)
- lambda is a down-weighting factor (usually 0.5)
Key properties:
- Encourages the network to output high values for correct classes
- Discourages high outputs for incorrect classes
- Helps in learning to represent different aspects of the input
Margin loss is ideal for Capsule Networks because:
- It allows multiple classes to be present in the same image
- It encourages the network to learn to represent different viewpoints and transformations
- It helps in achieving equivariance, a key property of Capsule Networks
Constructors
MarginLoss(double, double, double)
Initializes a new instance of the MarginLoss class with the specified parameters.
public MarginLoss(double mPlus = 0.9, double mMinus = 0.1, double lambda = 0.5)
Parameters
mPlusdoubleThe upper bound. Default is 0.9.
mMinusdoubleThe lower bound. Default is 0.1.
lambdadoubleThe down-weighting factor. Default is 0.5.
Methods
CalculateDerivative(Vector<T>, Vector<T>)
Calculates the derivative of the Margin loss function.
public override Vector<T> CalculateDerivative(Vector<T> predicted, Vector<T> actual)
Parameters
predictedVector<T>The predicted values from the model.
actualVector<T>The actual (target) values.
Returns
- Vector<T>
A vector containing the derivatives of Margin loss for each prediction.
CalculateLoss(Vector<T>, Vector<T>)
Calculates the Margin loss between predicted and actual values.
public override T CalculateLoss(Vector<T> predicted, Vector<T> actual)
Parameters
predictedVector<T>The predicted values from the model.
actualVector<T>The actual (target) values.
Returns
- T
The Margin loss value.