31 from __future__
import print_function
33 from daal.algorithms
import regression
34 from daal.algorithms
import decision_tree
35 import daal.algorithms.decision_tree.regression
36 import daal.algorithms.decision_tree.regression.training
38 from daal.data_management
import FileDataSource, DataSourceIface, NumericTableIface, HomogenNumericTable, MergedNumericTable
41 trainDatasetFileName =
"../data/batch/decision_tree_train.csv"
42 pruneDatasetFileName =
"../data/batch/decision_tree_prune.csv"
50 trainDataSource = FileDataSource(
51 trainDatasetFileName, DataSourceIface.notAllocateNumericTable, DataSourceIface.doDictionaryFromContext
55 trainData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
56 trainGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
57 mergedData = MergedNumericTable(trainData, trainGroundTruth)
60 trainDataSource.loadDataBlock(mergedData)
63 pruneDataSource = FileDataSource(
64 pruneDatasetFileName, DataSourceIface.notAllocateNumericTable, DataSourceIface.doDictionaryFromContext
68 pruneData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
69 pruneGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
70 pruneMergedData = MergedNumericTable(pruneData, pruneGroundTruth)
73 pruneDataSource.loadDataBlock(pruneMergedData)
76 algorithm = decision_tree.regression.training.Batch()
79 algorithm.input.set(decision_tree.regression.training.data, trainData)
80 algorithm.input.set(decision_tree.regression.training.dependentVariables, trainGroundTruth)
81 algorithm.input.set(decision_tree.regression.training.dataForPruning, pruneData)
82 algorithm.input.set(decision_tree.regression.training.dependentVariablesForPruning, pruneGroundTruth)
85 return algorithm.compute()
89 class PrintNodeVisitor(regression.TreeNodeVisitor):
92 super(PrintNodeVisitor, self).__init__()
94 def onLeafNode(self, level, response):
96 for i
in range(level):
98 print(
"Level {}, leaf node. Response value = {:.4g}".format(level, response))
102 def onSplitNode(self, level, featureIndex, featureValue):
104 for i
in range(level):
106 print(
"Level {}, split node. Feature index = {}, feature value = {:.4g}".format(level, featureIndex, featureValue))
111 visitor = PrintNodeVisitor()
112 m.traverseDF(visitor)
114 if __name__ ==
"__main__":
116 trainingResult = trainModel()
117 printModel(trainingResult.get(decision_tree.regression.training.model))