SG++-Doxygen-Documentation
|
A refinement indicator for classification problems based on impurity measures (e.g. More...
#include <ImpurityRefinementIndicator.hpp>
Public Types | |
typedef GridPoint | counter_key_type |
typedef std::pair< size_t, double > | value_type |
Public Types inherited from sgpp::base::RefinementFunctor | |
typedef double | value_type |
Public Member Functions | |
size_t | getRefinementsNum () const override |
Returns the maximal number of points that should be refined. More... | |
double | getRefinementThreshold () const override |
Returns the threshold for refinement. More... | |
ImpurityRefinementIndicator (Grid &grid, DataMatrix &dataset, DataVector *alphas, DataVector *w1, DataVector *w2, DataVector &classesComputed, double threshold=0.0, size_t refinementsNum=1) | |
Constructor. More... | |
virtual double | operator() (GridPoint &point) const |
This should be returning a refinement indicator for the specified grid point. More... | |
double | operator() (GridStorage &storage, size_t seq) const override |
This should be returning a refinement value for every grid point. More... | |
double | start () const override |
Returns the lower bound of refinement criterion (e.g., alpha or error) (lower bound). More... | |
void | update (GridPoint &point) |
Update normal vector of SVM. More... | |
Public Member Functions inherited from sgpp::base::RefinementFunctor | |
virtual double | getTotalRefinementValue (GridStorage &storage) const |
Returns the total sum of local (error) indicators used for refinement. More... | |
RefinementFunctor () | |
Constructor. More... | |
virtual | ~RefinementFunctor () |
Destructor. More... | |
Public Attributes | |
DataVector * | alphas |
DataVector * | w1 |
DataVector * | w2 |
Protected Attributes | |
DataVector & | classesComputed |
DataMatrix & | dataset |
size_t | refinementsNum |
double | threshold |
A refinement indicator for classification problems based on impurity measures (e.g.
gini impurity, entropy impurity,...). It calculates local impurities based on the information from the provided data set. If the indicator is applied within the SVM learner, the normal vector needs to be extended after each refinement.
typedef std::pair<size_t, double> sgpp::base::ImpurityRefinementIndicator::value_type |
sgpp::base::ImpurityRefinementIndicator::ImpurityRefinementIndicator | ( | Grid & | grid, |
DataMatrix & | dataset, | ||
DataVector * | alphas, | ||
DataVector * | w1, | ||
DataVector * | w2, | ||
DataVector & | classesComputed, | ||
double | threshold = 0.0 , |
||
size_t | refinementsNum = 1 |
||
) |
Constructor.
grid | The grid to refine. |
dataset | The set of data points used to compute impurities |
alphas | The weights corresponding to the support vectors (only required for SVM learner) |
w1 | Normal vector (only required for SVM learner) |
w2 | Normal vector computed with abs values (only required for SVM learner) |
classesComputed | The predicted labels for the data points from dataset |
threshold | The refinement threshold; Only grid points with indicator values greater than this threshold will be refined |
refinementsNum | The max amount of grid points to be refined |
|
overridevirtual |
Returns the maximal number of points that should be refined.
The maximal number of points to refine is set in the constructor of the implementing class.
Reimplemented from sgpp::base::RefinementFunctor.
References refinementsNum.
|
overridevirtual |
Returns the threshold for refinement.
Only the grid points with absolute value of refinement criterion greater than this threshold will be refined.
Implements sgpp::base::RefinementFunctor.
References threshold.
Referenced by sgpp::base::ImpurityRefinement::getIndicator().
|
virtual |
This should be returning a refinement indicator for the specified grid point.
The point with the highest value will be refined first.
point | The grid point for which to calculate an indicator value |
References classesComputed, dataset, chess::dim, sgpp::base::DataVector::get(), sgpp::base::DataMatrix::get(), sgpp::base::HashGridPoint::getDimension(), sgpp::base::HashGridPoint::getIndex(), sgpp::base::HashGridPoint::getLevel(), sgpp::base::DataMatrix::getNrows(), h, python.statsfileInfo::i, level, sgpp::combigrid::pow(), sgpp::base::DataVector::set(), and sgpp::base::DataVector::setAll().
|
overridevirtual |
This should be returning a refinement value for every grid point.
The point with the highest value will be refined first.
storage | Reference to the grids storage object |
seq | Sequence number in the coefficients array |
Implements sgpp::base::RefinementFunctor.
|
overridevirtual |
Returns the lower bound of refinement criterion (e.g., alpha or error) (lower bound).
The refinement value of grid points to be refined have to be larger than this value
Implements sgpp::base::RefinementFunctor.
void sgpp::base::ImpurityRefinementIndicator::update | ( | GridPoint & | point | ) |
Update normal vector of SVM.
For each new grid point the normal vector has to be extended by one component. Only required for SVMLearner!
point | The new grid point |
References alphas, sgpp::base::DataVector::append(), dataset, chess::dim, sgpp::base::Basis< LT, IT >::eval(), sgpp::base::DataVector::get(), sgpp::base::DataMatrix::get(), sgpp::base::Grid::getBasis(), sgpp::base::HashGridPoint::getDimension(), sgpp::base::HashGridPoint::getIndex(), sgpp::base::HashGridPoint::getLevel(), sgpp::base::DataMatrix::getNrows(), level, w1, and w2.
Referenced by python.uq.refinement.RefinementStrategy.Ranking::rank(), sgpp::base::ImpurityRefinement::refineGridpointsCollection(), and python.learner.LearnedKnowledge.LearnedKnowledge::setMemento().
DataVector* sgpp::base::ImpurityRefinementIndicator::alphas |
Referenced by sgpp::base::ImpurityRefinement::refineGridpointsCollection(), and update().
|
protected |
Referenced by operator()().
|
protected |
Referenced by operator()(), and update().
|
protected |
Referenced by getRefinementsNum().
|
protected |
Referenced by getRefinementThreshold().
DataVector* sgpp::base::ImpurityRefinementIndicator::w1 |
Referenced by sgpp::base::ImpurityRefinement::refineGridpointsCollection(), and update().
DataVector* sgpp::base::ImpurityRefinementIndicator::w2 |
Referenced by update().