SG++
sgpp::datadriven::LearnerSVM Class Reference

LearnerSVM learns the data using support vector machines and sparse grid kernels. More...

#include <LearnerSVM.hpp>

Public Member Functions

double getAccuracy (sgpp::base::DataMatrix &testDataset, const sgpp::base::DataVector &referenceLabels, const double threshold)
 Computes the classification accuracy on the given dataset. More...
 
double getAccuracy (const sgpp::base::DataVector &referenceLabels, const double threshold, const sgpp::base::DataVector &predictedLabels)
 Computes the classification accuracy. More...
 
double getError (sgpp::base::DataMatrix &data, sgpp::base::DataVector &labels, std::string errorType)
 Computes specified error type (e.g. More...
 
void initialize (size_t budget)
 Initializes the SVM learner. More...
 
 LearnerSVM (base::RegularGridConfiguration &gridConfig, base::AdpativityConfiguration &adaptConfig, base::DataMatrix &pTrainData, base::DataVector &pTrainLabels, base::DataMatrix &pTestData, base::DataVector &pTestLabels, base::DataMatrix *pValidData, base::DataVector *pValidLabels)
 Constructor. More...
 
void predict (sgpp::base::DataMatrix &testData, sgpp::base::DataVector &predictedLabels)
 Predicts class labels based on the trained model. More...
 
void storeResults (sgpp::base::DataMatrix &testDataset)
 Stores classified data, grids and function evaluations to csv files. More...
 
void train (size_t maxDataPasses, double lambda, double betaRef, std::string refType, std::string refMonitor, size_t refPeriod, double errorDeclineThreshold, size_t errorDeclineBufferSize, size_t minRefInterval)
 Implements support vector learning with sparse grid kernels. More...
 
 ~LearnerSVM ()
 Destructor. More...
 

Public Attributes

sgpp::base::DataVector avgErrors
 
double error
 

Protected Member Functions

std::unique_ptr< base::GridcreateRegularGrid ()
 Generates a regular sparse grid. More...
 

Protected Attributes

base::AdpativityConfiguration adaptivityConfig
 
std::unique_ptr< base::Gridgrid
 
base::RegularGridConfiguration gridConfig
 
std::unique_ptr< PrimalDualSVMsvm
 
base::DataMatrixtestData
 
base::DataVectortestLabels
 
base::DataMatrixtrainData
 
base::DataVectortrainLabels
 
base::DataMatrixvalidData
 
base::DataVectorvalidLabels
 

Detailed Description

LearnerSVM learns the data using support vector machines and sparse grid kernels.

As learning algorithm the Pegasos-method is implemented.

Constructor & Destructor Documentation

sgpp::datadriven::LearnerSVM::LearnerSVM ( base::RegularGridConfiguration gridConfig,
base::AdpativityConfiguration adaptConfig,
base::DataMatrix pTrainData,
base::DataVector pTrainLabels,
base::DataMatrix pTestData,
base::DataVector pTestLabels,
base::DataMatrix pValidData,
base::DataVector pValidLabels 
)

Constructor.

Parameters
gridConfigThe grid configuration
adaptConfigThe refinement configuration
pTrainDataThe training dataset
pTrainLabelsThe corresponding training labels
pTestDataThe test dataset
pTestLabelsThe corresponding test labels
pValidDataThe validation dataset
pValidLabelsThe corresponding validation labels
sgpp::datadriven::LearnerSVM::~LearnerSVM ( )

Destructor.

Member Function Documentation

std::unique_ptr< base::Grid > sgpp::datadriven::LearnerSVM::createRegularGrid ( )
protected
double sgpp::datadriven::LearnerSVM::getAccuracy ( sgpp::base::DataMatrix testDataset,
const sgpp::base::DataVector referenceLabels,
const double  threshold 
)

Computes the classification accuracy on the given dataset.

Parameters
testDatasetThe data for which class labels should be predicted
referenceLabelsThe corresponding actual class labels
thresholdThe decision threshold (e.g. for class labels -1, 1 -> threshold = 0)
Returns
The resulting accuracy

References sgpp::base::DataMatrix::getNrows(), and predict().

Referenced by train().

double sgpp::datadriven::LearnerSVM::getAccuracy ( const sgpp::base::DataVector referenceLabels,
const double  threshold,
const sgpp::base::DataVector predictedLabels 
)

Computes the classification accuracy.

Parameters
referenceLabelsThe actual class labels
thresholdThe decision threshold (e.g. for class labels -1, 1 -> threshold = 0)
predictedLabelsThe predicted class labels
Returns
The resulting accuracy

References sgpp::base::DataVector::get(), and sgpp::base::DataVector::getSize().

double sgpp::datadriven::LearnerSVM::getError ( sgpp::base::DataMatrix data,
sgpp::base::DataVector labels,
std::string  errorType 
)

Computes specified error type (e.g.

MSE).

Parameters
dataThe data points
labelsThe corresponding class labels
errorTypeThe type of the error measurement (MSE or Hinge loss)
Returns
The error estimation

References error, sgpp::base::DataVector::get(), sgpp::base::DataMatrix::getNcols(), sgpp::base::DataMatrix::getNrows(), sgpp::base::DataMatrix::getRow(), grid, sgpp::base::DataVector::set(), sgpp::base::DataVector::setAll(), and svm.

Referenced by train().

void sgpp::datadriven::LearnerSVM::initialize ( size_t  budget)

Initializes the SVM learner.

Parameters
budgetThe max. number of stored support vectors

References createRegularGrid(), sgpp::base::DataMatrix::getNcols(), grid, svm, and trainData.

void sgpp::datadriven::LearnerSVM::predict ( sgpp::base::DataMatrix testData,
sgpp::base::DataVector predictedLabels 
)

Predicts class labels based on the trained model.

Parameters
testDataThe data for which class labels should be predicted
predictedLabelsThe predicted class labels

References sgpp::base::DataMatrix::getNcols(), sgpp::base::DataMatrix::getNrows(), sgpp::base::DataMatrix::getRow(), grid, sgpp::base::DataVector::set(), and svm.

Referenced by getAccuracy(), storeResults(), and train().

void sgpp::datadriven::LearnerSVM::train ( size_t  maxDataPasses,
double  lambda,
double  betaRef,
std::string  refType,
std::string  refMonitor,
size_t  refPeriod,
double  errorDeclineThreshold,
size_t  errorDeclineBufferSize,
size_t  minRefInterval 
)

Implements support vector learning with sparse grid kernels.

Parameters
maxDataPassesThe number of passes over the whole training data
lambdaThe regularization parameter
betaRefWeighting factor for grid points; used within combined-measure refinement
refTypeThe refinement indicator (surplus, zero-crossings or data-based)
refMonitorThe refinement strategy (periodic or convergence-based)
refPeriodThe refinement interval (if periodic refinement is chosen)
errorDeclineThresholdThe convergence threshold (if convergence-based refinement is chosen)
errorDeclineBufferSizeThe number of error measurements which are used to check convergence (if convergence-based refinement is chosen)
minRefIntervalThe minimum number of data points which have to be processed before next refinement can be scheduled (if convergence-based refinement is chosen)

References adaptivityConfig, sgpp::base::DataVector::append(), avgErrors, error, sgpp::base::ImpurityRefinement::free_refine(), sgpp::base::ForwardSelectorRefinement::free_refine(), sgpp::base::DataVector::get(), getAccuracy(), getError(), sgpp::base::DataMatrix::getNcols(), sgpp::base::DataMatrix::getNrows(), sgpp::base::DataMatrix::getRow(), grid, sgpp::base::AdpativityConfiguration::noPoints_, sgpp::base::AdpativityConfiguration::numRefinements_, predict(), svm, testData, testLabels, sgpp::base::AdpativityConfiguration::threshold_, trainData, trainLabels, validData, and validLabels.

Member Data Documentation

base::AdpativityConfiguration sgpp::datadriven::LearnerSVM::adaptivityConfig
protected

Referenced by train().

sgpp::base::DataVector sgpp::datadriven::LearnerSVM::avgErrors

Referenced by train().

double sgpp::datadriven::LearnerSVM::error

Referenced by getError(), and train().

base::RegularGridConfiguration sgpp::datadriven::LearnerSVM::gridConfig
protected

Referenced by createRegularGrid().

std::unique_ptr<PrimalDualSVM> sgpp::datadriven::LearnerSVM::svm
protected
base::DataMatrix& sgpp::datadriven::LearnerSVM::testData
protected

Referenced by train().

base::DataVector& sgpp::datadriven::LearnerSVM::testLabels
protected

Referenced by train().

base::DataMatrix& sgpp::datadriven::LearnerSVM::trainData
protected

Referenced by initialize(), and train().

base::DataVector& sgpp::datadriven::LearnerSVM::trainLabels
protected

Referenced by train().

base::DataMatrix* sgpp::datadriven::LearnerSVM::validData
protected

Referenced by train().

base::DataVector* sgpp::datadriven::LearnerSVM::validLabels
protected

Referenced by train().


The documentation for this class was generated from the following files: