web123456

Decision tree algorithm hyperparameter tuning

import org.apache.spark..{DecisionTreeClassifier, DecisionTreeClassificationModel}
import .{MulticlassClassificationEvaluator, BinaryClassificationEvaluator}
import
import .{ParamGridBuilder, CrossValidator}
import
import ._

// Create SparkSession
val spark = ()
.appName(“BilibiliAnalysis”)
.config(“”, “local”)
.getOrCreate()

// Read CSV fileand select the desired column
val filePath = “file:///usr/local/hadoop/”
val df = (“header”, “true”).csv(filePath)

// Convert all fields' data types to integers and process null values
val convertedDF = (
col(“Views”).cast(“int”),
col(“Danmaku_Count”).cast(“int”),
col(“Comment_Count”).cast(“int”),
col(“Favorite_Count”).cast(“int”),
col(“Coin_Count”).cast(“int”),
col(“Share_Count”).cast(“int”),
col(“Like_Count”).cast(“int”),
col(“Partition_Ranking”).cast(“int”)
).(0) // Fill the empty value to 0

// Create a new tag column
val labeledDF = (“label”, when(col(“Partition_Ranking”) <= 10, 1).otherwise(0))

// Use VectorAssembler to convert feature columns into feature vectors
val featureCols = Array(“Views”, “Danmaku_Count”, “Comment_Count”, “Favorite_Count”, “Coin_Count”, “Share_Count”, “Like_Count”)
val assembler = new VectorAssembler()
.setInputCols(featureCols)
.setOutputCol(“features”)

val assembledDF = (labeledDF)

// Divide the training set and the verification set
val Array(trainData, testData) = (Array(0.8, 0.2), seed = 1234)

// InstantiationDecision treeClassifier
val dt = new DecisionTreeClassifier()
.setLabelCol(“label”)
.setFeaturesCol(“features”)

// Set the hyperparameter grid
val paramGrid = new ParamGridBuilder()
.addGrid(, Array(5, 10, 15))
.addGrid(, Array(16, 32, 64))
.build()

// Cross-validation
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol(“label”)
.setPredictionCol(“prediction”)
.setMetricName(“accuracy”)

val cv = new CrossValidator()
.setEstimator(dt)
.setEvaluator(evaluator)
.setEstimatorParamMaps(())
.setNumFolds(5) // Set the number of cross-verification

// Perform cross-validation and select the best model
val cvModel = (trainData)

// Predict the test set
val predictions = (testData)

// Get the best model
val bestModel = [DecisionTreeClassificationModel]

// Make predictions on test dataset
val predictions = (testData)

// Assess the prediction accuracy
val accuracy = (predictions)

// Output accuracy
println("Test Accuracy: " + accuracy)

// Get the best model
val bestModel = [DecisionTreeClassificationModel]

// Make predictions on the verification set
val predictions = (testData)

// Use MulticlassClassificationEvaluator to evaluate the classification accuracy
val multiEvaluator = new MulticlassClassificationEvaluator()
.setLabelCol(“label”)
.setPredictionCol(“prediction”)
.setMetricName(“accuracy”)

val accuracy = (predictions)
println("Multiclass Classification Accuracy: " + accuracy)

// Evaluate AUC using BinaryClassificationEvaluator
val binaryEvaluator = new BinaryClassificationEvaluator()
.setLabelCol(“label”)
.setRawPredictionCol(“prediction”)
.setMetricName(“areaUnderROC”)

val auc = (predictions)
println("Binary Classification AUC: " + auc)

println(s"Best model parameters: maxDepth = $bestMaxDepth, maxBins = $bestMaxBins")

// Close SparkSession
()