diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 3ff991e12..c3139ea7f 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -18,7 +18,7 @@ jobs: - name: 'Build only' - uses: shalzz/zola-deploy-action@master + uses: shalzz/zola-deploy-action@v0.22.1 env: BUILD_DIR: docs/website/ TOKEN: ${{ secrets.TOKEN }} diff --git a/Cargo.toml b/Cargo.toml index 90430e99b..2e2f13b7c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,6 +64,8 @@ linfa-datasets = { path = "datasets", features = [ "generate", ] } statrs = { git = "https://github.com/statrs-dev/statrs", branch = "master" } +linfa-linear = { path = "algorithms/linfa-linear" } +linfa-svm = { path = "algorithms/linfa-svm" } [target.'cfg(not(windows))'.dependencies] pprof = { version = "0.15", features = [ diff --git a/algorithms/linfa-ensemble/examples/adaboost_iris.rs b/algorithms/linfa-ensemble/examples/adaboost_iris.rs new file mode 100644 index 000000000..1a76d57c6 --- /dev/null +++ b/algorithms/linfa-ensemble/examples/adaboost_iris.rs @@ -0,0 +1,134 @@ +use linfa::prelude::{Fit, Predict, ToConfusionMatrix}; +use linfa_ensemble::AdaBoostParams; +use linfa_trees::DecisionTree; +use ndarray_rand::rand::SeedableRng; +use rand::rngs::SmallRng; + +fn adaboost_with_stumps(n_estimators: usize, learning_rate: f64) { + // Load dataset + let mut rng = SmallRng::seed_from_u64(42); + let (train, test) = linfa_datasets::iris() + .shuffle(&mut rng) + .split_with_ratio(0.8); + + // Train AdaBoost model with decision tree stumps (max_depth=1) + // Stumps are weak learners commonly used with AdaBoost + let model = AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(1)), rng) + .n_estimators(n_estimators) + .learning_rate(learning_rate) + .fit(&train) + .unwrap(); + + // Make predictions + let predictions = model.predict(&test); + println!("Final Predictions: \n{predictions:?}"); + + let cm = predictions.confusion_matrix(&test).unwrap(); + println!("{cm:?}"); + println!( + "Test accuracy: {:.2}%\nwith Decision Tree stumps (max_depth=1),\nn_estimators: {n_estimators},\nlearning_rate: {learning_rate}.\n", + 100.0 * cm.accuracy() + ); + println!("Number of models trained: {}", model.n_estimators()); +} + +fn adaboost_with_shallow_trees(n_estimators: usize, learning_rate: f64, max_depth: usize) { + let mut rng = SmallRng::seed_from_u64(42); + let (train, test) = linfa_datasets::iris() + .shuffle(&mut rng) + .split_with_ratio(0.8); + + // Train AdaBoost model with shallow decision trees + let model = + AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(max_depth)), rng) + .n_estimators(n_estimators) + .learning_rate(learning_rate) + .fit(&train) + .unwrap(); + + // Make predictions + let predictions = model.predict(&test); + println!("Final Predictions: \n{predictions:?}"); + + let cm = predictions.confusion_matrix(&test).unwrap(); + println!("{cm:?}"); + println!( + "Test accuracy: {:.2}%\nwith Decision Trees (max_depth={max_depth}),\nn_estimators: {n_estimators},\nlearning_rate: {learning_rate}.\n", + 100.0 * cm.accuracy() + ); + + // Display model weights + println!("Model weights (alpha values):"); + for (i, weight) in model.weights().iter().enumerate() { + println!(" Model {}: {:.4}", i + 1, weight); + } + println!(); +} + +fn main() { + println!("{}", "=".repeat(80)); + println!("AdaBoost Examples on Iris Dataset"); + println!("{}", "=".repeat(80)); + println!(); + + // Example 1: AdaBoost with decision stumps (most common configuration) + println!("Example 1: AdaBoost with Decision Stumps"); + println!("{}", "-".repeat(80)); + adaboost_with_stumps(50, 1.0); + println!(); + + // Example 2: AdaBoost with lower learning rate + println!("Example 2: AdaBoost with Lower Learning Rate"); + println!("{}", "-".repeat(80)); + adaboost_with_stumps(100, 0.5); + println!(); + + // Example 3: AdaBoost with shallow trees + println!("Example 3: AdaBoost with Shallow Decision Trees"); + println!("{}", "-".repeat(80)); + adaboost_with_shallow_trees(50, 1.0, 2); + println!(); + + // Example 4: Comparing different configurations + println!("Example 4: Comparing Configurations"); + println!("{}", "-".repeat(80)); + let configs = vec![ + (25, 1.0, 1, "Few stumps, high learning rate"), + (50, 1.0, 1, "Medium stumps, high learning rate"), + (100, 0.5, 1, "Many stumps, low learning rate"), + (50, 1.0, 2, "Shallow trees, high learning rate"), + ]; + + for (n_est, lr, depth, desc) in configs { + let mut rng = SmallRng::seed_from_u64(42); + let (train, test) = linfa_datasets::iris() + .shuffle(&mut rng) + .split_with_ratio(0.8); + + let model = + AdaBoostParams::new_fixed_rng(DecisionTree::params().max_depth(Some(depth)), rng) + .n_estimators(n_est) + .learning_rate(lr) + .fit(&train) + .unwrap(); + + let predictions = model.predict(&test); + let cm = predictions.confusion_matrix(&test).unwrap(); + + println!( + "{desc:50} => Accuracy: {:.2}% (models trained: {})", + 100.0 * cm.accuracy(), + model.n_estimators() + ); + } + + println!(); + println!("{}", "=".repeat(80)); + println!("Notes:"); + println!("- AdaBoost works by training weak learners sequentially"); + println!("- Each learner focuses on samples misclassified by previous learners"); + println!("- Decision stumps (depth=1) are the most common weak learners"); + println!("- Lower learning_rate provides regularization but needs more estimators"); + println!("- Model weights (alpha) reflect each learner's contribution to prediction"); + println!("{}", "=".repeat(80)); +} diff --git a/algorithms/linfa-ensemble/src/adaboost.rs b/algorithms/linfa-ensemble/src/adaboost.rs index 357fd262f..c4f6181c3 100644 --- a/algorithms/linfa-ensemble/src/adaboost.rs +++ b/algorithms/linfa-ensemble/src/adaboost.rs @@ -3,7 +3,7 @@ use linfa::{ dataset::{AsTargets, AsTargetsMut, FromTargetArrayOwned}, error::Error, traits::*, - DatasetBase, ParamGuard, + DatasetBase, }; use ndarray::{Array1, Array2, Axis}; use ndarray_rand::rand::distr::weighted::WeightedIndex; @@ -155,23 +155,27 @@ where } } -impl Fit, T, Error> for AdaBoostValidParams +// Mirrors the bounds shape of `EnsembleLearnerValidParams` (algorithm.rs:137) +// with one extra wrinkle: AdaBoost calls `predict_inplace` on the inner model +// during fit (to compute weighted error), so the inner model type needs a +// `PredictInplace` bound. Naming that bound directly as `P::Object: ...` would +// re-trigger Linfa's `ParamGuard` blanket-impl recursion under ndarray 0.17 by +// forcing the trait solver to resolve `P: Fit` to compute `P::Object`. We +// avoid that by introducing a separate type parameter `M` and binding the +// associated type via `Fit<..., Object = M>` so the model type is named at +// the impl signature (where the solver doesn't have to chase blankets). +impl Fit, T, Error> for AdaBoostValidParams where D: Clone + ndarray::ScalarOperand, - T: FromTargetArrayOwned + AsTargets + Clone, + T: FromTargetArrayOwned + AsTargets, T::Elem: Copy + Eq + Hash + std::fmt::Debug + Into, - P: linfa::ParamGuard + Clone, -

::Checked: Fit, T, Error>, - Error: From<

::Error>, - <

::Checked as Fit, T, Error>>::Object: - PredictInplace, T>, + T::Owned: AsTargets::Elem>, + P: Fit, T::Owned, Error, Object = M>, + M: PredictInplace, T::Owned>, R: Rng + Clone, usize: Into, { - type Object = AdaBoost< - <

::Checked as Fit, T, Error>>::Object, - T::Elem, - >; + type Object = AdaBoost; fn fit( &self, diff --git a/algorithms/linfa-ensemble/src/lib.rs b/algorithms/linfa-ensemble/src/lib.rs index 9573e2d9b..87f129a78 100644 --- a/algorithms/linfa-ensemble/src/lib.rs +++ b/algorithms/linfa-ensemble/src/lib.rs @@ -6,6 +6,7 @@ //! This crate (`linfa-ensemble`), provides pure Rust implementations of popular ensemble techniques, such as //! * [Boostrap Aggregation](EnsembleLearner) //! * [Random Forest](RandomForest) +//! * [AdaBoost] //! //! ## Bootstrap Aggregation (aka Bagging) //! @@ -18,6 +19,14 @@ //! selection. A typical number of random prediction to be selected is $\sqrt{p}$ with $p$ being //! the number of available features. //! +//! ## AdaBoost +//! +//! AdaBoost (Adaptive Boosting) is a boosting ensemble method that trains weak learners sequentially. +//! Each subsequent learner focuses on the examples that previous learners misclassified by increasing +//! their sample weights. The final prediction is a weighted vote of all learners, where better-performing +//! learners receive higher weights. Unlike bagging methods, boosting creates a strong classifier from +//! weak learners (typically shallow decision trees or "stumps"). +//! //! ## Reference //! //! * [Scikit-Learn User Guide](https://scikit-learn.org/stable/modules/ensemble.html) @@ -81,16 +90,13 @@ //! let predictions = random_forest.predict(&test); //! ``` -// NOTE: AdaBoost (mod adaboost, mod adaboost_hyperparams) is temporarily orphaned -// pending a structural rework of its `Fit` impl trait bounds — the current -// `where P: Fit + ParamGuard` shape triggers infinite recursion in Rust's trait -// solver under ndarray 0.17 because Linfa's `Fit` blanket impl requires -// `P::Checked: Fit` which itself only resolves through the same blanket. See -// the open issue tracking the rework. Source files remain in tree for the -// follow-up PR to revive. +mod adaboost; +mod adaboost_hyperparams; mod algorithm; mod hyperparams; +pub use adaboost::*; +pub use adaboost_hyperparams::*; pub use algorithm::*; pub use hyperparams::*; @@ -143,8 +149,6 @@ mod tests { assert!(acc >= 0.9, "Expected accuracy to be above 90%, got {}", acc); } - // AdaBoost tests removed — tracked under follow-up issue (see note in lib.rs above). - #[cfg(any())] // disabled until AdaBoost trait bounds are reworked #[test] fn test_adaboost_accuracy_on_iris_dataset() { let mut rng = SmallRng::seed_from_u64(42); @@ -170,7 +174,6 @@ mod tests { ); } - #[cfg(any())] #[test] fn test_adaboost_with_low_learning_rate() { let mut rng = SmallRng::seed_from_u64(42); @@ -196,7 +199,6 @@ mod tests { ); } - #[cfg(any())] #[test] fn test_adaboost_model_weights() { let mut rng = SmallRng::seed_from_u64(42); @@ -222,7 +224,6 @@ mod tests { assert_eq!(model.n_estimators(), 10); } - #[cfg(any())] #[test] fn test_adaboost_different_learning_rates() { // Test that different learning rates produce different model weights @@ -265,7 +266,6 @@ mod tests { ); } - #[cfg(any())] #[test] fn test_adaboost_early_stopping_on_perfect_fit() { use linfa::DatasetBase; @@ -301,7 +301,6 @@ mod tests { ); } - #[cfg(any())] #[test] fn test_adaboost_single_class_error() { use linfa::DatasetBase; @@ -321,7 +320,6 @@ mod tests { assert!(result.is_err(), "Should fail with single class dataset"); } - #[cfg(any())] #[test] fn test_adaboost_classes_method() { let mut rng = SmallRng::seed_from_u64(42); diff --git a/algorithms/linfa-logistic/src/lib.rs b/algorithms/linfa-logistic/src/lib.rs index cc112b6d0..ea2386d91 100644 --- a/algorithms/linfa-logistic/src/lib.rs +++ b/algorithms/linfa-logistic/src/lib.rs @@ -301,32 +301,33 @@ where { let y = y.as_single_targets(); - // counts the instances of two distinct class labels let mut binary_classes = [None, None]; - // find binary classes of our target dataset for class in y { binary_classes = match binary_classes { - // count the first class label [None, None] => [Some((class, 1)), None], - // if the class has already been counted, increment the count [Some((c, count)), c2] if c == class => [Some((class, count + 1)), c2], [c1, Some((c, count))] if c == class => [c1, Some((class, count + 1))], - // count the second class label [Some(c1), None] => [Some(c1), Some((class, 1))], - - // should not be possible [None, Some(_)] => unreachable!("impossible binary class array"), - // found 3rd distinct class [Some(_), Some(_)] => return Err(Error::TooManyClasses), }; } - let (pos_class, neg_class) = match binary_classes { + let (class_a, class_b) = match binary_classes { [Some(a), Some(b)] => (a, b), _ => return Err(Error::TooFewClasses), }; - let mut target_array = y + // Sort by label value (Ord), not by encounter order or count. + // The smaller label is always negative (-1), + // the larger label is always positive (+1). + let (neg_class, pos_class) = if class_a.0 < class_b.0 { + (class_a, class_b) + } else { + (class_b, class_a) + }; + + let target_array = y .into_iter() .map(|x| { if x == pos_class.0 { @@ -337,24 +338,14 @@ where }) .collect::>(); - let (pos_cl, neg_cl) = if pos_class.1 < neg_class.1 { - // If we found the larger class first, flip the sign in the target - // vector, so that -1.0 is always the label for the smaller class - // and 1.0 the label for the larger class - target_array *= -F::one(); - (neg_class.0.clone(), pos_class.0.clone()) - } else { - (pos_class.0.clone(), neg_class.0.clone()) - }; - Ok(( BinaryClassLabels { pos: ClassLabel { - class: pos_cl, + class: pos_class.0.clone(), label: F::POSITIVE_LABEL, }, neg: ClassLabel { - class: neg_cl, + class: neg_class.0.clone(), label: F::NEGATIVE_LABEL, }, }, @@ -989,7 +980,7 @@ mod test { let dataset = Dataset::new(x, y); let res = log_reg.fit(&dataset).unwrap(); assert_abs_diff_eq!(res.intercept(), 0.0); - assert!(res.params().abs_diff_eq(&array![-0.681], 1e-3)); + assert!(res.params().abs_diff_eq(&array![0.681], 1e-3)); assert_eq!( &res.predict(dataset.records()), dataset.targets().as_single_targets() @@ -1172,7 +1163,7 @@ mod test { let dataset = Dataset::new(x, y); let res = log_reg.fit(&dataset).unwrap(); assert_abs_diff_eq!(res.intercept(), 0.0_f32); - assert!(res.params().abs_diff_eq(&array![-0.682_f32], 1e-3)); + assert!(res.params().abs_diff_eq(&array![0.682_f32], 1e-3)); assert_eq!( &res.predict(dataset.records()), dataset.targets().as_single_targets() @@ -1375,4 +1366,28 @@ mod test { } )); } + + #[test] + fn label_order_independent() { + let x1 = array![[-1.0], [1.0], [-0.5], [0.5]]; + let y1 = array!["cat", "dog", "cat", "dog"]; + + let x2 = array![[1.0], [-1.0], [0.5], [-0.5]]; + let y2 = array!["dog", "cat", "dog", "cat"]; + + let model1 = LogisticRegression::default() + .fit(&Dataset::new(x1, y1)) + .unwrap(); + let model2 = LogisticRegression::default() + .fit(&Dataset::new(x2, y2)) + .unwrap(); + + assert_eq!(model1.labels().pos.class, "dog"); + assert_eq!(model1.labels().neg.class, "cat"); + assert_eq!(model2.labels().pos.class, "dog"); + assert_eq!(model2.labels().neg.class, "cat"); + + assert_abs_diff_eq!(model1.intercept(), model2.intercept()); + assert!(model1.params().abs_diff_eq(model2.params(), 1e-6)); + } } diff --git a/algorithms/linfa-preprocessing/src/error.rs b/algorithms/linfa-preprocessing/src/error.rs index 7d0089c8d..e99121b48 100644 --- a/algorithms/linfa-preprocessing/src/error.rs +++ b/algorithms/linfa-preprocessing/src/error.rs @@ -13,9 +13,9 @@ pub enum PreprocessingError { NotEnoughSamples, #[error("not a valid float")] InvalidFloat, - #[error("minimum value for MinMax scaler cannot be greater than the maximum")] - TokenizerNotSet, #[error("Tokenizer must be defined after deserializing CountVectorizer by calling force_tokenizer_redefinition")] + TokenizerNotSet, + #[error("minimum value for MinMax scaler cannot be greater than the maximum")] FlippedMinMaxRange, #[error("n_gram boundaries cannot be zero (min = {0}, max = {1})")] InvalidNGramBoundaries(usize, usize), diff --git a/algorithms/linfa-tsne/Cargo.toml b/algorithms/linfa-tsne/Cargo.toml index e537589bd..f77a2191a 100644 --- a/algorithms/linfa-tsne/Cargo.toml +++ b/algorithms/linfa-tsne/Cargo.toml @@ -16,14 +16,14 @@ categories = ["algorithms", "mathematics", "science"] [dependencies] thiserror = "2.0" ndarray = { version = "0.17" } -ndarray-rand = "0.16" -bhtsne = "0.4.0" -pdqselect = "=0.1.1" +bhtsne = { version = "0.5.4", default-features = false } linfa = { version = "0.8.1", path = "../.." } +linfa-nn = { version = "0.8.1", path = "../linfa-nn" } [dev-dependencies] rand = "0.9" +ndarray-rand = "0.16" approx = "0.5" linfa-datasets = { version = "0.8.1", path = "../../datasets", features = [ diff --git a/algorithms/linfa-tsne/src/hyperparams.rs b/algorithms/linfa-tsne/src/hyperparams.rs index ac2f8da7c..270961be7 100644 --- a/algorithms/linfa-tsne/src/hyperparams.rs +++ b/algorithms/linfa-tsne/src/hyperparams.rs @@ -1,5 +1,4 @@ use linfa::{Float, ParamGuard}; -use ndarray_rand::rand::{rngs::SmallRng, Rng, SeedableRng}; use crate::TSneError; @@ -32,16 +31,16 @@ use crate::TSneError; /// /// A verified hyper-parameter set ready for prediction #[derive(Debug, Clone, PartialEq)] -pub struct TSneValidParams { +pub struct TSneValidParams { embedding_size: usize, approx_threshold: F, perplexity: F, max_iter: usize, preliminary_iter: Option, - rng: R, + metric: D, } -impl TSneValidParams { +impl TSneValidParams { pub fn embedding_size(&self) -> usize { self.embedding_size } @@ -62,45 +61,46 @@ impl TSneValidParams { &self.preliminary_iter } - pub fn rng(&self) -> &R { - &self.rng + pub fn metric(&self) -> &D { + &self.metric } } #[derive(Debug, Clone, PartialEq)] -pub struct TSneParams(TSneValidParams); +pub struct TSneParams(TSneValidParams); -impl TSneParams { +impl TSneParams { /// Create a t-SNE param set with given embedding size /// /// # Defaults to: /// * `approx_threshold`: 0.5 /// * `perplexity`: 5.0 /// * `max_iter`: 2000 - /// * `rng`: SmallRng with seed 42 - pub fn embedding_size(embedding_size: usize) -> TSneParams { - Self::embedding_size_with_rng(embedding_size, SmallRng::seed_from_u64(42)) + pub fn embedding_size(embedding_size: usize) -> TSneParams { + Self::embedding_size_with_metric(embedding_size, linfa_nn::distance::L2Dist) } } -impl TSneParams { - /// Create a t-SNE param set with given embedding size and random number generator +impl> TSneParams { + /// Create a t-SNE param set with given embedding size and distance metric /// /// # Defaults to: /// * `approx_threshold`: 0.5 /// * `perplexity`: 5.0 /// * `max_iter`: 2000 - pub fn embedding_size_with_rng(embedding_size: usize, rng: R) -> TSneParams { + pub fn embedding_size_with_metric(embedding_size: usize, metric: D) -> Self { Self(TSneValidParams { embedding_size, - rng, approx_threshold: F::cast(0.5), perplexity: F::cast(5.0), max_iter: 2000, preliminary_iter: None, + metric, }) } +} +impl TSneParams { /// Set the approximation threshold of the Barnes Hut algorithm /// /// The threshold decides whether a cluster centroid can be used as a summary for the whole @@ -139,8 +139,8 @@ impl TSneParams { } } -impl ParamGuard for TSneParams { - type Checked = TSneValidParams; +impl ParamGuard for TSneParams { + type Checked = TSneValidParams; type Error = TSneError; /// Validates parameters diff --git a/algorithms/linfa-tsne/src/lib.rs b/algorithms/linfa-tsne/src/lib.rs index d175f1c66..2279e6d0c 100644 --- a/algorithms/linfa-tsne/src/lib.rs +++ b/algorithms/linfa-tsne/src/lib.rs @@ -1,8 +1,8 @@ #![doc = include_str!("../README.md")] +use std::convert::TryFrom; -use ndarray::Array2; -use ndarray_rand::rand::Rng; -use ndarray_rand::rand_distr::Normal; +use linfa_nn::distance::Distance; +use ndarray::{Array2, ArrayView1}; use linfa::{dataset::DatasetBase, traits::Transformer, Float, ParamGuard}; @@ -12,8 +12,8 @@ mod hyperparams; pub use error::{Result, TSneError}; pub use hyperparams::{TSneParams, TSneValidParams}; -impl Transformer, Result>> for TSneValidParams { - fn transform(&self, mut data: Array2) -> Result> { +impl> Transformer, Result>> for TSneValidParams { + fn transform(&self, data: Array2) -> Result> { let (nfeatures, nsamples) = (data.ncols(), data.nrows()); // validate parameter-data constraints @@ -21,6 +21,10 @@ impl Transformer, Result>> for TSn return Err(TSneError::EmbeddingSizeTooLarge); } + let Ok(embedding_size) = u8::try_from(self.embedding_size()) else { + return Err(TSneError::EmbeddingSizeTooLarge); + }; + if F::cast(nsamples - 1) < F::cast(3) * self.perplexity() { return Err(TSneError::PerplexityTooLarge); } @@ -31,43 +35,47 @@ impl Transformer, Result>> for TSn None => usize::min(self.max_iter() / 2, 250), }; - let data = data.as_slice_mut().unwrap(); - - let mut rng = self.rng().clone(); - let normal = Normal::new(0.0, 1e-4 * 10e-4).unwrap(); - - let mut embedding: Vec = (0..nsamples * self.embedding_size()) - .map(|_| rng.sample(normal)) - .map(F::cast) - .collect(); - - bhtsne::run( - data, - nsamples, - nfeatures, - &mut embedding, - self.embedding_size(), - self.perplexity(), - self.approx_threshold(), - true, - self.max_iter() as u64, - preliminary_iter as u64, - preliminary_iter as u64, - ); + let data: Vec<_> = data.as_slice().unwrap().chunks(nfeatures).collect(); + + let mut tsne = bhtsne::tSNE::new(&data); + let tsne = tsne + .embedding_dim(embedding_size) + .perplexity(self.perplexity()) + .epochs(self.max_iter()) + .stop_lying_epoch(preliminary_iter) + .momentum_switch_epoch(preliminary_iter); + + let tsne = if self.approx_threshold() <= F::zero() { + // compute exact t-SNE + tsne.exact(|a, b| { + let a = ArrayView1::from(a); + let b = ArrayView1::from(b); + self.metric().distance(a, b) + }) + } else { + // compute barnes-hut t-SNE + tsne.barnes_hut(self.approx_threshold(), |a, b| { + let a = ArrayView1::from(a); + let b = ArrayView1::from(b); + self.metric().distance(a, b) + }) + }; + + let embedding = tsne.embedding(); Array2::from_shape_vec((nsamples, self.embedding_size()), embedding).map_err(|e| e.into()) } } -impl Transformer, Result>> for TSneParams { +impl> Transformer, Result>> for TSneParams { fn transform(&self, x: Array2) -> Result> { self.check_ref()?.transform(x) } } -impl +impl> Transformer, T>, Result, T>>> - for TSneValidParams + for TSneValidParams { fn transform(&self, ds: DatasetBase, T>) -> Result, T>> { let DatasetBase { @@ -82,8 +90,8 @@ impl } } -impl - Transformer, T>, Result, T>>> for TSneParams +impl> + Transformer, T>, Result, T>>> for TSneParams { fn transform(&self, ds: DatasetBase, T>) -> Result, T>> { self.check_ref()?.transform(ds) @@ -103,17 +111,16 @@ mod tests { #[test] fn autotraits() { fn has_autotraits() {} - has_autotraits::>>(); - has_autotraits::>>(); + has_autotraits::>(); + has_autotraits::>(); has_autotraits::(); } #[test] fn iris_separate() -> Result<()> { let ds = linfa_datasets::iris(); - let rng = SmallRng::seed_from_u64(42); - let ds = TSneParams::embedding_size_with_rng(2, rng) + let ds = TSneParams::embedding_size(2) .perplexity(10.0) .approx_threshold(0.0) .transform(ds)?; @@ -123,6 +130,19 @@ mod tests { Ok(()) } + #[test] + fn iris_separate_bharnes_hut() -> Result<()> { + let ds = linfa_datasets::iris(); + + let ds = TSneParams::embedding_size(2) + .perplexity(10.0) + .transform(ds)?; + + assert!(ds.silhouette_score()? > 0.4); + + Ok(()) + } + #[test] fn blob_separate() -> Result<()> { let mut rng = SmallRng::seed_from_u64(42); @@ -137,7 +157,7 @@ mod tests { let targets = (0..200).map(|x| x < 100).collect::>(); let dataset = Dataset::new(entries, targets); - let ds = TSneParams::embedding_size_with_rng(2, rng) + let ds = TSneParams::embedding_size(2) .perplexity(60.0) .approx_threshold(0.0) .transform(dataset)?; diff --git a/docs/website/config.toml b/docs/website/config.toml index 9a3851ad6..322a487e4 100644 --- a/docs/website/config.toml +++ b/docs/website/config.toml @@ -7,11 +7,9 @@ compile_sass = true # Whether to build a search index to be used later on by a JavaScript library build_search_index = true -[markdown] -# Whether to do syntax highlighting -# Theme can be customised by setting the `highlight_theme` variable to a theme supported by Zola -highlight_code = true -highlight_theme = "inspired-github" +[markdown.highlighting] +# Theme can be customised by setting the `theme` variable to a theme supported by Zola +theme = "github-light-default" [extra] # Put all your custom variables here diff --git a/src/composing/mod.rs b/src/composing/mod.rs index a1f2acc37..bb7271889 100644 --- a/src/composing/mod.rs +++ b/src/composing/mod.rs @@ -1,12 +1,15 @@ //! Composition models //! -//! This module contains three composition models: +//! This module contains four composition models: //! * `MultiClassModel`: combine multiple binary decision models to a single multi-class model //! * `MultiTargetModel`: combine multiple univariate models to a single multi-target model //! * `Platt`: calibrate a classifier (i.e. SVC) to predicted posterior probabilities +//! * `ResidualChain`: fit models sequentially on the residuals of the previous one +//! (forward stagewise additive modeling / L2Boosting); see [`residual_chain::Stagewise`] mod multi_class_model; mod multi_target_model; pub mod platt_scaling; +pub mod residual_chain; pub use multi_class_model::MultiClassModel; pub use multi_target_model::MultiTargetModel; diff --git a/src/composing/residual_chain.rs b/src/composing/residual_chain.rs new file mode 100644 index 000000000..188c7ce74 --- /dev/null +++ b/src/composing/residual_chain.rs @@ -0,0 +1,591 @@ +//! L2Boosting (forward stagewise additive modelling with squared-error loss) +//! for the linfa ML framework. +//! +//! This module provides [`ResidualChain`], which fits models sequentially on +//! residuals. Chain as many stages as you like via [`Stagewise`]: +//! +//! 1. Fit `base` on `(X, Y)` +//! 2. Compute residuals: `R = Y - base.predict(X)` +//! 3. Fit `corrector` on `(X, R)` +//! 4. Repeat for any further correctors stacked on top +//! +//! When predicting, all stages' outputs are summed. +//! +//! This is the special case of FSAM (Friedman 2001) where the loss is squared +//! error. Shrinkage (learning rate ν ∈ (0, 1]) can be set per corrector via +//! [`Shrunk::with_shrinkage`]; the default is ν = 1 (no scaling). +//! +//! # References +//! +//! - J. H. Friedman (2001). "Greedy function approximation: A gradient boosting machine." +//! +//! +//! # Examples +//! +//! ## Linear + linear +//! +//! Two `linfa_linear::LinearRegression` models stacked: the corrector fits +//! the residuals left by the base. +//! +//! ``` +//! use linfa::traits::{Fit, Predict}; +//! use linfa::DatasetBase; +//! use linfa_linear::LinearRegression; +//! use linfa::composing::residual_chain::{ResidualChain, Stagewise}; +//! use ndarray::{array, Array2}; +//! +//! // y = 2x: perfectly linear, so the corrector should see zero residuals. +//! let x = Array2::from_shape_fn((5, 1), |(i, _)| i as f64); +//! let y = array![0., 2., 4., 6., 8.]; +//! let dataset = DatasetBase::new(x.clone(), y); +//! +//! let fitted = LinearRegression::default() +//! .chain(LinearRegression::default()) +//! .fit(&dataset) +//! .unwrap(); +//! +//! let _preds = fitted.predict(&x); +//! ``` +//! +//! ## The second model learns nothing when the first fits perfectly +//! +//! If the first model already captures the data exactly, the residuals are all +//! zero and the second model has nothing to learn — its parameters come out +//! at (or very near) zero. +//! +//! ``` +//! use linfa::traits::{Fit, Predict}; +//! use linfa::DatasetBase; +//! use linfa_linear::LinearRegression; +//! use linfa::composing::residual_chain::Stagewise; +//! use ndarray::{array, Array2}; +//! +//! // y = 2x: one linear model is enough to fit this perfectly. +//! let x = Array2::from_shape_fn((5, 1), |(i, _)| i as f64); +//! let y = array![0., 2., 4., 6., 8.]; +//! let dataset = DatasetBase::new(x.clone(), y); +//! +//! let fitted = LinearRegression::default() +//! .chain(LinearRegression::default()) +//! .fit(&dataset) +//! .unwrap(); +//! +//! // The corrector trained on zero residuals — nothing left to correct. +//! assert!(fitted.corrector().model().params().iter().all(|&c: &f64| c.abs() < 1e-10)); +//! assert!(fitted.corrector().model().intercept().abs() < 1e-10); +//! ``` +//! +//! ## Chained SVMs and linear regression +//! +//! A linear-kernel `linfa_svm::Svm` captures the overall trend; two +//! Gaussian-kernel SVMs and a `linfa_linear::LinearRegression` then fit +//! successive residuals in a four-model chain. +//! +//! ``` +//! use linfa::traits::{Fit, Predict}; +//! use linfa::DatasetBase; +//! use linfa_linear::LinearRegression; +//! use linfa::composing::residual_chain::{ResidualChain, Stagewise}; +//! use linfa_svm::Svm; +//! use ndarray::Array; +//! +//! // y = sin(x): the linear SVM captures the slope; the RBF SVM captures +//! // the curvature left in the residuals. +//! let x = Array::linspace(0f64, 6., 20) +//! .into_shape_with_order((20, 1)) +//! .unwrap(); +//! let y = x.column(0).mapv(f64::sin); +//! let dataset = DatasetBase::new(x.clone(), y); +//! +//! let fitted = Svm::::params() +//! .c_svr(1., None) +//! .linear_kernel() +//! .chain( +//! Svm::::params() +//! .c_svr(10., Some(0.1)) +//! .gaussian_kernel(1.), +//! ) +//! .chain(LinearRegression::default()) +//! .chain( +//! Svm::::params() +//! .c_svr(10., Some(0.1)) +//! .gaussian_kernel(3.), +//! ) +//! .fit(&dataset) +//! .unwrap(); +//! +//! let _preds = fitted.predict(&x); +//! ``` + +use crate::dataset::{AsTargets, DatasetBase, Records}; +use crate::param_guard::ParamGuard; +use crate::traits::{Fit, Predict, PredictInplace}; +use crate::Float; +use ndarray::{Array1, Array2, ArrayBase, Data, Ix1, Ix2, RawDataClone}; +#[cfg(feature = "serde")] +use serde_crate::{Deserialize, Serialize}; +use std::ops::{AddAssign, Mul}; + +type Arr2 = ArrayBase; + +/// Error returned by [`ResidualChain::fit`]. +/// +/// Wraps the error from whichever of the two model fits failed, keeping them +/// distinguishable without requiring both models to share the same error type. +#[derive(Debug, thiserror::Error)] +pub enum ResidualChainError { + #[error("base model: {0}")] + Base(E1), + #[error("corrector: {0}")] + Corrector(E2), + // Satisfies the `Fit` trait's `E: From` bound. + #[error(transparent)] + BaseCrate(#[from] crate::Error), +} + +/// A pair of [`Fit`] params that fits sequentially on residuals. +/// +/// `base` is fit on the original targets; `corrector` (a [`Shrunk`] model) is +/// fit on the residuals left by `base` and scaled by its shrinkage factor ν. +/// Prediction sums `base` and the scaled corrector output. +#[cfg_attr( + feature = "serde", + derive(Serialize, Deserialize), + serde(crate = "serde_crate") +)] +#[derive(Debug, Clone, Copy)] +pub struct ResidualChain { + base: B, + corrector: Shrunk, +} + +impl ResidualChain { + pub fn base(&self) -> &B { + &self.base + } + pub fn corrector(&self) -> &Shrunk { + &self.corrector + } +} + +/// Extension trait that adds residual-chain composition methods to any type. +/// +/// Blanket-implemented for all `Sized` types, so any model params type gains +/// these methods automatically: +/// +/// - [`chain`](Stagewise::chain): compose `self` (as the base) with a corrector +/// that will be trained on the residuals left by `self`. The corrector is used +/// without shrinkage (ν = 1). Returns a [`ResidualChainParams`] whose `.fit()` +/// runs both stages. Calls can be chained to build arbitrarily deep sequences. +/// - [`chain_shrunk`](Stagewise::chain_shrunk): like `chain`, but accepts a +/// [`Shrunk`]-wrapped corrector so you can control the learning rate ν +/// explicitly via [`shrink_by`](Stagewise::shrink_by). +/// - [`shrink_by`](Stagewise::shrink_by): wrap `self` in a [`Shrunk`] with the +/// given learning rate ν ∈ (0, 1], making it ready to pass as the `corrector` +/// argument to [`Stagewise::chain_shrunk`]. +/// +/// # Example +/// +/// ``` +/// use linfa::traits::Fit; +/// use linfa::DatasetBase; +/// use linfa_linear::LinearRegression; +/// use linfa::composing::residual_chain::Stagewise; +/// use ndarray::{array, Array2}; +/// +/// let x = Array2::from_shape_fn((5, 1), |(i, _)| i as f64); +/// let y = array![0., 2., 4., 6., 8.]; +/// let dataset = DatasetBase::new(x.clone(), y); +/// +/// let fitted = LinearRegression::default() +/// .chain(LinearRegression::default()) +/// .fit(&dataset) +/// .unwrap(); +/// ``` +pub trait Stagewise: Sized { + /// Compose `self` (as the base model) with a [`Shrunk`]-wrapped `corrector`, + /// which will be trained on the residuals left by `self`. Further stages can + /// be appended by calling `.chain(...)` or `.chain_shrunk(...)` on the + /// returned [`ResidualChainParams`]. + /// + /// Use [`chain`](Stagewise::chain) instead when you don't need to shrink + /// the corrector. + fn chain_shrunk(self, corrector: Shrunk) -> ResidualChainParams; + + /// Compose `self` (as the base model) with `corrector`, which will be + /// trained on the residuals left by `self`. The corrector is used without + /// shrinkage (equivalent to `shrink_by(1.0)`). Further stages can be + /// appended by calling `.chain(...)` or `.chain_shrunk(...)` on the + /// returned [`ResidualChainParams`]. + /// + /// Use [`chain_shrunk`](Stagewise::chain_shrunk) together with + /// [`shrink_by`](Stagewise::shrink_by) when you need to control the + /// learning rate ν of the corrector explicitly. + fn chain(self, corrector: C) -> ResidualChainParams + where + C: Fit, Array1, E>, + E: std::error::Error + From; + + /// Wrap `self` in a [`Shrunk`] with learning rate `shrinkage` ∈ (0, 1], + /// making it ready to pass as the `corrector` argument to [`Stagewise::chain_shrunk`]. + /// + /// The bound `Self: Fit, Array1, E>` ensures at compile time + /// that the model's element type matches the shrinkage type `F`. + fn shrink_by(self, shrinkage: F) -> Shrunk + where + Self: Fit, Array1, E>, + E: std::error::Error + From; +} + +impl Stagewise for B { + fn chain_shrunk(self, corrector: Shrunk) -> ResidualChainParams { + ResidualChainParams(ResidualChain { + base: self, + corrector, + }) + } + fn chain(self, corrector: C) -> ResidualChainParams + where + C: Fit, Array1, E>, + E: std::error::Error + From, + { + self.chain_shrunk(corrector.shrink_by(F::one())) + } + fn shrink_by(self, shrinkage: F) -> Shrunk + where + Self: Fit, Array1, E>, + E: std::error::Error + From, + { + Shrunk { + model: self, + shrinkage, + } + } +} + +// Same pattern as `AdaBoostValidParams::Fit` impl in linfa-ensemble: introduce +// fresh generics `M1`/`M2` for the inner-model types and bind the associated +// `Object` types via `Fit<..., Object = MN>`. This sidesteps Linfa's +// `ParamGuard` blanket-impl recursion under ndarray 0.17 (otherwise +// `F1::Object: Predict` would force the solver to *resolve* `F1: Fit`, +// leading to the infinite `<::Checked>::Checked: ParamGuard` +// regress). +impl + RawDataClone, T, E1, E2> + Fit, T, ResidualChainError> for ResidualChain +where + Arr2: Records, + F1: Fit, T, E1, Object = M1>, + F2: Fit, Array1, E2, Object = M2>, + for<'a> M1: Predict<&'a Arr2, Array1>, + T: AsTargets, + E1: std::error::Error + From, + E2: std::error::Error + From, +{ + type Object = ResidualChain; + + fn fit( + &self, + dataset: &DatasetBase, T>, + ) -> Result> { + let base = self.base.fit(dataset).map_err(ResidualChainError::Base)?; + + let y_pred = base.predict(dataset.records()); + let residuals = &dataset.targets().as_targets() - &y_pred; + + let residual_dataset = DatasetBase::new(dataset.records().clone(), residuals); + let corrector_model = self + .corrector + .model + .fit(&residual_dataset) + .map_err(ResidualChainError::Corrector)?; + + Ok(ResidualChain { + base, + corrector: Shrunk { + model: corrector_model, + shrinkage: self.corrector.shrinkage, + }, + }) + } +} + +impl> PredictInplace, Array1> + for ResidualChain +where + R1: PredictInplace, Array1>, + R2: PredictInplace, Array1>, +{ + fn predict_inplace<'a>(&'a self, x: &'a Arr2, y: &mut Array1) { + self.base.predict_inplace(x, y); + y.add_assign( + &self + .corrector + .model + .predict(x) + .mul(self.corrector.shrinkage), + ); + } + + fn default_target(&self, x: &Arr2) -> Array1 { + Array1::zeros(x.nrows()) + } +} + +/// A model (params or fitted) paired with a shrinkage factor ν ∈ (0, 1]. +/// +/// Used in two roles: +/// - **Before fitting**: `Shrunk` wraps corrector params `C`; created by +/// [`Stagewise::shrink_by`] and stored in [`ResidualChain`] / [`ResidualChainParams`]. +/// - **After fitting**: `Shrunk` wraps the fitted corrector model; +/// prediction scales the corrector's output by ν before summing with the base. +#[cfg_attr( + feature = "serde", + derive(Serialize, Deserialize), + serde(crate = "serde_crate") +)] +#[derive(Debug, Clone, Copy)] +pub struct Shrunk { + model: M, + shrinkage: F, +} + +impl Shrunk { + pub fn model(&self) -> &M { + &self.model + } + pub fn shrinkage(&self) -> F { + self.shrinkage + } + /// Set the shrinkage factor. Validation happens when the containing + /// [`ResidualChainParams`] is checked via [`ParamGuard`]. + pub fn with_shrinkage(mut self, shrinkage: F) -> Self { + self.shrinkage = shrinkage; + self + } +} + +/// Unvalidated [`ResidualChain`] parameters returned by [`Stagewise::chain_shrunk`]. +/// +/// Call `.fit()` to validate and fit in one step — the [`ParamGuard`] blanket +/// impl runs `check_ref` first, which verifies that the outermost corrector's +/// shrinkage factor is in (0, 1]. Inner chains validate lazily when their own +/// `.fit()` is called. You can also call `.check()` / `.check_unwrap()` to +/// validate explicitly. +/// +/// To set an explicit shrinkage factor on the corrector use +/// [`Shrunk::with_shrinkage`]: +/// +/// ``` +/// use linfa::traits::{Fit, Predict}; +/// use linfa::DatasetBase; +/// use linfa_linear::LinearRegression; +/// use linfa::composing::residual_chain::{Shrunk, Stagewise}; +/// use ndarray::{array, Array2}; +/// +/// let x = Array2::from_shape_fn((5, 1), |(i, _)| i as f64); +/// let y = array![0., 2., 4., 6., 8.]; +/// let dataset = DatasetBase::new(x.clone(), y); +/// +/// // The corrector's contribution is scaled by 0.1. +/// let fitted = LinearRegression::default() +/// .chain_shrunk(LinearRegression::default().shrink_by(0.1)) +/// .fit(&dataset) +/// .unwrap(); +/// ``` +#[cfg_attr( + feature = "serde", + derive(Serialize, Deserialize), + serde(crate = "serde_crate") +)] +#[derive(Debug, Clone, Copy)] +pub struct ResidualChainParams(ResidualChain); + +impl ParamGuard for ResidualChainParams { + type Checked = ResidualChain; + type Error = crate::Error; + + fn check_ref(&self) -> Result<&ResidualChain, crate::Error> { + let v = self.0.corrector.shrinkage; + if v > F::zero() && v <= F::one() { + Ok(&self.0) + } else { + Err(crate::Error::Parameters(format!( + "shrinkage must be in (0, 1], got {v}" + ))) + } + } + + fn check(self) -> Result, crate::Error> { + self.check_ref()?; + Ok(self.0) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::Error as LinfaError; + use crate::DatasetBase; + use ndarray::{array, Array1, Array2}; + + #[derive(thiserror::Error, Debug)] + #[error("dummy error")] + struct DummyError(#[from] LinfaError); + + // Params that fits by recording the mean of the targets. + struct MeanParams; + + // Model that predicts the mean it saw during fit. + struct MeanModel(f64); + + impl Fit, Array1, DummyError> for MeanParams { + type Object = MeanModel; + fn fit( + &self, + dataset: &DatasetBase, Array1>, + ) -> Result { + let mean = dataset.targets().iter().sum::() / dataset.targets().len() as f64; + Ok(MeanModel(mean)) + } + } + + impl PredictInplace, Array1> for MeanModel { + fn predict_inplace(&self, x: &Array2, y: &mut Array1) { + y.assign(&Array1::from_elem(x.nrows(), self.0)); + } + fn default_target(&self, x: &Array2) -> Array1 { + Array1::zeros(x.nrows()) + } + } + + #[test] + fn corrector_is_fit_on_residuals() { + // targets = [1, 3]. base sees mean=2, predicts 2 for all. + // residuals = [1-2, 3-2] = [-1, 1]. corrector sees mean=0. + let model = MeanParams.chain(MeanParams); + let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![1.0, 3.0]); + let fitted = model.fit(&dataset).unwrap(); + + assert_eq!(fitted.base().0, 2.0); // mean of [1, 3] + assert_eq!(fitted.corrector().model().0, 0.0); // mean of residuals [-1, 1] + } + + #[test] + fn predict_sums_both_models() { + // base predicts 2.0, corrector predicts 0.0 → sum = 2.0 + let model = MeanParams.chain(MeanParams); + let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![1.0, 3.0]); + let fitted = model.fit(&dataset).unwrap(); + + let records = array![[0.0_f64], [1.0]]; + let predictions = fitted.predict(&records); + assert_eq!(predictions, array![2.0, 2.0]); + } + + #[test] + fn predict_recovers_targets_when_residuals_fit_perfectly() { + // If the corrector perfectly fits the residuals, the combined prediction = original targets. + struct FixedParams(f64); + struct FixedModel(f64); + + impl Fit, Array1, DummyError> for FixedParams { + type Object = FixedModel; + fn fit( + &self, + _dataset: &DatasetBase, Array1>, + ) -> Result { + Ok(FixedModel(self.0)) + } + } + + impl PredictInplace, Array1> for FixedModel { + fn predict_inplace(&self, x: &Array2, y: &mut Array1) { + y.assign(&Array1::from_elem(x.nrows(), self.0)); + } + fn default_target(&self, x: &Array2) -> Array1 { + Array1::zeros(x.nrows()) + } + } + + // base predicts 3.0, corrector predicts 1.0 → sum = 4.0 + let model = FixedParams(3.0) + .chain(FixedParams(1.0)) + .chain(FixedParams(0.0)); + let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![4.0, 4.0]); + let fitted = model.fit(&dataset).unwrap(); + + let predictions = fitted.predict(&array![[0.0_f64], [1.0]]); + assert_eq!(predictions, array![4.0, 4.0]); + } + + #[test] + fn deep_chain_accessors() { + let model = MeanParams + .chain(MeanParams) + .chain(MeanParams) + .chain(MeanParams); + let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![1.0, 3.0]); + let fitted = model.fit(&dataset).unwrap(); + + assert_eq!(fitted.base().base().base().0, 2.0); // params trained on original targets + } + + #[test] + fn shrinkage_scales_corrector_prediction() { + // base predicts mean=2.0, corrector predicts mean=0.0 (residuals [-1,1]). + // With shrinkage=0.5, corrector contributes 0.5*0.0 = 0.0 → total = 2.0. + let model = MeanParams.chain_shrunk(MeanParams.shrink_by(0.5)); + let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![1.0, 3.0]); + let fitted = model.fit(&dataset).unwrap(); + + let preds = fitted.predict(&array![[0.0_f64], [1.0]]); + assert_eq!(preds, array![2.0, 2.0]); + assert_eq!(fitted.corrector().shrinkage(), 0.5); + } + + #[test] + fn shrinkage_corrector_sees_scaled_residuals() { + // base predicts 3.0 always. targets = [4.0, 4.0]. + // residuals = [1.0, 1.0]. corrector (mean) sees mean=1.0. + // With shrinkage=0.5: prediction = 3.0 + 0.5*1.0 = 3.5. + struct FixedParams(f64); + struct FixedModel(f64); + + impl Fit, Array1, DummyError> for FixedParams { + type Object = FixedModel; + fn fit( + &self, + _dataset: &DatasetBase, Array1>, + ) -> Result { + Ok(FixedModel(self.0)) + } + } + + impl PredictInplace, Array1> for FixedModel { + fn predict_inplace(&self, x: &Array2, y: &mut Array1) { + y.assign(&Array1::from_elem(x.nrows(), self.0)); + } + fn default_target(&self, x: &Array2) -> Array1 { + Array1::zeros(x.nrows()) + } + } + + let model = FixedParams(3.0).chain_shrunk(MeanParams.shrink_by(0.5)); + let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![4.0, 4.0]); + let fitted = model.fit(&dataset).unwrap(); + + let preds = fitted.predict(&array![[0.0_f64], [1.0]]); + // corrector saw residuals [1.0, 1.0], mean=1.0, shrunk by 0.5 → 0.5 + assert!((preds[0] - 3.5_f64).abs() < 1e-10); + } + + #[test] + fn shrinkage_invalid_value_returns_error() { + let dataset = DatasetBase::new(array![[0.0_f64], [1.0]], array![1.0, 3.0]); + let model = MeanParams.chain_shrunk(MeanParams.shrink_by(0.0)); + assert!(model.fit(&dataset).is_err()); + + let model = MeanParams.chain_shrunk(MeanParams.shrink_by(1.5)); + assert!(model.fit(&dataset).is_err()); + } +} diff --git a/src/metrics_regression.rs b/src/metrics_regression.rs index 56508cd6a..c15c8db08 100644 --- a/src/metrics_regression.rs +++ b/src/metrics_regression.rs @@ -8,7 +8,7 @@ use crate::{ Float, }; use ndarray::prelude::*; -use ndarray::Data; +use ndarray::{Data, Zip}; use std::ops::{Div, Sub}; /// Regression metrices trait for single targets. @@ -80,6 +80,29 @@ pub trait SingleTargetRegression>: .ok_or(Error::NotEnoughSamples) } + /// Symmetric mean absolute percentage error between two continuous variables. + /// This implementation follows the Adjusted sMAPE (Makridakis, 1993) + /// sMAPE = (200 / n) * SUM(abs(y_hat - y) / (abs(y) + abs(y_hat))) + fn symmetric_mean_absolute_percentage_error(&self, compare_to: &T) -> Result { + let y = self.as_single_targets(); + let y_hat = compare_to.as_single_targets(); + if y.is_empty() { + return Err(Error::NotEnoughSamples); + } + let sum: F = Zip::from(&y) + .and(&y_hat) + .fold(F::cast(0.0), |acc, &yi, &yhi| { + let num = (yhi - yi).abs(); + let den = yi.abs() + yhi.abs(); + if den <= F::epsilon() { + acc + } else { + acc + (num / den) + } + }); + Ok((F::cast(200.0) / F::cast(y.len())) * sum) + } + /// R squared coefficient, is the proportion of the variance in the dependent variable that is /// predictable from the independent variable // r2 = 1 - sum((pred_i - y_i)^2)/sum((mean_y - y_i)^2) @@ -193,6 +216,17 @@ pub trait MultiTargetRegression>: .collect() } + /// Symmetric mean absolute percentage error between two continuous variables + /// This implementation follows the Adjusted sMAPE (Makridakis, 1993) + /// sMAPE = (200 / n) * SUM(abs(y_hat - y) / (abs(y) + abs(y_hat))) + fn symmetric_mean_absolute_percentage_error(&self, other: &T) -> Result> { + self.as_multi_targets() + .axis_iter(Axis(1)) + .zip(other.as_multi_targets().axis_iter(Axis(1))) + .map(|(a, b)| a.symmetric_mean_absolute_percentage_error(&b)) + .collect() + } + /// R squared coefficient, is the proportion of the variance in the dependent variable that is /// predictable from the independent variable fn r2(&self, other: &T) -> Result> { @@ -225,7 +259,7 @@ impl, T2: AsMultiTargets, D: Dat #[cfg(test)] mod tests { - use super::SingleTargetRegression; + use super::{MultiTargetRegression, SingleTargetRegression}; use crate::dataset::DatasetBase; use approx::assert_abs_diff_eq; use ndarray::prelude::*; @@ -242,6 +276,10 @@ mod tests { assert_abs_diff_eq!(a.r2(&a).unwrap(), 1.0f32); assert_abs_diff_eq!(a.explained_variance(&a).unwrap(), 1.0f32); assert_abs_diff_eq!(a.mean_absolute_percentage_error(&a).unwrap(), 0.0f32); + assert_abs_diff_eq!( + a.symmetric_mean_absolute_percentage_error(&a).unwrap(), + 0.0f32 + ); } #[test] @@ -281,6 +319,18 @@ mod tests { ); } + #[test] + fn test_symmetric_mean_absolute_percentage_error() { + let a = array![0.5, 0.1, 0.2, 0.3, 0.4]; + let b = array![0.1, 0.2, 0.3, 0.4, 0.5]; + + assert_abs_diff_eq!( + a.symmetric_mean_absolute_percentage_error(&b).unwrap(), + 58.15873014693111, + epsilon = 1e-5 + ); + } + #[test] fn test_max_error_for_single_targets() { let records = array![[0.0, 0.0], [0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4]]; @@ -339,6 +389,23 @@ mod tests { assert_abs_diff_eq!(pct_err_from_arr1, pct_err_from_ds); } + #[test] + fn test_symmetric_mean_absolute_percentage_error_for_single_targets() { + let records = array![[0.0, 0.0], [0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4]]; + let targets = array![0.0, 0.1, 0.2, 0.3, 0.4]; + let st_dataset: DatasetBase<_, _> = (records.view(), targets).into(); + let prediction = array![0.1, 0.3, 0.2, 0.5, 0.7]; + let err_from_arr = prediction + .symmetric_mean_absolute_percentage_error(st_dataset.targets()) + .unwrap(); + let prediction_ds: DatasetBase<_, _> = (records.view(), prediction).into(); + let err_from_ds = prediction_ds + .symmetric_mean_absolute_percentage_error(&st_dataset) + .unwrap(); + assert_abs_diff_eq!(err_from_arr, 80.90909086184916, epsilon = 1e-5); + assert_abs_diff_eq!(err_from_arr, err_from_ds); + } + #[test] fn test_mean_squared_log_error_for_single_targets() { let records = array![[0.0, 0.0], [0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4]]; @@ -410,4 +477,13 @@ mod tests { assert_abs_diff_eq!(abs_err_from_arr1, 0.8, epsilon = 1e-5); assert_abs_diff_eq!(abs_err_from_arr1, abs_err_from_ds); } + + #[test] + fn test_symmetric_mean_absolute_percentage_error_multi_target() { + let a = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]; + let b = array![[1.1, 1.9], [2.9, 4.1], [5.1, 5.9]]; + let err = a.symmetric_mean_absolute_percentage_error(&b).unwrap(); + let expected = array![4.964612684, 3.092671067]; + assert_abs_diff_eq!(err, expected, epsilon = 1e-5); + } }