Python* API Reference for Intel® Data Analytics Acceleration Library 2020 Update 1

svm_multi_class_metrics_dense_batch.py

1 # file: svm_multi_class_metrics_dense_batch.py
2 #===============================================================================
3 # Copyright 2014-2020 Intel Corporation
4 #
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
8 #
9 # http://www.apache.org/licenses/LICENSE-2.0
10 #
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
16 #===============================================================================
17 
18 #
19 # ! Content:
20 # ! Python example of multi-class support vector machine (SVM) quality metrics
21 # !
22 # !*****************************************************************************
23 
24 #
25 
26 
27 #
28 
29 import os
30 import sys
31 import numpy as np
32 
33 from daal.algorithms.classifier.quality_metric import multiclass_confusion_matrix
34 from daal.algorithms import svm
35 from daal.algorithms import kernel_function
36 from daal.algorithms import multi_class_classifier
37 from daal.algorithms import classifier
38 from daal.data_management import (
39  DataSourceIface, FileDataSource, readOnly, BlockDescriptor, HomogenNumericTable,
40  NumericTableIface, MergedNumericTable
41 )
42 
43 utils_folder = os.path.realpath(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
44 if utils_folder not in sys.path:
45  sys.path.insert(0, utils_folder)
46 from utils import printNumericTables, printNumericTable
47 
48 # Input data set parameters
49 DATA_PREFIX = os.path.join('..', 'data', 'batch')
50 trainDatasetFileName = os.path.join(DATA_PREFIX, 'svm_multi_class_train_dense.csv')
51 testDatasetFileName = os.path.join(DATA_PREFIX, 'svm_multi_class_test_dense.csv')
52 
53 nFeatures = 20
54 nClasses = 5
55 
56 training = svm.training.Batch(fptype=np.float64)
57 prediction = svm.prediction.Batch(fptype=np.float64)
58 
59 # Model object for the multi-class classifier algorithm
60 trainingResult = None
61 predictionResult = None
62 
63 # Parameters for the multi-class classifier kernel function
64 kernel = kernel_function.linear.Batch(fptype=np.float64)
65 
66 qualityMetricSetResult = None
67 predictedLabels = None
68 groundTruthLabels = None
69 
70 
71 def trainModel():
72  global trainingResult
73 
74  # Initialize FileDataSource to retrieve the input data from a .csv file
75  trainDataSource = FileDataSource(
76  trainDatasetFileName, DataSourceIface.notAllocateNumericTable,
77  DataSourceIface.doDictionaryFromContext
78  )
79 
80  # Create Numeric Tables for training data and labels
81  trainData = HomogenNumericTable(nFeatures, 0, NumericTableIface.doNotAllocate)
82  trainGroundTruth = HomogenNumericTable(1, 0, NumericTableIface.doNotAllocate)
83  mergedData = MergedNumericTable(trainData, trainGroundTruth)
84 
85  # Retrieve the data from the input file
86  trainDataSource.loadDataBlock(mergedData)
87 
88  # Create an algorithm object to train the multi-class SVM model
89  algorithm = multi_class_classifier.training.Batch(nClasses,fptype=np.float64)
90 
91  algorithm.parameter.training = training
92  algorithm.parameter.prediction = prediction
93 
94  # Pass a training data set and dependent values to the algorithm
95  algorithm.input.set(classifier.training.data, trainData)
96  algorithm.input.set(classifier.training.labels, trainGroundTruth)
97 
98  # Build the multi-class SVM model and get the algorithm results
99  trainingResult = algorithm.compute()
100 
101 
102 def testModel():
103  global predictionResult, groundTruthLabels
104 
105  # Initialize FileDataSource<CSVFeatureManager> to retrieve the test data from a .csv file
106  testDataSource = FileDataSource(
107  testDatasetFileName, DataSourceIface.doAllocateNumericTable,
108  DataSourceIface.doDictionaryFromContext
109  )
110 
111  # Create Numeric Tables for testing data and labels
112  testData = HomogenNumericTable(nFeatures, 0, NumericTableIface.doNotAllocate)
113  groundTruthLabels = HomogenNumericTable(1, 0, NumericTableIface.doNotAllocate)
114  mergedData = MergedNumericTable(testData, groundTruthLabels)
115 
116  # Retrieve the data from input file
117  testDataSource.loadDataBlock(mergedData)
118 
119  # Create an algorithm object to predict multi-class SVM values
120  algorithm = multi_class_classifier.prediction.Batch(nClasses,fptype=np.float64)
121 
122  algorithm.parameter.training = training
123  algorithm.parameter.prediction = prediction
124 
125  # Pass a testing data set and the trained model to the algorithm
126  algorithm.input.setTable(classifier.prediction.data, testData)
127  algorithm.input.setModel(classifier.prediction.model, trainingResult.get(classifier.training.model))
128 
129  # Predict multi-class SVM values and get the Result class from daal.algorithms.classifier.prediction
130  predictionResult = algorithm.compute()
131 
132 
133 def testModelQuality():
134  global predictedLabels, qualityMetricSetResult
135 
136  # Retrieve predicted labels
137  predictedLabels = predictionResult.get(classifier.prediction.prediction)
138 
139  # Create a quality metric set object to compute quality metrics of the multi-class classifier algorithm
140  qualityMetricSet = multi_class_classifier.quality_metric_set.Batch(nClasses)
141  input = qualityMetricSet.getInputDataCollection().getInput(multi_class_classifier.quality_metric_set.confusionMatrix)
142 
143  input.set(multiclass_confusion_matrix.predictedLabels, predictedLabels)
144  input.set(multiclass_confusion_matrix.groundTruthLabels, groundTruthLabels)
145 
146  # Compute quality metrics and get the quality metrics
147  # returns ResultCollection class from daal.algorithms.multi_class_classifier.quality_metric_set
148  qualityMetricSetResult = qualityMetricSet.compute()
149 
150 def printResults():
151 
152  # Print the classification results
153  printNumericTables(
154  groundTruthLabels, predictedLabels,
155  "Ground truth", "Classification results",
156  "SVM classification results (first 20 observations):", 20, interval=15, flt64=False
157  )
158  # Print the quality metrics
159  qualityMetricResult = qualityMetricSetResult.getResult(multi_class_classifier.quality_metric_set.confusionMatrix)
160  printNumericTable(qualityMetricResult.get(multiclass_confusion_matrix.confusionMatrix), "Confusion matrix:")
161 
162  block = BlockDescriptor()
163  qualityMetricsTable = qualityMetricResult.get(multiclass_confusion_matrix.multiClassMetrics)
164  qualityMetricsTable.getBlockOfRows(0, 1, readOnly, block)
165  qualityMetricsData = block.getArray().flatten()
166  print("Average accuracy: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.averageAccuracy]))
167  print("Error rate: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.errorRate]))
168  print("Micro precision: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.microPrecision]))
169  print("Micro recall: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.microRecall]))
170  print("Micro F-score: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.microFscore]))
171  print("Macro precision: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.macroPrecision]))
172  print("Macro recall: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.macroRecall]))
173  print("Macro F-score: {0:.3f}".format(qualityMetricsData[multiclass_confusion_matrix.macroFscore]))
174  qualityMetricsTable.releaseBlockOfRows(block)
175 
176 if __name__ == "__main__":
177  training.parameter.cacheSize = 100000000
178  training.parameter.kernel = kernel
179  prediction.parameter.kernel = kernel
180 
181  trainModel()
182  testModel()
183  testModelQuality()
184  printResults()

For more complete information about compiler optimizations, see our Optimization Notice.