SG++-Doxygen-Documentation
|
The RegressionLearner class Solves a regression problem with continuous target vector. More...
#include <RegressionLearner.hpp>
Classes | |
class | Solver |
The RegressionLearner class Solves a regression problem with continuous target vector.
sgpp::datadriven::RegressionLearner::RegressionLearner | ( | sgpp::base::RegularGridConfiguration | gridConfig, |
sgpp::base::AdaptivityConfiguration | adaptivityConfig, | ||
sgpp::solver::SLESolverConfiguration | solverConfig, | ||
sgpp::solver::SLESolverConfiguration | finalSolverConfig, | ||
datadriven::RegularizationConfiguration | regularizationConfig, | ||
std::vector< std::vector< size_t >> | terms | ||
) |
gridConfig | |
adaptivityConfig | |
solverConfig | is the solver used during each adaptivity step |
finalSolverConfig | is the solver used to build the final model |
regularizationConfig | |
terms | is a vector that contains all desired interaction terms. For example, if we want to include grid points that model an interaction between the first and the second predictor, we would include the vector [1,2] in terms. |
sgpp::datadriven::RegressionLearner::RegressionLearner | ( | sgpp::base::RegularGridConfiguration | gridConfig, |
sgpp::base::AdaptivityConfiguration | adaptivityConfig, | ||
sgpp::solver::SLESolverConfiguration | solverConfig, | ||
sgpp::solver::SLESolverConfiguration | finalSolverConfig, | ||
datadriven::RegularizationConfiguration | regularizationConfig | ||
) |
gridConfig | |
adaptivityConfig | |
solverConfig | is the solver used during each adaptivity step |
finalSolverConfig | is the solver used to build the final model |
regularizationConfig |
base::Grid & sgpp::datadriven::RegressionLearner::getGrid | ( | ) |
getGrid
size_t sgpp::datadriven::RegressionLearner::getGridSize | ( | ) | const |
getGridSize
double sgpp::datadriven::RegressionLearner::getMSE | ( | sgpp::base::DataMatrix & | data, |
const sgpp::base::DataVector & | y | ||
) |
getMSE
data | is the design matrix |
y | is the target |
References chess::b, sgpp::datadriven::RegressionLearner::Solver::cg, sgpp::op_factory::createOperationMultipleEval(), python.statsfileInfo::data, sgpp::datadriven::RegressionLearner::Solver::fista, sgpp::datadriven::RegressionLearner::Solver::getL(), sgpp::base::DataVector::getSize(), sgpp::solver::SLESolverConfiguration::maxIterations_, sgpp::datadriven::RegressionLearner::Solver::none, sgpp::base::AdaptivityConfiguration::noPoints_, predict(), sgpp::base::DataVector::resizeZero(), sgpp::datadriven::RegressionLearner::Solver::solveCG(), sgpp::datadriven::RegressionLearner::Solver::solveFista(), sgpp::solver::SLESolverConfiguration::threshold_, sgpp::base::AdaptivityConfiguration::threshold_, and sgpp::datadriven::RegressionLearner::Solver::type.
Referenced by setWeights().
base::DataVector sgpp::datadriven::RegressionLearner::getWeights | ( | ) | const |
getWeights
base::DataVector sgpp::datadriven::RegressionLearner::predict | ( | sgpp::base::DataMatrix & | data | ) |
predict
data | are observations |
References sgpp::op_factory::createOperationMultipleEval(), and sgpp::base::DataMatrix::getNrows().
Referenced by getMSE().
void sgpp::datadriven::RegressionLearner::setWeights | ( | sgpp::base::DataVector | weights | ) |
setWeights
weights | are the new weights. |
References sgpp::datadriven::CG, sgpp::op_factory::createOperationDiagonal(), sgpp::op_factory::createOperationIdentity(), sgpp::op_factory::createOperationLaplace(), sgpp::datadriven::Diagonal, sgpp::base::GeneralGridConfiguration::dim_, sgpp::datadriven::ElasticNet, sgpp::solver::SLESolverConfiguration::eps_, sgpp::datadriven::RegularizationConfiguration::exponentBase_, getMSE(), sgpp::base::DataVector::getSize(), sgpp::datadriven::GroupLasso, sgpp::datadriven::Identity, sgpp::datadriven::RegularizationConfiguration::l1Ratio_, lambda, sgpp::datadriven::RegularizationConfiguration::lambda_, sgpp::datadriven::Laplace, sgpp::datadriven::Lasso, sgpp::base::GeneralGridConfiguration::level_, sgpp::solver::SLESolverConfiguration::maxIterations_, sgpp::base::DataVector::setAll(), sgpp::base::DataVector::sqr(), sgpp::base::DataVector::sub(), sgpp::base::DataVector::sum(), sgpp::base::GeneralGridConfiguration::t_, sgpp::datadriven::RegularizationConfiguration::type_, sgpp::solver::SLESolverConfiguration::type_, and sgpp::base::GeneralGridConfiguration::type_.
void sgpp::datadriven::RegressionLearner::train | ( | sgpp::base::DataMatrix & | trainDataset, |
sgpp::base::DataVector & | classes | ||
) |
train fits a sparse grid regression model.
trainDataset | is the design matrix |
classes | is the (continuous) target |
References sgpp::datadriven::RegressionLearner::Solver::cg, sgpp::op_factory::createOperationMultipleEval(), sgpp::base::DataMatrix::getNrows(), sgpp::base::DataVector::getSize(), and sgpp::base::AdaptivityConfiguration::numRefinements_.