24 from daal.algorithms.multinomial_naive_bayes
import prediction, training
25 from daal.algorithms
import classifier
26 from daal.data_management
import FileDataSource, DataSourceIface
28 utils_folder = os.path.realpath(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
29 if utils_folder
not in sys.path:
30 sys.path.insert(0, utils_folder)
31 from utils
import printNumericTables, createSparseTable
33 DAAL_PREFIX = os.path.join(
'..',
'data')
36 trainDatasetFileName = os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_csr.csv')
37 trainGroundTruthFileName = os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_train_labels.csv')
39 testDatasetFileName = os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_test_csr.csv')
40 testGroundTruthFileName = os.path.join(DAAL_PREFIX,
'batch',
'naivebayes_test_labels.csv')
42 nTrainObservations = 8000
43 nTestObservations = 2000
47 predictionResult =
None
54 trainGroundTruthSource = FileDataSource(
55 trainGroundTruthFileName,
56 DataSourceIface.doAllocateNumericTable,
57 DataSourceIface.doDictionaryFromContext
61 trainData = createSparseTable(trainDatasetFileName)
62 trainGroundTruthSource.loadDataBlock(nTrainObservations)
65 algorithm = training.Batch(nClasses, method=training.fastCSR)
68 algorithm.input.set(classifier.training.data, trainData)
69 algorithm.input.set(classifier.training.labels, trainGroundTruthSource.getNumericTable())
72 trainingResult = algorithm.compute()
76 global predictionResult
79 testData = createSparseTable(testDatasetFileName)
82 algorithm = prediction.Batch(nClasses, method=prediction.fastCSR)
85 algorithm.input.setTable(classifier.prediction.data, testData)
86 algorithm.input.setModel(classifier.prediction.model, trainingResult.get(classifier.training.model))
89 predictionResult = algorithm.compute()
94 testGroundTruth = FileDataSource(
95 testGroundTruthFileName,
96 DataSourceIface.doAllocateNumericTable,
97 DataSourceIface.doDictionaryFromContext
100 testGroundTruth.loadDataBlock(nTestObservations)
103 testGroundTruth.getNumericTable(),
104 predictionResult.get(classifier.prediction.prediction),
105 "Ground truth",
"Classification results",
106 "NaiveBayes classification results (first 20 observations):", 20, 15, flt64=
False
109 if __name__ ==
"__main__":