YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

Bank Customer Churn Prediction — Decision Tree Classifier

Author: Emma Choukroun
Course: CSC310 - Fall 2025 (University of Rhode Island)
Organization: CSC310-fall25
License: MIT
Version: 1.0.0


Overview

This project focuses on predicting bank customer churn — identifying clients likely to close their bank accounts — using a Decision Tree Classifier.
It demonstrates how interpretable machine learning models can uncover key behavioral and transactional factors driving customer attrition.

The model was trained and evaluated as part of an academic exercise in the CSC310 course (Fall 2025) at the University of Rhode Island, under the organization CSC310-fall25.


Dataset Description

For this project, I chose not to use standard UCI datasets. Instead, I selected a real-world, finance-related dataset to align with my professional interests.

Dataset: Bank Customer Churn Dataset
Rows: 10,127 customers
Columns: 22 (demographics, account activity, and engagement metrics)
Target Variable: Attrition_Flag (Existing Customer vs. Attrited Customer)

Column Overview

Column Description
CLIENTNUM Unique customer identifier
Attrition_Flag Target — whether the customer left or stayed
Customer_Age Age of the customer
Gender M/F
Dependent_count Number of dependents
Education_Level Level of education
Marital_Status Marital status
Income_Category Annual income range
Card_Category Credit card type
Months_on_book Duration of relationship (in months)
Contacts_Count_12_mon Number of contacts in the last 12 months
Credit_Limit Credit card limit
Total_Revolving_Bal Total revolving balance
Avg_Open_To_Buy Average amount available to spend
Total_Amt_Chng_Q4_Q1 Transaction amount change ratio (Q4 vs Q1)
Total_Trans_Amt Total transaction amount
Total_Trans_Ct Total number of transactions
Total_Ct_Chng_Q4_Q1 Transaction count change ratio (Q4 vs Q1)
Avg_Utilization_Ratio Average utilization ratio
Unnamed: 21 Empty column (dropped)

Note: Only 16.07% of customers are “Attrited Customers,” making this an imbalanced classification problem.


Data Visualization

Before modeling, exploratory analysis was performed to understand feature distributions and relationships. You can use the url to obtain data in order to test.

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

url_churn = "https://raw.githubusercontent.com/adin786/bank_churn/refs/heads/main/data/BankChurners.csv"
df_churn = pd.read_csv(url_churn)

plt.figure(figsize=(15, 10))
for i, column in enumerate(df_churn.select_dtypes(include=['object'])):
    plt.subplot(2, 3, i + 1)
    sns.countplot(x=column, data=df_churn, hue="Attrition_Flag", palette='husl')
    plt.title(f'Count of {column}')
    plt.xticks(rotation=60)
plt.tight_layout()
plt.show()

Why Decision Trees?

At first, a Naive Bayes classifier was considered.
However, the dataset does not satisfy Naive Bayes assumptions:

  • Features are correlated (e.g., Total_Trans_Amt and Total_Trans_Ct)
  • Distributions are non-Gaussian (e.g., Credit_Limit and Total_Revolving_Bal are heavily skewed)
  • The relationship between features and the target is nonlinear

Thus, the Decision Tree Classifier was chosen because it:

  • Handles non-linear relationships naturally
  • Is interpretable and visualizable
  • Handles correlated and categorical features gracefully
  • Requires no normalization or scaling

Model Training

The dataset was cleaned and encoded using LabelEncoder for categorical variables.
Data was split into train (60%), validation (20%), and test (20%) sets with stratification to preserve class balance.

from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.preprocessing import LabelEncoder

df = df_churn.copy()
df = df.drop(columns=['CLIENTNUM', 'Unnamed: 21'], errors='ignore')

encoder = LabelEncoder()
for col in df.select_dtypes(include='object'):
    df[col] = encoder.fit_transform(df[col].astype(str))

X = df.drop('Attrition_Flag', axis=1)
y = df['Attrition_Flag']

X_train_full, X_test, y_train_full, y_test = train_test_split(
    X, y, test_size=0.4, stratify=y, random_state=42
)

X_train, X_val, y_train, y_val = train_test_split(
    X_train_full, y_train_full, test_size=0.25, stratify=y_train_full, random_state=42
)

dt_model = DecisionTreeClassifier(
    criterion='gini',
    max_depth=5,
    class_weight='balanced',
    random_state=42
)

dt_model.fit(X_train, y_train)

Model Evaluation

The model was evaluated on the test set (20%) using standard classification metrics such as Accuracy, Precision, Recall, and F1-Score.

from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

# Predictions
y_pred = dt_model.predict(X_test)

# Classification report
print(classification_report(y_test, y_pred))

# Confusion matrix visualization
cm = confusion_matrix(y_test, y_pred, labels=dt_model.classes_)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['Existing', 'Attrited'])
disp.plot(cmap='Blues', values_format='d')
plt.title("Confusion Matrix - Decision Tree Classifier")
plt.show()
Class Precision Recall F1-Score Support
0 (Attrited) 0.65 0.87 0.74 651
1 (Existing) 0.97 0.91 0.94 3400
Accuracy 0.90 4051
Macro Avg 0.81 0.89 0.84 4051
Weighted Avg 0.92 0.90 0.91 4051

The model achieved 91% overall accuracy and strong recall (0.88) on the 'Attrited' class , critical for churn prediction, where missing potential churners is costly.

Decision Tree Visualization

The trained tree can be visualized to understand the key splits and thresholds used in classification.

import matplotlib.pyplot as plt from sklearn.tree import plot_tree

import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
plt.figure(figsize=(25, 12))
plot_tree(
    dt_model,
    feature_names=X.columns,
    class_names=['Existing', 'Attrited'],
    filled=True,
    rounded=True,
    fontsize=8
)
plt.title("Decision Tree Structure")
plt.show()

Model Strengths and Limitations

Aspect Strengths Limitations
Interpretability Easy to visualize and understand with decision paths Can become complex if depth increases
Handling of Data Works well with both categorical and numerical data Sensitive to small data changes (instability)
Preprocessing Requires minimal feature scaling or normalization Still needs encoding for categorical variables
Performance Performs well on non-linear relationships May overfit if not properly pruned or regularized
Computation Fast to train and predict on small datasets Computational cost grows with dataset size
Class Imbalance Handles imbalanced data with class_weight='balanced' Performance may still degrade with extreme imbalance
Explainability Feature importance helps identify key churn drivers Lacks probabilistic confidence compared to ensemble models

Overall, the Decision Tree model is an interpretable baseline, ideal for explainability and educational use, but less robust than ensemble methods like Random Forests or Gradient Boosting.

Ethical Considerations

While the model offers insights into customer behavior, it should be applied responsibly.
Financial institutions must ensure that features do not directly or indirectly encode sensitive information
such as gender, ethnicity, or socioeconomic status. And features are also really sensitive in this field.

This model is for illustrative and educational purposes only.

Downloading and using a model

from skops import io as sio
from skops.io import get_untrusted_types
from huggingface_hub import hf_hub_download

model_path = hf_hub_download(
    repo_id="CSC310-fall25/bank_churn_tree",
    filename="model.pkl",
    local_dir="."
)

untrusted = get_untrusted_types(file=model_path)
print("Types non fiables détectés :", untrusted)

dt_loaded = sio.load(file=model_path, trusted=untrusted)

print(" OK model loaded")

You can now see the accuracy for example :

score = dt_loaded.score(X_test, y_test)
print(f"Accuracy of the model : {score:.2%}")
Downloads last month
4
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support