from zoo.common.nncontext import init_nncontext
## TODO : use 'init_nncontext' to initialize Spark context.
## You can pass a string as program argument
sc = ???("???")
print('Spark UI running on http://localhost:' + sc.uiWebUrl.split(':')[2])
sc
Here we will see a simple regression example using Analytics zoo.
# Load data
from pyspark.sql.types import StructField, StructType, ArrayType, DoubleType
data = sc.parallelize([
((2.0, 1.0), (1.0, 2.0)),
((1.0, 2.0), (2.0, 1.0)),
((2.0, 1.0), (1.0, 2.0)),
((1.0, 2.0), (2.0, 1.0))])
schema = StructType([
StructField("features", ArrayType(DoubleType(), False), False),
StructField("label", ArrayType(DoubleType(), False), False)])
df = sqlContext.createDataFrame(data, schema)
df.show()
from bigdl.nn.layer import Sequential, Linear
from bigdl.nn.criterion import MSECriterion
from zoo.pipeline.nnframes.nn_classifier import NNEstimator
from zoo.pipeline.nnframes.nn_classifier import SeqToTensor, ArrayToTensor
model = Sequential().add(Linear(2, 2))
criterion = MSECriterion()
## TODO :
## - set batch size to 4
## - set learning rate to 0.2
## - set epoch to 20
estimator = NNEstimator(model, criterion, SeqToTensor([2]), ArrayToTensor([2]))\
.setBatchSize(???).setLearningRate(???).setMaxEpoch(???)
%%time
# training
print("training starting...")
## TODO : start training by calling 'fit' method on 'df'
nnModel = estimator.???(???)
print("training done")
## predict
## TODO : create predictions by running 'transform' on 'df'
results = nnModel.???(???)
results.show()
Analytics zoo has some tools to help us load image dataframes
! ./06-1-prep.sh
from pyspark.sql.functions import col, udf
from zoo.pipeline.nnframes import NNImageReader
from pyspark.sql.types import DoubleType, StringType
import re
image_path = "../data/cat-dog/sample/train/*.jpg"
## TODO : Use 'NNImageReader.readImages' function
## - first argument is 'image_path'
## - second argument is SparkContext 'sc'
imageDF = NNImageReader.readImages(???, ???)
getName = udf(lambda row:
re.search(r'(cat|dog)\.([\d]*)\.jpg', row[0], re.IGNORECASE).group(0),
StringType())
## TODO : return 1.0 for cat 2.0 for dog
getLabel = udf(lambda name: ??? if name.startswith('cat') else ???, DoubleType())
labelDF = imageDF.withColumn("name", getName(col("image"))) \
.withColumn("label", getLabel(col('name')))
## TODO : sample 10% (0.1) and display
labelDF.sample(False, ???).select("name","label").show(10)
## TODO : split 80% training and 20% validation
(trainingDF, validationDF) = labelDF.randomSplit([???, ???])
## TODO : print training / validation stats
print("training set count ", trainingDF.??? ())
print("validation set count ", validationDF.??? ())