Class SoftKNNOp
- Namespace
- AiDotNet.JitCompiler.IR.Operations
- Assembly
- AiDotNet.dll
Represents a soft K-Nearest Neighbors operation for differentiable instance-based learning.
public class SoftKNNOp : IROp
- Inheritance
-
SoftKNNOp
- Inherited Members
Remarks
Implements differentiable KNN using attention-weighted contributions from all support vectors instead of hard k-selection. This enables gradient-based optimization and JIT compilation.
The soft KNN computes:
distances[i] = ||input - supportVectors[i]||²
weights = softmax(-distances / temperature)
output = Σ weights[i] * labels[i]
For Beginners: Normal KNN finds the k closest neighbors and averages their labels. Soft KNN considers ALL neighbors but weights them by how close they are: - Very close neighbors get high weights (contribute more to prediction) - Far neighbors get very low weights (contribute almost nothing)
This is like "all neighbors vote, but closer neighbors have louder voices." The temperature controls how much we favor close neighbors over far ones.
Properties
DistanceType
Gets or sets the distance metric type (0=L2/Euclidean, 1=L1/Manhattan).
public int DistanceType { get; set; }
Property Value
Temperature
Gets or sets the temperature parameter controlling attention sharpness. Lower temperature = more focused on nearest neighbors.
public double Temperature { get; set; }
Property Value
Methods
ToString()
Gets a string representation of this operation for debugging.
public override string ToString()
Returns
- string
A string describing this operation.
Remarks
The string format is: "tOutput = OpType(tInput1, tInput2, ...) : Type [Shape]"
For Beginners: This creates a readable description of the operation.
Example outputs:
- "t2 = Add(t0, t1) : Float32 [3, 4]"
- "t5 = MatMul(t3, t4) : Float32 [128, 256]"
- "t8 = ReLU(t7) : Float32 [32, 128]"
This is super helpful for debugging - you can see exactly what each operation does and what shape tensors flow through the graph.
Validate()
Validates that this operation is correctly formed.
public override bool Validate()
Returns
- bool
True if valid, false otherwise.
Remarks
Basic validation checks that the operation has required information. Derived classes can override to add operation-specific validation.
For Beginners: This checks that the operation makes sense.
Basic checks:
- Output ID is valid (non-negative)
- Has the right number of inputs
- Shapes are compatible
Specific operations add their own checks:
- MatMul: inner dimensions must match
- Conv2D: kernel size must be valid
- Reshape: total elements must be preserved
If validation fails, the operation can't be compiled.