diff --git a/bensemble/__init__.py b/bensemble/__init__.py index 0f4c23c..2081d8a 100644 --- a/bensemble/__init__.py +++ b/bensemble/__init__.py @@ -11,13 +11,13 @@ from .methods import ( LaplaceApproximation, - ProbabilisticBackpropagation, + PBPEngine, ) from .search import RandomSearcher, EvolutionarySearcher, SearchSpace __all__ = [ "LaplaceApproximation", - "ProbabilisticBackpropagation", + "PBPEngine", "RandomSearcher", "EvolutionarySearcher", "SearchSpace", diff --git a/bensemble/core/base.py b/bensemble/core/base.py deleted file mode 100644 index 331af6e..0000000 --- a/bensemble/core/base.py +++ /dev/null @@ -1,76 +0,0 @@ -import abc -from typing import Any, Dict, List, Optional, Tuple - -import torch -import torch.nn as nn - - -# TODO: remove this, new class is in layers/base.py - - -class BaseBayesianEnsemble(abc.ABC): - """Базовый класс для всех методов байесовского ансамблирования""" - - """The base class for all Bayesian ensembling methods""" - - def __init__(self, model: nn.Module, **kwargs): - self.model = model - self.is_fitted = False - self.ensemble = [] - - @abc.abstractmethod - def fit( - self, - train_loader: torch.utils.data.DataLoader, - val_loader: Optional[torch.utils.data.DataLoader] = None, - **kwargs, - ) -> Dict[str, List[float]]: - """Обучение ансамбля""" - """Ensemble training""" - ... - - @abc.abstractmethod - def predict( - self, X: torch.Tensor, n_samples: int = 100 - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Предсказание с оценкой неопределенности""" - """Prediction with uncertainty estimation""" - ... - - @abc.abstractmethod - def sample_models(self, n_models: int = 10) -> List[nn.Module]: - """Сэмплирование моделей из апостериорного распределения""" - """Sampling models from a posteriori distribution""" - ... - - def save(self, path: str): - """Сохранение обученного ансамбля""" - """Saving a trained ensemble""" - torch.save( - { - "model_state_dict": self.model.state_dict(), - "ensemble_state": self._get_ensemble_state(), - "is_fitted": self.is_fitted, - }, - path, - ) - - def load(self, path: str): - """Загрузка обученного ансамбля""" - """Loading a trained ensemble""" - checkpoint = torch.load(path) - self.model.load_state_dict(checkpoint["model_state_dict"]) - self._set_ensemble_state(checkpoint["ensemble_state"]) - self.is_fitted = checkpoint["is_fitted"] - - @abc.abstractmethod - def _get_ensemble_state(self) -> Dict[str, Any]: - """Получение внутреннего состояния ансамбля""" - """Getting the internal state of the ensemble""" - ... - - @abc.abstractmethod - def _set_ensemble_state(self, state: Dict[str, Any]): - """Установка внутреннего состояния ансамбля""" - """Setting the internal state of the ensemble""" - ... diff --git a/bensemble/methods/__init__.py b/bensemble/methods/__init__.py index 993fd2c..d90c248 100644 --- a/bensemble/methods/__init__.py +++ b/bensemble/methods/__init__.py @@ -2,11 +2,9 @@ LaplaceApproximation, ) -from .probabilistic_backpropagation import ( - ProbabilisticBackpropagation, -) +from .probabilistic_backpropagation import PBPEngine __all__ = [ "LaplaceApproximation", - "ProbabilisticBackpropagation", + "PBPEngine", ] diff --git a/bensemble/methods/laplace_approximation.py b/bensemble/methods/laplace_approximation.py index 30b5a6e..3dd55c1 100644 --- a/bensemble/methods/laplace_approximation.py +++ b/bensemble/methods/laplace_approximation.py @@ -1,39 +1,28 @@ import copy -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader -from ..core.base import BaseBayesianEnsemble +from bensemble.core.ensemble import Ensemble -class LaplaceApproximation(BaseBayesianEnsemble): +class LaplaceApproximation: """ Kronecker-factored Laplace approximation for neural networks. - - Backwards-compatible with the old API: - fit(...) - compute_posterior(...) - sample_models(n_models=...) - predict(...) - - Also compatible with posterior-source style APIs: - sample_models(n_members=...) """ def __init__( self, model: nn.Module, - pretrained: bool = True, - likelihood: str = "classification", - verbose: bool = False, + likelihood: str = "regression", + prior_precision: float = 1.0, damping: float = 1e-6, regularization: str = "legacy", + verbose: bool = False, ): - super().__init__(model) - if likelihood not in ["classification", "regression"]: raise ValueError(f"Unsupported likelihood: {likelihood}") if regularization not in ["legacy", "paper"]: @@ -44,15 +33,14 @@ def __init__( self.device = next(model.parameters()).device self.likelihood = likelihood - self.pretrained = pretrained - self.verbose = verbose + self.prior_precision = float(prior_precision) self.damping = damping self.regularization = regularization + self.verbose = verbose self.kronecker_factors: Dict[str, Dict[str, torch.Tensor]] = {} self.sampling_factors: Dict[str, Dict[str, Any]] = {} self.dataset_size = 1 - self.prior_precision = 1.0 self.hook_handles = [] self.activations: Dict[str, torch.Tensor] = {} @@ -62,76 +50,24 @@ def toggle_verbose(self): self.verbose = not self.verbose print("Verbose:", "on" if self.verbose else "off") - def fit( - self, - train_loader: DataLoader, - val_loader: Optional[DataLoader] = None, - num_epochs: int = 100, - lr: float = 1e-3, - prior_precision: float = 1.0, - num_samples: int = 1000, - ) -> Dict[str, List[float]]: - history: Dict[str, List[float]] = {} - - if not self.pretrained: - if self.verbose: - print("Training model...") - - optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) - history["train_loss"] = [] - - for epoch in range(num_epochs): - self.model.train() - train_loss = 0.0 - - for batch_x, batch_y in train_loader: - batch_x = batch_x.to(self.device) - batch_y = batch_y.to(self.device) - - optimizer.zero_grad(set_to_none=True) - output = self.model(batch_x) - - if self.likelihood == "classification": - loss = F.cross_entropy(output, batch_y.long()) - else: - batch_y = batch_y.to(output.dtype) - if batch_y.shape != output.shape: - batch_y = batch_y.view_as(output) - loss = F.mse_loss(output, batch_y) - - loss.backward() - optimizer.step() - train_loss += loss.item() - - total_loss = train_loss / max(1, len(train_loader)) - history["train_loss"].append(total_loss) - - if self.verbose: - print(f"Epoch {epoch}: training loss = {total_loss:.4f}") - - self.pretrained = True - - self.compute_posterior(train_loader, prior_precision, num_samples) - return history - - def compute_posterior( + def compute_curvature( self, train_loader: DataLoader, - prior_precision: float = 1.0, num_samples: int = 1000, ) -> None: - self.prior_precision = float(prior_precision) + """ + Estimates the Kronecker factors of the Hessian using the training data. + This must be called before sampling models. + """ self.dataset_size = len(train_loader.dataset) if self.verbose: print("Registering hooks...") - self._register_hooks() try: if self.verbose: print("Estimating Kronecker factors...") - self._estimate_kronecker_factors(train_loader, num_samples) finally: if self.verbose: @@ -141,7 +77,7 @@ def compute_posterior( self.is_fitted = True if self.verbose: - print("Posterior computation completed!") + print("Curvature computation completed!") def _register_hooks(self) -> None: self._remove_hooks() @@ -154,6 +90,7 @@ def forward_hook(module, inputs, output): if x.dim() > 2: x = x.flatten(start_dim=1) self.activations[layer_name] = x.detach() + return forward_hook for name, module in self.model.named_modules(): @@ -202,9 +139,7 @@ def _backward_hessian( return hessians current_hessian = ( - hessian_final.mean(0) - if hessian_final.dim() == 3 - else hessian_final + hessian_final.mean(0) if hessian_final.dim() == 3 else hessian_final ) final_layer_name, _ = linear_layers[-1] @@ -232,9 +167,6 @@ def _estimate_kronecker_factors( accumulators: Dict[str, Dict[str, Any]] = {} sample_count = 0 - if self.verbose: - print(f"Processing up to {num_samples} samples...") - for batch_idx, (data, target) in enumerate(train_loader): if sample_count >= num_samples: break @@ -269,16 +201,10 @@ def _estimate_kronecker_factors( accumulators[name] = { "Q_sum": torch.zeros( - in_dim, - in_dim, - device=self.device, - dtype=a.dtype, + in_dim, in_dim, device=self.device, dtype=a.dtype ), "H_sum": torch.zeros( - out_dim, - out_dim, - device=self.device, - dtype=h.dtype, + out_dim, out_dim, device=self.device, dtype=h.dtype ), "count": 0, "in_dim": in_dim, @@ -297,9 +223,6 @@ def _estimate_kronecker_factors( if sample_count % 1000 == 0 and self.verbose: print(f"Processed {sample_count} samples...") - if self.verbose: - print("Computing final Kronecker factors...") - with torch.no_grad(): self.kronecker_factors.clear() self.sampling_factors.clear() @@ -324,8 +247,8 @@ def _estimate_kronecker_factors( q_reg = n * q + tau * eye_q h_reg = n * h + tau * eye_h else: - q_reg = (n ** 0.5) * q + (tau ** 0.5) * eye_q - h_reg = (n ** 0.5) * h + (tau ** 0.5) * eye_h + q_reg = (n**0.5) * q + (tau**0.5) * eye_q + h_reg = (n**0.5) * h + (tau**0.5) * eye_h q_reg = self._stabilize_spd(q_reg) h_reg = self._stabilize_spd(h_reg) @@ -335,14 +258,6 @@ def _estimate_kronecker_factors( "H": h_reg.detach().clone(), } - if self.verbose: - print(f"Layer {name}:") - print(f" Q shape: {q_reg.shape}, H shape: {h_reg.shape}") - print(f" Q norm: {torch.norm(q_reg).item():.6f}") - print(f" H norm: {torch.norm(h_reg).item():.6f}") - print(f" cond(Q): {torch.linalg.cond(q_reg).item():.2e}") - print(f" cond(H): {torch.linalg.cond(h_reg).item():.2e}") - q_cov = torch.linalg.inv(q_reg) h_cov = torch.linalg.inv(h_reg) @@ -362,16 +277,15 @@ def _matrix_sqrt(self, matrix: torch.Tensor) -> torch.Tensor: return eigvecs @ torch.diag(torch.sqrt(eigvals)) @ eigvecs.T def sample_models( - self, - n_models: int = 10, - temperature: float = 1.0, - n_members: Optional[int] = None, + self, n_models: int = 10, temperature: float = 1.0 ) -> List[nn.Module]: - if n_members is not None: - n_models = n_members - + """ + Samples models from the approximated posterior. + """ if not self.is_fitted: - raise RuntimeError("LaplaceApproximation must be fitted before sampling.") + raise RuntimeError( + "Laplace curvature not computed. Call compute_curvature() first." + ) samples = [] modules = dict(self.model.named_modules()) @@ -387,7 +301,9 @@ def sample_models( l_h = factors["L_V"].to(device=self.device, dtype=mean_weight.dtype) weight_shape = factors["weight_shape"] - z = torch.randn(weight_shape, device=self.device, dtype=mean_weight.dtype) + z = torch.randn( + weight_shape, device=self.device, dtype=mean_weight.dtype + ) sampled_weight = mean_weight + temperature * (l_h @ z @ l_q.T) sampled_state[f"{name}.weight"] = sampled_weight.detach().cpu() @@ -403,75 +319,28 @@ def sample_models( return samples - def predict( - self, - X: torch.Tensor, - n_samples: int = 100, - temperature: float = 1.0, - ) -> Tuple[torch.Tensor, torch.Tensor]: - X = X.to(self.device) - predictions = [] - - for sampled_model in self.sample_models( - n_models=n_samples, - temperature=temperature, - ): - sampled_model.eval() - with torch.no_grad(): - predictions.append(sampled_model(X)) - - predictions = torch.stack(predictions, dim=0) - - if self.likelihood == "classification": - probs = F.softmax(predictions, dim=-1) - mean_probs = probs.mean(dim=0) - uncertainty = -( - mean_probs * torch.log(mean_probs.clamp_min(1e-8)) - ).sum(dim=-1) - return mean_probs, uncertainty - - mean = predictions.mean(dim=0) - variance = predictions.var(dim=0, unbiased=False) - return mean, variance - - def build_ensemble(self, n_members: int = 10, **kwargs): - from ..core.ensemble import Ensemble - - return Ensemble.from_posterior(self, n_members=n_members, **kwargs) + def build_ensemble(self, n_members: int = 10, temperature: float = 1.0) -> Ensemble: + return Ensemble.from_posterior( + self, n_members=n_members, temperature=temperature + ) def _get_ensemble_state(self) -> Dict[str, Any]: return { - "model": self.model, + "model_state": self.model.state_dict(), "is_fitted": self.is_fitted, - "device": self.device, "likelihood": self.likelihood, - "pretrained": self.pretrained, - "kronecker_factors": self.kronecker_factors, "sampling_factors": self.sampling_factors, "dataset_size": self.dataset_size, "prior_precision": self.prior_precision, - "damping": self.damping, - "regularization": self.regularization, - "verbose": self.verbose, } def _set_ensemble_state(self, state: Dict[str, Any]): - if state["likelihood"] in ["classification", "regression"]: - self.likelihood = state["likelihood"] - else: - raise ValueError(f"Unsupported likelihood: {state['likelihood']}") - - self.model = state["model"] + self.model.load_state_dict(state["model_state"]) self.is_fitted = state["is_fitted"] - self.device = state["device"] - self.pretrained = state["pretrained"] - self.kronecker_factors = state["kronecker_factors"] + self.likelihood = state["likelihood"] self.sampling_factors = state["sampling_factors"] self.dataset_size = state["dataset_size"] - self.prior_precision = state.get("prior_precision", 1.0) - self.damping = state.get("damping", 1e-6) - self.regularization = state.get("regularization", "legacy") - self.verbose = state.get("verbose", False) + self.prior_precision = state["prior_precision"] self.hook_handles = [] def _stabilize_spd(self, matrix: torch.Tensor) -> torch.Tensor: @@ -480,4 +349,4 @@ def _stabilize_spd(self, matrix: torch.Tensor) -> torch.Tensor: @staticmethod def _symmetrize(matrix: torch.Tensor) -> torch.Tensor: - return 0.5 * (matrix + matrix.T) \ No newline at end of file + return 0.5 * (matrix + matrix.T) diff --git a/bensemble/methods/probabilistic_backpropagation.py b/bensemble/methods/probabilistic_backpropagation.py index 6412b50..9ac8264 100644 --- a/bensemble/methods/probabilistic_backpropagation.py +++ b/bensemble/methods/probabilistic_backpropagation.py @@ -3,8 +3,9 @@ import torch import torch.nn as nn +from torch.utils.data import DataLoader -from ..core.base import BaseBayesianEnsemble +from bensemble.core.ensemble import Ensemble from ..utils import standard_normal_cdf, standard_normal_pdf @@ -118,9 +119,9 @@ def forward_moments(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: return mz, vz -class ProbabilisticBackpropagation(BaseBayesianEnsemble): +class PBPEngine: """ - Probabilistic Backpropagation (PBP) for Bayesian regression with moment matching. + Probabilistic Backpropagation (PBP) Engine for Bayesian regression. """ def __init__( @@ -136,19 +137,18 @@ def __init__( ): if model is None: if layer_sizes is None: - raise ValueError( - "Specify either a ready PBP model or layer_sizes to construct it." - ) + raise ValueError("Specify either a ready PBP model or layer_sizes.") model = PBPNet(layer_sizes, dtype=dtype, device=device) + self.device = device or torch.device("cpu") self.dtype = dtype - super().__init__(model.to(self.device)) + self.model = model.to(self.device) + self.is_fitted = False self.alpha_g = torch.tensor(noise_alpha, dtype=self.dtype, device=self.device) self.beta_g = torch.tensor(noise_beta, dtype=self.dtype, device=self.device) self.alpha_l = torch.tensor(weight_alpha, dtype=self.dtype, device=self.device) self.beta_l = torch.tensor(weight_beta, dtype=self.dtype, device=self.device) - # Initialize weights to match the stated Gaussian prior (before 1/sqrt(d) scaling). self._init_from_prior() def _init_from_prior(self) -> None: @@ -365,8 +365,8 @@ def _collect_dataset( def fit( self, - train_loader: torch.utils.data.DataLoader, - val_loader: Optional[torch.utils.data.DataLoader] = None, + train_loader: DataLoader, + val_loader: Optional[DataLoader] = None, num_epochs: int = 100, step_clip: Optional[float] = 2.0, prior_refresh: int = 1, @@ -453,26 +453,14 @@ def _predictive_mean_var( var = torch.clamp(noise_var + vz, min=1e-12) return mz, var - def predict( - self, X: torch.Tensor, n_samples: int = 100 - ) -> Tuple[torch.Tensor, torch.Tensor]: - if not self.is_fitted: - raise RuntimeError("Model not fitted. Call fit() first.") - - mean, var = self._predictive_mean_var( - X.to(device=self.device, dtype=self.dtype) - ) - std = torch.sqrt(torch.clamp(var, min=1e-12)) - samples = mean.unsqueeze(0) + std.unsqueeze(0) * torch.randn( - (n_samples,) + mean.shape, device=self.device, dtype=self.dtype - ) - return mean.detach(), samples.detach() - def noise_variance(self) -> torch.Tensor: alpha = torch.clamp(self.alpha_g, min=1.0 + 1e-6) return self.beta_g / (alpha - 1.0) - def sample_models(self, n_models: int = 10) -> List[nn.Module]: + def sample_models(self, n_models: int = 10, **kwargs) -> List[nn.Module]: + if not self.is_fitted: + raise RuntimeError("PBPEngine not fitted. Call fit() first.") + models = [] for _ in range(n_models): model_copy = self._sample_single_model() @@ -504,6 +492,9 @@ def _sample_single_model(self) -> nn.Module: model.eval() return model + def build_ensemble(self, n_members: int = 10) -> Ensemble: + return Ensemble.from_posterior(self, n_members=n_members) + def _get_ensemble_state(self) -> Dict[str, Any]: return { "alpha_g": self.alpha_g, diff --git a/tests/methods/test_laplace_approximation.py b/tests/methods/test_laplace_approximation.py index f71408e..9d7d648 100644 --- a/tests/methods/test_laplace_approximation.py +++ b/tests/methods/test_laplace_approximation.py @@ -3,7 +3,8 @@ import torch.nn as nn from torch.utils.data import TensorDataset, DataLoader -from bensemble.methods import LaplaceApproximation +from bensemble.methods.laplace_approximation import LaplaceApproximation +from bensemble.core.ensemble import Ensemble @pytest.fixture @@ -74,24 +75,18 @@ def test_verbose_toggle(regression_setup): assert laplace.verbose is False -def test_fit_regression_no_pretrained(regression_setup): +def test_compute_curvature_regression(regression_setup): """ - Tests full pipeline for regression (MSE): - 1. Training from scratch (pretrained=False). - 2. Computing Laplace factors (K-FAC). + Tests full pipeline for regression: """ model, loader, _, _ = regression_setup - laplace = LaplaceApproximation( - model, pretrained=False, likelihood="regression", verbose=True - ) - - history = laplace.fit(loader, num_epochs=1, num_samples=10) + laplace = LaplaceApproximation(model, likelihood="regression", verbose=True) - assert "train_loss" in history - assert len(history["train_loss"]) == 1 - assert laplace.pretrained is True + # Use the new method name + laplace.compute_curvature(loader, num_samples=10) + assert laplace.is_fitted is True assert len(laplace.kronecker_factors) == 2 assert len(laplace.sampling_factors) == 2 @@ -101,19 +96,15 @@ def test_fit_regression_no_pretrained(regression_setup): assert "weight_shape" in factors -def test_fit_classification_pretrained(classification_setup): +def test_compute_curvature_classification(classification_setup): """ - Tests pipeline for classification (CrossEntropy): - 1. Skips MAP training (pretrained=True). - 2. Computes factors using CrossEntropy Hessian approximation. + Tests pipeline for classification: """ model, loader, _, _ = classification_setup - laplace = LaplaceApproximation( - model, pretrained=True, likelihood="classification", verbose=False - ) + laplace = LaplaceApproximation(model, likelihood="classification", verbose=False) - laplace.fit(loader, num_samples=10) + laplace.compute_curvature(loader, num_samples=10) assert len(laplace.kronecker_factors) > 0 for name, factors in laplace.kronecker_factors.items(): @@ -121,51 +112,49 @@ def test_fit_classification_pretrained(classification_setup): assert not torch.isnan(factors["H"]).any() -def test_predict_regression_shapes(regression_setup): +def test_ensemble_integration_regression(regression_setup): """ - Tests prediction output shapes and values for regression. + Tests integration for regression. """ model, loader, X, y = regression_setup laplace = LaplaceApproximation(model, likelihood="regression") - laplace.fit(loader, num_epochs=0, num_samples=10) # Fast fit + laplace.compute_curvature(loader, num_samples=10) - mean, var = laplace.predict(X, n_samples=5) + # Use the new API + ensemble = Ensemble.from_posterior(laplace, n_members=5) - assert mean.shape == y.shape - assert var.shape == y.shape + # Predict members directly + with torch.no_grad(): + member_preds = ensemble.predict_members(X) - assert (var >= 0).all() + assert member_preds.shape == (5, 20, 1) -def test_predict_classification_shapes(classification_setup): +def test_ensemble_integration_classification(classification_setup): """ - Tests prediction output shapes for classification. - Expects probabilities and entropy/uncertainty. + Tests integration for classification. """ model, loader, X, _ = classification_setup laplace = LaplaceApproximation(model, likelihood="classification") - laplace.fit(loader, num_epochs=0, num_samples=10) + laplace.compute_curvature(loader, num_samples=10) - probs, uncertainty = laplace.predict(X, n_samples=5) + ensemble = Ensemble.from_posterior(laplace, n_members=5) - assert probs.shape == (20, 3) + with torch.no_grad(): + member_preds = ensemble.predict_members(X) - sums = probs.sum(dim=1) - assert torch.allclose(sums, torch.ones_like(sums)) - - assert uncertainty.shape == (20,) - assert (uncertainty >= 0).all() + assert member_preds.shape == (5, 20, 3) def test_sample_models_diversity(regression_setup): """ Ensures that sampled models are: 1. Valid nn.Modules. - 2. Have different weights (stochasticity). + 2. Have different weights. """ model, loader, _, _ = regression_setup laplace = LaplaceApproximation(model, likelihood="regression") - laplace.fit(loader, num_epochs=0, num_samples=10) + laplace.compute_curvature(loader, num_samples=10) samples = laplace.sample_models(n_models=2) assert len(samples) == 2 @@ -180,15 +169,14 @@ def test_sample_models_diversity(regression_setup): def test_hooks_cleanup(regression_setup): """ - Ensures that PyTorch forward hooks are removed after fitting. - Leaving hooks can cause memory leaks or unexpected behavior. + Ensures that PyTorch forward hooks are removed after computing curvature. """ model, loader, _, _ = regression_setup laplace = LaplaceApproximation(model, likelihood="regression") assert len(laplace.hook_handles) == 0 - laplace.fit(loader, num_samples=5) + laplace.compute_curvature(loader, num_samples=5) assert len(laplace.hook_handles) == 0 @@ -196,22 +184,20 @@ def test_hooks_cleanup(regression_setup): def test_state_management(regression_setup): """ Tests _get_ensemble_state and _set_ensemble_state. - Assumes bugs with empty keys have been fixed. """ model, loader, _, _ = regression_setup laplace = LaplaceApproximation(model, likelihood="regression") - laplace.fit(loader, num_epochs=0, num_samples=5) + laplace.compute_curvature(loader, num_samples=5) state = laplace._get_ensemble_state() assert state["likelihood"] == "regression" - required_keys = ["kronecker_factors", "sampling_factors", "dataset_size"] + required_keys = ["sampling_factors", "dataset_size"] for key in required_keys: assert key in state new_laplace = LaplaceApproximation(model, likelihood="regression") - new_laplace._set_ensemble_state(state) assert new_laplace.dataset_size == laplace.dataset_size diff --git a/tests/methods/test_probabilistic_backpropagation.py b/tests/methods/test_probabilistic_backpropagation.py index 2408eed..623479c 100644 --- a/tests/methods/test_probabilistic_backpropagation.py +++ b/tests/methods/test_probabilistic_backpropagation.py @@ -3,14 +3,15 @@ import torch.nn as nn from torch.utils.data import TensorDataset, DataLoader - from bensemble.methods.probabilistic_backpropagation import ( - ProbabilisticBackpropagation, + PBPEngine, PBPNet, ProbLinear, relu_moments, ) +from bensemble.core.ensemble import Ensemble + @pytest.fixture def pbp_data(): @@ -67,11 +68,11 @@ def test_pbp_net_forward(): def test_initialization(pbp_model_setup): """Tests initialization of the PBP wrapper.""" - pbp = ProbabilisticBackpropagation(layer_sizes=pbp_model_setup) + pbp = PBPEngine(layer_sizes=pbp_model_setup) assert isinstance(pbp.model, PBPNet) with pytest.raises(ValueError, match="Specify either"): - ProbabilisticBackpropagation(model=None, layer_sizes=None) + PBPEngine(model=None, layer_sizes=None) def test_fit_loop(pbp_data, pbp_model_setup): @@ -80,7 +81,7 @@ def test_fit_loop(pbp_data, pbp_model_setup): Checks if parameters update and history is returned. """ X, y, loader = pbp_data - pbp = ProbabilisticBackpropagation(layer_sizes=pbp_model_setup, dtype=torch.float64) + pbp = PBPEngine(layer_sizes=pbp_model_setup, dtype=torch.float64) alpha_old = pbp.alpha_g.item() beta_old = pbp.beta_g.item() @@ -96,32 +97,30 @@ def test_fit_loop(pbp_data, pbp_model_setup): def test_prior_refresh(pbp_data, pbp_model_setup): """Tests the prior refresh mechanism (updating alpha_l, beta_l).""" _, _, loader = pbp_data - pbp = ProbabilisticBackpropagation(layer_sizes=pbp_model_setup, dtype=torch.float64) + pbp = PBPEngine(layer_sizes=pbp_model_setup, dtype=torch.float64) pbp.fit(loader, num_epochs=1, prior_refresh=1) assert isinstance(pbp.alpha_l, torch.Tensor) -def test_predict(pbp_data, pbp_model_setup): - """Tests prediction output (mean and samples).""" +def test_ensemble_integration(pbp_data, pbp_model_setup): + """Tests integration with the new Ensemble API.""" X, y, loader = pbp_data - pbp = ProbabilisticBackpropagation(layer_sizes=pbp_model_setup, dtype=torch.float64) + pbp = PBPEngine(layer_sizes=pbp_model_setup, dtype=torch.float64) pbp.fit(loader, num_epochs=1) - mean, samples = pbp.predict(X, n_samples=5) - - assert mean.shape == y.shape + ensemble = Ensemble.from_posterior(pbp, n_members=5) - assert samples.shape == (5,) + y.shape + with torch.no_grad(): + member_preds = ensemble.predict_members(X) - noise_var = pbp.noise_variance() - assert noise_var > 0 + assert member_preds.shape == (5, 20, 1) def test_sample_models(pbp_data, pbp_model_setup): """Tests sampling of PyTorch models from PBP posterior.""" X, y, loader = pbp_data - pbp = ProbabilisticBackpropagation(layer_sizes=pbp_model_setup, dtype=torch.float64) + pbp = PBPEngine(layer_sizes=pbp_model_setup, dtype=torch.float64) pbp.fit(loader, num_epochs=1) models = pbp.sample_models(n_models=2) @@ -136,7 +135,7 @@ def test_sample_models(pbp_data, pbp_model_setup): def test_state_management(pbp_model_setup): """Tests saving and loading ensemble state.""" - pbp = ProbabilisticBackpropagation(layer_sizes=pbp_model_setup, dtype=torch.float64) + pbp = PBPEngine(layer_sizes=pbp_model_setup, dtype=torch.float64) state = pbp._get_ensemble_state() @@ -152,7 +151,7 @@ def test_state_management(pbp_model_setup): def test_val_loader(pbp_data, pbp_model_setup): """Tests fit with validation loader.""" X, y, loader = pbp_data - pbp = ProbabilisticBackpropagation(layer_sizes=pbp_model_setup, dtype=torch.float64) + pbp = PBPEngine(layer_sizes=pbp_model_setup, dtype=torch.float64) history = pbp.fit(loader, val_loader=loader, num_epochs=1) @@ -161,11 +160,10 @@ def test_val_loader(pbp_data, pbp_model_setup): def test_predict_not_fitted(pbp_model_setup): - """Ensures predict raises error if not fitted.""" - pbp = ProbabilisticBackpropagation(layer_sizes=pbp_model_setup) - X = torch.randn(5, 1) - with pytest.raises(RuntimeError, match="Model not fitted"): - pbp.predict(X) + """Ensures sample_models raises error if not fitted.""" + pbp = PBPEngine(layer_sizes=pbp_model_setup) + with pytest.raises(RuntimeError, match="PBPEngine not fitted"): + pbp.sample_models(n_models=5) def test_relu_moments_positive_mean(): diff --git a/uv.lock b/uv.lock index 30dd145..fc1b723 100644 --- a/uv.lock +++ b/uv.lock @@ -54,13 +54,12 @@ wheels = [ [[package]] name = "bensemble" -version = "0.1.0" +version = "0.2.1" source = { editable = "." } dependencies = [ { name = "matplotlib" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, - { name = "scikit-learn" }, { name = "scipy", version = "1.15.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "scipy", version = "1.16.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "torch" }, @@ -68,9 +67,6 @@ dependencies = [ [package.optional-dependencies] dev = [ - { name = "autopep8" }, - { name = "black" }, - { name = "isort" }, { name = "mypy" }, { name = "nbqa" }, { name = "pytest" }, @@ -81,13 +77,12 @@ docs = [ { name = "mkdocs-material" }, { name = "mkdocs-minify-plugin" }, { name = "mkdocstrings", extra = ["python"] }, + { name = "pymdown-extensions" }, + { name = "python-markdown-math" }, ] [package.metadata] requires-dist = [ - { name = "autopep8", marker = "extra == 'dev'" }, - { name = "black", marker = "extra == 'dev'" }, - { name = "isort", marker = "extra == 'dev'" }, { name = "matplotlib", specifier = ">=3.10.7" }, { name = "mkdocs-material", marker = "extra == 'docs'" }, { name = "mkdocs-minify-plugin", marker = "extra == 'docs'" }, @@ -95,54 +90,16 @@ requires-dist = [ { name = "mypy", marker = "extra == 'dev'" }, { name = "nbqa", marker = "extra == 'dev'" }, { name = "numpy", specifier = ">=2.2.6" }, + { name = "pymdown-extensions", marker = "extra == 'docs'" }, { name = "pytest", marker = "extra == 'dev'" }, { name = "pytest-cov", marker = "extra == 'dev'" }, + { name = "python-markdown-math", marker = "extra == 'docs'" }, { name = "ruff", marker = "extra == 'dev'" }, - { name = "scikit-learn", specifier = ">=1.7.2" }, { name = "scipy", specifier = ">=1.15.3" }, { name = "torch", specifier = ">=2.9.1" }, ] provides-extras = ["dev", "docs"] -[[package]] -name = "black" -version = "25.11.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "click" }, - { name = "mypy-extensions" }, - { name = "packaging" }, - { name = "pathspec" }, - { name = "platformdirs" }, - { name = "pytokens" }, - { name = "tomli", marker = "python_full_version < '3.11'" }, - { name = "typing-extensions", marker = "python_full_version < '3.11'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/8c/ad/33adf4708633d047950ff2dfdea2e215d84ac50ef95aff14a614e4b6e9b2/black-25.11.0.tar.gz", hash = "sha256:9a323ac32f5dc75ce7470501b887250be5005a01602e931a15e45593f70f6e08", size = 655669, upload-time = "2025-11-10T01:53:50.558Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b3/d2/6caccbc96f9311e8ec3378c296d4f4809429c43a6cd2394e3c390e86816d/black-25.11.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ec311e22458eec32a807f029b2646f661e6859c3f61bc6d9ffb67958779f392e", size = 1743501, upload-time = "2025-11-10T01:59:06.202Z" }, - { url = "https://files.pythonhosted.org/packages/69/35/b986d57828b3f3dccbf922e2864223197ba32e74c5004264b1c62bc9f04d/black-25.11.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1032639c90208c15711334d681de2e24821af0575573db2810b0763bcd62e0f0", size = 1597308, upload-time = "2025-11-10T01:57:58.633Z" }, - { url = "https://files.pythonhosted.org/packages/39/8e/8b58ef4b37073f52b64a7b2dd8c9a96c84f45d6f47d878d0aa557e9a2d35/black-25.11.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0c0f7c461df55cf32929b002335883946a4893d759f2df343389c4396f3b6b37", size = 1656194, upload-time = "2025-11-10T01:57:10.909Z" }, - { url = "https://files.pythonhosted.org/packages/8d/30/9c2267a7955ecc545306534ab88923769a979ac20a27cf618d370091e5dd/black-25.11.0-cp310-cp310-win_amd64.whl", hash = "sha256:f9786c24d8e9bd5f20dc7a7f0cdd742644656987f6ea6947629306f937726c03", size = 1347996, upload-time = "2025-11-10T01:57:22.391Z" }, - { url = "https://files.pythonhosted.org/packages/c4/62/d304786b75ab0c530b833a89ce7d997924579fb7484ecd9266394903e394/black-25.11.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:895571922a35434a9d8ca67ef926da6bc9ad464522a5fe0db99b394ef1c0675a", size = 1727891, upload-time = "2025-11-10T02:01:40.507Z" }, - { url = "https://files.pythonhosted.org/packages/82/5d/ffe8a006aa522c9e3f430e7b93568a7b2163f4b3f16e8feb6d8c3552761a/black-25.11.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:cb4f4b65d717062191bdec8e4a442539a8ea065e6af1c4f4d36f0cdb5f71e170", size = 1581875, upload-time = "2025-11-10T01:57:51.192Z" }, - { url = "https://files.pythonhosted.org/packages/cb/c8/7c8bda3108d0bb57387ac41b4abb5c08782b26da9f9c4421ef6694dac01a/black-25.11.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d81a44cbc7e4f73a9d6ae449ec2317ad81512d1e7dce7d57f6333fd6259737bc", size = 1642716, upload-time = "2025-11-10T01:56:51.589Z" }, - { url = "https://files.pythonhosted.org/packages/34/b9/f17dea34eecb7cc2609a89627d480fb6caea7b86190708eaa7eb15ed25e7/black-25.11.0-cp311-cp311-win_amd64.whl", hash = "sha256:7eebd4744dfe92ef1ee349dc532defbf012a88b087bb7ddd688ff59a447b080e", size = 1352904, upload-time = "2025-11-10T01:59:26.252Z" }, - { url = "https://files.pythonhosted.org/packages/7f/12/5c35e600b515f35ffd737da7febdb2ab66bb8c24d88560d5e3ef3d28c3fd/black-25.11.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:80e7486ad3535636657aa180ad32a7d67d7c273a80e12f1b4bfa0823d54e8fac", size = 1772831, upload-time = "2025-11-10T02:03:47Z" }, - { url = "https://files.pythonhosted.org/packages/1a/75/b3896bec5a2bb9ed2f989a970ea40e7062f8936f95425879bbe162746fe5/black-25.11.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6cced12b747c4c76bc09b4db057c319d8545307266f41aaee665540bc0e04e96", size = 1608520, upload-time = "2025-11-10T01:58:46.895Z" }, - { url = "https://files.pythonhosted.org/packages/f3/b5/2bfc18330eddbcfb5aab8d2d720663cd410f51b2ed01375f5be3751595b0/black-25.11.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6cb2d54a39e0ef021d6c5eef442e10fd71fcb491be6413d083a320ee768329dd", size = 1682719, upload-time = "2025-11-10T01:56:55.24Z" }, - { url = "https://files.pythonhosted.org/packages/96/fb/f7dc2793a22cdf74a72114b5ed77fe3349a2e09ef34565857a2f917abdf2/black-25.11.0-cp312-cp312-win_amd64.whl", hash = "sha256:ae263af2f496940438e5be1a0c1020e13b09154f3af4df0835ea7f9fe7bfa409", size = 1362684, upload-time = "2025-11-10T01:57:07.639Z" }, - { url = "https://files.pythonhosted.org/packages/ad/47/3378d6a2ddefe18553d1115e36aea98f4a90de53b6a3017ed861ba1bd3bc/black-25.11.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:0a1d40348b6621cc20d3d7530a5b8d67e9714906dfd7346338249ad9c6cedf2b", size = 1772446, upload-time = "2025-11-10T02:02:16.181Z" }, - { url = "https://files.pythonhosted.org/packages/ba/4b/0f00bfb3d1f7e05e25bfc7c363f54dc523bb6ba502f98f4ad3acf01ab2e4/black-25.11.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:51c65d7d60bb25429ea2bf0731c32b2a2442eb4bd3b2afcb47830f0b13e58bfd", size = 1607983, upload-time = "2025-11-10T02:02:52.502Z" }, - { url = "https://files.pythonhosted.org/packages/99/fe/49b0768f8c9ae57eb74cc10a1f87b4c70453551d8ad498959721cc345cb7/black-25.11.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:936c4dd07669269f40b497440159a221ee435e3fddcf668e0c05244a9be71993", size = 1682481, upload-time = "2025-11-10T01:57:12.35Z" }, - { url = "https://files.pythonhosted.org/packages/55/17/7e10ff1267bfa950cc16f0a411d457cdff79678fbb77a6c73b73a5317904/black-25.11.0-cp313-cp313-win_amd64.whl", hash = "sha256:f42c0ea7f59994490f4dccd64e6b2dd49ac57c7c84f38b8faab50f8759db245c", size = 1363869, upload-time = "2025-11-10T01:58:24.608Z" }, - { url = "https://files.pythonhosted.org/packages/67/c0/cc865ce594d09e4cd4dfca5e11994ebb51604328489f3ca3ae7bb38a7db5/black-25.11.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:35690a383f22dd3e468c85dc4b915217f87667ad9cce781d7b42678ce63c4170", size = 1771358, upload-time = "2025-11-10T02:03:33.331Z" }, - { url = "https://files.pythonhosted.org/packages/37/77/4297114d9e2fd2fc8ab0ab87192643cd49409eb059e2940391e7d2340e57/black-25.11.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:dae49ef7369c6caa1a1833fd5efb7c3024bb7e4499bf64833f65ad27791b1545", size = 1612902, upload-time = "2025-11-10T01:59:33.382Z" }, - { url = "https://files.pythonhosted.org/packages/de/63/d45ef97ada84111e330b2b2d45e1dd163e90bd116f00ac55927fb6bf8adb/black-25.11.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5bd4a22a0b37401c8e492e994bce79e614f91b14d9ea911f44f36e262195fdda", size = 1680571, upload-time = "2025-11-10T01:57:04.239Z" }, - { url = "https://files.pythonhosted.org/packages/ff/4b/5604710d61cdff613584028b4cb4607e56e148801ed9b38ee7970799dab6/black-25.11.0-cp314-cp314-win_amd64.whl", hash = "sha256:aa211411e94fdf86519996b7f5f05e71ba34835d8f0c0f03c00a26271da02664", size = 1382599, upload-time = "2025-11-10T01:57:57.427Z" }, - { url = "https://files.pythonhosted.org/packages/00/5d/aed32636ed30a6e7f9efd6ad14e2a0b0d687ae7c8c7ec4e4a557174b895c/black-25.11.0-py3-none-any.whl", hash = "sha256:e3f562da087791e96cefcd9dda058380a442ab322a02e222add53736451f604b", size = 204918, upload-time = "2025-11-10T01:53:48.917Z" }, -] - [[package]] name = "certifi" version = "2026.4.22" @@ -768,15 +725,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl", hash = "sha256:a9462224a505ade19a605f71f8fa63c2048833ce50abc86768a0d81d876dc81c", size = 8074, upload-time = "2025-01-17T11:24:33.271Z" }, ] -[[package]] -name = "isort" -version = "7.0.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/63/53/4f3c058e3bace40282876f9b553343376ee687f3c35a525dc79dbd450f88/isort-7.0.0.tar.gz", hash = "sha256:5513527951aadb3ac4292a41a16cbc50dd1642432f5e8c20057d414bdafb4187", size = 805049, upload-time = "2025-10-11T13:30:59.107Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7f/ed/e3705d6d02b4f7aea715a353c8ce193efd0b5db13e204df895d38734c244/isort-7.0.0-py3-none-any.whl", hash = "sha256:1bcabac8bc3c36c7fb7b98a76c8abb18e0f841a3ba81decac7691008592499c1", size = 94672, upload-time = "2025-10-11T13:30:57.665Z" }, -] - [[package]] name = "jedi" version = "0.19.2" @@ -801,15 +749,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, ] -[[package]] -name = "joblib" -version = "1.5.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e8/5d/447af5ea094b9e4c4054f82e223ada074c552335b9b4b2d14bd9b35a67c4/joblib-1.5.2.tar.gz", hash = "sha256:3faa5c39054b2f03ca547da9b2f52fde67c06240c31853f306aea97f13647b55", size = 331077, upload-time = "2025-08-27T12:15:46.575Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/e8/685f47e0d754320684db4425a0967f7d3fa70126bffd76110b7009a0090f/joblib-1.5.2-py3-none-any.whl", hash = "sha256:4e1f0bdbb987e6d843c70cf43714cb276623def372df3c22fe5266b2670bc241", size = 308396, upload-time = "2025-08-27T12:15:45.188Z" }, -] - [[package]] name = "jsmin" version = "3.0.1" @@ -1990,12 +1929,15 @@ wheels = [ ] [[package]] -name = "pytokens" -version = "0.3.0" +name = "python-markdown-math" +version = "0.9" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/4e/8d/a762be14dae1c3bf280202ba3172020b2b0b4c537f94427435f19c413b72/pytokens-0.3.0.tar.gz", hash = "sha256:2f932b14ed08de5fcf0b391ace2642f858f1394c0857202959000b68ed7a458a", size = 17644, upload-time = "2025-11-05T13:36:35.34Z" } +dependencies = [ + { name = "markdown" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4c/68/fbea05ec6fb318bdcf56ea47596605614554f51d77bfd14f6fb481139ad8/python_markdown_math-0.9.tar.gz", hash = "sha256:567395553dc4941e79b3789a1096dcabb3fda9539d150d558ef3507948b264a3", size = 8680, upload-time = "2025-04-10T10:10:31.84Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/84/25/d9db8be44e205a124f6c98bc0324b2bb149b7431c53877fc6d1038dddaf5/pytokens-0.3.0-py3-none-any.whl", hash = "sha256:95b2b5eaf832e469d141a378872480ede3f251a5a5041b8ec6e581d3ac71bbf3", size = 12195, upload-time = "2025-11-05T13:36:33.183Z" }, + { url = "https://files.pythonhosted.org/packages/eb/68/ecf3535c40845de2efd8ac2d092dd5fca0868219fa3684d9e58ef7abeece/python_markdown_math-0.9-py3-none-any.whl", hash = "sha256:ac9932df517a5c0f6d01c56e7a44d065eca4a420893ac45f7a6937c67cb41e86", size = 6046, upload-time = "2025-04-10T10:10:30.318Z" }, ] [[package]] @@ -2115,52 +2057,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1d/d2/1637f4360ada6a368d3265bf39f2cf737a0aaab15ab520fc005903e883f8/ruff-0.14.7-py3-none-win_arm64.whl", hash = "sha256:be4d653d3bea1b19742fcc6502354e32f65cd61ff2fbdb365803ef2c2aec6228", size = 13609215, upload-time = "2025-11-28T20:55:15.375Z" }, ] -[[package]] -name = "scikit-learn" -version = "1.7.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "joblib" }, - { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, - { name = "scipy", version = "1.15.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "scipy", version = "1.16.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, - { name = "threadpoolctl" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/98/c2/a7855e41c9d285dfe86dc50b250978105dce513d6e459ea66a6aeb0e1e0c/scikit_learn-1.7.2.tar.gz", hash = "sha256:20e9e49ecd130598f1ca38a1d85090e1a600147b9c02fa6f15d69cb53d968fda", size = 7193136, upload-time = "2025-09-09T08:21:29.075Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ba/3e/daed796fd69cce768b8788401cc464ea90b306fb196ae1ffed0b98182859/scikit_learn-1.7.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6b33579c10a3081d076ab403df4a4190da4f4432d443521674637677dc91e61f", size = 9336221, upload-time = "2025-09-09T08:20:19.328Z" }, - { url = "https://files.pythonhosted.org/packages/1c/ce/af9d99533b24c55ff4e18d9b7b4d9919bbc6cd8f22fe7a7be01519a347d5/scikit_learn-1.7.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:36749fb62b3d961b1ce4fedf08fa57a1986cd409eff2d783bca5d4b9b5fce51c", size = 8653834, upload-time = "2025-09-09T08:20:22.073Z" }, - { url = "https://files.pythonhosted.org/packages/58/0e/8c2a03d518fb6bd0b6b0d4b114c63d5f1db01ff0f9925d8eb10960d01c01/scikit_learn-1.7.2-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7a58814265dfc52b3295b1900cfb5701589d30a8bb026c7540f1e9d3499d5ec8", size = 9660938, upload-time = "2025-09-09T08:20:24.327Z" }, - { url = "https://files.pythonhosted.org/packages/2b/75/4311605069b5d220e7cf5adabb38535bd96f0079313cdbb04b291479b22a/scikit_learn-1.7.2-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4a847fea807e278f821a0406ca01e387f97653e284ecbd9750e3ee7c90347f18", size = 9477818, upload-time = "2025-09-09T08:20:26.845Z" }, - { url = "https://files.pythonhosted.org/packages/7f/9b/87961813c34adbca21a6b3f6b2bea344c43b30217a6d24cc437c6147f3e8/scikit_learn-1.7.2-cp310-cp310-win_amd64.whl", hash = "sha256:ca250e6836d10e6f402436d6463d6c0e4d8e0234cfb6a9a47835bd392b852ce5", size = 8886969, upload-time = "2025-09-09T08:20:29.329Z" }, - { url = "https://files.pythonhosted.org/packages/43/83/564e141eef908a5863a54da8ca342a137f45a0bfb71d1d79704c9894c9d1/scikit_learn-1.7.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c7509693451651cd7361d30ce4e86a1347493554f172b1c72a39300fa2aea79e", size = 9331967, upload-time = "2025-09-09T08:20:32.421Z" }, - { url = "https://files.pythonhosted.org/packages/18/d6/ba863a4171ac9d7314c4d3fc251f015704a2caeee41ced89f321c049ed83/scikit_learn-1.7.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:0486c8f827c2e7b64837c731c8feff72c0bd2b998067a8a9cbc10643c31f0fe1", size = 8648645, upload-time = "2025-09-09T08:20:34.436Z" }, - { url = "https://files.pythonhosted.org/packages/ef/0e/97dbca66347b8cf0ea8b529e6bb9367e337ba2e8be0ef5c1a545232abfde/scikit_learn-1.7.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:89877e19a80c7b11a2891a27c21c4894fb18e2c2e077815bcade10d34287b20d", size = 9715424, upload-time = "2025-09-09T08:20:36.776Z" }, - { url = "https://files.pythonhosted.org/packages/f7/32/1f3b22e3207e1d2c883a7e09abb956362e7d1bd2f14458c7de258a26ac15/scikit_learn-1.7.2-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8da8bf89d4d79aaec192d2bda62f9b56ae4e5b4ef93b6a56b5de4977e375c1f1", size = 9509234, upload-time = "2025-09-09T08:20:38.957Z" }, - { url = "https://files.pythonhosted.org/packages/9f/71/34ddbd21f1da67c7a768146968b4d0220ee6831e4bcbad3e03dd3eae88b6/scikit_learn-1.7.2-cp311-cp311-win_amd64.whl", hash = "sha256:9b7ed8d58725030568523e937c43e56bc01cadb478fc43c042a9aca1dacb3ba1", size = 8894244, upload-time = "2025-09-09T08:20:41.166Z" }, - { url = "https://files.pythonhosted.org/packages/a7/aa/3996e2196075689afb9fce0410ebdb4a09099d7964d061d7213700204409/scikit_learn-1.7.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8d91a97fa2b706943822398ab943cde71858a50245e31bc71dba62aab1d60a96", size = 9259818, upload-time = "2025-09-09T08:20:43.19Z" }, - { url = "https://files.pythonhosted.org/packages/43/5d/779320063e88af9c4a7c2cf463ff11c21ac9c8bd730c4a294b0000b666c9/scikit_learn-1.7.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:acbc0f5fd2edd3432a22c69bed78e837c70cf896cd7993d71d51ba6708507476", size = 8636997, upload-time = "2025-09-09T08:20:45.468Z" }, - { url = "https://files.pythonhosted.org/packages/5c/d0/0c577d9325b05594fdd33aa970bf53fb673f051a45496842caee13cfd7fe/scikit_learn-1.7.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e5bf3d930aee75a65478df91ac1225ff89cd28e9ac7bd1196853a9229b6adb0b", size = 9478381, upload-time = "2025-09-09T08:20:47.982Z" }, - { url = "https://files.pythonhosted.org/packages/82/70/8bf44b933837ba8494ca0fc9a9ab60f1c13b062ad0197f60a56e2fc4c43e/scikit_learn-1.7.2-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b4d6e9deed1a47aca9fe2f267ab8e8fe82ee20b4526b2c0cd9e135cea10feb44", size = 9300296, upload-time = "2025-09-09T08:20:50.366Z" }, - { url = "https://files.pythonhosted.org/packages/c6/99/ed35197a158f1fdc2fe7c3680e9c70d0128f662e1fee4ed495f4b5e13db0/scikit_learn-1.7.2-cp312-cp312-win_amd64.whl", hash = "sha256:6088aa475f0785e01bcf8529f55280a3d7d298679f50c0bb70a2364a82d0b290", size = 8731256, upload-time = "2025-09-09T08:20:52.627Z" }, - { url = "https://files.pythonhosted.org/packages/ae/93/a3038cb0293037fd335f77f31fe053b89c72f17b1c8908c576c29d953e84/scikit_learn-1.7.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:0b7dacaa05e5d76759fb071558a8b5130f4845166d88654a0f9bdf3eb57851b7", size = 9212382, upload-time = "2025-09-09T08:20:54.731Z" }, - { url = "https://files.pythonhosted.org/packages/40/dd/9a88879b0c1104259136146e4742026b52df8540c39fec21a6383f8292c7/scikit_learn-1.7.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:abebbd61ad9e1deed54cca45caea8ad5f79e1b93173dece40bb8e0c658dbe6fe", size = 8592042, upload-time = "2025-09-09T08:20:57.313Z" }, - { url = "https://files.pythonhosted.org/packages/46/af/c5e286471b7d10871b811b72ae794ac5fe2989c0a2df07f0ec723030f5f5/scikit_learn-1.7.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:502c18e39849c0ea1a5d681af1dbcf15f6cce601aebb657aabbfe84133c1907f", size = 9434180, upload-time = "2025-09-09T08:20:59.671Z" }, - { url = "https://files.pythonhosted.org/packages/f1/fd/df59faa53312d585023b2da27e866524ffb8faf87a68516c23896c718320/scikit_learn-1.7.2-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7a4c328a71785382fe3fe676a9ecf2c86189249beff90bf85e22bdb7efaf9ae0", size = 9283660, upload-time = "2025-09-09T08:21:01.71Z" }, - { url = "https://files.pythonhosted.org/packages/a7/c7/03000262759d7b6f38c836ff9d512f438a70d8a8ddae68ee80de72dcfb63/scikit_learn-1.7.2-cp313-cp313-win_amd64.whl", hash = "sha256:63a9afd6f7b229aad94618c01c252ce9e6fa97918c5ca19c9a17a087d819440c", size = 8702057, upload-time = "2025-09-09T08:21:04.234Z" }, - { url = "https://files.pythonhosted.org/packages/55/87/ef5eb1f267084532c8e4aef98a28b6ffe7425acbfd64b5e2f2e066bc29b3/scikit_learn-1.7.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:9acb6c5e867447b4e1390930e3944a005e2cb115922e693c08a323421a6966e8", size = 9558731, upload-time = "2025-09-09T08:21:06.381Z" }, - { url = "https://files.pythonhosted.org/packages/93/f8/6c1e3fc14b10118068d7938878a9f3f4e6d7b74a8ddb1e5bed65159ccda8/scikit_learn-1.7.2-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:2a41e2a0ef45063e654152ec9d8bcfc39f7afce35b08902bfe290c2498a67a6a", size = 9038852, upload-time = "2025-09-09T08:21:08.628Z" }, - { url = "https://files.pythonhosted.org/packages/83/87/066cafc896ee540c34becf95d30375fe5cbe93c3b75a0ee9aa852cd60021/scikit_learn-1.7.2-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:98335fb98509b73385b3ab2bd0639b1f610541d3988ee675c670371d6a87aa7c", size = 9527094, upload-time = "2025-09-09T08:21:11.486Z" }, - { url = "https://files.pythonhosted.org/packages/9c/2b/4903e1ccafa1f6453b1ab78413938c8800633988c838aa0be386cbb33072/scikit_learn-1.7.2-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:191e5550980d45449126e23ed1d5e9e24b2c68329ee1f691a3987476e115e09c", size = 9367436, upload-time = "2025-09-09T08:21:13.602Z" }, - { url = "https://files.pythonhosted.org/packages/b5/aa/8444be3cfb10451617ff9d177b3c190288f4563e6c50ff02728be67ad094/scikit_learn-1.7.2-cp313-cp313t-win_amd64.whl", hash = "sha256:57dc4deb1d3762c75d685507fbd0bc17160144b2f2ba4ccea5dc285ab0d0e973", size = 9275749, upload-time = "2025-09-09T08:21:15.96Z" }, - { url = "https://files.pythonhosted.org/packages/d9/82/dee5acf66837852e8e68df6d8d3a6cb22d3df997b733b032f513d95205b7/scikit_learn-1.7.2-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:fa8f63940e29c82d1e67a45d5297bdebbcb585f5a5a50c4914cc2e852ab77f33", size = 9208906, upload-time = "2025-09-09T08:21:18.557Z" }, - { url = "https://files.pythonhosted.org/packages/3c/30/9029e54e17b87cb7d50d51a5926429c683d5b4c1732f0507a6c3bed9bf65/scikit_learn-1.7.2-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:f95dc55b7902b91331fa4e5845dd5bde0580c9cd9612b1b2791b7e80c3d32615", size = 8627836, upload-time = "2025-09-09T08:21:20.695Z" }, - { url = "https://files.pythonhosted.org/packages/60/18/4a52c635c71b536879f4b971c2cedf32c35ee78f48367885ed8025d1f7ee/scikit_learn-1.7.2-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9656e4a53e54578ad10a434dc1f993330568cfee176dff07112b8785fb413106", size = 9426236, upload-time = "2025-09-09T08:21:22.645Z" }, - { url = "https://files.pythonhosted.org/packages/99/7e/290362f6ab582128c53445458a5befd471ed1ea37953d5bcf80604619250/scikit_learn-1.7.2-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:96dc05a854add0e50d3f47a1ef21a10a595016da5b007c7d9cd9d0bffd1fcc61", size = 9312593, upload-time = "2025-09-09T08:21:24.65Z" }, - { url = "https://files.pythonhosted.org/packages/8e/87/24f541b6d62b1794939ae6422f8023703bbf6900378b2b34e0b4384dfefd/scikit_learn-1.7.2-cp314-cp314-win_amd64.whl", hash = "sha256:bb24510ed3f9f61476181e4db51ce801e2ba37541def12dc9333b946fc7a9cf8", size = 8820007, upload-time = "2025-09-09T08:21:26.713Z" }, -] - [[package]] name = "scipy" version = "1.15.3" @@ -2339,15 +2235,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" }, ] -[[package]] -name = "threadpoolctl" -version = "3.6.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b7/4d/08c89e34946fce2aec4fbb45c9016efd5f4d7f24af8e5d93296e935631d8/threadpoolctl-3.6.0.tar.gz", hash = "sha256:8ab8b4aa3491d812b623328249fab5302a68d2d71745c8a4c719a2fcaba9f44e", size = 21274, upload-time = "2025-03-13T13:49:23.031Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638, upload-time = "2025-03-13T13:49:21.846Z" }, -] - [[package]] name = "tokenize-rt" version = "6.2.0"