30 from __future__
import print_function
32 from daal.algorithms
import classifier
33 from daal.algorithms
import decision_tree
34 import daal.algorithms.decision_tree.classification
35 import daal.algorithms.decision_tree.classification.training
37 from daal.data_management
import (
38 DataSourceIface, NumericTableIface, HomogenNumericTable, MergedNumericTable, FileDataSource
42 trainDatasetFileName =
"../data/batch/decision_tree_train.csv"
43 pruneDatasetFileName =
"../data/batch/decision_tree_prune.csv"
52 trainDataSource = FileDataSource(
53 trainDatasetFileName, DataSourceIface.notAllocateNumericTable, DataSourceIface.doDictionaryFromContext
57 trainData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
58 trainGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
59 mergedData = MergedNumericTable(trainData, trainGroundTruth)
62 trainDataSource.loadDataBlock(mergedData)
65 pruneDataSource = FileDataSource(
66 pruneDatasetFileName, DataSourceIface.notAllocateNumericTable, DataSourceIface.doDictionaryFromContext
70 pruneData = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
71 pruneGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
72 pruneMergedData = MergedNumericTable(pruneData, pruneGroundTruth)
75 pruneDataSource.loadDataBlock(pruneMergedData)
78 algorithm = decision_tree.classification.training.Batch(nClasses)
81 algorithm.input.set(classifier.training.data, trainData)
82 algorithm.input.set(classifier.training.labels, trainGroundTruth)
83 algorithm.input.set(decision_tree.classification.training.dataForPruning, pruneData)
84 algorithm.input.set(decision_tree.classification.training.labelsForPruning, pruneGroundTruth)
87 return algorithm.compute()
92 class PrintNodeVisitor(classifier.TreeNodeVisitor):
95 super(PrintNodeVisitor, self).__init__()
97 def onLeafNode(self, level, response):
99 for i
in range(level):
101 print(
"Level {}, leaf node. Response value = {}".format(level, response))
105 def onSplitNode(self, level, featureIndex, featureValue):
107 for i
in range(level):
109 print(
"Level {}, split node. Feature index = {}, feature value = {:.4g}".format(level, featureIndex, featureValue))
115 visitor = PrintNodeVisitor()
116 m.traverseDF(visitor)
119 if __name__ ==
"__main__":
121 trainingResult = trainModel()
122 printModel(trainingResult.get(classifier.training.model))