Table of Contents

Class GradLSTMCellInputOp

Namespace
AiDotNet.JitCompiler.IR.Operations
Assembly
AiDotNet.dll

Backward operation for LSTMCellOp - computes gradient for input.

public class GradLSTMCellInputOp : BackwardOp
Inheritance
GradLSTMCellInputOp
Inherited Members

Remarks

LSTM backward pass uses the chain rule through the gate computations: - grad flows back through output gate, cell state, forget/input gates - Requires saved forward activations for correct gradient computation

For Beginners: LSTM has multiple paths for gradients to flow:

The LSTM has 4 gates (input, forget, cell candidate, output) and 2 states (hidden, cell). During backpropagation, we need to compute how the loss changes when we change:

  1. The input at this timestep
  2. The hidden state from previous timestep
  3. The cell state from previous timestep
  4. All the weights (W_ih, W_hh) and biases

This complexity is what makes LSTM training work well for sequences!

Properties

HiddenSize

Hidden state size.

public int HiddenSize { get; set; }

Property Value

int

InputIndex

Which gradient: 0 = input, 1 = hidden, 2 = cell, 3 = W_ih, 4 = W_hh, 5 = bias.

public int InputIndex { get; set; }

Property Value

int

Methods

ToString()

Gets a string representation of this operation for debugging.

public override string ToString()

Returns

string

A string describing this operation.

Remarks

The string format is: "tOutput = OpType(tInput1, tInput2, ...) : Type [Shape]"

For Beginners: This creates a readable description of the operation.

Example outputs:

  • "t2 = Add(t0, t1) : Float32 [3, 4]"
  • "t5 = MatMul(t3, t4) : Float32 [128, 256]"
  • "t8 = ReLU(t7) : Float32 [32, 128]"

This is super helpful for debugging - you can see exactly what each operation does and what shape tensors flow through the graph.

Validate()

Validates that this operation is correctly formed.

public override bool Validate()

Returns

bool

True if valid, false otherwise.

Remarks

Basic validation checks that the operation has required information. Derived classes can override to add operation-specific validation.

For Beginners: This checks that the operation makes sense.

Basic checks:

  • Output ID is valid (non-negative)
  • Has the right number of inputs
  • Shapes are compatible

Specific operations add their own checks:

  • MatMul: inner dimensions must match
  • Conv2D: kernel size must be valid
  • Reshape: total elements must be preserved

If validation fails, the operation can't be compiled.