SG++-Doxygen-Documentation
sgpp::datadriven::LearnerSGD Class Reference

LearnerSGD learns the data using stochastic gradient descent. More...

#include <LearnerSGD.hpp>

Public Member Functions

double getAccuracy (sgpp::base::DataMatrix &testData, sgpp::base::DataVector &testLabels, double threshold)
 Computes the classification accuracy on the given dataset. More...
 
void initialize ()
 Initializes the SGD learner (creates grid etc.). More...
 
 LearnerSGD (base::RegularGridConfiguration &gridConfig, base::AdaptivityConfiguration &adaptivityConfig, base::DataMatrix &pTrainData, base::DataVector &pTrainLabels, base::DataMatrix &pTestData, base::DataVector &pTestLabels, base::DataMatrix *pValData, base::DataVector *pValLabels, double lambda, double gamma, size_t batchSize, bool useValidData)
 Constructor. More...
 
void storeResults (base::DataMatrix &testDataset)
 Stores classified data, grids and function evaluations to csv files. More...
 
void train (size_t maxDataPasses, std::string refType, std::string refMonitor, size_t refPeriod, double errorDeclineThreshold, size_t errorDeclineBufferSize, size_t minRefInterval)
 Implements online learning using stochastic gradient descent. More...
 
 ~LearnerSGD ()
 Destructor. More...
 

Public Attributes

sgpp::base::DataVector avgErrors
 
double error
 

Protected Member Functions

std::unique_ptr< base::GridcreateRegularGrid ()
 Generates a regular grid. More...
 
double getAccuracy (sgpp::base::DataVector &testLabels, double threshold, sgpp::base::DataVector &predictedLabels)
 Computes the classification accuracy. More...
 
void getBatchError (sgpp::base::DataMatrix &data, const sgpp::base::DataVector &labels)
 Computes error contribution for each data point of the given data set (required for predictive refinement indicator). More...
 
double getError (sgpp::base::DataMatrix &data, sgpp::base::DataVector &labels, std::string errorType)
 Computes specified error type (e.g. More...
 
void predict (base::DataMatrix &testData, base::DataVector &predictedLabels)
 Predicts class labels based on the trained model. More...
 
void pushToBatch (sgpp::base::DataVector &x, double y)
 Stores the last 'batchSize' processed data points if no validation data is provided. More...
 

Protected Attributes

base::AdaptivityConfiguration adaptivityConfig
 
base::DataVector alpha
 
base::DataVector alphaAvg
 
base::DataMatrixbatchData
 
base::DataVector batchError
 
base::DataVectorbatchLabels
 
size_t batchSize
 
double currentGamma
 
double gamma
 
std::unique_ptr< base::Gridgrid
 
base::RegularGridConfiguration gridConfig
 
double lambda
 
base::DataMatrixtestData
 
base::DataVectortestLabels
 
base::DataMatrixtrainData
 
base::DataVectortrainLabels
 
bool useValidData
 

Detailed Description

LearnerSGD learns the data using stochastic gradient descent.

Constructor & Destructor Documentation

◆ LearnerSGD()

sgpp::datadriven::LearnerSGD::LearnerSGD ( base::RegularGridConfiguration gridConfig,
base::AdaptivityConfiguration adaptivityConfig,
base::DataMatrix pTrainData,
base::DataVector pTrainLabels,
base::DataMatrix pTestData,
base::DataVector pTestLabels,
base::DataMatrix pValData,
base::DataVector pValLabels,
double  lambda,
double  gamma,
size_t  batchSize,
bool  useValidData 
)

Constructor.

Parameters
gridConfigThe grid configuration
adaptivityConfigThe refinement configuration
pTrainDataThe training dataset
pTrainLabelsThe corresponding training labels
pTestDataThe test dataset
pTestLabelsThe corresponding test labels
pValDataThe validation dataset
pValLabelsThe corresponding validation labels
lambdaThe regularization parameter
gammaThe learning parameter (i.e. step width)
batchSizeThe number of data points which are considered to compute the error contributions for predictive refinement
useValidDataSpecifies if validation data should be used for all error computations

References batchData, batchLabels, sgpp::base::DataMatrix::getNcols(), sgpp::base::DataMatrix::reserveAdditionalRows(), sgpp::base::DataVector::setAll(), and trainData.

◆ ~LearnerSGD()

sgpp::datadriven::LearnerSGD::~LearnerSGD ( )

Destructor.

Member Function Documentation

◆ createRegularGrid()

◆ getAccuracy() [1/2]

double sgpp::datadriven::LearnerSGD::getAccuracy ( sgpp::base::DataMatrix testData,
sgpp::base::DataVector testLabels,
double  threshold 
)

Computes the classification accuracy on the given dataset.

Parameters
testDataThe data for which class labels should be predicted
testLabelsThe corresponding actual class labels
thresholdThe decision threshold (e.g. for class labels -1, 1 -> threshold = 0)
Returns
The resulting accuracy

References sgpp::base::DataMatrix::getNrows(), and predict().

Referenced by train().

◆ getAccuracy() [2/2]

double sgpp::datadriven::LearnerSGD::getAccuracy ( sgpp::base::DataVector testLabels,
double  threshold,
sgpp::base::DataVector predictedLabels 
)
protected

Computes the classification accuracy.

Parameters
testLabelsThe actual class labels
thresholdThe decision threshold (e.g. for class labels -1, 1 -> threshold = 0)
predictedLabelsThe predicted class labels
Returns
The resulting accuracy

References sgpp::base::DataVector::get(), sgpp::base::DataVector::getSize(), and python.statsfileInfo::i.

◆ getBatchError()

void sgpp::datadriven::LearnerSGD::getBatchError ( sgpp::base::DataMatrix data,
const sgpp::base::DataVector labels 
)
protected

Computes error contribution for each data point of the given data set (required for predictive refinement indicator).

Parameters
dataThe data points
labelsThe corresponding class labels

References alphaAvg, batchError, sgpp::op_factory::createOperationMultipleEval(), sgpp::base::DataVector::get(), sgpp::base::DataMatrix::getNrows(), grid, python.statsfileInfo::i, sgpp::combigrid::pow(), and sgpp::base::DataVector::set().

Referenced by train().

◆ getError()

double sgpp::datadriven::LearnerSGD::getError ( sgpp::base::DataMatrix data,
sgpp::base::DataVector labels,
std::string  errorType 
)
protected

Computes specified error type (e.g.

MSE).

Parameters
dataThe data points
labelsThe corresponding class labels
errorTypeThe type of the error measurement (MSE or Hinge loss)
Returns
The error estimation

References alphaAvg, sgpp::op_factory::createOperationMultipleEval(), error, sgpp::base::DataVector::get(), sgpp::base::DataMatrix::getNrows(), grid, python.statsfileInfo::i, sgpp::base::DataVector::set(), and sgpp::base::DataVector::setAll().

Referenced by train().

◆ initialize()

void sgpp::datadriven::LearnerSGD::initialize ( )

◆ predict()

void sgpp::datadriven::LearnerSGD::predict ( base::DataMatrix testData,
base::DataVector predictedLabels 
)
protected

Predicts class labels based on the trained model.

Parameters
testDataThe data for which class labels should be predicted
predictedLabelsThe predicted class labels

References alphaAvg, sgpp::op_factory::createOperationMultipleEval(), sgpp::base::DataMatrix::getNrows(), grid, python.statsfileInfo::i, and sgpp::base::DataVector::set().

Referenced by getAccuracy(), storeResults(), and train().

◆ pushToBatch()

void sgpp::datadriven::LearnerSGD::pushToBatch ( sgpp::base::DataVector x,
double  y 
)
protected

Stores the last 'batchSize' processed data points if no validation data is provided.

Parameters
xThe current data point
yThe corresponding class label

References sgpp::base::DataMatrix::appendRow(), batchData, batchSize, sgpp::base::DataMatrix::getNrows(), and sgpp::base::DataMatrix::setRow().

Referenced by train().

◆ storeResults()

◆ train()

void sgpp::datadriven::LearnerSGD::train ( size_t  maxDataPasses,
std::string  refType,
std::string  refMonitor,
size_t  refPeriod,
double  errorDeclineThreshold,
size_t  errorDeclineBufferSize,
size_t  minRefInterval 
)

Implements online learning using stochastic gradient descent.

Parameters
maxDataPassesThe number of passes over the whole training data
refTypeThe refinement indicator (surplus, zero-crossings or data-based)
refMonitorThe refinement strategy (periodic or convergence-based)
refPeriodThe refinement interval (if periodic refinement is chosen)
errorDeclineThresholdThe convergence threshold (if convergence-based refinement is chosen)
errorDeclineBufferSizeThe number of error measurements which are used to check convergence (if convergence-based refinement is chosen)
minRefIntervalThe minimum number of data points which have to be processed before next refinement can be scheduled (if convergence-based refinement is chosen)

References adaptivityConfig, alpha, alphaAvg, sgpp::base::DataVector::append(), avgErrors, sgpp::base::DataVector::axpy(), batchData, batchError, batchLabels, sgpp::op_factory::createOperationMultipleEval(), currentGamma, chess::dim, error, sgpp::base::ImpurityRefinement::free_refine(), sgpp::base::PredictiveRefinement::free_refine(), gamma, sgpp::base::DataVector::get(), getAccuracy(), getBatchError(), getError(), sgpp::base::DataMatrix::getNcols(), sgpp::base::DataMatrix::getNrows(), sgpp::base::DataVector::getPointer(), sgpp::base::DataMatrix::getRow(), sgpp::base::DataVector::getSize(), grid, python.utils.sg_projections::gridStorage, lambda, mu, sgpp::base::DataVector::mult(), sgpp::base::AdaptivityConfiguration::noPoints_, sgpp::base::AdaptivityConfiguration::numRefinements_, sgpp::combigrid::pow(), predict(), pushToBatch(), sgpp::datadriven::RefinementMonitor::pushToBuffer(), sgpp::datadriven::RefinementMonitor::refinementsNecessary(), sgpp::base::DataVector::resizeZero(), testData, testLabels, sgpp::base::AdaptivityConfiguration::threshold_, trainData, trainLabels, and useValidData.

Member Data Documentation

◆ adaptivityConfig

base::AdaptivityConfiguration sgpp::datadriven::LearnerSGD::adaptivityConfig
protected

Referenced by train().

◆ alpha

base::DataVector sgpp::datadriven::LearnerSGD::alpha
protected

◆ alphaAvg

base::DataVector sgpp::datadriven::LearnerSGD::alphaAvg
protected

◆ avgErrors

sgpp::base::DataVector sgpp::datadriven::LearnerSGD::avgErrors

Referenced by train().

◆ batchData

base::DataMatrix* sgpp::datadriven::LearnerSGD::batchData
protected

Referenced by LearnerSGD(), pushToBatch(), and train().

◆ batchError

base::DataVector sgpp::datadriven::LearnerSGD::batchError
protected

Referenced by getBatchError(), initialize(), and train().

◆ batchLabels

base::DataVector* sgpp::datadriven::LearnerSGD::batchLabels
protected

Referenced by LearnerSGD(), and train().

◆ batchSize

size_t sgpp::datadriven::LearnerSGD::batchSize
protected

Referenced by initialize(), and pushToBatch().

◆ currentGamma

double sgpp::datadriven::LearnerSGD::currentGamma
protected

Referenced by train().

◆ error

double sgpp::datadriven::LearnerSGD::error

Referenced by getError(), and train().

◆ gamma

double sgpp::datadriven::LearnerSGD::gamma
protected

Referenced by train().

◆ grid

◆ gridConfig

base::RegularGridConfiguration sgpp::datadriven::LearnerSGD::gridConfig
protected

Referenced by createRegularGrid(), and initialize().

◆ lambda

double sgpp::datadriven::LearnerSGD::lambda
protected

Referenced by train().

◆ testData

base::DataMatrix& sgpp::datadriven::LearnerSGD::testData
protected

Referenced by train().

◆ testLabels

base::DataVector& sgpp::datadriven::LearnerSGD::testLabels
protected

Referenced by train().

◆ trainData

◆ trainLabels

base::DataVector& sgpp::datadriven::LearnerSGD::trainLabels
protected

Referenced by train().

◆ useValidData

bool sgpp::datadriven::LearnerSGD::useValidData
protected

Referenced by train().


The documentation for this class was generated from the following files: