SG++-Doxygen-Documentation
sgpp::datadriven::LearnerSGD Class Reference

LearnerSGD learns the data using stochastic gradient descent. More...

#include <LearnerSGD.hpp>

Public Member Functions

double getAccuracy (sgpp::base::DataMatrix &testData, sgpp::base::DataVector &testLabels, double threshold)
 Computes the classification accuracy on the given dataset. More...
 
void initialize ()
 Initializes the SGD learner (creates grid etc.). More...
 
 LearnerSGD (base::RegularGridConfiguration &gridConfig, base::AdaptivityConfiguration &adaptivityConfig, base::DataMatrix &pTrainData, base::DataVector &pTrainLabels, base::DataMatrix &pTestData, base::DataVector &pTestLabels, base::DataMatrix *pValData, base::DataVector *pValLabels, double lambda, double gamma, size_t batchSize, bool useValidData)
 Constructor. More...
 
void storeResults (base::DataMatrix &testDataset)
 Stores classified data, grids and function evaluations to csv files. More...
 
void train (size_t maxDataPasses, std::string refType, std::string refMonitor, size_t refPeriod, double errorDeclineThreshold, size_t errorDeclineBufferSize, size_t minRefInterval)
 Implements online learning using stochastic gradient descent. More...
 
 ~LearnerSGD ()
 Destructor. More...
 

Public Attributes

sgpp::base::DataVector avgErrors
 
double error
 

Protected Member Functions

std::unique_ptr< base::GridcreateRegularGrid ()
 Generates a regular grid. More...
 
double getAccuracy (sgpp::base::DataVector &testLabels, double threshold, sgpp::base::DataVector &predictedLabels)
 Computes the classification accuracy. More...
 
void getBatchError (sgpp::base::DataMatrix &data, const sgpp::base::DataVector &labels)
 Computes error contribution for each data point of the given data set (required for predictive refinement indicator). More...
 
double getError (sgpp::base::DataMatrix &data, sgpp::base::DataVector &labels, std::string errorType)
 Computes specified error type (e.g. More...
 
void predict (base::DataMatrix &testData, base::DataVector &predictedLabels)
 Predicts class labels based on the trained model. More...
 
void pushToBatch (sgpp::base::DataVector &x, double y)
 Stores the last 'batchSize' processed data points if no validation data is provided. More...
 

Protected Attributes

base::AdaptivityConfiguration adaptivityConfig
 
base::DataVector alpha
 
base::DataVector alphaAvg
 
base::DataMatrixbatchData
 
base::DataVector batchError
 
base::DataVectorbatchLabels
 
size_t batchSize
 
double currentGamma
 
double gamma
 
std::unique_ptr< base::Gridgrid
 
base::RegularGridConfiguration gridConfig
 
double lambda
 
base::DataMatrixtestData
 
base::DataVectortestLabels
 
base::DataMatrixtrainData
 
base::DataVectortrainLabels
 
bool useValidData
 

Detailed Description

LearnerSGD learns the data using stochastic gradient descent.

Constructor & Destructor Documentation

sgpp::datadriven::LearnerSGD::LearnerSGD ( base::RegularGridConfiguration gridConfig,
base::AdaptivityConfiguration adaptivityConfig,
base::DataMatrix pTrainData,
base::DataVector pTrainLabels,
base::DataMatrix pTestData,
base::DataVector pTestLabels,
base::DataMatrix pValData,
base::DataVector pValLabels,
double  lambda,
double  gamma,
size_t  batchSize,
bool  useValidData 
)

Constructor.

Parameters
gridConfigThe grid configuration
adaptivityConfigThe refinement configuration
pTrainDataThe training dataset
pTrainLabelsThe corresponding training labels
pTestDataThe test dataset
pTestLabelsThe corresponding test labels
pValDataThe validation dataset
pValLabelsThe corresponding validation labels
lambdaThe regularization parameter
gammaThe learning parameter (i.e. step width)
batchSizeThe number of data points which are considered to compute the error contributions for predictive refinement
useValidDataSpecifies if validation data should be used for all error computations

References batchData, batchLabels, sgpp::base::DataMatrix::getNcols(), sgpp::base::DataMatrix::reserveAdditionalRows(), sgpp::base::DataVector::setAll(), and trainData.

sgpp::datadriven::LearnerSGD::~LearnerSGD ( )

Destructor.

Member Function Documentation

double sgpp::datadriven::LearnerSGD::getAccuracy ( sgpp::base::DataMatrix testData,
sgpp::base::DataVector testLabels,
double  threshold 
)

Computes the classification accuracy on the given dataset.

Parameters
testDataThe data for which class labels should be predicted
testLabelsThe corresponding actual class labels
thresholdThe decision threshold (e.g. for class labels -1, 1 -> threshold = 0)
Returns
The resulting accuracy

References sgpp::base::DataMatrix::getNrows(), and predict().

Referenced by train().

double sgpp::datadriven::LearnerSGD::getAccuracy ( sgpp::base::DataVector testLabels,
double  threshold,
sgpp::base::DataVector predictedLabels 
)
protected

Computes the classification accuracy.

Parameters
testLabelsThe actual class labels
thresholdThe decision threshold (e.g. for class labels -1, 1 -> threshold = 0)
predictedLabelsThe predicted class labels
Returns
The resulting accuracy

References sgpp::base::DataVector::get(), sgpp::base::DataVector::getSize(), and python.statsfileInfo::i.

void sgpp::datadriven::LearnerSGD::getBatchError ( sgpp::base::DataMatrix data,
const sgpp::base::DataVector labels 
)
protected

Computes error contribution for each data point of the given data set (required for predictive refinement indicator).

Parameters
dataThe data points
labelsThe corresponding class labels

References alphaAvg, batchError, sgpp::op_factory::createOperationMultipleEval(), sgpp::base::DataVector::get(), sgpp::base::DataMatrix::getNrows(), grid, python.statsfileInfo::i, sgpp::combigrid::pow(), and sgpp::base::DataVector::set().

Referenced by train().

double sgpp::datadriven::LearnerSGD::getError ( sgpp::base::DataMatrix data,
sgpp::base::DataVector labels,
std::string  errorType 
)
protected

Computes specified error type (e.g.

MSE).

Parameters
dataThe data points
labelsThe corresponding class labels
errorTypeThe type of the error measurement (MSE or Hinge loss)
Returns
The error estimation

References alphaAvg, sgpp::op_factory::createOperationMultipleEval(), error, sgpp::base::DataVector::get(), sgpp::base::DataMatrix::getNrows(), grid, python.statsfileInfo::i, sgpp::base::DataVector::set(), and sgpp::base::DataVector::setAll().

Referenced by train().

void sgpp::datadriven::LearnerSGD::initialize ( )
void sgpp::datadriven::LearnerSGD::predict ( base::DataMatrix testData,
base::DataVector predictedLabels 
)
protected

Predicts class labels based on the trained model.

Parameters
testDataThe data for which class labels should be predicted
predictedLabelsThe predicted class labels

References alphaAvg, sgpp::op_factory::createOperationMultipleEval(), sgpp::base::DataMatrix::getNrows(), grid, python.statsfileInfo::i, and sgpp::base::DataVector::set().

Referenced by getAccuracy(), storeResults(), and train().

void sgpp::datadriven::LearnerSGD::pushToBatch ( sgpp::base::DataVector x,
double  y 
)
protected

Stores the last 'batchSize' processed data points if no validation data is provided.

Parameters
xThe current data point
yThe corresponding class label

References sgpp::base::DataMatrix::appendRow(), batchData, batchSize, sgpp::base::DataMatrix::getNrows(), and sgpp::base::DataMatrix::setRow().

Referenced by train().

void sgpp::datadriven::LearnerSGD::train ( size_t  maxDataPasses,
std::string  refType,
std::string  refMonitor,
size_t  refPeriod,
double  errorDeclineThreshold,
size_t  errorDeclineBufferSize,
size_t  minRefInterval 
)

Implements online learning using stochastic gradient descent.

Parameters
maxDataPassesThe number of passes over the whole training data
refTypeThe refinement indicator (surplus, zero-crossings or data-based)
refMonitorThe refinement strategy (periodic or convergence-based)
refPeriodThe refinement interval (if periodic refinement is chosen)
errorDeclineThresholdThe convergence threshold (if convergence-based refinement is chosen)
errorDeclineBufferSizeThe number of error measurements which are used to check convergence (if convergence-based refinement is chosen)
minRefIntervalThe minimum number of data points which have to be processed before next refinement can be scheduled (if convergence-based refinement is chosen)

References adaptivityConfig, alpha, alphaAvg, sgpp::base::DataVector::append(), avgErrors, sgpp::base::DataVector::axpy(), batchData, batchError, batchLabels, sgpp::op_factory::createOperationMultipleEval(), currentGamma, chess::dim, error, sgpp::base::ImpurityRefinement::free_refine(), sgpp::base::PredictiveRefinement::free_refine(), gamma, sgpp::base::DataVector::get(), getAccuracy(), getBatchError(), getError(), sgpp::base::DataMatrix::getNcols(), sgpp::base::DataMatrix::getNrows(), sgpp::base::DataVector::getPointer(), sgpp::base::DataMatrix::getRow(), sgpp::base::DataVector::getSize(), grid, python.utils.sg_projections::gridStorage, lambda, mu, sgpp::base::DataVector::mult(), sgpp::base::AdaptivityConfiguration::noPoints_, sgpp::base::AdaptivityConfiguration::numRefinements_, sgpp::combigrid::pow(), predict(), pushToBatch(), sgpp::datadriven::RefinementMonitor::pushToBuffer(), sgpp::datadriven::RefinementMonitor::refinementsNecessary(), sgpp::base::DataVector::resizeZero(), testData, testLabels, sgpp::base::AdaptivityConfiguration::threshold_, trainData, trainLabels, and useValidData.

Member Data Documentation

base::AdaptivityConfiguration sgpp::datadriven::LearnerSGD::adaptivityConfig
protected

Referenced by train().

base::DataVector sgpp::datadriven::LearnerSGD::alpha
protected
base::DataVector sgpp::datadriven::LearnerSGD::alphaAvg
protected
sgpp::base::DataVector sgpp::datadriven::LearnerSGD::avgErrors

Referenced by train().

base::DataMatrix* sgpp::datadriven::LearnerSGD::batchData
protected

Referenced by LearnerSGD(), pushToBatch(), and train().

base::DataVector sgpp::datadriven::LearnerSGD::batchError
protected

Referenced by getBatchError(), initialize(), and train().

base::DataVector* sgpp::datadriven::LearnerSGD::batchLabels
protected

Referenced by LearnerSGD(), and train().

size_t sgpp::datadriven::LearnerSGD::batchSize
protected

Referenced by initialize(), and pushToBatch().

double sgpp::datadriven::LearnerSGD::currentGamma
protected

Referenced by train().

double sgpp::datadriven::LearnerSGD::error

Referenced by getError(), and train().

double sgpp::datadriven::LearnerSGD::gamma
protected

Referenced by train().

base::RegularGridConfiguration sgpp::datadriven::LearnerSGD::gridConfig
protected

Referenced by createRegularGrid(), and initialize().

double sgpp::datadriven::LearnerSGD::lambda
protected

Referenced by train().

base::DataMatrix& sgpp::datadriven::LearnerSGD::testData
protected

Referenced by train().

base::DataVector& sgpp::datadriven::LearnerSGD::testLabels
protected

Referenced by train().

base::DataVector& sgpp::datadriven::LearnerSGD::trainLabels
protected

Referenced by train().

bool sgpp::datadriven::LearnerSGD::useValidData
protected

Referenced by train().


The documentation for this class was generated from the following files: