Class DreamFusionModel<T>
DreamFusion model for text-to-3D generation via Score Distillation Sampling (SDS). Uses a 2D diffusion prior to optimize a 3D neural radiance field representation. Based on "DreamFusion: Text-to-3D using 2D Diffusion" (Poole et al., 2022).
public class DreamFusionModel<T> : LatentDiffusionModelBase<T>, ILatentDiffusionModel<T>, IDiffusionModel<T>, IFullModel<T, Tensor<T>, Tensor<T>>, IModel<Tensor<T>, Tensor<T>, ModelMetadata<T>>, IModelSerializer, ICheckpointableModel, IParameterizable<T, Tensor<T>, Tensor<T>>, IFeatureAware, IFeatureImportance<T>, ICloneable<IFullModel<T, Tensor<T>, Tensor<T>>>, IGradientComputable<T, Tensor<T>, Tensor<T>>, IJitCompilable<T>
Type Parameters
TThe numeric type for computations.
- Inheritance
-
DreamFusionModel<T>
- Implements
- Inherited Members
- Extension Methods
Remarks
DreamFusion revolutionized text-to-3D generation by using pretrained 2D diffusion models to guide the optimization of a 3D scene representation (NeRF). The key insight is that a 2D diffusion model can serve as a "critic" for 3D content through Score Distillation Sampling.
For Beginners: Think of DreamFusion as using an AI art critic (the 2D diffusion model) to guide a 3D sculptor (the NeRF). The critic looks at 2D views of the sculpture and gives feedback on how to make it look more like the text description.
How it works:
- You describe what you want: "a DSLR photo of a peacock on a surfboard"
- DreamFusion renders 2D images of a 3D shape from random viewpoints
- The 2D diffusion model evaluates: "Does this look like the prompt?"
- Gradients flow back to improve the 3D representation
- After many iterations, you get a full 3D object you can view from any angle
Key features:
- Creates full 3D assets from text descriptions
- View-consistent: looks correct from any angle
- Leverages the quality of 2D image generators
- No 3D training data required
Technical details: - Uses NeRF (Neural Radiance Field) for 3D representation - Employs Score Distillation Sampling (SDS) loss - Samples random camera views during optimization - Uses classifier-free guidance with high scale (typically 100) - Supports mesh extraction via marching cubes
Constructors
DreamFusionModel(IDiffusionModel<T>?, DreamFusionConfig?, IConditioningModule<T>?, int?)
Initializes a new instance of the DreamFusionModel.
public DreamFusionModel(IDiffusionModel<T>? diffusionPrior = null, DreamFusionConfig? config = null, IConditioningModule<T>? conditioner = null, int? seed = null)
Parameters
diffusionPriorIDiffusionModel<T>The 2D diffusion model to use as the prior.
configDreamFusionConfigOptional configuration settings.
conditionerIConditioningModule<T>Optional conditioning module.
seedint?Optional random seed for reproducibility.
Properties
Conditioner
Gets the conditioning module (optional, for conditioned generation).
public override IConditioningModule<T>? Conditioner { get; }
Property Value
Config
Configuration for DreamFusion model.
public DreamFusionConfig Config { get; }
Property Value
LatentChannels
Gets the number of latent channels.
public override int LatentChannels { get; }
Property Value
Remarks
Typically 4 for Stable Diffusion models.
NoisePredictor
Gets the noise predictor model (U-Net, DiT, etc.).
public override INoisePredictor<T> NoisePredictor { get; }
Property Value
ParameterCount
Gets the number of parameters in the model.
public override int ParameterCount { get; }
Property Value
Remarks
This property returns the total count of trainable parameters in the model. It's useful for understanding model complexity and memory requirements.
VAE
Gets the VAE model used for encoding and decoding.
public override IVAEModel<T> VAE { get; }
Property Value
- IVAEModel<T>
Methods
Clone()
Creates a deep copy of the model.
public override IDiffusionModel<T> Clone()
Returns
- IDiffusionModel<T>
A new instance with the same parameters.
DeepCopy()
Creates a deep copy of this object.
public override IFullModel<T, Tensor<T>, Tensor<T>> DeepCopy()
Returns
- IFullModel<T, Tensor<T>, Tensor<T>>
ExtractMesh(NeRFResult<T>, int, double)
Generates a mesh from the trained NeRF using marching cubes.
public DreamMesh<T> ExtractMesh(NeRFResult<T> result, int gridResolution = 128, double threshold = 0.5)
Parameters
resultNeRFResult<T>The NeRF result from training.
gridResolutionintResolution of the marching cubes grid.
thresholddoubleDensity threshold for surface extraction.
Returns
- DreamMesh<T>
The extracted mesh.
GenerateAsync(string, int, double, double, IProgress<double>?, CancellationToken)
Generates a 3D representation from a text prompt using Score Distillation Sampling.
public Task<NeRFResult<T>> GenerateAsync(string prompt, int numIterations = 10000, double learningRate = 0.001, double guidanceScale = 100, IProgress<double>? progress = null, CancellationToken cancellationToken = default)
Parameters
promptstringThe text prompt describing the desired 3D object.
numIterationsintNumber of optimization iterations.
learningRatedoubleLearning rate for NeRF optimization.
guidanceScaledoubleClassifier-free guidance scale.
progressIProgress<double>Optional progress reporter.
cancellationTokenCancellationTokenOptional cancellation token.
Returns
- Task<NeRFResult<T>>
The optimized NeRF parameters as a tensor.
GetModelMetadata()
Retrieves metadata and performance metrics about the trained model.
public override ModelMetadata<T> GetModelMetadata()
Returns
- ModelMetadata<T>
An object containing metadata and performance metrics about the trained model.
Remarks
This method provides information about the model's structure, parameters, and performance metrics.
For Beginners: Model metadata is like a report card for your machine learning model.
Just as a report card shows how well a student is performing in different subjects, model metadata shows how well your model is performing and provides details about its structure.
This information typically includes:
- Accuracy measures: How well does the model's predictions match actual values?
- Error metrics: How far off are the model's predictions on average?
- Model parameters: What patterns did the model learn from the data?
- Training information: How long did training take? How many iterations were needed?
For example, in a house price prediction model, metadata might include:
- Average prediction error (e.g., off by $15,000 on average)
- How strongly each feature (bedrooms, location) influences the prediction
- How well the model fits the training data
This information helps you understand your model's strengths and weaknesses, and decide if it's ready to use or needs more training.
GetParameters()
Gets the parameters that can be optimized.
public override Vector<T> GetParameters()
Returns
- Vector<T>
PredictNoise(Tensor<T>, int)
Predicts the noise in a noisy sample at a given timestep.
public override Tensor<T> PredictNoise(Tensor<T> noisySample, int timestep)
Parameters
noisySampleTensor<T>The noisy input sample.
timestepintThe current timestep in the diffusion process.
Returns
- Tensor<T>
The predicted noise tensor.
Remarks
This is the core prediction that the model learns. Given a noisy sample at timestep t, predict what noise was added to create it.
For Beginners: The model looks at a noisy image and guesses "what noise was added to make it look like this?" This prediction is then used to remove that noise and get a cleaner image.
RenderView(NeRFResult<T>, CameraPose, int)
Renders an image from the trained NeRF at a specific camera pose.
public Tensor<T> RenderView(NeRFResult<T> result, CameraPose cameraPose, int resolution = 256)
Parameters
resultNeRFResult<T>The NeRF result from training.
cameraPoseCameraPoseThe camera pose to render from.
resolutionintThe output resolution.
Returns
- Tensor<T>
The rendered image as a tensor.
SetParameters(Vector<T>)
Sets the model parameters.
public override void SetParameters(Vector<T> parameters)
Parameters
parametersVector<T>The parameter vector to set.
Remarks
This method allows direct modification of the model's internal parameters.
This is useful for optimization algorithms that need to update parameters iteratively.
If the length of parameters does not match ParameterCount,
an ArgumentException should be thrown.
Exceptions
- ArgumentException
Thrown when the length of
parametersdoes not match ParameterCount.