How to Train and Score Catboost Model on Spark

About CatBoost

Catboost (developed by Yandex)  is one of the great open-source gradient boosting libraries with great performance without a lot of additional tuning. It provides support for categorical features without any need for encoding etc. and predictions are pretty fast as well. No wonder its one of the algorithm which is increasingly popular among data scientists community for a lot of ranking, recommendation, classification and regression problems.

Till now, Catboost supported training only in Python and R and predictions (applying the model) on a multitude of languages – Java( JVM-Packages), Python, C++, and R.

Distributed CatBoost Trainingcatboost-spark

There was limited support to train the model in a distributed manner for a big data set on CPU except for some support via GPU training. Catboost team at Yandex started working on the Spark version of the Catboost for the training and inference and they have recently released the spark version and is available in the maven repository to use. Catboost Spark Implementation follows general Spark MLLib implementations and supports Spark ML Pipelines etc.

It supports the following functionalities as of now –

  • Support for Spark 2.3-3.0 and Scala 2.11-2.12
  • Support for both Scala Spark and PySpark
  • Distributed Training for Binary Classification, MultiClass Classification, and Regression.
  • Save trained model in Spark MLLib Serialization Format or Catboost Native Format (.cbm) files.
  • Get Feature Importance for the CatBoost Models.
  • Prediction/Inference over Spark for the Catboost Models.

Limitations – As of now, it doesn’t support training for  Text and Embedding Features, which might not be a big deal for a large number of users.

I highly recommend going over videos explaining the implementation in more detail from CatBoost Team- CatBoost for Apache Spark introduction and CatBoost for Apache Spark Architecture.

I thought of giving it a try on some of the models and find below the snapshot of how this can be used  for Spark and full source code is available here at my GitHub link –

You just need to add this dependency in your POM and you should be okay, please look at GitHub sample above for all set of dependencies needed for end to end spark code –






And then we can use CatBoost classes below in spark code to train or score the model as follows.

Catboost Binary Classification Model –
val srcDataSchema = Seq(
  StructField("features", SQLDataTypes.VectorType),
  StructField("label", StringType)

//training data containing features and label.
val trainData = Seq(
  Row(Vectors.dense(0.11, 0.22, 0.13, 0.45, 0.89), "0"),
  Row(Vectors.dense(0.99, 0.82, 0.33, 0.89, 0.97), "1"),
  Row(Vectors.dense(0.12, 0.21, 0.23, 0.42, 0.24), "1"),
  Row(Vectors.dense(0.81, 0.63, 0.02, 0.55, 0.65), "0")

val trainDf = spark.createDataFrame(spark.sparkContext.parallelize(trainData), StructType(srcDataSchema))

val trainPool = new Pool(trainDf)

//evaluation data containing features and label.
val evalData = Seq(
  Row(Vectors.dense(0.22, 0.34, 0.9, 0.66, 0.99), "1"),
  Row(Vectors.dense(0.16, 0.1, 0.21, 0.67, 0.46), "0"),
  Row(Vectors.dense(0.78, 0.0, 0.0, 0.22, 0.12), "1")

val evalDf = spark.createDataFrame(spark.sparkContext.parallelize(evalData), StructType(srcDataSchema))

val evalPool = new Pool(evalDf)

val classifier = new CatBoostClassifier

 // train model

val model: CatBoostClassificationModel =, Array[Pool](evalPool))

// apply model

val predictions: DataFrame = model.transform(


rawPredictions – confidence scores for each of the class for the classification model,

probability scores, which are sigmoid of raw predictions for each of the class and

prediction class of 0 or 1 basis probability of >0.5 assigned as the probability of 1.

Saving the Model –
// save model

val savedModelPath = "models/binclass_model"


  // save model as local file in CatBoost native format

val savedNativeModelPath = "models/binclass_model.cbm"

Catboost Model Feature Importance Calculation –
val loadedModel = CatBoostClassificationModel.loadNativeModel("models/binclass_model.cbm")

val featureImportance = loadedModel.getFeatureImportancePrettified()

featureImportance.foreach(fi => println("[" + fi.featureName + "," + fi.importance + "]"))
Output – Feature Importance % for each feature in the model.


Please feel free to comment with any questions.



Spark – How to Run Spark Applications on Windows


You may also like...