34 import daal.algorithms.optimization_solver
as optimization_solver
35 import daal.algorithms.optimization_solver.logistic_loss
36 import daal.algorithms.optimization_solver.sgd
37 import daal.algorithms.optimization_solver.iterative_solver
39 from daal.data_management
import (
40 DataSourceIface, FileDataSource, HomogenNumericTable, MergedNumericTable, NumericTableIface
43 utils_folder = os.path.realpath(os.path.abspath(os.path.dirname(os.path.dirname(__file__))))
44 if utils_folder
not in sys.path:
45 sys.path.insert(0, utils_folder)
46 from utils
import printNumericTable
48 datasetFileName = os.path.join(
'..',
'data',
'batch',
'custom.csv')
53 accuracyThreshold = 0.02
55 initialPoint = np.array([[1], [1], [1], [1], [1]], dtype=np.float64)
57 if __name__ ==
"__main__":
60 dataSource = FileDataSource(datasetFileName,
61 DataSourceIface.notAllocateNumericTable,
62 DataSourceIface.doDictionaryFromContext)
65 data = HomogenNumericTable(nFeatures, 0, NumericTableIface.doNotAllocate)
66 dependentVariables = HomogenNumericTable(1, 0, NumericTableIface.doNotAllocate)
67 mergedData = MergedNumericTable(data, dependentVariables)
70 dataSource.loadDataBlock(mergedData)
72 nVectors = data.getNumberOfRows()
74 logistic_lossObjectiveFunction = optimization_solver.logistic_loss.Batch(nVectors)
75 logistic_lossObjectiveFunction.input.set(optimization_solver.logistic_loss.data, data)
76 logistic_lossObjectiveFunction.input.set(optimization_solver.logistic_loss.dependentVariables, dependentVariables)
79 sgdAlgorithm = optimization_solver.sgd.Batch(logistic_lossObjectiveFunction)
82 sgdAlgorithm.input.setInput(optimization_solver.iterative_solver.inputArgument, HomogenNumericTable(initialPoint))
83 sgdAlgorithm.parameter.learningRateSequence = HomogenNumericTable(1, 1, NumericTableIface.doAllocate, learningRate)
84 sgdAlgorithm.parameter.nIterations = nIterations
85 sgdAlgorithm.parameter.accuracyThreshold = accuracyThreshold
89 res = sgdAlgorithm.compute()
92 printNumericTable(res.getResult(optimization_solver.iterative_solver.minimum),
"Minimum:")
93 printNumericTable(res.getResult(optimization_solver.iterative_solver.nIterations),
"Number of iterations performed:")