This project implements and compares multiple classification models (MLP with ReLU, MLP with LeakyReLU, and Logistic Regression) for predicting customer churn (Exited). The pipeline includes:
- Data preprocessing (dropping identifiers, one-hot encoding categorical features, normalization).
- Model training (custom MLP with weighted sampling to handle class imbalance).
- Model evaluation (confusion matrix, classification report, ROC curve).
- Comparison of pretrained models (visual and numerical evaluation).
- Export of predictions on the test dataset to CSV.
-
Input files:
train.csv(with targetExited)test.csv(without target)
-
Processing steps:
- Remove identifiers (
id,CustomerId,Surname). - One-hot encode categorical features.
- Normalize numeric features using
StandardScaler. - Split
train.csvinto train/validation sets (80/20).
- Remove identifiers (
- MLP (ReLU)
- MLP (LeakyReLU)
- Logistic Regression
Each model outputs logits that are converted into probabilities of class 1 (Exited).
- Model trained:
MLP_relu. - Loss:
CrossEntropyLoss. - Optimizer:
Adam. - Balanced sampling with
WeightedRandomSampler. - Number of epochs: 10.
Outputs:
- Confusion matrix + ROC curve plots (default threshold=0.9).
→ Saved under
exercise1_plot/Trained_MLP_ReLU_visualization.jpg. - Trained model parameters saved to
checkpoints/Trained_MLP_ReLU.pthif save_model flag is true (see parameter setting below).
Three pretrained models are evaluated on the validation set:
simple_mlp_normalized.pthsimple_mlp_leakyrelu.pthlogitregression.pth
Outputs:
- Confusion matrix + ROC curve plots for each model.
→ Saved under
exercise1_plot/{model_name}_visualization.jpg. - ROC comparison plot across all three models.
→ Saved under
exercise1_plot/ROC_3_model_comparison.jpg.
-
Each model predicts churn (
Exited) ontest.csv. -
Predictions are thresholded at 0.5.
-
Outputs:
model_predictions_test.csv→ predictions from all three pretrained models.Predicted_Exited_from_test.csv→ predictions from the trained model.
The script can be controlled via command-line arguments:
python main.py \
--train_file train.csv \
--test_file test.csv \
--model mlp_relu \
--epochs 10 \
--lr 0.001 \
--thres 0.9 \
--pretrained \
--save_model--train_file(str, default=train.csv): Path to training CSV file--test_file(str, default=test.csv): Path to test CSV file--model(str, default=mlp_relu): Model to train and evaluate. Options:mlp_relu,mlp_leakyrelu,logreg--epochs(int, default=10): Number of training epochs--lr(float, default=0.001): Learning rate--thres(float, default=0.9): Threshold for binary classification--pretrained(flag): If set, evaluates pretrained models instead of training--save_model(flag): If set, saves the trained model to disk
- ✅
exercise1_plot/Trained_MLP_ReLU_visualization.jpg(confusion matrix + ROC) - ✅
checkpoints/Trained_MLP_ReLU.pth(model weights) - ✅
exercise1_predicted/Predicted_Exited_from_test_Trained_MLP_ReLU.csv(predictions on test set)
- ✅
exercise1_plot/{model_name}_visualization.jpg(confusion matrix + ROC for each model) - ✅
exercise1_plot/ROC_3_model_comparison.jpg(combined ROC comparison) - ✅
exercise1_predicted/Predicted_Exited_from_test_{model_name}.csv(predictions on test set)
- Python 3.8+
- Libraries:
torch,pandas,numpy,scikit-learn,matplotlib,seaborn,tqdm
-
Model architecture and features
- The current MLP may be limited in capacity. Accuracy could potentially be improved by exploring deeper neural networks or alternative architectures beyond simple MLPs.
- Feature engineering can also help. For example, numeric features like
Agecould be categorized into levels and one-hot encoded, or all categorical features could be encoded more comprehensively.
-
Feature selection
- Not all features may provide useful information. Some, such as
Genderor rawAge, might be less informative. - Dynamic or iterative feature selection could help identify the most relevant subset of features, improving model performance and reducing overfitting.
- See Dynamic Feature Selection for a reference implementation.
- Not all features may provide useful information. Some, such as
-
Data imbalance
- The current approach uses weighted sampling to handle class imbalance.
- Further improvements could include adaptive weighting of the loss function (e.g., assigning higher weights to minority classes) or using oversampling/undersampling techniques.
-
Loss function alternatives
- Currently,
CrossEntropyLossis used for multi-class outputs (or BCE for single output). - More robust loss functions could be considered, such as angular loss or other specialized classification losses, to improve learning under class imbalance or noisy labels.
- Currently,