Table of Contents

Class SoftSplitOp

Namespace
AiDotNet.JitCompiler.IR.Operations
Assembly
AiDotNet.dll

Represents a soft split operation for differentiable decision trees in the IR.

public class SoftSplitOp : IROp
Inheritance
SoftSplitOp
Inherited Members

Remarks

Implements differentiable decision tree nodes using sigmoid gating instead of hard branching. This enables gradient-based learning and JIT compilation of tree-based models.

The soft split computes:

p_left = σ((threshold - x[featureIndex]) / temperature)
output = p_left * leftValue + (1 - p_left) * rightValue

For Beginners: Normal decision trees make hard yes/no decisions at each node. A soft split makes a "probabilistic" decision - instead of choosing left OR right, it takes a weighted average of both paths based on how close the input is to the threshold.

Example with temperature=1:

  • If x[feature] is much less than threshold: p_left ≈ 1 (mostly goes left)
  • If x[feature] equals threshold: p_left = 0.5 (50/50 split)
  • If x[feature] is much greater than threshold: p_left ≈ 0 (mostly goes right)

This makes the tree differentiable (can compute gradients for training) while still approximating hard decision behavior when temperature is low.

Properties

FeatureIndex

Gets or sets the index of the feature to split on.

public int FeatureIndex { get; set; }

Property Value

int

Temperature

Gets or sets the temperature parameter controlling split sharpness. Lower temperature = sharper (more like hard split), higher = softer.

public double Temperature { get; set; }

Property Value

double

Threshold

Gets or sets the threshold value for the split.

public double Threshold { get; set; }

Property Value

double

Methods

ToString()

Gets a string representation of this operation for debugging.

public override string ToString()

Returns

string

A string describing this operation.

Remarks

The string format is: "tOutput = OpType(tInput1, tInput2, ...) : Type [Shape]"

For Beginners: This creates a readable description of the operation.

Example outputs:

  • "t2 = Add(t0, t1) : Float32 [3, 4]"
  • "t5 = MatMul(t3, t4) : Float32 [128, 256]"
  • "t8 = ReLU(t7) : Float32 [32, 128]"

This is super helpful for debugging - you can see exactly what each operation does and what shape tensors flow through the graph.

Validate()

Validates that this operation is correctly formed.

public override bool Validate()

Returns

bool

True if valid, false otherwise.

Remarks

Basic validation checks that the operation has required information. Derived classes can override to add operation-specific validation.

For Beginners: This checks that the operation makes sense.

Basic checks:

  • Output ID is valid (non-negative)
  • Has the right number of inputs
  • Shapes are compatible

Specific operations add their own checks:

  • MatMul: inner dimensions must match
  • Conv2D: kernel size must be valid
  • Reshape: total elements must be preserved

If validation fails, the operation can't be compiled.