This example shows how to perform online-classification using sparse grid density estimation and conjugate gradients method.
It creates an instance of LearnerSGDE and runs the function trainOnline() where the main functionality is implemented.
Currently, only binary classification with class labels -1 and 1 is possible.
The example provides the option to execute several runs over differently ordered data and perform a 5-fold cross-validation within each run. Therefore, already randomly ordered and partitioned data is required. Average results from several runs might be more reliable in an online-learning scenario, because the ordering of the data points seen by the learner can affect the result.
Specify the number of runs to perform. If only one specific example should be executed, set totalSets=1.
size_t totalSets = 1;
size_t totalFolds = 1;
double avgError = 0.0;
double avgErrorFolds = 0.0;
for (size_t numSets = 0; numSets < totalSets; numSets++) {
A vector to compute average classification error throughout the learning process. The length of the vector determines the total number of error observations.
for (size_t numFolds = 0; numFolds < totalFolds; numFolds++) {
Get the training, test and validation data
std::string filename = "../../datasets/ripley/ripleyGarcke.train.arff";
std::cout << "# loading file: " << filename << std::endl;
filename = "../../datasets/ripley/ripleyGarcke.test.arff";
std::cout << "# loading file: " << filename << std::endl;
Specify the ocurring class labels.
size_t classNum = 2;
classLabels[0] = -1;
classLabels[1] = 1;
The grid configuration.
std::cout << "# create grid config" << std::endl;
Configure adaptive refinement. As refinement monitor the periodic monitor or the convergence monitor can be chosen. Possible refinement indicators are surplus refinement, data-based refinement, zero-crossings-based refinement.
std::cout << "# create adaptive refinement config" << std::endl;
std::string refMonitor;
refMonitor = "periodic";
size_t refPeriod = 40;
double accDeclineThreshold = 0.001;
size_t accDeclineBufferSize = 140;
size_t minRefInterval = 10;
std::cout << "Refinement monitor: " << refMonitor << std::endl;
std::string refType;
refType = "zero";
std::cout << "Refinement type: " << refType << std::endl;
Specify number of refinement steps and the max number of grid points to refine each step.
Configure the CG solver. Note that the max number of iterations should be limited in order to obtain feasible runtimes, especially for large grids.
std::cout << "# create solver config" << std::endl;
solverConfig.
eps_ = 1e-10;
Configure regularization.
std::cout << "# create regularization config" << std::endl;
Configure cross-validation.
std::cout << "# create cross-validation config" << std::endl;
crossvalidationConfig.
lambda_ = 0.01;
crossvalidationConfig.
enable_ =
false;
crossvalidationConfig.
kfold_ = 5;
crossvalidationConfig.
seed_ = 1234567;
crossvalidationConfig.
silent_ =
true;
Create the learner.
std::cout << "# creating the learner" << std::endl;
regularizationConfig, crossvalidationConfig);
learner.initialize(trainData);
bool usePrior = false;
size_t maxDataPasses = 2;
Learn the data.
std::cout << "# start to train the learner" << std::endl;
learner.trainOnline(trainLabels, testData, testLabels, validData, validLabels, classLabels,
maxDataPasses, refType, refMonitor, refPeriod, accDeclineThreshold,
accDeclineBufferSize, minRefInterval, usePrior);
std::cout << "# finished training" << std::endl;
Accuracy on test and current training data.
double accTrain = learner.getAccuracy(trainData, trainLabels, 0.0);
std::cout << "Acc (train): " << accTrain << std::endl;
double accTest = learner.getAccuracy(testData, testLabels, 0.0);
std::cout << "Acc (test): " << accTest << std::endl;
avgErrorFolds += learner.error;
avgErrorsFolds.add(learner.avgErrors);
}
avgErrorFolds = avgErrorFolds / static_cast<double>(totalFolds);
Average accuracy on test data reagarding 5-fold cv.
if ((totalSets > 1) && (totalFolds > 1)) {
std::cout << "Average accuracy on test data (set " + std::to_string(numSets + 1) + "): "
<< (1.0 - avgErrorFolds) << "\n";
}
avgError += avgErrorFolds;
avgErrorFolds = 0.0;
avgErrorsFolds.mult(1.0 / static_cast<double>(totalFolds));
}