diff --git a/.github/workflows/build-and-test.yaml b/.github/workflows/build-and-test.yaml
index be85d2bf5..02b8cb806 100644
--- a/.github/workflows/build-and-test.yaml
+++ b/.github/workflows/build-and-test.yaml
@@ -2,59 +2,52 @@ name: ABLkit-CI
on:
push:
- branches: [ main ]
+ branches: [main]
pull_request:
- branches: [ main ]
+ branches: [main]
jobs:
build:
runs-on: ${{ matrix.os }}
strategy:
+ fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
- python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
+ python-version: ["3.8", "3.9", "3.10", "3.11"]
steps:
- - uses: actions/checkout@v2
-
- - name: Set up Python
- uses: actions/setup-python@v2
- with:
- python-version: ${{ matrix.python-version }}
-
- - uses: syphar/restore-virtualenv@v1
- id: cache-virtualenv
- with:
- custom_cache_key_element: ABLkit
- requirement_files: requirements.txt
-
- - uses: syphar/restore-pip-download-cache@v1
- if: steps.cache-virtualenv.outputs.cache-hit != 'true'
-
- - name: Install SWI-Prolog on Ubuntu
- if: matrix.os == 'ubuntu-latest'
- run: sudo apt-get install swi-prolog
- - name: Install SWI-Prolog on Windows
- if: matrix.os == 'windows-latest'
- run: choco install swi-prolog
- - name: Install SWI-Prolog on MACOS
- if: matrix.os == 'macos-latest'
- run: brew install swi-prolog
-
- - name: Install package dependencies
- if : steps.cache-virtualenv.outputs.cache-hit != 'true'
- run: |
- python -m pip install --upgrade pip
- pip install pytest pytest-cov
- - name: Install
- if : steps.cache-virtualenv.outputs.cache-hit != 'true'
- run: pip install -v -e .
-
- - name: Run tests
- run: |
- pytest --cov-config=.coveragerc --cov-report=xml --cov=ablkit ./tests
-
- - name: Publish code coverage
- uses: codecov/codecov-action@v1
- with:
- token: ${{ secrets.CODECOV_TOKEN }}
- file: ./coverage.xml
\ No newline at end of file
+ - uses: actions/checkout@v4
+
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: pip
+ cache-dependency-path: pyproject.toml
+
+ - name: Install SWI-Prolog (Ubuntu)
+ if: matrix.os == 'ubuntu-latest'
+ run: sudo apt-get update && sudo apt-get install -y swi-prolog
+
+ - name: Install SWI-Prolog (Windows)
+ if: matrix.os == 'windows-latest'
+ run: choco install -y swi-prolog
+
+ - name: Install SWI-Prolog (macOS)
+ if: matrix.os == 'macos-latest'
+ run: brew install swi-prolog
+
+ - name: Install package and test dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install -e ".[test]"
+
+ - name: Run tests
+ run: pytest --cov-config=.coveragerc --cov-report=xml --cov=ablkit ./tests
+
+ - name: Publish code coverage
+ if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.10'
+ uses: codecov/codecov-action@v4
+ with:
+ token: ${{ secrets.CODECOV_TOKEN }}
+ files: ./coverage.xml
+ fail_ci_if_error: false
diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml
index b79d64b40..fd3d0164e 100644
--- a/.github/workflows/lint.yaml
+++ b/.github/workflows/lint.yaml
@@ -1,24 +1,27 @@
-name: flake8 Lint
-
-on:
- push:
- branches: [ main ]
- pull_request:
- branches: [ main ]
-
-jobs:
- flake8-lint:
- runs-on: ubuntu-latest
- name: Lint
- steps:
- - name: Check out source repository
- uses: actions/checkout@v3
- - name: Set up Python environment
- uses: actions/setup-python@v4
- with:
- python-version: "3.8"
- - name: flake8 Lint
- uses: py-actions/flake8@v2
- with:
- max-line-length: "100"
- args: --ignore=E203,W503,F821
\ No newline at end of file
+name: flake8 Lint
+
+on:
+ push:
+ branches: [main]
+ pull_request:
+ branches: [main]
+
+jobs:
+ flake8-lint:
+ runs-on: ubuntu-latest
+ name: Lint
+ steps:
+ - name: Check out source repository
+ uses: actions/checkout@v4
+
+ - name: Set up Python environment
+ uses: actions/setup-python@v5
+ with:
+ python-version: "3.10"
+
+ - name: flake8 Lint
+ uses: py-actions/flake8@v2
+ with:
+ path: ablkit
+ max-line-length: "100"
+ args: --ignore=E203,W503,F821
diff --git a/.gitignore b/.gitignore
index ba96e9d3a..80cfd18f9 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,13 +1,31 @@
*.pyc
examples/**/*.png
+examples/**/*.jpg
*.pk
+*.pkl
*.pth
*.json
*.ckpt
results
raw/
ablkit.egg-info/
-examples/**/*.jpg
.idea/
build/
-.history
\ No newline at end of file
+.history
+
+# Datasets and large binaries
+*.zip
+*.tar
+*.tar.gz
+*.tgz
+*.gz
+*.npz
+*.npy
+*.h5
+*.hdf5
+*.parquet
+*.feather
+*.arrow
+*-ubyte
+*.data
+*.arff
diff --git a/README.md b/README.md
index c83ebb93f..5a431ba09 100644
--- a/README.md
+++ b/README.md
@@ -2,13 +2,13 @@
-[](https://pypi.org/project/ablkit/) [](https://pypi.org/project/ablkit/) [](https://ablkit.readthedocs.io/en/latest/?badge=latest) [](https://github.com/AbductiveLearning/ABLkit/blob/main/LICENSE) [](https://github.com/AbductiveLearning/ABLkit/actions/workflows/lint.yaml) [](https://github.com/psf/black) [](https://github.com/AbductiveLearning/ABLkit/actions/workflows/build-and-test.yaml)
+[](https://github.com/AbductiveLearning/ABLkit/blob/main/LICENSE) [](https://img.shields.io/github/last-commit/AbductiveLearning/ablkit) [](https://pypi.org/project/ablkit/) [](https://pypi.org/project/ablkit/) [](https://ablkit.readthedocs.io/en/latest/?badge=latest) [](https://github.com/AbductiveLearning/ABLkit/actions/workflows/build-and-test.yaml) [](https://github.com/AbductiveLearning/ABLkit/actions/workflows/lint.yaml) [](https://github.com/psf/black) [](https://pypi.org/project/ablkit/)
-[📘Documentation](https://ablkit.readthedocs.io/en/latest/index.html) | [📄Paper](https://journal.hep.com.cn/fcs/EN/10.1007/s11704-024-40085-7) | [📚Examples](https://github.com/AbductiveLearning/ABLkit/tree/main/examples) | [💬Reporting Issues](https://github.com/AbductiveLearning/ABLkit/issues/new)
+[📘Documentation](https://ablkit.readthedocs.io/en/latest/index.html) | [📄Paper](https://journal.hep.com.cn/fcs/EN/10.1007/s11704-024-40085-7) | [🧪Examples](https://github.com/AbductiveLearning/ABLkit/tree/main/examples) | [💬Reporting Issues](https://github.com/AbductiveLearning/ABLkit/issues/new)
-# ABLkit: A Toolkit for Abductive Learning
+# 🧰 ABLkit: A Toolkit for Abductive Learning 📊📐
**ABLkit** is an efficient Python toolkit for [**Abductive Learning (ABL)**](https://www.lamda.nju.edu.cn/publication/chap_ABL.pdf). ABL is a novel paradigm that integrates machine learning and logical reasoning in a unified framework. It is suitable for tasks where both data and (logical) domain knowledge are available.
@@ -28,7 +28,7 @@ ABLkit encapsulates advanced ABL techniques, providing users with an efficient a
-## Installation
+## 🛠️ Installation
### Install from PyPI
@@ -60,7 +60,7 @@ sudo apt-get install swi-prolog
For Windows and Mac users, please refer to the [SWI-Prolog Install Guide](https://github.com/yuce/pyswip/blob/master/INSTALL.md).
-## Quick Start
+## ⚡ Quick Start
We use the MNIST Addition task as a quick start example. In this task, pairs of MNIST handwritten images and their sums are given, alongwith a domain knowledge base which contains information on how to perform addition operations. Our objective is to input a pair of handwritten images and accurately determine their sum.
@@ -184,7 +184,7 @@ bridge.test(test_data)
To explore detailed tutorials and information, please refer to: [Documentation on Read the Docs](https://ablkit.readthedocs.io/en/latest/index.html).
-## Examples
+## 🧪 Examples
We provide several examples in `examples/`. Each example is stored in a separate folder containing a README file.
@@ -192,8 +192,9 @@ We provide several examples in `examples/`. Each example is stored in a separate
+ [Handwritten Formula (HWF)](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/hwf)
+ [Handwritten Equation Decipherment](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/hed)
+ [Zoo](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/zoo)
++ [BDD-OIA](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/bdd_oia)
-## References
+## 📚 References
For more information about ABL, please refer to: [Zhou, 2019](http://scis.scichina.com/en/2019/076101.pdf) and [Zhou and Huang, 2022](https://www.lamda.nju.edu.cn/publication/chap_ABL.pdf).
@@ -220,7 +221,7 @@ For more information about ABL, please refer to: [Zhou, 2019](http://scis.scichi
}
```
-## Citation
+## 📝 Citation
To cite ABLkit, please cite the following paper: [Huang et al., 2024](https://journal.hep.com.cn/fcs/EN/10.1007/s11704-024-40085-7).
@@ -234,4 +235,54 @@ To cite ABLkit, please cite the following paper: [Huang et al., 2024](https://j
pages = {186354},
year = {2024}
}
-```
\ No newline at end of file
+```
+
+## ✨ Contributors
+
+We would like to thank the following contributors for their efforts on this project: (*: current maintainer)
+
+
+
+We also thank the following users for their helpful suggestions and feedback:
+
+- [Hao-Yuan He](https://github.com/Hao-Yuan-He)
+- [Lin-Han Jia](https://github.com/YGZWQZD)
+- [Wang-Zhou Dai](https://github.com/haldai)
\ No newline at end of file
diff --git a/ablkit/bridge/__init__.py b/ablkit/bridge/__init__.py
index 502a118cf..3d0f3f8e5 100644
--- a/ablkit/bridge/__init__.py
+++ b/ablkit/bridge/__init__.py
@@ -1,4 +1,6 @@
+from .a3bl_bridge import A3BLBridge
from .base_bridge import BaseBridge
from .simple_bridge import SimpleBridge
+from .verification_bridge import VerificationBridge
-__all__ = ["BaseBridge", "SimpleBridge"]
+__all__ = ["BaseBridge", "SimpleBridge", "A3BLBridge", "VerificationBridge"]
diff --git a/ablkit/bridge/a3bl_bridge.py b/ablkit/bridge/a3bl_bridge.py
new file mode 100644
index 000000000..d6c53faa2
--- /dev/null
+++ b/ablkit/bridge/a3bl_bridge.py
@@ -0,0 +1,166 @@
+import os.path as osp
+from typing import Any, List, Optional, Tuple, Union
+
+from ..data.evaluation.base_metric import BaseMetric
+from ..data.structures.list_data import ListData
+from ..learning import ABLModel
+from ..reasoning import A3BLReasoner
+from ..utils import print_log
+from .simple_bridge import SimpleBridge
+
+
+class A3BLBridge(SimpleBridge):
+ """
+ An ambiguity-aware implementation for bridging machine learning and reasoning parts.
+
+ Reference: https://github.com/Hao-Yuan-He/A3BL
+
+ Involves the following five steps:
+ - Predict class probabilities and indices for the given data examples.
+ - Map indices into pseudo-labels.
+ - Enumerate all valid pseudo-labels.
+ - Revise pseudo-labels to label distribution based on the class probabilities.
+ - Train the model.
+
+ Parameters
+ ----------
+ model : ABLModel
+ The machine learning model wrapped in ``ABLModel``, used for prediction
+ and training. The wrapped base model should expose ``extract_features``
+ so embeddings are available for the soft-label aggregation.
+ reasoner : A3BLReasoner
+ The reasoning part wrapped in ``A3BLReasoner``, used for pseudo-label
+ enumeration and soft-label aggregation.
+ metric_list : List[BaseMetric]
+ A list of metrics used for evaluating the model's performance.
+ """
+
+ def __init__(
+ self,
+ model: ABLModel,
+ reasoner: A3BLReasoner,
+ metric_list: List[BaseMetric],
+ ):
+ super().__init__(model, reasoner, metric_list)
+
+ def abduce_soft_label(self, data_examples: ListData) -> List[List[Any]]:
+ """
+ Revise predicted pseudo-labels to a soft label, given data examples using abduction.
+
+ Parameters
+ ----------
+ data_examples : ListData
+ Data examples containing predicted pseudo-labels.
+
+ Returns
+ -------
+ List[List[Any]]
+ A list of abduced soft labels for the given data examples.
+ """
+ self.reasoner.batch_abduce(data_examples)
+ return data_examples.abduced_soft_label
+
+ def train_data_iter(
+ self,
+ train_data,
+ val_data=None,
+ segment_size=1.0,
+ ):
+ data_examples = self.data_preprocess("train", train_data)
+
+ if val_data is not None:
+ val_data_examples = self.data_preprocess("val", val_data)
+ else:
+ val_data_examples = data_examples
+
+ if isinstance(segment_size, int):
+ if segment_size <= 0:
+ raise ValueError("segment_size should be positive.")
+ elif isinstance(segment_size, float):
+ if 0 < segment_size <= 1:
+ segment_size = int(segment_size * len(data_examples))
+ else:
+ raise ValueError("segment_size should be in (0, 1].")
+ else:
+ raise ValueError("segment_size should be int or float.")
+
+ for seg_idx in range((len(data_examples) - 1) // segment_size + 1):
+ sub = data_examples[seg_idx * segment_size: (seg_idx + 1) * segment_size]
+ yield sub, val_data_examples
+
+ def train(
+ self,
+ train_data: Union[
+ ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]
+ ],
+ val_data: Optional[
+ Union[
+ ListData,
+ Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]],
+ ]
+ ] = None,
+ loops: int = 50,
+ segment_size: Union[int, float] = 1.0,
+ eval_interval: int = 1,
+ save_interval: Optional[int] = None,
+ save_dir: Optional[str] = None,
+ ):
+ """
+ A typical training pipeline of Abuductive Learning.
+
+ Parameters
+ ----------
+ train_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]
+ Training data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData``
+ object with ``X``, ``gt_pseudo_label`` and ``Y`` attributes.
+ - ``X`` is a list of sublists representing the input data.
+ - ``gt_pseudo_label`` is only used to evaluate the performance of the ``ABLModel`` but
+ not to train. ``gt_pseudo_label`` can be ``None``.
+ - ``Y`` is a list representing the ground truth reasoning result for each sublist
+ in ``X``.
+ label_data : Union[ListData, Tuple[List[List[Any]], List[List[Any]], List[Any]]], optional
+ Labeled data should be in the same format as ``train_data``. The only difference is
+ that the ``gt_pseudo_label`` in ``label_data`` should not be ``None`` and will be
+ utilized to train the model. Defaults to None.
+ val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]], optional # noqa: E501 pylint: disable=line-too-long
+ Validation data should be in the same format as ``train_data``. Both ``gt_pseudo_label``
+ and ``Y`` can be either None or not, which depends on the evaluation metircs in
+ ``self.metric_list``. If ``val_data`` is None, ``train_data`` will be used to validate
+ the model during training time. Defaults to None.
+ loops : int
+ Learning part and Reasoning part will be iteratively optimized
+ for ``loops`` times. Defaults to 50.
+ segment_size : Union[int, float]
+ Data will be split into segments of this size and data in each segment
+ will be used together to train the model. Defaults to 1.0.
+ eval_interval : int
+ The model will be evaluated every ``eval_interval`` loop during training,
+ Defaults to 1.
+ save_interval : int, optional
+ The model will be saved every ``eval_interval`` loop during training.
+ Defaults to None.
+ save_dir : str, optional
+ Directory to save the model. Defaults to None.
+ """
+ for loop in range(loops):
+ iterator = self.train_data_iter(train_data, val_data, segment_size)
+ for train_examples_batch, val_examples_batch in iterator:
+ print_log(
+ f"loop(train) [{loop + 1}/{loops}] segment(train) ", logger="current"
+ )
+ self.predict(train_examples_batch)
+ self.idx_to_pseudo_label(train_examples_batch)
+ self.abduce_pseudo_label(train_examples_batch)
+ self.filter_pseudo_label(train_examples_batch)
+ self.pseudo_label_to_idx(train_examples_batch)
+ self.model.train(train_examples_batch)
+
+ if (loop + 1) % eval_interval == 0 or loop == loops - 1:
+ print_log(f"Eval start: loop(val) [{loop + 1}]", logger="current")
+ self._valid(val_examples_batch)
+
+ if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1):
+ print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current")
+ self.model.save(
+ save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth")
+ )
diff --git a/ablkit/bridge/base_bridge.py b/ablkit/bridge/base_bridge.py
index a3a40add2..8f5e55f35 100644
--- a/ablkit/bridge/base_bridge.py
+++ b/ablkit/bridge/base_bridge.py
@@ -37,9 +37,10 @@ class BaseBridge(metaclass=ABCMeta):
def __init__(self, model: ABLModel, reasoner: Reasoner) -> None:
if not isinstance(model, ABLModel):
raise TypeError(f"Expected an instance of ABLModel, but received type: {type(model)}")
- if not isinstance(reasoner, Reasoner):
+ if not (hasattr(reasoner, "idx_to_label") and hasattr(reasoner, "label_to_idx")):
raise TypeError(
- f"Expected an instance of Reasoner, but received type: {type(reasoner)}"
+ "Expected a reasoner exposing idx_to_label / label_to_idx (e.g. Reasoner "
+ f"or VerificationReasoner), but received type: {type(reasoner)}"
)
self.model = model
diff --git a/ablkit/bridge/simple_bridge.py b/ablkit/bridge/simple_bridge.py
index 5c2cbfbb9..e48254714 100644
--- a/ablkit/bridge/simple_bridge.py
+++ b/ablkit/bridge/simple_bridge.py
@@ -94,6 +94,23 @@ def abduce_pseudo_label(self, data_examples: ListData) -> List[List[Any]]:
self.reasoner.batch_abduce(data_examples)
return data_examples.abduced_pseudo_label
+ def supervised_abduce_pseudo_label(self, data_examples: ListData) -> List[List[Any]]:
+ """
+ Revise predicted pseudo-labels of the given data examples using ground truth.
+
+ Parameters
+ ----------
+ data_examples : ListData
+ Data examples containing predicted pseudo-labels.
+
+ Returns
+ -------
+ List[List[Any]]
+ A list of ground truth/abduced pseudo-labels for the given data examples.
+ """
+ self.reasoner.batch_supervised_abduce(data_examples)
+ return data_examples.abduced_pseudo_label
+
def idx_to_pseudo_label(self, data_examples: ListData) -> List[List[Any]]:
"""
Map indices of data examples into pseudo-labels.
@@ -211,10 +228,14 @@ def train(
Union[ListData, Tuple[List[List[Any]], List[List[Any]], List[Any]]]
] = None,
val_data: Optional[
- Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]]
+ Union[
+ ListData,
+ Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]],
+ ]
] = None,
loops: int = 50,
segment_size: Union[int, float] = 1.0,
+ use_supervised_data: bool = False,
eval_interval: int = 1,
save_interval: Optional[int] = None,
save_dir: Optional[str] = None,
@@ -257,58 +278,95 @@ def train(
Directory to save the model. Defaults to None.
"""
data_examples = self.data_preprocess("train", train_data)
+ label_data_examples = (
+ self.data_preprocess("label", label_data) if label_data is not None else None
+ )
+ val_data_examples = (
+ self.data_preprocess("val", val_data) if val_data is not None else data_examples
+ )
- if label_data is not None:
- label_data_examples = self.data_preprocess("label", label_data)
- else:
- label_data_examples = None
-
- if val_data is not None:
- val_data_examples = self.data_preprocess("val", val_data)
- else:
- val_data_examples = data_examples
-
- if isinstance(segment_size, int):
- if segment_size <= 0:
- raise ValueError("segment_size should be positive.")
- elif isinstance(segment_size, float):
- if 0 < segment_size <= 1:
- segment_size = int(segment_size * len(data_examples))
- else:
- raise ValueError("segment_size should be in (0, 1].")
- else:
- raise ValueError("segment_size should be int or float.")
+ segment_size = self._resolve_segment_size(segment_size, len(data_examples))
+ num_segments = (len(data_examples) - 1) // segment_size + 1
for loop in range(loops):
- for seg_idx in range((len(data_examples) - 1) // segment_size + 1):
+ for seg_idx in range(num_segments):
print_log(
f"loop(train) [{loop + 1}/{loops}] segment(train) "
- f"[{(seg_idx + 1)}/{(len(data_examples) - 1) // segment_size + 1}] ",
+ f"[{seg_idx + 1}/{num_segments}] ",
logger="current",
)
-
sub_data_examples = data_examples[
seg_idx * segment_size : (seg_idx + 1) * segment_size
]
- self.predict(sub_data_examples)
- self.idx_to_pseudo_label(sub_data_examples)
- self.abduce_pseudo_label(sub_data_examples)
- self.filter_pseudo_label(sub_data_examples)
- self.concat_data_examples(sub_data_examples, label_data_examples)
- self.pseudo_label_to_idx(sub_data_examples)
- self.model.train(sub_data_examples)
-
- if (loop + 1) % eval_interval == 0 or loop == loops - 1:
- print_log(f"Eval start: loop(val) [{loop + 1}]", logger="current")
- self._valid(val_data_examples)
-
- if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1):
- print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current")
- self.model.save(
- save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth")
+ self._train_one_segment(
+ sub_data_examples, label_data_examples, use_supervised_data
)
- def _valid(self, data_examples: ListData) -> None:
+ self._maybe_eval(val_data_examples, loop, loops, eval_interval)
+ self._maybe_save(loop, loops, save_interval, save_dir)
+
+ @staticmethod
+ def _resolve_segment_size(segment_size: Union[int, float], dataset_len: int) -> int:
+ """Validate and convert ``segment_size`` into an absolute number of examples."""
+ if isinstance(segment_size, int):
+ if segment_size <= 0:
+ raise ValueError("segment_size should be positive.")
+ return segment_size
+ if isinstance(segment_size, float):
+ if not (0 < segment_size <= 1):
+ raise ValueError("segment_size should be in (0, 1].")
+ return int(segment_size * dataset_len)
+ raise ValueError("segment_size should be int or float.")
+
+ def _train_one_segment(
+ self,
+ sub_data_examples: ListData,
+ label_data_examples: Optional[ListData],
+ use_supervised_data: bool,
+ ) -> None:
+ """Run prediction, abduction, label-data concat, and a single model.train step."""
+ self.predict(sub_data_examples)
+ self.idx_to_pseudo_label(sub_data_examples)
+ if use_supervised_data:
+ self.supervised_abduce_pseudo_label(sub_data_examples)
+ else:
+ self.abduce_pseudo_label(sub_data_examples)
+ self.filter_pseudo_label(sub_data_examples)
+ self.concat_data_examples(sub_data_examples, label_data_examples)
+ self.pseudo_label_to_idx(sub_data_examples)
+ if len(sub_data_examples) == 0:
+ return
+ self.model.train(sub_data_examples)
+
+ def _maybe_eval(
+ self,
+ val_data_examples: ListData,
+ loop: int,
+ loops: int,
+ eval_interval: int,
+ ) -> None:
+ """Evaluate on ``val_data_examples`` at the configured interval (and on the last loop)."""
+ if (loop + 1) % eval_interval == 0 or loop == loops - 1:
+ print_log(f"Eval start: loop(val) [{loop + 1}]", logger="current")
+ self._valid(val_data_examples, prefix="val")
+
+ def _maybe_save(
+ self,
+ loop: int,
+ loops: int,
+ save_interval: Optional[int],
+ save_dir: Optional[str],
+ ) -> None:
+ """Persist the model at the configured interval (and on the last loop)."""
+ if save_interval is None:
+ return
+ if (loop + 1) % save_interval == 0 or loop == loops - 1:
+ print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current")
+ self.model.save(
+ save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth")
+ )
+
+ def _valid(self, data_examples: ListData, prefix: str = "val") -> None:
"""
Internal method for validating the model with given data examples.
@@ -320,21 +378,31 @@ def _valid(self, data_examples: ListData) -> None:
self.predict(data_examples)
self.idx_to_pseudo_label(data_examples)
+ for metric in self.metric_list:
+ metric.prefix = prefix
+
for metric in self.metric_list:
metric.process(data_examples)
res = dict()
for metric in self.metric_list:
res.update(metric.evaluate())
+
msg = "Evaluation ended, "
for k, v in res.items():
- msg += k + f": {v:.3f} "
+ try:
+ v = float(v)
+ msg += k + f": {v:.3f} "
+ except (TypeError, ValueError):
+ pass
+
print_log(msg, logger="current")
def valid(
self,
val_data: Union[
- ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]
+ ListData,
+ Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]],
],
) -> None:
"""
@@ -349,12 +417,13 @@ def valid(
``self.metric_list``.
"""
val_data_examples = self.data_preprocess("val", val_data)
- self._valid(val_data_examples)
+ self._valid(val_data_examples, prefix="val")
def test(
self,
test_data: Union[
- ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]
+ ListData,
+ Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]],
],
) -> None:
"""
@@ -370,4 +439,4 @@ def test(
"""
print_log("Test start:", logger="current")
test_data_examples = self.data_preprocess("test", test_data)
- self._valid(test_data_examples)
+ self._valid(test_data_examples, prefix="test")
diff --git a/ablkit/bridge/verification_bridge.py b/ablkit/bridge/verification_bridge.py
new file mode 100644
index 000000000..b87537d24
--- /dev/null
+++ b/ablkit/bridge/verification_bridge.py
@@ -0,0 +1,117 @@
+"""
+Bridge for Verification Learning.
+
+:class:`VerificationBridge` replaces the single-candidate abduction step of
+:class:`SimpleBridge` with a top-K enumeration provided by
+:class:`~ablkit.reasoning.VerificationReasoner`. For each
+segment the bridge trains the model once per top-K candidate, exposing
+the model to every assignment that is consistent with the knowledge base.
+
+Reference: https://github.com/VerificationLearning/VerificationLearning
+"""
+
+from typing import Any, List, Optional, Tuple, Union
+
+from ..data.evaluation import BaseMetric
+from ..data.structures import ListData
+from ..learning import ABLModel
+from ..reasoning.reasoner import VerificationReasoner
+from ..utils import print_log
+from .simple_bridge import SimpleBridge
+
+
+class VerificationBridge(SimpleBridge):
+ """
+ Bridge implementing the Verification Learning training loop.
+
+ Parameters
+ ----------
+ model : ABLModel
+ Wrapped learning model.
+ reasoner : VerificationReasoner
+ Top-K reasoner. The bridge reads ``reasoner.top_k`` to decide how
+ many training passes to run per segment.
+ metric_list : List[BaseMetric]
+ Evaluation metrics, identical to :class:`SimpleBridge`.
+ """
+
+ def __init__(
+ self,
+ model: ABLModel,
+ reasoner: VerificationReasoner,
+ metric_list: List[BaseMetric],
+ ) -> None:
+ if not isinstance(reasoner, VerificationReasoner):
+ raise TypeError(
+ "VerificationBridge requires a VerificationReasoner; "
+ f"got {type(reasoner).__name__}."
+ )
+ super().__init__(model, reasoner, metric_list)
+
+ def train(
+ self,
+ train_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]],
+ val_data: Optional[
+ Union[
+ ListData,
+ Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]],
+ ]
+ ] = None,
+ loops: int = 50,
+ segment_size: Union[int, float] = 1.0,
+ eval_interval: int = 1,
+ save_interval: Optional[int] = None,
+ save_dir: Optional[str] = None,
+ ) -> None:
+ """
+ Verification Learning training loop. For each segment we predict
+ once, enumerate the top-K consistent candidates, then run a
+ ``model.train`` pass per candidate.
+ """
+ data_examples = self.data_preprocess("train", train_data)
+ val_data_examples = (
+ self.data_preprocess("val", val_data) if val_data is not None else data_examples
+ )
+
+ segment_size = self._resolve_segment_size(segment_size, len(data_examples))
+ num_segments = (len(data_examples) - 1) // segment_size + 1
+
+ for loop in range(loops):
+ for seg_idx in range(num_segments):
+ print_log(
+ f"loop(train) [{loop + 1}/{loops}] segment(train) "
+ f"[{seg_idx + 1}/{num_segments}] ",
+ logger="current",
+ )
+ sub_data_examples = data_examples[
+ seg_idx * segment_size : (seg_idx + 1) * segment_size
+ ]
+ self._train_one_segment_verification(sub_data_examples)
+
+ self._maybe_eval(val_data_examples, loop, loops, eval_interval)
+ self._maybe_save(loop, loops, save_interval, save_dir)
+
+ def _train_one_segment_verification(self, sub_data_examples: ListData) -> None:
+ """
+ Predict, enumerate top-K candidates, then train once per candidate.
+ Each example's k-th training pass uses its k-th candidate (or, if
+ the example yielded fewer than k candidates, its last available
+ candidate, repeated).
+ """
+ self.predict(sub_data_examples)
+ self.idx_to_pseudo_label(sub_data_examples)
+ per_example_candidates = self.reasoner.batch_top_k(sub_data_examples)
+
+ if not per_example_candidates:
+ return
+
+ max_k = max(len(cands) for cands in per_example_candidates)
+ for k_idx in range(max_k):
+ sub_data_examples.abduced_pseudo_label = [
+ cands[min(k_idx, len(cands) - 1)] for cands in per_example_candidates
+ ]
+ self.filter_pseudo_label(sub_data_examples)
+ self.pseudo_label_to_idx(sub_data_examples)
+ if len(sub_data_examples) == 0:
+ continue
+ self.model.train(sub_data_examples)
diff --git a/ablkit/data/structures/base_data_element.py b/ablkit/data/structures/base_data_element.py
index 2e9b00052..e9bd60c26 100644
--- a/ablkit/data/structures/base_data_element.py
+++ b/ablkit/data/structures/base_data_element.py
@@ -306,14 +306,12 @@ def keys(self) -> list:
Returns:
list: Contains all keys in data_fields.
"""
- # We assume that the name of the attribute related to property is
- # '_' + the name of the property. We use this rule to filter out
- # private keys.
- # TODO: Use a more robust way to solve this problem
+ cls = type(self)
private_keys = {
- "_" + key
- for key in self._data_fields
- if isinstance(getattr(type(self), key, None), property)
+ name
+ for name in self._data_fields
+ if name.startswith("_")
+ and isinstance(getattr(cls, name[1:], None), property)
}
return list(self._data_fields - private_keys)
diff --git a/ablkit/data/structures/list_data.py b/ablkit/data/structures/list_data.py
index 256deda7f..6fea13ad1 100644
--- a/ablkit/data/structures/list_data.py
+++ b/ablkit/data/structures/list_data.py
@@ -194,7 +194,7 @@ def __getitem__(self, item: IndexType) -> "ListData":
new_data[k] = None
else:
new_data[k] = v[item]
- return new_data # type:ignore
+ return new_data # type: ignore
def flatten(self, item: str) -> List:
"""
diff --git a/ablkit/learning/__init__.py b/ablkit/learning/__init__.py
index ad016a654..bd2f2a7f7 100644
--- a/ablkit/learning/__init__.py
+++ b/ablkit/learning/__init__.py
@@ -1,5 +1,19 @@
-from .abl_model import ABLModel
-from .basic_nn import BasicNN
-from .torch_dataset import ClassificationDataset, PredictionDataset, RegressionDataset
+from .abl_model import ABLModel, MultiLabelABLModel
+from .basic_nn import BasicNN, MultiLabelBasicNN
+from .torch_dataset import (
+ ClassificationDataset,
+ MultiLabelClassificationDataset,
+ PredictionDataset,
+ RegressionDataset,
+)
-__all__ = ["ABLModel", "BasicNN", "ClassificationDataset", "PredictionDataset", "RegressionDataset"]
+__all__ = [
+ "ABLModel",
+ "BasicNN",
+ "MultiLabelABLModel",
+ "MultiLabelBasicNN",
+ "ClassificationDataset",
+ "MultiLabelClassificationDataset",
+ "PredictionDataset",
+ "RegressionDataset",
+]
diff --git a/ablkit/learning/abl_model.py b/ablkit/learning/abl_model.py
index 7ceaa1510..db7f7cdee 100644
--- a/ablkit/learning/abl_model.py
+++ b/ablkit/learning/abl_model.py
@@ -8,6 +8,8 @@
import pickle
from typing import Any, Dict
+import numpy as np
+
from ..data.structures import ListData
from ..utils import reform_list
@@ -20,9 +22,10 @@ class ABLModel:
----------
base_model : Machine Learning Model
The machine learning base model used for training and prediction. This model should
- implement the ``fit`` and ``predict`` methods. It's recommended, but not required, for the
- model to also implement the ``predict_proba`` method for generating
- predictions on the probabilities.
+ implement the ``fit`` and ``predict`` methods. It's recommended, but not required, for
+ the model to also implement ``predict_proba`` (used to populate ``pred_prob``) and
+ ``extract_features`` (used to populate ``data_example.embeddings`` for distance
+ functions such as ``similarity``).
"""
def __init__(self, base_model: Any) -> None:
@@ -31,7 +34,7 @@ def __init__(self, base_model: Any) -> None:
self.base_model = base_model
- def predict(self, data_examples: ListData) -> Dict:
+ def predict(self, data_examples: ListData) -> Dict[str, Any]:
"""
Predict the labels and probabilities for the given data.
@@ -47,8 +50,14 @@ def predict(self, data_examples: ListData) -> Dict:
"""
model = self.base_model
data_X = data_examples.flatten("X")
+ embeddings = None
if hasattr(model, "predict_proba"):
prob = model.predict_proba(X=data_X)
+ if hasattr(model, "extract_features"):
+ try:
+ embeddings = model.extract_features(X=data_X)
+ except AttributeError:
+ embeddings = None
label = prob.argmax(axis=1)
prob = reform_list(prob, data_examples.X)
else:
@@ -58,6 +67,8 @@ def predict(self, data_examples: ListData) -> Dict:
data_examples.pred_idx = label
data_examples.pred_prob = prob
+ if embeddings is not None:
+ data_examples.embeddings = reform_list(embeddings, data_examples.X)
return {"label": label, "prob": prob}
@@ -138,3 +149,57 @@ def load(self, *args, **kwargs) -> None:
this method should match those expected by the ``load`` method of self.base_model.
"""
self._model_operation("load", *args, **kwargs)
+
+
+# =============================================================================
+# Multi-label variants
+# =============================================================================
+
+
+class MultiLabelABLModel(ABLModel):
+ """
+ Multi-label variant of :class:`ABLModel`.
+
+ The standard :class:`ABLModel.predict` selects a single class index per
+ instance via ``argmax`` over a softmax distribution. For multi-label
+ settings (each instance can have multiple active labels), this class
+ instead thresholds the per-label sigmoid probabilities at 0.5 and
+ stores the resulting binary indicator vectors on ``pred_idx``.
+
+ Pair it with :class:`~ablkit.learning.MultiLabelBasicNN` (which
+ provides ``predict_proba`` returning ``(num_samples, num_labels)``
+ sigmoid probabilities) for the typical multi-label workflow.
+ """
+
+ def predict(self, data_examples: ListData) -> Dict[str, Any]:
+ """
+ Predict per-label binary indicators and per-label probabilities.
+
+ Parameters
+ ----------
+ data_examples : ListData
+ A batch of data to predict on.
+
+ Returns
+ -------
+ Dict[str, Any]
+ A dictionary with keys ``"label"`` (binary indicator vectors
+ grouped per example) and ``"prob"`` (per-label probabilities
+ grouped per example, or ``None`` if the base model does not
+ expose ``predict_proba``).
+ """
+ model = self.base_model
+ data_X = data_examples.flatten("X")
+ if hasattr(model, "predict_proba"):
+ prob = model.predict_proba(X=data_X)
+ label = np.where(prob > 0.5, 1, 0).astype(int)
+ prob = reform_list(prob, data_examples.X)
+ else:
+ prob = None
+ label = model.predict(X=data_X)
+ label = reform_list(label, data_examples.X)
+
+ data_examples.pred_idx = label
+ data_examples.pred_prob = prob
+
+ return {"label": label, "prob": prob}
diff --git a/ablkit/learning/basic_nn.py b/ablkit/learning/basic_nn.py
index 5ad44f03d..e921ceabe 100644
--- a/ablkit/learning/basic_nn.py
+++ b/ablkit/learning/basic_nn.py
@@ -15,7 +15,11 @@
from torch.utils.data import DataLoader
from ..utils.logger import print_log
-from .torch_dataset import ClassificationDataset, PredictionDataset
+from .torch_dataset import (
+ ClassificationDataset,
+ MultiLabelClassificationDataset,
+ PredictionDataset,
+)
class BasicNN:
@@ -367,6 +371,86 @@ def predict_proba(
)
return self._predict(data_loader).softmax(axis=1).cpu().numpy()
+ def _extract_features(self, data_loader: DataLoader) -> torch.Tensor:
+ """
+ Internal method to compute feature embeddings via ``self.model.extract_features``
+ over every batch in ``data_loader``.
+
+ Parameters
+ ----------
+ data_loader : DataLoader
+ DataLoader providing input samples.
+
+ Returns
+ -------
+ torch.Tensor
+ Concatenated feature tensor across all batches.
+ """
+ if not isinstance(data_loader, DataLoader):
+ raise TypeError(
+ "data_loader must be an instance of torch.utils.data.DataLoader, "
+ f"but got {type(data_loader)}"
+ )
+ if not hasattr(self.model, "extract_features"):
+ raise AttributeError(
+ f"{type(self.model).__name__} does not implement extract_features(x). "
+ "Add such a method to your PyTorch model to enable feature extraction "
+ "(used by dist_func='similarity', among others)."
+ )
+ model = self.model
+ device = self.device
+ model.eval()
+ with torch.no_grad():
+ results = []
+ for data in data_loader:
+ data = data.to(device)
+ results.append(model.extract_features(data))
+ return torch.cat(results, dim=0)
+
+ def extract_features(
+ self,
+ data_loader: Optional[DataLoader] = None,
+ X: Optional[List[Any]] = None,
+ ) -> numpy.ndarray:
+ """
+ Compute feature embeddings for ``X`` (or a prebuilt ``data_loader``).
+ When both are provided, ``data_loader`` takes precedence.
+
+ The wrapped PyTorch model must implement ``extract_features(x)``
+ returning the embedding tensor (typically penultimate-layer
+ activations) used by downstream consumers such as
+ ``dist_func='similarity'``.
+
+ Parameters
+ ----------
+ data_loader : DataLoader, optional
+ DataLoader to use directly. Defaults to None.
+ X : List[Any], optional
+ Raw input list; converted to a ``PredictionDataset`` when used.
+ Defaults to None.
+
+ Returns
+ -------
+ numpy.ndarray
+ Feature embeddings of shape ``(num_samples, embedding_dim)``.
+ """
+ if data_loader is not None and X is not None:
+ print_log(
+ "Extracting features from data_loader; ignoring X.",
+ logger="current",
+ level=logging.WARNING,
+ )
+ if data_loader is None:
+ dataset = PredictionDataset(X, self.test_transform)
+ data_loader = DataLoader(
+ dataset,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ collate_fn=self.collate_fn,
+ pin_memory=torch.cuda.is_available(),
+ )
+ return self._extract_features(data_loader).cpu().numpy()
+
def _score(self, data_loader: DataLoader) -> Tuple[float, float]:
"""
Internal method to compute loss and accuracy for the data provided through a DataLoader.
@@ -455,7 +539,7 @@ def score(
else:
data_loader = self._data_loader(X, y)
mean_loss, accuracy = self._score(data_loader)
- print_log(f"mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current")
+ print_log(f"mean loss: {mean_loss:.3f}, accuracy: {accuracy:.3f}", logger="current")
return accuracy
def _data_loader(
@@ -528,12 +612,12 @@ def save(self, epoch_id: int = 0, save_path: Optional[str] = None) -> None:
print_log(f"Checkpoints will be saved to {save_path}", logger="current")
- save_parma_dic = {
+ save_param_dict = {
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
}
- torch.save(save_parma_dic, save_path)
+ torch.save(save_param_dict, save_path)
def load(self, load_path: str) -> None:
"""
@@ -557,3 +641,109 @@ def load(self, load_path: str) -> None:
self.model.load_state_dict(param_dic["model"])
if "optimizer" in param_dic.keys():
self.optimizer.load_state_dict(param_dic["optimizer"])
+
+
+# =============================================================================
+# Multi-label variants
+# =============================================================================
+
+
+class MultiLabelBasicNN(BasicNN):
+ """
+ A multi-label variant of :class:`BasicNN`.
+
+ The standard :class:`BasicNN` assumes a single-label, multi-class
+ classification setting (softmax output, argmax prediction). In
+ contrast, :class:`MultiLabelBasicNN` treats each output dimension as
+ an independent binary decision (sigmoid output, threshold-based binary
+ vector prediction) and uses
+ :class:`~ablkit.learning.MultiLabelClassificationDataset` so that
+ targets can be fed straight into losses like ``BCEWithLogitsLoss``.
+
+ Apart from prediction and dataset handling, the class reuses the full
+ training and evaluation pipeline from :class:`BasicNN`.
+ """
+
+ def predict(
+ self,
+ data_loader: Optional[DataLoader] = None,
+ X: Optional[List[Any]] = None,
+ ) -> numpy.ndarray:
+ """
+ Return a binary indicator vector for each sample by thresholding
+ the per-label sigmoid probabilities at 0.5.
+ """
+ if data_loader is not None and X is not None:
+ print_log(
+ "Predict the class of input data in data_loader instead of X.",
+ logger="current",
+ level=logging.WARNING,
+ )
+
+ if data_loader is None:
+ dataset = PredictionDataset(X, self.test_transform)
+ data_loader = DataLoader(
+ dataset,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ collate_fn=self.collate_fn,
+ pin_memory=torch.cuda.is_available(),
+ )
+ pred_probs = self._predict(data_loader).sigmoid()
+ pred = torch.where(pred_probs > 0.5, 1, 0).int()
+ return pred.cpu().numpy()
+
+ def predict_proba(
+ self,
+ data_loader: Optional[DataLoader] = None,
+ X: Optional[List[Any]] = None,
+ ) -> numpy.ndarray:
+ """
+ Return per-label sigmoid probabilities of shape
+ ``(num_samples, num_labels)``.
+ """
+ if data_loader is not None and X is not None:
+ print_log(
+ "Predict the class probability of input data in data_loader instead of X.",
+ logger="current",
+ level=logging.WARNING,
+ )
+
+ if data_loader is None:
+ dataset = PredictionDataset(X, self.test_transform)
+ data_loader = DataLoader(
+ dataset,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ collate_fn=self.collate_fn,
+ pin_memory=torch.cuda.is_available(),
+ )
+ return self._predict(data_loader).sigmoid().cpu().numpy()
+
+ def _data_loader(
+ self,
+ X: Optional[List[Any]],
+ y: Optional[List[int]] = None,
+ shuffle: Optional[bool] = True,
+ ) -> DataLoader:
+ """
+ Build a DataLoader backed by
+ :class:`~ablkit.learning.MultiLabelClassificationDataset`.
+ """
+ if X is None:
+ raise ValueError("X should not be None.")
+ if y is None:
+ y = [0] * len(X)
+ if not len(y) == len(X):
+ raise ValueError("X and y should have equal length.")
+
+ dataset = MultiLabelClassificationDataset(X, y, transform=self.train_transform)
+ data_loader = DataLoader(
+ dataset,
+ batch_size=self.batch_size,
+ shuffle=shuffle,
+ num_workers=self.num_workers,
+ collate_fn=self.collate_fn,
+ pin_memory=torch.cuda.is_available(),
+ )
+ return data_loader
diff --git a/ablkit/learning/torch_dataset/__init__.py b/ablkit/learning/torch_dataset/__init__.py
index b8237a2d7..f1a0d21e8 100644
--- a/ablkit/learning/torch_dataset/__init__.py
+++ b/ablkit/learning/torch_dataset/__init__.py
@@ -1,9 +1,11 @@
from .classification_dataset import ClassificationDataset
+from .multi_label_classification_dataset import MultiLabelClassificationDataset
from .prediction_dataset import PredictionDataset
from .regression_dataset import RegressionDataset
__all__ = [
"ClassificationDataset",
+ "MultiLabelClassificationDataset",
"PredictionDataset",
"RegressionDataset",
]
diff --git a/ablkit/learning/torch_dataset/multi_label_classification_dataset.py b/ablkit/learning/torch_dataset/multi_label_classification_dataset.py
new file mode 100644
index 000000000..3e5c68dd9
--- /dev/null
+++ b/ablkit/learning/torch_dataset/multi_label_classification_dataset.py
@@ -0,0 +1,47 @@
+"""
+Implementation of PyTorch dataset class used for multi-label classification.
+
+Copyright (c) 2024 LAMDA. All rights reserved.
+"""
+
+from typing import Any, Callable, List, Optional
+
+import numpy as np
+import torch
+
+from .classification_dataset import ClassificationDataset
+
+
+class MultiLabelClassificationDataset(ClassificationDataset):
+ """
+ Dataset used for multi-label classification, where each target ``Y[i]``
+ is a binary indicator vector (one entry per label) rather than a single
+ class index. ``Y`` is stored as a ``float32`` tensor so it can be fed
+ directly into ``BCEWithLogitsLoss`` and similar losses.
+
+ Parameters
+ ----------
+ X : List[Any]
+ The input data.
+ Y : List[Any]
+ The per-sample label vectors. Each entry is converted via
+ ``numpy.stack`` and stored as a ``FloatTensor``.
+ transform : Callable[..., Any], optional
+ A function/transform that takes an object and returns a transformed
+ version. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ X: List[Any],
+ Y: List[Any],
+ transform: Optional[Callable[..., Any]] = None,
+ ) -> None:
+ if (not isinstance(X, list)) or (not isinstance(Y, list)):
+ raise ValueError("X and Y should be of type list.")
+ if len(X) != len(Y):
+ raise ValueError("Length of X and Y must be equal.")
+
+ self.X = X
+ self.Y = torch.FloatTensor(np.stack(Y, axis=0))
+ self.transform = transform
diff --git a/ablkit/reasoning/__init__.py b/ablkit/reasoning/__init__.py
index 9d5a219fe..b9e329f77 100644
--- a/ablkit/reasoning/__init__.py
+++ b/ablkit/reasoning/__init__.py
@@ -1,4 +1,11 @@
from .kb import GroundKB, KBBase, PrologKB
-from .reasoner import Reasoner
+from .reasoner import A3BLReasoner, Reasoner, VerificationReasoner
-__all__ = ["KBBase", "GroundKB", "PrologKB", "Reasoner"]
+__all__ = [
+ "KBBase",
+ "GroundKB",
+ "PrologKB",
+ "Reasoner",
+ "A3BLReasoner",
+ "VerificationReasoner",
+]
diff --git a/ablkit/reasoning/kb.py b/ablkit/reasoning/kb.py
index 3202d8dd1..9fcf7b58b 100644
--- a/ablkit/reasoning/kb.py
+++ b/ablkit/reasoning/kb.py
@@ -487,11 +487,14 @@ def __init__(self, pseudo_label_list: List[Any], pl_file: str):
try:
import pyswip # pylint: disable=import-outside-toplevel
except (IndexError, ImportError):
- print(
- "A Prolog-based knowledge base is in use. Please install SWI-Prolog using the"
- + "command 'sudo apt-get install swi-prolog' for Linux users, or download it "
- + "following the guide in https://github.com/yuce/pyswip/blob/master/INSTALL.md "
- + "for Windows and Mac users."
+ print_log(
+ "A Prolog-based knowledge base is in use. Please install SWI-Prolog using "
+ "the command 'sudo apt-get install swi-prolog' for Linux users, or download "
+ "it following the guide in "
+ "https://github.com/yuce/pyswip/blob/master/INSTALL.md "
+ "for Windows and Mac users.",
+ logger="current",
+ level=logging.WARNING,
)
self.prolog = pyswip.Prolog()
diff --git a/ablkit/reasoning/reasoner.py b/ablkit/reasoning/reasoner.py
index 2905007dc..2b5b934be 100644
--- a/ablkit/reasoning/reasoner.py
+++ b/ablkit/reasoning/reasoner.py
@@ -5,15 +5,22 @@
Copyright (c) 2024 LAMDA. All rights reserved.
"""
+import heapq
import inspect
-from typing import Any, Callable, List, Optional, Union
+from typing import Any, Callable, Iterator, List, Optional, Tuple, Union
import numpy as np
from zoopt import Dimension, Objective, Opt, Parameter, Solution
from ..data.structures import ListData
from ..reasoning import KBBase
-from ..utils.utils import hamming_dist, confidence_dist, avg_confidence_dist
+from ..utils.utils import (
+ avg_confidence_dist,
+ confidence_dist,
+ hamming_dist,
+ rejection_dist,
+ similarity_dist,
+)
class Reasoner:
@@ -30,12 +37,17 @@ class Reasoner:
measure, wherein the candidate with lowest cost is selected as the final
abduced label. It can be either a string representing a predefined distance
function or a callable function. The available predefined distance functions:
- 'hamming' | 'confidence' | 'avg_confidence'. 'hamming' directly calculates the
- Hamming distance between the predicted pseudo-label in the data example and each
- candidate. 'confidence' and 'avg_confidence' calculates the confidence distance
- between the predicted probabilities in the data example and each candidate, where
- the confidence distance is defined as 1 - the product of prediction probabilities
- in 'confidence' and 1 - the average of prediction probabilities in 'avg_confidence'.
+ 'hamming' | 'confidence' | 'avg_confidence' | 'similarity' | 'rejection'.
+ 'hamming' directly calculates the Hamming distance between the predicted
+ pseudo-label in the data example and each candidate. 'confidence' and
+ 'avg_confidence' calculate the confidence distance between the predicted
+ probabilities and each candidate, defined as ``1 - product`` and
+ ``1 - average`` of the candidate's per-symbol probabilities respectively.
+ 'similarity' compares candidates against the geometry of the model's
+ embeddings (requires the base model to expose ``extract_features``;
+ ``ABLModel`` then stores the result on ``data_example.embeddings``).
+ 'rejection' combines confidence distance with a candidate-complexity penalty,
+ favoring shorter candidates when scores are close.
Alternatively, the callable function should have the signature
``dist_func(data_example, candidates, candidate_idxs, reasoning_results)`` and must
return a cost list. Each element in this cost list should be a numerical value
@@ -83,10 +95,11 @@ def __init__(
def _check_valid_dist(self, dist_func):
if isinstance(dist_func, str):
- if dist_func not in ["hamming", "confidence", "avg_confidence"]:
+ valid = ["hamming", "confidence", "avg_confidence", "similarity", "rejection"]
+ if dist_func not in valid:
raise NotImplementedError(
- 'Valid options for predefined dist_func include "hamming", '
- + f'"confidence" and "avg_confidence", but got {dist_func}.'
+ f"Valid options for predefined dist_func are {valid}, "
+ f"but got {dist_func!r}."
)
return
elif callable(dist_func):
@@ -179,9 +192,22 @@ def _get_cost_list(
elif self.dist_func == "avg_confidence":
candidates_idxs = [[self.label_to_idx[x] for x in c] for c in candidates]
return avg_confidence_dist(data_example.pred_prob, candidates_idxs)
+ elif self.dist_func == "similarity":
+ embeddings = getattr(data_example, "embeddings", None)
+ if embeddings is None:
+ raise ValueError(
+ "dist_func='similarity' requires the base model to expose an "
+ "extract_features(X=...) method so ABLModel can populate "
+ "data_example.embeddings."
+ )
+ candidates_idxs = [[self.label_to_idx[x] for x in c] for c in candidates]
+ return similarity_dist(embeddings, candidates_idxs=candidates_idxs)
+ elif self.dist_func == "rejection":
+ candidates_idxs = [[self.label_to_idx[x] for x in c] for c in candidates]
+ return rejection_dist(data_example.pred_prob, candidates_idxs=candidates_idxs)
else:
- candidate_idxs = [[self.label_to_idx[x] for x in c] for c in candidates]
- cost_list = self.dist_func(data_example, candidates, candidate_idxs, reasoning_results)
+ candidates_idxs = [[self.label_to_idx[x] for x in c] for c in candidates]
+ cost_list = self.dist_func(data_example, candidates, candidates_idxs, reasoning_results)
if len(cost_list) != len(candidates):
raise ValueError(
"The length of the array returned by dist_func must be equal to the number "
@@ -354,5 +380,444 @@ def batch_abduce(self, data_examples: ListData) -> List[List[Any]]:
data_examples.abduced_pseudo_label = abduced_pseudo_label
return abduced_pseudo_label
+ def batch_supervised_abduce(self, data_examples: ListData) -> List[List[Any]]:
+ """
+ Perform abductive reasoning on the given prediction data examples, using supervised data
+ when gt_pseudo_label is given.
+ """
+ abduced_pseudo_label = [
+ (
+ data_example.gt_pseudo_label
+ if data_example.gt_pseudo_label
+ else self.abduce(data_example)
+ )
+ for data_example in data_examples
+ ]
+ data_examples.abduced_pseudo_label = abduced_pseudo_label
+ return abduced_pseudo_label
+
+ def __call__(self, data_examples: ListData) -> List[List[Any]]:
+ return self.batch_abduce(data_examples)
+
+
+# =============================================================================
+# A3BL: Ambiguity-Aware Abductive Learning
+#
+# Reference: https://github.com/Hao-Yuan-He/A3BL
+# =============================================================================
+
+
+class A3BLReasoner(Reasoner):
+ """
+ Reasoner for minimizing the inconsistency between the knowledge base and learning models.
+
+ Parameters
+ ----------
+ kb : class KBBase
+ The knowledge base to be used for reasoning.
+ dist_func : Union[str, Callable], optional
+ The distance function used to determine the cost list between each
+ candidate and the given prediction. The cost is also referred to as a consistency
+ measure, wherein the candidate with the lowest cost is selected as the final
+ abduced label. It can be either a string representing a predefined distance
+ function or a callable function. The available predefined distance functions:
+ 'hamming' | 'confidence' | 'avg_confidence' | 'similarity' | 'rejection'.
+ See :class:`Reasoner` for the full description of each option.
+ Defaults to 'confidence'.
+ idx_to_label : dict, optional
+ A mapping from index in the base model to label. If not provided, a default
+ order-based index to label mapping is created. Defaults to None.
+ max_revision : Union[int, float], optional
+ The upper limit on the number of revisions for each data example when
+ performing abductive reasoning. If float, denotes the fraction of the total
+ length that can be revised. A value of -1 implies no restriction on the
+ number of revisions. Defaults to -1.
+ require_more_revision : int, optional
+ Specifies additional number of revisions permitted beyond the minimum required
+ when performing abductive reasoning. Defaults to 0.
+ use_zoopt : bool, optional
+ Whether to use ZOOpt library during abductive reasoning. Defaults to False.
+ topK : int, optional
+ Number of top-ranked candidates to keep when forming the soft label. ``-1``
+ keeps all candidates. Defaults to 16.
+ temperature : float, optional
+ Softmax temperature used when aggregating candidate probabilities into a
+ soft label. Lower values produce sharper distributions. Defaults to 0.2.
+ multi_label : bool, optional
+ Whether the underlying task is multi-label (each symbol is a binary vector
+ rather than a single class index). Defaults to False.
+ """
+
+ def __init__(
+ self,
+ kb,
+ dist_func="confidence",
+ idx_to_label=None,
+ max_revision: Union[int, float] = -1,
+ require_more_revision: int = 0,
+ use_zoopt: bool = False,
+ topK: int = 16,
+ temperature: float = 0.2,
+ multi_label: bool = False,
+ ):
+ super().__init__(
+ kb, dist_func, idx_to_label, max_revision, require_more_revision, use_zoopt
+ )
+ import torch
+
+ self.topK = topK
+ self.temperature = temperature
+ self.class_num = len(self.kb.pseudo_label_list)
+ self.multi_label = multi_label
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ def _confidence_dist(
+ self, pred_probs: np.ndarray, candidate_idxs: List[List[Any]], temp: float = 1.0
+ ) -> np.ndarray:
+ from scipy.special import softmax
+
+ candidates_array = np.array(candidate_idxs)
+ _, symbol_num = candidates_array.shape
+ row_indices = np.arange(symbol_num)[:, np.newaxis]
+ selected_probs = pred_probs[row_indices, candidates_array.T]
+ candidate_probs = np.sum(selected_probs, axis=0) / temp
+ return softmax(candidate_probs)
+
+ def _confidence_dist_multi_label(
+ self, pred_probs: np.ndarray, candidate_idxs: List[List[Any]], temp: float = 1.0
+ ) -> np.ndarray:
+ from scipy.special import softmax
+
+ candidate_probs = pred_probs @ np.array(candidate_idxs).T / temp
+ return softmax(candidate_probs.squeeze(axis=0))
+
+ def _candidates_idxs(self, candidates: List[List[Any]]):
+ return [[self.label_to_idx[x] for x in c] for c in candidates]
+
+ def _topk(
+ self, candidates: List[Any], candidate_probs: np.ndarray, K: int = -1
+ ) -> Tuple[List[List[Any]], List[Any]]:
+ """
+ Performs a top-k selection from the candidate_set based on candidate_probs.
+ If `K` is set to -1, all candidates are chosen.
+ Returns a tuple containing the selected candidates and their corresponding probabilities.
+ """
+ import heapq
+
+ if K == -1 or len(candidates) <= K:
+ return candidates, candidate_probs
+
+ # Iterate over all candidates and maintain a heap of size K with the largest probabilities
+ heap = []
+ for i, (candidate, prob) in enumerate(zip(candidates, candidate_probs)):
+ if i < K:
+ heapq.heappush(heap, (prob, candidate))
+ else:
+ if prob > heap[0][0]:
+ heapq.heappop(heap)
+ heapq.heappush(heap, (prob, candidate))
+
+ # Extract top-k elements from the heap,
+ # and reverse them to get the highest probabilities first
+ topk_probs, topk_candidates = zip(*heap)
+ return list(topk_candidates), list(topk_probs)
+
+ def multi_label_aggregate(self, candidates: List[List[int]], candidate_probs: List[float]):
+ """
+ An multi-label version of A3BL.
+ """
+ import torch
+
+ with torch.no_grad():
+ symbol_num = len(candidates[0])
+ aggregate_label = torch.zeros(size=(symbol_num, 1))
+ for candidate, prob in zip(candidates, candidate_probs):
+ for i, item in enumerate(candidate):
+ if item == 1:
+ aggregate_label[i] += prob
+ return list(aggregate_label.unbind(1))
+
+ def aggregate(self, candidates: List[List[int]], candidate_probs: List[float]):
+ import torch
+ import torch.nn.functional as F
+
+ with torch.no_grad():
+ candidates_tensor = torch.tensor(candidates, device=self.device, dtype=torch.long)
+ probs_tensor = torch.tensor(candidate_probs, device=self.device, dtype=torch.float32)
+ one_hot = F.one_hot(candidates_tensor, num_classes=self.class_num).float() # [N, M, C]
+ weighted_one_hot = one_hot * probs_tensor.unsqueeze(-1).unsqueeze(-1) # [N, M, C]
+ aggregate_label = weighted_one_hot.sum(dim=0) # [M, C]
+ return [tensor.cpu() for tensor in aggregate_label.unbind(0)]
+
+ def abduce(self, data_example: ListData) -> Tuple[List[Any], List[Any]]:
+ """
+ Perform abduction and get a soft label distribution aggregated from
+ all valid candidates that satisfy the underlying rules.
+
+ Parameters
+ ----------
+ data_example : ListData
+ Data example.
+
+ Returns
+ -------
+ soft_label : List[Any]
+ Soft label aggregated from the top-k valid candidates.
+ pseudo_label : List[Any]
+ Hard pseudo-label revision (the top-1 candidate) that is
+ consistent with the knowledge base.
+ """
+ max_revision_num = data_example.elements_num("pred_pseudo_label")
+ max_revision_num = self._get_max_revision_num(self.max_revision, max_revision_num)
+ candidates, _ = self.kb.abduce_candidates(
+ pseudo_label=data_example.pred_pseudo_label,
+ y=data_example.Y,
+ x=data_example.X,
+ max_revision_num=max_revision_num,
+ require_more_revision=self.require_more_revision,
+ )
+
+ if len(candidates) == 0:
+ return [], []
+
+ confidence_dist_cal = (
+ self._confidence_dist if not self.multi_label else self._confidence_dist_multi_label
+ )
+
+ candidate_probs = confidence_dist_cal(
+ data_example.pred_prob, self._candidates_idxs(candidates), self.temperature
+ )
+ topk_candidates, topk_candidates_probs = self._topk(candidates, candidate_probs, self.topK)
+ aggregated_labels = (
+ self.aggregate(topk_candidates, topk_candidates_probs)
+ if not self.multi_label
+ else self.multi_label_aggregate(topk_candidates, topk_candidates_probs)
+ )
+ return aggregated_labels, topk_candidates[0]
+
+ def batch_abduce(self, data_examples: ListData) -> List[List[Any]]:
+ """
+ Perform abductive reasoning on the given prediction data examples.
+ For detailed information, refer to ``abduce``.
+ """
+ abduced_soft_label, abduced_pseudo_label = zip(
+ *[self.abduce(data_example) for data_example in data_examples]
+ )
+ data_examples.abduced_soft_label = abduced_soft_label
+ data_examples.abduced_pseudo_label = abduced_pseudo_label
+ return abduced_soft_label
+
def __call__(self, data_examples: ListData) -> List[List[Any]]:
return self.batch_abduce(data_examples)
+
+
+# =============================================================================
+# Verification Learning
+#
+# Walks the per-symbol probability lattice in descending joint-probability
+# order, collecting the first top_k assignments that satisfy the knowledge
+# base. Reference: https://github.com/VerificationLearning/VerificationLearning
+# =============================================================================
+
+
+def enumerate_label_assignments(
+ pred_prob: np.ndarray, max_iter: int = 10000
+) -> Iterator[Tuple[List[int], float, List[float]]]:
+ """
+ Yield label-index assignments for a single data example in descending
+ joint-probability order. The walk is a Lawler-style best-first search:
+ each state is the tuple of per-symbol rank indices, and successors are
+ generated by advancing any one symbol to its next-best class.
+
+ Parameters
+ ----------
+ pred_prob : np.ndarray
+ Per-symbol probability matrix with shape ``(num_symbols, num_classes)``.
+ max_iter : int, optional
+ Hard cap on the number of yields. Defaults to 10000.
+
+ Yields
+ ------
+ labels : List[int]
+ Class indices for each symbol.
+ joint_prob : float
+ Product of the chosen per-symbol probabilities.
+ per_symbol_probs : List[float]
+ The chosen probability for each symbol.
+ """
+ pred_prob = np.asarray(pred_prob, dtype=float)
+ num_symbols, num_classes = pred_prob.shape
+ if num_symbols == 0:
+ return
+
+ sorted_indices = np.argsort(-pred_prob, axis=1)
+ sorted_probs = np.take_along_axis(pred_prob, sorted_indices, axis=1)
+
+ initial_state = (0,) * num_symbols
+ initial_prob = float(np.prod(sorted_probs[:, 0]))
+
+ seen = {initial_state}
+ heap: List[Tuple[float, Tuple[int, ...]]] = [(-initial_prob, initial_state)]
+ yields = 0
+
+ while heap and yields < max_iter:
+ neg_prob, state = heapq.heappop(heap)
+ joint_prob = -neg_prob
+ labels = [int(sorted_indices[i, state[i]]) for i in range(num_symbols)]
+ per_symbol_probs = [float(sorted_probs[i, state[i]]) for i in range(num_symbols)]
+ yield labels, joint_prob, per_symbol_probs
+ yields += 1
+
+ for sym in range(num_symbols):
+ next_rank = state[sym] + 1
+ if next_rank >= num_classes:
+ continue
+ new_state = state[:sym] + (next_rank,) + state[sym + 1:]
+ if new_state in seen:
+ continue
+ seen.add(new_state)
+ current_p = sorted_probs[sym, state[sym]]
+ next_p = sorted_probs[sym, next_rank]
+ if current_p <= 0:
+ new_joint = 0.0
+ else:
+ new_joint = joint_prob * (next_p / current_p)
+ heapq.heappush(heap, (-new_joint, new_state))
+
+
+def top_k_satisfying(
+ pred_prob: np.ndarray,
+ predicate: Callable[[List[Any]], bool],
+ top_k: int = 1,
+ max_iter: int = 10000,
+ idx_to_label: Optional[dict] = None,
+) -> Tuple[List[List[Any]], List[float]]:
+ """
+ Walk label assignments in descending joint-probability order and return
+ the first ``top_k`` that satisfy ``predicate``. If none is found within
+ ``max_iter`` iterations the single highest-probability assignment is
+ returned as a fallback so callers always receive a usable label.
+
+ Parameters
+ ----------
+ pred_prob : np.ndarray
+ Per-symbol probability matrix with shape ``(num_symbols, num_classes)``.
+ predicate : Callable[[List[Any]], bool]
+ Function called on each candidate label sequence; truthy means the
+ candidate is consistent with the knowledge base.
+ top_k : int, optional
+ Maximum number of satisfying candidates to return. Defaults to 1.
+ max_iter : int, optional
+ Hard cap on enumeration steps. Defaults to 10000.
+ idx_to_label : dict, optional
+ Optional mapping from class index to pseudo-label. When omitted, the
+ raw class indices are returned.
+
+ Returns
+ -------
+ candidates : List[List[Any]]
+ Label assignments that satisfy ``predicate`` (or the fallback).
+ probs : List[float]
+ Joint probability of each returned candidate.
+ """
+ matches: List[List[Any]] = []
+ probs: List[float] = []
+ fallback: Optional[Tuple[List[Any], float]] = None
+
+ for labels_idx, joint_prob, _ in enumerate_label_assignments(pred_prob, max_iter):
+ labels = (
+ [idx_to_label[i] for i in labels_idx] if idx_to_label is not None else labels_idx
+ )
+ if fallback is None:
+ fallback = (labels, joint_prob)
+ if predicate(labels):
+ matches.append(labels)
+ probs.append(joint_prob)
+ if len(matches) >= top_k:
+ break
+
+ if not matches and fallback is not None:
+ matches.append(fallback[0])
+ probs.append(fallback[1])
+ return matches, probs
+
+
+class VerificationReasoner:
+ """
+ Reasoner used by :class:`~ablkit.bridge.VerificationBridge`. Rather than
+ picking a single best candidate via a distance function, it enumerates
+ the top ``top_k`` label assignments that satisfy the knowledge base,
+ ordered by joint probability. The bridge then trains the model on each
+ of those candidates.
+
+ Parameters
+ ----------
+ kb : KBBase
+ The knowledge base used to verify candidates. ``kb.logic_forward``
+ must return the reasoning result so it can be compared with each
+ data example's ``Y``.
+ top_k : int, optional
+ Number of satisfying candidates to enumerate per example.
+ Defaults to 1.
+ max_iter : int, optional
+ Maximum number of enumeration steps per example before giving up
+ and returning the fallback. Defaults to 10000.
+ idx_to_label : dict, optional
+ A mapping from base-model index to pseudo-label. If omitted a
+ default order-based mapping is built from ``kb.pseudo_label_list``.
+ """
+
+ def __init__(
+ self,
+ kb: KBBase,
+ top_k: int = 1,
+ max_iter: int = 10000,
+ idx_to_label: Optional[dict] = None,
+ ) -> None:
+ if top_k < 1:
+ raise ValueError("top_k must be >= 1.")
+ if max_iter < 1:
+ raise ValueError("max_iter must be >= 1.")
+ self.kb = kb
+ self.top_k = top_k
+ self.max_iter = max_iter
+ if idx_to_label is None:
+ idx_to_label = dict(enumerate(kb.pseudo_label_list))
+ self.idx_to_label = idx_to_label
+ self.label_to_idx = {label: idx for idx, label in idx_to_label.items()}
+
+ def top_k_candidates(
+ self, pred_prob: np.ndarray, y: Any
+ ) -> Tuple[List[List[Any]], List[float]]:
+ """
+ Return up to ``top_k`` label assignments for one data example whose
+ ``kb.logic_forward`` matches ``y``.
+ """
+
+ def predicate(labels: List[Any]) -> bool:
+ return self.kb.logic_forward(labels) == y
+
+ return top_k_satisfying(
+ pred_prob,
+ predicate,
+ top_k=self.top_k,
+ max_iter=self.max_iter,
+ idx_to_label=self.idx_to_label,
+ )
+
+ def batch_top_k(self, data_examples) -> List[List[List[Any]]]:
+ """
+ Run :meth:`top_k_candidates` on every example in ``data_examples``.
+ Stores the result on ``data_examples.top_k_candidates`` and
+ ``data_examples.top_k_probs``. Returns the list of per-example
+ candidate lists.
+ """
+ all_candidates: List[List[List[Any]]] = []
+ all_probs: List[List[float]] = []
+ for data_example in data_examples:
+ cands, probs = self.top_k_candidates(data_example.pred_prob, data_example.Y)
+ all_candidates.append(cands)
+ all_probs.append(probs)
+ data_examples.top_k_candidates = all_candidates
+ data_examples.top_k_probs = all_probs
+ return all_candidates
diff --git a/ablkit/utils/__init__.py b/ablkit/utils/__init__.py
index 7962745cf..c1214d290 100644
--- a/ablkit/utils/__init__.py
+++ b/ablkit/utils/__init__.py
@@ -1,13 +1,15 @@
from .cache import Cache, abl_cache
from .logger import ABLLogger, print_log
from .utils import (
- confidence_dist,
avg_confidence_dist,
+ confidence_dist,
flatten,
hamming_dist,
reform_list,
- to_hashable,
+ rejection_dist,
+ similarity_dist,
tab_data_to_tuple,
+ to_hashable,
)
__all__ = [
@@ -19,6 +21,8 @@
"flatten",
"hamming_dist",
"reform_list",
+ "rejection_dist",
+ "similarity_dist",
"to_hashable",
"abl_cache",
"tab_data_to_tuple",
diff --git a/ablkit/utils/utils.py b/ablkit/utils/utils.py
index abc1e3561..02c60a4ab 100644
--- a/ablkit/utils/utils.py
+++ b/ablkit/utils/utils.py
@@ -138,6 +138,79 @@ def avg_confidence_dist(pred_prob: np.ndarray, candidates_idxs: List[List[Any]])
return 1 - np.average(pred_prob[cols, candidates_idxs], axis=1)
+def similarity_dist(
+ pred_embeddings: np.ndarray, candidates_idxs: List[List[Any]]
+) -> np.ndarray:
+ """
+ Compute a similarity-based cost for each candidate label assignment.
+
+ For each candidate, the cost is the average cosine similarity between
+ symbol pairs assigned different labels minus the average between pairs
+ assigned the same label. Lower values mean the candidate's labeling is
+ more consistent with the embedding geometry.
+
+ Parameters
+ ----------
+ pred_embeddings : np.ndarray
+ Embedding matrix for the symbols in a single data example, with shape
+ ``(num_symbols, embedding_dim)``.
+ candidates_idxs : List[List[Any]]
+ Candidate label assignments, each of length ``num_symbols``.
+
+ Returns
+ -------
+ np.ndarray
+ Cost for each candidate.
+ """
+ phi = np.asarray(pred_embeddings, dtype=float)
+ norms = np.linalg.norm(phi, axis=1, keepdims=True)
+ phi_normalized = phi / (norms + 1e-8)
+ sim = phi_normalized @ phi_normalized.T
+ num_symbols = sim.shape[0]
+ triu_i, triu_j = np.triu_indices(num_symbols, k=1)
+ pair_sims = sim[triu_i, triu_j]
+
+ costs = []
+ for cand in candidates_idxs:
+ labels = np.asarray(cand)
+ same = labels[triu_i] == labels[triu_j]
+ intra = pair_sims[same].mean() if same.any() else 0.0
+ inter = pair_sims[~same].mean() if (~same).any() else 0.0
+ costs.append(inter - intra)
+ return np.asarray(costs)
+
+
+def rejection_dist(
+ pred_prob: np.ndarray, candidates_idxs: List[List[Any]], alpha: float = 0.5
+) -> np.ndarray:
+ """
+ Compute a rejection-aware cost that combines model confidence with
+ candidate complexity. Each candidate's cost is a convex combination of
+ the standard confidence distance and a normalized length term, so
+ longer (more complex) candidates are penalized.
+
+ Parameters
+ ----------
+ pred_prob : np.ndarray
+ Prediction probability distributions for the symbols in a single
+ data example.
+ candidates_idxs : List[List[Any]]
+ Candidate label assignments.
+ alpha : float, optional
+ Weight in ``[0, 1]`` for the complexity term. Defaults to 0.5.
+
+ Returns
+ -------
+ np.ndarray
+ Cost for each candidate.
+ """
+ conf = confidence_dist(pred_prob, candidates_idxs)
+ complexity = np.array([len(c) for c in candidates_idxs], dtype=float)
+ if complexity.max() > 0:
+ complexity = complexity / complexity.max()
+ return (1 - alpha) * conf + alpha * complexity
+
+
def to_hashable(x: Union[List[Any], Any]) -> Union[Tuple[Any, ...], Any]:
"""
Convert a nested list to a nested tuple so it is hashable.
diff --git a/docs/API/ablkit.bridge.rst b/docs/API/ablkit.bridge.rst
index 3cfc12702..7140bc24d 100644
--- a/docs/API/ablkit.bridge.rst
+++ b/docs/API/ablkit.bridge.rst
@@ -1,7 +1,7 @@
ablkit.bridge
-==================
+=============
.. automodule:: ablkit.bridge
:members:
:undoc-members:
- :show-inheritance:
\ No newline at end of file
+ :show-inheritance:
diff --git a/docs/API/ablkit.data.rst b/docs/API/ablkit.data.rst
index 32bc5d26d..d935b04f3 100644
--- a/docs/API/ablkit.data.rst
+++ b/docs/API/ablkit.data.rst
@@ -1,5 +1,5 @@
ablkit.data
-===================
+===========
``structures``
--------------
diff --git a/docs/API/ablkit.learning.rst b/docs/API/ablkit.learning.rst
index f26a866e2..0b0c2683f 100644
--- a/docs/API/ablkit.learning.rst
+++ b/docs/API/ablkit.learning.rst
@@ -1,5 +1,5 @@
ablkit.learning
-==================
+===============
.. autoclass:: ablkit.learning.ABLModel
:members:
diff --git a/docs/API/ablkit.reasoning.rst b/docs/API/ablkit.reasoning.rst
index 8707fc308..9de28556c 100644
--- a/docs/API/ablkit.reasoning.rst
+++ b/docs/API/ablkit.reasoning.rst
@@ -1,7 +1,36 @@
ablkit.reasoning
-==================
+================
-.. automodule:: ablkit.reasoning
+.. autoclass:: ablkit.reasoning.KBBase
:members:
:undoc-members:
- :show-inheritance:
\ No newline at end of file
+ :show-inheritance:
+
+.. autoclass:: ablkit.reasoning.GroundKB
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: ablkit.reasoning.PrologKB
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: ablkit.reasoning.Reasoner
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: ablkit.reasoning.A3BLReasoner
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: ablkit.reasoning.VerificationReasoner
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autofunction:: ablkit.reasoning.reasoner.enumerate_label_assignments
+
+.. autofunction:: ablkit.reasoning.reasoner.top_k_satisfying
diff --git a/docs/API/ablkit.utils.rst b/docs/API/ablkit.utils.rst
index b963a8963..604024ce0 100644
--- a/docs/API/ablkit.utils.rst
+++ b/docs/API/ablkit.utils.rst
@@ -1,7 +1,20 @@
ablkit.utils
-==================
+============
.. automodule:: ablkit.utils
:members:
+ ABLLogger,
+ print_log,
+ Cache,
+ abl_cache,
+ hamming_dist,
+ confidence_dist,
+ avg_confidence_dist,
+ similarity_dist,
+ rejection_dist,
+ flatten,
+ reform_list,
+ to_hashable,
+ tab_data_to_tuple
:undoc-members:
- :show-inheritance:
\ No newline at end of file
+ :show-inheritance:
diff --git a/docs/Examples/BDD-OIA.rst b/docs/Examples/BDD-OIA.rst
new file mode 100644
index 000000000..df75278ae
--- /dev/null
+++ b/docs/Examples/BDD-OIA.rst
@@ -0,0 +1,216 @@
+BDD-OIA
+=======
+
+.. raw:: html
+
+ For detailed code implementation, please view it on GitHub.
+
+Below shows an implementation of `BDD-OIA `__.
+The BDD-OIA dataset comprises frames extracted from driving scene videos
+that are used for autonomous driving predictions. Each frame is
+annotated with 4 binary action labels (:math:`\textsf{move_forward}`,
+:math:`\textsf{stop}`, :math:`\textsf{turn_left}`, :math:`\textsf{turn_right}`),
+as well as 21 intermediate binary concept labels such as
+:math:`\textsf{red_light}` and :math:`\textsf{road_clear}` that explain those
+actions.
+
+The objective is to predict the possible actions for each frame.
+During training we use only the action-level supervision together with
+a knowledge base that captures the relations between concepts and
+actions, e.g.,
+:math:`\textsf{red_light} \lor \textsf{traffic_sign} \lor \textsf{obstacle} \implies \textsf{stop}`.
+The training set contains 16,000 frames; the test set contains 4,500.
+
+Intuitively, the learning part predicts the 21 binary concept
+pseudo-labels from each frame, and the reasoning part uses the
+knowledge base to derive the four action labels from those concepts.
+When the learning part's predictions conflict with the ground-truth
+actions, the reasoner revises the concepts via abductive reasoning,
+and those revised concepts are used to further train the learning
+part.
+
+The dataset was preprocessed by `Marconato et al. (2023) `__
+with a pretrained Faster-RCNN on BDD-100k together with the first
+module of CBM-AUC `(Sawada & Nakamura, 2022) `__,
+yielding a 2048-dimensional visual feature for each frame.
+
+.. code:: python
+
+ # Import necessary libraries and modules
+ import os.path as osp
+
+ import numpy as np
+ import torch
+ import torch.nn as nn
+ from torch import optim
+
+ from ablkit.data.evaluation import SymbolAccuracy
+ from ablkit.learning import MultiLabelABLModel, MultiLabelBasicNN
+ from ablkit.reasoning import KBBase, Reasoner
+ from ablkit.utils import ABLLogger, print_log
+
+ from bridge import BDDBridge
+ from dataset.data_util import get_dataset
+ from metric import BDDReasoningMetric
+ from models.nn import ConceptNet
+
+Working with Data
+-----------------
+
+First, we load the training, validation, and testing splits:
+
+.. code:: python
+
+ train_data = get_dataset(fname="train.npz", get_pseudo_label=True)
+ val_data = get_dataset(fname="val.npz", get_pseudo_label=True)
+ test_data = get_dataset(fname="test.npz", get_pseudo_label=True)
+
+Each split consists of three components (``X``, ``gt_pseudo_label``,
+and ``Y``) with one entry per frame:
+
+- ``X[i]`` is a list with a single ndarray of shape ``(2048,)``, the
+ pre-extracted visual feature for the frame.
+- ``gt_pseudo_label[i]`` is a list of length 21 holding the binary
+ concept annotations (``red_light``, ``road_clear``, …).
+- ``Y[i]`` is a tuple of length 4 holding the binary action labels
+ (``move_forward``, ``stop``, ``turn_left``, ``turn_right``).
+
+During training only ``X`` and ``Y`` are used; ``gt_pseudo_label`` is
+held back for evaluation.
+
+Building the Learning Part
+--------------------------
+
+To build the learning part we first construct a PyTorch model,
+``ConceptNet``, then wrap it in
+:class:`~ablkit.learning.MultiLabelBasicNN` to obtain an sklearn-style
+base model. ``MultiLabelBasicNN`` is a multi-label variant of
+``BasicNN``: the output uses sigmoid activations rather than softmax,
+predictions are binary vectors rather than single class indices, and
+the dataset is a
+:class:`~ablkit.learning.MultiLabelClassificationDataset`. The 21
+outputs therefore correspond to the 21 binary concept labels.
+
+.. code:: python
+
+ net = ConceptNet()
+ loss_fn = nn.BCEWithLogitsLoss()
+ optimizer = optim.Adam(net.parameters(), lr=0.002)
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ scheduler = optim.lr_scheduler.OneCycleLR(
+ optimizer,
+ max_lr=0.002,
+ pct_start=0.15,
+ epochs=2,
+ steps_per_epoch=int(1 / 0.01) + 1,
+ )
+
+ base_model = MultiLabelBasicNN(
+ net,
+ loss_fn,
+ optimizer,
+ scheduler=scheduler,
+ device=device,
+ batch_size=32,
+ num_epochs=1,
+ )
+
+``MultiLabelBasicNN`` operates on a single frame at a time. To work at
+the example level (a frame together with its label set), we wrap the
+base model in :class:`~ablkit.learning.MultiLabelABLModel`, an
+``ABLModel`` subclass that threshold-binarises the sigmoid
+probabilities into per-concept 0/1 pseudo-labels.
+
+.. code:: python
+
+ model = MultiLabelABLModel(base_model)
+
+Building the Reasoning Part
+---------------------------
+
+The knowledge base ``BDDKB`` encodes the rules linking the 21
+concepts to the 4 actions (e.g., ``red_light`` or ``obstacle`` imply
+``stop``; ``green_light`` together with ``road_clear`` implies
+``move_forward``). It subclasses ``KBBase``; the ``pseudo_label_list``
+parameter is ``[0, 1]`` because each pseudo-label is binary, and the
+``logic_forward`` method computes the 4-tuple of action labels from
+the 21 concept attributes.
+
+.. code:: python
+
+ from reasoning.bddkb import BDDKB
+
+ kb = BDDKB()
+
+Since abductive reasoning is non-deterministic, multiple concept
+revisions can be consistent with the ground-truth actions. The
+``Reasoner`` picks the revision that minimises a user-supplied
+distance function. For BDD-OIA we provide
+``multi_label_confidence_dist``, which sums ``-log(p)`` over the
+concept-by-concept probabilities so that revisions consistent with the
+learning part's per-concept confidence are preferred:
+
+.. code:: python
+
+ def multi_label_confidence_dist(data_example, candidates, candidates_idxs, reasoning_results):
+ pred_prob = data_example.pred_prob.T # nc x 1
+ pred_prob = np.concatenate([1 - pred_prob, pred_prob], axis=1) # nc x 2
+ cols = np.arange(len(candidates_idxs[0]))[None, :]
+ corr_prob = pred_prob[cols, candidates_idxs]
+ costs = -np.sum(np.log(corr_prob + 1e-6), axis=1)
+ return costs
+
+ reasoner = Reasoner(
+ kb,
+ dist_func=multi_label_confidence_dist,
+ max_revision=3,
+ require_more_revision=3,
+ )
+
+``max_revision`` and ``require_more_revision`` cap how many concept
+flips the reasoner explores when searching for a consistent
+revision.
+
+Building Evaluation Metrics
+---------------------------
+
+We track two metrics. ``SymbolAccuracy`` measures how often the
+predicted concepts match the ground-truth concepts, and
+``BDDReasoningMetric`` measures the per-action accuracy after
+passing the predicted concepts through ``logic_forward``.
+
+.. code:: python
+
+ metric_list = [
+ SymbolAccuracy(prefix="bdd_oia"),
+ BDDReasoningMetric(kb=kb, prefix="bdd_oia"),
+ ]
+
+Bridging Learning and Reasoning
+-------------------------------
+
+Finally we bridge the learning and reasoning parts via ``BDDBridge``,
+a thin subclass of ``SimpleBridge`` that handles the
+multi-label-specific shape of ``pred_idx`` (a ``[1, nc]`` ndarray per
+example).
+
+.. code:: python
+
+ bridge = BDDBridge(model, reasoner, metric_list)
+
+Training and testing reuse the standard ``SimpleBridge`` interface:
+
+.. code:: python
+
+ print_log("Abductive Learning on the BDD_OIA example.", logger="current")
+ log_dir = ABLLogger.get_current_instance().log_dir
+ weights_dir = osp.join(log_dir, "weights")
+
+ bridge.train(
+ train_data,
+ loops=2,
+ segment_size=0.01,
+ save_interval=1,
+ save_dir=weights_dir,
+ )
+ bridge.test(test_data)
diff --git a/docs/Examples/HED.rst b/docs/Examples/HED.rst
index d8fbed02a..6c4851445 100644
--- a/docs/Examples/HED.rst
+++ b/docs/Examples/HED.rst
@@ -179,13 +179,13 @@ sklearn-style interface.
.. code:: python
# class of symbol may be one of ['0', '1', '+', '='], total of 4 classes
- cls = SymbolNet(num_classes=4)
+ net = SymbolNet(num_classes=4)
loss_fn = nn.CrossEntropyLoss()
- optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-4)
+ optimizer = torch.optim.RMSprop(net.parameters(), lr=0.001, weight_decay=1e-4)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
+
base_model = BasicNN(
- cls,
+ net,
loss_fn,
optimizer,
device=device,
diff --git a/docs/Examples/HWF.rst b/docs/Examples/HWF.rst
index 9ab0c54d4..4b801c1d5 100644
--- a/docs/Examples/HWF.rst
+++ b/docs/Examples/HWF.rst
@@ -8,7 +8,7 @@ Handwritten Formula (HWF)
Below shows an implementation of `Handwritten
Formula `__. In this
task, handwritten images of decimal formulas and their computed results
-are given, alongwith a domain knowledge base containing information on
+are given, along with a domain knowledge base containing information on
how to compute the decimal formula. The task is to recognize the symbols
(which can be digits or operators ‘+’, ‘-’, ‘×’, ‘÷’) of handwritten
images and accurately determine their results.
@@ -100,7 +100,7 @@ Out:
The ith element of X, gt_pseudo_label, and Y together constitute the ith
data example. Here we use two of them (the 1001st and the 3001st) as
-illstrations:
+illustrations:
.. code:: python
@@ -177,14 +177,14 @@ sklearn-style interface.
.. code:: python
- # class of symbol may be one of ['1', ..., '9', '+', '-', '*', '/'], total of 14 classes
- cls = SymbolNet(num_classes=13, image_size=(45, 45, 1))
+ # class of symbol may be one of ['1', ..., '9', '+', '-', '*', '/'], total of 13 classes
+ net = SymbolNet(num_classes=13, image_size=(45, 45, 1))
loss_fn = nn.CrossEntropyLoss()
- optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))
+ optimizer = torch.optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.99))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-
+
base_model = BasicNN(
- model=cls,
+ model=net,
loss_fn=loss_fn,
optimizer=optimizer,
device=device,
diff --git a/docs/Examples/MNISTAdd.rst b/docs/Examples/MNISTAdd.rst
index b90bb86e3..850921958 100644
--- a/docs/Examples/MNISTAdd.rst
+++ b/docs/Examples/MNISTAdd.rst
@@ -7,7 +7,7 @@ MNIST Addition
Below shows an implementation of `MNIST
Addition `__. In this task, pairs of
-MNIST handwritten images and their sums are given, alongwith a domain
+MNIST handwritten images and their sums are given, along with a domain
knowledge base containing information on how to perform addition
operations. The task is to recognize the digits of handwritten images
and accurately determine their sum.
@@ -147,14 +147,14 @@ model with a sklearn-style interface.
.. code:: python
- cls = LeNet5(num_classes=10)
+ net = LeNet5(num_classes=10)
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
- optimizer = RMSprop(cls.parameters(), lr=0.001, alpha=0.9)
+ optimizer = RMSprop(net.parameters(), lr=0.001, alpha=0.9)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, pct_start=0.1, total_steps=100)
base_model = BasicNN(
- cls,
+ net,
loss_fn,
optimizer,
scheduler=scheduler,
@@ -298,8 +298,10 @@ candidate that has the highest consistency.
customized within the ``dist_func`` parameter. In the code above, we
employ a consistency measurement based on confidence, which calculates
the consistency between the data example and candidates based on the
- confidence derived from the predicted probability. In ``examples/mnist_add/main.py``, we
- provide options for utilizing other forms of consistency measurement.
+ confidence derived from the predicted probability. In
+ ``examples/mnist_add/main.py``, the ``--dist-func`` flag lets you swap in
+ other predefined options (``hamming``, ``avg_confidence``, ``similarity``,
+ ``rejection``) without editing the code.
Also, during the process of inconsistency minimization, we can leverage
`ZOOpt library `__ for acceleration.
@@ -368,7 +370,57 @@ Log:
abl - INFO - Eval start: loop(val) [2]
abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.993 mnist_add/reasoning_accuracy: 0.986
abl - INFO - Test start:
- abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.991 mnist_add/reasoning_accuracy: 0.980
+ abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.991 mnist_add/reasoning_accuracy: 0.980
+
+
+Command-line options
+--------------------
+
+In addition to the standard pipeline above, ``examples/mnist_add/main.py`` accepts
+several flags that switch in alternative methods. The defaults reproduce the
+standard pipeline, so existing usage is unaffected.
+
+- ``--method {standard,a3bl,verification}``: choose the learning/reasoning
+ pipeline.
+ ``standard`` uses ``Reasoner`` / ``SimpleBridge``.
+ ``a3bl`` uses the ambiguity-aware ``A3BLReasoner`` together with
+ ``A3BLBridge``.
+ ``verification`` uses ``VerificationReasoner`` and ``VerificationBridge``:
+ rather than picking one best candidate by ``dist_func``, it enumerates
+ the top ``--top-k`` consistent candidates by joint probability and trains
+ the model once per candidate.
+ All methods share the same ``BasicNN`` / ``ABLModel`` wrappers.
+- ``--dist-func {hamming,confidence,avg_confidence,similarity,rejection}``:
+ passed straight to the reasoner. ``similarity`` requires the wrapped
+ PyTorch model to implement ``extract_features(x)`` (the ``LeNet5`` in
+ this example does); ``BasicNN`` then surfaces those embeddings to the
+ reasoner automatically. Ignored when ``--method verification``.
+- ``--labeled-ratio FLOAT``: fraction in ``(0, 1]`` of training samples that
+ keep their ground-truth pseudo-labels. Values below ``1.0`` enable the
+ semi-supervised pipeline (``use_supervised_data=True`` on
+ ``bridge.train``). Only valid with ``--method standard``.
+- ``--top-k INT``: number of consistent candidates the verification
+ reasoner enumerates per example. Only used with ``--method verification``.
+ Defaults to ``1``.
+
+Examples:
+
+.. code:: bash
+
+ # Standard pipeline (default)
+ python main.py
+
+ # A3BL with similarity-based consistency
+ python main.py --method a3bl --dist-func similarity
+
+ # Semi-supervised: keep 30% of pseudo-labels, abduce the rest
+ python main.py --labeled-ratio 0.3
+
+ # Rejection-aware reasoning
+ python main.py --dist-func rejection
+
+ # Verification Learning: train on the top-3 consistent candidates
+ python main.py --method verification --top-k 3
Environment
diff --git a/docs/Intro/Advanced.rst b/docs/Intro/Advanced.rst
new file mode 100644
index 000000000..b10ab57b3
--- /dev/null
+++ b/docs/Intro/Advanced.rst
@@ -0,0 +1,179 @@
+.. _Advanced:
+
+Advanced Topics
+===============
+
+The standard ABL pipeline (``BasicNN`` + ``ABLModel`` + ``Reasoner`` +
+``SimpleBridge``) covers the majority of tasks. ABLkit also ships a few
+drop-in variants for settings where the standard pipeline does not quite
+fit. This page collects them in one place.
+
+The four topics below are independent: pick whichever ones apply to your
+task.
+
+* :ref:`advanced-multilabel`: when each instance can carry **multiple
+ active labels** (sigmoid + binary indicator vectors) rather than a
+ single class.
+* :ref:`advanced-semisupervised`: when **part of the training set
+ carries ground-truth pseudo-labels** and the rest must be abduced.
+* :ref:`advanced-a3bl`: when many label assignments are consistent with
+ the knowledge base and we want to aggregate them into a **soft
+ label** instead of picking the single best one.
+* :ref:`advanced-verification`: when we want to train against the
+ **top-K consistent label assignments** by joint probability rather
+ than a single best candidate.
+
+
+.. _advanced-multilabel:
+
+Multi-Label Models
+------------------
+
+By default ``BasicNN`` and ``ABLModel`` assume a single-label
+multi-class setting: softmax over classes, ``argmax`` at prediction
+time, one integer label per instance. For tasks where each instance is
+described by a *vector* of independent binary attributes (e.g., the 21
+binary concepts in BDD-OIA), ABLkit provides multi-label drop-in
+replacements:
+
+* :class:`~ablkit.learning.MultiLabelBasicNN`: sigmoid output,
+ threshold at 0.5 for prediction, ``MultiLabelClassificationDataset``
+ for training.
+* :class:`~ablkit.learning.MultiLabelABLModel`: wraps a multi-label
+ base model and thresholds per-label probabilities into binary
+ indicator vectors stored on ``pred_idx``.
+* :class:`~ablkit.learning.MultiLabelClassificationDataset`: stores
+ ``Y`` as a ``FloatTensor`` so it can be fed directly into
+ ``BCEWithLogitsLoss``.
+
+Typical usage swaps the standard classes 1-for-1:
+
+.. code:: python
+
+ import torch.nn as nn
+ from torch import optim
+
+ from ablkit.learning import MultiLabelABLModel, MultiLabelBasicNN
+
+ net = MyMultiLabelNet() # PyTorch model with num_labels outputs
+ loss_fn = nn.BCEWithLogitsLoss()
+ optimizer = optim.Adam(net.parameters(), lr=2e-3)
+
+ base_model = MultiLabelBasicNN(net, loss_fn, optimizer, device="cpu",
+ batch_size=32, num_epochs=1)
+ model = MultiLabelABLModel(base_model)
+
+See the BDD-OIA example for an end-to-end multi-label pipeline.
+
+
+.. _advanced-semisupervised:
+
+Semi-Supervised Training
+------------------------
+
+When part of the training set already carries ground-truth
+pseudo-labels (and the rest is unlabeled), ``SimpleBridge`` can be
+asked to use those labels directly instead of abducing them.
+
+The mechanism is purely a flag on ``SimpleBridge.train``:
+
+* Provide a ``train_data`` tuple ``(X, gt_pseudo_label, Y)`` where the
+ ``gt_pseudo_label`` for unlabeled examples is ``None``.
+* Pass ``use_supervised_data=True``.
+
+Under the hood the bridge calls
+``Reasoner.batch_supervised_abduce``, which keeps existing
+``gt_pseudo_label`` values verbatim and only abduces a candidate for
+the ``None`` entries:
+
+.. code:: python
+
+ bridge.train(
+ train_data=(X, pseudo_label_with_some_None, Y),
+ use_supervised_data=True,
+ loops=50,
+ segment_size=0.01,
+ )
+
+The ``--labeled-ratio`` flag in the MNIST Addition example
+demonstrates how to mask out a fraction of pseudo-labels and feed the
+result through this flow.
+
+
+.. _advanced-a3bl:
+
+A3BL: Ambiguity-Aware Abductive Learning
+----------------------------------------
+
+When many label assignments are consistent with the knowledge base
+for a given example, picking only the lowest-distance candidate
+discards useful signal. A3BL (Ambiguity-Aware Abductive Learning)
+keeps the top candidates, weights them by their joint probability,
+and trains the model on the resulting *soft label distribution*.
+
+ABLkit ships two classes:
+
+* :class:`~ablkit.reasoning.A3BLReasoner`: enumerates valid
+ candidates, scores them via a softmax over per-symbol probabilities,
+ and aggregates the top-K into a soft label.
+* :class:`~ablkit.bridge.A3BLBridge`: runs the ambiguity-aware
+ prediction → soft-label-abduction → train loop.
+
+Minimal wiring:
+
+.. code:: python
+
+ from ablkit.bridge import A3BLBridge
+ from ablkit.reasoning import A3BLReasoner
+
+ reasoner = A3BLReasoner(kb, topK=16, temperature=0.2)
+ bridge = A3BLBridge(model, reasoner, metric_list)
+ bridge.train(train_data, loops=2, segment_size=0.01)
+
+Reference: https://github.com/Hao-Yuan-He/A3BL
+
+
+.. _advanced-verification:
+
+Verification Learning
+---------------------
+
+Verification Learning replaces the standard "abduce the single best
+candidate" step with a top-K enumeration: starting from the most
+probable joint label assignment, the search walks the per-symbol
+probability lattice in **descending joint-probability order** and
+collects the first ``top_k`` candidates that satisfy the knowledge
+base. The model is then trained once per candidate per segment.
+
+ABLkit ships two classes (consolidated in
+``ablkit/reasoning/reasoner.py`` and
+``ablkit/bridge/verification_bridge.py``):
+
+* :class:`~ablkit.reasoning.VerificationReasoner`: exposes
+ ``top_k_candidates(pred_prob, y)`` and the batched variant.
+* :class:`~ablkit.bridge.VerificationBridge`: drives the
+ predict → enumerate → train-per-candidate loop.
+
+Helpers usable without a reasoner instance:
+
+* :func:`ablkit.reasoning.reasoner.enumerate_label_assignments`: a
+ generator over label assignments in descending joint-probability
+ order.
+* :func:`ablkit.reasoning.reasoner.top_k_satisfying`: wraps the
+ generator with a user predicate and a fallback when nothing matches.
+
+Minimal wiring:
+
+.. code:: python
+
+ from ablkit.bridge import VerificationBridge
+ from ablkit.reasoning import VerificationReasoner
+
+ reasoner = VerificationReasoner(kb, top_k=3, max_iter=10000)
+ bridge = VerificationBridge(model, reasoner, metric_list)
+ bridge.train(train_data, loops=2, segment_size=0.01)
+
+Reference: https://github.com/VerificationLearning/VerificationLearning
+
+The ``--method verification --top-k K`` flags in the MNIST Addition
+example demonstrate the full pipeline.
diff --git a/docs/Intro/Basics.rst b/docs/Intro/Basics.rst
index 5cb83847b..73af0b69b 100644
--- a/docs/Intro/Basics.rst
+++ b/docs/Intro/Basics.rst
@@ -10,7 +10,7 @@ Learn the Basics
================
Modules in ABLkit
-----------------------
+-----------------
ABLkit is an efficient toolkit for `Abductive Learning <../Overview/Abductive-Learning.html>`_ (ABL),
a paradigm which integrates machine learning and logical reasoning in a balanced-loop.
@@ -50,7 +50,7 @@ learning, and reasoning, facilitating the training and testing
of the entire ABL framework.
Use ABLkit Step by Step
-----------------------------
+-----------------------
In a typical ABL process, as illustrated below,
data inputs are first predicted by the learning model ``ABLModel.predict``, and the outcomes are pseudo-labels.
diff --git a/docs/Intro/Learning.rst b/docs/Intro/Learning.rst
index efb1e9444..d460ea387 100644
--- a/docs/Intro/Learning.rst
+++ b/docs/Intro/Learning.rst
@@ -43,13 +43,13 @@ For a PyTorch-based neural network, we need to encapsulate it within a ``BasicNN
.. code:: python
# Load a PyTorch-based neural network
- cls = torchvision.models.resnet18(pretrained=True)
+ net = torchvision.models.resnet18(pretrained=True)
# loss function and optimizer are used for training
- loss_fn = torch.nn.CrossEntropyLoss()
- optimizer = torch.optim.Adam(cls.parameters())
+ loss_fn = torch.nn.CrossEntropyLoss()
+ optimizer = torch.optim.Adam(net.parameters())
- base_model = BasicNN(cls, loss_fn, optimizer)
+ base_model = BasicNN(net, loss_fn, optimizer)
BasicNN
^^^^^^^
diff --git a/docs/Intro/Quick-Start.rst b/docs/Intro/Quick-Start.rst
index 4563673e5..3eb6252e6 100644
--- a/docs/Intro/Quick-Start.rst
+++ b/docs/Intro/Quick-Start.rst
@@ -41,7 +41,7 @@ In this example, we build a simple LeNet5 network as the base model.
# The 'models' module below is located in 'examples/mnist_add/'
from models.nn import LeNet5
- cls = LeNet5(num_classes=10)
+ net = LeNet5(num_classes=10)
To facilitate uniform processing, ABLkit provides the ``BasicNN`` class to convert a PyTorch-based neural network into a format compatible with scikit-learn models. To construct a ``BasicNN`` instance, aside from the network itself, we also need to define a loss function, an optimizer, and the computing device.
@@ -51,9 +51,9 @@ To facilitate uniform processing, ABLkit provides the ``BasicNN`` class to conve
from ablkit.learning import BasicNN
loss_fn = torch.nn.CrossEntropyLoss()
- optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001)
+ optimizer = torch.optim.RMSprop(net.parameters(), lr=0.001)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- base_model = BasicNN(model=cls, loss_fn=loss_fn, optimizer=optimizer, device=device)
+ base_model = BasicNN(model=net, loss_fn=loss_fn, optimizer=optimizer, device=device)
The base model built above is trained to make predictions on instance-level data (e.g., a single image), while ABL deals with example-level data. To bridge this gap, we wrap the ``base_model`` into an instance of ``ABLModel``. This class serves as a unified wrapper for base models, facilitating the learning part to train, test, and predict on example-level data, (e.g., images that comprise an equation).
@@ -109,7 +109,7 @@ ABLkit provides two basic metrics, namely ``SymbolAccuracy`` and ``ReasoningMetr
Read more about `building evaluation metrics `_
Bridging Learning and Reasoning
----------------------------------------
+-------------------------------
Now, we use ``SimpleBridge`` to combine learning and reasoning in a unified ABL framework.
diff --git a/docs/Intro/Reasoning.rst b/docs/Intro/Reasoning.rst
index 33ce16319..935cf62af 100644
--- a/docs/Intro/Reasoning.rst
+++ b/docs/Intro/Reasoning.rst
@@ -8,7 +8,7 @@
Reasoning part
-===============
+==============
In this section, we will look at how to build the reasoning part, which
leverages domain knowledge and performs deductive or abductive reasoning.
@@ -318,11 +318,20 @@ specify:
used when determining consistency between your prediction and
candidate returned from knowledge base. This can be either a user-defined function
or one that is predefined. Valid predefined options include
- “hamming”, “confidence” and “avg_confidence”. For “hamming”, it directly calculates the Hamming distance between the
- predicted pseudo-label in the data example and candidate. For “confidence”, it
- calculates the confidence distance between the predicted probabilities in the data
- example and each candidate, where the confidence distance is defined as 1 - the product
- of prediction probabilities in “confidence” and 1 - the average of prediction probabilities in “avg_confidence”.
+ “hamming”, “confidence”, “avg_confidence”, “similarity” and “rejection”.
+ For “hamming”, it directly calculates the Hamming distance between the
+ predicted pseudo-label in the data example and candidate. For “confidence” and
+ “avg_confidence”, it calculates the confidence distance between the predicted
+ probabilities and each candidate, defined as ``1 - product`` and ``1 - average``
+ of the candidate's per-symbol probabilities respectively. For “similarity”,
+ it compares candidates against the geometry of the model's embeddings.
+ This requires the wrapped PyTorch model to implement
+ ``extract_features(x)`` (returning, for example, penultimate-layer
+ activations); ``BasicNN`` then surfaces them via its own
+ ``extract_features`` method, and ``ABLModel`` automatically stores the
+ resulting embeddings on ``data_example.embeddings`` for the reasoner.
+ For “rejection”, it combines the confidence distance with a candidate-complexity
+ penalty so that shorter candidates are favored when scores are close.
Defaults to “confidence”.
- ``idx_to_label`` (dict, optional), a mapping from index in the base model to label.
If not provided, a default order-based index to label mapping is created.
@@ -333,7 +342,7 @@ The main method implemented by ``Reasoner`` is
based on the distance function defined in ``dist_func``.
MNIST Addition example (cont.)
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
As an example, consider these data examples for MNIST Addition:
@@ -377,7 +386,7 @@ Out:
Specifically, as mentioned before, “confidence” calculates the distance between the data
example and candidates based on the confidence derived from the predicted probability.
-Take ``example1`` as an example, the ``pred_prob`` in it indicates a higher
-confidence that the first label should be "1" rather than "7". Therefore, among the
+Take ``example1`` as an example, the ``pred_prob`` in it indicates a higher
+confidence that the first label should be "1" rather than "7". Therefore, among the
candidates [1,7] and [7,1], it would be closer to [1,7] (as its first label is "1").
diff --git a/docs/index.rst b/docs/index.rst
index d398c4d5a..f3ce3914e 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -127,6 +127,7 @@ To cite ABLkit, please cite the following paper: `Huang et al., 2024 List[List[Any]]:
+ pred_idx = data_examples.pred_idx # [ ndarray(1,nc),... ]
+ pred_pseudo_label = []
+ for sub_list in pred_idx:
+ sub_list = sub_list.squeeze() # 1 x nc -> nc
+ pred_pseudo_label.append([self.reasoner.idx_to_label[_idx] for _idx in sub_list])
+ data_examples.pred_pseudo_label = pred_pseudo_label
+ return data_examples.pred_pseudo_label
+
+ def pseudo_label_to_idx(self, data_examples: ListData) -> List[List[Any]]:
+ abduced_pseudo_label = data_examples.abduced_pseudo_label
+ abduced_idx = []
+ for sub_list in abduced_pseudo_label:
+ sub_list = np.array([self.reasoner.label_to_idx[_lab] for _lab in sub_list])
+ abduced_idx.append(sub_list)
+ data_examples.abduced_idx = abduced_idx
+ return data_examples.abduced_idx
diff --git a/examples/bdd_oia/dataset/data_util.py b/examples/bdd_oia/dataset/data_util.py
new file mode 100644
index 000000000..a4a9ee965
--- /dev/null
+++ b/examples/bdd_oia/dataset/data_util.py
@@ -0,0 +1,19 @@
+import os
+import numpy as np
+
+CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))
+
+
+def get_dataset(fname, get_pseudo_label=True):
+ fname = os.path.join(CURRENT_DIR, fname)
+ data = np.load(fname)
+ X = data["X"]
+ X = [[emb.astype(np.float32)] for emb in X]
+ pseudo_label = data["pseudo_label"].astype(int).tolist() if get_pseudo_label else None
+ Y = data["Y"][:, :4].astype(int).tolist()
+ Y = [tuple(y) for y in Y]
+ return X, pseudo_label, Y
+
+
+if __name__ == "__main__":
+ dataset = get_dataset("val.npz")
diff --git a/examples/bdd_oia/main.py b/examples/bdd_oia/main.py
new file mode 100644
index 000000000..d720b6493
--- /dev/null
+++ b/examples/bdd_oia/main.py
@@ -0,0 +1,147 @@
+import argparse
+import os.path as osp
+import numpy as np
+import torch
+from torch import optim
+import torch.nn as nn
+
+from ablkit.data.evaluation import SymbolAccuracy
+from ablkit.learning import MultiLabelABLModel, MultiLabelBasicNN
+from ablkit.reasoning import Reasoner
+from ablkit.utils import ABLLogger, print_log
+
+from models.nn import ConceptNet
+from reasoning.bddkb import BDDKB
+from dataset.data_util import get_dataset
+from bridge import BDDBridge
+from metric import BDDReasoningMetric
+
+
+def multi_label_confidence_dist(data_example, candidates, candidates_idxs, reasoning_results):
+ pred_prob = data_example.pred_prob.T # nc x 1
+ pred_prob = np.concatenate([1 - pred_prob, pred_prob], axis=1) # nc x 2
+ cols = np.arange(len(candidates_idxs[0]))[None, :]
+ corr_prob = pred_prob[cols, candidates_idxs]
+ costs = -np.sum(np.log(corr_prob + 1e-6), axis=1)
+ return costs
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="BDD-OIA example")
+ parser.add_argument(
+ "--no-cuda", action="store_true", default=False, help="disables CUDA training"
+ )
+ parser.add_argument(
+ "--epochs",
+ type=int,
+ default=1,
+ help="number of epochs in each learning loop iteration (default : 1)",
+ )
+ parser.add_argument(
+ "--lr", type=float, default=2e-3, help="base model learning rate (default : 0.002)"
+ )
+ parser.add_argument(
+ "--batch-size", type=int, default=32, help="base model batch size (default : 32)"
+ )
+ parser.add_argument(
+ "--loops", type=int, default=2, help="number of loop iterations (default : 2)"
+ )
+ parser.add_argument(
+ "--segment_size", type=float, default=0.01, help="segment size (default : 0.01)"
+ )
+ parser.add_argument("--save_interval", type=int, default=1, help="save interval (default : 1)")
+ parser.add_argument(
+ "--max-revision", type=int, default=3, help="maximum revision in reasoner (default : 3)"
+ )
+ parser.add_argument(
+ "--require-more-revision",
+ type=int,
+ default=3,
+ help="require more revision in reasoner (default : 3)",
+ )
+
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
+
+ # Build logger
+ print_log("Abductive Learning on the BDD-OIA example.", logger="current")
+
+ # -- Working with Data ------------------------------
+ print_log("Working with Data.", logger="current")
+ train_data = get_dataset(fname="train.npz", get_pseudo_label=True)
+ val_data = get_dataset(fname="val.npz", get_pseudo_label=True)
+ test_data = get_dataset(fname="test.npz", get_pseudo_label=True)
+
+ # -- Building the Learning Part ---------------------
+ print_log("Building the Learning Part.", logger="current")
+
+ # Build necessary components for MultiLabelBasicNN
+ net = ConceptNet()
+ loss_fn = nn.BCEWithLogitsLoss()
+ optimizer = optim.Adam(net.parameters(), lr=args.lr)
+ use_cuda = not args.no_cuda and torch.cuda.is_available()
+ device = torch.device("cuda" if use_cuda else "cpu")
+ scheduler = optim.lr_scheduler.OneCycleLR(
+ optimizer,
+ max_lr=args.lr,
+ pct_start=0.15,
+ epochs=args.loops,
+ steps_per_epoch=int(1 / args.segment_size) + 1,
+ )
+
+ base_model = MultiLabelBasicNN(
+ net,
+ loss_fn,
+ optimizer,
+ scheduler=scheduler,
+ device=device,
+ batch_size=args.batch_size,
+ num_epochs=args.epochs,
+ )
+
+ model = MultiLabelABLModel(base_model)
+
+ # -- Building the Reasoning Part --------------------
+ print_log("Building the Reasoning Part.", logger="current")
+
+ # Build knowledge base
+ kb = BDDKB()
+
+ # Create reasoner
+ reasoner = Reasoner(
+ kb,
+ dist_func=multi_label_confidence_dist,
+ max_revision=args.max_revision,
+ require_more_revision=args.require_more_revision,
+ )
+
+ # -- Building Evaluation Metrics --------------------
+ print_log("Building Evaluation Metrics.", logger="current")
+ metric_list = [SymbolAccuracy(prefix="bdd_oia"), BDDReasoningMetric(kb=kb, prefix="bdd_oia")]
+
+ # -- Bridging Learning and Reasoning ----------------
+ print_log("Bridge Learning and Reasoning.", logger="current")
+ bridge = BDDBridge(model, reasoner, metric_list)
+
+ # Retrieve the directory of the Log file and define the directory for saving the model weights.
+ log_dir = ABLLogger.get_current_instance().log_dir
+ weights_dir = osp.join(log_dir, "weights")
+
+ # Train and Test
+ bridge.train(
+ train_data=train_data,
+ val_data=val_data,
+ loops=args.loops,
+ segment_size=args.segment_size,
+ save_interval=args.save_interval,
+ save_dir=weights_dir,
+ )
+ bridge.test(test_data)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/bdd_oia/metric.py b/examples/bdd_oia/metric.py
new file mode 100644
index 000000000..e2042ecfc
--- /dev/null
+++ b/examples/bdd_oia/metric.py
@@ -0,0 +1,27 @@
+from typing import Optional
+
+from ablkit.reasoning import KBBase
+from ablkit.data import BaseMetric, ListData
+
+
+class BDDReasoningMetric(BaseMetric):
+ def __init__(self, kb: KBBase, prefix: Optional[str] = None) -> None:
+ super().__init__(prefix)
+ self.kb = kb
+
+ def process(self, data_examples: ListData) -> None:
+ pred_pseudo_label_list = data_examples.pred_pseudo_label
+ y_list = data_examples.Y
+ x_list = data_examples.X
+ for pred_pseudo_label, y, x in zip(pred_pseudo_label_list, y_list, x_list):
+ pred_y = self.kb.logic_forward(
+ pred_pseudo_label, *(x,) if self.kb._num_args == 2 else ()
+ )
+ for py, yy in zip(pred_y, y):
+ self.results.append(int(py == yy))
+
+ def compute_metrics(self) -> dict:
+ results = self.results
+ metrics = dict()
+ metrics["reasoning_accuracy"] = sum(results) / len(results)
+ return metrics
diff --git a/examples/bdd_oia/models/__init__.py b/examples/bdd_oia/models/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/examples/bdd_oia/models/nn.py b/examples/bdd_oia/models/nn.py
new file mode 100644
index 000000000..8ff817335
--- /dev/null
+++ b/examples/bdd_oia/models/nn.py
@@ -0,0 +1,24 @@
+from torch import nn
+
+
+class SimpleNet(nn.Module):
+ def __init__(self, num_features=2048, num_concepts=21):
+ super(SimpleNet, self).__init__()
+ self.fc = nn.Linear(num_features, num_concepts)
+
+ def forward(self, x):
+ return self.fc(x)
+
+
+class ConceptNet(nn.Module):
+ def __init__(self, num_features=2048, num_concepts=21):
+ super(ConceptNet, self).__init__()
+ intermidate_dim = 256
+ self.fc = nn.Sequential(
+ nn.Linear(num_features, intermidate_dim),
+ nn.SiLU(),
+ nn.Linear(intermidate_dim, num_concepts),
+ )
+
+ def forward(self, x):
+ return self.fc(x)
diff --git a/examples/bdd_oia/reasoning/__init__.py b/examples/bdd_oia/reasoning/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/examples/bdd_oia/reasoning/bddkb.py b/examples/bdd_oia/reasoning/bddkb.py
new file mode 100644
index 000000000..02e3d6a4b
--- /dev/null
+++ b/examples/bdd_oia/reasoning/bddkb.py
@@ -0,0 +1,70 @@
+# -*- coding: utf-8 -*-
+from ablkit.reasoning import KBBase
+
+
+class BDDKB(KBBase):
+ def __init__(self, pseudo_label_list=None):
+ if pseudo_label_list is None:
+ pseudo_label_list = [0, 1]
+ super().__init__(pseudo_label_list)
+
+ def logic_forward(self, attrs):
+ """
+ Abduction space
+ (0, 1, 0, 0) 610812
+ (0, 1, 0, 1) 75012
+ (0, 1, 1, 0) 75012
+ (0, 1, 1, 1) 9212
+ (1, 0, 0, 0) 12996
+ (1, 0, 0, 1) 1596
+ (1, 0, 1, 0) 1596
+ (1, 0, 1, 1) 196
+ """
+ if len(attrs) != 21:
+ raise ValueError(
+ f"BDDKB.logic_forward expects exactly 21 concept attributes, got {len(attrs)}."
+ )
+ (
+ green_light,
+ follow,
+ road_clear,
+ red_light,
+ traffic_sign,
+ car,
+ person,
+ rider,
+ other_obstacle,
+ left_lane,
+ left_green_light,
+ left_follow,
+ no_left_lane,
+ left_obstacle,
+ left_solid_line,
+ right_lane,
+ right_green_light,
+ right_follow,
+ no_right_lane,
+ right_obstacle,
+ right_solid_line,
+ ) = attrs
+
+ illegal_return = (0, 0, 0, 0)
+ if red_light == green_light == 1:
+ return illegal_return
+ obstacle = car or person or rider or other_obstacle
+ if road_clear == obstacle:
+ return illegal_return
+ move_forward = green_light or follow or road_clear
+ stop = red_light or traffic_sign or obstacle
+ if stop:
+ move_forward = 0
+
+ can_turn_left = left_lane or left_green_light or left_follow
+ cannot_turn_left = no_left_lane or left_obstacle or left_solid_line
+ turn_left = can_turn_left and int(not cannot_turn_left)
+
+ can_turn_right = right_lane or right_green_light or right_follow
+ cannot_turn_right = no_right_lane or right_obstacle or right_solid_line
+ turn_right = can_turn_right and int(not cannot_turn_right)
+
+ return move_forward, stop, turn_left, turn_right
diff --git a/examples/bdd_oia/requirements.txt b/examples/bdd_oia/requirements.txt
new file mode 100644
index 000000000..a1d6490d8
--- /dev/null
+++ b/examples/bdd_oia/requirements.txt
@@ -0,0 +1 @@
+ablkit
diff --git a/examples/hed/bridge.py b/examples/hed/bridge.py
index 82ab1779b..08605b9c3 100644
--- a/examples/hed/bridge.py
+++ b/examples/hed/bridge.py
@@ -26,7 +26,7 @@ def __init__(
) -> None:
super().__init__(model, reasoner, metric_list)
- def pretrain(self, weights_dir):
+ def pretrain(self, weights_dir: str) -> None:
if not os.path.exists(os.path.join(weights_dir, "pretrain_weights.pth")):
print_log("Pretrain Start", logger="current")
@@ -64,7 +64,7 @@ def pretrain(self, weights_dir):
self.model.load(load_path=os.path.join(weights_dir, "pretrain_weights.pth"))
- def select_mapping_and_abduce(self, data_examples: ListData):
+ def select_mapping_and_abduce(self, data_examples: ListData) -> List[List[Any]]:
candidate_mappings = gen_mappings([0, 1, 2, 3], ["+", "=", 0, 1])
mapping_score = []
abduced_pseudo_label_list = []
@@ -87,11 +87,13 @@ def select_mapping_and_abduce(self, data_examples: ListData):
return data_examples.abduced_pseudo_label
- def abduce_pseudo_label(self, data_examples: ListData):
+ def abduce_pseudo_label(self, data_examples: ListData) -> List[List[Any]]:
self.reasoner.abduce(data_examples)
return data_examples.abduced_pseudo_label
- def check_training_impact(self, filtered_data_examples, data_examples):
+ def check_training_impact(
+ self, filtered_data_examples: ListData, data_examples: ListData
+ ) -> bool:
character_accuracy = self.model.valid(filtered_data_examples)
revisible_ratio = len(filtered_data_examples.X) / len(data_examples.X)
log_string = (
@@ -104,7 +106,7 @@ def check_training_impact(self, filtered_data_examples, data_examples):
return True
return False
- def check_rule_quality(self, rule, val_data, equation_len):
+ def check_rule_quality(self, rule: Any, val_data: Any, equation_len: int) -> bool:
val_X_true = self.data_preprocess(val_data[1], equation_len)
val_X_false = self.data_preprocess(val_data[0], equation_len)
@@ -121,7 +123,7 @@ def check_rule_quality(self, rule, val_data, equation_len):
return True
return False
- def calc_consistent_ratio(self, data_examples, rule):
+ def calc_consistent_ratio(self, data_examples: ListData, rule: Any) -> float:
self.predict(data_examples)
pred_pseudo_label = self.idx_to_pseudo_label(data_examples)
consistent_num = sum(
@@ -129,7 +131,9 @@ def calc_consistent_ratio(self, data_examples, rule):
)
return consistent_num / len(data_examples.X)
- def get_rules_from_data(self, data_examples, samples_per_rule, samples_num):
+ def get_rules_from_data(
+ self, data_examples: ListData, samples_per_rule: int, samples_num: int
+ ) -> List[Any]:
rules = []
sampler = InfiniteSampler(len(data_examples), batch_size=samples_per_rule)
@@ -159,7 +163,7 @@ def get_rules_from_data(self, data_examples, samples_per_rule, samples_num):
return rules
@staticmethod
- def filter_empty(data_examples: ListData):
+ def filter_empty(data_examples: ListData) -> ListData:
consistent_dix = [
i
for i in range(len(data_examples.abduced_pseudo_label))
@@ -168,7 +172,7 @@ def filter_empty(data_examples: ListData):
return data_examples[consistent_dix]
@staticmethod
- def select_rules(rule_dict):
+ def select_rules(rule_dict: dict) -> List[Any]:
add_nums_dict = {}
for r in list(rule_dict):
add_nums = str(r.split("]")[0].split("[")[1]) + str(
@@ -185,7 +189,7 @@ def select_rules(rule_dict):
add_nums_dict[add_nums] = r
return list(rule_dict)
- def data_preprocess(self, data, equation_len) -> ListData:
+ def data_preprocess(self, data: Any, equation_len: int) -> ListData:
data_examples = ListData()
data_examples.X = data[equation_len] + data[equation_len + 1]
data_examples.gt_pseudo_label = None
@@ -193,7 +197,15 @@ def data_preprocess(self, data, equation_len) -> ListData:
return data_examples
- def train(self, train_data, val_data, segment_size=10, min_len=5, max_len=8, save_dir="./"):
+ def train(
+ self,
+ train_data: Any,
+ val_data: Any,
+ segment_size: int = 10,
+ min_len: int = 5,
+ max_len: int = 8,
+ save_dir: str = "./",
+ ) -> None:
for equation_len in range(min_len, max_len):
print_log(
f"============== equation_len: {equation_len}-{equation_len + 1} ================",
diff --git a/examples/hed/datasets/get_dataset.py b/examples/hed/datasets/get_dataset.py
index cea07bc38..f16701bb1 100644
--- a/examples/hed/datasets/get_dataset.py
+++ b/examples/hed/datasets/get_dataset.py
@@ -10,6 +10,8 @@
import numpy as np
from torchvision.transforms import transforms
+from ablkit.utils import print_log
+
CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))
@@ -79,10 +81,10 @@ def get_dataset(dataset="mnist", train=True):
data_dir = CURRENT_DIR + "/mnist_images"
if not os.path.exists(data_dir):
- print("Dataset not exist, downloading it...")
+ print_log("Dataset not present, downloading it...", logger="current")
url = "https://drive.google.com/u/0/uc?id=1W2AUn_fnXa4XkgLk4d17K3bEgpae8GMg&export=download"
download_and_unzip(url, os.path.join(CURRENT_DIR, "HED.zip"))
- print("Download and extraction complete.")
+ print_log("Download and extraction complete.", logger="current")
if train:
file = os.path.join(data_dir, "expr_train.json")
diff --git a/examples/hed/main.py b/examples/hed/main.py
index c1d7e76c8..e25e645a2 100644
--- a/examples/hed/main.py
+++ b/examples/hed/main.py
@@ -61,15 +61,15 @@ def main():
print_log("Building the Learning Part.", logger="current")
# Build necessary components for BasicNN
- cls = SymbolNet(num_classes=4)
+ net = SymbolNet(num_classes=4)
loss_fn = nn.CrossEntropyLoss()
- optimizer = torch.optim.RMSprop(cls.parameters(), lr=args.lr, weight_decay=args.weight_decay)
+ optimizer = torch.optim.RMSprop(net.parameters(), lr=args.lr, weight_decay=args.weight_decay)
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
# Build BasicNN
base_model = BasicNN(
- cls,
+ net,
loss_fn,
optimizer,
device=device,
diff --git a/examples/hed/models/__init__.py b/examples/hed/models/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/examples/hed/reasoning/reasoning.py b/examples/hed/reasoning/reasoning.py
index 1d27763f8..22293a6d5 100644
--- a/examples/hed/reasoning/reasoning.py
+++ b/examples/hed/reasoning/reasoning.py
@@ -1,8 +1,10 @@
import math
import os
+from typing import Any, List, Optional, Union
import numpy as np
+from ablkit.data.structures import ListData
from ablkit.reasoning import PrologKB, Reasoner
from ablkit.utils import reform_list
@@ -11,37 +13,44 @@
class HedKB(PrologKB):
def __init__(
- self, pseudo_label_list=[1, 0, "+", "="], pl_file=os.path.join(CURRENT_DIR, "learn_add.pl")
- ):
+ self,
+ pseudo_label_list: List[Any] = [1, 0, "+", "="],
+ pl_file: str = os.path.join(CURRENT_DIR, "learn_add.pl"),
+ ) -> None:
pl_file = pl_file.replace("\\", "/")
super().__init__(pseudo_label_list, pl_file)
- self.learned_rules = {}
+ self.learned_rules: dict = {}
- def consist_rule(self, exs, rules):
+ def consist_rule(self, exs: Any, rules: Any) -> bool:
rules = str(rules).replace("'", "")
return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0
- def abduce_rules(self, pred_res):
+ def abduce_rules(self, pred_res: Any) -> Optional[List[Any]]:
prolog_result = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res))
if len(prolog_result) == 0:
return None
prolog_rules = prolog_result[0]["X"]
- rules = [rule.value for rule in prolog_rules]
- return rules
+ return [rule.value for rule in prolog_rules]
class HedReasoner(Reasoner):
- def revise_at_idx(self, data_example):
+ def revise_at_idx(self, data_example: ListData) -> Any:
revision_idx = np.where(np.array(data_example.flatten("revision_flag")) != 0)[0]
candidate = self.kb.revise_at_idx(
data_example.pred_pseudo_label, data_example.Y, data_example.X, revision_idx
)
return candidate
- def zoopt_budget(self, symbol_num):
+ def zoopt_budget(self, symbol_num: int) -> int:
return 200
- def zoopt_score(self, symbol_num, data_example, sol, get_score=True):
+ def zoopt_score(
+ self,
+ symbol_num: int,
+ data_example: ListData,
+ sol: Any,
+ get_score: bool = True,
+ ) -> Union[float, List[int]]:
revision_flag = reform_list(
list(sol.get_x().astype(np.int32)), data_example.pred_pseudo_label
)
@@ -82,7 +91,7 @@ def zoopt_score(self, symbol_num, data_example, sol, get_score=True):
else:
return max_consistent_idxs
- def abduce(self, data_example):
+ def abduce(self, data_example: ListData) -> List[List[Any]]:
symbol_num = data_example.elements_num("pred_pseudo_label")
max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num)
@@ -98,5 +107,5 @@ def abduce(self, data_example):
data_example.abduced_pseudo_label = abduced_pseudo_label
return abduced_pseudo_label
- def abduce_rules(self, pred_res):
+ def abduce_rules(self, pred_res: Any) -> Optional[List[Any]]:
return self.kb.abduce_rules(pred_res)
diff --git a/examples/hwf/datasets/get_dataset.py b/examples/hwf/datasets/get_dataset.py
index 0700d3ceb..4109910a0 100644
--- a/examples/hwf/datasets/get_dataset.py
+++ b/examples/hwf/datasets/get_dataset.py
@@ -6,6 +6,8 @@
from PIL import Image
from torchvision.transforms import transforms
+from ablkit.utils import print_log
+
CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))
img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1,))])
@@ -30,10 +32,10 @@ def get_dataset(train=True, get_pseudo_label=False):
data_dir = CURRENT_DIR + "/data"
if not os.path.exists(data_dir):
- print("Dataset not exist, downloading it...")
+ print_log("Dataset not present, downloading it...", logger="current")
url = "https://drive.google.com/u/0/uc?id=1t52OE2Wdm5GdShX1jD2Wy8phCllk0r8I&export=download"
download_and_unzip(url, os.path.join(CURRENT_DIR, "HWF.zip"))
- print("Download and extraction complete.")
+ print_log("Download and extraction complete.", logger="current")
if train:
file = os.path.join(data_dir, "expr_train.json")
diff --git a/examples/hwf/main.py b/examples/hwf/main.py
index 6ea03804c..45c427f44 100644
--- a/examples/hwf/main.py
+++ b/examples/hwf/main.py
@@ -1,5 +1,8 @@
import argparse
+import ast
+import operator
import os.path as osp
+from typing import List
import numpy as np
import torch
@@ -14,56 +17,80 @@
from datasets import get_dataset
from models.nn import SymbolNet
+DIGITS = {"1", "2", "3", "4", "5", "6", "7", "8", "9"}
+OPERATORS = {"+", "-", "*", "/"}
+PSEUDO_LABEL_LIST = sorted(DIGITS) + sorted(OPERATORS)
+
+
+def _is_well_formed(formula: List[str]) -> bool:
+ """Return True iff ``formula`` alternates digit-operator-digit and has odd length."""
+ if len(formula) % 2 == 0:
+ return False
+ for i, sym in enumerate(formula):
+ expected = DIGITS if i % 2 == 0 else OPERATORS
+ if sym not in expected:
+ return False
+ return True
+
+
+_SAFE_BINOPS = {
+ ast.Add: operator.add,
+ ast.Sub: operator.sub,
+ ast.Mult: operator.mul,
+ ast.Div: operator.truediv,
+}
+
+
+def _safe_eval(node: ast.AST) -> float:
+ """Evaluate an arithmetic AST restricted to the four basic operators."""
+ if isinstance(node, ast.Expression):
+ return _safe_eval(node.body)
+ if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
+ return node.value
+ if isinstance(node, ast.BinOp) and type(node.op) in _SAFE_BINOPS:
+ return _SAFE_BINOPS[type(node.op)](_safe_eval(node.left), _safe_eval(node.right))
+ raise ValueError(f"unsupported AST node: {type(node).__name__}")
+
+
+def _evaluate(formula: List[str]) -> float:
+ """Evaluate a well-formed digit/operator formula with proper precedence."""
+ try:
+ tree = ast.parse("".join(formula), mode="eval")
+ return _safe_eval(tree)
+ except (SyntaxError, ZeroDivisionError, ValueError):
+ return np.inf
+
+
+def _hwf_logic_forward(formula: List[str]) -> float:
+ """Shared ``logic_forward`` implementation for both ``HwfKB`` variants."""
+ if not _is_well_formed(formula):
+ return np.inf
+ return _evaluate(formula)
+
class HwfKB(KBBase):
def __init__(
self,
- pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"],
- max_err=1e-10,
- ):
+ pseudo_label_list: List[str] = PSEUDO_LABEL_LIST,
+ max_err: float = 1e-10,
+ ) -> None:
super().__init__(pseudo_label_list, max_err)
- def _valid_candidate(self, formula):
- if len(formula) % 2 == 0:
- return False
- for i in range(len(formula)):
- if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]:
- return False
- if i % 2 != 0 and formula[i] not in ["+", "-", "*", "/"]:
- return False
- return True
-
- # Implement the deduction function
- def logic_forward(self, formula):
- if not self._valid_candidate(formula):
- return np.inf
- return eval("".join(formula))
+ def logic_forward(self, formula: List[str]) -> float:
+ return _hwf_logic_forward(formula)
class HwfGroundKB(GroundKB):
def __init__(
self,
- pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"],
- GKB_len_list=[1, 3, 5, 7],
- max_err=1e-10,
- ):
+ pseudo_label_list: List[str] = PSEUDO_LABEL_LIST,
+ GKB_len_list: List[int] = [1, 3, 5, 7],
+ max_err: float = 1e-10,
+ ) -> None:
super().__init__(pseudo_label_list, GKB_len_list, max_err)
- def _valid_candidate(self, formula):
- if len(formula) % 2 == 0:
- return False
- for i in range(len(formula)):
- if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]:
- return False
- if i % 2 != 0 and formula[i] not in ["+", "-", "*", "/"]:
- return False
- return True
-
- # Implement the deduction function
- def logic_forward(self, formula):
- if not self._valid_candidate(formula):
- return np.inf
- return eval("".join(formula))
+ def logic_forward(self, formula: List[str]) -> float:
+ return _hwf_logic_forward(formula)
def main():
@@ -81,7 +108,7 @@ def main():
"--label-smoothing",
type=float,
default=0.2,
- help="label smoothing in cross entropy loss (default : 0.2)"
+ help="label smoothing in cross entropy loss (default : 0.2)",
)
parser.add_argument(
"--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)"
@@ -130,15 +157,20 @@ def main():
print_log("Building the Learning Part.", logger="current")
# Build necessary components for BasicNN
- cls = SymbolNet(num_classes=13, image_size=(45, 45, 1))
+ net = SymbolNet(num_classes=13, image_size=(45, 45, 1))
loss_fn = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
- optimizer = torch.optim.Adam(cls.parameters(), lr=args.lr)
+ optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
# Build BasicNN
base_model = BasicNN(
- cls, loss_fn, optimizer, device=device, batch_size=args.batch_size, num_epochs=args.epochs,
+ net,
+ loss_fn,
+ optimizer,
+ device=device,
+ batch_size=args.batch_size,
+ num_epochs=args.epochs,
)
# Build ABLModel
diff --git a/examples/hwf/models/__init__.py b/examples/hwf/models/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/examples/mnist_add/main.py b/examples/mnist_add/main.py
index c992c4e5e..41f4434a5 100644
--- a/examples/mnist_add/main.py
+++ b/examples/mnist_add/main.py
@@ -1,14 +1,22 @@
import argparse
import os.path as osp
+import random
import torch
from torch import nn
from torch.optim import RMSprop, lr_scheduler
-from ablkit.bridge import SimpleBridge
+from ablkit.bridge import A3BLBridge, SimpleBridge, VerificationBridge
from ablkit.data.evaluation import ReasoningMetric, SymbolAccuracy
from ablkit.learning import ABLModel, BasicNN
-from ablkit.reasoning import GroundKB, KBBase, PrologKB, Reasoner
+from ablkit.reasoning import (
+ A3BLReasoner,
+ GroundKB,
+ KBBase,
+ PrologKB,
+ Reasoner,
+ VerificationReasoner,
+)
from ablkit.utils import ABLLogger, print_log
from datasets import get_dataset
@@ -31,7 +39,10 @@ def logic_forward(self, nums):
return sum(nums)
-def main():
+DIST_FUNCS = ["hamming", "confidence", "avg_confidence", "similarity", "rejection"]
+
+
+def parse_args():
parser = argparse.ArgumentParser(description="MNIST Addition example")
parser.add_argument(
"--no-cuda", action="store_true", default=False, help="disables CUDA training"
@@ -59,7 +70,7 @@ def main():
"--loops", type=int, default=2, help="number of loop iterations (default : 2)"
)
parser.add_argument(
- "--segment_size", type=int, default=0.01, help="segment size (default : 0.01)"
+ "--segment_size", type=float, default=0.01, help="segment size (default : 0.01)"
)
parser.add_argument("--save_interval", type=int, default=1, help="save interval (default : 1)")
parser.add_argument(
@@ -78,37 +89,78 @@ def main():
kb_type.add_argument(
"--ground", action="store_true", default=False, help="use GroundKB (default: False)"
)
+ parser.add_argument(
+ "--method",
+ choices=["standard", "a3bl", "verification"],
+ default="standard",
+ help="learning/reasoning pipeline to use (default: standard)",
+ )
+ parser.add_argument(
+ "--dist-func",
+ choices=DIST_FUNCS,
+ default="confidence",
+ help="distance function used by the reasoner (default: confidence)",
+ )
+ parser.add_argument(
+ "--labeled-ratio",
+ type=float,
+ default=1.0,
+ help=(
+ "fraction in (0, 1] of training samples that keep their ground-truth pseudo-labels. "
+ "Values below 1.0 enable the semi-supervised pipeline (default: 1.0)"
+ ),
+ )
+ parser.add_argument(
+ "--top-k",
+ type=int,
+ default=1,
+ help=(
+ "number of consistent candidates the verification reasoner enumerates per example. "
+ "Only used when --method verification (default: 1)"
+ ),
+ )
+ parser.add_argument(
+ "--seed", type=int, default=0, help="random seed for semi-supervised split (default: 0)"
+ )
args = parser.parse_args()
+ if not (0.0 < args.labeled_ratio <= 1.0):
+ parser.error("--labeled-ratio must be in (0, 1].")
+ if args.method == "a3bl" and args.labeled_ratio < 1.0:
+ parser.error("--method a3bl does not support --labeled-ratio < 1.0.")
+ if args.method == "a3bl" and args.dist_func == "rejection":
+ parser.error("--method a3bl is not compatible with --dist-func rejection.")
+ if args.method == "verification" and args.labeled_ratio < 1.0:
+ parser.error("--method verification does not support --labeled-ratio < 1.0.")
+ if args.top_k < 1:
+ parser.error("--top-k must be >= 1.")
+ return args
- # Build logger
- print_log("Abductive Learning on the MNIST Addition example.", logger="current")
- # -- Working with Data ------------------------------
- print_log("Working with Data.", logger="current")
- train_data = get_dataset(train=True, get_pseudo_label=True)
- test_data = get_dataset(train=False, get_pseudo_label=True)
+def build_kb(args):
+ if args.prolog:
+ return PrologKB(pseudo_label_list=list(range(10)), pl_file="add.pl")
+ if args.ground:
+ return AddGroundKB()
+ return AddKB()
- # -- Building the Learning Part ---------------------
- print_log("Building the Learning Part.", logger="current")
- # Build necessary components for BasicNN
- cls = LeNet5(num_classes=10)
+def build_base_model(args):
+ net = LeNet5(num_classes=10)
loss_fn = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
- optimizer = RMSprop(cls.parameters(), lr=args.lr, alpha=args.alpha)
+ optimizer = RMSprop(net.parameters(), lr=args.lr, alpha=args.alpha)
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
+ train_passes_per_segment = args.top_k if args.method == "verification" else 1
scheduler = lr_scheduler.OneCycleLR(
optimizer,
max_lr=args.lr,
pct_start=0.15,
epochs=args.loops,
- steps_per_epoch=int(1 / args.segment_size),
+ steps_per_epoch=int(1 / args.segment_size) * train_passes_per_segment,
)
-
- # Build BasicNN
- base_model = BasicNN(
- cls,
+ return BasicNN(
+ net,
loss_fn,
optimizer,
scheduler=scheduler,
@@ -117,45 +169,88 @@ def main():
num_epochs=args.epochs,
)
- # Build ABLModel
+
+def mask_pseudo_labels(pseudo_label, ratio, seed):
+ """
+ Randomly null out ``(1 - ratio)`` of the entries in ``pseudo_label`` so the
+ semi-supervised pipeline treats them as unlabeled. The remaining entries
+ keep their ground-truth values and are used directly during abduction.
+ """
+ rng = random.Random(seed)
+ n = len(pseudo_label)
+ keep = set(rng.sample(range(n), int(round(ratio * n))))
+ return [pseudo_label[i] if i in keep else None for i in range(n)]
+
+
+def main():
+ args = parse_args()
+
+ print_log("Abductive Learning on the MNIST Addition example.", logger="current")
+
+ print_log("Working with Data.", logger="current")
+ train_data = get_dataset(train=True, get_pseudo_label=True)
+ test_data = get_dataset(train=False, get_pseudo_label=True)
+
+ val_data = None
+ if args.labeled_ratio < 1.0:
+ X, pseudo_label, Y = train_data
+ val_data = train_data
+ train_data = (X, mask_pseudo_labels(pseudo_label, args.labeled_ratio, args.seed), Y)
+ print_log(
+ f"Semi-supervised: keeping {args.labeled_ratio:.0%} of pseudo-labels.",
+ logger="current",
+ )
+
+ print_log("Building the Learning Part.", logger="current")
+ base_model = build_base_model(args)
model = ABLModel(base_model)
- # -- Building the Reasoning Part --------------------
print_log("Building the Reasoning Part.", logger="current")
-
- # Build knowledge base
- if args.prolog:
- kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="add.pl")
- elif args.ground:
- kb = AddGroundKB()
+ kb = build_kb(args)
+ if args.method == "verification":
+ reasoner = VerificationReasoner(kb, top_k=args.top_k)
+ elif args.method == "a3bl":
+ reasoner = A3BLReasoner(
+ kb,
+ dist_func=args.dist_func,
+ max_revision=args.max_revision,
+ require_more_revision=args.require_more_revision,
+ )
else:
- kb = AddKB()
-
- # Create reasoner
- reasoner = Reasoner(
- kb, max_revision=args.max_revision, require_more_revision=args.require_more_revision
- )
+ reasoner = Reasoner(
+ kb,
+ dist_func=args.dist_func,
+ max_revision=args.max_revision,
+ require_more_revision=args.require_more_revision,
+ )
- # -- Building Evaluation Metrics --------------------
print_log("Building Evaluation Metrics.", logger="current")
metric_list = [SymbolAccuracy(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")]
- # -- Bridging Learning and Reasoning ----------------
print_log("Bridge Learning and Reasoning.", logger="current")
- bridge = SimpleBridge(model, reasoner, metric_list)
+ if args.method == "verification":
+ bridge_cls = VerificationBridge
+ elif args.method == "a3bl":
+ bridge_cls = A3BLBridge
+ else:
+ bridge_cls = SimpleBridge
+ bridge = bridge_cls(model, reasoner, metric_list)
- # Retrieve the directory of the Log file and define the directory for saving the model weights.
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")
- # Train and Test
- bridge.train(
- train_data,
+ train_kwargs = dict(
loops=args.loops,
segment_size=args.segment_size,
save_interval=args.save_interval,
save_dir=weights_dir,
)
+ if args.method == "standard" and args.labeled_ratio < 1.0:
+ train_kwargs["use_supervised_data"] = True
+ if val_data is not None and args.method != "a3bl":
+ train_kwargs["val_data"] = val_data
+
+ bridge.train(train_data, **train_kwargs)
bridge.test(test_data)
diff --git a/examples/mnist_add/models/__init__.py b/examples/mnist_add/models/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/examples/mnist_add/models/nn.py b/examples/mnist_add/models/nn.py
index 93f5cda11..f1019df1e 100644
--- a/examples/mnist_add/models/nn.py
+++ b/examples/mnist_add/models/nn.py
@@ -1,4 +1,5 @@
from torch import nn
+import torch.nn.functional as F
class LeNet5(nn.Module):
@@ -20,6 +21,9 @@ def __init__(self, num_classes=10, image_size=(28, 28, 1)):
nn.ReLU(),
nn.Linear(84, num_classes),
)
+
+ def extract_features(self, x):
+ return self.encoder(x)
def forward(self, x):
x = self.encoder(x)
diff --git a/examples/zoo/kb.py b/examples/zoo/kb.py
index 9a08077f1..d39d98536 100644
--- a/examples/zoo/kb.py
+++ b/examples/zoo/kb.py
@@ -1,5 +1,5 @@
import openml
-from z3 import If, Implies, Int, Not, Solver, Sum, sat # noqa: F401
+from z3 import If, Implies, Int, Not, Solver, Sum, sat
from ablkit.reasoning import KBBase
@@ -10,84 +10,71 @@ def __init__(self):
self.solver = Solver()
- # Load information of Zoo dataset
dataset = openml.datasets.get_dataset(
dataset_id=62,
download_data=False,
download_qualities=False,
download_features_meta_data=False,
)
- X, y, categorical_indicator, attribute_names = dataset.get_data(
+ _, y, _, attribute_names = dataset.get_data(
target=dataset.default_target_attribute
)
self.attribute_names = attribute_names
self.target_names = y.cat.categories.tolist()
- # print("Attribute names are: ", self.attribute_names)
- # print("Target names are: ", self.target_names)
- # self.attribute_names = ["hair", "feathers", "eggs", "milk", "airborne", "aquatic", "predator", "toothed", "backbone", "breathes", "venomous", "fins", "legs", "tail", "domestic", "catsize"] # noqa: E501
- # self.target_names = ["mammal", "bird", "reptile", "fish", "amphibian", "insect", "invertebrate"] # noqa: E501
- # Define variables
- for name in self.attribute_names + self.target_names:
- exec(
- f"globals()['{name}'] = Int('{name}')"
- ) # or use dict to create var and modify rules
- # Define rules
+ # Create a Z3 Int variable for every attribute and target, keyed by name.
+ self.vars = {name: Int(name) for name in self.attribute_names + self.target_names}
+ v = self.vars
+
rules = [
- Implies(milk == 1, mammal == 1),
- Implies(mammal == 1, milk == 1),
- Implies(mammal == 1, backbone == 1),
- Implies(mammal == 1, breathes == 1),
- Implies(feathers == 1, bird == 1),
- Implies(bird == 1, feathers == 1),
- Implies(bird == 1, eggs == 1),
- Implies(bird == 1, backbone == 1),
- Implies(bird == 1, breathes == 1),
- Implies(bird == 1, legs == 2),
- Implies(bird == 1, tail == 1),
- Implies(reptile == 1, backbone == 1),
- Implies(reptile == 1, breathes == 1),
- Implies(reptile == 1, tail == 1),
- Implies(fish == 1, aquatic == 1),
- Implies(fish == 1, toothed == 1),
- Implies(fish == 1, backbone == 1),
- Implies(fish == 1, Not(breathes == 1)),
- Implies(fish == 1, fins == 1),
- Implies(fish == 1, legs == 0),
- Implies(fish == 1, tail == 1),
- Implies(amphibian == 1, eggs == 1),
- Implies(amphibian == 1, aquatic == 1),
- Implies(amphibian == 1, backbone == 1),
- Implies(amphibian == 1, breathes == 1),
- Implies(amphibian == 1, legs == 4),
- Implies(insect == 1, eggs == 1),
- Implies(insect == 1, Not(backbone == 1)),
- Implies(insect == 1, legs == 6),
- Implies(invertebrate == 1, Not(backbone == 1)),
+ Implies(v["milk"] == 1, v["mammal"] == 1),
+ Implies(v["mammal"] == 1, v["milk"] == 1),
+ Implies(v["mammal"] == 1, v["backbone"] == 1),
+ Implies(v["mammal"] == 1, v["breathes"] == 1),
+ Implies(v["feathers"] == 1, v["bird"] == 1),
+ Implies(v["bird"] == 1, v["feathers"] == 1),
+ Implies(v["bird"] == 1, v["eggs"] == 1),
+ Implies(v["bird"] == 1, v["backbone"] == 1),
+ Implies(v["bird"] == 1, v["breathes"] == 1),
+ Implies(v["bird"] == 1, v["legs"] == 2),
+ Implies(v["bird"] == 1, v["tail"] == 1),
+ Implies(v["reptile"] == 1, v["backbone"] == 1),
+ Implies(v["reptile"] == 1, v["breathes"] == 1),
+ Implies(v["reptile"] == 1, v["tail"] == 1),
+ Implies(v["fish"] == 1, v["aquatic"] == 1),
+ Implies(v["fish"] == 1, v["toothed"] == 1),
+ Implies(v["fish"] == 1, v["backbone"] == 1),
+ Implies(v["fish"] == 1, Not(v["breathes"] == 1)),
+ Implies(v["fish"] == 1, v["fins"] == 1),
+ Implies(v["fish"] == 1, v["legs"] == 0),
+ Implies(v["fish"] == 1, v["tail"] == 1),
+ Implies(v["amphibian"] == 1, v["eggs"] == 1),
+ Implies(v["amphibian"] == 1, v["aquatic"] == 1),
+ Implies(v["amphibian"] == 1, v["backbone"] == 1),
+ Implies(v["amphibian"] == 1, v["breathes"] == 1),
+ Implies(v["amphibian"] == 1, v["legs"] == 4),
+ Implies(v["insect"] == 1, v["eggs"] == 1),
+ Implies(v["insect"] == 1, Not(v["backbone"] == 1)),
+ Implies(v["insect"] == 1, v["legs"] == 6),
+ Implies(v["invertebrate"] == 1, Not(v["backbone"] == 1)),
]
- # Define weights and sum of violated weights
self.weights = {rule: 1 for rule in rules}
self.total_violation_weight = Sum(
[If(Not(rule), self.weights[rule], 0) for rule in self.weights]
)
def logic_forward(self, pseudo_label, data_point):
- attribute_names, target_names = self.attribute_names, self.target_names
- solver = self.solver
- total_violation_weight = self.total_violation_weight
pseudo_label, data_point = pseudo_label[0], data_point[0]
-
self.solver.reset()
- for name, value in zip(attribute_names, data_point):
- solver.add(eval(f"{name} == {value}"))
- for cate, name in zip(self.pseudo_label_list, target_names):
- value = 1 if (cate == pseudo_label) else 0
- solver.add(eval(f"{name} == {value}"))
- if solver.check() == sat:
- model = solver.model()
- total_weight = model.evaluate(total_violation_weight)
- return total_weight.as_long()
- else:
- # No solution found
- return 1e10
+ for name, value in zip(self.attribute_names, data_point):
+ self.solver.add(self.vars[name] == value)
+ for cate, name in zip(self.pseudo_label_list, self.target_names):
+ value = 1 if cate == pseudo_label else 0
+ self.solver.add(self.vars[name] == value)
+
+ if self.solver.check() == sat:
+ model = self.solver.model()
+ return model.evaluate(self.total_violation_weight).as_long()
+ # No solution found
+ return 1e10
diff --git a/examples/zoo/main.py b/examples/zoo/main.py
index 64f8109a7..41d3f4103 100644
--- a/examples/zoo/main.py
+++ b/examples/zoo/main.py
@@ -14,7 +14,7 @@
from kb import ZooKB
-def consitency(data_example, candidates, candidate_idxs, reasoning_results):
+def consistency(data_example, candidates, candidate_idxs, reasoning_results):
pred_prob = data_example.pred_prob
model_scores = avg_confidence_dist(pred_prob, candidate_idxs)
rule_scores = np.array(reasoning_results)
@@ -57,7 +57,7 @@ def main():
kb = ZooKB()
# Create reasoner
- reasoner = Reasoner(kb, dist_func=consitency)
+ reasoner = Reasoner(kb, dist_func=consistency)
# -- Building Evaluation Metrics --------------------
print_log("Building Evaluation Metrics.", logger="current")
diff --git a/pyproject.toml b/pyproject.toml
index 887dd04ff..9a6d5094b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -10,7 +10,7 @@ authors = [
]
description = "Abductive Learning (ABL) toolkit"
readme = "README.md"
-requires-python = ">=3.7.0"
+requires-python = ">=3.8.0"
license = {text = "MIT LICENSE"}
keywords = ["machine-learning", "framework", "abductive-learning", "neuro-symbolic"]
classifiers = [
@@ -24,7 +24,6 @@ classifiers = [
"Operating System :: POSIX :: Linux",
"Operating System :: Microsoft :: Windows",
"Operating System :: MacOS :: MacOS X",
- "Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
@@ -45,6 +44,7 @@ Issues = "https://github.com/AbductiveLearning/ABLkit/issues"
[project.optional-dependencies]
test = [
+ "pytest",
"pytest-cov",
"black==22.10.0",
]
\ No newline at end of file
diff --git a/tests/test_reasoning.py b/tests/test_reasoning.py
index fa96dcee9..70b9c96a5 100644
--- a/tests/test_reasoning.py
+++ b/tests/test_reasoning.py
@@ -114,11 +114,10 @@ def test_reasoner_init(self, reasoner_instance):
def test_invalid_predefined_dist_func(self, kb_add):
with pytest.raises(NotImplementedError) as excinfo:
Reasoner(kb_add, "invalid_dist_func")
- assert (
- 'Valid options for predefined dist_func include "hamming", "confidence" '
- + 'and "avg_confidence"'
- in str(excinfo.value)
- )
+ message = str(excinfo.value)
+ for option in ("hamming", "confidence", "avg_confidence", "similarity", "rejection"):
+ assert option in message
+ assert "invalid_dist_func" in message
def random_dist(self, data_example, candidates, candidate_idxs, reasoning_results):
cost_list = [np.random.rand() for _ in candidates]