Skip to content
Merged
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:


- name: 'Build only'
uses: shalzz/zola-deploy-action@master
uses: shalzz/zola-deploy-action@v0.22.1
env:
BUILD_DIR: docs/website/
TOKEN: ${{ secrets.TOKEN }}
Expand Down
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ linfa-datasets = { path = "datasets", features = [
"generate",
] }
statrs = { git = "https://github.com/statrs-dev/statrs", branch = "master" }
linfa-linear = { path = "algorithms/linfa-linear" }
linfa-svm = { path = "algorithms/linfa-svm" }

[target.'cfg(not(windows))'.dependencies]
pprof = { version = "0.15", features = [
Expand Down
134 changes: 134 additions & 0 deletions algorithms/linfa-ensemble/examples/adaboost_iris.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
use linfa::prelude::{Fit, Predict, ToConfusionMatrix};
use linfa_ensemble::AdaBoostParams;
use linfa_trees::DecisionTree;
use ndarray_rand::rand::SeedableRng;
use rand::rngs::SmallRng;

fn adaboost_with_stumps(n_estimators: usize, learning_rate: f64) {
// Load dataset
let mut rng = SmallRng::seed_from_u64(42);
let (train, test) = linfa_datasets::iris()
.shuffle(&mut rng)
.split_with_ratio(0.8);

// Train AdaBoost model with decision tree stumps (max_depth=1)
// Stumps are weak learners commonly used with AdaBoost
let model = AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(1)), rng)
.n_estimators(n_estimators)
.learning_rate(learning_rate)
.fit(&train)
.unwrap();

// Make predictions
let predictions = model.predict(&test);
println!("Final Predictions: \n{predictions:?}");

let cm = predictions.confusion_matrix(&test).unwrap();
println!("{cm:?}");
println!(
"Test accuracy: {:.2}%\nwith Decision Tree stumps (max_depth=1),\nn_estimators: {n_estimators},\nlearning_rate: {learning_rate}.\n",
100.0 * cm.accuracy()
);
println!("Number of models trained: {}", model.n_estimators());
}

fn adaboost_with_shallow_trees(n_estimators: usize, learning_rate: f64, max_depth: usize) {
let mut rng = SmallRng::seed_from_u64(42);
let (train, test) = linfa_datasets::iris()
.shuffle(&mut rng)
.split_with_ratio(0.8);

// Train AdaBoost model with shallow decision trees
let model =
AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(max_depth)), rng)
.n_estimators(n_estimators)
.learning_rate(learning_rate)
.fit(&train)
.unwrap();

// Make predictions
let predictions = model.predict(&test);
println!("Final Predictions: \n{predictions:?}");

let cm = predictions.confusion_matrix(&test).unwrap();
println!("{cm:?}");
println!(
"Test accuracy: {:.2}%\nwith Decision Trees (max_depth={max_depth}),\nn_estimators: {n_estimators},\nlearning_rate: {learning_rate}.\n",
100.0 * cm.accuracy()
);

// Display model weights
println!("Model weights (alpha values):");
for (i, weight) in model.weights().iter().enumerate() {
println!(" Model {}: {:.4}", i + 1, weight);
}
println!();
}

fn main() {
println!("{}", "=".repeat(80));
println!("AdaBoost Examples on Iris Dataset");
println!("{}", "=".repeat(80));
println!();

// Example 1: AdaBoost with decision stumps (most common configuration)
println!("Example 1: AdaBoost with Decision Stumps");
println!("{}", "-".repeat(80));
adaboost_with_stumps(50, 1.0);
println!();

// Example 2: AdaBoost with lower learning rate
println!("Example 2: AdaBoost with Lower Learning Rate");
println!("{}", "-".repeat(80));
adaboost_with_stumps(100, 0.5);
println!();

// Example 3: AdaBoost with shallow trees
println!("Example 3: AdaBoost with Shallow Decision Trees");
println!("{}", "-".repeat(80));
adaboost_with_shallow_trees(50, 1.0, 2);
println!();

// Example 4: Comparing different configurations
println!("Example 4: Comparing Configurations");
println!("{}", "-".repeat(80));
let configs = vec![
(25, 1.0, 1, "Few stumps, high learning rate"),
(50, 1.0, 1, "Medium stumps, high learning rate"),
(100, 0.5, 1, "Many stumps, low learning rate"),
(50, 1.0, 2, "Shallow trees, high learning rate"),
];

for (n_est, lr, depth, desc) in configs {
let mut rng = SmallRng::seed_from_u64(42);
let (train, test) = linfa_datasets::iris()
.shuffle(&mut rng)
.split_with_ratio(0.8);

let model =
AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(depth)), rng)
.n_estimators(n_est)
.learning_rate(lr)
.fit(&train)
.unwrap();

let predictions = model.predict(&test);
let cm = predictions.confusion_matrix(&test).unwrap();

println!(
"{desc:50} => Accuracy: {:.2}% (models trained: {})",
100.0 * cm.accuracy(),
model.n_estimators()
);
}

println!();
println!("{}", "=".repeat(80));
println!("Notes:");
println!("- AdaBoost works by training weak learners sequentially");
println!("- Each learner focuses on samples misclassified by previous learners");
println!("- Decision stumps (depth=1) are the most common weak learners");
println!("- Lower learning_rate provides regularization but needs more estimators");
println!("- Model weights (alpha) reflect each learner's contribution to prediction");
println!("{}", "=".repeat(80));
}
28 changes: 16 additions & 12 deletions algorithms/linfa-ensemble/src/adaboost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use linfa::{
dataset::{AsTargets, AsTargetsMut, FromTargetArrayOwned},
error::Error,
traits::*,
DatasetBase, ParamGuard,
DatasetBase,
};
use ndarray::{Array1, Array2, Axis};
use ndarray_rand::rand::distr::weighted::WeightedIndex;
Expand Down Expand Up @@ -155,23 +155,27 @@ where
}
}

impl<D, T, P, R> Fit<Array2<D>, T, Error> for AdaBoostValidParams<P, R>
// Mirrors the bounds shape of `EnsembleLearnerValidParams` (algorithm.rs:137)
// with one extra wrinkle: AdaBoost calls `predict_inplace` on the inner model
// during fit (to compute weighted error), so the inner model type needs a
// `PredictInplace` bound. Naming that bound directly as `P::Object: ...` would
// re-trigger Linfa's `ParamGuard` blanket-impl recursion under ndarray 0.17 by
// forcing the trait solver to resolve `P: Fit` to compute `P::Object`. We
// avoid that by introducing a separate type parameter `M` and binding the
// associated type via `Fit<..., Object = M>` so the model type is named at
// the impl signature (where the solver doesn't have to chase blankets).
impl<D, T, P, M, R> Fit<Array2<D>, T, Error> for AdaBoostValidParams<P, R>
where
D: Clone + ndarray::ScalarOperand,
T: FromTargetArrayOwned<Owned = T> + AsTargets + Clone,
T: FromTargetArrayOwned + AsTargets,
T::Elem: Copy + Eq + Hash + std::fmt::Debug + Into<usize>,
P: linfa::ParamGuard + Clone,
<P as linfa::ParamGuard>::Checked: Fit<Array2<D>, T, Error>,
Error: From<<P as linfa::ParamGuard>::Error>,
<<P as linfa::ParamGuard>::Checked as Fit<Array2<D>, T, Error>>::Object:
PredictInplace<Array2<D>, T>,
T::Owned: AsTargets<Elem = <T as AsTargets>::Elem>,
P: Fit<Array2<D>, T::Owned, Error, Object = M>,
M: PredictInplace<Array2<D>, T::Owned>,
R: Rng + Clone,
usize: Into<T::Elem>,
{
type Object = AdaBoost<
<<P as linfa::ParamGuard>::Checked as Fit<Array2<D>, T, Error>>::Object,
T::Elem,
>;
type Object = AdaBoost<M, T::Elem>;

fn fit(
&self,
Expand Down
28 changes: 13 additions & 15 deletions algorithms/linfa-ensemble/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//! This crate (`linfa-ensemble`), provides pure Rust implementations of popular ensemble techniques, such as
//! * [Boostrap Aggregation](EnsembleLearner)
//! * [Random Forest](RandomForest)
//! * [AdaBoost]
//!
//! ## Bootstrap Aggregation (aka Bagging)
//!
Expand All @@ -18,6 +19,14 @@
//! selection. A typical number of random prediction to be selected is $\sqrt{p}$ with $p$ being
//! the number of available features.
//!
//! ## AdaBoost
//!
//! AdaBoost (Adaptive Boosting) is a boosting ensemble method that trains weak learners sequentially.
//! Each subsequent learner focuses on the examples that previous learners misclassified by increasing
//! their sample weights. The final prediction is a weighted vote of all learners, where better-performing
//! learners receive higher weights. Unlike bagging methods, boosting creates a strong classifier from
//! weak learners (typically shallow decision trees or "stumps").
//!
//! ## Reference
//!
//! * [Scikit-Learn User Guide](https://scikit-learn.org/stable/modules/ensemble.html)
Expand Down Expand Up @@ -81,16 +90,13 @@
//! let predictions = random_forest.predict(&test);
//! ```

// NOTE: AdaBoost (mod adaboost, mod adaboost_hyperparams) is temporarily orphaned
// pending a structural rework of its `Fit` impl trait bounds — the current
// `where P: Fit + ParamGuard` shape triggers infinite recursion in Rust's trait
// solver under ndarray 0.17 because Linfa's `Fit` blanket impl requires
// `P::Checked: Fit` which itself only resolves through the same blanket. See
// the open issue tracking the rework. Source files remain in tree for the
// follow-up PR to revive.
mod adaboost;
mod adaboost_hyperparams;
mod algorithm;
mod hyperparams;

pub use adaboost::*;
pub use adaboost_hyperparams::*;
pub use algorithm::*;
pub use hyperparams::*;

Expand Down Expand Up @@ -143,8 +149,6 @@ mod tests {
assert!(acc >= 0.9, "Expected accuracy to be above 90%, got {}", acc);
}

// AdaBoost tests removed — tracked under follow-up issue (see note in lib.rs above).
#[cfg(any())] // disabled until AdaBoost trait bounds are reworked
#[test]
fn test_adaboost_accuracy_on_iris_dataset() {
let mut rng = SmallRng::seed_from_u64(42);
Expand All @@ -170,7 +174,6 @@ mod tests {
);
}

#[cfg(any())]
#[test]
fn test_adaboost_with_low_learning_rate() {
let mut rng = SmallRng::seed_from_u64(42);
Expand All @@ -196,7 +199,6 @@ mod tests {
);
}

#[cfg(any())]
#[test]
fn test_adaboost_model_weights() {
let mut rng = SmallRng::seed_from_u64(42);
Expand All @@ -222,7 +224,6 @@ mod tests {
assert_eq!(model.n_estimators(), 10);
}

#[cfg(any())]
#[test]
fn test_adaboost_different_learning_rates() {
// Test that different learning rates produce different model weights
Expand Down Expand Up @@ -265,7 +266,6 @@ mod tests {
);
}

#[cfg(any())]
#[test]
fn test_adaboost_early_stopping_on_perfect_fit() {
use linfa::DatasetBase;
Expand Down Expand Up @@ -301,7 +301,6 @@ mod tests {
);
}

#[cfg(any())]
#[test]
fn test_adaboost_single_class_error() {
use linfa::DatasetBase;
Expand All @@ -321,7 +320,6 @@ mod tests {
assert!(result.is_err(), "Should fail with single class dataset");
}

#[cfg(any())]
#[test]
fn test_adaboost_classes_method() {
let mut rng = SmallRng::seed_from_u64(42);
Expand Down
Loading
Loading