SG++
sgpp::datadriven::ModelFittingBase Class Referenceabstract

Base class for arbitrary machine learning models based on adaptive sparse grids. More...

#include <ModelFittingBase.hpp>

Inheritance diagram for sgpp::datadriven::ModelFittingBase:
sgpp::datadriven::ModelFittingDensityEstimation sgpp::datadriven::ModelFittingLeastSquares

Public Member Functions

virtual double evaluate (const DataVector &sample) const =0
 Evaluate the fitted model at a single data point. More...
 
virtual void evaluate (DataMatrix &samples, DataVector &results)=0
 Evaluate the fitted model on a set of data points. More...
 
virtual void fit (Dataset &dataset)=0
 Polymorphic clone pattern. More...
 
const FitterConfigurationgetFitterConfiguration () const
 Get the configuration of the fitter object. More...
 
const GridgetGrid () const
 Get the underlying grid object for the current model. More...
 
const DataVectorgetSurpluses () const
 Get the surpluses of the current grid. More...
 
 ModelFittingBase ()
 Default constructor. More...
 
 ModelFittingBase (const ModelFittingBase &rhs)=delete
 Copy constructor - we cannot deep copy all member variables yet. More...
 
 ModelFittingBase (ModelFittingBase &&rhs)=default
 Move constructor. More...
 
ModelFittingBaseoperator= (const ModelFittingBase &rhs)=delete
 Copy assign operator - we cannot deep copy all member variables yet. More...
 
ModelFittingBaseoperator= (ModelFittingBase &&rhs)=default
 Move assign operator. More...
 
virtual bool refine ()=0
 Improve accuracy of the model on the given training data by adaptive refinement of the grid. More...
 
virtual void update (Dataset &dataset)=0
 Train the grid of an existing model with new samples. More...
 
virtual ~ModelFittingBase ()=default
 virtual destructor. More...
 

Protected Member Functions

GridbuildGrid (const RegularGridConfiguration &gridConfig) const
 Factory member function that generates a grid from configuration. More...
 
SLESolverbuildSolver (const SLESolverConfiguration &config) const
 Factory member function to build the solver for the least squares regression problem according to the config. More...
 
void reconfigureSolver (SLESolver &solver, const SLESolverConfiguration &config) const
 Configure solver based on the desired configuration. More...
 

Protected Attributes

DataVector alpha
 hierarchical surpluses of the grid. More...
 
std::unique_ptr< FitterConfigurationconfig
 Configuration object for the fitter. More...
 
Datasetdataset
 Pointer to sgpp::datadriven::Dataset. More...
 
std::unique_ptr< Gridgrid
 the sparse grid that approximates the data. More...
 
std::unique_ptr< SLESolversolver
 Solver for the learning problem. More...
 

Detailed Description

Base class for arbitrary machine learning models based on adaptive sparse grids.

A model tries to generalize high dimensional training data by using sparse grids. An underlying model can be trained using training data, its accuracy can be improved by using the adaptivity of sparse grids and the underlying grid(s) of a model can be retrained on other data. Once a model is trained it can be evaluated on unseen data.

Constructor & Destructor Documentation

sgpp::datadriven::ModelFittingBase::ModelFittingBase ( )

Default constructor.

References alpha, dataset, grid, and solver.

sgpp::datadriven::ModelFittingBase::ModelFittingBase ( const ModelFittingBase rhs)
delete

Copy constructor - we cannot deep copy all member variables yet.

Parameters
rhsconst reference to the scorer object to copy from.
sgpp::datadriven::ModelFittingBase::ModelFittingBase ( ModelFittingBase &&  rhs)
default

Move constructor.

Parameters
rhsR-value reference to a scorer object to moved from.
virtual sgpp::datadriven::ModelFittingBase::~ModelFittingBase ( )
virtualdefault

virtual destructor.

Member Function Documentation

Grid * sgpp::datadriven::ModelFittingBase::buildGrid ( const RegularGridConfiguration gridConfig) const
protected
SLESolver * sgpp::datadriven::ModelFittingBase::buildSolver ( const SLESolverConfiguration config) const
protected

Factory member function to build the solver for the least squares regression problem according to the config.

Parameters
configconfiguratin for the solver object

References sgpp::datadriven::CG, sgpp::solver::SLESolverConfiguration::eps_, sgpp::solver::SLESolverConfiguration::maxIterations_, and sgpp::solver::SLESolverConfiguration::type_.

Referenced by sgpp::datadriven::ModelFittingLeastSquares::ModelFittingLeastSquares().

virtual double sgpp::datadriven::ModelFittingBase::evaluate ( const DataVector sample) const
pure virtual

Evaluate the fitted model at a single data point.

Parameters
samplevector with the coordinates in all dimensions of that sample.
Returns
evaluation of the model.

Implemented in sgpp::datadriven::ModelFittingDensityEstimation, and sgpp::datadriven::ModelFittingLeastSquares.

Referenced by sgpp::datadriven::Scorer::test().

virtual void sgpp::datadriven::ModelFittingBase::evaluate ( DataMatrix samples,
DataVector results 
)
pure virtual

Evaluate the fitted model on a set of data points.

Parameters
samplesmatrix where each row represents a sample and the columns contain the coordinates in all dimensions of that sample.
resultsvector where each row will contain the evaluation of the respective sample on the current model.

Implemented in sgpp::datadriven::ModelFittingDensityEstimation, and sgpp::datadriven::ModelFittingLeastSquares.

virtual void sgpp::datadriven::ModelFittingBase::fit ( Dataset dataset)
pure virtual

Polymorphic clone pattern.

Returns
deep copy of this object. New object is owned by caller. Fit the grid to the dataset by determinig the weights of an initial grid
Parameters
datasetthe training dataset that is used to fit the model.

Implemented in sgpp::datadriven::ModelFittingDensityEstimation, and sgpp::datadriven::ModelFittingLeastSquares.

Referenced by sgpp::datadriven::Scorer::train().

const FitterConfiguration & sgpp::datadriven::ModelFittingBase::getFitterConfiguration ( ) const

Get the configuration of the fitter object.

Returns
configuration of the fitter object

References config.

const Grid & sgpp::datadriven::ModelFittingBase::getGrid ( ) const

Get the underlying grid object for the current model.

Returns
the grid object.

References grid.

Referenced by sgpp::datadriven::Scorer::refine().

const DataVector & sgpp::datadriven::ModelFittingBase::getSurpluses ( ) const

Get the surpluses of the current grid.

Returns
vector of surpluses.

References alpha.

ModelFittingBase& sgpp::datadriven::ModelFittingBase::operator= ( const ModelFittingBase rhs)
delete

Copy assign operator - we cannot deep copy all member variables yet.

Parameters
rhsconst reference to the scorer object to copy from.
Returns
rerefernce to this with updated values.
ModelFittingBase& sgpp::datadriven::ModelFittingBase::operator= ( ModelFittingBase &&  rhs)
default

Move assign operator.

Parameters
rhsR-value reference to an a scorer object to move from.
Returns
rerefernce to this with updated values.
void sgpp::datadriven::ModelFittingBase::reconfigureSolver ( SLESolver solver,
const SLESolverConfiguration config 
) const
protected

Configure solver based on the desired configuration.

Parameters
solverthe solver object to be modified.
configconfiguration updating the for the solver.

References sgpp::solver::SLESolverConfiguration::eps_, sgpp::solver::SLESolverConfiguration::maxIterations_, sgpp::solver::SGSolver::setEpsilon(), and sgpp::solver::SGSolver::setMaxIterations().

Referenced by sgpp::datadriven::ModelFittingLeastSquares::update().

virtual bool sgpp::datadriven::ModelFittingBase::refine ( )
pure virtual

Improve accuracy of the model on the given training data by adaptive refinement of the grid.

Returns
true if refinement was performed, else false.

Implemented in sgpp::datadriven::ModelFittingDensityEstimation, and sgpp::datadriven::ModelFittingLeastSquares.

Referenced by sgpp::datadriven::Scorer::refine().

virtual void sgpp::datadriven::ModelFittingBase::update ( Dataset dataset)
pure virtual

Train the grid of an existing model with new samples.

Parameters
datasetthe training dataset that is used to fit the model.

Implemented in sgpp::datadriven::ModelFittingDensityEstimation, and sgpp::datadriven::ModelFittingLeastSquares.

Member Data Documentation

Dataset* sgpp::datadriven::ModelFittingBase::dataset
protected

Pointer to sgpp::datadriven::Dataset.

The initial grid is fitted on the given data. Adaptive refinement is then performed on the very same data. The used dataset used for refinement overwritten once either fit() or update() introduce a new dataset.

Referenced by sgpp::datadriven::ModelFittingLeastSquares::fit(), ModelFittingBase(), and sgpp::datadriven::ModelFittingLeastSquares::update().

std::unique_ptr<SLESolver> sgpp::datadriven::ModelFittingBase::solver
protected

The documentation for this class was generated from the following files: