SG++-Doxygen-Documentation
|
#include <RefinementHandler.hpp>
Public Member Functions | |
bool | checkReadyForRefinement () const |
Check whether all grids are consistent and the scheduler is currently allowing refinement. More... | |
size_t | checkRefinementNecessary (const std::string &refMonitor, size_t refPeriod, size_t batchSize, double currentValidError, double currentTrainError, size_t numberOfCompletedRefinements, RefinementMonitor &monitor, sgpp::base::AdaptivityConfiguration adaptivityConfig) |
Check whether refinement is currently necessary according to the guidelines set by the user. More... | |
void | doRefinementForClass (const std::string &refType, RefinementResult *refinementResult, const std::vector< std::pair< std::unique_ptr< DBMatOnlineDE >, size_t >> &onlineObjects, Grid &grid, DataVector &alpha, bool preCompute, MultiGridRefinementFunctor *refinementFunctor, size_t classIndex, sgpp::base::AdaptivityConfiguration &adaptivityConfig) |
Handles refinement for a specific class. More... | |
RefinementResult & | getRefinementResult (size_t classIndex) |
Fetches the currently stored refinement results for a specific class. More... | |
RefinementHandler (LearnerSGDEOnOffParallel *learnerInstance, size_t numClasses) | |
Creates the refinement handler for a specific learner instance. More... | |
void | updateClassVariablesAfterRefinement (size_t classIndex, RefinementResult *refinementResult, DBMatOnlineDE *densEst, Grid &grid) |
After refinement completes locally or refinement results have been received over MPI, this method uses the results to adjust the grid and alpha vector. More... | |
Protected Member Functions | |
size_t | handleDataAndZeroBasedRefinement (bool preCompute, MultiGridRefinementFunctor *func, size_t idx, base::Grid &grid, base::GridGenerator &gridGen) const |
Logic that handles data-based and zero-crossing refinement functors. More... | |
size_t | handleSurplusBasedRefinement (DBMatOnlineDE *densEst, Grid &grid, DataVector &alpha, base::GridGenerator &gridGen, sgpp::base::AdaptivityConfiguration adaptivityConfig) const |
Logic that handles surplus based refinement functors. More... | |
Protected Attributes | |
LearnerSGDEOnOffParallel * | learnerInstance |
std::vector< RefinementResult > | vectorRefinementResults |
sgpp::datadriven::RefinementHandler::RefinementHandler | ( | LearnerSGDEOnOffParallel * | learnerInstance, |
size_t | numClasses | ||
) |
Creates the refinement handler for a specific learner instance.
learnerInstance | The instance of the learner to handle refinement for |
numClasses | The number of classes for the current problem |
References learnerInstance, and vectorRefinementResults.
bool sgpp::datadriven::RefinementHandler::checkReadyForRefinement | ( | ) | const |
Check whether all grids are consistent and the scheduler is currently allowing refinement.
References sgpp::datadriven::LearnerSGDEOnOffParallel::checkAllGridsConsistent(), sgpp::datadriven::LearnerSGDEOnOffParallel::getScheduler(), sgpp::datadriven::MPITaskScheduler::isReadyForRefinement(), and learnerInstance.
Referenced by sgpp::datadriven::LearnerSGDEOnOffParallel::trainParallel().
size_t sgpp::datadriven::RefinementHandler::checkRefinementNecessary | ( | const std::string & | refMonitor, |
size_t | refPeriod, | ||
size_t | batchSize, | ||
double | currentValidError, | ||
double | currentTrainError, | ||
size_t | numberOfCompletedRefinements, | ||
RefinementMonitor & | monitor, | ||
sgpp::base::AdaptivityConfiguration | adaptivityConfig | ||
) |
Check whether refinement is currently necessary according to the guidelines set by the user.
refMonitor | String constant specifying the monitor to use for refinement |
refPeriod | The minimum period in which refinement cycles are allowed |
batchSize | The number of instances that were added by the current batch |
currentValidError | The current validation error |
currentTrainError | The current training error |
numberOfCompletedRefinements | The number of refinement cycles already completed |
monitor | The convergence monitor, if any |
adaptivityConfig | the configuration for the adaptivity of the grids |
References sgpp::datadriven::LearnerSGDEOnOffParallel::getError(), sgpp::datadriven::LearnerSGDEOnOffParallel::getOffline(), sgpp::datadriven::LearnerSGDEOnOffParallel::getTrainData(), sgpp::datadriven::LearnerSGDEOnOffParallel::getValidationData(), learnerInstance, sgpp::base::AdaptivityConfiguration::numRefinements_, sgpp::datadriven::RefinementMonitor::pushToBuffer(), and sgpp::datadriven::RefinementMonitor::refinementsNecessary().
Referenced by sgpp::datadriven::LearnerSGDEOnOffParallel::trainParallel().
void sgpp::datadriven::RefinementHandler::doRefinementForClass | ( | const std::string & | refType, |
RefinementResult * | refinementResult, | ||
const std::vector< std::pair< std::unique_ptr< DBMatOnlineDE >, size_t >> & | onlineObjects, | ||
Grid & | grid, | ||
DataVector & | alpha, | ||
bool | preCompute, | ||
MultiGridRefinementFunctor * | refinementFunctor, | ||
size_t | classIndex, | ||
sgpp::base::AdaptivityConfiguration & | adaptivityConfig | ||
) |
Handles refinement for a specific class.
refType | String constant specifying the type of refinement functor |
refinementResult | The RefinementResult used to store changes for the grid |
onlineObjects | The density estimation online objects |
grid | the grid of the online object of the current class |
alpha | the surplusses of the current class |
preCompute | Whether to precompute the functor's evaluation step |
refinementFunctor | The refinement functor to use |
classIndex | The index of the current class for which refinement is taking place |
adaptivityConfig | the configuration for the adaptivity of the grids |
References sgpp::datadriven::RefinementResult::addedGridPoints, D, sgpp::datadriven::RefinementResult::deletedGridPointsIndices, sgpp::base::HashGridPoint::get(), sgpp::datadriven::LearnerSGDEOnOffParallel::getDimensionality(), sgpp::base::Grid::getGenerator(), sgpp::base::Grid::getSize(), sgpp::base::Grid::getStorage(), handleDataAndZeroBasedRefinement(), handleSurplusBasedRefinement(), python.statsfileInfo::i, learnerInstance, and updateClassVariablesAfterRefinement().
Referenced by sgpp::datadriven::LearnerSGDEOnOffParallel::doRefinementForAll().
RefinementResult & sgpp::datadriven::RefinementHandler::getRefinementResult | ( | size_t | classIndex | ) |
Fetches the currently stored refinement results for a specific class.
classIndex | The class to search refinement results for |
References vectorRefinementResults.
Referenced by sgpp::datadriven::LearnerSGDEOnOffParallel::computeNewSystemMatrixDecomposition(), sgpp::datadriven::LearnerSGDEOnOffParallel::doRefinementForAll(), sgpp::datadriven::LearnerSGDEOnOffParallel::mergeAlphaValues(), sgpp::datadriven::MPIMethods::receiveGridComponentsUpdate(), and sgpp::datadriven::LearnerSGDEOnOffParallel::train().
|
protected |
Logic that handles data-based and zero-crossing refinement functors.
preCompute | Whether to precompute evaluations in the functor |
func | Pointer to the refinement functor itself |
idx | Class index |
grid | The grid for the current class |
gridGen | The grid's generator for the current grid |
References sgpp::base::Grid::getSize(), sgpp::datadriven::MultiGridRefinementFunctor::preComputeEvaluations(), sgpp::base::GridGenerator::refine(), and sgpp::datadriven::MultiGridRefinementFunctor::setGridIndex().
Referenced by doRefinementForClass().
|
protected |
Logic that handles surplus based refinement functors.
densEst | Online objects for use in density estimation for the current class |
grid | The current classes grid |
alpha | The current surpluss vector |
gridGen | The current grid's grid generator |
adaptivityConfig | the configuration for the adaptivity |
References alpha, sgpp::op_factory::createOperationEval(), sgpp::base::DataVector::get(), sgpp::datadriven::Dataset::getDimension(), sgpp::base::HashGridStorage::getPoint(), sgpp::base::HashGridStorage::getSize(), sgpp::base::DataVector::getSize(), sgpp::base::Grid::getSize(), sgpp::base::HashGridPoint::getStandardCoordinates(), sgpp::base::Grid::getStorage(), sgpp::datadriven::LearnerSGDEOnOffParallel::getTrainData(), python.utils.sg_projections::gridStorage, learnerInstance, sgpp::base::AdaptivityConfiguration::noPoints_, friedman::p, and sgpp::base::GridGenerator::refine().
Referenced by doRefinementForClass().
void sgpp::datadriven::RefinementHandler::updateClassVariablesAfterRefinement | ( | size_t | classIndex, |
RefinementResult * | refinementResult, | ||
DBMatOnlineDE * | densEst, | ||
Grid & | grid | ||
) |
After refinement completes locally or refinement results have been received over MPI, this method uses the results to adjust the grid and alpha vector.
If this is run on the master, the grid changes will be exported over MPI. If the system matrix is refineable, a system matrix update will be assigned to workers over MPI.
classIndex | The index of the class being updated |
refinementResult | The grid changes from the refinement cycle |
densEst | A pointer to the online object specfic to this class |
grid | the grid of the current density function |
References sgpp::datadriven::RefinementResult::addedGridPoints, sgpp::datadriven::MPIMethods::assignSystemMatrixUpdate(), sgpp::datadriven::MPITaskScheduler::assignTaskStaticTaskSize(), sgpp::datadriven::LearnerSGDEOnOffParallel::checkGridStateConsistent(), D, sgpp::datadriven::RefinementResult::deletedGridPointsIndices, sgpp::base::HashGridStorage::deletePoints(), sgpp::datadriven::LearnerSGDEOnOffParallel::getDimensionality(), sgpp::datadriven::LearnerSGDEOnOffParallel::getLocalGridVersion(), sgpp::datadriven::LearnerSGDEOnOffParallel::getOffline(), sgpp::datadriven::LearnerSGDEOnOffParallel::getScheduler(), sgpp::base::Grid::getSize(), sgpp::base::Grid::getStorage(), sgpp::base::HashGridStorage::insert(), sgpp::datadriven::MPIMethods::isMaster(), learnerInstance, level, sgpp::base::HashGridStorage::recalcLeafProperty(), sgpp::datadriven::RECOMPUTE_SYSTEM_MATRIX_DECOMPOSITION, sgpp::datadriven::MPIMethods::sendRefinementUpdates(), sgpp::datadriven::LearnerSGDEOnOffParallel::setLocalGridVersion(), and sgpp::datadriven::LearnerSGDEOnOffParallel::updateAlpha().
Referenced by doRefinementForClass(), and sgpp::datadriven::MPIMethods::receiveGridComponentsUpdate().
|
protected |
|
protected |
Referenced by getRefinementResult(), and RefinementHandler().