Table of Contents

Class DeepRitzMethod<T>

Namespace
AiDotNet.PhysicsInformed.PINNs
Assembly
AiDotNet.dll

Implements the Deep Ritz Method for solving variational problems and PDEs.

public class DeepRitzMethod<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable

Type Parameters

T

The numeric type used for calculations.

Inheritance
DeepRitzMethod<T>
Implements
IFullModel<T, Tensor<T>, Tensor<T>>
IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>
IParameterizable<T, Tensor<T>, Tensor<T>>
ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>
IGradientComputable<T, Tensor<T>, Tensor<T>>
Inherited Members
Extension Methods

Remarks

For Beginners: The Deep Ritz Method is a variational approach to solving PDEs using neural networks. Instead of minimizing the PDE residual directly (like standard PINNs), it minimizes an energy functional.

The Ritz Method (Classical): Many PDEs can be reformulated as minimization problems. For example:

  • Poisson equation: -∇²u = f is equivalent to minimizing E(u) = ½∫|∇u|² dx - ∫fu dx
  • This is called the "variational formulation"
  • The solution minimizes the energy functional

Deep Ritz (Modern):

  • Use a neural network to represent u(x)
  • Compute the energy functional using automatic differentiation
  • Train the network to minimize the energy
  • Naturally incorporates boundary conditions

Advantages over Standard PINNs:

  1. More stable training (minimizing energy vs. residual)
  2. Natural framework for problems with variational structure
  3. Often converges faster
  4. Physical interpretation (energy minimization)

Applications:

  • Elasticity (minimize strain energy)
  • Electrostatics (minimize electrostatic energy)
  • Fluid dynamics (minimize dissipation)
  • Quantum mechanics (minimize expected energy)
  • Optimal control problems

Key Difference from PINNs: PINN: Minimize ||PDE residual||² Deep Ritz: Minimize ∫ Energy(u, ∇u) dx

Both solve the same PDE, but Deep Ritz uses the variational (energy) formulation, which can be more natural and stable for certain problems.

Constructors

DeepRitzMethod(NeuralNetworkArchitecture<T>, Func<T[], T[], T[,], T>, Func<T[], bool>?, Func<T[], T[]>?, int)

Initializes a new instance of the Deep Ritz Method.

public DeepRitzMethod(NeuralNetworkArchitecture<T> architecture, Func<T[], T[], T[,], T> energyFunctional, Func<T[], bool>? boundaryCheck = null, Func<T[], T[]>? boundaryValue = null, int numQuadraturePoints = 10000)

Parameters

architecture NeuralNetworkArchitecture<T>

The neural network architecture.

energyFunctional Func<T[], T[], T[,], T>

The energy functional to minimize: E(x, u, ∇u).

boundaryCheck Func<T[], bool>

Function to check if a point is on the boundary.

boundaryValue Func<T[], T[]>

Function returning the boundary value at a point.

numQuadraturePoints int

Number of quadrature points for numerical integration.

Remarks

For Beginners: The energy functional should encode the physics of your problem.

Example - Poisson Equation (-∇²u = f): Energy: E(u) = ½∫|∇u|² dx - ∫fu dx Implementation: energyFunctional = (x, u, grad_u) => 0.5 * ||grad_u||² - f(x) * u

Example - Linear Elasticity: Energy: E(u) = ∫ strain_energy(∇u) dx Implementation: energyFunctional = (x, u, grad_u) => compute_strain_energy(grad_u)

The method will integrate this over the domain using quadrature points.

Properties

SupportsTraining

Indicates whether this model supports training.

public override bool SupportsTraining { get; }

Property Value

bool

Methods

ComputeTotalEnergy()

Computes the total energy functional by integrating over the domain.

public T ComputeTotalEnergy()

Returns

T

The total energy value.

Remarks

For Beginners: This is the key method that computes ∫E(u, ∇u)dx numerically.

Steps:

  1. For each quadrature point x_i: a) Evaluate u(x_i) using the network b) Compute ∇u(x_i) using automatic differentiation c) Evaluate the energy density E(x_i, u(x_i), ∇u(x_i))
  2. Sum weighted energies: Total = Σ w_i * E_i
  3. Add boundary penalty if needed

The gradient of this total energy with respect to network parameters tells us how to update the network to minimize energy.

CreateNewInstance()

Creates a new instance with the same configuration.

protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()

Returns

IFullModel<T, Tensor<T>, Tensor<T>>

New Deep Ritz instance.

DeserializeNetworkSpecificData(BinaryReader)

Deserializes Deep Ritz-specific data.

protected override void DeserializeNetworkSpecificData(BinaryReader reader)

Parameters

reader BinaryReader

Binary reader.

Forward(Tensor<T>)

Performs a forward pass through the network.

public Tensor<T> Forward(Tensor<T> input)

Parameters

input Tensor<T>

Input tensor for evaluation.

Returns

Tensor<T>

Network output tensor.

GetModelMetadata()

Gets metadata about the Deep Ritz model.

public override ModelMetadata<T> GetModelMetadata()

Returns

ModelMetadata<T>

Model metadata.

GetSolution(T[])

Gets the solution at a specific point.

public T[] GetSolution(T[] point)

Parameters

point T[]

Returns

T[]

InitializeLayers()

Initializes the layers of the neural network based on the architecture.

protected override void InitializeLayers()

Remarks

For Beginners: This method sets up all the layers in your neural network according to the architecture you've defined. It's like assembling the parts of your network before you can use it.

Predict(Tensor<T>)

Makes a prediction using the Deep Ritz network.

public override Tensor<T> Predict(Tensor<T> input)

Parameters

input Tensor<T>

Input tensor.

Returns

Tensor<T>

Predicted output tensor.

SerializeNetworkSpecificData(BinaryWriter)

Serializes Deep Ritz-specific data.

protected override void SerializeNetworkSpecificData(BinaryWriter writer)

Parameters

writer BinaryWriter

Binary writer.

Solve(int, double, bool, int, double)

Trains the network to minimize the energy functional.

public TrainingHistory<T> Solve(int epochs = 1000, double learningRate = 0.001, bool verbose = true, int batchSize = 256, double derivativeStep = 0.0001)

Parameters

epochs int

Number of training epochs.

learningRate double

Learning rate for optimization.

verbose bool

Whether to print progress.

batchSize int

Number of quadrature points per batch.

derivativeStep double

Finite-difference step size for input derivatives.

Returns

TrainingHistory<T>

Train(Tensor<T>, Tensor<T>)

Performs a basic supervised training step using MSE loss.

public override void Train(Tensor<T> input, Tensor<T> expectedOutput)

Parameters

input Tensor<T>

Training input tensor.

expectedOutput Tensor<T>

Expected output tensor.

UpdateParameters(Vector<T>)

Updates the network parameters from a flattened vector.

public override void UpdateParameters(Vector<T> parameters)

Parameters

parameters Vector<T>

Parameter vector.