24 from daal.algorithms.decision_tree.classification
import prediction, training
25 from daal.algorithms
import classifier
26 from daal.data_management
import (
27 FileDataSource, DataSourceIface, NumericTableIface, HomogenNumericTable, MergedNumericTable
29 utils_folder = os.path.realpath(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
30 if utils_folder
not in sys.path:
31 sys.path.insert(0, utils_folder)
32 from utils
import printNumericTables
34 DAAL_PREFIX = os.path.join(
'..',
'data')
37 trainDatasetFileName = os.path.join(DAAL_PREFIX,
'batch',
'decision_tree_train.csv')
38 pruneDatasetFileName = os.path.join(DAAL_PREFIX,
'batch',
'decision_tree_prune.csv')
39 testDatasetFileName = os.path.join(DAAL_PREFIX,
'batch',
'decision_tree_test.csv')
46 predictionResult =
None
47 testGroundTruth =
None
54 trainDataSource = FileDataSource(
56 DataSourceIface.notAllocateNumericTable,
57 DataSourceIface.doDictionaryFromContext
61 trainData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
62 trainGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
63 mergedData = MergedNumericTable(trainData, trainGroundTruth)
66 trainDataSource.loadDataBlock(mergedData)
69 pruneDataSource = FileDataSource(
71 DataSourceIface.notAllocateNumericTable,
72 DataSourceIface.doDictionaryFromContext
76 pruneData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
77 pruneGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
78 pruneMergedData = MergedNumericTable(pruneData, pruneGroundTruth)
81 pruneDataSource.loadDataBlock(pruneMergedData)
84 algorithm = training.Batch(nClasses)
87 algorithm.input.set(classifier.training.data, trainData)
88 algorithm.input.set(classifier.training.labels, trainGroundTruth)
89 algorithm.input.setTable(training.dataForPruning, pruneData)
90 algorithm.input.setTable(training.labelsForPruning, pruneGroundTruth)
93 trainingResult = algorithm.compute()
94 model = trainingResult.get(classifier.training.model)
97 global testGroundTruth, predictionResult
100 testDataSource = FileDataSource(
102 DataSourceIface.notAllocateNumericTable,
103 DataSourceIface.doDictionaryFromContext
107 testData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
108 testGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
109 mergedData = MergedNumericTable(testData, testGroundTruth)
112 testDataSource.loadDataBlock(mergedData)
115 algorithm = prediction.Batch()
119 algorithm.input.setTable(classifier.prediction.data, testData)
120 algorithm.input.setModel(classifier.prediction.model, model)
124 predictionResult = algorithm.compute()
131 predictionResult.get(classifier.prediction.prediction),
132 "Ground truth",
"Classification results",
133 "Decision tree classification results (first 20 observations):",
137 if __name__ ==
"__main__":