SG++
sgpp::datadriven::LearnerSGDEOnOff Class Reference

LearnerSGDEOnOff learns the data using sparse grid density estimation. More...

#include <LearnerSGDEOnOff.hpp>

Inheritance diagram for sgpp::datadriven::LearnerSGDEOnOff:
sgpp::datadriven::LearnerSGDEOnOffParallel

Public Member Functions

double getAccuracy () const
 Returns the accuracy of the classifier measured on the test data. More...
 
void getAvgErrors (DataVector &result) const
 
void getDensities (DataVector &point, DataVector &density) const
 Returns the values of all density functions for a specified data point. More...
 
ClassDensityConntainergetDensityFunctions ()
 Returns the density functions mapped to class labels. More...
 
double getError (Dataset &dataset) const
 Error evaluation required for convergence-based refinement. More...
 
size_t getNumClasses () const
 In case of crossvalidation, returns the current best lambda. More...
 
 LearnerSGDEOnOff (sgpp::base::RegularGridConfiguration &gridConfig, sgpp::base::AdpativityConfiguration &adaptivityConfig, sgpp::datadriven::RegularizationConfiguration &regularizationConfig, sgpp::datadriven::DensityEstimationConfiguration &densityEstimationConfig, Dataset &trainData, Dataset &testData, Dataset *validationData, DataVector &classLabels, size_t classNumber, bool usePrior, double beta, std::string matrixfile="")
 Constructor. More...
 
void predict (DataMatrix &test, DataVector &classLabels) const
 Predicts the class labels of the test data points. More...
 
int predict (DataVector &p) const
 Predicts the class label of the given data point. More...
 
void setCrossValidationParameters (int lambdaStep, double lambdaStart, double lambdaEnd, DataMatrix *test, DataMatrix *testRes, bool logscale)
 Sets the cross-validation parameters. More...
 
void storeResults ()
 Stores classified data, grids and density function evaluations to csv files. More...
 
void train (size_t batchSize, size_t maxDataPasses, std::string refType, std::string refMonitor, size_t refPeriod, double accDeclineThreshold, size_t accDeclineBufferSize, size_t minRefInterval, bool enableCv, size_t nextCvStep)
 Trains the learner with the given dataset. More...
 
void train (Dataset &dataset, bool doCv=false, std::vector< std::pair< std::list< size_t >, size_t >> *refineCoarse=nullptr)
 Trains the learner with the given data batch. More...
 
void train (std::vector< std::pair< DataMatrix *, double >> &trainDataClasses, bool doCv=false, std::vector< std::pair< std::list< size_t >, size_t >> *refineCoarse=nullptr)
 Trains the learner with the given data batch that is already split up wrt its different classes. More...
 

Protected Member Functions

void refine (ConvergenceMonitor &monitor, std::vector< std::pair< std::list< size_t >, size_t >> &refineCoarse, std::string &refType)
 

Protected Attributes

DataVector avgErrors
 
double beta
 
DataVector classLabels
 
ClassDensityConntainer densityFunctions
 
size_t numClasses
 
std::unique_ptr< DBMatOfflineoffline
 
std::vector< std::unique_ptr< DBMatOffline > > offlineContainer
 
std::map< double, double > prior
 
size_t processedPoints
 
DatasettestData
 
DatasettrainData
 
bool trained
 
bool usePrior
 
DatasetvalidationData
 

Detailed Description

LearnerSGDEOnOff learns the data using sparse grid density estimation.

The system matrix is precomputed and factorized using Eigen-, LU- or Cholesky decomposition (offline step). Then, for each class a density function is computed by solving the system in every iteration (online step). If Cholesky decomposition is chosen, refinement/coarsening can be applied.

Constructor & Destructor Documentation

sgpp::datadriven::LearnerSGDEOnOff::LearnerSGDEOnOff ( sgpp::base::RegularGridConfiguration gridConfig,
sgpp::base::AdpativityConfiguration adaptivityConfig,
sgpp::datadriven::RegularizationConfiguration regularizationConfig,
sgpp::datadriven::DensityEstimationConfiguration densityEstimationConfig,
Dataset trainData,
Dataset testData,
Dataset validationData,
DataVector classLabels,
size_t  classNumber,
bool  usePrior,
double  beta,
std::string  matrixfile = "" 
)

Constructor.

Parameters
gridConfigThe configuration of the grid
adaptivityConfigThe configuration of the grid adaptivity
regularizationConfigThe configuration of the grid regularization
densityEstimationConfigThe configuration of the matrix decomposition
trainDataThe (mandatory) training dataset
testDataThe (mandatory) test dataset
validationDataThe (optional) validation dataset
classLabelsThe class labels (e.g. -1, 1)
classNumberTotal number of classes
usePriorDetermines if prior probabilities should be used to compute class labels
betaThe initial weighting factor
matrixfilepath to a decomposed matrix file

References avgErrors, beta, sgpp::datadriven::DBMatOnlineDEFactory::buildDBMatOnlineDE(), sgpp::datadriven::DBMatOfflineFactory::buildFromFile(), sgpp::datadriven::DBMatOfflineFactory::buildOfflineObject(), classLabels, densityFunctions, numClasses, offline, offlineContainer, prior, processedPoints, testData, trained, usePrior, and validationData.

Member Function Documentation

double sgpp::datadriven::LearnerSGDEOnOff::getAccuracy ( ) const

Returns the accuracy of the classifier measured on the test data.

Returns
The classification accuracy measured on the test data

References sgpp::base::DataVector::get(), sgpp::datadriven::Dataset::getData(), sgpp::datadriven::Dataset::getNumberInstances(), sgpp::datadriven::Dataset::getTargets(), predict(), and testData.

Referenced by train().

void sgpp::datadriven::LearnerSGDEOnOff::getAvgErrors ( DataVector result) const

References avgErrors.

void sgpp::datadriven::LearnerSGDEOnOff::getDensities ( DataVector point,
DataVector density 
) const

Returns the values of all density functions for a specified data point.

Parameters
pointThe point for which the density functions should be evaluated
densityThe function evaluations

References densityFunctions.

double sgpp::datadriven::LearnerSGDEOnOff::getError ( Dataset dataset) const

Error evaluation required for convergence-based refinement.

Parameters
datasetThe data to measure the error on
Returns
The error evaluation

References sgpp::base::DataVector::get(), sgpp::datadriven::Dataset::getData(), sgpp::datadriven::Dataset::getNumberInstances(), sgpp::datadriven::Dataset::getTargets(), and predict().

Referenced by sgpp::datadriven::RefinementHandler::checkRefinementNecessary(), and train().

size_t sgpp::datadriven::LearnerSGDEOnOff::getNumClasses ( ) const
void sgpp::datadriven::LearnerSGDEOnOff::predict ( DataMatrix test,
DataVector classLabels 
) const

Predicts the class labels of the test data points.

Parameters
testThe data points for which labels will be precicted
classLabelsvector containing the predicted class labels

References densityFunctions, sgpp::base::DataMatrix::getNrows(), numClasses, and prior.

Referenced by getAccuracy(), getError(), predict(), and storeResults().

int sgpp::datadriven::LearnerSGDEOnOff::predict ( DataVector p) const

Predicts the class label of the given data point.

Parameters
pThe data point
Returns
The predicted class label

References sgpp::base::DataVector::getSize(), predict(), and sgpp::base::DataMatrix::setRow().

void sgpp::datadriven::LearnerSGDEOnOff::setCrossValidationParameters ( int  lambdaStep,
double  lambdaStart,
double  lambdaEnd,
DataMatrix test,
DataMatrix testRes,
bool  logscale 
)

Sets the cross-validation parameters.

They get directly passed to the DBMatOnlineDE class-instance.

Parameters
lambdaStepDefines how many different lambdas are tried out
lambdaStartThe smallest possible lambda
lambdaEndThe biggest possible lambda
testThe test matrix
testResThe results of the points in the test matrix
logscaleIndicates whether the values between lambdaStart and lambdaEnd are searched using logscale or not

References densityFunctions.

void sgpp::datadriven::LearnerSGDEOnOff::train ( size_t  batchSize,
size_t  maxDataPasses,
std::string  refType,
std::string  refMonitor,
size_t  refPeriod,
double  accDeclineThreshold,
size_t  accDeclineBufferSize,
size_t  minRefInterval,
bool  enableCv,
size_t  nextCvStep 
)

Trains the learner with the given dataset.

Parameters
batchSizeSize of subset of data points used for each training step
maxDataPassesThe number of passes over the whole training data
refTypeThe refinement indicator (surplus, zero-crossings or data-based)
refMonitorThe refinement strategy (periodic or convergence-based)
refPeriodThe refinement interval (if periodic refinement is chosen)
accDeclineThresholdThe convergence threshold (if convergence-based refinement is chosen)
accDeclineBufferSizeThe number of accuracy measurements which are used to check convergence (if convergence-based refinement is chosen)
minRefIntervalThe minimum number of data points (or data batches) which have to be processed before next refinement can be scheduled (if convergence-based refinement is chosen)
enableCvSpecifies whether to perform cross-validation during training process or not
nextCvStepDetermines when next cross-validation has to be triggered

References sgpp::base::DataVector::append(), avgErrors, python.leja::count, sgpp::base::DataVector::get(), getAccuracy(), sgpp::datadriven::Dataset::getData(), getDensityFunctions(), sgpp::datadriven::Dataset::getDimension(), getError(), sgpp::datadriven::Dataset::getNumberInstances(), sgpp::base::DataMatrix::getRow(), sgpp::base::Grid::getSize(), sgpp::datadriven::Dataset::getTargets(), grid(), numClasses, offline, processedPoints, refine(), sgpp::base::DataVector::set(), sgpp::base::DataMatrix::setRow(), trainData, and validationData.

Referenced by train().

void sgpp::datadriven::LearnerSGDEOnOff::train ( Dataset dataset,
bool  doCv = false,
std::vector< std::pair< std::list< size_t >, size_t >> *  refineCoarse = nullptr 
)

Trains the learner with the given data batch.

Parameters
datasetThe next data batch to process
doCvEnable cross-validation
refineCoarseVector of pairs containing a list representing indices of removed grid points and an unsigned int representing added grid points

References classLabels, sgpp::datadriven::Dataset::getData(), sgpp::datadriven::Dataset::getDimension(), sgpp::datadriven::Dataset::getNumberInstances(), sgpp::base::DataMatrix::getRow(), sgpp::base::DataVector::getSize(), sgpp::datadriven::Dataset::getTargets(), and train().

void sgpp::datadriven::LearnerSGDEOnOff::train ( std::vector< std::pair< DataMatrix *, double >> &  trainDataClasses,
bool  doCv = false,
std::vector< std::pair< std::list< size_t >, size_t >> *  refineCoarse = nullptr 
)

Trains the learner with the given data batch that is already split up wrt its different classes.

Parameters
trainDataClassesA vector of pairs; Each pair contains the data points that belong to one class and the corresponding class label
doCvEnable cross-validation
refineCoarseVector of pairs containing a list representing indices of removed grid points and an unsigned int representing added grid points

References densityFunctions, prior, processedPoints, trained, and usePrior.

Member Data Documentation

DataVector sgpp::datadriven::LearnerSGDEOnOff::avgErrors
protected
double sgpp::datadriven::LearnerSGDEOnOff::beta
protected

Referenced by LearnerSGDEOnOff().

size_t sgpp::datadriven::LearnerSGDEOnOff::numClasses
protected
std::vector<std::unique_ptr<DBMatOffline> > sgpp::datadriven::LearnerSGDEOnOff::offlineContainer
protected

Referenced by LearnerSGDEOnOff().

std::map<double, double> sgpp::datadriven::LearnerSGDEOnOff::prior
protected
size_t sgpp::datadriven::LearnerSGDEOnOff::processedPoints
protected
Dataset& sgpp::datadriven::LearnerSGDEOnOff::testData
protected
bool sgpp::datadriven::LearnerSGDEOnOff::trained
protected
Dataset* sgpp::datadriven::LearnerSGDEOnOff::validationData
protected

The documentation for this class was generated from the following files: