From 720e00b58fb6124c3f5bc463a362716e7cb83a68 Mon Sep 17 00:00:00 2001 From: Olga Date: Thu, 28 May 2026 13:17:50 +0200 Subject: [PATCH 1/3] added adaptive thresholding + adaptation to all files to include it --- .../online_localizer_lsh.cpp | 4 +- .../similarity_matrix_no_hashing.cpp | 10 +- src/localization/database/idatabase.h | 6 +- src/localization/database/online_database.cpp | 18 + src/localization/database/online_database.h | 5 +- .../database/similarity_matrix_database.cpp | 8 + .../database/similarity_matrix_database.h | 3 +- .../online_localizer/CMakeLists.txt | 4 + .../math_tools/CMakeLists.txt | 19 ++ .../online_localizer/math_tools/constants.h | 311 +++++++++++++++++ .../derivation_boundary_estimation.md | 83 +++++ .../online_localizer/math_tools/gmm.cpp | 286 ++++++++++++++++ .../online_localizer/math_tools/gmm.h | 87 +++++ .../math_tools/math_tools.cpp | 121 +++++++ .../online_localizer/math_tools/math_tools.h | 44 +++ .../math_tools/statistical_test.cpp | 194 +++++++++++ .../math_tools/statistical_test.h | 51 +++ .../online_localizer/online_localizer.cpp | 151 ++++++++- .../online_localizer/online_localizer.h | 24 +- .../online_localizer/path_element.cpp | 12 + .../online_localizer/path_element.h | 6 + .../successor_manager/successor_manager.cpp | 21 ++ .../successor_manager/successor_manager.h | 2 + .../tools/config_parser/config_parser.cpp | 318 +++++++++--------- .../tools/config_parser/config_parser.h | 2 + src/localization_protos.proto | 38 +++ src/test/online_localizer_test.cpp | 2 +- 27 files changed, 1650 insertions(+), 180 deletions(-) create mode 100644 src/localization/online_localizer/math_tools/CMakeLists.txt create mode 100644 src/localization/online_localizer/math_tools/constants.h create mode 100644 src/localization/online_localizer/math_tools/derivation_boundary_estimation.md create mode 100644 src/localization/online_localizer/math_tools/gmm.cpp create mode 100644 src/localization/online_localizer/math_tools/gmm.h create mode 100644 src/localization/online_localizer/math_tools/math_tools.cpp create mode 100644 src/localization/online_localizer/math_tools/math_tools.h create mode 100644 src/localization/online_localizer/math_tools/statistical_test.cpp create mode 100644 src/localization/online_localizer/math_tools/statistical_test.h diff --git a/src/apps/similarity_matrix_based_matching/online_localizer_lsh.cpp b/src/apps/similarity_matrix_based_matching/online_localizer_lsh.cpp index fc15835..204622b 100644 --- a/src/apps/similarity_matrix_based_matching/online_localizer_lsh.cpp +++ b/src/apps/similarity_matrix_based_matching/online_localizer_lsh.cpp @@ -90,9 +90,9 @@ int main(int argc, char *argv[]) { std::make_unique( database.get(), relocalizer.get(), parser.fanOut); loc::online_localizer::OnlineLocalizer localizer{ - successorManager.get(), parser.expansionRate, parser.matchingThreshold}; + successorManager.get(), parser.expansionRate, parser.matchingThreshold, parser.adaptThreshold}; const loc::online_localizer::Matches imageMatches = - localizer.findMatchesTill(parser.querySize); + localizer.findMatchesTill(parser.querySize, parser.debugProto); loc::online_localizer::storeMatchesAsProto(imageMatches, parser.matchingResult); diff --git a/src/apps/similarity_matrix_based_matching/similarity_matrix_no_hashing.cpp b/src/apps/similarity_matrix_based_matching/similarity_matrix_no_hashing.cpp index 3d8e953..7f44b3f 100644 --- a/src/apps/similarity_matrix_based_matching/similarity_matrix_no_hashing.cpp +++ b/src/apps/similarity_matrix_based_matching/similarity_matrix_no_hashing.cpp @@ -1,7 +1,7 @@ /* By O. Vysotska in 2023 */ -#include "database/similarity_matrix_database.h" #include "database/idatabase.h" +#include "database/similarity_matrix_database.h" #include "online_localizer/online_localizer.h" #include "online_localizer/path_element.h" #include "relocalizers/default_relocalizer.h" @@ -31,7 +31,8 @@ int main(int argc, char *argv[]) { parser.print(); const auto database = - std::make_unique(parser.similarityMatrix); + std::make_unique( + parser.similarityMatrix); const auto relocalizer = std::make_unique( @@ -41,9 +42,10 @@ int main(int argc, char *argv[]) { std::make_unique( database.get(), relocalizer.get(), parser.fanOut); loc::online_localizer::OnlineLocalizer localizer{ - successorManager.get(), parser.expansionRate, parser.matchingThreshold}; + successorManager.get(), parser.expansionRate, parser.matchingThreshold, + parser.adaptThreshold}; const loc::online_localizer::Matches imageMatches = - localizer.findMatchesTill(parser.querySize); + localizer.findMatchesTill(parser.querySize, parser.debugProto); loc::online_localizer::storeMatchesAsProto(imageMatches, parser.matchingResult); diff --git a/src/localization/database/idatabase.h b/src/localization/database/idatabase.h index c8ca122..0e8d066 100644 --- a/src/localization/database/idatabase.h +++ b/src/localization/database/idatabase.h @@ -24,6 +24,8 @@ #ifndef SRC_DATABASE_IDATABASE_H_ #define SRC_DATABASE_IDATABASE_H_ +#include + namespace localization::database { /** @@ -31,7 +33,7 @@ namespace localization::database { */ class iDatabase { public: - virtual int refSize() = 0; + virtual int refSize() const = 0; /** * @brief Gets the cost. This cost goes directly in the graph structure. * Smaller costs correspond to bigger similarities. @@ -42,6 +44,8 @@ class iDatabase { * @return The cost. */ virtual double getCost(int quId, int refId) = 0; + virtual std::optional getCostIfComputed(int quId, + int refId) const = 0; iDatabase() = default; iDatabase(const iDatabase &) = delete; diff --git a/src/localization/database/online_database.cpp b/src/localization/database/online_database.cpp index 0916a95..c745aa2 100644 --- a/src/localization/database/online_database.cpp +++ b/src/localization/database/online_database.cpp @@ -99,6 +99,24 @@ double OnlineDatabase::getCost(int quId, int refId) { return cost; } +std::optional OnlineDatabase::getCostIfComputed(int quId, + int refId) const { + if (precomputedScores_) { + if (refId >= refSize()) { + return {}; + } + return precomputedScores_->at(quId, refId); + } + auto rowIter = costs_.find(quId); + if (rowIter != costs_.end()) { + auto elementIter = rowIter->second.find(refId); + if (elementIter != rowIter->second.end()) { + return 1. / elementIter->second; + } + } + return {}; +} + const features::iFeature &OnlineDatabase::getQueryFeature(int quId) { return addFeatureIfNeeded(*queryBuffer_, quFeaturesNames_, featureType_, quId); diff --git a/src/localization/database/online_database.h b/src/localization/database/online_database.h index 1fc9ca3..c514f01 100644 --- a/src/localization/database/online_database.h +++ b/src/localization/database/online_database.h @@ -26,8 +26,8 @@ #ifndef SRC_DATABASE_ONLINE_DATABASE_H_ #define SRC_DATABASE_ONLINE_DATABASE_H_ -#include "database/similarity_matrix.h" #include "database/idatabase.h" +#include "database/similarity_matrix.h" #include "features/feature_buffer.h" #include "features/feature_factory.h" @@ -48,8 +48,9 @@ class OnlineDatabase : public iDatabase { const std::string &refFeaturesDir, features::FeatureType type, int bufferSize, const std::string &similarityMatrixFile = ""); - inline int refSize() override { return refFeaturesNames_.size(); } + inline int refSize() const override { return refFeaturesNames_.size(); } double getCost(int quId, int refId) override; + std::optional getCostIfComputed(int quId, int refId) const override; double computeMatchingCost(int quId, int refId); diff --git a/src/localization/database/similarity_matrix_database.cpp b/src/localization/database/similarity_matrix_database.cpp index 78da35b..7cdf1fb 100644 --- a/src/localization/database/similarity_matrix_database.cpp +++ b/src/localization/database/similarity_matrix_database.cpp @@ -38,4 +38,12 @@ double SimilarityMatrixDatabase::getCost(int quId, int refId) { return similarityMatrix_.getCost(quId, refId); } +std::optional SimilarityMatrixDatabase::getCostIfComputed(int quId, + int refId) const { + if (refId >= refSize()) { + return {}; + } + return similarityMatrix_.at(quId, refId); +} + } // namespace localization::database diff --git a/src/localization/database/similarity_matrix_database.h b/src/localization/database/similarity_matrix_database.h index aee03f6..9655fac 100644 --- a/src/localization/database/similarity_matrix_database.h +++ b/src/localization/database/similarity_matrix_database.h @@ -38,8 +38,9 @@ class SimilarityMatrixDatabase : public iDatabase { public: explicit SimilarityMatrixDatabase(const std::string &costMatrixFile); - int refSize() override { return similarityMatrix_.cols(); } + int refSize() const override { return similarityMatrix_.cols(); } double getCost(int quId, int refId) override; + std::optional getCostIfComputed(int quId, int refId) const override; private: SimilarityMatrix similarityMatrix_; diff --git a/src/localization/online_localizer/CMakeLists.txt b/src/localization/online_localizer/CMakeLists.txt index f338a00..32e3e9f 100644 --- a/src/localization/online_localizer/CMakeLists.txt +++ b/src/localization/online_localizer/CMakeLists.txt @@ -1,3 +1,5 @@ +add_subdirectory(math_tools) + add_library(path_element path_element.cpp) target_link_libraries(path_element cxx_flags @@ -14,6 +16,8 @@ target_link_libraries(online_localizer successor_manager node timer + gmm + statistical_test protos glog::glog ) \ No newline at end of file diff --git a/src/localization/online_localizer/math_tools/CMakeLists.txt b/src/localization/online_localizer/math_tools/CMakeLists.txt new file mode 100644 index 0000000..294c83a --- /dev/null +++ b/src/localization/online_localizer/math_tools/CMakeLists.txt @@ -0,0 +1,19 @@ +add_library(math_tools math_tools.cpp) +target_link_libraries(math_tools + cxx_flags + glog::glog +) + +add_library(gmm gmm.cpp) +target_link_libraries(gmm + math_tools + cxx_flags + glog::glog +) + +add_library(statistical_test statistical_test.cpp constants.h) +target_link_libraries(statistical_test + math_tools + cxx_flags + glog::glog +) diff --git a/src/localization/online_localizer/math_tools/constants.h b/src/localization/online_localizer/math_tools/constants.h new file mode 100644 index 0000000..e54f805 --- /dev/null +++ b/src/localization/online_localizer/math_tools/constants.h @@ -0,0 +1,311 @@ +#ifndef SRC_ONLINE_LOCALIZER_CONSTANTS_H_ +#define SRC_ONLINE_LOCALIZER_CONSTANTS_H_ + +#include + +namespace localization::online_localizer::constants { +inline const std::vector kolmogorovX = { + 0.000000000000000000e+00, 1.000000000000000021e-02, + 2.000000000000000042e-02, 2.999999999999999889e-02, + 4.000000000000000083e-02, 5.000000000000000278e-02, + 5.999999999999999778e-02, 7.000000000000000666e-02, + 8.000000000000000167e-02, 8.999999999999999667e-02, + 1.000000000000000056e-01, 1.100000000000000006e-01, + 1.199999999999999956e-01, 1.300000000000000044e-01, + 1.400000000000000133e-01, 1.499999999999999944e-01, + 1.600000000000000033e-01, 1.700000000000000122e-01, + 1.799999999999999933e-01, 1.900000000000000022e-01, + 2.000000000000000111e-01, 2.099999999999999922e-01, + 2.200000000000000011e-01, 2.300000000000000100e-01, + 2.399999999999999911e-01, 2.500000000000000000e-01, + 2.600000000000000089e-01, 2.700000000000000178e-01, + 2.800000000000000266e-01, 2.899999999999999800e-01, + 2.999999999999999889e-01, 3.099999999999999978e-01, + 3.200000000000000067e-01, 3.300000000000000155e-01, + 3.400000000000000244e-01, 3.500000000000000333e-01, + 3.599999999999999867e-01, 3.699999999999999956e-01, + 3.800000000000000044e-01, 3.900000000000000133e-01, + 4.000000000000000222e-01, 4.100000000000000311e-01, + 4.199999999999999845e-01, 4.299999999999999933e-01, + 4.400000000000000022e-01, 4.500000000000000111e-01, + 4.600000000000000200e-01, 4.700000000000000289e-01, + 4.799999999999999822e-01, 4.899999999999999911e-01, + 5.000000000000000000e-01, 5.100000000000000089e-01, + 5.200000000000000178e-01, 5.300000000000000266e-01, + 5.400000000000000355e-01, 5.500000000000000444e-01, + 5.600000000000000533e-01, 5.700000000000000622e-01, + 5.799999999999999600e-01, 5.899999999999999689e-01, + 5.999999999999999778e-01, 6.099999999999999867e-01, + 6.199999999999999956e-01, 6.300000000000000044e-01, + 6.400000000000000133e-01, 6.500000000000000222e-01, + 6.600000000000000311e-01, 6.700000000000000400e-01, + 6.800000000000000488e-01, 6.900000000000000577e-01, + 7.000000000000000666e-01, 7.099999999999999645e-01, + 7.199999999999999734e-01, 7.299999999999999822e-01, + 7.399999999999999911e-01, 7.500000000000000000e-01, + 7.600000000000000089e-01, 7.700000000000000178e-01, + 7.800000000000000266e-01, 7.900000000000000355e-01, + 8.000000000000000444e-01, 8.100000000000000533e-01, + 8.200000000000000622e-01, 8.300000000000000711e-01, + 8.399999999999999689e-01, 8.499999999999999778e-01, + 8.599999999999999867e-01, 8.699999999999999956e-01, + 8.800000000000000044e-01, 8.900000000000000133e-01, + 9.000000000000000222e-01, 9.100000000000000311e-01, + 9.200000000000000400e-01, 9.300000000000000488e-01, + 9.400000000000000577e-01, 9.500000000000000666e-01, + 9.599999999999999645e-01, 9.699999999999999734e-01, + 9.799999999999999822e-01, 9.899999999999999911e-01, + 1.000000000000000000e+00, 1.010000000000000009e+00, + 1.020000000000000018e+00, 1.030000000000000027e+00, + 1.040000000000000036e+00, 1.050000000000000044e+00, + 1.060000000000000053e+00, 1.070000000000000062e+00, + 1.080000000000000071e+00, 1.090000000000000080e+00, + 1.100000000000000089e+00, 1.110000000000000098e+00, + 1.120000000000000107e+00, 1.130000000000000115e+00, + 1.140000000000000124e+00, 1.150000000000000133e+00, + 1.159999999999999920e+00, 1.169999999999999929e+00, + 1.179999999999999938e+00, 1.189999999999999947e+00, + 1.199999999999999956e+00, 1.209999999999999964e+00, + 1.219999999999999973e+00, 1.229999999999999982e+00, + 1.239999999999999991e+00, 1.250000000000000000e+00, + 1.260000000000000009e+00, 1.270000000000000018e+00, + 1.280000000000000027e+00, 1.290000000000000036e+00, + 1.300000000000000044e+00, 1.310000000000000053e+00, + 1.320000000000000062e+00, 1.330000000000000071e+00, + 1.340000000000000080e+00, 1.350000000000000089e+00, + 1.360000000000000098e+00, 1.370000000000000107e+00, + 1.380000000000000115e+00, 1.390000000000000124e+00, + 1.400000000000000133e+00, 1.409999999999999920e+00, + 1.419999999999999929e+00, 1.429999999999999938e+00, + 1.439999999999999947e+00, 1.449999999999999956e+00, + 1.459999999999999964e+00, 1.469999999999999973e+00, + 1.479999999999999982e+00, 1.489999999999999991e+00, + 1.500000000000000000e+00, 1.510000000000000009e+00, + 1.520000000000000018e+00, 1.530000000000000027e+00, + 1.540000000000000036e+00, 1.550000000000000044e+00, + 1.560000000000000053e+00, 1.570000000000000062e+00, + 1.580000000000000071e+00, 1.590000000000000080e+00, + 1.600000000000000089e+00, 1.610000000000000098e+00, + 1.620000000000000107e+00, 1.630000000000000115e+00, + 1.640000000000000124e+00, 1.650000000000000133e+00, + 1.660000000000000142e+00, 1.669999999999999929e+00, + 1.679999999999999938e+00, 1.689999999999999947e+00, + 1.699999999999999956e+00, 1.709999999999999964e+00, + 1.719999999999999973e+00, 1.729999999999999982e+00, + 1.739999999999999991e+00, 1.750000000000000000e+00, + 1.760000000000000009e+00, 1.770000000000000018e+00, + 1.780000000000000027e+00, 1.790000000000000036e+00, + 1.800000000000000044e+00, 1.810000000000000053e+00, + 1.820000000000000062e+00, 1.830000000000000071e+00, + 1.840000000000000080e+00, 1.850000000000000089e+00, + 1.860000000000000098e+00, 1.870000000000000107e+00, + 1.880000000000000115e+00, 1.890000000000000124e+00, + 1.900000000000000133e+00, 1.910000000000000142e+00, + 1.919999999999999929e+00, 1.929999999999999938e+00, + 1.939999999999999947e+00, 1.949999999999999956e+00, + 1.959999999999999964e+00, 1.969999999999999973e+00, + 1.979999999999999982e+00, 1.989999999999999991e+00, + 2.000000000000000000e+00, 2.010000000000000231e+00, + 2.020000000000000018e+00, 2.030000000000000249e+00, + 2.040000000000000036e+00, 2.049999999999999822e+00, + 2.060000000000000053e+00, 2.069999999999999840e+00, + 2.080000000000000071e+00, 2.089999999999999858e+00, + 2.100000000000000089e+00, 2.109999999999999876e+00, + 2.120000000000000107e+00, 2.129999999999999893e+00, + 2.140000000000000124e+00, 2.149999999999999911e+00, + 2.160000000000000142e+00, 2.169999999999999929e+00, + 2.180000000000000160e+00, 2.189999999999999947e+00, + 2.200000000000000178e+00, 2.209999999999999964e+00, + 2.220000000000000195e+00, 2.229999999999999982e+00, + 2.240000000000000213e+00, 2.250000000000000000e+00, + 2.260000000000000231e+00, 2.270000000000000018e+00, + 2.280000000000000249e+00, 2.290000000000000036e+00, + 2.300000000000000266e+00, 2.310000000000000053e+00, + 2.319999999999999840e+00, 2.330000000000000071e+00, + 2.339999999999999858e+00, 2.350000000000000089e+00, + 2.359999999999999876e+00, 2.370000000000000107e+00, + 2.379999999999999893e+00, 2.390000000000000124e+00, + 2.399999999999999911e+00, 2.410000000000000142e+00, + 2.419999999999999929e+00, 2.430000000000000160e+00, + 2.439999999999999947e+00, 2.450000000000000178e+00, + 2.459999999999999964e+00, 2.470000000000000195e+00, + 2.479999999999999982e+00, 2.490000000000000213e+00, + 2.500000000000000000e+00, 2.510000000000000231e+00, + 2.520000000000000018e+00, 2.530000000000000249e+00, + 2.540000000000000036e+00, 2.550000000000000266e+00, + 2.560000000000000053e+00, 2.569999999999999840e+00, + 2.580000000000000071e+00, 2.589999999999999858e+00, + 2.600000000000000089e+00, 2.609999999999999876e+00, + 2.620000000000000107e+00, 2.629999999999999893e+00, + 2.640000000000000124e+00, 2.649999999999999911e+00, + 2.660000000000000142e+00, 2.669999999999999929e+00, + 2.680000000000000160e+00, 2.689999999999999947e+00, + 2.700000000000000178e+00, 2.709999999999999964e+00, + 2.720000000000000195e+00, 2.729999999999999982e+00, + 2.740000000000000213e+00, 2.750000000000000000e+00, + 2.760000000000000231e+00, 2.770000000000000018e+00, + 2.780000000000000249e+00, 2.790000000000000036e+00, + 2.800000000000000266e+00, 2.810000000000000053e+00, + 2.819999999999999840e+00, 2.830000000000000071e+00, + 2.839999999999999858e+00, 2.850000000000000089e+00, + 2.859999999999999876e+00, 2.870000000000000107e+00, + 2.879999999999999893e+00, 2.890000000000000124e+00, + 2.899999999999999911e+00, 2.910000000000000142e+00, + 2.919999999999999929e+00, 2.930000000000000160e+00, + 2.939999999999999947e+00, 2.950000000000000178e+00, + 2.959999999999999964e+00, 2.970000000000000195e+00, + 2.979999999999999982e+00, 2.990000000000000213e+00}; +inline const std::vector kolmogorovPvalue = { + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 1.000000000000000000e+00, 1.000000000000000000e+00, + 9.999999999999995559e-01, 9.999999999999810152e-01, + 9.999999999994949595e-01, 9.999999999915384352e-01, + 9.999999999030295683e-01, 9.999999991890031747e-01, + 9.999999947882880491e-01, 9.999999731761899380e-01, + 9.999998856493185206e-01, 9.999995849712886020e-01, + 9.999986881499901559e-01, 9.999963201361131704e-01, + 9.999906941986654862e-01, 9.999785020570597371e-01, + 9.999541201308875804e-01, 9.999086804678157803e-01, + 9.998290948601502581e-01, 9.996971473702905842e-01, + 9.994887769610756312e-01, 9.991736569658024036e-01, + 9.987151458373864088e-01, 9.980706413847183489e-01, + 9.971923267772982991e-01, 9.960282580366814909e-01, + 9.945237142469574021e-01, 9.926227153740532216e-01, + 9.902696081545563356e-01, 9.874106261529367323e-01, + 9.839953430838830339e-01, 9.799779559351086133e-01, + 9.753183536399832043e-01, 9.699829458942970106e-01, + 9.639452436648751066e-01, 9.571861970810587872e-01, + 9.496943073294776827e-01, 9.414655368786881695e-01, + 9.325030471043408431e-01, 9.228167945891518897e-01, + 9.124230175285631406e-01, 9.013436422831755390e-01, + 8.896056376475568461e-01, 8.772403412558946734e-01, + 8.642827790506042440e-01, 8.507709951598498854e-01, + 8.367454060556157591e-01, 8.222481896276203628e-01, + 8.073227168928092112e-01, 7.920130315089249917e-01, + 7.763633800874639723e-01, 7.604177944983462334e-01, + 7.442197259033892509e-01, 7.278117291165672187e-01, + 7.112351950296890157e-01, 6.945301282244591423e-01, + 6.777349664784746430e-01, 6.608864386282611303e-01, + 6.440194571446776761e-01, 6.271670417762615912e-01, + 6.103602706992852234e-01, 5.936282557581773478e-01, + 5.769981385685900577e-01, 5.604951044730810406e-01, + 5.441424115741980705e-01, 5.279614323123148489e-01, + 5.119717052984921191e-01, 4.961909953505034832e-01, + 4.806353599086878869e-01, 4.653192202251582588e-01, + 4.502554359224337222e-01, 4.354553817055706322e-01, + 4.209290251842246477e-01, 4.066850049180472104e-01, + 3.927307079406542889e-01, 3.790723461446418030e-01, + 3.657150310235156021e-01, 3.526628463667678681e-01, + 3.399189185925663437e-01, 3.274854844795591635e-01, + 3.153639561260539459e-01, 3.035549830222197265e-01, + 2.920585111698588143e-01, 2.808738392255489269e-01, + 2.699996716773545580e-01, 2.594341690935974554e-01, + 2.491749955050625021e-01, 2.392193630001340499e-01, + 2.295640736263188653e-01, 2.202055587019502714e-01, + 2.111399156490972873e-01, 2.023629424632721607e-01, + 1.938701699378505838e-01, 1.856568917615732672e-01, + 1.777181926064012041e-01, 1.700489743206650417e-01, + 1.626439803391253047e-01, 1.554978184174867217e-01, + 1.486049817942788343e-01, 1.419598688780007445e-01, + 1.355568015521786407e-01, 1.293900421856188365e-01, + 1.234538094297657068e-01, 1.177422928797716567e-01, + 1.122496666707249735e-01, 1.069701020755133430e-01, + 1.018977791660635596e-01, 9.702689759522084567e-02, + 9.235168655233264023e-02, 8.786641394169107666e-02, + 8.356539482936792274e-02, 7.944299920064742948e-02, + 7.549365906721043629e-02, 7.171187496044884035e-02, + 6.809222184476636242e-02, 6.462935448244652714e-02, + 6.131801227961464901e-02, 5.815302364097883064e-02, + 5.512930985938035988e-02, 5.224188856470460968e-02, + 4.948587675537787639e-02, 4.685649343449063003e-02, + 4.434906187152343193e-02, 4.195901150969805204e-02, + 3.968187953811435492e-02, 3.751331214705481976e-02, + 3.544906548412805891e-02, 3.348500632827312862e-02, + 3.161711249804311646e-02, 2.984147301002937347e-02, + 2.815428800275951091e-02, 2.655186844090486747e-02, + 2.503063561415656141e-02, 2.358712044466907895e-02, + 2.221796261652512691e-02, 2.091990954023876734e-02, + 1.968981516488541722e-02, 1.852463865002384330e-02, + 1.742144290915547192e-02, 1.637739303604922661e-02, + 1.538975462484470751e-02, 1.445589199443203694e-02, + 1.357326632719390441e-02, 1.273943373178294987e-02, + 1.195204323919661638e-02, 1.120883474100282245e-02, + 1.050763687816192241e-02, 9.846364888486520278e-03, + 9.223018420378748666e-03, 8.635679320087892227e-03, + 8.082509399337856060e-03, 7.561748189787020381e-03, + 7.071710690401440159e-03, 6.610785113447655882e-03, + 6.177430634444128557e-03, 5.770175151050935432e-03, + 5.387613055526675569e-03, 5.028403025039829451e-03, + 4.691265833789841876e-03, 4.374982190571073161e-03, + 4.078390605101352151e-03, 3.800385286135859672e-03, + 3.539914074097407371e-03, 3.295976410675510899e-03, + 3.067621347579705879e-03, 2.853945596376883300e-03, + 2.654091621098549677e-03, 2.467245775071747523e-03, + 2.292636483206679770e-03, 2.129532470765164651e-03, + 1.977241039436586183e-03, 1.835106391361817930e-03, + 1.702508001570808517e-03, 1.578859039135579925e-03, + 1.463604837187349253e-03, 1.356221411803981442e-03, + 1.256214029641705395e-03, 1.163115824062832884e-03, + 1.076486459398666006e-03, 9.959108428835800140e-04, + 9.209978837021620159e-04, 8.513792985058264836e-04, + 7.867084626782404506e-04, 7.266593065597688316e-04, + 6.709252557796953342e-04, 6.192182147907223646e-04, + 5.712675926529839571e-04, 5.268193700739778494e-04, + 4.856352066763015558e-04, 4.474915874362286774e-04, + 4.121790072129495153e-04, 3.795011922700794395e-04, + 3.492743576776980645e-04, 3.213264994742838520e-04, + 2.954967204631142076e-04, 2.716345885168294891e-04, + 2.495995262663892249e-04, 2.292602310565636497e-04, + 2.104941240588425794e-04, 1.931868274441938952e-04, + 1.772316685319751106e-04, 1.625292098474701995e-04, + 1.489868040385482651e-04, 1.365181726217763530e-04, + 1.250430075496097566e-04, 1.144865946129515428e-04, + 1.047794577171236373e-04, 9.585702309404538050e-05, + 8.765930253891871730e-05, 8.013059478589698744e-05, + 7.321920416384974451e-05, 6.687717570036445813e-05, + 6.106004586935513646e-05, 5.572660820503811925e-05, + 5.083869303239757281e-05, 4.636096059159284464e-05, + 4.226070686088361970e-05, 3.850768140955769455e-05, + 3.507391663892059114e-05, 3.193356779560930435e-05, + 2.906276316726267902e-05, 2.643946389585209638e-05, + 2.404333286873791737e-05, 2.185561217168985901e-05, + 1.985900861170214197e-05, 1.803758684038606281e-05, + 1.637666963103638843e-05, 1.486274488410742304e-05, + 1.348337895680140905e-05, 1.222713593274277897e-05, + 1.108350246729316785e-05, 1.004281786293862474e-05, + 9.096209047365308938e-06, 8.235530144320954595e-06, + 7.453306344157341876e-06, 6.742681797056310516e-06, + 6.097371267379711650e-06, 5.511615302353663057e-06, + 4.980138682419505245e-06, 4.498111934064677381e-06, + 4.061115698807367185e-06, 3.665107764259547482e-06, + 3.306392574855062154e-06, 2.981593050916146976e-06, + 2.687624555263038175e-06, 2.421670856565512852e-06, + 2.181161948107961976e-06, 1.963753589613636236e-06, + 1.767308448263122681e-06, 1.589878723069828145e-06, + 1.429690144355897495e-06, 1.285127247227409751e-06, + 1.154719824693183656e-06, 1.037130472427221596e-06, + 9.311431431566156289e-07, 8.356526342833310594e-07, + 7.496549376350442197e-07, 6.722383852050892364e-07, + 6.025755293990676294e-07, 5.399157006726028162e-07, + 4.835781895349435402e-07, 4.329460037221953012e-07, + 3.874601549249640873e-07, 3.466144328024234159e-07, + 3.099506271405782374e-07, 2.770541619298303274e-07, + 2.475501078572927207e-07, 2.210995422445311533e-07, + 1.973962278221192788e-07, 1.761635839292111419e-07, + 1.571519257691956181e-07, 1.401359492507390953e-07, + 1.249124407063549179e-07, 1.112981924164049207e-07, + 9.912810638344996235e-08, 8.825347020762937518e-08, + 7.854039021568849526e-08, 6.986836820115264203e-08, + 6.212900924755241274e-08, 5.522484913656069790e-08, + 4.906829079430924200e-08, 4.358064010729986662e-08, + 3.869123224946007697e-08, 3.433664040870189657e-08}; +} // namespace localization::online_localizer::constants + +#endif diff --git a/src/localization/online_localizer/math_tools/derivation_boundary_estimation.md b/src/localization/online_localizer/math_tools/derivation_boundary_estimation.md new file mode 100644 index 0000000..84f19f9 --- /dev/null +++ b/src/localization/online_localizer/math_tools/derivation_boundary_estimation.md @@ -0,0 +1,83 @@ +# Derivation for boundary estimation between two weighted Gaussians + +Assuming that were are given two Gaussian distributions $\mathcal{N}(\mu_1, \sigma_1)$ and $\mathcal{N}(\mu_2, \sigma_2)$ and weights $\pi_1$ and $\pi_2$ respectively. +Then, we find select the boundary at the point $x$ such that +$$ + \pi_1\mathcal{N}(x | \mu_1, \sigma_1) = \pi_2\mathcal{N}(x | \mu_2, \sigma_2) +$$ + +the point x has a equal probability to belong to two different distributions. + +$$ + +\begin{align} + \frac{\pi_1}{\sqrt{2\pi}\sigma_1} \exp{(-\frac{(x-\mu_1)^2}{2\sigma_1^2})} = \frac{\pi_2}{\sqrt{2\pi}\sigma_2} \exp{(-\frac{(x-\mu_2)^2}{2\sigma_2^2})} && \text{(Multiple by ${\sqrt{2\pi}}$ and by $\ln$)} \\ + + \ln(\frac{\pi_1}{\sigma_1}) - \frac{(x-\mu_1)^2}{2\sigma_1^2} = \ln(\frac{\pi_2}{\sigma_2}) - \frac{(x-\mu_2)^2}{2\sigma_2^2} && \text{(Rearrage)} \\ + - \frac{(x-\mu_1)^2}{2\sigma_1^2} + \frac{(x-\mu_2)^2}{2\sigma_2^2} = \ln(\frac{\pi_2}{\sigma_2}) - \ln(\frac{\pi_1}{\sigma_1}) && \text{(Open brackets)} \\ + + - \frac{1}{2\sigma_1^2} \left( x^2 - 2x\mu_1 + \mu_1^2 \right) + + \frac{1}{2\sigma_2^2} \left( x^2 - 2x\mu_2 + \mu_2^2 \right) = \ln\left(\frac{\pi_2}{\sigma_2}\right) - \ln\left(\frac{\pi_1}{\sigma_1}\right) && \text{(Collect for x)} \\ + + \left( + \frac{1}{2\sigma_2^2} - \frac{1}{2\sigma_1^2} + \right) x^2 + + \left( + \frac{2\mu_1}{2\sigma_1^2} - \frac{2\mu_2}{2\sigma_2^2} + \right) x + + \frac{\mu_2^2}{2\sigma_2^2} - \frac{\mu_1^2}{2\sigma_1^2} = + \ln \left( \frac{\pi_2}{\sigma_2}\right) - \ln \left( \frac{\pi_1}{\sigma_1}\right) + +\end{align} +$$ + +Now the quadratic equation of form $Ax^2 + Bx + C = 0$ can be seen where: + +$$ +\begin{align} + A = \frac{1}{2\sigma_2^2} - \frac{1}{2\sigma_1^2} \\ + B = \frac{\mu_1}{\sigma_1^2} - \frac{\mu_2}{\sigma_2^2} \\ + C = \frac{\mu_2^2}{2\sigma_2^2} - \frac{\mu_1^2}{2\sigma_1^2} - \ln \left( \frac{\pi_2 \sigma_1}{\sigma_2 \pi_1} \right) +\end{align} +$$ + +When solving this quadratic equation, we may encounter the following situations, The discriminant $D = B^2 - 4AC$ + +1. D > 0 + + The equation has two real solutions. $x_1 = -B + \frac{\sqrt{D}}{2A}$ and $x_2 = -B - \frac{\sqrt{D}}{2A}$. +2. D = 0 + + The equation has a unique soluion + $x = \frac{-B}{2A}$. +3. D < 0 + + The equation does not have real solutions. + +For the cases, where we do not have a unique solution, case 1. We select the solution that lies between $\mu_1 < x < \mu_2$ if $\mu_2 > \mu_1$ and $\mu_2 < x < \mu_1$ otherwise. + + +### Border condition + +In case $\sigma_1 = \sigma_2$ the coefficient $A = 0$ and we have a linear equation with respect to $x$. + +$$ +\begin{align} + \left( + \frac{2\mu_1}{2\sigma^2} - \frac{2\mu_2}{2\sigma^2} + \right) x + + \frac{\mu_2^2}{2\sigma^2} - \frac{\mu_1^2}{2\sigma^2} = + \ln \left( \frac{\pi_2}{\sigma}\right) - \ln \left( \frac{\pi_1}{\sigma}\right) \\ + + \left( + \frac{\mu_1 - \mu_2}{\sigma^2} + \right) x + + \frac{\mu_2^2 - \mu_1^2}{2\sigma^2} = + \ln \left( \frac{\pi_2}{\pi_1}\right) \\ + + x = \frac{2\sigma^2 \ln{\frac{\pi_2}{\pi_1}} - \left(\mu_2^2 - \mu_1^2 \right)}{2\left( \mu_1-\mu_2 \right)} + +\end{align} +$$ + +That is it! \ No newline at end of file diff --git a/src/localization/online_localizer/math_tools/gmm.cpp b/src/localization/online_localizer/math_tools/gmm.cpp new file mode 100644 index 0000000..b603c97 --- /dev/null +++ b/src/localization/online_localizer/math_tools/gmm.cpp @@ -0,0 +1,286 @@ +/** Copyright (c) 2026 Olga Vysotska, RSL, ETH Zurich +** Permission is hereby granted, free of charge, to any person obtaining a copy +** of this software and associated documentation files (the "Software"), to deal +** in the Software without restriction, including without limitation the rights +** to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +** copies of the Software, and to permit persons to whom the Software is +** furnished to do so, subject to the following conditions: +** +** The above copyright notice and this permission notice shall be included in +** all copies or substantial portions of the Software. +** +** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +** IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +** FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +** AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +** LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +** OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +** SOFTWARE. +**/ +#include "online_localizer/math_tools/gmm.h" +#include "online_localizer/math_tools/math_tools.h" + +#include "glog/logging.h" + +#include +#include +#include +#include + +namespace localization::online_localizer { + +namespace { +const int kGmmIterationNumber = 100; +const double kModelChangeTolerance = 1e-08; +} // namespace + +bool valuesAreFromUnitSegment(const std::vector &values) { + for (const double &value : values) { + if (value < 0 - std::numeric_limits::epsilon() || + value > 1.0 + std::numeric_limits::epsilon()) { + LOG(WARNING) << "Value is outside the unit segment " << value; + return false; + } + } + return true; +} + +void GaussianModel::update(const std::vector &values, + const std::vector &associations) { + + double associationProbabilitySum = + std::reduce(associations.begin(), associations.end()); + int valuesNum = values.size(); + + // update mu + mu_ = std::inner_product(associations.begin(), associations.end(), + values.begin(), 0.0) / + associationProbabilitySum; + + // update sigma + sigma_ = 0; + for (int i = 0; i < valuesNum; ++i) { + sigma_ += associations[i] * (values[i] - mu_) * (values[i] - mu_); + } + sigma_ /= associationProbabilitySum; + if (std::fabs(sigma_) < 1e-08) { + sigma_ = 0; + } + sigma_ = std::sqrt(sigma_); + + // update weight + weight_ = associationProbabilitySum / valuesNum; +} + +double GaussianModel::computeAssociationProbability(double value) const { + return weight_ * getNormalPdfValueAtX(value, mu_, sigma_); +} + +std::vector GaussianModel::computeModelAssociations( + const std::vector &values) const { + std::vector associations; + associations.reserve(values.size()); + for (int i = 0; i < values.size(); ++i) { + associations.push_back(computeAssociationProbability(values[i])); + } + return associations; +} + +double GaussianModel::computeDifferenceTo(const GaussianModel &other) const { + // This could be a more elaborate comparison. + return std::abs(mu_ - other.mu_); +} + +bool GaussianModel::isDegenerate() const { + if (std::isnan(mu_) || std::isnan(sigma_)) { + return true; + } + return false; +} + +void GaussianMixtureModels::initializeForUnitValues( + const std::vector &values) { + int valuesNum = values.size(); + CHECK(valuesNum >= 2) << "Need at least 2 samples to compute GMM."; + DCHECK(valuesAreFromUnitSegment(values)) + << "Values are outside the unit segment."; + // This initialization only holds, because we assume all the values to be in + // [0,1]. So we assume values to already represent the probability to belong + // to class "Matching". + std::vector matchingAssociations = values; + std::vector nonMatchingAssociations; + nonMatchingAssociations.reserve(valuesNum); + for (int i = 0; i < valuesNum; i++) { + nonMatchingAssociations.emplace_back(1 - values[i]); + } + matchingModel_.update(values, matchingAssociations); + nonMatchingModel_.update(values, nonMatchingAssociations); +} + +void normalizeAssociationProbabilities( + std::vector &matchingAssociations, + std::vector &nonMatchingAssociations) { + CHECK(matchingAssociations.size() == nonMatchingAssociations.size()) + << "Associations of of different sizes."; + + for (int i = 0; i < matchingAssociations.size(); ++i) { + double probabilitySum = + matchingAssociations[i] + nonMatchingAssociations[i]; + CHECK(probabilitySum > 1e-08) + << "Critical value of probability sum " << probabilitySum; + matchingAssociations[i] /= probabilitySum; + nonMatchingAssociations[i] /= probabilitySum; + } +} + +void GaussianMixtureModels::fitTwoModels(const std::vector &values) { + + // Probably can be omitted by the power of math. + this->initializeForUnitValues(values); + + GaussianModel previousMatching = matchingModel_; + GaussianModel previousNonMatching = nonMatchingModel_; + + for (int iter = 0; iter < kGmmIterationNumber; ++iter) { + // E-step + std::vector matchingAssociations = + matchingModel_.computeModelAssociations(values); + std::vector nonMatchingAssociations = + nonMatchingModel_.computeModelAssociations(values); + normalizeAssociationProbabilities(matchingAssociations, + nonMatchingAssociations); + + // M-step + matchingModel_.update(values, matchingAssociations); + nonMatchingModel_.update(values, nonMatchingAssociations); + + if (matchingModel_.computeDifferenceTo(previousMatching) < + kModelChangeTolerance || + nonMatchingModel_.computeDifferenceTo(previousNonMatching) < + kModelChangeTolerance) { + break; + } else { + previousMatching = matchingModel_; + previousNonMatching = nonMatchingModel_; + } + } +} + +double computeModelsSeparationThreshold(const GaussianModel &matching, + const GaussianModel &nonMatching) { + // Computes a mean value between the ends of the 1-sigma confidence interval. + double boundaryClass0Right = nonMatching.mu() + nonMatching.sigma(); + double boundaryClass0Left = nonMatching.mu() - nonMatching.sigma(); + double boundaryClass1Left = matching.mu() - matching.sigma(); + double boundaryClass1Right = matching.mu() + matching.sigma(); + if (boundaryClass0Right < boundaryClass1Left || + (boundaryClass0Left < boundaryClass1Left && + boundaryClass1Left < boundaryClass0Right)) { + return (boundaryClass1Left + boundaryClass0Right) / 2; + } + return (boundaryClass0Left + boundaryClass1Right) / 2; +} + +bool xBetweenValues(double x, double valueLeft, double valueRight) { + CHECK(valueLeft <= valueRight) << "Left value > right value"; + return (valueLeft <= x && x <= valueRight); +} + +std::optional +computeDecisionBoundary(const GaussianModel &matching, + const GaussianModel &nonMatching) { + // This function compute the boundary between two Gaussian distribution + // The boundary is the point x that has equal probability to belong to both + // classes Please check derivation_boundary_estimation.md for details. + double mu_1 = matching.mu(); + double sigma_1 = matching.sigma(); + double pi_1 = matching.weight(); + + double mu_2 = nonMatching.mu(); + double sigma_2 = nonMatching.sigma(); + double pi_2 = nonMatching.weight(); + + if (std::abs(mu_1 - mu_2) < 1e-08) { + // means are the same -> no separation boundary + LOG(WARNING) << "Means are too close" << mu_1 << ", " << mu_2; + return {}; + } + + // if variances are equal + if (std::abs(sigma_1 - sigma_2) < 1e-03 || std::abs(sigma_1) < 1e-08 || + std::abs(sigma_2) < 1e-08) { + // having a linear equation now. + return (2.0 * (sigma_1 * sigma_1) * std::log(pi_2 / pi_1) - + (mu_2 * mu_2 - mu_1 * mu_1)) / + (2.0 * (mu_1 - mu_2)); + } + + double A = 1. / (2 * sigma_2 * sigma_2) - 1. / (2 * sigma_1 * sigma_1); + double B = mu_1 / (sigma_1 * sigma_1) - mu_2 / (sigma_2 * sigma_2); + double C = (mu_2 * mu_2) / (2 * sigma_2 * sigma_2) - + (mu_1 * mu_1) / (2 * sigma_1 * sigma_1) - + std::log((pi_2 * sigma_1) / (sigma_2 * pi_1)); + + double discriminant = B * B - 4 * A * C; + if (discriminant > 0) { + double solution_1 = (-B + std::sqrt(discriminant)) / (2 * A); + double solution_2 = (-B - std::sqrt(discriminant)) / (2 * A); + + if (mu_1 <= mu_2) { + if (xBetweenValues(solution_1, mu_1, mu_2)) { + return solution_1; + } else if (xBetweenValues(solution_2, mu_1, mu_2)) { + return solution_2; + } + LOG(WARNING) + << " None of the solution of quadratic equation is between the mu_1: " + << mu_1 << " and mu_2: " << mu_2; + LOG(INFO) << "Solution 1: " << solution_1; + LOG(INFO) << "Solution 2: " << solution_2; + return {}; + } + if (mu_2 < mu_1) { + if (xBetweenValues(solution_1, mu_2, mu_1)) { + return solution_1; + } else if (xBetweenValues(solution_2, mu_2, mu_1)) { + return solution_2; + } + LOG(WARNING) + << " None of the solution of quadratic equation is between the mu_2: " + << mu_2 << " and mu_1: " << mu_1; + LOG(INFO) << "Solution 1: " << solution_1; + LOG(INFO) << "Solution 2: " << solution_2; + return {}; + } + } else if (std::abs(discriminant) < 1e-08) { + return -B / (2 * A); + } + + LOG(FATAL) << "Quadratic equation does not have real solutions"; +} + +std::optional> +estimateSeparationThreshold(const std::vector &values, + bool distance_based) { + GaussianMixtureModels models; + models.fitTwoModels(values); + if (models.matchingModel().isDegenerate() || + models.nonMatchingModel().isDegenerate()) { + LOG(WARNING) << "One of the models is degenerate"; + return {}; + } + if (distance_based) { + return std::tuple(computeModelsSeparationThreshold( + models.matchingModel(), models.nonMatchingModel()), + models); + } + + std::optional decision_boundary = computeDecisionBoundary( + models.matchingModel(), models.nonMatchingModel()); + + if (decision_boundary) { + return std::tuple(decision_boundary.value(), models); + } + return {}; +} +} // namespace localization::online_localizer \ No newline at end of file diff --git a/src/localization/online_localizer/math_tools/gmm.h b/src/localization/online_localizer/math_tools/gmm.h new file mode 100644 index 0000000..50c15cc --- /dev/null +++ b/src/localization/online_localizer/math_tools/gmm.h @@ -0,0 +1,87 @@ +/** Copyright (c) 2026 Olga Vysotska, RSL, ETH Zurich +** Permission is hereby granted, free of charge, to any person obtaining a copy +** of this software and associated documentation files (the "Software"), to deal +** in the Software without restriction, including without limitation the rights +** to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +** copies of the Software, and to permit persons to whom the Software is +** furnished to do so, subject to the following conditions: +** +** The above copyright notice and this permission notice shall be included in +** all copies or substantial portions of the Software. +** +** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +** IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +** FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +** AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +** LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +** OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +** SOFTWARE. +**/ +#ifndef SRC_ONLINE_LOCALIZER_MATH_TOOLS_GMM_H_ +#define SRC_ONLINE_LOCALIZER_MATH_TOOLS_GMM_H_ + +#include +#include +#include +#include + +namespace localization::online_localizer { + +class GaussianModel { +public: + GaussianModel(double mu, double sigma, double weight) + : mu_(mu), sigma_(sigma), weight_(weight) {} + double computeAssociationProbability(double value) const; + std::vector + computeModelAssociations(const std::vector &values) const; + bool isDegenerate() const; + void update(const std::vector &values, + const std::vector &associations); + double computeDifferenceTo(const GaussianModel &other) const; + const double &mu() const { return mu_; } + const double &sigma() const { return sigma_; } + const double &weight() const { return weight_; } + friend std::ostream &operator<<(std::ostream &stream, + const GaussianModel &model) { + stream << "mu " << model.mu_ << " sigma " << model.sigma_ << " weight " + << model.weight_; + return stream; + } + +private: + double mu_ = 0.0; + double sigma_ = 1.0; + double weight_ = 0.0; +}; + +class GaussianMixtureModels { +public: + void fitTwoModels(const std::vector &values); + const GaussianModel &matchingModel() const { return matchingModel_; } + const GaussianModel &nonMatchingModel() const { return nonMatchingModel_; } + + void initializeForUnitValues(const std::vector &values); + +private: + GaussianModel matchingModel_ = {1.0, 1.0, 0.5}; + GaussianModel nonMatchingModel_ = {0.0, 1.0, 0.5}; +}; + +double computeModelsSeparationThreshold(const GaussianModel &matching, + const GaussianModel &nonMatching); + +bool xBetweenValues(double x, double valueLeft, double valueRight); +std::optional computeDecisionBoundary(const GaussianModel &matching, + const GaussianModel &nonMatching); + +void normalizeAssociationProbabilities( + std::vector &matchingAssociations, + std::vector &nonMatchingAssociations); + +std::optional> +estimateSeparationThreshold(const std::vector &values, + bool distance_based = true); + +} // namespace localization::online_localizer + +#endif \ No newline at end of file diff --git a/src/localization/online_localizer/math_tools/math_tools.cpp b/src/localization/online_localizer/math_tools/math_tools.cpp new file mode 100644 index 0000000..6967c17 --- /dev/null +++ b/src/localization/online_localizer/math_tools/math_tools.cpp @@ -0,0 +1,121 @@ +/** Copyright (c) 2026 Olga Vysotska, RSL, ETH Zurich +** Permission is hereby granted, free of charge, to any person obtaining a copy +** of this software and associated documentation files (the "Software"), to deal +** in the Software without restriction, including without limitation the rights +** to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +** copies of the Software, and to permit persons to whom the Software is +** furnished to do so, subject to the following conditions: +** +** The above copyright notice and this permission notice shall be included in +** all copies or substantial portions of the Software. +** +** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +** IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +** FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +** AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +** LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +** OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +** SOFTWARE. +**/ +#include "math_tools.h" +#include "glog/logging.h" + +#include +#include +#include + +namespace localization::online_localizer { + +namespace { +constexpr inline double kInvSqrt2Pi = 0.3989422804014327; +} // namespace + +std::vector computeCumulativeSum(const std::vector &array) { + + std::vector cumSum; + double sum = 0; + cumSum.reserve(array.size()); + for (auto number : array) { + sum += number; + cumSum.emplace_back(sum); + } + return cumSum; +} + +double estimateMean(const std::vector &values) { + CHECK(!values.empty()) << "Error: Cannot calculate mean for an empty vector."; + double sum = std::accumulate(values.begin(), values.end(), 0.0); + double mean = sum / values.size(); + + return mean; +} + +double estimateVariance(const std::vector &values) { + CHECK(values.size() >= 2) + << "Error: Variance calculation requires at least two elements."; + double mean = estimateMean(values); + double sumSquaredDifferences = std::accumulate( + values.begin(), values.end(), 0.0, [mean](double acc, double value) { + return acc + (value - mean) * (value - mean); + }); + + double variance = sumSquaredDifferences / (values.size()); + + return variance; +} + +// Assumes x and y are sorted. +double interpolate1D(const std::vector &x, const std::vector &y, + double point) { + + CHECK(x.size() == y.size()) << "x and y vectors must have the same size"; + + LOG_IF(WARNING, point < x.front() || point > x.back()) + << "point " << point << " is outside range [" << x.front() << "," + << x.back() << "]"; + point = std::max(x.front(), point); + point = std::min(x.back(), point); + + // Find the index i such that point < x[i] + auto upper = std::upper_bound(x.begin(), x.end(), point); + int i = 0; + if (upper != x.begin()) { + i = std::distance(x.begin(), upper) - 1; + } + + // Perform linear interpolation + double interpolatedValue = + y[i] + ((y[i + 1] - y[i]) / (x[i + 1] - x[i])) * (point - x[i]); + return interpolatedValue; +} + +double getNormalPdfValueAtX(double x, double mu, double sigma) { + if (std::abs(sigma) < 1e-08) { + LOG(WARNING) << "Sigma is too close to 0, sigma= " << sigma; + sigma = 1e-08; + } + double tmp = (x - mu) / sigma; + return kInvSqrt2Pi / sigma * std::exp(-0.5 * tmp * tmp); +} + +ValueEstimate kalmanFilterUpdate(ValueEstimate previousEstimate, + ValueEstimate measurement) { + CHECK(previousEstimate.uncertainty > 0) << "Previous uncertainty <= 0."; + CHECK(measurement.uncertainty > 0) << "Measuremerent uncertainty <= 0."; + // TODO(olga, when visualization is ready). Add process noise to avoid too + // certain estimates. + previousEstimate.uncertainty += 0.01; + double kalmanGain = previousEstimate.uncertainty / + (previousEstimate.uncertainty + measurement.uncertainty); + + ValueEstimate updatedEstimate; + updatedEstimate.value = + previousEstimate.value + + kalmanGain * (measurement.value - previousEstimate.value); + + updatedEstimate.uncertainty = (1 - kalmanGain) * previousEstimate.uncertainty; + + return updatedEstimate; +} + +} // namespace localization::online_localizer \ No newline at end of file diff --git a/src/localization/online_localizer/math_tools/math_tools.h b/src/localization/online_localizer/math_tools/math_tools.h new file mode 100644 index 0000000..400e8e7 --- /dev/null +++ b/src/localization/online_localizer/math_tools/math_tools.h @@ -0,0 +1,44 @@ +/** Copyright (c) 2026 Olga Vysotska, RSL, ETH Zurich +** Permission is hereby granted, free of charge, to any person obtaining a copy +** of this software and associated documentation files (the "Software"), to deal +** in the Software without restriction, including without limitation the rights +** to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +** copies of the Software, and to permit persons to whom the Software is +** furnished to do so, subject to the following conditions: +** +** The above copyright notice and this permission notice shall be included in +** all copies or substantial portions of the Software. +** +** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +** IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +** FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +** AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +** LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +** OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +** SOFTWARE. +**/ +#ifndef SRC_ONLINE_LOCALIZER_MATH_TOOLS_H_ +#define SRC_ONLINE_LOCALIZER_MATH_TOOLS_H_ + +#include + +namespace localization::online_localizer { + +double getNormalPdfValueAtX(double x, double mu, double sigma); +std::vector computeCumulativeSum(const std::vector &array); +double interpolate1D(const std::vector &x, const std::vector &y, + double t); + +double estimateMean(const std::vector &values); +double estimateVariance(const std::vector &values); + +struct ValueEstimate { + double value; + double uncertainty; +}; +ValueEstimate kalmanFilterUpdate(ValueEstimate previousEstimate, + ValueEstimate measurement); + +} // namespace localization::online_localizer + +#endif \ No newline at end of file diff --git a/src/localization/online_localizer/math_tools/statistical_test.cpp b/src/localization/online_localizer/math_tools/statistical_test.cpp new file mode 100644 index 0000000..8cda714 --- /dev/null +++ b/src/localization/online_localizer/math_tools/statistical_test.cpp @@ -0,0 +1,194 @@ +/** Copyright (c) 2026 Olga Vysotska, RSL, ETH Zurich +** Permission is hereby granted, free of charge, to any person obtaining a copy +** of this software and associated documentation files (the "Software"), to deal +** in the Software without restriction, including without limitation the rights +** to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +** copies of the Software, and to permit persons to whom the Software is +** furnished to do so, subject to the following conditions: +** +** The above copyright notice and this permission notice shall be included in +** all copies or substantial portions of the Software. +** +** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +** IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +** FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +** AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +** LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +** OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +** SOFTWARE. +**/ +#include "statistical_test.h" +#include "constants.h" +#include "glog/logging.h" +#include "math_tools.h" + +#include +#include +#include +#include +#include + +namespace localization::online_localizer { + +namespace { +const int kHistogramBinsNum = 50; +const double kSignificanceValue = 0.05; +} // namespace + +const EmpiricalDistribution kGaussianDistribution = createStandardGaussian(); +const EmpiricalDistribution kKolmogorovPvalueDistribution = + createKolmogorovDistribution(); + +EmpiricalDistribution createStandardGaussian() { + const float kGaussianResolution = 0.001; + const float kGaussianBoundaries = 5.0; + std::vector gaussianX(2 * kGaussianBoundaries / kGaussianResolution, + 0.0); + std::vector gaussianPdf(2 * kGaussianBoundaries / kGaussianResolution, + 0.0); + std::vector gaussianCdf(2 * kGaussianBoundaries / kGaussianResolution, + 0.0); + gaussianX[0] = -kGaussianBoundaries; + gaussianPdf[0] = getNormalPdfValueAtX(gaussianX[0], 0.0, 1.0); + for (int i = 1; i < gaussianX.size(); ++i) { + gaussianX[i] = gaussianX[i - 1] + kGaussianResolution; + gaussianPdf[i] = getNormalPdfValueAtX(gaussianX[i], 0.0, 1.0); + } + + std::vector cumsumValues = computeCumulativeSum(gaussianPdf); + // We need to multiply by bin_size to turn cumsum into cdf (describes the cdf + // of a distribution) + for (int idx = 0; idx < cumsumValues.size(); ++idx) { + gaussianCdf[idx] = cumsumValues[idx] * kGaussianResolution; + } + return EmpiricalDistribution(gaussianX, gaussianPdf, gaussianCdf); +} + +double getNormalCdfAtPoint(double point, double mu, double sigma, + const EmpiricalDistribution &gaussianDistribution) { + double pointNorm = (point - mu) / sigma; + return interpolate1D(gaussianDistribution.x, gaussianDistribution.cdf, + pointNorm); +} + +EmpiricalDistribution createKolmogorovDistribution() { + return EmpiricalDistribution(constants::kolmogorovX, + constants::kolmogorovPvalue, {}); +} + +EmpiricalDistribution::EmpiricalDistribution(const std::vector &x, + const std::vector &values, + const std::vector &cdf) { + // Make sure that the values make sense. No further checks are provided here. + this->x = x; + this->values = values; + this->cdf = cdf; +} + +EmpiricalDistribution::EmpiricalDistribution(const std::vector &samples, + int numBins) { + + CHECK(samples.size() > 1) << "Samples should contain at least 2 values."; + CHECK(numBins > 0) << "Num of Bins must be higher then 0."; + + computeHistogram(samples, numBins); + computeCdf(); +} + +void EmpiricalDistribution::computeHistogram(const std::vector &samples, + int numBins) { + double minValue = *std::min_element(samples.begin(), samples.end()); + double maxValue = *std::max_element(samples.begin(), samples.end()); + + CHECK(maxValue > minValue) << "Max value should be bigger then min value"; + + double binSize = (maxValue - minValue) / numBins; + double tmpX = minValue; + while (tmpX <= maxValue) { + x.push_back(tmpX); + tmpX += binSize; + } + + values.resize(x.size(), 0.0); + for (double value : samples) { + int binIndex = std::floor((value - minValue) / binSize); + if (binIndex == x.size() && std::abs(value - maxValue) < 1e-08) { + values[binIndex - 1] += 1; + continue; + } + LOG_IF(WARNING, binIndex > x.size()) << "Bin index is outside boundaries"; + values[binIndex] += 1; + } +} + +void EmpiricalDistribution::computeCdf() { + // In theory, the cdf is the integral over the pdf (lossely speaking). + // However, we don't compute a real PDF before, we just use histogram, + // so just computing the cumulative sum and dividing by the last element (the + // sum of all the values) represents the same cdf if we would have computed + // the proper pdf and the integral over it. + std::vector cumsumValues = computeCumulativeSum(values); + cdf.resize(values.size(), 0.0); + for (int idx = 0; idx < cumsumValues.size(); ++idx) { + cdf[idx] = cumsumValues[idx] / *(cumsumValues.end() - 1); + } +} + +std::tuple +patchContainsPath(const std::vector &values) { + // This function checks if the patch contains the path using + // the Kolmogorov-Smirnov statistical test. + // https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test + // We consider that the patch contains NO path is the values form a unimodal + // distribution. If the observed distribution diverges from the unimodal + // Gaussian distribution, then there exists a path in the patch. + // The null hypothesis in the ks-test is: The patch is unimodal. + + // The idea: We compute a unimodal Gaussian parameters from the set of values. + // Then we compare the theoretical CDF of this unimodal Gaussian with the + // empirical CDF of the values (real, obtained values). + // If the CDF differs "significantly" (according to the kSignificanceValue) + // then we reject the null hypothesis. And the patch contains the path. + + const double expectedMean = estimateMean(values); + const double expectedVariance = estimateVariance(values); + const double expectedSigma = std::sqrt(expectedVariance); + + EmpiricalDistribution observedDistribution = + EmpiricalDistribution(values, kHistogramBinsNum); + + std::vector gaussianCdfsAtObservedPoints; + for (const auto &observedPoint : observedDistribution.x) { + gaussianCdfsAtObservedPoints.push_back(getNormalCdfAtPoint( + observedPoint, expectedMean, expectedSigma, kGaussianDistribution)); + } + + // Compute D statistics. Notation from KS test. + std::vector dStatistics(gaussianCdfsAtObservedPoints.size()); + for (int pointIdx = 0; pointIdx < gaussianCdfsAtObservedPoints.size(); + ++pointIdx) { + dStatistics[pointIdx] = observedDistribution.cdf[pointIdx] - + gaussianCdfsAtObservedPoints[pointIdx]; + } + + auto dMax = std::max_element(dStatistics.begin(), dStatistics.end()); + auto dMin = std::min_element(dStatistics.begin(), dStatistics.end()); + + double Dn, location; + if (std::abs(*dMax) > std::abs(*dMin)) { + Dn = std::abs(*dMax); + location = observedDistribution.x[dMax - dStatistics.begin()]; + } else { + Dn = std::abs(*dMin); + location = observedDistribution.x[dMin - dStatistics.begin()]; + } + double Kn = std::sqrt(observedDistribution.x.size()) * Dn; + double pvalue = interpolate1D(kKolmogorovPvalueDistribution.x, + kKolmogorovPvalueDistribution.values, Kn); + + // If pvalue < 0.05, we reject the null hypothesis that patch is unimodal + // with probability of 95% + return std::make_tuple(pvalue < kSignificanceValue, pvalue, location); +} + +} // namespace localization::online_localizer \ No newline at end of file diff --git a/src/localization/online_localizer/math_tools/statistical_test.h b/src/localization/online_localizer/math_tools/statistical_test.h new file mode 100644 index 0000000..9796133 --- /dev/null +++ b/src/localization/online_localizer/math_tools/statistical_test.h @@ -0,0 +1,51 @@ +/** Copyright (c) 2026 Olga Vysotska, RSL, ETH Zurich +** Permission is hereby granted, free of charge, to any person obtaining a copy +** of this software and associated documentation files (the "Software"), to deal +** in the Software without restriction, including without limitation the rights +** to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +** copies of the Software, and to permit persons to whom the Software is +** furnished to do so, subject to the following conditions: +** +** The above copyright notice and this permission notice shall be included in +** all copies or substantial portions of the Software. +** +** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +** IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +** FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +** AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +** LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +** OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +** SOFTWARE. +**/ +#ifndef SRC_ONLINE_LOCALIZER_MATH_TOOLS_STATISTICAL_TEST_H_ +#define SRC_ONLINE_LOCALIZER_MATH_TOOLS_STATISTICAL_TEST_H_ + +#include +#include +#include + +namespace localization::online_localizer { + +class EmpiricalDistribution { +public: + EmpiricalDistribution(const std::vector &samples, int numBins); + EmpiricalDistribution(const std::vector &x, + const std::vector &values, + const std::vector &cdf); + + void computeCdf(); + void computeHistogram(const std::vector &samples, int numBins); + + std::vector x; + std::vector values; // f(x) + std::vector cdf; +}; + +EmpiricalDistribution createKolmogorovDistribution(); +EmpiricalDistribution createStandardGaussian(); +std::tuple +patchContainsPath(const std::vector &values); + +} // namespace localization::online_localizer + +#endif \ No newline at end of file diff --git a/src/localization/online_localizer/online_localizer.cpp b/src/localization/online_localizer/online_localizer.cpp index 9915ba5..e0e27ba 100644 --- a/src/localization/online_localizer/online_localizer.cpp +++ b/src/localization/online_localizer/online_localizer.cpp @@ -25,6 +25,8 @@ #include "online_localizer/online_localizer.h" #include "localization_protos.pb.h" +#include "online_localizer/math_tools/gmm.h" +#include "online_localizer/math_tools/statistical_test.h" #include "online_localizer/path_element.h" #include "tools/timer/timer.h" @@ -44,11 +46,37 @@ namespace localization::online_localizer { using std::string; using std::vector; +namespace { + const float kMaxLostNodesRatio = 0.8; // 80% +const int kPatchWidth = 20; +const int kMinPatchSize = 2; + +} // namespace + +void writePatchToProto(const std::vector patch, + image_sequence_localizer::Patch *patchProto) { + for (const auto &node : patch) { + image_sequence_localizer::Patch::Element *element = + patchProto->add_elements(); + element->set_row(node.quId); + element->set_col(node.refId); + element->set_similarity_value(node.idvCost); + } +} + +void writeGmmStatsToProto(const GaussianMixtureModels &models, + image_sequence_localizer::PatchInfo::GMM *gmmProto) { + gmmProto->set_mu_class_0(models.nonMatchingModel().mu()); + gmmProto->set_mu_class_1(models.matchingModel().mu()); + gmmProto->set_sigma_class_0(models.nonMatchingModel().sigma()); + gmmProto->set_sigma_class_1(models.matchingModel().sigma()); +} + OnlineLocalizer::OnlineLocalizer( successor_manager::SuccessorManager *successorManager, double expansionRate, - double matchingThreshold) { + double matchingThreshold, bool adaptThreshold) { CHECK(successorManager) << "Successor manager is not set."; CHECK(expansionRate > 0 && expansionRate <= 1) @@ -58,7 +86,13 @@ OnlineLocalizer::OnlineLocalizer( successorManager_ = successorManager; expansionRate_ = expansionRate; - matchingThreshold_ = matchingThreshold; + matchingThreshold_.value = matchingThreshold; + adaptThreshold_ = adaptThreshold; + LOG(INFO) << "Adapting matching threshold: " << (adaptThreshold_ ? "yes": "no"); + + debug_.set_start_matching_threshold(matchingThreshold_.value); + debug_.set_stats_patch_width(kPatchWidth); + debug_.set_sliding_window_size(kSlidingWindowSize_); pred_[kSourceNode.quId][kSourceNode.refId] = kSourceNode; Node source = kSourceNode; @@ -68,15 +102,19 @@ OnlineLocalizer::OnlineLocalizer( currentBestHyp_ = source; } -Matches OnlineLocalizer::findMatchesTill(int queryId) { +Matches OnlineLocalizer::findMatchesTill(int queryId, + const std::string &debugFilename) { CHECK(queryId >= 0) << "Number of queries is <= 0: " << queryId; Timer timer; // For the first image consider lost // for every image in the query set for (int qu = 0; qu < queryId; ++qu) { + // Setting up debug message + image_sequence_localizer::OnlineLocalizerDebugPerStep *debugPerStep = + debug_.add_debug_per_step(); // while the graph is not expanded till row 'qu' timer.start(); - processImage(qu); + processImage(qu, debugPerStep); timer.stop(); LOG(INFO) << "Matched image " << qu; @@ -85,6 +123,18 @@ Matches OnlineLocalizer::findMatchesTill(int queryId) { visualize(); } LOG(INFO) << "Finished matching."; + + if (!debugFilename.empty()) { + std::fstream out(debugFilename, + std::ios::out | std::ios::trunc | std::ios::binary); + if (!debug_.SerializeToOstream(&out)) { + LOG(ERROR) << "Couldn't open the file: " << debugFilename; + } + out.close(); + + LOG(INFO) << "Debug is written to: " << debugFilename; + } + if (_vis) { _vis->processFinished(); } @@ -109,6 +159,54 @@ void OnlineLocalizer::writeOutExpanded(const std::string &filename) const { LOG(INFO) << "Wrote patch " << filename; } +std::optional OnlineLocalizer::estimateMatchingThreshold( + image_sequence_localizer::PatchInfo *patchInfoProto) const { + + std::vector patch = successorManager_->getPatchCosts( + currentBestHyp_.quId, currentBestHyp_.refId, kPatchWidth); + patchInfoProto->set_best_row(currentBestHyp_.quId); + patchInfoProto->set_best_col(currentBestHyp_.refId); + + if (patch.size() < kMinPatchSize) { + LOG(WARNING) << "Patch is too small to make decisions: " << patch.size(); + return {}; + } + + writePatchToProto(patch, patchInfoProto->mutable_patch()); + + std::vector patchValues; + patchValues.reserve(patch.size()); + for (const auto &node : patch) { + patchValues.push_back(node.idvCost); + } + + const auto &[pathInsidePatch, pvalue, location] = + patchContainsPath(patchValues); + patchInfoProto->set_path_exists(pathInsidePatch); + patchInfoProto->set_pvalue(pvalue); + patchInfoProto->set_pvalue_location(location); + + if (!pathInsidePatch) { + LOG(INFO) << "No path inside the patch"; + return {}; + } + + auto matchThresholdResult = estimateSeparationThreshold(patchValues); + + if (!matchThresholdResult) { + return {}; + } + const auto &[matchThreshold, gaussianModels] = matchThresholdResult.value(); + writeGmmStatsToProto(gaussianModels, patchInfoProto->mutable_gmm_stats()); + + if (std::abs(matchThreshold) < 1e-08) { + LOG(WARNING) << "Match threshold is too close to 0"; + return {}; + } + patchInfoProto->set_threshold(matchThreshold); + return 1.0 / matchThreshold; +} + // frontier picking up routine void OnlineLocalizer::matchImage(int quId) { expandedRecently_.clear(); @@ -156,12 +254,15 @@ void OnlineLocalizer::matchImage(int quId) { } } -void OnlineLocalizer::processImage(int quId) { +void OnlineLocalizer::processImage( + int quId, + image_sequence_localizer::OnlineLocalizerDebugPerStep *debugProto) { LOG(INFO) << "Checking image " << quId; if (quId == 0) { needReloc_ = true; } matchImage(quId); + queryIdToThreshMap_[quId] = matchingThreshold_.value; CHECK(!frontier_.empty()) << "Frontier is empty! Something bad happened."; @@ -172,6 +273,40 @@ void OnlineLocalizer::processImage(int quId) { } else { needReloc_ = false; } + + if (!adaptThreshold_){ + return; + } + + // Perform threshold adaptation + + image_sequence_localizer::PatchInfo *ksDebugInfo = + debugProto->add_ks_patches(); + debugProto->set_matching_threshold(matchingThreshold_.value); + debugProto->set_row(quId); + // TODO(olga). Not sure this is the safest way to code this. + writePatchToProto({expandedRecently_.begin(), expandedRecently_.end()}, + debugProto->mutable_expanded_nodes()); + + debugProto->set_lost(needReloc_); + convertMatchesToProto(getCurrentPath(), debugProto->mutable_path()); + + // Threshold is estimated only if path exists and GMM was able to find a + // separation threshold + only if in Localization mode (NOT LOST) + // or LOST but first path is not found yet + if (!needReloc_ || (needReloc_ && !thresholdFoundFirstTime_)) { + std::optional estimatedThreshold = + estimateMatchingThreshold(ksDebugInfo); + if (estimatedThreshold) { + ValueEstimate estimate; + estimate.value = estimatedThreshold.value(); + estimate.uncertainty = 20.0; + + matchingThreshold_ = kalmanFilterUpdate(matchingThreshold_, estimate); + thresholdFoundFirstTime_ = true; + debugProto->set_estimated_matching_threshold(matchingThreshold_.value); + } + } } bool OnlineLocalizer::nodeWorthExpanding(const Node &node) const { @@ -315,7 +450,8 @@ std::vector OnlineLocalizer::getCurrentPath() const { source_reached = true; continue; } - NodeState state = pred.idvCost > matchingThreshold_ ? HIDDEN : REAL; + NodeState state = + pred.idvCost > queryIdToThreshMap_.at(pred.quId) ? HIDDEN : REAL; PathElement pathEl(pred.quId, pred.refId, state); path.push_back(pathEl); pred = pred_.at(pred.quId).at(pred.refId); @@ -362,7 +498,8 @@ std::vector OnlineLocalizer::getLastNmatches(int N) const { source_reached = true; continue; } - NodeState state = pred.idvCost > matchingThreshold_ ? HIDDEN : REAL; + NodeState state = + pred.idvCost > queryIdToThreshMap_.at(pred.quId) ? HIDDEN : REAL; PathElement pathEl(pred.quId, pred.refId, state); path.push_back(pathEl); pred = pred_.at(pred.quId).at(pred.refId); diff --git a/src/localization/online_localizer/online_localizer.h b/src/localization/online_localizer/online_localizer.h index c3616d6..cd0e19b 100644 --- a/src/localization/online_localizer/online_localizer.h +++ b/src/localization/online_localizer/online_localizer.h @@ -24,15 +24,17 @@ #ifndef SRC_ONLINE_LOCALIZER_ONLINE_LOCALIZER_H_ #define SRC_ONLINE_LOCALIZER_ONLINE_LOCALIZER_H_ +#include #include #include -#include #include #include #include #include +#include "localization_protos.pb.h" #include "online_localizer/ilocvisualizer.h" +#include "online_localizer/math_tools/math_tools.h" #include "online_localizer/path_element.h" #include "successor_manager/node.h" #include "successor_manager/successor_manager.h" @@ -46,14 +48,17 @@ class OnlineLocalizer { using AccCostsMap = std::unordered_map>; OnlineLocalizer(successor_manager::SuccessorManager *successorManager, - double expansionRate, double matchingThreshold); + double expansionRate, double matchingThreshold, + bool adaptThreshold = false); ~OnlineLocalizer() {} - Matches findMatchesTill(int queryId); + Matches findMatchesTill(int queryId, const std::string &debugFilename); void writeOutExpanded(const std::string &filename) const; protected: - void processImage(int quId); + void processImage( + int quId, + image_sequence_localizer::OnlineLocalizerDebugPerStep *debugProto); void matchImage(int quId); std::vector getCurrentPath() const; @@ -66,6 +71,8 @@ class OnlineLocalizer { double computeAveragePathCost() const; bool isLost(int N, double perc) const; + std::optional estimateMatchingThreshold( + image_sequence_localizer::PatchInfo *patchInfoProto) const; void visualize() const; @@ -73,19 +80,24 @@ class OnlineLocalizer { int kSlidingWindowSize_ = 5; // frames bool needReloc_ = false; double expansionRate_ = -1.0; - double matchingThreshold_ = -1.0; + + ValueEstimate matchingThreshold_ = {-1.0, 100}; + bool thresholdFoundFirstTime_ = false; + bool adaptThreshold_ = false; std::priority_queue frontier_; // stores parent for each node PredMap pred_; - // stores the accumulative cost for each node + // stores the accumulative cost for each node AccCostsMap accCosts_; Node currentBestHyp_; successor_manager::SuccessorManager *successorManager_ = nullptr; iLocVisualizer::Ptr _vis = nullptr; + std::map queryIdToThreshMap_; NodeSet expandedRecently_; + image_sequence_localizer::OnlineLocalizerDebug debug_; }; } // namespace localization::online_localizer diff --git a/src/localization/online_localizer/path_element.cpp b/src/localization/online_localizer/path_element.cpp index 5998f91..899656b 100644 --- a/src/localization/online_localizer/path_element.cpp +++ b/src/localization/online_localizer/path_element.cpp @@ -31,6 +31,18 @@ namespace localization::online_localizer { +void convertMatchesToProto( + const Matches &matches, + image_sequence_localizer::MatchingResult *matching_result_proto) { + for (const auto &match : matches) { + image_sequence_localizer::MatchingResult::Match *match_proto = + matching_result_proto->add_matches(); + match_proto->set_query_id(match.quId); + match_proto->set_ref_id(match.refId); + match_proto->set_real(match.state == NodeState::HIDDEN ? 0 : 1); + } +} + void storeMatchesAsProto(const Matches &matches, const std::string &protoFilename) { image_sequence_localizer::MatchingResult matching_result_proto; diff --git a/src/localization/online_localizer/path_element.h b/src/localization/online_localizer/path_element.h index 1596fb5..a5b5f50 100644 --- a/src/localization/online_localizer/path_element.h +++ b/src/localization/online_localizer/path_element.h @@ -24,6 +24,8 @@ #ifndef SRC_ONLINE_LOCALIZER_PATH_ELEMENT_H_ #define SRC_ONLINE_LOCALIZER_PATH_ELEMENT_H_ +#include "localization_protos.pb.h" + #include #include @@ -51,6 +53,10 @@ class PathElement { using Matches = std::vector; void storeMatchesAsProto(const Matches &matches, const std::string &protoFilename); + +void convertMatchesToProto( + const Matches &matches, + image_sequence_localizer::MatchingResult *matching_result_proto); }; // namespace localization::online_localizer #endif // SRC_ONLINE_LOCALIZER_PATH_ELEMENT_H_ diff --git a/src/localization/successor_manager/successor_manager.cpp b/src/localization/successor_manager/successor_manager.cpp index 1ab198e..94bb27e 100644 --- a/src/localization/successor_manager/successor_manager.cpp +++ b/src/localization/successor_manager/successor_manager.cpp @@ -171,4 +171,25 @@ SuccessorManager::getSuccessorsIfLost(const Node &node) { } return _successors; } + +std::vector SuccessorManager::getPatchCosts(int quId, int refId, + int patchWidth) const { + CHECK_GE(patchWidth, 1); + std::vector patchCosts; + int halfPatchWidth = (patchWidth > 1) ? patchWidth / 2 : patchWidth; + int topLeftRow = std::max(quId - (patchWidth - 1), 0); + int topLeftCol = std::max(refId - (patchWidth - 1), 0); + for (int patchQueryId = topLeftRow; patchQueryId <= quId; patchQueryId++) { + for (int patchRefId = topLeftCol; patchRefId <= refId; patchRefId++) { + // Return directly similarity matching costs [0,1]. + auto cost = database_->getCostIfComputed(patchQueryId, patchRefId); + if (cost) { + patchCosts.push_back( + Node(patchQueryId, patchRefId, std::abs(cost.value()))); + } + } + } + return patchCosts; +} + } // namespace localization::successor_manager diff --git a/src/localization/successor_manager/successor_manager.h b/src/localization/successor_manager/successor_manager.h index 7f25c15..a424760 100644 --- a/src/localization/successor_manager/successor_manager.h +++ b/src/localization/successor_manager/successor_manager.h @@ -62,6 +62,8 @@ class SuccessorManager { void getSuccessorFanOut(int quId, int refId); void getSuccessorsSimPlaces(int quId, int refId); + std::vector getPatchCosts(int quId, int refId, int patchWidth) const; + protected: database::iDatabase *database_ = nullptr; int fanOut_ = 0; diff --git a/src/localization/tools/config_parser/config_parser.cpp b/src/localization/tools/config_parser/config_parser.cpp index 77bc0fc..7287f8d 100644 --- a/src/localization/tools/config_parser/config_parser.cpp +++ b/src/localization/tools/config_parser/config_parser.cpp @@ -31,166 +31,172 @@ using std::string; bool ConfigParser::parse(const std::string &iniFile) { - std::ifstream in(iniFile.c_str()); - if (!in) { - printf("[ERROR][ConfigParser] The file \"%s\" cannot be opened.\n", - iniFile.c_str()); - return false; - } - while (!in.eof()) { - string line; - std::getline(in, line); - if (line.empty() || line[0] == '#') { - // it should be a comment - continue; - } - std::stringstream ss(line); - while (!ss.eof()) { - string header; - ss >> header; - if (header == "path2qu") { - ss >> header; // reads "=" - ss >> path2qu; - continue; - } - - if (header == "path2ref") { - ss >> header; // reads "=" - ss >> path2ref; - continue; - } - - if (header == "querySize") { - ss >> header; // reads "=" - ss >> querySize; - } - - if (header == "matchingThreshold") { - ss >> header; // reads "=" - ss >> matchingThreshold; - continue; - } - - if (header == "expansionRate") { - ss >> header; // reads "=" - ss >> expansionRate; - continue; - } - if (header == "fanOut") { - ss >> header; // reads "=" - ss >> fanOut; - continue; - } - if (header == "bufferSize") { - ss >> header; // reads "=" - ss >> bufferSize; - continue; - } - - if (header == "path2quImg") { - ss >> header; // reads "=" - ss >> path2quImg; - continue; - } - - if (header == "path2refImg") { - ss >> header; // reads "=" - ss >> path2refImg; - continue; - } - if (header == "imgExt") { - ss >> header; // reads "=" - ss >> imgExt; - continue; - } - if (header == "similarityMatrix") { - ss >> header; // reads "=" - ss >> similarityMatrix; - continue; - } - if (header == "simPlaces") { - ss >> header; // reads "=" - ss >> simPlaces; - continue; - } - } // end of line parsing - } // end of file - return true; + std::ifstream in(iniFile.c_str()); + if (!in) { + printf("[ERROR][ConfigParser] The file \"%s\" cannot be opened.\n", + iniFile.c_str()); + return false; + } + while (!in.eof()) { + string line; + std::getline(in, line); + if (line.empty() || line[0] == '#') { + // it should be a comment + continue; + } + std::stringstream ss(line); + while (!ss.eof()) { + string header; + ss >> header; + if (header == "path2qu") { + ss >> header; // reads "=" + ss >> path2qu; + continue; + } + + if (header == "path2ref") { + ss >> header; // reads "=" + ss >> path2ref; + continue; + } + + if (header == "querySize") { + ss >> header; // reads "=" + ss >> querySize; + } + + if (header == "matchingThreshold") { + ss >> header; // reads "=" + ss >> matchingThreshold; + continue; + } + + if (header == "expansionRate") { + ss >> header; // reads "=" + ss >> expansionRate; + continue; + } + if (header == "fanOut") { + ss >> header; // reads "=" + ss >> fanOut; + continue; + } + if (header == "bufferSize") { + ss >> header; // reads "=" + ss >> bufferSize; + continue; + } + + if (header == "path2quImg") { + ss >> header; // reads "=" + ss >> path2quImg; + continue; + } + + if (header == "path2refImg") { + ss >> header; // reads "=" + ss >> path2refImg; + continue; + } + if (header == "imgExt") { + ss >> header; // reads "=" + ss >> imgExt; + continue; + } + if (header == "similarityMatrix") { + ss >> header; // reads "=" + ss >> similarityMatrix; + continue; + } + if (header == "simPlaces") { + ss >> header; // reads "=" + ss >> simPlaces; + continue; + } + } // end of line parsing + } // end of file + return true; } void ConfigParser::print() const { - printf("== Read parameters ==\n"); - printf("== Path2query: %s\n", path2qu.c_str()); - printf("== Path2ref: %s\n", path2ref.c_str()); - - printf("== Query size: %d\n", querySize); - printf("== matchingThreshold: %3.4f\n", matchingThreshold); - printf("== Expansion Rate: %3.4f\n", expansionRate); - printf("== FanOut: %d\n", fanOut); - - printf("== Path2query images: %s\n", path2quImg.c_str()); - printf("== Path2reference images: %s\n", path2refImg.c_str()); - printf("== Image extension: %s\n", imgExt.c_str()); - printf("== Buffer size: %d\n", bufferSize); - - printf("== similarityMatrix: %s\n", similarityMatrix.c_str()); - printf("== matchingResult: %s\n", matchingResult.c_str()); - printf("== simPlaces: %s\n", simPlaces.c_str()); + printf("== Read parameters ==\n"); + printf("== Path2query: %s\n", path2qu.c_str()); + printf("== Path2ref: %s\n", path2ref.c_str()); + + printf("== Query size: %d\n", querySize); + printf("== matchingThreshold: %3.4f\n", matchingThreshold); + printf("== Expansion Rate: %3.4f\n", expansionRate); + printf("== FanOut: %d\n", fanOut); + + printf("== Path2query images: %s\n", path2quImg.c_str()); + printf("== Path2reference images: %s\n", path2refImg.c_str()); + printf("== Image extension: %s\n", imgExt.c_str()); + printf("== Buffer size: %d\n", bufferSize); + + printf("== similarityMatrix: %s\n", similarityMatrix.c_str()); + printf("== matchingResult: %s\n", matchingResult.c_str()); + printf("== simPlaces: %s\n", simPlaces.c_str()); } bool ConfigParser::parseYaml(const std::string &yamlFile) { - YAML::Node config; - try { - config = YAML::LoadFile(yamlFile.c_str()); - } catch (...) { - printf("[ERROR][ConfigParser] File %s cannot be opened\n", - yamlFile.c_str()); - return false; - } - if (config["path2ref"]) { - path2ref = config["path2ref"].as(); - } - if (config["path2qu"]) { - path2qu = config["path2qu"].as(); - } - if (config["querySize"]) { - querySize = config["querySize"].as(); - } - if (config["fanOut"]) { - fanOut = config["fanOut"].as(); - } - if (config["matchingThreshold"]) { - matchingThreshold = config["matchingThreshold"].as(); - } - if (config["expansionRate"]) { - expansionRate = config["expansionRate"].as(); - } - if (config["path2quImg"]) { - path2quImg = config["path2quImg"].as(); - } - - if (config["path2refImg"]) { - path2refImg = config["path2refImg"].as(); - } - if (config["imgExt"]) { - imgExt = config["imgExt"].as(); - } - if (config["bufferSize"]) { - bufferSize = config["bufferSize"].as(); - } - if (config["similarityMatrix"]) { - similarityMatrix = config["similarityMatrix"].as(); - } - if (config["simPlaces"]) { - simPlaces = config["simPlaces"].as(); - } - - if (config["hashTable"]) { - hashTable = config["hashTable"].as(); - } - if (config["matchingResult"]) { - matchingResult = config["matchingResult"].as(); - } - - return true; + YAML::Node config; + try { + config = YAML::LoadFile(yamlFile.c_str()); + } catch (...) { + printf("[ERROR][ConfigParser] File %s cannot be opened\n", + yamlFile.c_str()); + return false; + } + if (config["path2ref"]) { + path2ref = config["path2ref"].as(); + } + if (config["path2qu"]) { + path2qu = config["path2qu"].as(); + } + if (config["querySize"]) { + querySize = config["querySize"].as(); + } + if (config["fanOut"]) { + fanOut = config["fanOut"].as(); + } + if (config["matchingThreshold"]) { + matchingThreshold = config["matchingThreshold"].as(); + } + if (config["expansionRate"]) { + expansionRate = config["expansionRate"].as(); + } + if (config["path2quImg"]) { + path2quImg = config["path2quImg"].as(); + } + + if (config["path2refImg"]) { + path2refImg = config["path2refImg"].as(); + } + if (config["imgExt"]) { + imgExt = config["imgExt"].as(); + } + if (config["bufferSize"]) { + bufferSize = config["bufferSize"].as(); + } + if (config["similarityMatrix"]) { + similarityMatrix = config["similarityMatrix"].as(); + } + if (config["simPlaces"]) { + simPlaces = config["simPlaces"].as(); + } + + if (config["hashTable"]) { + hashTable = config["hashTable"].as(); + } + if (config["matchingResult"]) { + matchingResult = config["matchingResult"].as(); + } + if (config["debugProto"]) { + debugProto = config["debugProto"].as(); + } +if (config["adaptThreshold"]) { + adaptThreshold = config["adaptThreshold"].as(); + } + + return true; } diff --git a/src/localization/tools/config_parser/config_parser.h b/src/localization/tools/config_parser/config_parser.h index a54b1e8..800e656 100644 --- a/src/localization/tools/config_parser/config_parser.h +++ b/src/localization/tools/config_parser/config_parser.h @@ -45,12 +45,14 @@ class ConfigParser { std::string simPlaces = ""; std::string hashTable = ""; std::string matchingResult = "matches.MatchingResult.pb"; + std::string debugProto = ""; int querySize = -1; int fanOut = -1; int bufferSize = -1; double matchingThreshold = -1.0; double expansionRate = -1.0; + bool adaptThreshold = false; }; /*! \var std::string ConfigParser::path2qu diff --git a/src/localization_protos.proto b/src/localization_protos.proto index ae91d11..a81bbb4 100644 --- a/src/localization_protos.proto +++ b/src/localization_protos.proto @@ -33,4 +33,42 @@ message Patch { optional int32 similarity_value = 3; } repeated Element elements = 1; +} + +message PatchInfo{ + message GMM { + optional double mu_class_0 = 1; + optional double mu_class_1 = 2; + optional double sigma_class_0 = 3; + optional double sigma_class_1 = 4; + } + optional Patch patch = 1; + optional double pvalue = 2; + optional double pvalue_location = 3; + optional bool path_exists = 4; + optional double threshold = 5; + optional int32 best_row = 6; + optional int32 best_col = 7; + // GMM result + optional GMM gmm_stats = 10; + +} + +message OnlineLocalizerDebugPerStep{ + optional double matching_threshold = 1; + optional double estimated_matching_threshold = 2; + optional int32 row = 3; + repeated PatchInfo ks_patches = 4; + optional Patch expanded_nodes = 5; + optional MatchingResult path = 6; + optional bool lost = 7; + +} + +message OnlineLocalizerDebug { + optional double start_matching_threshold = 1; + optional int32 stats_patch_width = 2; + optional int32 sliding_window_size = 3; + optional double significance_value = 4; + repeated OnlineLocalizerDebugPerStep debug_per_step = 10; } \ No newline at end of file diff --git a/src/test/online_localizer_test.cpp b/src/test/online_localizer_test.cpp index f585d3a..5af7b2a 100644 --- a/src/test/online_localizer_test.cpp +++ b/src/test/online_localizer_test.cpp @@ -60,7 +60,7 @@ class OnlineLocalizerTest : public ::testing::Test { }; TEST_F(OnlineLocalizerTest, Get) { - loc::online_localizer::Matches matches = localizer->findMatchesTill(4); + loc::online_localizer::Matches matches = localizer->findMatchesTill(4, ""); // Expecting diagonal elements as path in reverse. for (int i = 0; i < matches.size(); ++i) { From 44d3f2d9cfeab0d330270d366a82bb8d8bb1123f Mon Sep 17 00:00:00 2001 From: Olga Date: Thu, 28 May 2026 14:49:36 +0200 Subject: [PATCH 2/3] exposing the parameter for adapting the threshold --- readme.md | 24 ++++++++++++++----- .../online_localizer/math_tools/gmm.cpp | 3 +-- .../math_tools/statistical_test.cpp | 4 +--- .../math_tools/statistical_test.h | 2 -- src/python/matching_scripts.py | 9 +++++-- src/python/run_matching_from_features.py | 10 ++++++++ src/python/run_matching_from_images.py | 9 +++++++ 7 files changed, 46 insertions(+), 15 deletions(-) diff --git a/readme.md b/readme.md index 6d74761..7bb0476 100644 --- a/readme.md +++ b/readme.md @@ -74,13 +74,25 @@ For more details about the parameters, please use `python run_matching_from_*.py For more details about the underlying method and the interpretation of the results, please have a look at [paper](http://www.ipb.uni-bonn.de/pdfs/vysotska16ral-icra.pdf). Here is a sketch of what roughly is happening for those who don't like to read much ![](doc/cost_matrix_view.png) -## Parent project +### Adaptive thresholding -This repository is a continuation of my previous works [vpr_relocalization](https://github.com/PRBonn/vpr_relocalization) and [online_place_recognition](https://github.com/PRBonn/online_place_recognition). +This algorithm requires user to provide a value for matching threshold, a value that defines starting with which similarity two images are no longer showing the same place. Selecting a correct value for this is not straightforward and depends on the degree of challenge two image sequences exhibit. For example, when matching summer sequence with respect to winter sequence one threshold value would be fitting, however, the same value may not be good when matching day and night sequences. Moreover, the environment can be gradually changing within one image sequence, for example, it becomes gradually darker when the sun goes down. + +Here we provide an option to dynamically adapt the initial threshold. The code will check the current similarity values and adapt the matching threshold accordingly. +To use **adaptive thresholding** add `--adaptThreshold` parameter in the matching call, as shown here: + +```bash +python run_matching_from_images.py \ + --query_images \ + --reference_images \ + --dataset_name \ + --output_dir + --write_image_matches + --adaptThreshold +``` -The plan is to gradually modernize and improve the code by preserving the essential capabilities of the system. +For more details, please refer to the [paper](https://www.research-collection.ethz.ch/server/api/core/bitstreams/39b0396e-4743-4bf1-97c6-4cbcf97b0270/content). -**Essential capabilities**: +## Parent project -1. Given two sequences of images compute the matching image pairs. -2. Scripts to visualize the results. +This repository is a continuation of my previous works [vpr_relocalization](https://github.com/PRBonn/vpr_relocalization) and [online_place_recognition](https://github.com/PRBonn/online_place_recognition). diff --git a/src/localization/online_localizer/math_tools/gmm.cpp b/src/localization/online_localizer/math_tools/gmm.cpp index b603c97..7ab367e 100644 --- a/src/localization/online_localizer/math_tools/gmm.cpp +++ b/src/localization/online_localizer/math_tools/gmm.cpp @@ -1,4 +1,4 @@ -/** Copyright (c) 2026 Olga Vysotska, RSL, ETH Zurich +/** Copyright (c) 2026 Olga Vysotska, RSL, ETH Zurich ** Permission is hereby granted, free of charge, to any person obtaining a copy ** of this software and associated documentation files (the "Software"), to deal ** in the Software without restriction, including without limitation the rights @@ -22,7 +22,6 @@ #include "glog/logging.h" -#include #include #include #include diff --git a/src/localization/online_localizer/math_tools/statistical_test.cpp b/src/localization/online_localizer/math_tools/statistical_test.cpp index 8cda714..36cf9cc 100644 --- a/src/localization/online_localizer/math_tools/statistical_test.cpp +++ b/src/localization/online_localizer/math_tools/statistical_test.cpp @@ -24,8 +24,6 @@ #include #include -#include -#include #include namespace localization::online_localizer { @@ -139,7 +137,7 @@ patchContainsPath(const std::vector &values) { // This function checks if the patch contains the path using // the Kolmogorov-Smirnov statistical test. // https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test - // We consider that the patch contains NO path is the values form a unimodal + // We consider that the patch contains NO path if the values form a unimodal // distribution. If the observed distribution diverges from the unimodal // Gaussian distribution, then there exists a path in the patch. // The null hypothesis in the ks-test is: The patch is unimodal. diff --git a/src/localization/online_localizer/math_tools/statistical_test.h b/src/localization/online_localizer/math_tools/statistical_test.h index 9796133..b3c35e5 100644 --- a/src/localization/online_localizer/math_tools/statistical_test.h +++ b/src/localization/online_localizer/math_tools/statistical_test.h @@ -20,8 +20,6 @@ #ifndef SRC_ONLINE_LOCALIZER_MATH_TOOLS_STATISTICAL_TEST_H_ #define SRC_ONLINE_LOCALIZER_MATH_TOOLS_STATISTICAL_TEST_H_ -#include -#include #include namespace localization::online_localizer { diff --git a/src/python/matching_scripts.py b/src/python/matching_scripts.py index 5f6fdcc..1b22298 100644 --- a/src/python/matching_scripts.py +++ b/src/python/matching_scripts.py @@ -15,6 +15,7 @@ class RunParameters: expansionRate: float = 0.3 fanOut: int = 5 matchingThreshold: float = 3.7 + adaptThreshold: bool = False querySize: int = None bufferSize: int = 100 @@ -112,14 +113,18 @@ def computeSimilarityMatrix(run_params): def runMatching(config_yaml_file): - binary = "../../build/src/apps/similarity_matrix_based_matching/online_localizer_lsh" + binary = ( + "../../build/src/apps/similarity_matrix_based_matching/online_localizer_lsh" + ) command = binary + " " + str(config_yaml_file) print("Calling:", command) os.system(command) def runLocalizationResultVisualization(run_params): - params = "--similarity_matrix {similarity_matrix} ".format(similarity_matrix=run_params.similarityMatrix) + params = "--similarity_matrix {similarity_matrix} ".format( + similarity_matrix=run_params.similarityMatrix + ) params += "--matching_result {matching_result} ".format( matching_result=run_params.matchingResult ) diff --git a/src/python/run_matching_from_features.py b/src/python/run_matching_from_features.py index e4db806..cab1f59 100644 --- a/src/python/run_matching_from_features.py +++ b/src/python/run_matching_from_features.py @@ -34,11 +34,17 @@ def parseParams(): required=True, help="Path to output directory to store results.", ) + parser.add_argument( + "--adaptThreshold", + action="store_true", + help="Adapts the matching threshold based on the similarity values.", + ) return parser.parse_args() def setRunParameters(args): run_parameters = matching.RunParameters() + print(run_parameters) run_parameters.path2qu = args.query_features.as_posix() run_parameters.path2ref = args.reference_features.as_posix() run_parameters.similarityMatrix = ( @@ -69,6 +75,10 @@ def main(): args.output_dir.mkdir() run_params = setRunParameters(args) + if args.adaptThreshold: + run_params.adaptThreshold = True + # equals to 1. / (0.5 for similarity values between [0, 1] + run_params.matchingThreshold = 2.0 param_as_dict = matching.convertToDictWithoutNoneEntries(run_params) yaml_config_file = args.output_dir / (args.dataset_name + "_config.yml") diff --git a/src/python/run_matching_from_images.py b/src/python/run_matching_from_images.py index 976cd0e..e494eeb 100644 --- a/src/python/run_matching_from_images.py +++ b/src/python/run_matching_from_images.py @@ -46,6 +46,11 @@ def parseParams(): action="store_true", help="Creates and writes the pair of matching images.", ) + parser.add_argument( + "--adaptThreshold", + action="store_true", + help="Adapts the matching threshold based on the similarity values.", + ) parser.add_argument( "--link_images", action="store_true", @@ -104,6 +109,10 @@ def main(): ) run_params = setRunParameters(args) + if args.adaptThreshold: + run_params.adaptThreshold = True + # equals to 1. / (0.5 for similarity values between [0, 1] + run_params.matchingThreshold = 2.0 params_as_dict = matching.convertToDictWithoutNoneEntries(run_params) yaml_config_file = args.output_dir / (args.dataset_name + "_config.yml") From 12dbd264f2036e411c7d5ded94dd14670f0fc2f0 Mon Sep 17 00:00:00 2001 From: Olga Date: Fri, 29 May 2026 11:30:04 +0200 Subject: [PATCH 3/3] added possibility to store the debug file. The visualization of the debug file is still missing. --- src/python/matching_scripts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/python/matching_scripts.py b/src/python/matching_scripts.py index 1b22298..86dcc46 100644 --- a/src/python/matching_scripts.py +++ b/src/python/matching_scripts.py @@ -18,6 +18,7 @@ class RunParameters: adaptThreshold: bool = False querySize: int = None bufferSize: int = 100 + debugProto: str = None def initializeFromDict(params, params_as_dict):