SG++-Doxygen-Documentation
|
Fitter object that encapsulates density based classification using instances of ModelFittingDensityEstimation for each class. More...
#include <ModelFittingClassification.hpp>
Public Member Functions | |
double | evaluate (const DataVector &sample) override |
Predict the class of a data sample based on the density of the sample for each model. More... | |
void | evaluate (DataMatrix &samples, DataVector &results) override |
Predicts the class for a set of data points based on the learned densities for each class. More... | |
void | fit (Dataset &dataset) override |
Fits the models for all classes based on the data given in the dataset parameter. More... | |
ModelFittingClassification (const FitterConfigurationClassification &config) | |
Constructor. More... | |
bool | refine () override |
Improve the accuracy of the classification by refining the grids of each class. More... | |
void | reset () override |
Resets the state of the entire model. More... | |
void | update (Dataset &dataset) override |
Updates the models for each class based on new data (streaming or batch learning) More... | |
Public Member Functions inherited from sgpp::datadriven::ModelFittingBase | |
const FitterConfiguration & | getFitterConfiguration () const |
Get the configuration of the fitter object. More... | |
ModelFittingBase () | |
Default constructor. More... | |
ModelFittingBase (const ModelFittingBase &rhs)=delete | |
Copy constructor - we cannot deep copy all member variables yet. More... | |
ModelFittingBase (ModelFittingBase &&rhs)=default | |
Move constructor. More... | |
ModelFittingBase & | operator= (const ModelFittingBase &rhs)=delete |
Copy assign operator - we cannot deep copy all member variables yet. More... | |
ModelFittingBase & | operator= (ModelFittingBase &&rhs)=default |
Move assign operator. More... | |
virtual | ~ModelFittingBase ()=default |
virtual destructor. More... | |
Additional Inherited Members | |
Public Attributes inherited from sgpp::datadriven::ModelFittingBase | |
bool | verboseSolver |
Whether the Solver produces output or not. More... | |
Protected Member Functions inherited from sgpp::datadriven::ModelFittingBase | |
Grid * | buildGrid (const RegularGridConfiguration &gridConfig) const |
Factory member function that generates a grid from configuration. More... | |
SLESolver * | buildSolver (const SLESolverConfiguration &config) const |
Factory member function to build the solver for the least squares regression problem according to the config. More... | |
void | reconfigureSolver (SLESolver &solver, const SLESolverConfiguration &config) const |
Configure solver based on the desired configuration. More... | |
Protected Attributes inherited from sgpp::datadriven::ModelFittingBase | |
std::unique_ptr< FitterConfiguration > | config |
Configuration object for the fitter. More... | |
Dataset * | dataset |
Pointer to sgpp::datadriven::Dataset. More... | |
std::unique_ptr< SLESolver > | solver |
Solver for the learning problem. More... | |
Fitter object that encapsulates density based classification using instances of ModelFittingDensityEstimation for each class.
|
explicit |
Constructor.
config | configuration object that specifies grid, refinement, and regularization |
References sgpp::datadriven::ModelFittingBase::config.
|
overridevirtual |
Predict the class of a data sample based on the density of the sample for each model.
sample | the sample point to classify |
Implements sgpp::datadriven::ModelFittingBase.
References sgpp::datadriven::ModelFittingBase::config, and friedman::p.
Referenced by evaluate().
|
overridevirtual |
Predicts the class for a set of data points based on the learned densities for each class.
samples | matrix where each row represents a data sample |
results | vector to output the predicted classes |
Implements sgpp::datadriven::ModelFittingBase.
References evaluate(), sgpp::base::DataMatrix::getNcols(), sgpp::base::DataMatrix::getNrows(), sgpp::base::DataMatrix::getRow(), python.statsfileInfo::i, sgpp::base::DataVector::set(), and analyse_erg::tmp.
|
overridevirtual |
Fits the models for all classes based on the data given in the dataset parameter.
dataset | the training dataset that is used to fit the models |
Implements sgpp::datadriven::ModelFittingBase.
References sgpp::datadriven::CG, sgpp::datadriven::ModelFittingBase::config, sgpp::datadriven::ModelFittingBase::dataset, sgpp::datadriven::Decomposition, create_dataset::default, sgpp::datadriven::Dataset::getData(), sgpp::datadriven::FitterConfiguration::getDensityEstimationConfig(), sgpp::datadriven::Dataset::getTargets(), sgpp::base::AdaptivityConfiguration::levelPenalize, sgpp::base::AdaptivityConfiguration::noPoints_, sgpp::base::AdaptivityConfiguration::precomputeEvaluations, sgpp::base::AdaptivityConfiguration::refinementFunctorType, reset(), sgpp::base::AdaptivityConfiguration::scalingCoefficients, sgpp::base::AdaptivityConfiguration::threshold_, sgpp::datadriven::DensityEstimationConfiguration::type_, python.utils.pca_normalize_dataset::u, and update().
|
overridevirtual |
Improve the accuracy of the classification by refining the grids of each class.
Implements sgpp::datadriven::ModelFittingBase.
References sgpp::datadriven::ModelFittingBase::config, sgpp::base::AdaptivityConfiguration::numRefinements_, sgpp::datadriven::MultiGridRefinementFunctor::preComputeEvaluations(), sgpp::base::AdaptivityConfiguration::precomputeEvaluations, sgpp::datadriven::MultipleClassRefinementFunctor::refine(), sgpp::base::AdaptivityConfiguration::refinementFunctorType, and sgpp::datadriven::MultiGridRefinementFunctor::setGridIndex().
|
overridevirtual |
Resets the state of the entire model.
Implements sgpp::datadriven::ModelFittingBase.
Referenced by fit().
|
overridevirtual |
Updates the models for each class based on new data (streaming or batch learning)
dataset | the new data |
Implements sgpp::datadriven::ModelFittingBase.
References sgpp::datadriven::ModelFittingBase::dataset, sgpp::base::DataVector::get(), sgpp::datadriven::Dataset::getData(), sgpp::datadriven::Dataset::getDimension(), sgpp::base::DataMatrix::getNrows(), sgpp::datadriven::Dataset::getNumberInstances(), sgpp::base::DataMatrix::getRow(), sgpp::datadriven::Dataset::getTargets(), python.statsfileInfo::i, friedman::p, parabolasimple::samples, and analyse_erg::tmp.
Referenced by fit().