Class SoftSplitOp
- Namespace
- AiDotNet.JitCompiler.IR.Operations
- Assembly
- AiDotNet.dll
Represents a soft split operation for differentiable decision trees in the IR.
public class SoftSplitOp : IROp
- Inheritance
-
SoftSplitOp
- Inherited Members
Remarks
Implements differentiable decision tree nodes using sigmoid gating instead of hard branching. This enables gradient-based learning and JIT compilation of tree-based models.
The soft split computes:
p_left = σ((threshold - x[featureIndex]) / temperature)
output = p_left * leftValue + (1 - p_left) * rightValue
For Beginners: Normal decision trees make hard yes/no decisions at each node. A soft split makes a "probabilistic" decision - instead of choosing left OR right, it takes a weighted average of both paths based on how close the input is to the threshold.
Example with temperature=1:
- If x[feature] is much less than threshold: p_left ≈ 1 (mostly goes left)
- If x[feature] equals threshold: p_left = 0.5 (50/50 split)
- If x[feature] is much greater than threshold: p_left ≈ 0 (mostly goes right)
This makes the tree differentiable (can compute gradients for training) while still approximating hard decision behavior when temperature is low.
Properties
FeatureIndex
Gets or sets the index of the feature to split on.
public int FeatureIndex { get; set; }
Property Value
Temperature
Gets or sets the temperature parameter controlling split sharpness. Lower temperature = sharper (more like hard split), higher = softer.
public double Temperature { get; set; }
Property Value
Threshold
Gets or sets the threshold value for the split.
public double Threshold { get; set; }
Property Value
Methods
ToString()
Gets a string representation of this operation for debugging.
public override string ToString()
Returns
- string
A string describing this operation.
Remarks
The string format is: "tOutput = OpType(tInput1, tInput2, ...) : Type [Shape]"
For Beginners: This creates a readable description of the operation.
Example outputs:
- "t2 = Add(t0, t1) : Float32 [3, 4]"
- "t5 = MatMul(t3, t4) : Float32 [128, 256]"
- "t8 = ReLU(t7) : Float32 [32, 128]"
This is super helpful for debugging - you can see exactly what each operation does and what shape tensors flow through the graph.
Validate()
Validates that this operation is correctly formed.
public override bool Validate()
Returns
- bool
True if valid, false otherwise.
Remarks
Basic validation checks that the operation has required information. Derived classes can override to add operation-specific validation.
For Beginners: This checks that the operation makes sense.
Basic checks:
- Output ID is valid (non-negative)
- Has the right number of inputs
- Shapes are compatible
Specific operations add their own checks:
- MatMul: inner dimensions must match
- Conv2D: kernel size must be valid
- Reshape: total elements must be preserved
If validation fails, the operation can't be compiled.