SG++-Doxygen-Documentation
sgpp::datadriven::RefinementHandler Class Reference

#include <RefinementHandler.hpp>

Public Member Functions

bool checkReadyForRefinement () const
 Check whether all grids are consistent and the scheduler is currently allowing refinement. More...
 
size_t checkRefinementNecessary (const std::string &refMonitor, size_t refPeriod, size_t batchSize, double currentValidError, double currentTrainError, size_t numberOfCompletedRefinements, RefinementMonitor &monitor, sgpp::base::AdaptivityConfiguration adaptivityConfig)
 Check whether refinement is currently necessary according to the guidelines set by the user. More...
 
void doRefinementForClass (const std::string &refType, RefinementResult *refinementResult, const std::vector< std::pair< std::unique_ptr< DBMatOnlineDE >, size_t >> &onlineObjects, Grid &grid, DataVector &alpha, bool preCompute, MultiGridRefinementFunctor *refinementFunctor, size_t classIndex, sgpp::base::AdaptivityConfiguration &adaptivityConfig)
 Handles refinement for a specific class. More...
 
RefinementResultgetRefinementResult (size_t classIndex)
 Fetches the currently stored refinement results for a specific class. More...
 
 RefinementHandler (LearnerSGDEOnOffParallel *learnerInstance, size_t numClasses)
 Creates the refinement handler for a specific learner instance. More...
 
void updateClassVariablesAfterRefinement (size_t classIndex, RefinementResult *refinementResult, DBMatOnlineDE *densEst, Grid &grid)
 After refinement completes locally or refinement results have been received over MPI, this method uses the results to adjust the grid and alpha vector. More...
 

Protected Member Functions

size_t handleDataAndZeroBasedRefinement (bool preCompute, MultiGridRefinementFunctor *func, size_t idx, base::Grid &grid, base::GridGenerator &gridGen) const
 Logic that handles data-based and zero-crossing refinement functors. More...
 
size_t handleSurplusBasedRefinement (DBMatOnlineDE *densEst, Grid &grid, DataVector &alpha, base::GridGenerator &gridGen, sgpp::base::AdaptivityConfiguration adaptivityConfig) const
 Logic that handles surplus based refinement functors. More...
 

Protected Attributes

LearnerSGDEOnOffParallellearnerInstance
 
std::vector< RefinementResultvectorRefinementResults
 

Constructor & Destructor Documentation

◆ RefinementHandler()

sgpp::datadriven::RefinementHandler::RefinementHandler ( LearnerSGDEOnOffParallel learnerInstance,
size_t  numClasses 
)

Creates the refinement handler for a specific learner instance.

Parameters
learnerInstanceThe instance of the learner to handle refinement for
numClassesThe number of classes for the current problem

References learnerInstance, and vectorRefinementResults.

Member Function Documentation

◆ checkReadyForRefinement()

bool sgpp::datadriven::RefinementHandler::checkReadyForRefinement ( ) const

Check whether all grids are consistent and the scheduler is currently allowing refinement.

Returns
Whether refinement is currently possible

References sgpp::datadriven::LearnerSGDEOnOffParallel::checkAllGridsConsistent(), sgpp::datadriven::LearnerSGDEOnOffParallel::getScheduler(), sgpp::datadriven::MPITaskScheduler::isReadyForRefinement(), and learnerInstance.

Referenced by sgpp::datadriven::LearnerSGDEOnOffParallel::trainParallel().

◆ checkRefinementNecessary()

size_t sgpp::datadriven::RefinementHandler::checkRefinementNecessary ( const std::string &  refMonitor,
size_t  refPeriod,
size_t  batchSize,
double  currentValidError,
double  currentTrainError,
size_t  numberOfCompletedRefinements,
RefinementMonitor monitor,
sgpp::base::AdaptivityConfiguration  adaptivityConfig 
)

Check whether refinement is currently necessary according to the guidelines set by the user.

Parameters
refMonitorString constant specifying the monitor to use for refinement
refPeriodThe minimum period in which refinement cycles are allowed
batchSizeThe number of instances that were added by the current batch
currentValidErrorThe current validation error
currentTrainErrorThe current training error
numberOfCompletedRefinementsThe number of refinement cycles already completed
monitorThe convergence monitor, if any
adaptivityConfigthe configuration for the adaptivity of the grids
Returns
How many refinement cycles should be started

References sgpp::datadriven::LearnerSGDEOnOffParallel::getError(), sgpp::datadriven::LearnerSGDEOnOffParallel::getOffline(), sgpp::datadriven::LearnerSGDEOnOffParallel::getTrainData(), sgpp::datadriven::LearnerSGDEOnOffParallel::getValidationData(), learnerInstance, sgpp::base::AdaptivityConfiguration::numRefinements_, sgpp::datadriven::RefinementMonitor::pushToBuffer(), and sgpp::datadriven::RefinementMonitor::refinementsNecessary().

Referenced by sgpp::datadriven::LearnerSGDEOnOffParallel::trainParallel().

◆ doRefinementForClass()

void sgpp::datadriven::RefinementHandler::doRefinementForClass ( const std::string &  refType,
RefinementResult refinementResult,
const std::vector< std::pair< std::unique_ptr< DBMatOnlineDE >, size_t >> &  onlineObjects,
Grid grid,
DataVector alpha,
bool  preCompute,
MultiGridRefinementFunctor refinementFunctor,
size_t  classIndex,
sgpp::base::AdaptivityConfiguration adaptivityConfig 
)

Handles refinement for a specific class.

Parameters
refTypeString constant specifying the type of refinement functor
refinementResultThe RefinementResult used to store changes for the grid
onlineObjectsThe density estimation online objects
gridthe grid of the online object of the current class
alphathe surplusses of the current class
preComputeWhether to precompute the functor's evaluation step
refinementFunctorThe refinement functor to use
classIndexThe index of the current class for which refinement is taking place
adaptivityConfigthe configuration for the adaptivity of the grids

References sgpp::datadriven::RefinementResult::addedGridPoints, D, sgpp::datadriven::RefinementResult::deletedGridPointsIndices, sgpp::base::HashGridPoint::get(), sgpp::datadriven::LearnerSGDEOnOffParallel::getDimensionality(), sgpp::base::Grid::getGenerator(), sgpp::base::Grid::getSize(), sgpp::base::Grid::getStorage(), handleDataAndZeroBasedRefinement(), handleSurplusBasedRefinement(), python.statsfileInfo::i, learnerInstance, and updateClassVariablesAfterRefinement().

Referenced by sgpp::datadriven::LearnerSGDEOnOffParallel::doRefinementForAll().

◆ getRefinementResult()

RefinementResult & sgpp::datadriven::RefinementHandler::getRefinementResult ( size_t  classIndex)

Fetches the currently stored refinement results for a specific class.

Parameters
classIndexThe class to search refinement results for
Returns
A reference to the stored refinement results

References vectorRefinementResults.

Referenced by sgpp::datadriven::LearnerSGDEOnOffParallel::computeNewSystemMatrixDecomposition(), sgpp::datadriven::LearnerSGDEOnOffParallel::doRefinementForAll(), sgpp::datadriven::LearnerSGDEOnOffParallel::mergeAlphaValues(), sgpp::datadriven::MPIMethods::receiveGridComponentsUpdate(), and sgpp::datadriven::LearnerSGDEOnOffParallel::train().

◆ handleDataAndZeroBasedRefinement()

size_t sgpp::datadriven::RefinementHandler::handleDataAndZeroBasedRefinement ( bool  preCompute,
MultiGridRefinementFunctor func,
size_t  idx,
base::Grid grid,
base::GridGenerator gridGen 
) const
protected

Logic that handles data-based and zero-crossing refinement functors.

Parameters
preComputeWhether to precompute evaluations in the functor
funcPointer to the refinement functor itself
idxClass index
gridThe grid for the current class
gridGenThe grid's generator for the current grid
Returns
The number of added grid points

References sgpp::base::Grid::getSize(), sgpp::datadriven::MultiGridRefinementFunctor::preComputeEvaluations(), sgpp::base::GridGenerator::refine(), and sgpp::datadriven::MultiGridRefinementFunctor::setGridIndex().

Referenced by doRefinementForClass().

◆ handleSurplusBasedRefinement()

size_t sgpp::datadriven::RefinementHandler::handleSurplusBasedRefinement ( DBMatOnlineDE densEst,
base::Grid grid,
DataVector alpha,
base::GridGenerator gridGen,
sgpp::base::AdaptivityConfiguration  adaptivityConfig 
) const
protected

Logic that handles surplus based refinement functors.

Parameters
densEstOnline objects for use in density estimation for the current class
gridThe current classes grid
alphaThe current surpluss vector
gridGenThe current grid's grid generator
adaptivityConfigthe configuration for the adaptivity
Returns
The number of added grid points

References alpha, sgpp::op_factory::createOperationEval(), sgpp::base::DataVector::get(), sgpp::datadriven::Dataset::getDimension(), sgpp::base::HashGridStorage::getPoint(), sgpp::base::HashGridStorage::getSize(), sgpp::base::DataVector::getSize(), sgpp::base::Grid::getSize(), sgpp::base::HashGridPoint::getStandardCoordinates(), sgpp::base::Grid::getStorage(), sgpp::datadriven::LearnerSGDEOnOffParallel::getTrainData(), python.utils.sg_projections::gridStorage, learnerInstance, sgpp::base::AdaptivityConfiguration::noPoints_, friedman::p, and sgpp::base::GridGenerator::refine().

Referenced by doRefinementForClass().

◆ updateClassVariablesAfterRefinement()

void sgpp::datadriven::RefinementHandler::updateClassVariablesAfterRefinement ( size_t  classIndex,
RefinementResult refinementResult,
DBMatOnlineDE densEst,
Grid grid 
)

After refinement completes locally or refinement results have been received over MPI, this method uses the results to adjust the grid and alpha vector.

If this is run on the master, the grid changes will be exported over MPI. If the system matrix is refineable, a system matrix update will be assigned to workers over MPI.

Parameters
classIndexThe index of the class being updated
refinementResultThe grid changes from the refinement cycle
densEstA pointer to the online object specfic to this class
gridthe grid of the current density function

References sgpp::datadriven::RefinementResult::addedGridPoints, sgpp::datadriven::MPIMethods::assignSystemMatrixUpdate(), sgpp::datadriven::MPITaskScheduler::assignTaskStaticTaskSize(), sgpp::datadriven::LearnerSGDEOnOffParallel::checkGridStateConsistent(), D, sgpp::datadriven::RefinementResult::deletedGridPointsIndices, sgpp::base::HashGridStorage::deletePoints(), sgpp::datadriven::LearnerSGDEOnOffParallel::getDimensionality(), sgpp::datadriven::LearnerSGDEOnOffParallel::getLocalGridVersion(), sgpp::datadriven::LearnerSGDEOnOffParallel::getOffline(), sgpp::datadriven::LearnerSGDEOnOffParallel::getScheduler(), sgpp::base::Grid::getSize(), sgpp::base::Grid::getStorage(), sgpp::base::HashGridStorage::insert(), sgpp::datadriven::MPIMethods::isMaster(), learnerInstance, level, sgpp::base::HashGridStorage::recalcLeafProperty(), sgpp::datadriven::RECOMPUTE_SYSTEM_MATRIX_DECOMPOSITION, sgpp::datadriven::MPIMethods::sendRefinementUpdates(), sgpp::datadriven::LearnerSGDEOnOffParallel::setLocalGridVersion(), and sgpp::datadriven::LearnerSGDEOnOffParallel::updateAlpha().

Referenced by doRefinementForClass(), and sgpp::datadriven::MPIMethods::receiveGridComponentsUpdate().

Member Data Documentation

◆ learnerInstance

◆ vectorRefinementResults

std::vector<RefinementResult> sgpp::datadriven::RefinementHandler::vectorRefinementResults
protected

The documentation for this class was generated from the following files: