This example shows how to perform offline/online-classification using sparse grid density estimation and matrix decomposition methods.
It creates an instance of LearnerSGDEOnOff and runs the function train() where the main functionality is implemented.
Currently, only binary classification with class labels -1 and 1 is possible.
The example provides the option to execute several runs over differently ordered data and perform a 5-fold cross-validation within each run. Therefore, already randomly ordered and partitioned data is required. Average results from several runs might be more reliable in an online-learning scenario, because the ordering of the data points seen by the learner can affect the result.
int main(
int argc,
char *argv[]) {
#ifdef USE_MPI
#ifdef USE_GSL
omp_set_num_threads(1);
std::cout << "LearnerSGDEOnOffParallelTest" << std::endl;
if (argc != 5) {
std::cout << "Usage:" << std::endl
<< "learnerSGDEOnOffParallelTest <trainDataFile> "
<< "<testDataFile> <batchSize> <refPeriod>"
<< std::endl;
return -1;
}
Specify the number of runs to perform. If only one specific example should be executed, set totalSets=1.
Get the training, test and validation data
Specify the number of classes and the corresponding class labels.
size_t classNum = 2;
classLabels[0] = -1;
classLabels[1] = 1;
The grid configuration.
std::cout << "# create grid config" << std::endl;
Configure regularization.
std::cout << "# create regularization config" << std::endl;
regularizationConfig.
lambda_ = 0.01;
Select the desired decomposition type for the offline step. Note: Refinement/Coarsening only possible for Cholesky decomposition.
std::string decompType;
decompType = "Incomplete Cholesky decomposition on Dense Matrix";
std::cout << "Decomposition type: " << decompType << std::endl;
Configure adaptive refinement (if Cholesky is chosen). As refinement monitor the periodic monitor or the convergence monitor can be chosen. Possible refinement indicators are surplus refinement, data-based refinement, zero-crossings-based refinement.
std::cout << "# create adaptive refinement configuration" << std::endl;
std::string refMonitor;
refMonitor = "periodic";
size_t refPeriod = 0;
parseInputValue(argv[4], refPeriod);
double accDeclineThreshold = 0.001;
size_t accDeclineBufferSize = 140;
size_t minRefInterval = 10;
std::cout << "Refinement monitor: " << refMonitor << std::endl;
std::string refType;
refType = "zero";
std::cout << "Refinement type: " << refType << std::endl;
Specify number of refinement steps and the max number of grid points to refine each step.
double beta = 0.0;
bool usePrior = false;
size_t batchSize = 0;
parseInputValue(argv[3], batchSize);
Create the learner.
std::cout << "# create learner" << std::endl;
adaptConfig,
regularizationConfig,
densityEstimationConfig,
trainDataset,
testDataset,
nullptr,
classLabels,
classNum,
usePrior,
beta,
scheduler);
size_t maxDataPasses = 1;
Learn the data.
MPI_Barrier(MPI_COMM_WORLD);
std::cout << "# start to train the learner" << std::endl;
learner.trainParallel(batchSize, maxDataPasses, refType, refMonitor, refPeriod,
accDeclineThreshold,
accDeclineBufferSize, minRefInterval);
double deltaTime = stopwatch.
stop();
MPI_Barrier(MPI_COMM_WORLD);
Accuracy on test data.
double acc = learner.getAccuracy();
std::cout << "# accuracy (test data): " << acc << std::endl;
std::cout << "# delta time training: " << deltaTime << std::endl;
} else {
std::cout << "# accuracy (client, test data): " << acc << std::endl;
std::cout << "# delta time training (client): " << deltaTime << std::endl;
}
#else
std::cout << "GSL not enabled at compile time" << std::endl;
#endif // USE_GSL
#endif
}
#ifdef USE_MPI
std::cout << "# loading file: " << filename << std::endl;
std::cout << "# Failed to read dataset! " << filename << std::endl;
exit(-1);
} else {
std::cout <<
"# dataset dimensionality: " << dataset.
getDimension() << std::endl;
}
}
void parseInputValue(char *inputString, size_t &outputValue) {
std::basic_stringstream<char> argumentParser = std::stringstream(inputString);
argumentParser >> outputValue;
if (argumentParser.fail()) {
}
}
#endif