SG++-Doxygen-Documentation
sgpp::datadriven::DataSourceCrossValidation Class Reference

DataSourceCrossValidation is a high level interface to provide functionality for processing data using a cross validation enviroment. More...

#include <DataSourceCrossValidation.hpp>

Inheritance diagram for sgpp::datadriven::DataSourceCrossValidation:
sgpp::datadriven::DataSource

Public Member Functions

 DataSourceCrossValidation (const DataSourceConfig &dataSourceConfig, const CrossvalidationConfiguration &crossValidationconfig, DataShufflingFunctorCrossValidation *shuffling, SampleProvider *sampleProvider)
 Constructor. More...
 
const CrossvalidationConfigurationgetCrossValidationConfig () const
 Gets the configuration for the cross validation. More...
 
DatasetgetValidationData () override
 Returns the data that is used for validation, i.e. More...
 
void reset ()
 Resets the state of the the sample provider to begin a new training epoch. More...
 
void setFold (size_t foldIdx)
 Sets the next fold idx to be used for cross validation. More...
 
- Public Member Functions inherited from sgpp::datadriven::DataSource
DataSourceIterator begin ()
 Return an iterator object pointing to the first batch of this DataSource. More...
 
 DataSource (DataSourceConfig config, SampleProvider *sampleProvider)
 Constructor. More...
 
DataSourceIterator end ()
 Return an iterator object pointing to the last possible batch of this DataSource. More...
 
const DataSourceConfiggetConfig () const
 Read only access to the configuration used by DataSource and underlying SampleProvider. More...
 
size_t getCurrentIteration () const
 Return how many batches have already been requested from this DataSource. More...
 
virtual DatasetgetNextSamples ()
 Request data from the underlying SampleProvider as specified in the provided configuration object upon construction. More...
 
virtual ~DataSource ()=default
 

Additional Inherited Members

- Protected Attributes inherited from sgpp::datadriven::DataSource
DataSourceConfig config
 Configuration file that determines all relevant properties of the object. More...
 
size_t currentIteration
 counter variable if data is requested in batches. More...
 
DataTransformationdataTransformation
 pointer to DataTransformation to perform transformations on init. More...
 
std::unique_ptr< SampleProvidersampleProvider
 pointer to sample provider that actually handles data aquisition. More...
 

Detailed Description

DataSourceCrossValidation is a high level interface to provide functionality for processing data using a cross validation enviroment.

That is retrieving a certain fold for validation and the rest of the data for training. Note that memory-wise this is very costly and not tractable for large data.

Constructor & Destructor Documentation

◆ DataSourceCrossValidation()

sgpp::datadriven::DataSourceCrossValidation::DataSourceCrossValidation ( const DataSourceConfig dataSourceConfig,
const CrossvalidationConfiguration crossValidationconfig,
DataShufflingFunctorCrossValidation shuffling,
SampleProvider sampleProvider 
)

Constructor.

Parameters
dataSourceConfigconfiguration of the data source
crossValidationconfigconfiguration of the cross validation
shufflingcross validation shuffling that is used by the sample provider instance
sampleProviderthe sample provider to operate on.

Member Function Documentation

◆ getCrossValidationConfig()

const CrossvalidationConfiguration & sgpp::datadriven::DataSourceCrossValidation::getCrossValidationConfig ( ) const

Gets the configuration for the cross validation.

Returns
configuration for the cross validation

◆ getValidationData()

Dataset * sgpp::datadriven::DataSourceCrossValidation::getValidationData ( )
overridevirtual

Returns the data that is used for validation, i.e.

the current fold.d If all folds were already iterated over, this method throws.

Returns
pointer to the validation dataset

Implements sgpp::datadriven::DataSource.

◆ reset()

void sgpp::datadriven::DataSourceCrossValidation::reset ( )

Resets the state of the the sample provider to begin a new training epoch.

References sgpp::datadriven::DataShufflingFunctorCrossValidation::getCurrentFoldSize().

◆ setFold()

void sgpp::datadriven::DataSourceCrossValidation::setFold ( size_t  foldIdx)

Sets the next fold idx to be used for cross validation.

Parameters
foldIdxindex of the fold

References sgpp::datadriven::DataShufflingFunctorCrossValidation::setFold().


The documentation for this class was generated from the following files: