SG++-Doxygen-Documentation
sgpp::datadriven::LearnerSGDEOnOffParallel Class Reference

LearnerSGDEOnOffParallel learns the data using sparse grid density estimation. More...

#include <LearnerSGDEOnOffParallel.hpp>

Public Member Functions

void assembleNextBatchData (Dataset *dataBatch, size_t *batchOffset) const
 Copies the data from the training set into the data batch. More...
 
size_t assignBatchToWorker (size_t batchOffset, bool doCrossValidation)
 Asks the scheduler where to assign the next batch to and sends the MPI request. More...
 
bool checkAllGridsConsistent ()
 Check whether all grids are not in a temporarily inconsistent state. More...
 
bool checkGridStateConsistent (size_t classIndex)
 Check whether the grid is in a final state where learning can occur. More...
 
void computeNewSystemMatrixDecomposition (size_t classIndex, size_t gridVersion)
 Update the system matrix decomposition after a refinement step. More...
 
double getAccuracy () const
 Returns the accuracy of the classifier measured on the test data. More...
 
std::vector< std::pair< std::unique_ptr< DBMatOnlineDE >, size_t > > & getDensityFunctions ()
 Returns the density functions mapped to class labels. More...
 
size_t getDimensionality ()
 Returns the dimensionality of the learner as determined from its training set. More...
 
double getError (Dataset &dataset) const
 Error evaluation required for convergence-based refinement. More...
 
GridgetGrid (size_t classIndex)
 Retrieves the grid for a certain class. More...
 
size_t getLocalGridVersion (size_t classIndex)
 Returns the internally stored current version of the grid. More...
 
size_t getNumClasses () const
 Returns the number of existing classes. More...
 
std::unique_ptr< DBMatOffline > & getOffline ()
 Gets the DBMatOffline object. More...
 
RefinementHandlergetRefinementHandler ()
 Returns a reference to the refinement handler, that contains logic to handle the master's refinement cycles. More...
 
MPITaskSchedulergetScheduler ()
 Gets a reference to the currently installed MPI Scheduler. More...
 
DatasetgetTrainData ()
 Returns a reference to the currently used training data set. More...
 
DatasetgetValidationData ()
 Returns a reference to the currently used test data set. More...
 
 LearnerSGDEOnOffParallel (sgpp::base::RegularGridConfiguration &gridConfig, sgpp::base::AdaptivityConfiguration &adaptivityConfig, sgpp::datadriven::RegularizationConfiguration &regularizationConfig, sgpp::datadriven::DensityEstimationConfiguration &densityEstimationConfig, Dataset &trainData, Dataset &testData, Dataset *validationData, DataVector &classLabels, size_t numClassesInit, bool usePrior, double beta, MPITaskScheduler &mpiTaskScheduler)
 
void mergeAlphaValues (size_t classIndex, size_t remoteGridVersion, DataVector dataVector, size_t batchOffset, size_t batchSize, bool isLastPacketInSeries)
 Merge alpha values received from a remote process into the local alpha vector. More...
 
void predict (DataMatrix &test, DataVector &classLabels) const
 Predicts the class labels of the test data points. More...
 
void setLocalGridVersion (size_t classIndex, size_t gridVersion)
 Set the grid version. More...
 
void shutdownMPINodes ()
 If this is run on master, it issues shutdown requests to all workers and waits for them to return. More...
 
void train (Dataset &dataBatch, bool doCrossValidation)
 Trains the learner with the given data batch. More...
 
void train (std::vector< std::pair< sgpp::base::DataMatrix *, double > > &trainDataClasses, bool doCrossValidation)
 Trains the learner with the given data batch that is already split up wrt its different classes. More...
 
void trainParallel (size_t batchSize, size_t maxDataPasses, std::string refinementFunctorType, std::string refMonitor, size_t refPeriod, double accDeclineThreshold, size_t accDeclineBufferSize, size_t minRefInterval)
 Trains the learner with the given dataset. More...
 
void updateAlpha (size_t classIndex, std::list< size_t > *deletedPoints, size_t newPoints)
 Updates the surplus vector of a certain class. More...
 
void workBatch (Dataset dataset, size_t batchOffset, bool doCrossValidation)
 Train from a batch. More...
 
virtual ~LearnerSGDEOnOffParallel ()
 Runs MPI finalize when destructing the learner. More...
 

Static Public Member Functions

static bool isVersionConsistent (size_t version)
 Check whether a specific grid version is consistent, i.e. More...
 

Protected Member Functions

void allocateClassMatrices (size_t dim, std::vector< std::pair< base::DataMatrix *, double >> &trainDataClasses, std::map< double, int > &classIndices) const
 Allocates memory for every class to hold training data before learning. More...
 
void doRefinementForAll (const std::string &refinementFunctorType, const std::string &refinementMonitorType, const std::vector< std::pair< std::unique_ptr< DBMatOnlineDE >, size_t >> &onlineObjects, RefinementMonitor &monitor)
 Do an entire refinement cycle for all classes. More...
 
void printGridSizeStatistics (const char *messageString, std::vector< std::pair< std::unique_ptr< DBMatOnlineDE >, size_t >> &onlineObjects)
 Shows grid size statistics along with a message. More...
 
void splitBatchIntoClasses (const Dataset &dataset, size_t dim, const std::vector< std::pair< DataMatrix *, double >> &trainDataClasses, std::map< double, int > &classIndices) const
 
void waitForAllGridsConsistent ()
 Wait for all grids to reach a consistent state before continuing. More...
 

Protected Attributes

sgpp::base::AdaptivityConfigurationadaptivityConfig
 
std::vector< DataVector * > alphas
 
DataVector avgErrors
 
double beta
 
DataVector classLabels
 
sgpp::datadriven::DensityEstimationConfigurationdensityEstimationConfig
 
std::vector< std::pair< std::unique_ptr< DBMatOnlineDE >, size_t > > densityFunctions
 
sgpp::base::GeneralGridConfigurationgridConfig
 
std::vector< std::unique_ptr< Grid > > grids
 
std::vector< size_t > localGridVersions
 Vector that holds the grid version for every class. More...
 
MPITaskSchedulermpiTaskScheduler
 Reference to the currently installed MPI Task Scheduler. More...
 
size_t numClasses
 
std::unique_ptr< DBMatOfflineoffline
 
std::vector< std::unique_ptr< DBMatOffline > > offlineContainer
 
std::map< double, double > prior
 
size_t processedPoints
 
RefinementHandler refinementHandler
 Instance of the currently installed refinement handler. More...
 
sgpp::datadriven::RegularizationConfigurationregularizationConfig
 
DatasettestData
 
DatasettrainData
 
bool trained
 
bool usePrior
 
DatasetvalidationData
 
bool workerActive
 Boolean used to detect when a shutdown of a worker has been requested. More...
 

Detailed Description

LearnerSGDEOnOffParallel learns the data using sparse grid density estimation.

The system matrix is precomputed and factorized using Eigen-, LU- or Cholesky decomposition (offline step). Then, for each class a density function is computed by solving the system in every iteration (online step). If Cholesky decomposition is chosen, refinement/coarsening can be applied. This learner uses MPI to parallelize the learning phase across multiple nodes.

Constructor & Destructor Documentation

◆ LearnerSGDEOnOffParallel()

◆ ~LearnerSGDEOnOffParallel()

sgpp::datadriven::LearnerSGDEOnOffParallel::~LearnerSGDEOnOffParallel ( )
virtual

Runs MPI finalize when destructing the learner.

References sgpp::datadriven::MPIMethods::finalizeMPI().

Member Function Documentation

◆ allocateClassMatrices()

void sgpp::datadriven::LearnerSGDEOnOffParallel::allocateClassMatrices ( size_t  dim,
std::vector< std::pair< base::DataMatrix *, double >> &  trainDataClasses,
std::map< double, int > &  classIndices 
) const
protected

Allocates memory for every class to hold training data before learning.

Parameters
dimThe dimensionality of the current problem
trainDataClassesStorage that will be allocated that holds space for data and label
classIndicesA map of each classes label to its index

References classLabels, getNumClasses(), python.statsfileInfo::i, m, and friedman::p.

Referenced by train().

◆ assembleNextBatchData()

void sgpp::datadriven::LearnerSGDEOnOffParallel::assembleNextBatchData ( Dataset dataBatch,
size_t *  batchOffset 
) const

Copies the data from the training set into the data batch.

Parameters
dataBatchBatch of data to fill, with set dimensionality and size
batchOffsetThe offset in the training data from which to start copying

References D, sgpp::base::DataVector::get(), sgpp::datadriven::Dataset::getData(), sgpp::datadriven::Dataset::getDimension(), sgpp::datadriven::Dataset::getNumberInstances(), sgpp::base::DataMatrix::getRow(), sgpp::datadriven::Dataset::getTargets(), python.utils.statsfile2gnuplot::j, sgpp::base::DataVector::set(), sgpp::base::DataMatrix::setRow(), and trainData.

Referenced by workBatch().

◆ assignBatchToWorker()

size_t sgpp::datadriven::LearnerSGDEOnOffParallel::assignBatchToWorker ( size_t  batchOffset,
bool  doCrossValidation 
)

Asks the scheduler where to assign the next batch to and sends the MPI request.

Parameters
batchOffsetStarting offset of the new batch
doCrossValidationWhether the client should do cross-validation
Returns
The size of the batch assigned by the scheduler

References sgpp::datadriven::MPIMethods::assignBatch(), sgpp::datadriven::MPITaskScheduler::assignTaskVariableTaskSize(), sgpp::datadriven::Dataset::getNumberInstances(), mpiTaskScheduler, sgpp::datadriven::TRAIN_FROM_BATCH, and trainData.

Referenced by trainParallel().

◆ checkAllGridsConsistent()

bool sgpp::datadriven::LearnerSGDEOnOffParallel::checkAllGridsConsistent ( )

Check whether all grids are not in a temporarily inconsistent state.

Returns
Whether all grids are consistent

References isVersionConsistent(), and localGridVersions.

Referenced by sgpp::datadriven::RefinementHandler::checkReadyForRefinement().

◆ checkGridStateConsistent()

bool sgpp::datadriven::LearnerSGDEOnOffParallel::checkGridStateConsistent ( size_t  classIndex)

Check whether the grid is in a final state where learning can occur.

This is not the case while receiving refinement results or updating the system matrix decomposition.

Parameters
classIndexThe class for which to check consistency.
Returns
Whether the grid is currently in a consistent state

References isVersionConsistent(), and localGridVersions.

Referenced by sgpp::datadriven::MPIMethods::receiveGridComponentsUpdate(), setLocalGridVersion(), sgpp::datadriven::RefinementHandler::updateClassVariablesAfterRefinement(), waitForAllGridsConsistent(), and sgpp::datadriven::MPIMethods::waitForGridConsistent().

◆ computeNewSystemMatrixDecomposition()

void sgpp::datadriven::LearnerSGDEOnOffParallel::computeNewSystemMatrixDecomposition ( size_t  classIndex,
size_t  gridVersion 
)

◆ doRefinementForAll()

void sgpp::datadriven::LearnerSGDEOnOffParallel::doRefinementForAll ( const std::string &  refinementFunctorType,
const std::string &  refinementMonitorType,
const std::vector< std::pair< std::unique_ptr< DBMatOnlineDE >, size_t >> &  onlineObjects,
RefinementMonitor monitor 
)
protected

Do an entire refinement cycle for all classes.

Parameters
refinementFunctorTypeString constant specifying the functor to use in refinement
refinementMonitorTypeString constant specifying the monitor to use in refinement
onlineObjectsReference to the online objects for density estimation
monitorThe setup of the convergence monitor for refinement

References adaptivityConfig, alphas, sgpp::datadriven::PendingMPIRequest::buffer, CHECK_SIZE_T_TO_INT, D, sgpp::datadriven::RefinementHandler::doRefinementForClass(), sgpp::datadriven::Dataset::getData(), getNumClasses(), sgpp::datadriven::RefinementHandler::getRefinementResult(), sgpp::datadriven::Dataset::getTargets(), grids, isVersionConsistent(), sgpp::base::AdaptivityConfiguration::noPoints_, sgpp::datadriven::MPI_Packet::payload, refinementHandler, sgpp::datadriven::SYSTEM_MATRIX_DECOMPOSITION, trainData, sgpp::datadriven::UPDATE_GRID, and sgpp::datadriven::MPIMethods::waitForIncomingMessageType().

Referenced by trainParallel().

◆ getAccuracy()

double sgpp::datadriven::LearnerSGDEOnOffParallel::getAccuracy ( ) const

Returns the accuracy of the classifier measured on the test data.

Returns
The classification accuracy measured on the test data

References sgpp::base::DataVector::get(), sgpp::datadriven::Dataset::getData(), sgpp::datadriven::Dataset::getNumberInstances(), sgpp::datadriven::Dataset::getTargets(), python.statsfileInfo::i, predict(), and testData.

◆ getDensityFunctions()

std::vector< std::pair< std::unique_ptr< DBMatOnlineDE >, size_t > > & sgpp::datadriven::LearnerSGDEOnOffParallel::getDensityFunctions ( )

Returns the density functions mapped to class labels.

Returns
The density function objects mapped to class labels

References densityFunctions.

Referenced by computeNewSystemMatrixDecomposition(), sgpp::datadriven::MPIMethods::receiveGridComponentsUpdate(), trainParallel(), and workBatch().

◆ getDimensionality()

size_t sgpp::datadriven::LearnerSGDEOnOffParallel::getDimensionality ( )

◆ getError()

double sgpp::datadriven::LearnerSGDEOnOffParallel::getError ( Dataset dataset) const

Error evaluation required for convergence-based refinement.

Parameters
datasetThe data to measure the error on
Returns
The error evaluation

References sgpp::base::DataVector::get(), sgpp::datadriven::Dataset::getData(), sgpp::datadriven::Dataset::getNumberInstances(), sgpp::datadriven::Dataset::getTargets(), python.statsfileInfo::i, and predict().

Referenced by sgpp::datadriven::RefinementHandler::checkRefinementNecessary().

◆ getGrid()

Grid & sgpp::datadriven::LearnerSGDEOnOffParallel::getGrid ( size_t  classIndex)

Retrieves the grid for a certain class.

Parameters
classIndexthe index of the desired class
Returns
the underlying grid

References grids.

Referenced by sgpp::datadriven::MPIMethods::receiveGridComponentsUpdate().

◆ getLocalGridVersion()

◆ getNumClasses()

size_t sgpp::datadriven::LearnerSGDEOnOffParallel::getNumClasses ( ) const

Returns the number of existing classes.

Returns
The number of classes

References numClasses.

Referenced by allocateClassMatrices(), sgpp::datadriven::RoundRobinScheduler::assignTaskStaticTaskSize(), doRefinementForAll(), train(), and workBatch().

◆ getOffline()

std::unique_ptr< DBMatOffline > & sgpp::datadriven::LearnerSGDEOnOffParallel::getOffline ( )

◆ getRefinementHandler()

RefinementHandler & sgpp::datadriven::LearnerSGDEOnOffParallel::getRefinementHandler ( )

Returns a reference to the refinement handler, that contains logic to handle the master's refinement cycles.

Returns
A reference to the refinement handler

References refinementHandler.

Referenced by sgpp::datadriven::MPIMethods::receiveGridComponentsUpdate().

◆ getScheduler()

MPITaskScheduler & sgpp::datadriven::LearnerSGDEOnOffParallel::getScheduler ( )

Gets a reference to the currently installed MPI Scheduler.

The scheduler assigns tasks of variable or static size to workers.

Returns
A reference to the installed MPI Task Scheduler

References mpiTaskScheduler.

Referenced by sgpp::datadriven::RefinementHandler::checkReadyForRefinement(), and sgpp::datadriven::RefinementHandler::updateClassVariablesAfterRefinement().

◆ getTrainData()

Dataset & sgpp::datadriven::LearnerSGDEOnOffParallel::getTrainData ( )

Returns a reference to the currently used training data set.

Returns
A reference to the training data set

References trainData.

Referenced by sgpp::datadriven::RefinementHandler::checkRefinementNecessary(), and sgpp::datadriven::RefinementHandler::handleSurplusBasedRefinement().

◆ getValidationData()

Dataset * sgpp::datadriven::LearnerSGDEOnOffParallel::getValidationData ( )

Returns a reference to the currently used test data set.

Returns
A reference to the test data set

References validationData.

Referenced by sgpp::datadriven::RefinementHandler::checkRefinementNecessary().

◆ isVersionConsistent()

bool sgpp::datadriven::LearnerSGDEOnOffParallel::isVersionConsistent ( size_t  version)
static

Check whether a specific grid version is consistent, i.e.

whether it is higher than MINIMUM_CONSISTENT_GRID_VERSION

Parameters
versionThe version of the grid to check against
Returns
Whether the version indicates consistency.

Referenced by checkAllGridsConsistent(), checkGridStateConsistent(), doRefinementForAll(), mergeAlphaValues(), sgpp::datadriven::MPIMethods::receiveGridComponentsUpdate(), and setLocalGridVersion().

◆ mergeAlphaValues()

void sgpp::datadriven::LearnerSGDEOnOffParallel::mergeAlphaValues ( size_t  classIndex,
size_t  remoteGridVersion,
DataVector  dataVector,
size_t  batchOffset,
size_t  batchSize,
bool  isLastPacketInSeries 
)

Merge alpha values received from a remote process into the local alpha vector.

Parameters
classIndexThe class to which the alpha vector belongs
remoteGridVersionThe remote grid version this alpha vector was trained on
dataVectorThe alpha vector itself
batchOffsetThe offset from the start of the training set this vector was trained from
batchSizeThe size of the batch this vector was trained from
isLastPacketInSeriesWhether this merge is the last merge in several for the same class and batch

References sgpp::base::DataVector::add(), sgpp::datadriven::RefinementResult::addedGridPoints, alphas, classLabels, D, sgpp::datadriven::RefinementResult::deletedGridPointsIndices, sgpp::base::DataVector::get(), getLocalGridVersion(), sgpp::datadriven::RefinementHandler::getRefinementResult(), sgpp::base::DataVector::getSize(), python.statsfileInfo::i, isVersionConsistent(), localGridVersions, mpiTaskScheduler, sgpp::datadriven::MPITaskScheduler::onMergeRequestIncoming(), prior, refinementHandler, sgpp::base::DataVector::resizeZero(), usePrior, and sgpp::datadriven::MPIMethods::waitForGridConsistent().

Referenced by sgpp::datadriven::MPIMethods::receiveMergeGridNetworkMessage().

◆ predict()

void sgpp::datadriven::LearnerSGDEOnOffParallel::predict ( DataMatrix test,
DataVector classLabels 
) const

Predicts the class labels of the test data points.

Parameters
testThe data points for which labels will be precicted
classLabelsvector containing the predicted class labels

References alphas, classLabels, densityFunctions, sgpp::base::DataMatrix::getNrows(), grids, numClasses, chess::point, prior, and python.utils.pca_normalize_dataset::u.

Referenced by getAccuracy(), and getError().

◆ printGridSizeStatistics()

void sgpp::datadriven::LearnerSGDEOnOffParallel::printGridSizeStatistics ( const char *  messageString,
std::vector< std::pair< std::unique_ptr< DBMatOnlineDE >, size_t >> &  onlineObjects 
)
protected

Shows grid size statistics along with a message.

Parameters
messageStringThe message to display alongside the statistics
onlineObjectsThe current density estimation objects

References sgpp::base::Grid::getSize(), grid(), and grids.

Referenced by trainParallel().

◆ setLocalGridVersion()

void sgpp::datadriven::LearnerSGDEOnOffParallel::setLocalGridVersion ( size_t  classIndex,
size_t  gridVersion 
)

Set the grid version.

Parameters
classIndexThe class of the grid to search for
gridVersionThe new version of the grid

References checkGridStateConsistent(), D, isVersionConsistent(), and localGridVersions.

Referenced by computeNewSystemMatrixDecomposition(), sgpp::datadriven::MPIMethods::receiveGridComponentsUpdate(), and sgpp::datadriven::RefinementHandler::updateClassVariablesAfterRefinement().

◆ shutdownMPINodes()

void sgpp::datadriven::LearnerSGDEOnOffParallel::shutdownMPINodes ( )

◆ splitBatchIntoClasses()

void sgpp::datadriven::LearnerSGDEOnOffParallel::splitBatchIntoClasses ( const Dataset dataset,
size_t  dim,
const std::vector< std::pair< DataMatrix *, double >> &  trainDataClasses,
std::map< double, int > &  classIndices 
) const
protected

◆ train() [1/2]

void sgpp::datadriven::LearnerSGDEOnOffParallel::train ( Dataset dataBatch,
bool  doCrossValidation 
)

Trains the learner with the given data batch.

Parameters
dataBatchThe next data batch to process
doCrossValidationEnable cross-validation

References allocateClassMatrices(), D, chess::dim, sgpp::datadriven::Dataset::getDimension(), sgpp::datadriven::Dataset::getNumberInstances(), getNumClasses(), and splitBatchIntoClasses().

Referenced by workBatch().

◆ train() [2/2]

void sgpp::datadriven::LearnerSGDEOnOffParallel::train ( std::vector< std::pair< sgpp::base::DataMatrix *, double > > &  trainDataClasses,
bool  doCrossValidation 
)

Trains the learner with the given data batch that is already split up wrt its different classes.

Parameters
trainDataClassesA vector of pairs; Each pair contains the data points that belong to one class and the corresponding class label
doCrossValidationEnable cross-validation

References sgpp::datadriven::RefinementResult::addedGridPoints, alphas, D, sgpp::datadriven::RefinementResult::deletedGridPointsIndices, densityEstimationConfig, densityFunctions, sgpp::datadriven::RefinementHandler::getRefinementResult(), grids, friedman::p, prior, processedPoints, refinementHandler, trained, and usePrior.

◆ trainParallel()

void sgpp::datadriven::LearnerSGDEOnOffParallel::trainParallel ( size_t  batchSize,
size_t  maxDataPasses,
std::string  refinementFunctorType,
std::string  refMonitor,
size_t  refPeriod,
double  accDeclineThreshold,
size_t  accDeclineBufferSize,
size_t  minRefInterval 
)

Trains the learner with the given dataset.

Parameters
batchSizeSize of subset of data points used for each training step
maxDataPassesThe number of passes over the whole training data
refinementFunctorTypeThe refinement indicator (surplus, zero-crossings or data-based)
refMonitorThe refinement strategy (periodic or convergence-based)
refPeriodThe refinement interval (if periodic refinement is chosen)
accDeclineThresholdThe convergence threshold (if convergence-based refinement is chosen)
accDeclineBufferSizeThe number of accuracy measurements which are used to check convergence (if convergence-based refinement is chosen)
minRefIntervalThe minimum number of data points (or data batches) which have to be processed before next refinement can be scheduled (if convergence-based refinement is chosen)

References adaptivityConfig, assignBatchToWorker(), sgpp::datadriven::RefinementHandler::checkReadyForRefinement(), sgpp::datadriven::RefinementHandler::checkRefinementNecessary(), python.leja::count, D, doRefinementForAll(), getDensityFunctions(), sgpp::datadriven::Dataset::getNumberInstances(), sgpp::datadriven::MPIMethods::getQueueSize(), sgpp::datadriven::MPIMethods::hasPendingOutgoingRequests(), sgpp::datadriven::MPIMethods::isMaster(), MPI_MASTER_RANK, mpiTaskScheduler, sgpp::datadriven::MPITaskScheduler::onRefinementStarted(), printGridSizeStatistics(), sgpp::datadriven::MPIMethods::processCompletedMPIRequests(), processedPoints, refinementHandler, sgpp::datadriven::MPIMethods::sendCommandNoArgs(), shutdownMPINodes(), trainData, sgpp::datadriven::MPIMethods::waitForAnyMPIRequestsToComplete(), sgpp::datadriven::WORKER_SHUTDOWN_SUCCESS, and workerActive.

◆ updateAlpha()

void sgpp::datadriven::LearnerSGDEOnOffParallel::updateAlpha ( size_t  classIndex,
std::list< size_t > *  deletedPoints,
size_t  newPoints 
)

Updates the surplus vector of a certain class.

Parameters
classIndexthe index of the class
deletedPointsa list of indexes of deleted points (coarsening)
newPointsthe number of new grid points (refinemenet)

References alpha, alphas, sgpp::base::DataVector::getSize(), sgpp::base::DataVector::remove(), and sgpp::base::DataVector::resizeZero().

Referenced by sgpp::datadriven::RefinementHandler::updateClassVariablesAfterRefinement().

◆ waitForAllGridsConsistent()

void sgpp::datadriven::LearnerSGDEOnOffParallel::waitForAllGridsConsistent ( )
protected

Wait for all grids to reach a consistent state before continuing.

References checkGridStateConsistent(), getLocalGridVersion(), localGridVersions, and sgpp::datadriven::MPIMethods::waitForGridConsistent().

Referenced by workBatch().

◆ workBatch()

void sgpp::datadriven::LearnerSGDEOnOffParallel::workBatch ( Dataset  dataset,
size_t  batchOffset,
bool  doCrossValidation 
)

Train from a batch.

Will wait until all grids are consistent, fill the dataset, learn from the dataset and send the new alpha vector to the master

Parameters
datasetAn empty dataset with size and dimension set.
batchOffsetThe offset from the start of the training set to assemble the batch from.
doCrossValidationWhether to cross validate results.

References alphas, assembleNextBatchData(), D, getDensityFunctions(), getLocalGridVersion(), sgpp::datadriven::Dataset::getNumberInstances(), getNumClasses(), sgpp::datadriven::MPIMethods::sendMergeGridNetworkMessage(), train(), and waitForAllGridsConsistent().

Referenced by sgpp::datadriven::MPIMethods::runBatch().

Member Data Documentation

◆ adaptivityConfig

sgpp::base::AdaptivityConfiguration& sgpp::datadriven::LearnerSGDEOnOffParallel::adaptivityConfig
protected

◆ alphas

std::vector<DataVector*> sgpp::datadriven::LearnerSGDEOnOffParallel::alphas
protected

◆ avgErrors

DataVector sgpp::datadriven::LearnerSGDEOnOffParallel::avgErrors
protected

◆ beta

double sgpp::datadriven::LearnerSGDEOnOffParallel::beta
protected

◆ classLabels

DataVector sgpp::datadriven::LearnerSGDEOnOffParallel::classLabels
protected

◆ densityEstimationConfig

sgpp::datadriven::DensityEstimationConfiguration& sgpp::datadriven::LearnerSGDEOnOffParallel::densityEstimationConfig
protected

◆ densityFunctions

std::vector<std::pair<std::unique_ptr<DBMatOnlineDE>, size_t> > sgpp::datadriven::LearnerSGDEOnOffParallel::densityFunctions
protected

◆ gridConfig

sgpp::base::GeneralGridConfiguration& sgpp::datadriven::LearnerSGDEOnOffParallel::gridConfig
protected

◆ grids

std::vector<std::unique_ptr<Grid> > sgpp::datadriven::LearnerSGDEOnOffParallel::grids
protected

◆ localGridVersions

std::vector<size_t> sgpp::datadriven::LearnerSGDEOnOffParallel::localGridVersions
protected

◆ mpiTaskScheduler

MPITaskScheduler& sgpp::datadriven::LearnerSGDEOnOffParallel::mpiTaskScheduler
protected

Reference to the currently installed MPI Task Scheduler.

Referenced by assignBatchToWorker(), getScheduler(), LearnerSGDEOnOffParallel(), mergeAlphaValues(), and trainParallel().

◆ numClasses

size_t sgpp::datadriven::LearnerSGDEOnOffParallel::numClasses
protected

◆ offline

std::unique_ptr<DBMatOffline> sgpp::datadriven::LearnerSGDEOnOffParallel::offline
protected

◆ offlineContainer

std::vector<std::unique_ptr<DBMatOffline> > sgpp::datadriven::LearnerSGDEOnOffParallel::offlineContainer
protected

◆ prior

std::map<double, double> sgpp::datadriven::LearnerSGDEOnOffParallel::prior
protected

◆ processedPoints

size_t sgpp::datadriven::LearnerSGDEOnOffParallel::processedPoints
protected

◆ refinementHandler

RefinementHandler sgpp::datadriven::LearnerSGDEOnOffParallel::refinementHandler
protected

◆ regularizationConfig

sgpp::datadriven::RegularizationConfiguration& sgpp::datadriven::LearnerSGDEOnOffParallel::regularizationConfig
protected

◆ testData

Dataset& sgpp::datadriven::LearnerSGDEOnOffParallel::testData
protected

◆ trainData

◆ trained

bool sgpp::datadriven::LearnerSGDEOnOffParallel::trained
protected

Referenced by LearnerSGDEOnOffParallel(), and train().

◆ usePrior

bool sgpp::datadriven::LearnerSGDEOnOffParallel::usePrior
protected

◆ validationData

Dataset* sgpp::datadriven::LearnerSGDEOnOffParallel::validationData
protected

◆ workerActive

bool sgpp::datadriven::LearnerSGDEOnOffParallel::workerActive
protected

Boolean used to detect when a shutdown of a worker has been requested.

Referenced by LearnerSGDEOnOffParallel(), shutdownMPINodes(), and trainParallel().


The documentation for this class was generated from the following files: