Catboost (developed by Yandex) is one of the great open-source gradient boosting libraries with great performance without a lot of additional tuning. It provides support for categorical features without any need for encoding etc. and predictions are pretty fast as well. No wonder its one of the algorithm which is increasingly popular among data scientists community for a lot of ranking, recommendation, classification and regression problems.
Till now, Catboost supported training only in Python and R and predictions (applying the model) on a multitude of languages – Java( JVM-Packages), Python, C++, and R.
Distributed CatBoost Training
There was limited support to train the model in a distributed manner for a big data set on CPU except for some support via GPU training. Catboost team at Yandex started working on the Spark version of the Catboost for the training and inference and they have recently released the spark version and is available in the maven repository to use. Catboost Spark Implementation follows general Spark MLLib implementations and supports Spark ML Pipelines etc.
It supports the following functionalities as of now –
Support for Spark 2.3-3.0 and Scala 2.11-2.12
Support for both Scala Spark and PySpark
Distributed Training for Binary Classification, MultiClass Classification, and Regression.
Save trained model in Spark MLLib Serialization Format or Catboost Native Format (.cbm) files.
Get Feature Importance for the CatBoost Models.
Prediction/Inference over Spark for the Catboost Models.
Limitations – As of now, it doesn’t support training for Text and Embedding Features, which might not be a big deal for a large number of users.
I thought of giving it a try on some of the models and find below the snapshot of how this can be used for Spark and full source code is available here at my GitHub link – https://github.com/saurzcode/catboost-spark-examples
You just need to add this dependency in your POM and you should be okay, please look at GitHub sample above for all set of dependencies needed for end to end spark code –
And then we can use CatBoost classes below in spark code to train or score the model as follows.
Catboost Binary Classification Model –
val srcDataSchema = Seq( StructField("features", SQLDataTypes.VectorType), StructField("label", StringType) )
//training data containing features and label.
val trainData = Seq( Row(Vectors.dense(0.11, 0.22, 0.13, 0.45, 0.89), "0"), Row(Vectors.dense(0.99, 0.82, 0.33, 0.89, 0.97), "1"), Row(Vectors.dense(0.12, 0.21, 0.23, 0.42, 0.24), "1"), Row(Vectors.dense(0.81, 0.63, 0.02, 0.55, 0.65), "0") )
val trainDf = spark.createDataFrame(spark.sparkContext.parallelize(trainData), StructType(srcDataSchema))
val trainPool = new Pool(trainDf)
//evaluation data containing features and label.
val evalData = Seq( Row(Vectors.dense(0.22, 0.34, 0.9, 0.66, 0.99), "1"), Row(Vectors.dense(0.16, 0.1, 0.21, 0.67, 0.46), "0"), Row(Vectors.dense(0.78, 0.0, 0.0, 0.22, 0.12), "1") )
val evalDf = spark.createDataFrame(spark.sparkContext.parallelize(evalData), StructType(srcDataSchema))
val evalPool = new Pool(evalDf)
val classifier = new CatBoostClassifier // train model
val model: CatBoostClassificationModel = classifier.fit(trainPool, Array[Pool](evalPool)) // apply model
val predictions: DataFrame = model.transform(evalPool.data)
println("predictions")
predictions.show(false)
Output
rawPredictions – confidence scores for each of the class for the classification model,
probability scores, which are sigmoid of raw predictions for each of the class and
prediction class of 0 or 1 basis probability of >0.5 assigned as the probability of 1.
Saving the Model –
// save model
val savedModelPath = "models/binclass_model"
model.write.overwrite().save(savedModelPath) // save model as local file in CatBoost native format
val savedNativeModelPath = "models/binclass_model.cbm"
model.saveNativeModel(savedNativeModelPath)
Catboost Model Feature Importance Calculation –
val loadedModel = CatBoostClassificationModel.loadNativeModel("models/binclass_model.cbm")
val featureImportance = loadedModel.getFeatureImportancePrettified()
featureImportance.foreach(fi => println("[" + fi.featureName + "," + fi.importance + "]"))
Output – Feature Importance % for each feature in the model.