30 from __future__
import print_function
32 from daal.algorithms
import classifier
33 from daal.algorithms
import decision_forest
34 import daal.algorithms.decision_forest.classification
35 import daal.algorithms.decision_forest.classification.training
37 from daal.data_management
import (
38 FileDataSource, HomogenNumericTable, MergedNumericTable, NumericTableIface, DataSourceIface, features
42 trainDatasetFileName =
"../data/batch/df_classification_train.csv"
43 categoricalFeaturesIndices = [2]
48 minObservationsInLeafNode = 8
57 trainData, trainDependentVariable = loadData(trainDatasetFileName)
60 algorithm = decision_forest.classification.training.Batch(nClasses)
63 algorithm.input.set(classifier.training.data, trainData)
64 algorithm.input.set(classifier.training.labels, trainDependentVariable)
66 algorithm.parameter.nTrees = nTrees
67 algorithm.parameter.featuresPerNode = nFeatures
68 algorithm.parameter.minObservationsInLeafNode = minObservationsInLeafNode
69 algorithm.parameter.maxTreeDepth = maxTreeDepth
72 return algorithm.compute()
75 def loadData(fileName):
78 trainDataSource = FileDataSource(
79 fileName, DataSourceIface.notAllocateNumericTable, DataSourceIface.doDictionaryFromContext
83 data = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
84 dependentVar = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
85 mergedData = MergedNumericTable(data, dependentVar)
88 trainDataSource.loadDataBlock(mergedData)
90 dictionary = data.getDictionary()
91 for i
in range(len(categoricalFeaturesIndices)):
92 dictionary[categoricalFeaturesIndices[i]].featureType = features.DAAL_CATEGORICAL
94 return data, dependentVar
98 class PrintNodeVisitor(classifier.TreeNodeVisitor):
101 super(PrintNodeVisitor, self).__init__()
103 def onLeafNode(self, level, response):
105 for i
in range(level):
107 print(
"Level {}, leaf node. Response value = {}".format(level, response))
110 def onSplitNode(self, level, featureIndex, featureValue):
112 for i
in range(level):
114 print(
"Level {}, split node. Feature index = {}, feature value = {:.6g}".format(level, featureIndex, featureValue))
119 visitor = PrintNodeVisitor()
120 print(
"Number of trees: {}".format(m.getNumberOfTrees()))
121 for i
in range(m.getNumberOfTrees()):
122 print(
"Tree #{}".format(i))
123 m.traverseDF(i, visitor)
126 if __name__ ==
"__main__":
128 trainingResult = trainModel()
129 printModel(trainingResult.get(classifier.training.model))