SG++
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
CrossValidationExample.cpp

This example can be found under datadriven/examples/CrossValidationExample.cpp.

/* Copyright (C) 2008-today The SG++ project
* This file is part of the SG++ project. For conditions of distribution and
* use, please see the copyright notice provided with SG++ or at
* sgpp.sparsegrids.org
*
* regressionPipeline.cpp
*
* Created on: 01.06.2016
* Author: Michael Lettrich
*/
#include <iostream>
#include <memory>
#include <string>
// using sgpp::datadriven::SimpleSplittingScorer;
int main(int argc, char **argv) {
// input
std::string path;
if (argc != 2) {
std::cout << "No or bad path given, aborting" << std::endl;
exit(-1);
} else {
path = std::string(argv[1]);
}
auto dataSource = std::unique_ptr<DataSource>(DataSourceBuilder().withPath(path).assemble());
std::cout << "reading input file: " << path << std::endl;
auto dataset = std::unique_ptr<Dataset>(dataSource->getNextSamples());
// regression
auto config = DataMiningConfigurationLeastSquares();
// set grid dim
auto gridConfig = config.getGridConfig();
gridConfig.level_ = 2;
gridConfig.type_ = GridType::ModLinear;
gridConfig.dim_ = dataset->getDimension();
config.setGridConfig(gridConfig);
config.setLambda(10e-1);
std::cout << "starting 5 fold cross validation with seed 42" << std::endl;
auto model = std::make_unique<ModelFittingLeastSquares>(config);
auto metric = std::make_unique<MSE>();
auto shuffling = std::make_unique<RandomShufflingFunctor>();
double stdDeviation;
CrossValidation crossValidation(metric.release(), shuffling.release(), 42);
double score = crossValidation.calculateScore(*model, *dataset, 5, &stdDeviation);
std::cout << "Score = " << score << " with stdDeviation " << stdDeviation << std::endl;
return 0;
}