![]() |
SG++-Doxygen-Documentation
|
Fitter object that encapsulates density based classification using instances of ModelFittingDensityEstimation for each class. More...
#include <ModelFittingClassification.hpp>
Public Member Functions | |
| double | evaluate (const DataVector &sample) override |
| Predict the class of a data sample based on the density of the sample for each model. More... | |
| void | evaluate (DataMatrix &samples, DataVector &results) override |
| Predicts the class for a set of data points based on the learned densities for each class. More... | |
| void | fit (Dataset &dataset) override |
| Fits the models for all classes based on the data given in the dataset parameter. More... | |
| ModelFittingClassification (const FitterConfigurationClassification &config) | |
| Constructor. More... | |
| bool | refine () override |
| Improve the accuracy of the classification by refining the grids of each class. More... | |
| void | reset () override |
| Resets the state of the entire model. More... | |
| void | update (Dataset &dataset) override |
| Updates the models for each class based on new data (streaming or batch learning) More... | |
Public Member Functions inherited from sgpp::datadriven::ModelFittingBase | |
| const FitterConfiguration & | getFitterConfiguration () const |
| Get the configuration of the fitter object. More... | |
| ModelFittingBase () | |
| Default constructor. More... | |
| ModelFittingBase (const ModelFittingBase &rhs)=delete | |
| Copy constructor - we cannot deep copy all member variables yet. More... | |
| ModelFittingBase (ModelFittingBase &&rhs)=default | |
| Move constructor. More... | |
| ModelFittingBase & | operator= (const ModelFittingBase &rhs)=delete |
| Copy assign operator - we cannot deep copy all member variables yet. More... | |
| ModelFittingBase & | operator= (ModelFittingBase &&rhs)=default |
| Move assign operator. More... | |
| virtual | ~ModelFittingBase ()=default |
| virtual destructor. More... | |
Additional Inherited Members | |
Public Attributes inherited from sgpp::datadriven::ModelFittingBase | |
| bool | verboseSolver |
| Whether the Solver produces output or not. More... | |
Protected Member Functions inherited from sgpp::datadriven::ModelFittingBase | |
| Grid * | buildGrid (const RegularGridConfiguration &gridConfig) const |
| Factory member function that generates a grid from configuration. More... | |
| SLESolver * | buildSolver (const SLESolverConfiguration &config) const |
| Factory member function to build the solver for the least squares regression problem according to the config. More... | |
| void | reconfigureSolver (SLESolver &solver, const SLESolverConfiguration &config) const |
| Configure solver based on the desired configuration. More... | |
Protected Attributes inherited from sgpp::datadriven::ModelFittingBase | |
| std::unique_ptr< FitterConfiguration > | config |
| Configuration object for the fitter. More... | |
| Dataset * | dataset |
| Pointer to sgpp::datadriven::Dataset. More... | |
| std::unique_ptr< SLESolver > | solver |
| Solver for the learning problem. More... | |
Fitter object that encapsulates density based classification using instances of ModelFittingDensityEstimation for each class.
|
explicit |
Constructor.
| config | configuration object that specifies grid, refinement, and regularization |
References sgpp::datadriven::ModelFittingBase::config.
|
overridevirtual |
Predict the class of a data sample based on the density of the sample for each model.
| sample | the sample point to classify |
Implements sgpp::datadriven::ModelFittingBase.
References sgpp::datadriven::ModelFittingBase::config, and friedman::p.
Referenced by evaluate().
|
overridevirtual |
Predicts the class for a set of data points based on the learned densities for each class.
| samples | matrix where each row represents a data sample |
| results | vector to output the predicted classes |
Implements sgpp::datadriven::ModelFittingBase.
References evaluate(), sgpp::base::DataMatrix::getNcols(), sgpp::base::DataMatrix::getNrows(), sgpp::base::DataMatrix::getRow(), python.statsfileInfo::i, sgpp::base::DataVector::set(), and analyse_erg::tmp.
|
overridevirtual |
Fits the models for all classes based on the data given in the dataset parameter.
| dataset | the training dataset that is used to fit the models |
Implements sgpp::datadriven::ModelFittingBase.
References sgpp::datadriven::CG, sgpp::datadriven::ModelFittingBase::config, sgpp::datadriven::ModelFittingBase::dataset, sgpp::datadriven::Decomposition, create_dataset::default, sgpp::datadriven::Dataset::getData(), sgpp::datadriven::FitterConfiguration::getDensityEstimationConfig(), sgpp::datadriven::Dataset::getTargets(), sgpp::base::AdaptivityConfiguration::levelPenalize, sgpp::base::AdaptivityConfiguration::noPoints_, sgpp::base::AdaptivityConfiguration::precomputeEvaluations, sgpp::base::AdaptivityConfiguration::refinementFunctorType, reset(), sgpp::base::AdaptivityConfiguration::scalingCoefficients, sgpp::base::AdaptivityConfiguration::threshold_, sgpp::datadriven::DensityEstimationConfiguration::type_, python.utils.pca_normalize_dataset::u, and update().
|
overridevirtual |
Improve the accuracy of the classification by refining the grids of each class.
Implements sgpp::datadriven::ModelFittingBase.
References sgpp::datadriven::ModelFittingBase::config, sgpp::base::AdaptivityConfiguration::numRefinements_, sgpp::datadriven::MultiGridRefinementFunctor::preComputeEvaluations(), sgpp::base::AdaptivityConfiguration::precomputeEvaluations, sgpp::datadriven::MultipleClassRefinementFunctor::refine(), sgpp::base::AdaptivityConfiguration::refinementFunctorType, and sgpp::datadriven::MultiGridRefinementFunctor::setGridIndex().
|
overridevirtual |
Resets the state of the entire model.
Implements sgpp::datadriven::ModelFittingBase.
Referenced by fit().
|
overridevirtual |
Updates the models for each class based on new data (streaming or batch learning)
| dataset | the new data |
Implements sgpp::datadriven::ModelFittingBase.
References sgpp::datadriven::ModelFittingBase::dataset, sgpp::base::DataVector::get(), sgpp::datadriven::Dataset::getData(), sgpp::datadriven::Dataset::getDimension(), sgpp::base::DataMatrix::getNrows(), sgpp::datadriven::Dataset::getNumberInstances(), sgpp::base::DataMatrix::getRow(), sgpp::datadriven::Dataset::getTargets(), python.statsfileInfo::i, friedman::p, parabolasimple::samples, and analyse_erg::tmp.
Referenced by fit().