SG++-Doxygen-Documentation
sgpp::datadriven::RegressionLearner Class Reference

The RegressionLearner class Solves a regression problem with continuous target vector. More...

#include <RegressionLearner.hpp>

Classes

class  Solver
 

Public Member Functions

sgpp::base::GridgetGrid ()
 getGrid More...
 
size_t getGridSize () const
 getGridSize More...
 
double getMSE (sgpp::base::DataMatrix &data, const sgpp::base::DataVector &y)
 getMSE More...
 
sgpp::base::DataVector getWeights () const
 getWeights More...
 
sgpp::base::DataVector predict (sgpp::base::DataMatrix &data)
 predict More...
 
 RegressionLearner (sgpp::base::RegularGridConfiguration gridConfig, sgpp::base::AdaptivityConfiguration adaptivityConfig, sgpp::solver::SLESolverConfiguration solverConfig, sgpp::solver::SLESolverConfiguration finalSolverConfig, datadriven::RegularizationConfiguration regularizationConfig, std::vector< std::vector< size_t >> terms)
 RegressionLearner. More...
 
 RegressionLearner (sgpp::base::RegularGridConfiguration gridConfig, sgpp::base::AdaptivityConfiguration adaptivityConfig, sgpp::solver::SLESolverConfiguration solverConfig, sgpp::solver::SLESolverConfiguration finalSolverConfig, datadriven::RegularizationConfiguration regularizationConfig)
 RegressionLearner. More...
 
void setWeights (sgpp::base::DataVector weights)
 setWeights More...
 
void train (sgpp::base::DataMatrix &trainDataset, sgpp::base::DataVector &classes)
 train fits a sparse grid regression model. More...
 

Detailed Description

The RegressionLearner class Solves a regression problem with continuous target vector.

Constructor & Destructor Documentation

◆ RegressionLearner() [1/2]

sgpp::datadriven::RegressionLearner::RegressionLearner ( sgpp::base::RegularGridConfiguration  gridConfig,
sgpp::base::AdaptivityConfiguration  adaptivityConfig,
sgpp::solver::SLESolverConfiguration  solverConfig,
sgpp::solver::SLESolverConfiguration  finalSolverConfig,
datadriven::RegularizationConfiguration  regularizationConfig,
std::vector< std::vector< size_t >>  terms 
)

RegressionLearner.

Parameters
gridConfig
adaptivityConfig
solverConfigis the solver used during each adaptivity step
finalSolverConfigis the solver used to build the final model
regularizationConfig
termsis a vector that contains all desired interaction terms. For example, if we want to include grid points that model an interaction between the first and the second predictor, we would include the vector [1,2] in terms.

◆ RegressionLearner() [2/2]

sgpp::datadriven::RegressionLearner::RegressionLearner ( sgpp::base::RegularGridConfiguration  gridConfig,
sgpp::base::AdaptivityConfiguration  adaptivityConfig,
sgpp::solver::SLESolverConfiguration  solverConfig,
sgpp::solver::SLESolverConfiguration  finalSolverConfig,
datadriven::RegularizationConfiguration  regularizationConfig 
)

RegressionLearner.

Parameters
gridConfig
adaptivityConfig
solverConfigis the solver used during each adaptivity step
finalSolverConfigis the solver used to build the final model
regularizationConfig

Member Function Documentation

◆ getGrid()

base::Grid & sgpp::datadriven::RegressionLearner::getGrid ( )

getGrid

Returns
the grid

◆ getGridSize()

size_t sgpp::datadriven::RegressionLearner::getGridSize ( ) const

getGridSize

Returns
the size of the grid

◆ getMSE()

◆ getWeights()

base::DataVector sgpp::datadriven::RegressionLearner::getWeights ( ) const

getWeights

Returns
the weights

◆ predict()

base::DataVector sgpp::datadriven::RegressionLearner::predict ( sgpp::base::DataMatrix data)

predict

Parameters
dataare observations
Returns
the predicted target for matrix data

References sgpp::op_factory::createOperationMultipleEval(), and sgpp::base::DataMatrix::getNrows().

Referenced by getMSE().

◆ setWeights()

◆ train()

void sgpp::datadriven::RegressionLearner::train ( sgpp::base::DataMatrix trainDataset,
sgpp::base::DataVector classes 
)

train fits a sparse grid regression model.

Parameters
trainDatasetis the design matrix
classesis the (continuous) target

References sgpp::datadriven::RegressionLearner::Solver::cg, sgpp::op_factory::createOperationMultipleEval(), sgpp::base::DataMatrix::getNrows(), sgpp::base::DataVector::getSize(), and sgpp::base::AdaptivityConfiguration::numRefinements_.


The documentation for this class was generated from the following files: