24 from daal
import step1Local, step2Master
25 from daal.algorithms
import classifier
26 from daal.algorithms.multinomial_naive_bayes
import training, prediction
27 from daal.data_management
import FileDataSource, DataSourceIface
29 utils_folder = os.path.realpath(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
30 if utils_folder
not in sys.path:
31 sys.path.insert(0, utils_folder)
32 from utils
import printNumericTables, createSparseTable
34 DAAL_PREFIX = os.path.join(
'..',
'data')
37 trainDatasetFileNames = [
38 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_csr.csv'),
39 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_csr.csv'),
40 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_csr.csv'),
41 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_csr.csv')
44 trainGroundTruthFileNames = [
45 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_labels.csv'),
46 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_labels.csv'),
47 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_labels.csv'),
48 os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_labels.csv')
51 testDatasetFileName = os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_test_csr.csv')
52 testGroundTruthFileName = os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_test_labels.csv')
56 nTrainVectorsInBlock = 8000
57 nTestObservations = 2000
60 predictionResult =
None
61 trainData = [0] * nBlocks
66 global trainData, trainingResult
68 masterAlgorithm = training.Distributed(step2Master, nClasses, method=training.fastCSR)
70 for i
in range(nBlocks):
72 trainData[i] = createSparseTable(trainDatasetFileNames[i])
75 trainLabelsSource = FileDataSource(
76 trainGroundTruthFileNames[i], DataSourceIface.doAllocateNumericTable,
77 DataSourceIface.doDictionaryFromContext
81 trainLabelsSource.loadDataBlock(nTrainVectorsInBlock)
84 localAlgorithm = training.Distributed(step1Local, nClasses, method=training.fastCSR)
87 localAlgorithm.input.set(classifier.training.data, trainData[i])
88 localAlgorithm.input.set(classifier.training.labels, trainLabelsSource.getNumericTable())
92 masterAlgorithm.input.add(training.partialModels, localAlgorithm.compute())
95 masterAlgorithm.compute()
96 trainingResult = masterAlgorithm.finalizeCompute()
100 global predictionResult, testData
103 testData = createSparseTable(testDatasetFileName)
106 algorithm = prediction.Batch(nClasses, method=prediction.fastCSR)
109 algorithm.input.setTable(classifier.prediction.data, testData)
110 algorithm.input.setModel(classifier.prediction.model, trainingResult.get(classifier.training.model))
113 predictionResult = algorithm.compute()
118 testGroundTruth = FileDataSource(
119 testGroundTruthFileName, DataSourceIface.doAllocateNumericTable,
120 DataSourceIface.doDictionaryFromContext
122 testGroundTruth.loadDataBlock(nTestObservations)
125 testGroundTruth.getNumericTable(),
126 predictionResult.get(classifier.prediction.prediction),
127 "Ground truth",
"Classification results",
128 "NaiveBayes classification results (first 20 observations):", 20, 15, flt64=
False
131 if __name__ ==
"__main__":