SG++-Doxygen-Documentation
|
DataSourceCrossValidation is a high level interface to provide functionality for processing data using a cross validation enviroment. More...
#include <DataSourceCrossValidation.hpp>
Public Member Functions | |
DataSourceCrossValidation (const DataSourceConfig &dataSourceConfig, const CrossvalidationConfiguration &crossValidationconfig, DataShufflingFunctorCrossValidation *shuffling, SampleProvider *sampleProvider) | |
Constructor. More... | |
const CrossvalidationConfiguration & | getCrossValidationConfig () const |
Gets the configuration for the cross validation. More... | |
Dataset * | getValidationData () override |
Returns the data that is used for validation, i.e. More... | |
void | reset () |
Resets the state of the the sample provider to begin a new training epoch. More... | |
void | setFold (size_t foldIdx) |
Sets the next fold idx to be used for cross validation. More... | |
Public Member Functions inherited from sgpp::datadriven::DataSource | |
DataSourceIterator | begin () |
Return an iterator object pointing to the first batch of this DataSource. More... | |
DataSource (DataSourceConfig config, SampleProvider *sampleProvider) | |
Constructor. More... | |
DataSourceIterator | end () |
Return an iterator object pointing to the last possible batch of this DataSource. More... | |
const DataSourceConfig & | getConfig () const |
Read only access to the configuration used by DataSource and underlying SampleProvider. More... | |
size_t | getCurrentIteration () const |
Return how many batches have already been requested from this DataSource. More... | |
virtual Dataset * | getNextSamples () |
Request data from the underlying SampleProvider as specified in the provided configuration object upon construction. More... | |
virtual | ~DataSource ()=default |
Additional Inherited Members | |
Protected Attributes inherited from sgpp::datadriven::DataSource | |
DataSourceConfig | config |
Configuration file that determines all relevant properties of the object. More... | |
size_t | currentIteration |
counter variable if data is requested in batches. More... | |
DataTransformation * | dataTransformation |
pointer to DataTransformation to perform transformations on init. More... | |
std::unique_ptr< SampleProvider > | sampleProvider |
pointer to sample provider that actually handles data aquisition. More... | |
DataSourceCrossValidation is a high level interface to provide functionality for processing data using a cross validation enviroment.
That is retrieving a certain fold for validation and the rest of the data for training. Note that memory-wise this is very costly and not tractable for large data.
sgpp::datadriven::DataSourceCrossValidation::DataSourceCrossValidation | ( | const DataSourceConfig & | dataSourceConfig, |
const CrossvalidationConfiguration & | crossValidationconfig, | ||
DataShufflingFunctorCrossValidation * | shuffling, | ||
SampleProvider * | sampleProvider | ||
) |
Constructor.
dataSourceConfig | configuration of the data source |
crossValidationconfig | configuration of the cross validation |
shuffling | cross validation shuffling that is used by the sample provider instance |
sampleProvider | the sample provider to operate on. |
const CrossvalidationConfiguration & sgpp::datadriven::DataSourceCrossValidation::getCrossValidationConfig | ( | ) | const |
Gets the configuration for the cross validation.
|
overridevirtual |
Returns the data that is used for validation, i.e.
the current fold.d If all folds were already iterated over, this method throws.
Implements sgpp::datadriven::DataSource.
void sgpp::datadriven::DataSourceCrossValidation::reset | ( | ) |
Resets the state of the the sample provider to begin a new training epoch.
References sgpp::datadriven::DataShufflingFunctorCrossValidation::getCurrentFoldSize().
void sgpp::datadriven::DataSourceCrossValidation::setFold | ( | size_t | foldIdx | ) |
Sets the next fold idx to be used for cross validation.
foldIdx | index of the fold |
References sgpp::datadriven::DataShufflingFunctorCrossValidation::setFold().