2 min read

Tags

How to Train and Score CatBoost Model on Spark

A practical, developer-focused guide to distributed training and inference with CatBoost on Apache Spark, including code examples and best practices.


Table of Contents


About CatBoost

CatBoost (by Yandex) is a high-performance, open-source gradient boosting library. It is popular for:

  • Native support for categorical features (no need for manual encoding)
  • Fast and accurate predictions
  • Minimal parameter tuning required
  • Widely used for ranking, recommendation, classification, and regression tasks

Previously, CatBoost supported training only in Python and R, but could make predictions in Java, Python, C++, and R.


Distributed CatBoost Training on Spark

Distributed training for large datasets was limited to GPU or single-node CPU setups. Now, CatBoost provides a Spark package for distributed training and inference, following the Spark MLlib API and supporting Spark ML Pipelines.

  • Official Spark package: catboost4j-spark
  • Supported Spark versions: 2.3–3.0
  • Supported Scala versions: 2.11–2.12
  • Works with both Scala Spark and PySpark

Key Features and Limitations

Features:

  • Distributed training for binary, multiclass classification, and regression
  • Save models in Spark MLlib or CatBoost native format (.cbm)
  • Feature importance calculation
  • Distributed prediction/inference

Limitations:

  • No support for text and embedding features (as of now)

Getting Started: Setup and Dependencies

Add the CatBoost Spark dependency to your Maven pom.xml:

<dependency>
  <groupId>ai.catboost</groupId>
  <artifactId>catboost-spark_2.4_2.12</artifactId>
  <version>0.25</version>
</dependency>

For a full working example and all dependencies, see the GitHub sample project.


Example: Training and Scoring a CatBoost Model

Training a Binary Classification Model

import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types._
import ai.catboost.spark._

val srcDataSchema = Seq(
  StructField("features", SQLDataTypes.VectorType),
  StructField("label", StringType)
)

// Training data
val trainData = Seq(
  Row(Vectors.dense(0.11, 0.22, 0.13, 0.45, 0.89), "0"),
  Row(Vectors.dense(0.99, 0.82, 0.33, 0.89, 0.97), "1"),
  Row(Vectors.dense(0.12, 0.21, 0.23, 0.42, 0.24), "1"),
  Row(Vectors.dense(0.81, 0.63, 0.02, 0.55, 0.65), "0")
)
val spark = SparkSession.builder().getOrCreate()
val trainDf = spark.createDataFrame(spark.sparkContext.parallelize(trainData), StructType(srcDataSchema))
val trainPool = new Pool(trainDf)

// Evaluation data
val evalData = Seq(
  Row(Vectors.dense(0.22, 0.34, 0.9, 0.66, 0.99), "1"),
  Row(Vectors.dense(0.16, 0.1, 0.21, 0.67, 0.46), "0"),
  Row(Vectors.dense(0.78, 0.0, 0.0, 0.22, 0.12), "1")
)
val evalDf = spark.createDataFrame(spark.sparkContext.parallelize(evalData), StructType(srcDataSchema))
val evalPool = new Pool(evalDf)

val classifier = new CatBoostClassifier
val model: CatBoostClassificationModel = classifier.fit(trainPool, Array(evalPool))

// Apply model
val predictions = model.transform(evalPool.data)
predictions.show(false)

Output columns:

  • rawPredictions: Confidence scores for each class
  • probability: Sigmoid of raw predictions (class probabilities)
  • prediction: Predicted class (0 or 1)

Saving the Model

// Save model in Spark MLlib format
val savedModelPath = "models/binclass_model"
model.write.overwrite().save(savedModelPath)

// Save model in CatBoost native format
val savedNativeModelPath = "models/binclass_model.cbm"
model.saveNativeModel(savedNativeModelPath)

Feature Importance

val loadedModel = CatBoostClassificationModel.loadNativeModel("models/binclass_model.cbm")
val featureImportance = loadedModel.getFeatureImportancePrettified()
featureImportance.foreach(fi => println(s"[${fi.featureName}, ${fi.importance}]") )

Sample Output:

[2,47.26]
[4,30.27]
[1,12.31]
[3,10.16]
[0,0.0]

References & Further Reading