Table of Contents

Class MarginLoss<T>

Namespace
AiDotNet.LossFunctions
Assembly
AiDotNet.dll

Implements the Margin loss function, specifically designed for Capsule Networks.

public class MarginLoss<T> : LossFunctionBase<T>, ILossFunction<T>

Type Parameters

T

The numeric type used for calculations (e.g., float, double).

Inheritance
MarginLoss<T>
Implements
Inherited Members
Extension Methods

Remarks

For Beginners: Margin loss is a special loss function used in Capsule Networks.

The formula is: T_c * max(0, m+ - ||v_c||)^2 + lambda * (1 - T_c) * max(0, ||v_c|| - m-)^2

Where:

  • T_c is 1 if class c is present, 0 otherwise
  • ||v_c|| is the length of the output vector of the capsule for class c
  • m+ is the upper bound (usually 0.9)
  • m- is the lower bound (usually 0.1)
  • lambda is a down-weighting factor (usually 0.5)

Key properties:

  • Encourages the network to output high values for correct classes
  • Discourages high outputs for incorrect classes
  • Helps in learning to represent different aspects of the input

Margin loss is ideal for Capsule Networks because:

  • It allows multiple classes to be present in the same image
  • It encourages the network to learn to represent different viewpoints and transformations
  • It helps in achieving equivariance, a key property of Capsule Networks

Constructors

MarginLoss(double, double, double)

Initializes a new instance of the MarginLoss class with the specified parameters.

public MarginLoss(double mPlus = 0.9, double mMinus = 0.1, double lambda = 0.5)

Parameters

mPlus double

The upper bound. Default is 0.9.

mMinus double

The lower bound. Default is 0.1.

lambda double

The down-weighting factor. Default is 0.5.

Methods

CalculateDerivative(Vector<T>, Vector<T>)

Calculates the derivative of the Margin loss function.

public override Vector<T> CalculateDerivative(Vector<T> predicted, Vector<T> actual)

Parameters

predicted Vector<T>

The predicted values from the model.

actual Vector<T>

The actual (target) values.

Returns

Vector<T>

A vector containing the derivatives of Margin loss for each prediction.

CalculateLoss(Vector<T>, Vector<T>)

Calculates the Margin loss between predicted and actual values.

public override T CalculateLoss(Vector<T> predicted, Vector<T> actual)

Parameters

predicted Vector<T>

The predicted values from the model.

actual Vector<T>

The actual (target) values.

Returns

T

The Margin loss value.