SG++
sgpp::datadriven::LearnerSGDE Class Reference

#include <LearnerSGDE.hpp>

Inheritance diagram for sgpp::datadriven::LearnerSGDE:
sgpp::datadriven::DensityEstimator

Public Member Functions

virtual void cov (base::DataMatrix &cov)
 WARNING: Not yet implemented. More...
 
virtual double getAccuracy (base::DataMatrix &testDataset, const base::DataVector &referenceLabels, const double threshold)
 Computes the classification accuracy. More...
 
virtual double getAccuracy (const base::DataVector &referenceLabels, const double threshold, const base::DataVector &predictedLabels)
 Computes the classification accuracy. More...
 
virtual size_t getDim ()
 get number of dimensions More...
 
virtual double getError (base::DataMatrix &data, const base::DataVector &labels, const double threshold, std::string errorType)
 Error evaluation required for convergence-based refinement. More...
 
virtual std::shared_ptr< base::GridgetGrid ()
 returns the grid More...
 
virtual size_t getNsamples ()
 get number of samples More...
 
virtual std::shared_ptr< base::DataVectorgetSamples (size_t dim)
 returns the samples in the given dimension More...
 
virtual std::shared_ptr< base::DataMatrixgetSamples ()
 returns the complete sample set More...
 
virtual std::shared_ptr< base::DataVectorgetSurpluses ()
 returns the surpluses More...
 
virtual void initialize (base::DataMatrix &samples)
 Create grid and perform cross-validation if enabled. More...
 
 LearnerSGDE (sgpp::base::RegularGridConfiguration &gridConfig, sgpp::base::AdpativityConfiguration &adaptivityConfig, sgpp::solver::SLESolverConfiguration &solverConfig, sgpp::datadriven::RegularizationConfiguration &regularizationConfig, CrossvalidationConfiguration &crossvalidationConfig)
 Constructor. More...
 
 LearnerSGDE (LearnerSGDEConfiguration &learnerSGDEConfig)
 
 LearnerSGDE (const LearnerSGDE &learnerSGDE)
 
virtual double mean ()
 This method computes the mean of the density function. More...
 
virtual double pdf (base::DataVector &x)
 This methods evaluates the sparse grid density at a single point. More...
 
virtual void pdf (base::DataMatrix &points, base::DataVector &res)
 Evaluation of the sparse grid density at a set of points. More...
 
virtual void predict (base::DataMatrix &testDataset, base::DataVector &predictedLabels)
 Predicts class labels based on the trained model. More...
 
virtual void storeResults (base::DataMatrix &testDataset)
 Stores classified data, grids and density function evaluations to csv files. More...
 
virtual void train (base::Grid &grid, base::DataVector &alpha, base::DataMatrix &trainData, double lambdaReg)
 Does the learning step (i.e. More...
 
virtual void train ()
 Learns the data. More...
 
virtual void trainOnline (base::DataVector &labels, base::DataMatrix &testData, base::DataVector &testLabels, base::DataMatrix *validData, base::DataVector *validLabels, base::DataVector &classLabels, size_t maxDataPasses, std::string refType, std::string refMonitor, size_t refPeriod, double accDeclineThreshold, size_t accDeclineBufferSize, size_t minRefInterval, bool usePrior)
 Performs the sparse grid density estimation via online learning. More...
 
virtual double variance ()
 Computes the variance of the density function. More...
 
virtual ~LearnerSGDE ()
 
- Public Member Functions inherited from sgpp::datadriven::DensityEstimator
virtual void corrcoef (base::DataMatrix &corr)
 
 DensityEstimator ()
 
virtual double std_deviation ()
 
virtual ~DensityEstimator ()
 

Public Attributes

base::DataVector avgErrors
 
double error
 

Protected Member Functions

std::unique_ptr< base::OperationMatrixcomputeRegularizationMatrix (base::Grid &grid)
 generates the regularization matrix More...
 
double computeResidual (base::Grid &grid, base::DataVector &alpha, base::DataMatrix &test, double lambdaReg)
 Compute the residual for a given test data set on a learned grid. More...
 
std::shared_ptr< base::GridcreateRegularGrid ()
 generates a regular grid More...
 
double mean (base::Grid &grid, base::DataVector &alpha)
 
double optimizeLambdaCV ()
 Does cross-validation to obtain a suitable regularization parameter. More...
 
void splitset (std::vector< std::shared_ptr< base::DataMatrix >> &strain, std::vector< std::shared_ptr< base::DataMatrix >> &stest)
 splits the complete sample set in a set of smaller training and test samples for cross-validation. More...
 
double variance (base::Grid &grid, base::DataVector &alpha)
 

Protected Attributes

sgpp::base::AdpativityConfiguration adaptivityConfig
 
std::shared_ptr< base::DataVectoralpha
 
std::map< int, std::shared_ptr< base::DataVector > > alphas
 
std::map< int, size_t > appearances
 
sgpp::datadriven::CrossvalidationConfiguration crossvalidationConfig
 
std::shared_ptr< base::Gridgrid
 
sgpp::base::RegularGridConfiguration gridConfig
 
std::map< int, std::shared_ptr< base::Grid > > grids
 
double lambdaReg
 
std::map< int, double > priors
 
sgpp::datadriven::RegularizationConfiguration regularizationConfig
 
sgpp::solver::SLESolverConfiguration solverConfig
 
std::shared_ptr< base::DataMatrixtrainData
 
std::shared_ptr< base::DataVectortrainLabels
 
bool usePrior
 

Constructor & Destructor Documentation

sgpp::datadriven::LearnerSGDE::LearnerSGDE ( sgpp::base::RegularGridConfiguration gridConfig,
sgpp::base::AdpativityConfiguration adaptivityConfig,
sgpp::solver::SLESolverConfiguration solverConfig,
sgpp::datadriven::RegularizationConfiguration regularizationConfig,
CrossvalidationConfiguration crossvalidationConfig 
)

Constructor.

Parameters
gridConfiggrid configuration
adaptivityConfigadaptive refinement configuration
solverConfigsolver configuration (CG)
regularizationConfigconfig for regularization operator
crossvalidationConfigconfiguration for the cross validation
sgpp::datadriven::LearnerSGDE::LearnerSGDE ( LearnerSGDEConfiguration learnerSGDEConfig)
explicit
sgpp::datadriven::LearnerSGDE::LearnerSGDE ( const LearnerSGDE learnerSGDE)
sgpp::datadriven::LearnerSGDE::~LearnerSGDE ( )
virtual

Member Function Documentation

std::unique_ptr< base::OperationMatrix > sgpp::datadriven::LearnerSGDE::computeRegularizationMatrix ( base::Grid grid)
protected
double sgpp::datadriven::LearnerSGDE::computeResidual ( base::Grid grid,
base::DataVector alpha,
base::DataMatrix test,
double  lambdaReg 
)
protected

Compute the residual for a given test data set on a learned grid.

$|(A - lambda C) alpha - 1/n B|$

This is used as quality criterion for the estimated density.

Parameters
gridgrid
alphacoefficient vector
testtest set
lambdaRegregularization parameters
Returns

References computeRegularizationMatrix(), and sgpp::base::Grid::getSize().

Referenced by optimizeLambdaCV().

double sgpp::datadriven::LearnerSGDE::getAccuracy ( base::DataMatrix testDataset,
const base::DataVector referenceLabels,
const double  threshold 
)
virtual

Computes the classification accuracy.

Parameters
testDatasetThe data for which class labels should be predicted
referenceLabelsThe corresponding actual class labels
thresholdThe decision threshold (e.g. for class labels -1, 1 -> threshold = 0)
Returns
The resulting accuracy

References sgpp::base::DataMatrix::getNrows(), and predict().

Referenced by getError(), and trainOnline().

double sgpp::datadriven::LearnerSGDE::getAccuracy ( const base::DataVector referenceLabels,
const double  threshold,
const base::DataVector predictedLabels 
)
virtual

Computes the classification accuracy.

Parameters
referenceLabelsThe actual class labels
thresholdThe decision threshold (e.g. for class labels -1, 1 -> threshold = 0)
predictedLabelsThe predicted class labels
Returns
The resulting accuracy

References sgpp::base::DataVector::get(), and sgpp::base::DataVector::getSize().

size_t sgpp::datadriven::LearnerSGDE::getDim ( )
virtual

get number of dimensions

Implements sgpp::datadriven::DensityEstimator.

References sgpp::base::RegularGridConfiguration::dim_, and gridConfig.

double sgpp::datadriven::LearnerSGDE::getError ( base::DataMatrix data,
const base::DataVector labels,
const double  threshold,
std::string  errorType 
)
virtual

Error evaluation required for convergence-based refinement.

Parameters
dataThe data points to measure the error on
labelsThe corresponding class labels
thresholdThe decision threshold (e.g. for class labels -1, 1 -> threshold = 0)
errorTypeThe error type (only "Acc" possible, i.e. classification error based on accuracy)
Returns
The error evaluation

References getAccuracy().

Referenced by trainOnline().

std::shared_ptr< base::Grid > sgpp::datadriven::LearnerSGDE::getGrid ( )
virtual

returns the grid

References grid.

size_t sgpp::datadriven::LearnerSGDE::getNsamples ( )
virtual

get number of samples

Implements sgpp::datadriven::DensityEstimator.

References trainData.

Referenced by getSamples(), and optimizeLambdaCV().

std::shared_ptr< base::DataVector > sgpp::datadriven::LearnerSGDE::getSamples ( size_t  dim)
virtual

returns the samples in the given dimension

Parameters
dim

Implements sgpp::datadriven::DensityEstimator.

References getNsamples(), and trainData.

std::shared_ptr< base::DataMatrix > sgpp::datadriven::LearnerSGDE::getSamples ( )
virtual

returns the complete sample set

Implements sgpp::datadriven::DensityEstimator.

References trainData.

std::shared_ptr< base::DataVector > sgpp::datadriven::LearnerSGDE::getSurpluses ( )
virtual

returns the surpluses

References alpha.

void sgpp::datadriven::LearnerSGDE::initialize ( base::DataMatrix samples)
virtual
double sgpp::datadriven::LearnerSGDE::mean ( )
virtual
double sgpp::datadriven::LearnerSGDE::pdf ( base::DataVector x)
virtual

This methods evaluates the sparse grid density at a single point.

Parameters
xDataVector length equal to dimensionality

Implements sgpp::datadriven::DensityEstimator.

References alpha, sgpp::op_factory::createOperationEval(), and sgpp::base::OperationEval::eval().

void sgpp::datadriven::LearnerSGDE::pdf ( base::DataMatrix points,
base::DataVector res 
)
virtual

Evaluation of the sparse grid density at a set of points.

Parameters
pointsDataMatrix (nrows = number of samples, ncols = dimensionality)
resDataVector (size = number of samples) where the results are stored

Implements sgpp::datadriven::DensityEstimator.

References alpha, sgpp::op_factory::createOperationMultipleEval(), sgpp::base::OperationMultipleEval::eval(), and grid.

void sgpp::datadriven::LearnerSGDE::predict ( base::DataMatrix testDataset,
base::DataVector predictedLabels 
)
virtual

Predicts class labels based on the trained model.

Parameters
testDatasetThe data for which class labels should be predicted
predictedLabelsThe predicted class labels

References alphas, sgpp::op_factory::createOperationEval(), g, sgpp::base::DataMatrix::getNcols(), sgpp::base::DataMatrix::getNrows(), sgpp::base::DataMatrix::getRow(), grids, priors, sgpp::base::DataVector::set(), and usePrior.

Referenced by getAccuracy(), and storeResults().

void sgpp::datadriven::LearnerSGDE::splitset ( std::vector< std::shared_ptr< base::DataMatrix >> &  strain,
std::vector< std::shared_ptr< base::DataMatrix >> &  stest 
)
protected

splits the complete sample set in a set of smaller training and test samples for cross-validation.

Parameters
strainvector containing the training samples for cv
stestvector containing the test samples for cv

References crossvalidationConfig, sgpp::datadriven::CrossvalidationConfiguration::kfold_, sgpp::datadriven::CrossvalidationConfiguration::seed_, sgpp::datadriven::CrossvalidationConfiguration::shuffle_, sgpp::datadriven::CrossvalidationConfiguration::silent_, and analyse_erg::tmp.

Referenced by optimizeLambdaCV().

void sgpp::datadriven::LearnerSGDE::train ( )
virtual

Learns the data.

References alpha, grid, lambdaReg, and trainData.

Referenced by optimizeLambdaCV().

void sgpp::datadriven::LearnerSGDE::trainOnline ( base::DataVector labels,
base::DataMatrix testData,
base::DataVector testLabels,
base::DataMatrix validData,
base::DataVector validLabels,
base::DataVector classLabels,
size_t  maxDataPasses,
std::string  refType,
std::string  refMonitor,
size_t  refPeriod,
double  accDeclineThreshold,
size_t  accDeclineBufferSize,
size_t  minRefInterval,
bool  usePrior 
)
virtual

Performs the sparse grid density estimation via online learning.

Parameters
labelsThe training labels
testDataThe test data
testLabelsThe corresponding test labels
validDataThe validation data
validLabelsThe corresponding validation labels
classLabelsThe ocurring class labels (e.g. -1,1)
maxDataPassesThe number of passes over the whole training data
refTypeThe refinement indicator (surplus, zero-crossings or data-based)
refMonitorThe refinement strategy (periodic or convergence-based)
refPeriodThe refinement interval (if periodic refinement is chosen)
accDeclineThresholdThe convergence threshold (if convergence-based refinement is chosen)
accDeclineBufferSizeThe number of accuracy measurements which are used to check convergence (if convergence-based refinement is chosen)
minRefIntervalThe minimum number of data points (or data batches) which have to be processed before next refinement can be scheduled (if convergence-based refinement is chosen)
usePriorSpecifies whether prior probabilities should be used to predict class labels

References adaptivityConfig, alpha, alphas, appearances, sgpp::base::DataVector::append(), avgErrors, computeRegularizationMatrix(), sgpp::base::Grid::createLinearGrid(), sgpp::base::Grid::createModLinearGrid(), sgpp::op_factory::createOperationEval(), sgpp::base::RegularGridConfiguration::dim_, sgpp::solver::SLESolverConfiguration::eps_, error, g, sgpp::datadriven::DensitySystemMatrix::generateb(), getAccuracy(), getError(), sgpp::base::HashGridStorage::getPoint(), sgpp::base::DataVector::getSize(), sgpp::base::HashGridPoint::getStandardCoordinates(), grid, gridConfig, grids, lambdaReg, sgpp::base::RegularGridConfiguration::level_, sgpp::base::Linear, sgpp::solver::SLESolverConfiguration::maxIterations_, sgpp::base::ModLinear, sgpp::base::AdpativityConfiguration::noPoints_, sgpp::base::AdpativityConfiguration::numRefinements_, sgpp::datadriven::MultiGridRefinementFunctor::preComputeEvaluations(), priors, sgpp::datadriven::MultiGridRefinementFunctor::setGridIndex(), sgpp::solver::ConjugateGradients::solve(), solverConfig, sgpp::solver::SLESolverConfiguration::threshold_, sgpp::base::AdpativityConfiguration::threshold_, trainData, trainLabels, sgpp::base::RegularGridConfiguration::type_, and usePrior.

double sgpp::datadriven::LearnerSGDE::variance ( )
virtual

Computes the variance of the density function.

Implements sgpp::datadriven::DensityEstimator.

References alpha, and grid.

Referenced by cov().

double sgpp::datadriven::LearnerSGDE::variance ( base::Grid grid,
base::DataVector alpha 
)
protected

Member Data Documentation

sgpp::base::AdpativityConfiguration sgpp::datadriven::LearnerSGDE::adaptivityConfig
protected

Referenced by LearnerSGDE(), train(), and trainOnline().

std::map<int, std::shared_ptr<base::DataVector> > sgpp::datadriven::LearnerSGDE::alphas
protected

Referenced by predict(), storeResults(), and trainOnline().

std::map<int, size_t> sgpp::datadriven::LearnerSGDE::appearances
protected

Referenced by trainOnline().

base::DataVector sgpp::datadriven::LearnerSGDE::avgErrors

Referenced by trainOnline().

sgpp::datadriven::CrossvalidationConfiguration sgpp::datadriven::LearnerSGDE::crossvalidationConfig
protected
double sgpp::datadriven::LearnerSGDE::error

Referenced by trainOnline().

sgpp::base::RegularGridConfiguration sgpp::datadriven::LearnerSGDE::gridConfig
protected
std::map<int, std::shared_ptr<base::Grid> > sgpp::datadriven::LearnerSGDE::grids
protected

Referenced by predict(), storeResults(), and trainOnline().

double sgpp::datadriven::LearnerSGDE::lambdaReg
protected

Referenced by initialize(), train(), and trainOnline().

std::map<int, double> sgpp::datadriven::LearnerSGDE::priors
protected

Referenced by predict(), and trainOnline().

sgpp::datadriven::RegularizationConfiguration sgpp::datadriven::LearnerSGDE::regularizationConfig
protected
sgpp::solver::SLESolverConfiguration sgpp::datadriven::LearnerSGDE::solverConfig
protected

Referenced by LearnerSGDE(), train(), and trainOnline().

std::shared_ptr<base::DataMatrix> sgpp::datadriven::LearnerSGDE::trainData
protected
std::shared_ptr<base::DataVector> sgpp::datadriven::LearnerSGDE::trainLabels
protected

Referenced by trainOnline().

bool sgpp::datadriven::LearnerSGDE::usePrior
protected

Referenced by predict(), and trainOnline().


The documentation for this class was generated from the following files: