A Python package for dropout-based uncertainty quantification and dataset pruning in binary classification tasks. TML implements a two-level training pipeline that uses Monte Carlo Dropout to produce reliable probability scores with associated uncertainty estimates.
- Two-Level Training Pipeline β First prunes unreliable samples, then trains on high-confidence data
- Monte Carlo Dropout β Uncertainty estimation through stochastic forward passes at inference time
- Balanced Sampling β Automatic handling of class imbalance during training
- Custom Architectures β Bring your own PyTorch Lightning models
- Analysis Tools β Built-in metrics (ROC-AUC, Brier score, calibration) and visualizations
- GPU Acceleration β Seamless CUDA support via PyTorch Lightning
git clone https://github.com/EhsanKA/tml.git
cd tml
conda env create --file environment.yaml
conda activate tml
pip install .The conda environment includes:
| Package | Purpose |
|---|---|
| PyTorch | Deep learning framework |
| PyTorch Lightning | Training orchestration |
| scikit-learn | Metrics and evaluation |
| pandas | Data manipulation |
| matplotlib / seaborn | Visualization |
| tensorboard | Training logging |
import torch
from tml.pipeline import Pipeline, ModelHandler
from models.mnist import CNNBinaryMNISTClassifier
# Prepare your data (must be torch tensors with binary labels 0/1)
X_train = ... # Shape: (n_samples, *input_dims)
y_train = ... # Shape: (n_samples,) with values {0, 1}
# Initialize your model
model = CNNBinaryMNISTClassifier(learning_rate=1e-3, dropout_rate=0.5)
model_handler = ModelHandler(model_instance=model)
# Create and run the pipeline
pipeline = Pipeline(
model_handler=model_handler,
data=X_train,
hard_targets=y_train,
batch_size=64,
max_epochs=10,
lower_threshold=0.3,
upper_threshold=0.7,
drop_iterations=10,
seed=42
)
# Run n_steps iterations of the two-level training
pipeline.run(n_steps=5)
# Access results
probability_scores = pipeline.probability_scores # Mean predictions
uncertainty_scores = pipeline.uncertainty_scores # Prediction varianceTML implements a two-level training strategy designed to improve prediction reliability:
- Balanced Sampling β Creates a balanced subset from imbalanced data
- Initial Training β Trains the model on the balanced subset
- Prediction β Generates predictions on the full dataset
- Pruning β Identifies high-confidence predictions:
- True Positives: samples with label=1 and prediction >
lower_threshold - True Negatives: samples with label=0 and prediction <
upper_threshold
- True Positives: samples with label=1 and prediction >
- Refined Training β Re-trains on the pruned (high-confidence) subset
- Standard Prediction β Generates probability scores
- MC Dropout β Multiple forward passes with dropout enabled to estimate uncertainty
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β TML Pipeline β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β βββββββββββββββ βββββββββββββββ βββββββββββββββββββββββ β
β β Balanced βββββΆβ Train βββββΆβ Predict & Prune β β
β β Sampling β β Level 1 β β (remove uncertain) β β
β βββββββββββββββ βββββββββββββββ ββββββββββββ¬βββββββββββ β
β β β
β βΌ β
β βββββββββββββββ βββββββββββββββ βββββββββββββββββββββββ β
β β Uncertainty ββββββ Train ββββββ High-Confidence β β
β β (MC Drop) β β Level 2 β β Subset β β
β βββββββββββββββ βββββββββββββββ βββββββββββββββββββββββ β
β β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
TML provides tools for evaluating model performance and calibration:
from tml.analysis import BinaryClassificationAnalysis
# Create analysis object
analysis = BinaryClassificationAnalysis(
labels=y_test,
probability_scores=pipeline.probability_scores,
uncertainty_scores=pipeline.uncertainty_scores
)
# Metrics
roc_auc = analysis.calculate_roc_auc()
brier = analysis.calculate_brier_score()
ece = analysis.expected_calibration_error(n_bins=10)
optimal_thresh = analysis.find_optimal_threshold()
# Visualizations
analysis.plot_roc_curve()
analysis.plot_reliability_diagram()
analysis.plot_uncertainty_distribution()
analysis.plot_uncertainty_vs_confidence()For genomics applications (e.g., SNP classification):
from tml.plotting import tml_plots
cutoff = tml_plots(
final=results_array, # (n_samples, 2) - [prob_score, uncertainty]
neg_ind=negative_indices,
hpos_ind=positive_indices,
minScore=0.5,
auc_cf=0.9,
tpr_cf=0.95,
out="output_prefix"
)tml/
βββ tml/
β βββ pipeline.py # Main Pipeline and ModelHandler classes
β βββ tml_dataset.py # TMLDataset, BalancedSampler, prune function
β βββ analysis.py # BinaryClassificationAnalysis metrics & plots
β βββ plotting.py # Domain-specific visualization functions
β βββ utils.py # Helper functions
βββ models/
β βββ mnist.py # Example CNN for MNIST binary classification
β βββ model.py # Generic binary classification models
βββ notebooks/
β βββ MNIST_MLP_0_7.ipynb # Usage example
β βββ MNIST_MLP_0_7_test_model_performance.ipynb
β βββ MNIST_MLP_1_7.ipynb
βββ environment.yaml # Conda environment specification
βββ pyproject.toml # Package configuration
βββ README.md
| Parameter | Type | Default | Description |
|---|---|---|---|
model_handler |
ModelHandler | required | Wrapper for your PyTorch Lightning model |
data |
Tensor | required | Input features |
hard_targets |
Tensor | required | Binary labels (0 or 1) |
batch_size |
int | 64 | Training batch size |
max_epochs |
int | 1 | Epochs per training phase |
learning_rate |
float | 1e-3 | Optimizer learning rate |
lower_threshold |
float | 0.3 | Pruning threshold for positive class |
upper_threshold |
float | 0.7 | Pruning threshold for negative class |
drop_iterations |
int | 2 | MC Dropout forward passes |
seed |
int | 42 | Random seed for reproducibility |
TML works with any PyTorch Lightning model. Requirements:
- Must include
nn.Dropoutlayers for MC Dropout to work - Output should be probability (use
nn.Sigmoid()for binary classification) - Model class should be re-instantiable
import torch.nn as nn
import pytorch_lightning as pl
class MyBinaryClassifier(pl.LightningModule):
def __init__(self, input_dim, dropout_rate=0.5, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
self.model = nn.Sequential(
nn.Linear(input_dim, 128),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(128, 64),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(64, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_pred = self(x).squeeze()
loss = nn.BCELoss()(y_pred, y.float())
self.log('train_loss', loss)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
# Use with TML
model = MyBinaryClassifier(input_dim=100)
handler = ModelHandler(model_instance=model)See the notebooks directory for complete examples:
- MNIST 0 vs 7 Classification β Binary digit classification with CNN
- Model Performance Testing β Evaluation and analysis workflows
This project is licensed under the GNU General Public License v3.0 β see the LICENSE file for details.
Ehsan Karimiara β e.karimiara@gmail.com
If you use TML in your research, please cite:
@software{tml2024,
author = {Karimiara, Ehsan},
title = {TML: Transductive Machine Learning},
year = {2024},
url = {https://github.com/EhsanKA/tml}
}Built with PyTorch Lightning β‘