SG++-Doxygen-Documentation
sgpp::datadriven::LearnerSGDE Class Reference

#include <LearnerSGDE.hpp>

Inheritance diagram for sgpp::datadriven::LearnerSGDE:
sgpp::datadriven::DensityEstimator

Public Member Functions

void cov (base::DataMatrix &cov, base::DataMatrix *bounds=nullptr) override
 WARNING: Not yet implemented. More...
 
virtual double getAccuracy (base::DataMatrix &testDataset, const base::DataVector &referenceLabels, const double threshold)
 Computes the classification accuracy. More...
 
virtual double getAccuracy (const base::DataVector &referenceLabels, const double threshold, const base::DataVector &predictedLabels)
 Computes the classification accuracy. More...
 
size_t getDim () override
 get number of dimensions More...
 
virtual double getError (base::DataMatrix &data, const base::DataVector &labels, const double threshold, std::string errorType)
 Error evaluation required for convergence-based refinement. More...
 
virtual base::GridgetGrid ()
 returns the grid More...
 
size_t getNsamples () override
 get number of samples More...
 
std::shared_ptr< base::DataVectorgetSamples (size_t dim) override
 returns the samples in the given dimension More...
 
std::shared_ptr< base::DataMatrixgetSamples () override
 returns the complete sample set More...
 
std::shared_ptr< base::GridgetSharedGrid ()
 
std::shared_ptr< base::DataVectorgetSharedSurpluses ()
 
virtual base::DataVectorgetSurpluses ()
 returns the surpluses More...
 
void initialize (base::DataMatrix &samples) override
 Create grid and perform cross-validation if enabled. More...
 
 LearnerSGDE (sgpp::base::RegularGridConfiguration &gridConfig, sgpp::base::AdaptivityConfiguration &adaptivityConfig, sgpp::solver::SLESolverConfiguration &solverConfig, sgpp::datadriven::RegularizationConfiguration &regularizationConfig, CrossvalidationConfiguration &crossvalidationConfig)
 Constructor. More...
 
 LearnerSGDE (LearnerSGDEConfiguration &learnerSGDEConfig)
 
 LearnerSGDE (const LearnerSGDE &learnerSGDE)
 
double mean () override
 This method computes the mean of the density function. More...
 
double pdf (base::DataVector &x) override
 This methods evaluates the sparse grid density at a single point. More...
 
void pdf (base::DataMatrix &points, base::DataVector &res) override
 Evaluation of the sparse grid density at a set of points. More...
 
virtual void predict (base::DataMatrix &testDataset, base::DataVector &predictedLabels)
 Predicts class labels based on the trained model. More...
 
virtual void storeResults (base::DataMatrix &testDataset)
 Stores classified data, grids and density function evaluations to csv files. More...
 
virtual void train (base::Grid &grid, base::DataVector &alpha, base::DataMatrix &trainData, double lambdaReg)
 Does the learning step (i.e. More...
 
virtual void train ()
 Learns the data. More...
 
virtual void trainOnline (base::DataVector &labels, base::DataMatrix &testData, base::DataVector &testLabels, base::DataMatrix *validData, base::DataVector *validLabels, base::DataVector &classLabels, size_t maxDataPasses, std::string refType, std::string refMonitor, size_t refPeriod, double accDeclineThreshold, size_t accDeclineBufferSize, size_t minRefInterval, bool usePrior)
 Performs the sparse grid density estimation via online learning. More...
 
double variance () override
 Computes the variance of the density function. More...
 
virtual ~LearnerSGDE ()
 
- Public Member Functions inherited from sgpp::datadriven::DensityEstimator
virtual void corrcoef (base::DataMatrix &corr, base::DataMatrix *bounds=nullptr)
 
double crossEntropy (sgpp::base::DataMatrix &samples)
 
 DensityEstimator ()
 
virtual double std_deviation ()
 
virtual ~DensityEstimator ()
 

Public Attributes

base::DataVector avgErrors
 
double error
 

Protected Member Functions

base::OperationMatrixcomputeRegularizationMatrix (base::Grid &grid)
 generates the regularization matrix More...
 
double computeResidual (base::Grid &grid, base::DataVector &alpha, base::DataMatrix &test, double lambdaReg)
 Compute the residual for a given test data set on a learned grid. More...
 
std::shared_ptr< base::GridcreateRegularGrid ()
 generates a regular grid More...
 
double mean (base::Grid &grid, base::DataVector &alpha)
 
double optimizeLambdaCV ()
 Does cross-validation to obtain a suitable regularization parameter. More...
 
void splitset (std::vector< std::shared_ptr< base::DataMatrix >> &strain, std::vector< std::shared_ptr< base::DataMatrix >> &stest)
 splits the complete sample set in a set of smaller training and test samples for cross-validation. More...
 
double variance (base::Grid &grid, base::DataVector &alpha)
 

Protected Attributes

sgpp::base::AdaptivityConfiguration adaptivityConfig
 
std::shared_ptr< base::DataVectoralpha
 
std::map< int, std::shared_ptr< base::DataVector > > alphas
 
std::map< int, size_t > appearances
 
sgpp::datadriven::CrossvalidationConfiguration crossvalidationConfig
 
std::shared_ptr< base::Gridgrid
 
sgpp::base::RegularGridConfiguration gridConfig
 
std::map< int, std::shared_ptr< base::Grid > > grids
 
double lambdaReg
 
std::map< int, double > priors
 
sgpp::datadriven::RegularizationConfiguration regularizationConfig
 
sgpp::solver::SLESolverConfiguration solverConfig
 
std::shared_ptr< base::DataMatrixtrainData
 
std::shared_ptr< base::DataVectortrainLabels
 
bool usePrior
 

Constructor & Destructor Documentation

◆ LearnerSGDE() [1/3]

sgpp::datadriven::LearnerSGDE::LearnerSGDE ( sgpp::base::RegularGridConfiguration gridConfig,
sgpp::base::AdaptivityConfiguration adaptivityConfig,
sgpp::solver::SLESolverConfiguration solverConfig,
sgpp::datadriven::RegularizationConfiguration regularizationConfig,
CrossvalidationConfiguration crossvalidationConfig 
)

Constructor.

Parameters
gridConfiggrid configuration
adaptivityConfigadaptive refinement configuration
solverConfigsolver configuration (CG)
regularizationConfigconfig for regularization operator
crossvalidationConfigconfiguration for the cross validation

◆ LearnerSGDE() [2/3]

sgpp::datadriven::LearnerSGDE::LearnerSGDE ( LearnerSGDEConfiguration learnerSGDEConfig)
explicit

◆ LearnerSGDE() [3/3]

sgpp::datadriven::LearnerSGDE::LearnerSGDE ( const LearnerSGDE learnerSGDE)

◆ ~LearnerSGDE()

sgpp::datadriven::LearnerSGDE::~LearnerSGDE ( )
virtual

Member Function Documentation

◆ computeRegularizationMatrix()

◆ computeResidual()

double sgpp::datadriven::LearnerSGDE::computeResidual ( base::Grid grid,
base::DataVector alpha,
base::DataMatrix test,
double  lambdaReg 
)
protected

Compute the residual for a given test data set on a learned grid.

$|(A - lambda C) alpha - 1/n B|$

This is used as quality criterion for the estimated density.

Parameters
gridgrid
alphacoefficient vector
testtest set
lambdaRegregularization parameters
Returns

References python.utils.pca_normalize_dataset::C, computeRegularizationMatrix(), sgpp::base::Grid::getSize(), and python.statsfileInfo::i.

Referenced by optimizeLambdaCV().

◆ cov()

void sgpp::datadriven::LearnerSGDE::cov ( base::DataMatrix cov,
base::DataMatrix bounds = nullptr 
)
overridevirtual

WARNING: Not yet implemented.

Implements sgpp::datadriven::DensityEstimator.

References alpha, sgpp::op_factory::createOperationCovariance(), and grid.

◆ createRegularGrid()

◆ getAccuracy() [1/2]

double sgpp::datadriven::LearnerSGDE::getAccuracy ( base::DataMatrix testDataset,
const base::DataVector referenceLabels,
const double  threshold 
)
virtual

Computes the classification accuracy.

Parameters
testDatasetThe data for which class labels should be predicted
referenceLabelsThe corresponding actual class labels
thresholdThe decision threshold (e.g. for class labels -1, 1 -> threshold = 0)
Returns
The resulting accuracy

References sgpp::base::DataMatrix::getNrows(), and predict().

Referenced by getError(), and trainOnline().

◆ getAccuracy() [2/2]

double sgpp::datadriven::LearnerSGDE::getAccuracy ( const base::DataVector referenceLabels,
const double  threshold,
const base::DataVector predictedLabels 
)
virtual

Computes the classification accuracy.

Parameters
referenceLabelsThe actual class labels
thresholdThe decision threshold (e.g. for class labels -1, 1 -> threshold = 0)
predictedLabelsThe predicted class labels
Returns
The resulting accuracy

References sgpp::base::DataVector::get(), sgpp::base::DataVector::getSize(), and python.statsfileInfo::i.

◆ getDim()

size_t sgpp::datadriven::LearnerSGDE::getDim ( )
overridevirtual

◆ getError()

double sgpp::datadriven::LearnerSGDE::getError ( base::DataMatrix data,
const base::DataVector labels,
const double  threshold,
std::string  errorType 
)
virtual

Error evaluation required for convergence-based refinement.

Parameters
dataThe data points to measure the error on
labelsThe corresponding class labels
thresholdThe decision threshold (e.g. for class labels -1, 1 -> threshold = 0)
errorTypeThe error type (only "Acc" possible, i.e. classification error based on accuracy)
Returns
The error evaluation

References getAccuracy().

Referenced by trainOnline().

◆ getGrid()

base::Grid * sgpp::datadriven::LearnerSGDE::getGrid ( )
virtual

returns the grid

References grid.

◆ getNsamples()

size_t sgpp::datadriven::LearnerSGDE::getNsamples ( )
overridevirtual

get number of samples

Implements sgpp::datadriven::DensityEstimator.

References trainData.

Referenced by getSamples(), and optimizeLambdaCV().

◆ getSamples() [1/2]

std::shared_ptr< base::DataVector > sgpp::datadriven::LearnerSGDE::getSamples ( size_t  dim)
overridevirtual

returns the samples in the given dimension

Parameters
dim

Implements sgpp::datadriven::DensityEstimator.

References getNsamples(), and trainData.

◆ getSamples() [2/2]

std::shared_ptr< base::DataMatrix > sgpp::datadriven::LearnerSGDE::getSamples ( )
overridevirtual

returns the complete sample set

Implements sgpp::datadriven::DensityEstimator.

References trainData.

◆ getSharedGrid()

std::shared_ptr< base::Grid > sgpp::datadriven::LearnerSGDE::getSharedGrid ( )

◆ getSharedSurpluses()

std::shared_ptr< base::DataVector > sgpp::datadriven::LearnerSGDE::getSharedSurpluses ( )

◆ getSurpluses()

base::DataVector * sgpp::datadriven::LearnerSGDE::getSurpluses ( )
virtual

returns the surpluses

References alpha.

◆ initialize()

void sgpp::datadriven::LearnerSGDE::initialize ( base::DataMatrix samples)
overridevirtual

◆ mean() [1/2]

double sgpp::datadriven::LearnerSGDE::mean ( )
overridevirtual

◆ mean() [2/2]

◆ optimizeLambdaCV()

◆ pdf() [1/2]

double sgpp::datadriven::LearnerSGDE::pdf ( base::DataVector x)
overridevirtual

This methods evaluates the sparse grid density at a single point.

Parameters
xDataVector length equal to dimensionality

Implements sgpp::datadriven::DensityEstimator.

References alpha, sgpp::op_factory::createOperationEval(), and sgpp::base::OperationEval::eval().

◆ pdf() [2/2]

void sgpp::datadriven::LearnerSGDE::pdf ( base::DataMatrix points,
base::DataVector res 
)
overridevirtual

Evaluation of the sparse grid density at a set of points.

Parameters
pointsDataMatrix (nrows = number of samples, ncols = dimensionality)
resDataVector (size = number of samples) where the results are stored

Implements sgpp::datadriven::DensityEstimator.

References alpha, sgpp::op_factory::createOperationMultipleEval(), sgpp::base::OperationMultipleEval::eval(), and grid.

◆ predict()

void sgpp::datadriven::LearnerSGDE::predict ( base::DataMatrix testDataset,
base::DataVector predictedLabels 
)
virtual

Predicts class labels based on the trained model.

Parameters
testDatasetThe data for which class labels should be predicted
predictedLabelsThe predicted class labels

References alphas, sgpp::op_factory::createOperationEval(), chess::dim, g, sgpp::base::DataMatrix::getNcols(), sgpp::base::DataMatrix::getNrows(), sgpp::base::DataMatrix::getRow(), grids, python.statsfileInfo::i, priors, sgpp::base::DataVector::set(), and usePrior.

Referenced by getAccuracy(), and storeResults().

◆ splitset()

void sgpp::datadriven::LearnerSGDE::splitset ( std::vector< std::shared_ptr< base::DataMatrix >> &  strain,
std::vector< std::shared_ptr< base::DataMatrix >> &  stest 
)
protected

splits the complete sample set in a set of smaller training and test samples for cross-validation.

Parameters
strainvector containing the training samples for cv
stestvector containing the test samples for cv

References crossvalidationConfig, python.statsfileInfo::i, python.utils.statsfile2gnuplot::j, sgpp::datadriven::CrossvalidationConfiguration::kfold_, friedman::p, chess::r, create_scripts::s, sgpp::datadriven::CrossvalidationConfiguration::seed_, sgpp::datadriven::CrossvalidationConfiguration::shuffle_, sgpp::datadriven::CrossvalidationConfiguration::silent_, and analyse_erg::tmp.

Referenced by optimizeLambdaCV().

◆ storeResults()

◆ train() [1/2]

void sgpp::datadriven::LearnerSGDE::train ( base::Grid grid,
base::DataVector alpha,
base::DataMatrix trainData,
double  lambdaReg 
)
virtual

◆ train() [2/2]

void sgpp::datadriven::LearnerSGDE::train ( )
virtual

Learns the data.

References alpha, grid, lambdaReg, and trainData.

Referenced by optimizeLambdaCV().

◆ trainOnline()

void sgpp::datadriven::LearnerSGDE::trainOnline ( base::DataVector labels,
base::DataMatrix testData,
base::DataVector testLabels,
base::DataMatrix validData,
base::DataVector validLabels,
base::DataVector classLabels,
size_t  maxDataPasses,
std::string  refType,
std::string  refMonitor,
size_t  refPeriod,
double  accDeclineThreshold,
size_t  accDeclineBufferSize,
size_t  minRefInterval,
bool  usePrior 
)
virtual

Performs the sparse grid density estimation via online learning.

Parameters
labelsThe training labels
testDataThe test data
testLabelsThe corresponding test labels
validDataThe validation data
validLabelsThe corresponding validation labels
classLabelsThe ocurring class labels (e.g. -1,1)
maxDataPassesThe number of passes over the whole training data
refTypeThe refinement indicator (surplus, zero-crossings or data-based)
refMonitorThe refinement strategy (periodic or convergence-based)
refPeriodThe refinement interval (if periodic refinement is chosen)
accDeclineThresholdThe convergence threshold (if convergence-based refinement is chosen)
accDeclineBufferSizeThe number of accuracy measurements which are used to check convergence (if convergence-based refinement is chosen)
minRefIntervalThe minimum number of data points (or data batches) which have to be processed before next refinement can be scheduled (if convergence-based refinement is chosen)
usePriorSpecifies whether prior probabilities should be used to predict class labels

References adaptivityConfig, alpha, alphas, appearances, sgpp::base::DataVector::append(), avgErrors, python.utils.pca_normalize_dataset::C, computeRegularizationMatrix(), sgpp::base::Grid::createLinearGrid(), sgpp::base::Grid::createModLinearGrid(), sgpp::op_factory::createOperationEval(), chess::dim, sgpp::base::GeneralGridConfiguration::dim_, sgpp::solver::SLESolverConfiguration::eps_, error, g, sgpp::datadriven::DensitySystemMatrix::generateb(), getAccuracy(), getError(), sgpp::base::HashGridStorage::getPoint(), sgpp::base::DataVector::getSize(), sgpp::base::HashGridPoint::getStandardCoordinates(), grid, gridConfig, grids, python.utils.sg_projections::gridStorage, python.statsfileInfo::i, python.utils.statsfile2gnuplot::j, lambdaReg, sgpp::base::GeneralGridConfiguration::level_, sgpp::base::Linear, sgpp::solver::SLESolverConfiguration::maxIterations_, sgpp::base::ModLinear, sgpp::base::AdaptivityConfiguration::noPoints_, sgpp::base::AdaptivityConfiguration::numRefinements_, friedman::p, sgpp::datadriven::MultiGridRefinementFunctor::preComputeEvaluations(), priors, sgpp::datadriven::RefinementMonitor::pushToBuffer(), sgpp::datadriven::RefinementMonitor::refinementsNecessary(), sgpp::datadriven::MultiGridRefinementFunctor::setGridIndex(), sgpp::solver::ConjugateGradients::solve(), solverConfig, sgpp::solver::SLESolverConfiguration::threshold_, sgpp::base::AdaptivityConfiguration::threshold_, trainData, trainLabels, sgpp::base::GeneralGridConfiguration::type_, and usePrior.

◆ variance() [1/2]

double sgpp::datadriven::LearnerSGDE::variance ( )
overridevirtual

Computes the variance of the density function.

Implements sgpp::datadriven::DensityEstimator.

References alpha, and grid.

◆ variance() [2/2]

double sgpp::datadriven::LearnerSGDE::variance ( base::Grid grid,
base::DataVector alpha 
)
protected

Member Data Documentation

◆ adaptivityConfig

sgpp::base::AdaptivityConfiguration sgpp::datadriven::LearnerSGDE::adaptivityConfig
protected

Referenced by LearnerSGDE(), train(), and trainOnline().

◆ alpha

◆ alphas

std::map<int, std::shared_ptr<base::DataVector> > sgpp::datadriven::LearnerSGDE::alphas
protected

Referenced by predict(), storeResults(), and trainOnline().

◆ appearances

std::map<int, size_t> sgpp::datadriven::LearnerSGDE::appearances
protected

Referenced by trainOnline().

◆ avgErrors

base::DataVector sgpp::datadriven::LearnerSGDE::avgErrors

Referenced by trainOnline().

◆ crossvalidationConfig

sgpp::datadriven::CrossvalidationConfiguration sgpp::datadriven::LearnerSGDE::crossvalidationConfig
protected

◆ error

double sgpp::datadriven::LearnerSGDE::error

Referenced by LearnerSGDE(), and trainOnline().

◆ grid

◆ gridConfig

sgpp::base::RegularGridConfiguration sgpp::datadriven::LearnerSGDE::gridConfig
protected

◆ grids

std::map<int, std::shared_ptr<base::Grid> > sgpp::datadriven::LearnerSGDE::grids
protected

Referenced by predict(), storeResults(), and trainOnline().

◆ lambdaReg

double sgpp::datadriven::LearnerSGDE::lambdaReg
protected

◆ priors

std::map<int, double> sgpp::datadriven::LearnerSGDE::priors
protected

Referenced by predict(), and trainOnline().

◆ regularizationConfig

sgpp::datadriven::RegularizationConfiguration sgpp::datadriven::LearnerSGDE::regularizationConfig
protected

◆ solverConfig

sgpp::solver::SLESolverConfiguration sgpp::datadriven::LearnerSGDE::solverConfig
protected

Referenced by LearnerSGDE(), train(), and trainOnline().

◆ trainData

◆ trainLabels

std::shared_ptr<base::DataVector> sgpp::datadriven::LearnerSGDE::trainLabels
protected

Referenced by trainOnline().

◆ usePrior

bool sgpp::datadriven::LearnerSGDE::usePrior
protected

Referenced by LearnerSGDE(), predict(), and trainOnline().


The documentation for this class was generated from the following files: