Bank Customer Churn Prediction — Decision Tree Classifier
Author: Emma Choukroun
Course: CSC310 - Fall 2025 (University of Rhode Island)
Organization: CSC310-fall25
License: MIT
Version: 1.0.0
Overview
This project focuses on predicting bank customer churn — identifying clients likely to close their bank accounts — using a Decision Tree Classifier.
It demonstrates how interpretable machine learning models can uncover key behavioral and transactional factors driving customer attrition.
The model was trained and evaluated as part of an academic exercise in the CSC310 course (Fall 2025) at the University of Rhode Island, under the organization CSC310-fall25.
Dataset Description
For this project, I chose not to use standard UCI datasets. Instead, I selected a real-world, finance-related dataset to align with my professional interests.
Dataset: Bank Customer Churn Dataset
Rows: 10,127 customers
Columns: 22 (demographics, account activity, and engagement metrics)
Target Variable: Attrition_Flag (Existing Customer vs. Attrited Customer)
Column Overview
| Column | Description |
|---|---|
| CLIENTNUM | Unique customer identifier |
| Attrition_Flag | Target — whether the customer left or stayed |
| Customer_Age | Age of the customer |
| Gender | M/F |
| Dependent_count | Number of dependents |
| Education_Level | Level of education |
| Marital_Status | Marital status |
| Income_Category | Annual income range |
| Card_Category | Credit card type |
| Months_on_book | Duration of relationship (in months) |
| Contacts_Count_12_mon | Number of contacts in the last 12 months |
| Credit_Limit | Credit card limit |
| Total_Revolving_Bal | Total revolving balance |
| Avg_Open_To_Buy | Average amount available to spend |
| Total_Amt_Chng_Q4_Q1 | Transaction amount change ratio (Q4 vs Q1) |
| Total_Trans_Amt | Total transaction amount |
| Total_Trans_Ct | Total number of transactions |
| Total_Ct_Chng_Q4_Q1 | Transaction count change ratio (Q4 vs Q1) |
| Avg_Utilization_Ratio | Average utilization ratio |
| Unnamed: 21 | Empty column (dropped) |
Note: Only 16.07% of customers are “Attrited Customers,” making this an imbalanced classification problem.
Data Visualization
Before modeling, exploratory analysis was performed to understand feature distributions and relationships. You can use the url to obtain data in order to test.
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
url_churn = "https://raw.githubusercontent.com/adin786/bank_churn/refs/heads/main/data/BankChurners.csv"
df_churn = pd.read_csv(url_churn)
plt.figure(figsize=(15, 10))
for i, column in enumerate(df_churn.select_dtypes(include=['object'])):
plt.subplot(2, 3, i + 1)
sns.countplot(x=column, data=df_churn, hue="Attrition_Flag", palette='husl')
plt.title(f'Count of {column}')
plt.xticks(rotation=60)
plt.tight_layout()
plt.show()
Why Decision Trees?
At first, a Naive Bayes classifier was considered.
However, the dataset does not satisfy Naive Bayes assumptions:
- Features are correlated (e.g.,
Total_Trans_AmtandTotal_Trans_Ct) - Distributions are non-Gaussian (e.g.,
Credit_LimitandTotal_Revolving_Balare heavily skewed) - The relationship between features and the target is nonlinear
Thus, the Decision Tree Classifier was chosen because it:
- Handles non-linear relationships naturally
- Is interpretable and visualizable
- Handles correlated and categorical features gracefully
- Requires no normalization or scaling
Model Training
The dataset was cleaned and encoded using LabelEncoder for categorical variables.
Data was split into train (60%), validation (20%), and test (20%) sets with stratification to preserve class balance.
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.preprocessing import LabelEncoder
df = df_churn.copy()
df = df.drop(columns=['CLIENTNUM', 'Unnamed: 21'], errors='ignore')
encoder = LabelEncoder()
for col in df.select_dtypes(include='object'):
df[col] = encoder.fit_transform(df[col].astype(str))
X = df.drop('Attrition_Flag', axis=1)
y = df['Attrition_Flag']
X_train_full, X_test, y_train_full, y_test = train_test_split(
X, y, test_size=0.4, stratify=y, random_state=42
)
X_train, X_val, y_train, y_val = train_test_split(
X_train_full, y_train_full, test_size=0.25, stratify=y_train_full, random_state=42
)
dt_model = DecisionTreeClassifier(
criterion='gini',
max_depth=5,
class_weight='balanced',
random_state=42
)
dt_model.fit(X_train, y_train)
Model Evaluation
The model was evaluated on the test set (20%) using standard classification metrics such as Accuracy, Precision, Recall, and F1-Score.
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
# Predictions
y_pred = dt_model.predict(X_test)
# Classification report
print(classification_report(y_test, y_pred))
# Confusion matrix visualization
cm = confusion_matrix(y_test, y_pred, labels=dt_model.classes_)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['Existing', 'Attrited'])
disp.plot(cmap='Blues', values_format='d')
plt.title("Confusion Matrix - Decision Tree Classifier")
plt.show()
| Class | Precision | Recall | F1-Score | Support |
|---|---|---|---|---|
| 0 (Attrited) | 0.65 | 0.87 | 0.74 | 651 |
| 1 (Existing) | 0.97 | 0.91 | 0.94 | 3400 |
| Accuracy | — | — | 0.90 | 4051 |
| Macro Avg | 0.81 | 0.89 | 0.84 | 4051 |
| Weighted Avg | 0.92 | 0.90 | 0.91 | 4051 |
The model achieved 91% overall accuracy and strong recall (0.88) on the 'Attrited' class , critical for churn prediction, where missing potential churners is costly.
Decision Tree Visualization
The trained tree can be visualized to understand the key splits and thresholds used in classification.
import matplotlib.pyplot as plt from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
plt.figure(figsize=(25, 12))
plot_tree(
dt_model,
feature_names=X.columns,
class_names=['Existing', 'Attrited'],
filled=True,
rounded=True,
fontsize=8
)
plt.title("Decision Tree Structure")
plt.show()
Model Strengths and Limitations
| Aspect | Strengths | Limitations |
|---|---|---|
| Interpretability | Easy to visualize and understand with decision paths | Can become complex if depth increases |
| Handling of Data | Works well with both categorical and numerical data | Sensitive to small data changes (instability) |
| Preprocessing | Requires minimal feature scaling or normalization | Still needs encoding for categorical variables |
| Performance | Performs well on non-linear relationships | May overfit if not properly pruned or regularized |
| Computation | Fast to train and predict on small datasets | Computational cost grows with dataset size |
| Class Imbalance | Handles imbalanced data with class_weight='balanced' |
Performance may still degrade with extreme imbalance |
| Explainability | Feature importance helps identify key churn drivers | Lacks probabilistic confidence compared to ensemble models |
Overall, the Decision Tree model is an interpretable baseline, ideal for explainability and educational use, but less robust than ensemble methods like Random Forests or Gradient Boosting.
Ethical Considerations
While the model offers insights into customer behavior, it should be applied responsibly.
Financial institutions must ensure that features do not directly or indirectly encode sensitive information
such as gender, ethnicity, or socioeconomic status. And features are also really sensitive in this field.
This model is for illustrative and educational purposes only.
Downloading and using a model
from skops import io as sio
from skops.io import get_untrusted_types
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(
repo_id="CSC310-fall25/bank_churn_tree",
filename="model.pkl",
local_dir="."
)
untrusted = get_untrusted_types(file=model_path)
print("Types non fiables détectés :", untrusted)
dt_loaded = sio.load(file=model_path, trusted=untrusted)
print(" OK model loaded")
You can now see the accuracy for example :
score = dt_loaded.score(X_test, y_test)
print(f"Accuracy of the model : {score:.2%}")
- Downloads last month
- 4