SG++
sgpp::datadriven::BatchLearner Class Reference

The Batchlearner learns the data provided as input in batches. More...

#include <BatchLearner.hpp>

Public Member Functions

 BatchLearner (base::BatchConfiguration batchConfig, base::RegularGridConfiguration gridConfig, solver::SLESolverConfiguration solverConfig, base::AdpativityConfiguration adaptConfig)
 constructor taking all relevant parameters More...
 
void closeFile ()
 close the stream More...
 
double getAccCurrent ()
 Get the accuracy of the last batch predicted. More...
 
double getAccGlobal ()
 Get the accuracy over all predictions. More...
 
bool getIsFinished ()
 Get whether stream has been read to the end. More...
 
base::DataVector predict (base::DataMatrix &entries, bool updateNorm)
 predict labels of the data provided, return vector of predicted labels More...
 
void trainBatch ()
 read all lines needed for one batch (and maybe test data), call processBatch(..) if test is wanted by user: call predict(..) on test data, calculate, save and output accuracies More...
 

Protected Member Functions

base::DataVector applyWeight (base::DataVector alpha, int grid)
 function to apply one of the weighting methods to the alpha provided, returns the weighted alpha More...
 
void processBatch (std::string workData)
 function that processes the data provided as string, learns the new data and refines the grids (if wanted by user) More...
 
void stringToDataMatrix (std::string &input, base::DataMatrix &dataFound, base::DataVector &classesFound, bool mapData)
 function that parses a string containing many items if (mapData) and maps the data according to class into dataInBatch (will be cleared before usage, dataFound and classesFound will not be touched) if (!mapData) to a DataMatrix containg the data and a DataVector containing the classes (dataInBatch will not be touched) More...
 
void stringToDataVector (std::string input, base::DataVector &dataFound, int &classFound)
 function that parses a string to data and class More...
 

Protected Attributes

double acc_current = -1.0
 accuracy of the last call of predict(..) More...
 
double acc_global = -1.0
 
base::AdpativityConfiguration adaptConf
 configuration for the adaptivity More...
 
std::map< int, std::deque< base::DataVector > > alphaStorage
 mapping used to store the previous alphas More...
 
std::map< int, base::DataVector * > alphaVectors
 mapping of alpha vectors to label More...
 
std::string batch
 because testdata is added to the following batch: save batch global More...
 
base::BatchConfiguration batchConf
 configuration for the BatchLearner More...
 
int batchnum = 0
 number of the current batch More...
 
unsigned int bs = 0
 count of lines read for the batch More...
 
std::map< int, base::DataMatrix * > dataInBatch
 mapping of data in each batch to label More...
 
int dataLine = 0
 stores the current line number from @DATA More...
 
size_t dimensions = 0
 count of dimensions in the data More...
 
base::RegularGridConfiguration gridConf
 configuration for the grids More...
 
std::map< int, base::LinearGrid * > grids
 mapping of grids to label More...
 
bool isFinished = false
 indicates whether the stream has been read to the end More...
 
solver::SLESolvermyCG
 solver More...
 
std::map< int, double > normFactors
 mapping of factors for the normalization to label More...
 
std::map< int, size_t > occurences
 mapping of count of items to label More...
 
bool reachedData = false
 flag if "\@DATA" has been reached in the arff More...
 
std::fstream reader
 stream to read in the arff file More...
 
solver::SLESolverConfiguration solverConf
 configuration for the solver More...
 
int t_correct = 0
 items predicted correct More...
 
int t_total = 0
 total items tested More...
 

Detailed Description

The Batchlearner learns the data provided as input in batches.

The batches can by weighted with different functions. Labels are predicted by using density functions. Adaptivity can be enabled.

Constructor & Destructor Documentation

sgpp::datadriven::BatchLearner::BatchLearner ( base::BatchConfiguration  batchConfig,
base::RegularGridConfiguration  gridConfig,
solver::SLESolverConfiguration  solverConfig,
base::AdpativityConfiguration  adaptConfig 
)

constructor taking all relevant parameters

Parameters
batchConfigstruct containig all parameters specific for the BatchLearner
gridConfigconfiguration for the grids
solverConfigconfiguration for the solver
adaptConfigconfiguration for the adaptivity of the solver

References adaptConf, adaptConfig, batchConf, sgpp::solver::BiCGSTAB, sgpp::solver::CG, sgpp::solver::SLESolverConfiguration::eps_, sgpp::base::BatchConfiguration::filename_, gridConf, sgpp::solver::SLESolverConfiguration::maxIterations_, myCG, reader, solverConf, and sgpp::solver::SLESolverConfiguration::type_.

Member Function Documentation

base::DataVector sgpp::datadriven::BatchLearner::applyWeight ( base::DataVector  alpha,
int  grid 
)
protected

function to apply one of the weighting methods to the alpha provided, returns the weighted alpha

See also
batchConfig.wMode for the different weighting methods
Parameters
alpharecent alpha that has to be taken into account for the calculation
gridthe class the alphas referes to

References sgpp::base::DataVector::add(), alpha, alphaStorage, batchConf, python.leja::count, dataInBatch, python.test::f, sgpp::base::DataVector::getSize(), sgpp::base::DataVector::mult(), occurences, sgpp::combigrid::pow(), sgpp::base::DataVector::resizeZero(), sgpp::base::DataVector::setAll(), sgpp::base::BatchConfiguration::stack_, sgpp::base::BatchConfiguration::verbose_, sgpp::base::BatchConfiguration::wArgument_, and sgpp::base::BatchConfiguration::wMode_.

Referenced by processBatch().

void sgpp::datadriven::BatchLearner::closeFile ( )
inline

close the stream

double sgpp::datadriven::BatchLearner::getAccCurrent ( )
inline

Get the accuracy of the last batch predicted.

References acc_current.

double sgpp::datadriven::BatchLearner::getAccGlobal ( )
inline

Get the accuracy over all predictions.

References acc_global, predict(), and trainBatch().

bool sgpp::datadriven::BatchLearner::getIsFinished ( )
inline

Get whether stream has been read to the end.

References isFinished.

base::DataVector sgpp::datadriven::BatchLearner::predict ( base::DataMatrix entries,
bool  updateNorm 
)

predict labels of the data provided, return vector of predicted labels

Parameters
entriesDataMatrix containing the data to test
updateNormshould the normalization factors be updated before predicting?

References alphaVectors, batchConf, sgpp::op_factory::createOperationEval(), dimensions, python.test::f, g, sgpp::base::DataMatrix::getNcols(), sgpp::base::DataMatrix::getNrows(), sgpp::base::DataMatrix::getRow(), grids, normFactors, sgpp::base::BatchConfiguration::samples_, sgpp::base::DataVector::setAll(), sgpp::base::DataVector::toString(), and sgpp::base::BatchConfiguration::verbose_.

Referenced by getAccGlobal(), and trainBatch().

void sgpp::datadriven::BatchLearner::stringToDataMatrix ( std::string &  input,
base::DataMatrix dataFound,
base::DataVector classesFound,
bool  mapData 
)
protected

function that parses a string containing many items if (mapData) and maps the data according to class into dataInBatch (will be cleared before usage, dataFound and classesFound will not be touched) if (!mapData) to a DataMatrix containg the data and a DataVector containing the classes (dataInBatch will not be touched)

Parameters
inputmany items seperated by '
'
dataFoundDataMatrix that will contain the found data afterwards (will be cleared before usage)
classesFoundDataVector that will contain the found classes afterwards (will be cleared before usage)
mapDatacontrols whether data is mapped to dataInBatch or saved to dataFound and classesFound

References sgpp::base::DataVector::append(), sgpp::base::DataMatrix::appendRow(), python.leja::count, dataInBatch, dimensions, python.test::f, sgpp::base::DataMatrix::resize(), sgpp::base::DataVector::setAll(), sgpp::base::DataMatrix::setAll(), python.leja::start, and stringToDataVector().

Referenced by processBatch(), and trainBatch().

void sgpp::datadriven::BatchLearner::stringToDataVector ( std::string  input,
base::DataVector dataFound,
int &  classFound 
)
protected

function that parses a string to data and class

Parameters
inputstring of one item, values separated by ',', last entry is the class as int
dataFoundthe data found in the string
classFoundthe class found in the string

References sgpp::base::DataVector::copyFrom(), dimensions, sgpp::base::DataVector::getSize(), and sgpp::base::DataVector::set().

Referenced by stringToDataMatrix().

void sgpp::datadriven::BatchLearner::trainBatch ( )

Member Data Documentation

double sgpp::datadriven::BatchLearner::acc_current = -1.0
protected

accuracy of the last call of predict(..)

Referenced by getAccCurrent(), and trainBatch().

double sgpp::datadriven::BatchLearner::acc_global = -1.0
protected

Referenced by getAccGlobal(), and trainBatch().

base::AdpativityConfiguration sgpp::datadriven::BatchLearner::adaptConf
protected

configuration for the adaptivity

Referenced by BatchLearner(), and processBatch().

std::map<int, std::deque<base::DataVector> > sgpp::datadriven::BatchLearner::alphaStorage
protected

mapping used to store the previous alphas

Referenced by applyWeight().

std::map<int, base::DataVector*> sgpp::datadriven::BatchLearner::alphaVectors
protected

mapping of alpha vectors to label

Referenced by predict(), and processBatch().

std::string sgpp::datadriven::BatchLearner::batch
protected

because testdata is added to the following batch: save batch global

Referenced by trainBatch().

base::BatchConfiguration sgpp::datadriven::BatchLearner::batchConf
protected

configuration for the BatchLearner

Referenced by applyWeight(), BatchLearner(), predict(), processBatch(), and trainBatch().

int sgpp::datadriven::BatchLearner::batchnum = 0
protected

number of the current batch

Referenced by processBatch(), and trainBatch().

unsigned int sgpp::datadriven::BatchLearner::bs = 0
protected

count of lines read for the batch

Referenced by trainBatch().

std::map<int, base::DataMatrix*> sgpp::datadriven::BatchLearner::dataInBatch
protected

mapping of data in each batch to label

Referenced by applyWeight(), processBatch(), and stringToDataMatrix().

int sgpp::datadriven::BatchLearner::dataLine = 0
protected

stores the current line number from @DATA

Referenced by trainBatch().

size_t sgpp::datadriven::BatchLearner::dimensions = 0
protected

count of dimensions in the data

Referenced by predict(), processBatch(), stringToDataMatrix(), and stringToDataVector().

base::RegularGridConfiguration sgpp::datadriven::BatchLearner::gridConf
protected

configuration for the grids

Referenced by BatchLearner(), and processBatch().

std::map<int, base::LinearGrid*> sgpp::datadriven::BatchLearner::grids
protected

mapping of grids to label

Referenced by predict(), and processBatch().

bool sgpp::datadriven::BatchLearner::isFinished = false
protected

indicates whether the stream has been read to the end

accuracy over all predictions done so far (including the ones if batchConfig.testsize > 0 )

Referenced by getIsFinished(), and trainBatch().

solver::SLESolver* sgpp::datadriven::BatchLearner::myCG
protected

solver

Referenced by BatchLearner(), and processBatch().

std::map<int, double> sgpp::datadriven::BatchLearner::normFactors
protected

mapping of factors for the normalization to label

Referenced by predict(), and processBatch().

std::map<int, size_t> sgpp::datadriven::BatchLearner::occurences
protected

mapping of count of items to label

Referenced by applyWeight(), and processBatch().

bool sgpp::datadriven::BatchLearner::reachedData = false
protected

flag if "\@DATA" has been reached in the arff

Referenced by trainBatch().

std::fstream sgpp::datadriven::BatchLearner::reader
protected

stream to read in the arff file

Referenced by BatchLearner(), and trainBatch().

solver::SLESolverConfiguration sgpp::datadriven::BatchLearner::solverConf
protected

configuration for the solver

Referenced by BatchLearner(), and processBatch().

int sgpp::datadriven::BatchLearner::t_correct = 0
protected

items predicted correct

Referenced by trainBatch().

int sgpp::datadriven::BatchLearner::t_total = 0
protected

total items tested

Referenced by trainBatch().


The documentation for this class was generated from the following files: