SG++
sgpp::datadriven::RefinementHandler Class Reference

#include <RefinementHandler.hpp>

Public Member Functions

bool checkReadyForRefinement () const
 Check whether all grids are consistent and the scheduler is currently allowing refinement. More...
 
bool checkRefinementNecessary (const std::string &refMonitor, size_t refPeriod, size_t totalInstances, double currentValidError, double currentTrainError, size_t numberOfCompletedRefinements, ConvergenceMonitor &monitor)
 Check whether refinement is currently necessary according to the guidelines set by the user. More...
 
void doRefinementForClass (const std::string &refType, RefinementResult *refinementResult, const ClassDensityConntainer &onlineObjects, bool preCompute, MultiGridRefinementFunctor *refinementFunctor, size_t classIndex)
 Handles refinement for a specific class. More...
 
RefinementResultgetRefinementResult (size_t classIndex)
 Fetches the currently stored refinement results for a specific class. More...
 
 RefinementHandler (LearnerSGDEOnOffParallel *learnerInstance, size_t numClasses)
 Creates the refinement handler for a specific learner instance. More...
 
void updateClassVariablesAfterRefinement (size_t classIndex, RefinementResult *refinementResult, DBMatOnlineDE *densEst)
 After refinement completes locally or refinement results have been received over MPI, this method uses the results to adjust the grid and alpha vector. More...
 

Protected Member Functions

size_t handleDataAndZeroBasedRefinement (bool preCompute, MultiGridRefinementFunctor *func, size_t idx, base::Grid &grid, base::GridGenerator &gridGen) const
 Logic that handles data-based and zero-crossing refinement functors. More...
 
size_t handleSurplusBasedRefinement (DBMatOnlineDE *densEst, Grid &grid, base::GridGenerator &gridGen) const
 Logic that handles surplus based refinement functors. More...
 

Protected Attributes

LearnerSGDEOnOffParallellearnerInstance
 
std::vector< RefinementResultvectorRefinementResults
 

Constructor & Destructor Documentation

sgpp::datadriven::RefinementHandler::RefinementHandler ( LearnerSGDEOnOffParallel learnerInstance,
size_t  numClasses 
)

Creates the refinement handler for a specific learner instance.

Parameters
learnerInstanceThe instance of the learner to handle refinement for
numClassesThe number of classes for the current problem

References learnerInstance, and vectorRefinementResults.

Member Function Documentation

bool sgpp::datadriven::RefinementHandler::checkReadyForRefinement ( ) const

Check whether all grids are consistent and the scheduler is currently allowing refinement.

Returns
Whether refinement is currently possible

References sgpp::datadriven::LearnerSGDEOnOffParallel::checkAllGridsConsistent(), sgpp::datadriven::LearnerSGDEOnOffParallel::getScheduler(), sgpp::datadriven::MPITaskScheduler::isReadyForRefinement(), and learnerInstance.

Referenced by sgpp::datadriven::LearnerSGDEOnOffParallel::trainParallel().

bool sgpp::datadriven::RefinementHandler::checkRefinementNecessary ( const std::string &  refMonitor,
size_t  refPeriod,
size_t  totalInstances,
double  currentValidError,
double  currentTrainError,
size_t  numberOfCompletedRefinements,
ConvergenceMonitor monitor 
)

Check whether refinement is currently necessary according to the guidelines set by the user.

Parameters
refMonitorString constant specifying the monitor to use for refinement
refPeriodThe minimum period in which refinement cycles are allowed
totalInstancesThe number of batches that have already completed
currentValidErrorThe current validation error
currentTrainErrorThe current training error
numberOfCompletedRefinementsThe number of refinement cycles already completed
monitorThe convergence monitor, if any
Returns
Whether a refinement cycle should be started

References sgpp::datadriven::ConvergenceMonitor::checkConvergence(), sgpp::datadriven::LearnerSGDEOnOff::getError(), sgpp::datadriven::LearnerSGDEOnOffParallel::getOffline(), sgpp::datadriven::LearnerSGDEOnOffParallel::getTrainData(), sgpp::datadriven::LearnerSGDEOnOffParallel::getValidationData(), learnerInstance, sgpp::datadriven::ConvergenceMonitor::nextRefCnt, and sgpp::datadriven::ConvergenceMonitor::pushToBuffer().

Referenced by sgpp::datadriven::LearnerSGDEOnOffParallel::trainParallel().

void sgpp::datadriven::RefinementHandler::doRefinementForClass ( const std::string &  refType,
RefinementResult refinementResult,
const ClassDensityConntainer onlineObjects,
bool  preCompute,
MultiGridRefinementFunctor refinementFunctor,
size_t  classIndex 
)

Handles refinement for a specific class.

Parameters
refTypeString constant specifying the type of refinement functor
refinementResultThe RefinementResult used to store changes for the grid
onlineObjectsThe density estimation online objects
preComputeWhether to precompute the functor's evaluation step
refinementFunctorThe refinement functor to use
classIndexThe index of the current class for which refinement is taking place

References sgpp::datadriven::RefinementResult::addedGridPoints, D, sgpp::datadriven::RefinementResult::deletedGridPointsIndices, sgpp::base::HashGridPoint::get(), sgpp::datadriven::LearnerSGDEOnOffParallel::getDimensionality(), sgpp::base::Grid::getGenerator(), sgpp::base::Grid::getSize(), sgpp::base::Grid::getStorage(), grid(), handleDataAndZeroBasedRefinement(), handleSurplusBasedRefinement(), learnerInstance, and updateClassVariablesAfterRefinement().

Referenced by sgpp::datadriven::LearnerSGDEOnOffParallel::doRefinementForAll().

RefinementResult & sgpp::datadriven::RefinementHandler::getRefinementResult ( size_t  classIndex)

Fetches the currently stored refinement results for a specific class.

Parameters
classIndexThe class to search refinement results for
Returns
A reference to the stored refinement results

References vectorRefinementResults.

Referenced by sgpp::datadriven::LearnerSGDEOnOffParallel::computeNewSystemMatrixDecomposition(), sgpp::datadriven::LearnerSGDEOnOffParallel::doRefinementForAll(), sgpp::datadriven::LearnerSGDEOnOffParallel::mergeAlphaValues(), sgpp::datadriven::MPIMethods::receiveGridComponentsUpdate(), and sgpp::datadriven::LearnerSGDEOnOffParallel::train().

size_t sgpp::datadriven::RefinementHandler::handleDataAndZeroBasedRefinement ( bool  preCompute,
MultiGridRefinementFunctor func,
size_t  idx,
base::Grid grid,
base::GridGenerator gridGen 
) const
protected

Logic that handles data-based and zero-crossing refinement functors.

Parameters
preComputeWhether to precompute evaluations in the functor
funcPointer to the refinement functor itself
idxClass index
gridThe grid for the current class
gridGenThe grid's generator for the current grid
Returns
The number of added grid points

References sgpp::base::Grid::getSize(), sgpp::datadriven::MultiGridRefinementFunctor::preComputeEvaluations(), sgpp::base::GridGenerator::refine(), and sgpp::datadriven::MultiGridRefinementFunctor::setGridIndex().

Referenced by doRefinementForClass().

size_t sgpp::datadriven::RefinementHandler::handleSurplusBasedRefinement ( DBMatOnlineDE densEst,
base::Grid grid,
base::GridGenerator gridGen 
) const
protected
void sgpp::datadriven::RefinementHandler::updateClassVariablesAfterRefinement ( size_t  classIndex,
RefinementResult refinementResult,
DBMatOnlineDE densEst 
)

After refinement completes locally or refinement results have been received over MPI, this method uses the results to adjust the grid and alpha vector.

If this is run on the master, the grid changes will be exported over MPI. If the system matrix is refineable, a system matrix update will be assigned to workers over MPI.

Parameters
classIndexThe index of the class being updated
refinementResultThe grid changes from the refinement cycle
densEstA pointer to the online object specfic to this class

References sgpp::datadriven::RefinementResult::addedGridPoints, sgpp::datadriven::MPIMethods::assignSystemMatrixUpdate(), sgpp::datadriven::MPITaskScheduler::assignTaskStaticTaskSize(), sgpp::datadriven::LearnerSGDEOnOffParallel::checkGridStateConsistent(), D, sgpp::datadriven::RefinementResult::deletedGridPointsIndices, sgpp::base::HashGridStorage::deletePoints(), sgpp::datadriven::DBMatOnlineDE::getAlpha(), sgpp::datadriven::LearnerSGDEOnOffParallel::getDimensionality(), sgpp::datadriven::DBMatOffline::getGrid(), sgpp::datadriven::LearnerSGDEOnOffParallel::getLocalGridVersion(), sgpp::datadriven::LearnerSGDEOnOffParallel::getOffline(), sgpp::datadriven::DBMatOnline::getOfflineObject(), sgpp::datadriven::LearnerSGDEOnOffParallel::getScheduler(), sgpp::base::Grid::getSize(), sgpp::base::Grid::getStorage(), grid(), sgpp::base::HashGridStorage::insert(), sgpp::datadriven::MPIMethods::isMaster(), learnerInstance, level, sgpp::base::HashGridStorage::recalcLeafProperty(), sgpp::datadriven::RECOMPUTE_SYSTEM_MATRIX_DECOMPOSITION, sgpp::datadriven::MPIMethods::sendRefinementUpdates(), sgpp::datadriven::LearnerSGDEOnOffParallel::setLocalGridVersion(), and sgpp::datadriven::DBMatOnlineDE::updateAlpha().

Referenced by doRefinementForClass(), and sgpp::datadriven::MPIMethods::receiveGridComponentsUpdate().

Member Data Documentation

std::vector<RefinementResult> sgpp::datadriven::RefinementHandler::vectorRefinementResults
protected

The documentation for this class was generated from the following files: