30 from __future__
import print_function
32 from daal
import algorithms
33 from daal.algorithms
import decision_forest
34 import daal.algorithms.decision_forest.regression
35 import daal.algorithms.decision_forest.regression.training
37 from daal.data_management
import (
38 FileDataSource, DataSourceIface, NumericTableIface, HomogenNumericTable, MergedNumericTable, features
42 trainDatasetFileName =
"../data/batch/df_regression_train.csv"
43 categoricalFeaturesIndices = [3]
53 trainData, trainDependentVariable = loadData(trainDatasetFileName)
56 algorithm = decision_forest.regression.training.Batch()
59 algorithm.input.set(decision_forest.regression.training.data, trainData)
60 algorithm.input.set(decision_forest.regression.training.dependentVariable, trainDependentVariable)
62 algorithm.parameter.nTrees = nTrees
65 return algorithm.compute()
68 def loadData(fileName):
71 trainDataSource = FileDataSource(
72 fileName, DataSourceIface.notAllocateNumericTable, DataSourceIface.doDictionaryFromContext
76 data = HomogenNumericTable(nFeatures, 0, NumericTableIface.notAllocate)
77 dependentVar = HomogenNumericTable(1, 0, NumericTableIface.notAllocate)
78 mergedData = MergedNumericTable(data, dependentVar)
81 trainDataSource.loadDataBlock(mergedData)
83 dictionary = data.getDictionary()
84 for i
in range(len(categoricalFeaturesIndices)):
85 dictionary[categoricalFeaturesIndices[i]].featureType = features.DAAL_CATEGORICAL
87 return data, dependentVar
91 class PrintNodeVisitor(algorithms.regression.TreeNodeVisitor):
94 super(PrintNodeVisitor, self).__init__()
96 def onLeafNode(self, level, response):
98 for i
in range(level):
100 print(
"Level {}, leaf node. Response value = {:.4g}".format(level, response))
104 def onSplitNode(self, level, featureIndex, featureValue):
106 for i
in range(level):
108 print(
"Level {}, split node. Feature index = {}, feature value = {:.4g}".format(level, featureIndex, featureValue))
113 visitor = PrintNodeVisitor()
114 print(
"Number of trees: {}".format(m.getNumberOfTrees()))
115 for i
in range(m.getNumberOfTrees()):
116 print(
"Tree #{}".format(i))
117 m.traverseDF(i, visitor)
119 if __name__ ==
"__main__":
121 trainingResult = trainModel()
122 printModel(trainingResult.get(decision_forest.regression.training.model))