web123456

Spark MLlib model training—Classification algorithm Multilayer Perceptron Classifier

import org.apache.spark.ml.classification.MultilayerPerceptronClassifier import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.ml.feature.{VectorAssembler, StringIndexer} import org.apache.spark.sql.SparkSession // Create SparkSession val spark = SparkSession.builder() .appName("MLPClassifierExample") .master("local[*]") .getOrCreate() // Prepare the dataset val data = spark.createDataFrame(Seq( (0.0, 0.0, 0.0, 0.0, 0.0), (1.0, 1.0, 1.0, 1.0, 1.0), (1.0, 0.0, 1.0, 0.0, 0.0), (0.0, 1.0, 0.0, 1.0, 1.0), (0.0, 1.0, 1.0, 0.0, 0.0) )).toDF("label", "feature1", "feature2", "feature3", "feature4") // Combination of eigenvectors val assembler = new VectorAssembler() .setInputCols(Array("feature1", "feature2", "feature3", "feature4")) .setOutputCol("features") val trainingData = assembler.transform(data).select("label", "features") // Define neural network structure val layers = Array[Int](4, 5, 4, 2) // 4 nodes in the input layer, 5 nodes in the first hidden layer, 4 nodes in the second hidden layer, 2 nodes in the output layer // Configure the MLP classifier val trainer = new MultilayerPerceptronClassifier() .setLayers(layers) .setLabelCol("label") .setFeaturesCol("features") .setMaxIter(100) // Set the maximum number of iterations // Training the model val model = trainer.fit(trainingData) // Prepare test data val testData = spark.createDataFrame(Seq( (0.0, 1.0, 1.0, 0.0, 0.0), (1.0, 0.0, 0.0, 1.0, 1.0) )).toDF("label", "feature1", "feature2", "feature3", "feature4") val testFeatures = assembler.transform(testData).select("features") // Make predictions val predictions = model.transform(testFeatures) predictions.show() // Evaluate the model val evaluator = new MulticlassClassificationEvaluator() .setLabelCol("label") .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) println(s"Test set accuracy = $accuracy") // Close SparkSession spark.stop()