32 from daal.algorithms
import kernel_function
33 from daal.algorithms.classifier.quality_metric
import binary_confusion_matrix
34 from daal.algorithms
import svm
35 from daal.algorithms
import classifier
36 from daal.data_management
import (
37 DataSourceIface, FileDataSource, readOnly, BlockDescriptor,
38 HomogenNumericTable, NumericTableIface, MergedNumericTable
41 utils_folder = os.path.realpath(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
42 if utils_folder
not in sys.path:
43 sys.path.insert(0, utils_folder)
44 from utils
import printNumericTables, printNumericTable
47 DATA_PREFIX = os.path.join(
'..',
'data',
'batch')
48 trainDatasetFileName = os.path.join(DATA_PREFIX,
'svm_two_class_train_dense.csv')
49 testDatasetFileName = os.path.join(DATA_PREFIX,
'svm_two_class_test_dense.csv')
54 kernel = kernel_function.linear.Batch()
58 predictionResult =
None
59 qualityMetricSetResult =
None
61 predictedLabels =
None
62 groundTruthLabels =
None
69 trainDataSource = FileDataSource(
70 trainDatasetFileName, DataSourceIface.notAllocateNumericTable,
71 DataSourceIface.doDictionaryFromContext
75 trainData = HomogenNumericTable(nFeatures, 0, NumericTableIface.doNotAllocate)
76 trainGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.doNotAllocate)
77 mergedData = MergedNumericTable(trainData, trainGroundTruth)
80 trainDataSource.loadDataBlock(mergedData)
83 algorithm = svm.training.Batch()
85 algorithm.parameter.kernel = kernel
86 algorithm.parameter.cacheSize = 600000000
89 algorithm.input.set(classifier.training.data, trainData)
90 algorithm.input.set(classifier.training.labels, trainGroundTruth)
93 trainingResult = algorithm.compute()
96 global predictionResult, groundTruthLabels
99 testDataSource = FileDataSource(
100 testDatasetFileName, DataSourceIface.doAllocateNumericTable,
101 DataSourceIface.doDictionaryFromContext
105 testData = HomogenNumericTable(nFeatures, 0, NumericTableIface.doNotAllocate)
106 groundTruthLabels = HomogenNumericTable(1, 0, NumericTableIface.doNotAllocate)
107 mergedData = MergedNumericTable(testData, groundTruthLabels)
110 testDataSource.loadDataBlock(mergedData)
113 algorithm = svm.prediction.Batch()
115 algorithm.parameter.kernel = kernel
118 algorithm.input.setTable(classifier.prediction.data, testData)
119 algorithm.input.setModel(classifier.prediction.model, trainingResult.get(classifier.training.model))
123 predictionResult = algorithm.compute()
126 def testModelQuality():
127 global predictedLabels, qualityMetricSetResult, groundTruthLabels
130 predictedLabels = predictionResult.get(classifier.prediction.prediction)
133 qualityMetricSet = svm.quality_metric_set.Batch()
135 input = qualityMetricSet.getInputDataCollection().getInput(svm.quality_metric_set.confusionMatrix)
137 input.set(binary_confusion_matrix.predictedLabels, predictedLabels)
138 input.set(binary_confusion_matrix.groundTruthLabels, groundTruthLabels)
142 qualityMetricSetResult = qualityMetricSet.compute()
149 groundTruthLabels, predictedLabels,
150 "Ground truth",
"Classification results",
151 "SVM classification results (first 20 observations):", 20, interval=15, flt64=
False
155 qualityMetricResult = qualityMetricSetResult.getResult(svm.quality_metric_set.confusionMatrix)
156 printNumericTable(qualityMetricResult.get(binary_confusion_matrix.confusionMatrix),
"Confusion matrix:")
158 block = BlockDescriptor()
159 qualityMetricsTable = qualityMetricResult.get(binary_confusion_matrix.binaryMetrics)
160 qualityMetricsTable.getBlockOfRows(0, 1, readOnly, block)
161 qualityMetricsData = block.getArray().flatten()
162 print(
"Accuracy: {0:.3f}".format(qualityMetricsData[binary_confusion_matrix.accuracy]))
163 print(
"Precision: {0:.3f}".format(qualityMetricsData[binary_confusion_matrix.precision]))
164 print(
"Recall: {0:.3f}".format(qualityMetricsData[binary_confusion_matrix.recall]))
165 print(
"F-score: {0:.3f}".format(qualityMetricsData[binary_confusion_matrix.fscore]))
166 print(
"Specificity: {0:.3f}".format(qualityMetricsData[binary_confusion_matrix.specificity]))
167 print(
"AUC: {0:.3f}".format(qualityMetricsData[binary_confusion_matrix.AUC]))
168 qualityMetricsTable.releaseBlockOfRows(block)
170 if __name__ ==
"__main__":