SG++-Doxygen-Documentation
sgpp::datadriven::LearnerSVM Class Reference

LearnerSVM learns the data using support vector machines and sparse grid kernels. More...

#include <LearnerSVM.hpp>

Public Member Functions

double getAccuracy (sgpp::base::DataMatrix &testDataset, const sgpp::base::DataVector &referenceLabels, const double threshold)
 Computes the classification accuracy on the given dataset. More...
 
double getAccuracy (const sgpp::base::DataVector &referenceLabels, const double threshold, const sgpp::base::DataVector &predictedLabels)
 Computes the classification accuracy. More...
 
double getError (sgpp::base::DataMatrix &data, sgpp::base::DataVector &labels, std::string errorType)
 Computes specified error type (e.g. More...
 
void initialize (size_t budget)
 Initializes the SVM learner. More...
 
 LearnerSVM (base::RegularGridConfiguration &gridConfig, base::AdaptivityConfiguration &adaptConfig, base::DataMatrix &pTrainData, base::DataVector &pTrainLabels, base::DataMatrix &pTestData, base::DataVector &pTestLabels, base::DataMatrix *pValidData, base::DataVector *pValidLabels)
 Constructor. More...
 
void predict (sgpp::base::DataMatrix &testData, sgpp::base::DataVector &predictedLabels)
 Predicts class labels based on the trained model. More...
 
void storeResults (sgpp::base::DataMatrix &testDataset)
 Stores classified data, grids and function evaluations to csv files. More...
 
void train (size_t maxDataPasses, double lambda, double betaRef, std::string refType, std::string refMonitor, size_t refPeriod, double errorDeclineThreshold, size_t errorDeclineBufferSize, size_t minRefInterval)
 Implements support vector learning with sparse grid kernels. More...
 
 ~LearnerSVM ()
 Destructor. More...
 

Public Attributes

sgpp::base::DataVector avgErrors
 
double error
 

Protected Member Functions

std::unique_ptr< base::GridcreateRegularGrid ()
 Generates a regular sparse grid. More...
 

Protected Attributes

base::AdaptivityConfiguration adaptivityConfig
 
std::unique_ptr< base::Gridgrid
 
base::RegularGridConfiguration gridConfig
 
std::unique_ptr< PrimalDualSVMsvm
 
base::DataMatrixtestData
 
base::DataVectortestLabels
 
base::DataMatrixtrainData
 
base::DataVectortrainLabels
 
base::DataMatrixvalidData
 
base::DataVectorvalidLabels
 

Detailed Description

LearnerSVM learns the data using support vector machines and sparse grid kernels.

As learning algorithm the Pegasos-method is implemented.

Constructor & Destructor Documentation

sgpp::datadriven::LearnerSVM::LearnerSVM ( base::RegularGridConfiguration gridConfig,
base::AdaptivityConfiguration adaptConfig,
base::DataMatrix pTrainData,
base::DataVector pTrainLabels,
base::DataMatrix pTestData,
base::DataVector pTestLabels,
base::DataMatrix pValidData,
base::DataVector pValidLabels 
)

Constructor.

Parameters
gridConfigThe grid configuration
adaptConfigThe refinement configuration
pTrainDataThe training dataset
pTrainLabelsThe corresponding training labels
pTestDataThe test dataset
pTestLabelsThe corresponding test labels
pValidDataThe validation dataset
pValidLabelsThe corresponding validation labels
sgpp::datadriven::LearnerSVM::~LearnerSVM ( )

Destructor.

Member Function Documentation

std::unique_ptr< base::Grid > sgpp::datadriven::LearnerSVM::createRegularGrid ( )
protected
double sgpp::datadriven::LearnerSVM::getAccuracy ( sgpp::base::DataMatrix testDataset,
const sgpp::base::DataVector referenceLabels,
const double  threshold 
)

Computes the classification accuracy on the given dataset.

Parameters
testDatasetThe data for which class labels should be predicted
referenceLabelsThe corresponding actual class labels
thresholdThe decision threshold (e.g. for class labels -1, 1 -> threshold = 0)
Returns
The resulting accuracy

References sgpp::base::DataMatrix::getNrows(), and predict().

Referenced by train().

double sgpp::datadriven::LearnerSVM::getAccuracy ( const sgpp::base::DataVector referenceLabels,
const double  threshold,
const sgpp::base::DataVector predictedLabels 
)

Computes the classification accuracy.

Parameters
referenceLabelsThe actual class labels
thresholdThe decision threshold (e.g. for class labels -1, 1 -> threshold = 0)
predictedLabelsThe predicted class labels
Returns
The resulting accuracy

References sgpp::base::DataVector::get(), sgpp::base::DataVector::getSize(), and python.statsfileInfo::i.

double sgpp::datadriven::LearnerSVM::getError ( sgpp::base::DataMatrix data,
sgpp::base::DataVector labels,
std::string  errorType 
)

Computes specified error type (e.g.

MSE).

Parameters
dataThe data points
labelsThe corresponding class labels
errorTypeThe type of the error measurement (MSE or Hinge loss)
Returns
The error estimation

References chess::dim, error, sgpp::base::DataVector::get(), sgpp::base::DataMatrix::getNcols(), sgpp::base::DataMatrix::getNrows(), sgpp::base::DataMatrix::getRow(), grid, python.statsfileInfo::i, sgpp::base::DataVector::set(), sgpp::base::DataVector::setAll(), and svm.

Referenced by train().

void sgpp::datadriven::LearnerSVM::initialize ( size_t  budget)

Initializes the SVM learner.

Parameters
budgetThe max. number of stored support vectors

References createRegularGrid(), sgpp::base::DataMatrix::getNcols(), grid, svm, and trainData.

void sgpp::datadriven::LearnerSVM::predict ( sgpp::base::DataMatrix testData,
sgpp::base::DataVector predictedLabels 
)

Predicts class labels based on the trained model.

Parameters
testDataThe data for which class labels should be predicted
predictedLabelsThe predicted class labels

References chess::dim, sgpp::base::DataMatrix::getNcols(), sgpp::base::DataMatrix::getNrows(), sgpp::base::DataMatrix::getRow(), grid, python.statsfileInfo::i, sgpp::base::DataVector::set(), and svm.

Referenced by getAccuracy(), storeResults(), and train().

void sgpp::datadriven::LearnerSVM::train ( size_t  maxDataPasses,
double  lambda,
double  betaRef,
std::string  refType,
std::string  refMonitor,
size_t  refPeriod,
double  errorDeclineThreshold,
size_t  errorDeclineBufferSize,
size_t  minRefInterval 
)

Implements support vector learning with sparse grid kernels.

Parameters
maxDataPassesThe number of passes over the whole training data
lambdaThe regularization parameter
betaRefWeighting factor for grid points; used within combined-measure refinement
refTypeThe refinement indicator (surplus, zero-crossings or data-based)
refMonitorThe refinement strategy (periodic or convergence-based)
refPeriodThe refinement interval (if periodic refinement is chosen)
errorDeclineThresholdThe convergence threshold (if convergence-based refinement is chosen)
errorDeclineBufferSizeThe number of error measurements which are used to check convergence (if convergence-based refinement is chosen)
minRefIntervalThe minimum number of data points which have to be processed before next refinement can be scheduled (if convergence-based refinement is chosen)

References adaptivityConfig, sgpp::base::DataVector::append(), avgErrors, chess::dim, error, sgpp::base::ImpurityRefinement::free_refine(), sgpp::base::ForwardSelectorRefinement::free_refine(), sgpp::base::DataVector::get(), getAccuracy(), getError(), sgpp::base::DataMatrix::getNcols(), sgpp::base::DataMatrix::getNrows(), sgpp::base::DataMatrix::getRow(), grid, python.utils.sg_projections::gridStorage, python.statsfileInfo::i, sgpp::base::AdaptivityConfiguration::noPoints_, predict(), sgpp::datadriven::RefinementMonitor::pushToBuffer(), sgpp::datadriven::RefinementMonitor::refinementsNecessary(), svm, testData, testLabels, sgpp::base::AdaptivityConfiguration::threshold_, trainData, trainLabels, validData, and validLabels.

Member Data Documentation

base::AdaptivityConfiguration sgpp::datadriven::LearnerSVM::adaptivityConfig
protected

Referenced by train().

sgpp::base::DataVector sgpp::datadriven::LearnerSVM::avgErrors

Referenced by train().

double sgpp::datadriven::LearnerSVM::error

Referenced by getError(), and train().

base::RegularGridConfiguration sgpp::datadriven::LearnerSVM::gridConfig
protected

Referenced by createRegularGrid().

std::unique_ptr<PrimalDualSVM> sgpp::datadriven::LearnerSVM::svm
protected
base::DataMatrix& sgpp::datadriven::LearnerSVM::testData
protected

Referenced by train().

base::DataVector& sgpp::datadriven::LearnerSVM::testLabels
protected

Referenced by train().

base::DataVector& sgpp::datadriven::LearnerSVM::trainLabels
protected

Referenced by train().

base::DataMatrix* sgpp::datadriven::LearnerSVM::validData
protected

Referenced by train().

base::DataVector* sgpp::datadriven::LearnerSVM::validLabels
protected

Referenced by train().


The documentation for this class was generated from the following files: