24 from daal.algorithms.decision_tree.regression
import prediction, training
25 from daal.data_management
import (
26 FileDataSource, DataSourceIface, NumericTableIface, HomogenNumericTable, MergedNumericTable
29 utils_folder = os.path.realpath(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
30 if utils_folder
not in sys.path:
31 sys.path.insert(0, utils_folder)
32 from utils
import printNumericTables
34 DAAL_PREFIX = os.path.join(
'..',
'data')
37 trainDatasetFileName = os.path.join(DAAL_PREFIX,
'batch',
'decision_tree_train.csv')
38 pruneDatasetFileName = os.path.join(DAAL_PREFIX,
'batch',
'decision_tree_prune.csv')
39 testDatasetFileName = os.path.join(DAAL_PREFIX,
'batch',
'decision_tree_test.csv')
45 predictionResult =
None
46 testGroundTruth =
None
53 trainDataSource = FileDataSource(
55 DataSourceIface.notAllocateNumericTable,
56 DataSourceIface.doDictionaryFromContext
60 trainData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
61 trainGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
62 mergedData = MergedNumericTable(trainData, trainGroundTruth)
65 trainDataSource.loadDataBlock(mergedData)
68 pruneDataSource = FileDataSource(
70 DataSourceIface.notAllocateNumericTable,
71 DataSourceIface.doDictionaryFromContext
75 pruneData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
76 pruneGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
77 pruneMergedData = MergedNumericTable(pruneData, pruneGroundTruth)
80 pruneDataSource.loadDataBlock(pruneMergedData)
83 algorithm = training.Batch()
86 algorithm.input.set(training.data, trainData)
87 algorithm.input.set(training.dependentVariables, trainGroundTruth)
88 algorithm.input.set(training.dataForPruning, pruneData)
89 algorithm.input.set(training.dependentVariablesForPruning, pruneGroundTruth)
92 trainingResult = algorithm.compute()
93 model = trainingResult.get(training.model)
96 global testGroundTruth, predictionResult
99 testDataSource = FileDataSource(
101 DataSourceIface.notAllocateNumericTable,
102 DataSourceIface.doDictionaryFromContext
106 testData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
107 testGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
108 mergedData = MergedNumericTable(testData, testGroundTruth)
111 testDataSource.loadDataBlock(mergedData)
114 algorithm = prediction.Batch()
118 algorithm.input.setTable(prediction.data, testData)
119 algorithm.input.setModel(prediction.model, model)
122 predictionResult = algorithm.compute()
127 printNumericTables(testGroundTruth, predictionResult.get(prediction.prediction),
128 "Ground truth",
"Regression results",
129 "Decision tree regression results (first 20 observations):",
132 if __name__ ==
"__main__":