Class DeepRitzMethod<T>
- Namespace
- AiDotNet.PhysicsInformed.PINNs
- Assembly
- AiDotNet.dll
Implements the Deep Ritz Method for solving variational problems and PDEs.
public class DeepRitzMethod<T> : NeuralNetworkBase<T>, INeuralNetworkModel<T>, INeuralNetwork<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>, IInterpretableModel<T>, IInputGradientComputable<T>, IDisposable
Type Parameters
TThe numeric type used for calculations.
- Inheritance
-
DeepRitzMethod<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
For Beginners: The Deep Ritz Method is a variational approach to solving PDEs using neural networks. Instead of minimizing the PDE residual directly (like standard PINNs), it minimizes an energy functional.
The Ritz Method (Classical): Many PDEs can be reformulated as minimization problems. For example:
- Poisson equation: -∇²u = f is equivalent to minimizing E(u) = ½∫|∇u|² dx - ∫fu dx
- This is called the "variational formulation"
- The solution minimizes the energy functional
Deep Ritz (Modern):
- Use a neural network to represent u(x)
- Compute the energy functional using automatic differentiation
- Train the network to minimize the energy
- Naturally incorporates boundary conditions
Advantages over Standard PINNs:
- More stable training (minimizing energy vs. residual)
- Natural framework for problems with variational structure
- Often converges faster
- Physical interpretation (energy minimization)
Applications:
- Elasticity (minimize strain energy)
- Electrostatics (minimize electrostatic energy)
- Fluid dynamics (minimize dissipation)
- Quantum mechanics (minimize expected energy)
- Optimal control problems
Key Difference from PINNs: PINN: Minimize ||PDE residual||² Deep Ritz: Minimize ∫ Energy(u, ∇u) dx
Both solve the same PDE, but Deep Ritz uses the variational (energy) formulation, which can be more natural and stable for certain problems.
Constructors
DeepRitzMethod(NeuralNetworkArchitecture<T>, Func<T[], T[], T[,], T>, Func<T[], bool>?, Func<T[], T[]>?, int)
Initializes a new instance of the Deep Ritz Method.
public DeepRitzMethod(NeuralNetworkArchitecture<T> architecture, Func<T[], T[], T[,], T> energyFunctional, Func<T[], bool>? boundaryCheck = null, Func<T[], T[]>? boundaryValue = null, int numQuadraturePoints = 10000)
Parameters
architectureNeuralNetworkArchitecture<T>The neural network architecture.
energyFunctionalFunc<T[], T[], T[,], T>The energy functional to minimize: E(x, u, ∇u).
boundaryCheckFunc<T[], bool>Function to check if a point is on the boundary.
boundaryValueFunc<T[], T[]>Function returning the boundary value at a point.
numQuadraturePointsintNumber of quadrature points for numerical integration.
Remarks
For Beginners: The energy functional should encode the physics of your problem.
Example - Poisson Equation (-∇²u = f): Energy: E(u) = ½∫|∇u|² dx - ∫fu dx Implementation: energyFunctional = (x, u, grad_u) => 0.5 * ||grad_u||² - f(x) * u
Example - Linear Elasticity: Energy: E(u) = ∫ strain_energy(∇u) dx Implementation: energyFunctional = (x, u, grad_u) => compute_strain_energy(grad_u)
The method will integrate this over the domain using quadrature points.
Properties
SupportsTraining
Indicates whether this model supports training.
public override bool SupportsTraining { get; }
Property Value
Methods
ComputeTotalEnergy()
Computes the total energy functional by integrating over the domain.
public T ComputeTotalEnergy()
Returns
- T
The total energy value.
Remarks
For Beginners: This is the key method that computes ∫E(u, ∇u)dx numerically.
Steps:
- For each quadrature point x_i: a) Evaluate u(x_i) using the network b) Compute ∇u(x_i) using automatic differentiation c) Evaluate the energy density E(x_i, u(x_i), ∇u(x_i))
- Sum weighted energies: Total = Σ w_i * E_i
- Add boundary penalty if needed
The gradient of this total energy with respect to network parameters tells us how to update the network to minimize energy.
CreateNewInstance()
Creates a new instance with the same configuration.
protected override IFullModel<T, Tensor<T>, Tensor<T>> CreateNewInstance()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
New Deep Ritz instance.
DeserializeNetworkSpecificData(BinaryReader)
Deserializes Deep Ritz-specific data.
protected override void DeserializeNetworkSpecificData(BinaryReader reader)
Parameters
readerBinaryReaderBinary reader.
Forward(Tensor<T>)
Performs a forward pass through the network.
public Tensor<T> Forward(Tensor<T> input)
Parameters
inputTensor<T>Input tensor for evaluation.
Returns
- Tensor<T>
Network output tensor.
GetModelMetadata()
Gets metadata about the Deep Ritz model.
public override ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
Model metadata.
GetSolution(T[])
Gets the solution at a specific point.
public T[] GetSolution(T[] point)
Parameters
pointT[]
Returns
- T[]
InitializeLayers()
Initializes the layers of the neural network based on the architecture.
protected override void InitializeLayers()
Remarks
For Beginners: This method sets up all the layers in your neural network according to the architecture you've defined. It's like assembling the parts of your network before you can use it.
Predict(Tensor<T>)
Makes a prediction using the Deep Ritz network.
public override Tensor<T> Predict(Tensor<T> input)
Parameters
inputTensor<T>Input tensor.
Returns
- Tensor<T>
Predicted output tensor.
SerializeNetworkSpecificData(BinaryWriter)
Serializes Deep Ritz-specific data.
protected override void SerializeNetworkSpecificData(BinaryWriter writer)
Parameters
writerBinaryWriterBinary writer.
Solve(int, double, bool, int, double)
Trains the network to minimize the energy functional.
public TrainingHistory<T> Solve(int epochs = 1000, double learningRate = 0.001, bool verbose = true, int batchSize = 256, double derivativeStep = 0.0001)
Parameters
epochsintNumber of training epochs.
learningRatedoubleLearning rate for optimization.
verboseboolWhether to print progress.
batchSizeintNumber of quadrature points per batch.
derivativeStepdoubleFinite-difference step size for input derivatives.
Returns
Train(Tensor<T>, Tensor<T>)
Performs a basic supervised training step using MSE loss.
public override void Train(Tensor<T> input, Tensor<T> expectedOutput)
Parameters
inputTensor<T>Training input tensor.
expectedOutputTensor<T>Expected output tensor.
UpdateParameters(Vector<T>)
Updates the network parameters from a flattened vector.
public override void UpdateParameters(Vector<T> parameters)
Parameters
parametersVector<T>Parameter vector.