SG++
sgpp::datadriven::CrossValidation Class Reference

Supervised learning with cross validation used to fit a model and quantify accuracy using a sgpp::datadriven::Metric. More...

#include <CrossValidation.hpp>

Inheritance diagram for sgpp::datadriven::CrossValidation:
sgpp::datadriven::Scorer

Public Member Functions

double calculateScore (ModelFittingBase &model, Dataset &dataset, double *stdDeviation=nullptr) override
 Train and test a model on a dataset and provide a score to quantify the approximation quality. More...
 
Scorerclone () const override
 Polymorphic clone pattern. More...
 
 CrossValidation (Metric *metric, ShufflingFunctor *shuffling, int64_t seed=-1, size_t foldNumber=5)
 Constructor. More...
 
- Public Member Functions inherited from sgpp::datadriven::Scorer
Scoreroperator= (const Scorer &rhs)
 Copy assign operator. More...
 
Scoreroperator= (Scorer &&rhs)=default
 Move assign operator. More...
 
 Scorer (Metric *metric, ShufflingFunctor *shuffling, int64_t seed=-1)
 Constructor. More...
 
 Scorer (const Scorer &rhs)
 Copy constructor. More...
 
 Scorer (Scorer &&rhs)=default
 Move constructor. More...
 
virtual ~Scorer ()=default
 virtual destructor. More...
 

Additional Inherited Members

- Protected Member Functions inherited from sgpp::datadriven::Scorer
void randomizeIndices (const Dataset &data, std::vector< size_t > &randomizedIndices)
 Helper method to generate an ordering for the samples of the dataset based on the shuffling functor. More...
 
double refine (ModelFittingBase &model, Dataset &testDataset)
 Fit the model on the train dataset and evaluate the accuracy on the test set. More...
 
void splitSet (const Dataset &fullDataset, Dataset &trainDataset, Dataset &testDataset, const std::vector< size_t > &randomizedIndices, size_t offset=0)
 Split dataset into testing and training set. More...
 
double test (ModelFittingBase &model, Dataset &testDataset)
 evaluate the accuracy on the test set using the metric. More...
 
double train (ModelFittingBase &model, Dataset &trainDataset, Dataset &testDataset)
 Fit the model on the train dataset and evaluate the accuracy on the test set. More...
 
- Protected Attributes inherited from sgpp::datadriven::Scorer
std::unique_ptr< Metricmetric
 sgpp::datadriven::Metric to be used to quantify accuracy of the fit. More...
 
std::unique_ptr< ShufflingFunctorshuffling
 sgpp::datadriven::ShufflingFunctor used to rearrange samples of a dataset in the desired manner, ready to be split into testing and training sets More...
 

Detailed Description

Supervised learning with cross validation used to fit a model and quantify accuracy using a sgpp::datadriven::Metric.

Splits a dataset into testing and training parts, trains the model and measures average accuracy and standard deviation of the fits.

Constructor & Destructor Documentation

sgpp::datadriven::CrossValidation::CrossValidation ( Metric metric,
ShufflingFunctor shuffling,
int64_t  seed = -1,
size_t  foldNumber = 5 
)

Constructor.

Parameters
metricsgpp::datadriven::Metric to to quantify approximation quality of a trained model. Scorer will take ownership of this object.
shufflingsgpp::datadriven::ShufflingFunctor to rearrange samples of a dataset in the desired manner, ready to be split into testing and training sets. Scorer will take ownership of this object.
seedseed for randomization in sgpp::datadriven::ShufflingFunctor. Default is -1 which puts a random seed.
foldNumberamount of folds used for cross validation.

References sgpp::datadriven::Scorer::metric.

Member Function Documentation

double sgpp::datadriven::CrossValidation::calculateScore ( ModelFittingBase model,
Dataset dataset,
double *  stdDeviation = nullptr 
)
overridevirtual

Train and test a model on a dataset and provide a score to quantify the approximation quality.

If multiple models are trained, calculate the standard deviation between the different fits.

Parameters
modelA model to be fitted on the training part of the dataset.
datasetSet of samples to use for fitting and testing the model.
stdDeviationreturn standard deviation between different runs.
Returns
average accuracy of all fits as calculated by the metric provided.

Implements sgpp::datadriven::Scorer.

References sgpp::datadriven::Dataset::getDimension(), sgpp::datadriven::Dataset::getNumberInstances(), sgpp::combigrid::pow(), sgpp::datadriven::Scorer::randomizeIndices(), sgpp::datadriven::Scorer::refine(), sgpp::datadriven::Scorer::splitSet(), and sgpp::datadriven::Scorer::train().

Scorer * sgpp::datadriven::CrossValidation::clone ( ) const
overridevirtual

Polymorphic clone pattern.

Returns
deep copy of this object. New object is owned by caller.

Implements sgpp::datadriven::Scorer.


The documentation for this class was generated from the following files: