33 from daal.algorithms.classifier.quality_metric
import multiclass_confusion_matrix
34 from daal.algorithms
import svm
35 from daal.algorithms
import kernel_function
36 from daal.algorithms
import multi_class_classifier
37 from daal.algorithms
import classifier
38 from daal.data_management
import (
39 DataSourceIface, FileDataSource, readOnly, BlockDescriptor, HomogenNumericTable,
40 NumericTableIface, MergedNumericTable
43 utils_folder = os.path.realpath(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
44 if utils_folder
not in sys.path:
45 sys.path.insert(0, utils_folder)
46 from utils
import printNumericTables, printNumericTable
49 DATA_PREFIX = os.path.join(
'..',
'data',
'batch')
50 trainDatasetFileName = os.path.join(DATA_PREFIX,
'svm_multi_class_train_dense.csv')
51 testDatasetFileName = os.path.join(DATA_PREFIX,
'svm_multi_class_test_dense.csv')
56 training = svm.training.Batch(fptype=np.float64)
57 prediction = svm.prediction.Batch(fptype=np.float64)
61 predictionResult =
None
64 kernel = kernel_function.linear.Batch(fptype=np.float64)
66 qualityMetricSetResult =
None
67 predictedLabels =
None
68 groundTruthLabels =
None
75 trainDataSource = FileDataSource(
76 trainDatasetFileName, DataSourceIface.notAllocateNumericTable,
77 DataSourceIface.doDictionaryFromContext
81 trainData = HomogenNumericTable(nFeatures, 0, NumericTableIface.doNotAllocate)
82 trainGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.doNotAllocate)
83 mergedData = MergedNumericTable(trainData, trainGroundTruth)
86 trainDataSource.loadDataBlock(mergedData)
89 algorithm = multi_class_classifier.training.Batch(nClasses,fptype=np.float64)
91 algorithm.parameter.training = training
92 algorithm.parameter.prediction = prediction
95 algorithm.input.set(classifier.training.data, trainData)
96 algorithm.input.set(classifier.training.labels, trainGroundTruth)
99 trainingResult = algorithm.compute()
103 global predictionResult, groundTruthLabels
106 testDataSource = FileDataSource(
107 testDatasetFileName, DataSourceIface.doAllocateNumericTable,
108 DataSourceIface.doDictionaryFromContext
112 testData = HomogenNumericTable(nFeatures, 0, NumericTableIface.doNotAllocate)
113 groundTruthLabels = HomogenNumericTable(1, 0, NumericTableIface.doNotAllocate)
114 mergedData = MergedNumericTable(testData, groundTruthLabels)
117 testDataSource.loadDataBlock(mergedData)
120 algorithm = multi_class_classifier.prediction.Batch(nClasses,fptype=np.float64)
122 algorithm.parameter.training = training
123 algorithm.parameter.prediction = prediction
126 algorithm.input.setTable(classifier.prediction.data, testData)
127 algorithm.input.setModel(classifier.prediction.model, trainingResult.get(classifier.training.model))
130 predictionResult = algorithm.compute()
133 def testModelQuality():
134 global predictedLabels, qualityMetricSetResult
137 predictedLabels = predictionResult.get(classifier.prediction.prediction)
140 qualityMetricSet = multi_class_classifier.quality_metric_set.Batch(nClasses)
141 input = qualityMetricSet.getInputDataCollection().getInput(multi_class_classifier.quality_metric_set.confusionMatrix)
143 input.set(multiclass_confusion_matrix.predictedLabels, predictedLabels)
144 input.set(multiclass_confusion_matrix.groundTruthLabels, groundTruthLabels)
148 qualityMetricSetResult = qualityMetricSet.compute()
154 groundTruthLabels, predictedLabels,
155 "Ground truth",
"Classification results",
156 "SVM classification results (first 20 observations):", 20, interval=15, flt64=
False
159 qualityMetricResult = qualityMetricSetResult.getResult(multi_class_classifier.quality_metric_set.confusionMatrix)
160 printNumericTable(qualityMetricResult.get(multiclass_confusion_matrix.confusionMatrix),
"Confusion matrix:")
162 block = BlockDescriptor()
163 qualityMetricsTable = qualityMetricResult.get(multiclass_confusion_matrix.multiClassMetrics)
164 qualityMetricsTable.getBlockOfRows(0, 1, readOnly, block)
165 qualityMetricsData = block.getArray().flatten()
166 print(
"Average accuracy: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.averageAccuracy]))
167 print(
"Error rate: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.errorRate]))
168 print(
"Micro precision: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.microPrecision]))
169 print(
"Micro recall: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.microRecall]))
170 print(
"Micro F-score: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.microFscore]))
171 print(
"Macro precision: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.macroPrecision]))
172 print(
"Macro recall: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.macroRecall]))
173 print(
"Macro F-score: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.macroFscore]))
174 qualityMetricsTable.releaseBlockOfRows(block)
176 if __name__ ==
"__main__":
177 training.parameter.cacheSize = 100000000
178 training.parameter.kernel = kernel
179 prediction.parameter.kernel = kernel