SG++-Doxygen-Documentation
|
LearnerSGDEOnOffParallel learns the data using sparse grid density estimation. More...
#include <LearnerSGDEOnOffParallel.hpp>
Public Member Functions | |
void | assembleNextBatchData (Dataset *dataBatch, size_t *batchOffset) const |
Copies the data from the training set into the data batch. More... | |
size_t | assignBatchToWorker (size_t batchOffset, bool doCrossValidation) |
Asks the scheduler where to assign the next batch to and sends the MPI request. More... | |
bool | checkAllGridsConsistent () |
Check whether all grids are not in a temporarily inconsistent state. More... | |
bool | checkGridStateConsistent (size_t classIndex) |
Check whether the grid is in a final state where learning can occur. More... | |
void | computeNewSystemMatrixDecomposition (size_t classIndex, size_t gridVersion) |
Update the system matrix decomposition after a refinement step. More... | |
double | getAccuracy () const |
Returns the accuracy of the classifier measured on the test data. More... | |
std::vector< std::pair< std::unique_ptr< DBMatOnlineDE >, size_t > > & | getDensityFunctions () |
Returns the density functions mapped to class labels. More... | |
size_t | getDimensionality () |
Returns the dimensionality of the learner as determined from its training set. More... | |
double | getError (Dataset &dataset) const |
Error evaluation required for convergence-based refinement. More... | |
Grid & | getGrid (size_t classIndex) |
Retrieves the grid for a certain class. More... | |
size_t | getLocalGridVersion (size_t classIndex) |
Returns the internally stored current version of the grid. More... | |
size_t | getNumClasses () const |
Returns the number of existing classes. More... | |
std::unique_ptr< DBMatOffline > & | getOffline () |
Gets the DBMatOffline object. More... | |
RefinementHandler & | getRefinementHandler () |
Returns a reference to the refinement handler, that contains logic to handle the master's refinement cycles. More... | |
MPITaskScheduler & | getScheduler () |
Gets a reference to the currently installed MPI Scheduler. More... | |
Dataset & | getTrainData () |
Returns a reference to the currently used training data set. More... | |
Dataset * | getValidationData () |
Returns a reference to the currently used test data set. More... | |
LearnerSGDEOnOffParallel (sgpp::base::RegularGridConfiguration &gridConfig, sgpp::base::AdaptivityConfiguration &adaptivityConfig, sgpp::datadriven::RegularizationConfiguration ®ularizationConfig, sgpp::datadriven::DensityEstimationConfiguration &densityEstimationConfig, Dataset &trainData, Dataset &testData, Dataset *validationData, DataVector &classLabels, size_t numClassesInit, bool usePrior, double beta, MPITaskScheduler &mpiTaskScheduler) | |
void | mergeAlphaValues (size_t classIndex, size_t remoteGridVersion, DataVector dataVector, size_t batchOffset, size_t batchSize, bool isLastPacketInSeries) |
Merge alpha values received from a remote process into the local alpha vector. More... | |
void | predict (DataMatrix &test, DataVector &classLabels) const |
Predicts the class labels of the test data points. More... | |
void | setLocalGridVersion (size_t classIndex, size_t gridVersion) |
Set the grid version. More... | |
void | shutdownMPINodes () |
If this is run on master, it issues shutdown requests to all workers and waits for them to return. More... | |
void | train (Dataset &dataBatch, bool doCrossValidation) |
Trains the learner with the given data batch. More... | |
void | train (std::vector< std::pair< sgpp::base::DataMatrix *, double > > &trainDataClasses, bool doCrossValidation) |
Trains the learner with the given data batch that is already split up wrt its different classes. More... | |
void | trainParallel (size_t batchSize, size_t maxDataPasses, std::string refinementFunctorType, std::string refMonitor, size_t refPeriod, double accDeclineThreshold, size_t accDeclineBufferSize, size_t minRefInterval) |
Trains the learner with the given dataset. More... | |
void | updateAlpha (size_t classIndex, std::list< size_t > *deletedPoints, size_t newPoints) |
Updates the surplus vector of a certain class. More... | |
void | workBatch (Dataset dataset, size_t batchOffset, bool doCrossValidation) |
Train from a batch. More... | |
virtual | ~LearnerSGDEOnOffParallel () |
Runs MPI finalize when destructing the learner. More... | |
Static Public Member Functions | |
static bool | isVersionConsistent (size_t version) |
Check whether a specific grid version is consistent, i.e. More... | |
Protected Member Functions | |
void | allocateClassMatrices (size_t dim, std::vector< std::pair< base::DataMatrix *, double >> &trainDataClasses, std::map< double, int > &classIndices) const |
Allocates memory for every class to hold training data before learning. More... | |
void | doRefinementForAll (const std::string &refinementFunctorType, const std::string &refinementMonitorType, const std::vector< std::pair< std::unique_ptr< DBMatOnlineDE >, size_t >> &onlineObjects, RefinementMonitor &monitor) |
Do an entire refinement cycle for all classes. More... | |
void | printGridSizeStatistics (const char *messageString, std::vector< std::pair< std::unique_ptr< DBMatOnlineDE >, size_t >> &onlineObjects) |
Shows grid size statistics along with a message. More... | |
void | splitBatchIntoClasses (const Dataset &dataset, size_t dim, const std::vector< std::pair< DataMatrix *, double >> &trainDataClasses, std::map< double, int > &classIndices) const |
void | waitForAllGridsConsistent () |
Wait for all grids to reach a consistent state before continuing. More... | |
LearnerSGDEOnOffParallel learns the data using sparse grid density estimation.
The system matrix is precomputed and factorized using Eigen-, LU- or Cholesky decomposition (offline step). Then, for each class a density function is computed by solving the system in every iteration (online step). If Cholesky decomposition is chosen, refinement/coarsening can be applied. This learner uses MPI to parallelize the learning phase across multiple nodes.
sgpp::datadriven::LearnerSGDEOnOffParallel::LearnerSGDEOnOffParallel | ( | sgpp::base::RegularGridConfiguration & | gridConfig, |
sgpp::base::AdaptivityConfiguration & | adaptivityConfig, | ||
sgpp::datadriven::RegularizationConfiguration & | regularizationConfig, | ||
sgpp::datadriven::DensityEstimationConfiguration & | densityEstimationConfig, | ||
Dataset & | trainData, | ||
Dataset & | testData, | ||
Dataset * | validationData, | ||
DataVector & | classLabels, | ||
size_t | numClassesInit, | ||
bool | usePrior, | ||
double | beta, | ||
MPITaskScheduler & | mpiTaskScheduler | ||
) |
References adaptivityConfig, alpha, alphas, avgErrors, beta, sgpp::datadriven::DBMatOnlineDEFactory::buildDBMatOnlineDE(), sgpp::datadriven::DBMatOfflineFactory::buildOfflineObject(), classLabels, sgpp::datadriven::GridFactory::createGrid(), densityEstimationConfig, densityFunctions, grid(), gridConfig, grids, sgpp::datadriven::MPIMethods::initMPI(), sgpp::datadriven::RegularizationConfiguration::lambda_, localGridVersions, mpiTaskScheduler, numClasses, offline, offlineContainer, prior, processedPoints, refinementHandler, regularizationConfig, sgpp::datadriven::MPITaskScheduler::setLearnerInstance(), testData, trained, usePrior, validationData, and workerActive.
|
virtual |
Runs MPI finalize when destructing the learner.
References sgpp::datadriven::MPIMethods::finalizeMPI().
|
protected |
Allocates memory for every class to hold training data before learning.
dim | The dimensionality of the current problem |
trainDataClasses | Storage that will be allocated that holds space for data and label |
classIndices | A map of each classes label to its index |
References classLabels, getNumClasses(), python.statsfileInfo::i, m, and friedman::p.
Referenced by train().
void sgpp::datadriven::LearnerSGDEOnOffParallel::assembleNextBatchData | ( | Dataset * | dataBatch, |
size_t * | batchOffset | ||
) | const |
Copies the data from the training set into the data batch.
dataBatch | Batch of data to fill, with set dimensionality and size |
batchOffset | The offset in the training data from which to start copying |
References D, sgpp::base::DataVector::get(), sgpp::datadriven::Dataset::getData(), sgpp::datadriven::Dataset::getDimension(), sgpp::datadriven::Dataset::getNumberInstances(), sgpp::base::DataMatrix::getRow(), sgpp::datadriven::Dataset::getTargets(), python.utils.statsfile2gnuplot::j, sgpp::base::DataVector::set(), sgpp::base::DataMatrix::setRow(), and trainData.
Referenced by workBatch().
size_t sgpp::datadriven::LearnerSGDEOnOffParallel::assignBatchToWorker | ( | size_t | batchOffset, |
bool | doCrossValidation | ||
) |
Asks the scheduler where to assign the next batch to and sends the MPI request.
batchOffset | Starting offset of the new batch |
doCrossValidation | Whether the client should do cross-validation |
References sgpp::datadriven::MPIMethods::assignBatch(), sgpp::datadriven::MPITaskScheduler::assignTaskVariableTaskSize(), sgpp::datadriven::Dataset::getNumberInstances(), mpiTaskScheduler, sgpp::datadriven::TRAIN_FROM_BATCH, and trainData.
Referenced by trainParallel().
bool sgpp::datadriven::LearnerSGDEOnOffParallel::checkAllGridsConsistent | ( | ) |
Check whether all grids are not in a temporarily inconsistent state.
References isVersionConsistent(), and localGridVersions.
Referenced by sgpp::datadriven::RefinementHandler::checkReadyForRefinement().
bool sgpp::datadriven::LearnerSGDEOnOffParallel::checkGridStateConsistent | ( | size_t | classIndex | ) |
Check whether the grid is in a final state where learning can occur.
This is not the case while receiving refinement results or updating the system matrix decomposition.
classIndex | The class for which to check consistency. |
References isVersionConsistent(), and localGridVersions.
Referenced by sgpp::datadriven::MPIMethods::receiveGridComponentsUpdate(), setLocalGridVersion(), sgpp::datadriven::RefinementHandler::updateClassVariablesAfterRefinement(), waitForAllGridsConsistent(), and sgpp::datadriven::MPIMethods::waitForGridConsistent().
void sgpp::datadriven::LearnerSGDEOnOffParallel::computeNewSystemMatrixDecomposition | ( | size_t | classIndex, |
size_t | gridVersion | ||
) |
Update the system matrix decomposition after a refinement step.
This will wait for the receiving of refinement results to complete. After computation, the system matrix is sent back to the master
classIndex | The class for which to update the system matrix decomposition |
gridVersion | The new grid version to set after updating the matrix |
References sgpp::datadriven::RefinementResult::addedGridPoints, D, sgpp::datadriven::RefinementResult::deletedGridPointsIndices, densityEstimationConfig, sgpp::datadriven::DBMatOffline::getDecomposedMatrix(), getDensityFunctions(), getLocalGridVersion(), sgpp::datadriven::DBMatOnline::getOfflineObject(), sgpp::datadriven::RefinementHandler::getRefinementResult(), GRID_RECEIVED_ADDED_POINTS, grids, sgpp::datadriven::RegularizationConfiguration::lambda_, refinementHandler, regularizationConfig, sgpp::datadriven::MPIMethods::sendSystemMatrixDecomposition(), setLocalGridVersion(), sgpp::datadriven::UPDATE_GRID, sgpp::datadriven::DBMatOnline::updateSystemMatrixDecomposition(), and sgpp::datadriven::MPIMethods::waitForIncomingMessageType().
Referenced by sgpp::datadriven::MPIMethods::processIncomingMPICommands().
|
protected |
Do an entire refinement cycle for all classes.
refinementFunctorType | String constant specifying the functor to use in refinement |
refinementMonitorType | String constant specifying the monitor to use in refinement |
onlineObjects | Reference to the online objects for density estimation |
monitor | The setup of the convergence monitor for refinement |
References adaptivityConfig, alphas, sgpp::datadriven::PendingMPIRequest::buffer, CHECK_SIZE_T_TO_INT, D, sgpp::datadriven::RefinementHandler::doRefinementForClass(), sgpp::datadriven::Dataset::getData(), getNumClasses(), sgpp::datadriven::RefinementHandler::getRefinementResult(), sgpp::datadriven::Dataset::getTargets(), grids, isVersionConsistent(), sgpp::base::AdaptivityConfiguration::noPoints_, sgpp::datadriven::MPI_Packet::payload, refinementHandler, sgpp::datadriven::SYSTEM_MATRIX_DECOMPOSITION, trainData, sgpp::datadriven::UPDATE_GRID, and sgpp::datadriven::MPIMethods::waitForIncomingMessageType().
Referenced by trainParallel().
double sgpp::datadriven::LearnerSGDEOnOffParallel::getAccuracy | ( | ) | const |
Returns the accuracy of the classifier measured on the test data.
References sgpp::base::DataVector::get(), sgpp::datadriven::Dataset::getData(), sgpp::datadriven::Dataset::getNumberInstances(), sgpp::datadriven::Dataset::getTargets(), python.statsfileInfo::i, predict(), and testData.
std::vector< std::pair< std::unique_ptr< DBMatOnlineDE >, size_t > > & sgpp::datadriven::LearnerSGDEOnOffParallel::getDensityFunctions | ( | ) |
Returns the density functions mapped to class labels.
References densityFunctions.
Referenced by computeNewSystemMatrixDecomposition(), sgpp::datadriven::MPIMethods::receiveGridComponentsUpdate(), trainParallel(), and workBatch().
size_t sgpp::datadriven::LearnerSGDEOnOffParallel::getDimensionality | ( | ) |
Returns the dimensionality of the learner as determined from its training set.
References sgpp::datadriven::Dataset::getDimension(), and trainData.
Referenced by sgpp::datadriven::RefinementHandler::doRefinementForClass(), sgpp::datadriven::MPIMethods::receiveGridComponentsUpdate(), sgpp::datadriven::MPIMethods::runBatch(), and sgpp::datadriven::RefinementHandler::updateClassVariablesAfterRefinement().
double sgpp::datadriven::LearnerSGDEOnOffParallel::getError | ( | Dataset & | dataset | ) | const |
Error evaluation required for convergence-based refinement.
dataset | The data to measure the error on |
References sgpp::base::DataVector::get(), sgpp::datadriven::Dataset::getData(), sgpp::datadriven::Dataset::getNumberInstances(), sgpp::datadriven::Dataset::getTargets(), python.statsfileInfo::i, and predict().
Referenced by sgpp::datadriven::RefinementHandler::checkRefinementNecessary().
Grid & sgpp::datadriven::LearnerSGDEOnOffParallel::getGrid | ( | size_t | classIndex | ) |
Retrieves the grid for a certain class.
classIndex | the index of the desired class |
References grids.
Referenced by sgpp::datadriven::MPIMethods::receiveGridComponentsUpdate().
size_t sgpp::datadriven::LearnerSGDEOnOffParallel::getLocalGridVersion | ( | size_t | classIndex | ) |
Returns the internally stored current version of the grid.
classIndex | The class of the grid to search for |
References localGridVersions.
Referenced by sgpp::datadriven::MPIMethods::assignSystemMatrixUpdate(), computeNewSystemMatrixDecomposition(), mergeAlphaValues(), sgpp::datadriven::MPIMethods::receiveGridComponentsUpdate(), sgpp::datadriven::MPIMethods::sendMergeGridNetworkMessage(), sgpp::datadriven::MPIMethods::sendSystemMatrixDecomposition(), sgpp::datadriven::RefinementHandler::updateClassVariablesAfterRefinement(), waitForAllGridsConsistent(), sgpp::datadriven::MPIMethods::waitForGridConsistent(), and workBatch().
size_t sgpp::datadriven::LearnerSGDEOnOffParallel::getNumClasses | ( | ) | const |
Returns the number of existing classes.
References numClasses.
Referenced by allocateClassMatrices(), sgpp::datadriven::RoundRobinScheduler::assignTaskStaticTaskSize(), doRefinementForAll(), train(), and workBatch().
std::unique_ptr< DBMatOffline > & sgpp::datadriven::LearnerSGDEOnOffParallel::getOffline | ( | ) |
Gets the DBMatOffline object.
References offline.
Referenced by sgpp::datadriven::RefinementHandler::checkRefinementNecessary(), and sgpp::datadriven::RefinementHandler::updateClassVariablesAfterRefinement().
RefinementHandler & sgpp::datadriven::LearnerSGDEOnOffParallel::getRefinementHandler | ( | ) |
Returns a reference to the refinement handler, that contains logic to handle the master's refinement cycles.
References refinementHandler.
Referenced by sgpp::datadriven::MPIMethods::receiveGridComponentsUpdate().
MPITaskScheduler & sgpp::datadriven::LearnerSGDEOnOffParallel::getScheduler | ( | ) |
Gets a reference to the currently installed MPI Scheduler.
The scheduler assigns tasks of variable or static size to workers.
References mpiTaskScheduler.
Referenced by sgpp::datadriven::RefinementHandler::checkReadyForRefinement(), and sgpp::datadriven::RefinementHandler::updateClassVariablesAfterRefinement().
Dataset & sgpp::datadriven::LearnerSGDEOnOffParallel::getTrainData | ( | ) |
Returns a reference to the currently used training data set.
References trainData.
Referenced by sgpp::datadriven::RefinementHandler::checkRefinementNecessary(), and sgpp::datadriven::RefinementHandler::handleSurplusBasedRefinement().
Dataset * sgpp::datadriven::LearnerSGDEOnOffParallel::getValidationData | ( | ) |
Returns a reference to the currently used test data set.
References validationData.
Referenced by sgpp::datadriven::RefinementHandler::checkRefinementNecessary().
|
static |
Check whether a specific grid version is consistent, i.e.
whether it is higher than MINIMUM_CONSISTENT_GRID_VERSION
version | The version of the grid to check against |
Referenced by checkAllGridsConsistent(), checkGridStateConsistent(), doRefinementForAll(), mergeAlphaValues(), sgpp::datadriven::MPIMethods::receiveGridComponentsUpdate(), and setLocalGridVersion().
void sgpp::datadriven::LearnerSGDEOnOffParallel::mergeAlphaValues | ( | size_t | classIndex, |
size_t | remoteGridVersion, | ||
DataVector | dataVector, | ||
size_t | batchOffset, | ||
size_t | batchSize, | ||
bool | isLastPacketInSeries | ||
) |
Merge alpha values received from a remote process into the local alpha vector.
classIndex | The class to which the alpha vector belongs |
remoteGridVersion | The remote grid version this alpha vector was trained on |
dataVector | The alpha vector itself |
batchOffset | The offset from the start of the training set this vector was trained from |
batchSize | The size of the batch this vector was trained from |
isLastPacketInSeries | Whether this merge is the last merge in several for the same class and batch |
References sgpp::base::DataVector::add(), sgpp::datadriven::RefinementResult::addedGridPoints, alphas, classLabels, D, sgpp::datadriven::RefinementResult::deletedGridPointsIndices, sgpp::base::DataVector::get(), getLocalGridVersion(), sgpp::datadriven::RefinementHandler::getRefinementResult(), sgpp::base::DataVector::getSize(), python.statsfileInfo::i, isVersionConsistent(), localGridVersions, mpiTaskScheduler, sgpp::datadriven::MPITaskScheduler::onMergeRequestIncoming(), prior, refinementHandler, sgpp::base::DataVector::resizeZero(), usePrior, and sgpp::datadriven::MPIMethods::waitForGridConsistent().
Referenced by sgpp::datadriven::MPIMethods::receiveMergeGridNetworkMessage().
void sgpp::datadriven::LearnerSGDEOnOffParallel::predict | ( | DataMatrix & | test, |
DataVector & | classLabels | ||
) | const |
Predicts the class labels of the test data points.
test | The data points for which labels will be precicted |
classLabels | vector containing the predicted class labels |
References alphas, classLabels, densityFunctions, sgpp::base::DataMatrix::getNrows(), grids, numClasses, chess::point, prior, and python.utils.pca_normalize_dataset::u.
Referenced by getAccuracy(), and getError().
|
protected |
Shows grid size statistics along with a message.
messageString | The message to display alongside the statistics |
onlineObjects | The current density estimation objects |
References sgpp::base::Grid::getSize(), grid(), and grids.
Referenced by trainParallel().
void sgpp::datadriven::LearnerSGDEOnOffParallel::setLocalGridVersion | ( | size_t | classIndex, |
size_t | gridVersion | ||
) |
Set the grid version.
classIndex | The class of the grid to search for |
gridVersion | The new version of the grid |
References checkGridStateConsistent(), D, isVersionConsistent(), and localGridVersions.
Referenced by computeNewSystemMatrixDecomposition(), sgpp::datadriven::MPIMethods::receiveGridComponentsUpdate(), and sgpp::datadriven::RefinementHandler::updateClassVariablesAfterRefinement().
void sgpp::datadriven::LearnerSGDEOnOffParallel::shutdownMPINodes | ( | ) |
If this is run on master, it issues shutdown requests to all workers and waits for them to return.
If this is run on a worker, it sets the shutdown flag.
References sgpp::datadriven::MPIMethods::bcastCommandNoArgs(), sgpp::datadriven::MPIMethods::getWorldSize(), sgpp::datadriven::MPIMethods::isMaster(), sgpp::datadriven::SHUTDOWN, sgpp::datadriven::MPIMethods::waitForIncomingMessageType(), sgpp::datadriven::WORKER_SHUTDOWN_SUCCESS, and workerActive.
Referenced by sgpp::datadriven::MPIMethods::processIncomingMPICommands(), and trainParallel().
|
protected |
void sgpp::datadriven::LearnerSGDEOnOffParallel::train | ( | Dataset & | dataBatch, |
bool | doCrossValidation | ||
) |
Trains the learner with the given data batch.
dataBatch | The next data batch to process |
doCrossValidation | Enable cross-validation |
References allocateClassMatrices(), D, chess::dim, sgpp::datadriven::Dataset::getDimension(), sgpp::datadriven::Dataset::getNumberInstances(), getNumClasses(), and splitBatchIntoClasses().
Referenced by workBatch().
void sgpp::datadriven::LearnerSGDEOnOffParallel::train | ( | std::vector< std::pair< sgpp::base::DataMatrix *, double > > & | trainDataClasses, |
bool | doCrossValidation | ||
) |
Trains the learner with the given data batch that is already split up wrt its different classes.
trainDataClasses | A vector of pairs; Each pair contains the data points that belong to one class and the corresponding class label |
doCrossValidation | Enable cross-validation |
References sgpp::datadriven::RefinementResult::addedGridPoints, alphas, D, sgpp::datadriven::RefinementResult::deletedGridPointsIndices, densityEstimationConfig, densityFunctions, sgpp::datadriven::RefinementHandler::getRefinementResult(), grids, friedman::p, prior, processedPoints, refinementHandler, trained, and usePrior.
void sgpp::datadriven::LearnerSGDEOnOffParallel::trainParallel | ( | size_t | batchSize, |
size_t | maxDataPasses, | ||
std::string | refinementFunctorType, | ||
std::string | refMonitor, | ||
size_t | refPeriod, | ||
double | accDeclineThreshold, | ||
size_t | accDeclineBufferSize, | ||
size_t | minRefInterval | ||
) |
Trains the learner with the given dataset.
batchSize | Size of subset of data points used for each training step |
maxDataPasses | The number of passes over the whole training data |
refinementFunctorType | The refinement indicator (surplus, zero-crossings or data-based) |
refMonitor | The refinement strategy (periodic or convergence-based) |
refPeriod | The refinement interval (if periodic refinement is chosen) |
accDeclineThreshold | The convergence threshold (if convergence-based refinement is chosen) |
accDeclineBufferSize | The number of accuracy measurements which are used to check convergence (if convergence-based refinement is chosen) |
minRefInterval | The minimum number of data points (or data batches) which have to be processed before next refinement can be scheduled (if convergence-based refinement is chosen) |
References adaptivityConfig, assignBatchToWorker(), sgpp::datadriven::RefinementHandler::checkReadyForRefinement(), sgpp::datadriven::RefinementHandler::checkRefinementNecessary(), python.leja::count, D, doRefinementForAll(), getDensityFunctions(), sgpp::datadriven::Dataset::getNumberInstances(), sgpp::datadriven::MPIMethods::getQueueSize(), sgpp::datadriven::MPIMethods::hasPendingOutgoingRequests(), sgpp::datadriven::MPIMethods::isMaster(), MPI_MASTER_RANK, mpiTaskScheduler, sgpp::datadriven::MPITaskScheduler::onRefinementStarted(), printGridSizeStatistics(), sgpp::datadriven::MPIMethods::processCompletedMPIRequests(), processedPoints, refinementHandler, sgpp::datadriven::MPIMethods::sendCommandNoArgs(), shutdownMPINodes(), trainData, sgpp::datadriven::MPIMethods::waitForAnyMPIRequestsToComplete(), sgpp::datadriven::WORKER_SHUTDOWN_SUCCESS, and workerActive.
void sgpp::datadriven::LearnerSGDEOnOffParallel::updateAlpha | ( | size_t | classIndex, |
std::list< size_t > * | deletedPoints, | ||
size_t | newPoints | ||
) |
Updates the surplus vector of a certain class.
classIndex | the index of the class |
deletedPoints | a list of indexes of deleted points (coarsening) |
newPoints | the number of new grid points (refinemenet) |
References alpha, alphas, sgpp::base::DataVector::getSize(), sgpp::base::DataVector::remove(), and sgpp::base::DataVector::resizeZero().
Referenced by sgpp::datadriven::RefinementHandler::updateClassVariablesAfterRefinement().
|
protected |
Wait for all grids to reach a consistent state before continuing.
References checkGridStateConsistent(), getLocalGridVersion(), localGridVersions, and sgpp::datadriven::MPIMethods::waitForGridConsistent().
Referenced by workBatch().
void sgpp::datadriven::LearnerSGDEOnOffParallel::workBatch | ( | Dataset | dataset, |
size_t | batchOffset, | ||
bool | doCrossValidation | ||
) |
Train from a batch.
Will wait until all grids are consistent, fill the dataset, learn from the dataset and send the new alpha vector to the master
dataset | An empty dataset with size and dimension set. |
batchOffset | The offset from the start of the training set to assemble the batch from. |
doCrossValidation | Whether to cross validate results. |
References alphas, assembleNextBatchData(), D, getDensityFunctions(), getLocalGridVersion(), sgpp::datadriven::Dataset::getNumberInstances(), getNumClasses(), sgpp::datadriven::MPIMethods::sendMergeGridNetworkMessage(), train(), and waitForAllGridsConsistent().
Referenced by sgpp::datadriven::MPIMethods::runBatch().
|
protected |
Referenced by doRefinementForAll(), LearnerSGDEOnOffParallel(), and trainParallel().
|
protected |
Referenced by doRefinementForAll(), LearnerSGDEOnOffParallel(), mergeAlphaValues(), predict(), train(), updateAlpha(), and workBatch().
|
protected |
Referenced by LearnerSGDEOnOffParallel().
|
protected |
Referenced by LearnerSGDEOnOffParallel().
|
protected |
Referenced by allocateClassMatrices(), LearnerSGDEOnOffParallel(), mergeAlphaValues(), and predict().
|
protected |
Referenced by computeNewSystemMatrixDecomposition(), LearnerSGDEOnOffParallel(), and train().
|
protected |
Referenced by getDensityFunctions(), LearnerSGDEOnOffParallel(), predict(), and train().
|
protected |
Referenced by LearnerSGDEOnOffParallel().
|
protected |
|
protected |
Vector that holds the grid version for every class.
Referenced by checkAllGridsConsistent(), checkGridStateConsistent(), getLocalGridVersion(), LearnerSGDEOnOffParallel(), mergeAlphaValues(), setLocalGridVersion(), and waitForAllGridsConsistent().
|
protected |
Reference to the currently installed MPI Task Scheduler.
Referenced by assignBatchToWorker(), getScheduler(), LearnerSGDEOnOffParallel(), mergeAlphaValues(), and trainParallel().
|
protected |
Referenced by getNumClasses(), LearnerSGDEOnOffParallel(), and predict().
|
protected |
Referenced by getOffline(), and LearnerSGDEOnOffParallel().
|
protected |
Referenced by LearnerSGDEOnOffParallel().
|
protected |
Referenced by LearnerSGDEOnOffParallel(), mergeAlphaValues(), predict(), and train().
|
protected |
Referenced by LearnerSGDEOnOffParallel(), train(), and trainParallel().
|
protected |
Instance of the currently installed refinement handler.
Referenced by computeNewSystemMatrixDecomposition(), doRefinementForAll(), getRefinementHandler(), LearnerSGDEOnOffParallel(), mergeAlphaValues(), train(), and trainParallel().
|
protected |
Referenced by computeNewSystemMatrixDecomposition(), and LearnerSGDEOnOffParallel().
|
protected |
Referenced by getAccuracy(), and LearnerSGDEOnOffParallel().
|
protected |
Referenced by python.uq.dists.KDEDist.KDEDist::__init__(), assembleNextBatchData(), assignBatchToWorker(), doRefinementForAll(), getDimensionality(), getTrainData(), python.uq.dists.KDEDist.KDEDist::marginalize(), python.uq.dists.KDEDist.KDEDist::marginalizeToDimX(), python.uq.dists.KDEDist.KDEDist::toJson(), and trainParallel().
|
protected |
Referenced by LearnerSGDEOnOffParallel(), and train().
|
protected |
Referenced by LearnerSGDEOnOffParallel(), mergeAlphaValues(), and train().
|
protected |
Referenced by getValidationData(), and LearnerSGDEOnOffParallel().
|
protected |
Boolean used to detect when a shutdown of a worker has been requested.
Referenced by LearnerSGDEOnOffParallel(), shutdownMPINodes(), and trainParallel().