Skip to content

EhsanKA/tml

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

79 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Python PyTorch Lightning License

TML β€” Transductive Machine Learning

A Python package for dropout-based uncertainty quantification and dataset pruning in binary classification tasks. TML implements a two-level training pipeline that uses Monte Carlo Dropout to produce reliable probability scores with associated uncertainty estimates.

✨ Key Features

  • Two-Level Training Pipeline β€” First prunes unreliable samples, then trains on high-confidence data
  • Monte Carlo Dropout β€” Uncertainty estimation through stochastic forward passes at inference time
  • Balanced Sampling β€” Automatic handling of class imbalance during training
  • Custom Architectures β€” Bring your own PyTorch Lightning models
  • Analysis Tools β€” Built-in metrics (ROC-AUC, Brier score, calibration) and visualizations
  • GPU Acceleration β€” Seamless CUDA support via PyTorch Lightning

πŸ“¦ Installation

From Source (Recommended)

git clone https://github.com/EhsanKA/tml.git
cd tml
conda env create --file environment.yaml
conda activate tml
pip install .

Dependencies

The conda environment includes:

Package Purpose
PyTorch Deep learning framework
PyTorch Lightning Training orchestration
scikit-learn Metrics and evaluation
pandas Data manipulation
matplotlib / seaborn Visualization
tensorboard Training logging

πŸš€ Quick Start

import torch
from tml.pipeline import Pipeline, ModelHandler
from models.mnist import CNNBinaryMNISTClassifier

# Prepare your data (must be torch tensors with binary labels 0/1)
X_train = ...  # Shape: (n_samples, *input_dims)
y_train = ...  # Shape: (n_samples,) with values {0, 1}

# Initialize your model
model = CNNBinaryMNISTClassifier(learning_rate=1e-3, dropout_rate=0.5)
model_handler = ModelHandler(model_instance=model)

# Create and run the pipeline
pipeline = Pipeline(
    model_handler=model_handler,
    data=X_train,
    hard_targets=y_train,
    batch_size=64,
    max_epochs=10,
    lower_threshold=0.3,
    upper_threshold=0.7,
    drop_iterations=10,
    seed=42
)

# Run n_steps iterations of the two-level training
pipeline.run(n_steps=5)

# Access results
probability_scores = pipeline.probability_scores  # Mean predictions
uncertainty_scores = pipeline.uncertainty_scores  # Prediction variance

πŸ”¬ How It Works

TML implements a two-level training strategy designed to improve prediction reliability:

Level 1: Initial Training & Pruning

  1. Balanced Sampling β€” Creates a balanced subset from imbalanced data
  2. Initial Training β€” Trains the model on the balanced subset
  3. Prediction β€” Generates predictions on the full dataset
  4. Pruning β€” Identifies high-confidence predictions:
    • True Positives: samples with label=1 and prediction > lower_threshold
    • True Negatives: samples with label=0 and prediction < upper_threshold

Level 2: Refined Training & Uncertainty Estimation

  1. Refined Training β€” Re-trains on the pruned (high-confidence) subset
  2. Standard Prediction β€” Generates probability scores
  3. MC Dropout β€” Multiple forward passes with dropout enabled to estimate uncertainty
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                         TML Pipeline                            β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                                                                 β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
β”‚  β”‚  Balanced   │───▢│   Train     │───▢│  Predict & Prune    β”‚ β”‚
β”‚  β”‚  Sampling   β”‚    β”‚   Level 1   β”‚    β”‚  (remove uncertain) β”‚ β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
β”‚                                                   β”‚             β”‚
β”‚                                                   β–Ό             β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
β”‚  β”‚ Uncertainty │◀───│   Train     │◀───│  High-Confidence    β”‚ β”‚
β”‚  β”‚  (MC Drop)  β”‚    β”‚   Level 2   β”‚    β”‚     Subset          β”‚ β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜    β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
β”‚                                                                 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

πŸ“Š Analysis & Visualization

TML provides tools for evaluating model performance and calibration:

from tml.analysis import BinaryClassificationAnalysis

# Create analysis object
analysis = BinaryClassificationAnalysis(
    labels=y_test,
    probability_scores=pipeline.probability_scores,
    uncertainty_scores=pipeline.uncertainty_scores
)

# Metrics
roc_auc = analysis.calculate_roc_auc()
brier = analysis.calculate_brier_score()
ece = analysis.expected_calibration_error(n_bins=10)
optimal_thresh = analysis.find_optimal_threshold()

# Visualizations
analysis.plot_roc_curve()
analysis.plot_reliability_diagram()
analysis.plot_uncertainty_distribution()
analysis.plot_uncertainty_vs_confidence()

Domain-Specific Plotting

For genomics applications (e.g., SNP classification):

from tml.plotting import tml_plots

cutoff = tml_plots(
    final=results_array,      # (n_samples, 2) - [prob_score, uncertainty]
    neg_ind=negative_indices,
    hpos_ind=positive_indices,
    minScore=0.5,
    auc_cf=0.9,
    tpr_cf=0.95,
    out="output_prefix"
)

πŸ—οΈ Project Structure

tml/
β”œβ”€β”€ tml/
β”‚   β”œβ”€β”€ pipeline.py      # Main Pipeline and ModelHandler classes
β”‚   β”œβ”€β”€ tml_dataset.py   # TMLDataset, BalancedSampler, prune function
β”‚   β”œβ”€β”€ analysis.py      # BinaryClassificationAnalysis metrics & plots
β”‚   β”œβ”€β”€ plotting.py      # Domain-specific visualization functions
β”‚   └── utils.py         # Helper functions
β”œβ”€β”€ models/
β”‚   β”œβ”€β”€ mnist.py         # Example CNN for MNIST binary classification
β”‚   └── model.py         # Generic binary classification models
β”œβ”€β”€ notebooks/
β”‚   β”œβ”€β”€ MNIST_MLP_0_7.ipynb                      # Usage example
β”‚   β”œβ”€β”€ MNIST_MLP_0_7_test_model_performance.ipynb
β”‚   └── MNIST_MLP_1_7.ipynb
β”œβ”€β”€ environment.yaml     # Conda environment specification
β”œβ”€β”€ pyproject.toml       # Package configuration
└── README.md

βš™οΈ Pipeline Parameters

Parameter Type Default Description
model_handler ModelHandler required Wrapper for your PyTorch Lightning model
data Tensor required Input features
hard_targets Tensor required Binary labels (0 or 1)
batch_size int 64 Training batch size
max_epochs int 1 Epochs per training phase
learning_rate float 1e-3 Optimizer learning rate
lower_threshold float 0.3 Pruning threshold for positive class
upper_threshold float 0.7 Pruning threshold for negative class
drop_iterations int 2 MC Dropout forward passes
seed int 42 Random seed for reproducibility

🎯 Custom Models

TML works with any PyTorch Lightning model. Requirements:

  1. Must include nn.Dropout layers for MC Dropout to work
  2. Output should be probability (use nn.Sigmoid() for binary classification)
  3. Model class should be re-instantiable
import torch.nn as nn
import pytorch_lightning as pl

class MyBinaryClassifier(pl.LightningModule):
    def __init__(self, input_dim, dropout_rate=0.5, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()
        
        self.model = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x).squeeze()
        loss = nn.BCELoss()(y_pred, y.float())
        self.log('train_loss', loss)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

# Use with TML
model = MyBinaryClassifier(input_dim=100)
handler = ModelHandler(model_instance=model)

πŸ““ Examples

See the notebooks directory for complete examples:

  • MNIST 0 vs 7 Classification β€” Binary digit classification with CNN
  • Model Performance Testing β€” Evaluation and analysis workflows

πŸ“„ License

This project is licensed under the GNU General Public License v3.0 β€” see the LICENSE file for details.

πŸ‘€ Author

Ehsan Karimiara β€” e.karimiara@gmail.com

πŸ“š Citation

If you use TML in your research, please cite:

@software{tml2024,
  author = {Karimiara, Ehsan},
  title = {TML: Transductive Machine Learning},
  year = {2024},
  url = {https://github.com/EhsanKA/tml}
}

Built with PyTorch Lightning ⚑

About

TML is a Python package for dropout-based dataset pruning in binary classification tasks.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors