SG++-Doxygen-Documentation
|
#include <sgpp/datadriven/datamining/builder/DataSourceBuilder.hpp>
#include <sgpp/datadriven/datamining/modules/dataSource/DataSource.hpp>
#include <sgpp/datadriven/datamining/modules/dataSource/DataSourceCrossValidation.hpp>
#include <sgpp/datadriven/datamining/modules/fitting/FitterConfigurationLeastSquares.hpp>
#include <sgpp/datadriven/datamining/modules/fitting/ModelFittingLeastSquares.hpp>
#include <sgpp/datadriven/datamining/modules/scoring/MSE.hpp>
#include <sgpp/datadriven/datamining/base/SparseGridMinerCrossValidation.hpp>
#include <sgpp/globaldef.hpp>
#include <iostream>
#include <memory>
#include <string>
Functions | |
int | main (int argc, char **argv) |
int main | ( | int | argc, |
char ** | argv | ||
) |
use immediately invoked lambda expression to get the path to a configuration file.
In order to read a dataset from disk, we need an instance of a sgpp::datadriven::DataSource object, that is constructed using a builder pattern. Since we only want to read a from disk and use all samples it provides, we only pass in the path. Everything else is managed by default values and auto detection of extensions.
Once we have a data source, we can read the contents of the stored dataset.
We want to perform least squares regression now. First we need to set up our Fitter using a configuration structure.
We first set up the provided default parameters enabled for least squares regression
Everything that does not match the default values is then adapted.
Based on our configuration, we then can create a fitter object.
We want to perform 5 Fold cross validation on our model. To assess the quality of the regression algorithm, we use the mean squared error (MSE) as an error metric. To ensure testing and training data are not taken from an ordered distribution, we will permute the values that go into testing and training dataset.
Create a sparse grid miner that performs cross validation. The number of folds is 5 per default.
Here the actual learning process is launched. The miner will perform k-fold cross validation and print the mean score as well as the standard deviation.
References sgpp::datadriven::DataSourceBuilder::crossValidationAssemble(), dataset, sgpp::datadriven::Dataset::getDimension(), sgpp::datadriven::SparseGridMinerCrossValidation::learn(), and sgpp::datadriven::DataSourceBuilder::withPath().