diff --git a/.github/workflows/build-and-test.yaml b/.github/workflows/build-and-test.yaml index be85d2bf5..02b8cb806 100644 --- a/.github/workflows/build-and-test.yaml +++ b/.github/workflows/build-and-test.yaml @@ -2,59 +2,52 @@ name: ABLkit-CI on: push: - branches: [ main ] + branches: [main] pull_request: - branches: [ main ] + branches: [main] jobs: build: runs-on: ${{ matrix.os }} strategy: + fail-fast: false matrix: os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - - uses: actions/checkout@v2 - - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - - uses: syphar/restore-virtualenv@v1 - id: cache-virtualenv - with: - custom_cache_key_element: ABLkit - requirement_files: requirements.txt - - - uses: syphar/restore-pip-download-cache@v1 - if: steps.cache-virtualenv.outputs.cache-hit != 'true' - - - name: Install SWI-Prolog on Ubuntu - if: matrix.os == 'ubuntu-latest' - run: sudo apt-get install swi-prolog - - name: Install SWI-Prolog on Windows - if: matrix.os == 'windows-latest' - run: choco install swi-prolog - - name: Install SWI-Prolog on MACOS - if: matrix.os == 'macos-latest' - run: brew install swi-prolog - - - name: Install package dependencies - if : steps.cache-virtualenv.outputs.cache-hit != 'true' - run: | - python -m pip install --upgrade pip - pip install pytest pytest-cov - - name: Install - if : steps.cache-virtualenv.outputs.cache-hit != 'true' - run: pip install -v -e . - - - name: Run tests - run: | - pytest --cov-config=.coveragerc --cov-report=xml --cov=ablkit ./tests - - - name: Publish code coverage - uses: codecov/codecov-action@v1 - with: - token: ${{ secrets.CODECOV_TOKEN }} - file: ./coverage.xml \ No newline at end of file + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: pip + cache-dependency-path: pyproject.toml + + - name: Install SWI-Prolog (Ubuntu) + if: matrix.os == 'ubuntu-latest' + run: sudo apt-get update && sudo apt-get install -y swi-prolog + + - name: Install SWI-Prolog (Windows) + if: matrix.os == 'windows-latest' + run: choco install -y swi-prolog + + - name: Install SWI-Prolog (macOS) + if: matrix.os == 'macos-latest' + run: brew install swi-prolog + + - name: Install package and test dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[test]" + + - name: Run tests + run: pytest --cov-config=.coveragerc --cov-report=xml --cov=ablkit ./tests + + - name: Publish code coverage + if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.10' + uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} + files: ./coverage.xml + fail_ci_if_error: false diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index b79d64b40..fd3d0164e 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -1,24 +1,27 @@ -name: flake8 Lint - -on: - push: - branches: [ main ] - pull_request: - branches: [ main ] - -jobs: - flake8-lint: - runs-on: ubuntu-latest - name: Lint - steps: - - name: Check out source repository - uses: actions/checkout@v3 - - name: Set up Python environment - uses: actions/setup-python@v4 - with: - python-version: "3.8" - - name: flake8 Lint - uses: py-actions/flake8@v2 - with: - max-line-length: "100" - args: --ignore=E203,W503,F821 \ No newline at end of file +name: flake8 Lint + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + flake8-lint: + runs-on: ubuntu-latest + name: Lint + steps: + - name: Check out source repository + uses: actions/checkout@v4 + + - name: Set up Python environment + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: flake8 Lint + uses: py-actions/flake8@v2 + with: + path: ablkit + max-line-length: "100" + args: --ignore=E203,W503,F821 diff --git a/.gitignore b/.gitignore index ba96e9d3a..80cfd18f9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,13 +1,31 @@ *.pyc examples/**/*.png +examples/**/*.jpg *.pk +*.pkl *.pth *.json *.ckpt results raw/ ablkit.egg-info/ -examples/**/*.jpg .idea/ build/ -.history \ No newline at end of file +.history + +# Datasets and large binaries +*.zip +*.tar +*.tar.gz +*.tgz +*.gz +*.npz +*.npy +*.h5 +*.hdf5 +*.parquet +*.feather +*.arrow +*-ubyte +*.data +*.arff diff --git a/README.md b/README.md index c83ebb93f..5a431ba09 100644 --- a/README.md +++ b/README.md @@ -2,13 +2,13 @@ -[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/ablkit)](https://pypi.org/project/ablkit/) [![PyPI version](https://badgen.net/pypi/v/ablkit)](https://pypi.org/project/ablkit/) [![Documentation Status](https://readthedocs.org/projects/ablkit/badge/?version=latest)](https://ablkit.readthedocs.io/en/latest/?badge=latest) [![license](https://img.shields.io/github/license/mashape/apistatus.svg?maxAge=2592000)](https://github.com/AbductiveLearning/ABLkit/blob/main/LICENSE) [![flake8 Lint](https://github.com/AbductiveLearning/ABLkit/actions/workflows/lint.yaml/badge.svg)](https://github.com/AbductiveLearning/ABLkit/actions/workflows/lint.yaml) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![ABLkit-CI](https://github.com/AbductiveLearning/ABLkit/actions/workflows/build-and-test.yaml/badge.svg)](https://github.com/AbductiveLearning/ABLkit/actions/workflows/build-and-test.yaml) +[![license](https://img.shields.io/github/license/mashape/apistatus.svg?maxAge=2592000)](https://github.com/AbductiveLearning/ABLkit/blob/main/LICENSE) [![last commit](https://img.shields.io/github/last-commit/AbductiveLearning/ablkit)](https://img.shields.io/github/last-commit/AbductiveLearning/ablkit) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/ablkit)](https://pypi.org/project/ablkit/) [![PyPI version](https://badgen.net/pypi/v/ablkit)](https://pypi.org/project/ablkit/) [![Documentation Status](https://readthedocs.org/projects/ablkit/badge/?version=latest)](https://ablkit.readthedocs.io/en/latest/?badge=latest) [![ABLkit-CI](https://github.com/AbductiveLearning/ABLkit/actions/workflows/build-and-test.yaml/badge.svg)](https://github.com/AbductiveLearning/ABLkit/actions/workflows/build-and-test.yaml) [![flake8 Lint](https://github.com/AbductiveLearning/ABLkit/actions/workflows/lint.yaml/badge.svg)](https://github.com/AbductiveLearning/ABLkit/actions/workflows/lint.yaml) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![PyPI - Downloads](https://img.shields.io/pypi/dm/ablkit)](https://pypi.org/project/ablkit/) -[📘Documentation](https://ablkit.readthedocs.io/en/latest/index.html) | [📄Paper](https://journal.hep.com.cn/fcs/EN/10.1007/s11704-024-40085-7) | [📚Examples](https://github.com/AbductiveLearning/ABLkit/tree/main/examples) | [💬Reporting Issues](https://github.com/AbductiveLearning/ABLkit/issues/new) +[📘Documentation](https://ablkit.readthedocs.io/en/latest/index.html) | [📄Paper](https://journal.hep.com.cn/fcs/EN/10.1007/s11704-024-40085-7) | [🧪Examples](https://github.com/AbductiveLearning/ABLkit/tree/main/examples) | [💬Reporting Issues](https://github.com/AbductiveLearning/ABLkit/issues/new) -# ABLkit: A Toolkit for Abductive Learning +# 🧰 ABLkit: A Toolkit for Abductive Learning 📊📐 **ABLkit** is an efficient Python toolkit for [**Abductive Learning (ABL)**](https://www.lamda.nju.edu.cn/publication/chap_ABL.pdf). ABL is a novel paradigm that integrates machine learning and logical reasoning in a unified framework. It is suitable for tasks where both data and (logical) domain knowledge are available. @@ -28,7 +28,7 @@ ABLkit encapsulates advanced ABL techniques, providing users with an efficient a ABLkit

-## Installation +## 🛠️ Installation ### Install from PyPI @@ -60,7 +60,7 @@ sudo apt-get install swi-prolog For Windows and Mac users, please refer to the [SWI-Prolog Install Guide](https://github.com/yuce/pyswip/blob/master/INSTALL.md). -## Quick Start +## ⚡ Quick Start We use the MNIST Addition task as a quick start example. In this task, pairs of MNIST handwritten images and their sums are given, alongwith a domain knowledge base which contains information on how to perform addition operations. Our objective is to input a pair of handwritten images and accurately determine their sum. @@ -184,7 +184,7 @@ bridge.test(test_data) To explore detailed tutorials and information, please refer to: [Documentation on Read the Docs](https://ablkit.readthedocs.io/en/latest/index.html). -## Examples +## 🧪 Examples We provide several examples in `examples/`. Each example is stored in a separate folder containing a README file. @@ -192,8 +192,9 @@ We provide several examples in `examples/`. Each example is stored in a separate + [Handwritten Formula (HWF)](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/hwf) + [Handwritten Equation Decipherment](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/hed) + [Zoo](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/zoo) ++ [BDD-OIA](https://github.com/AbductiveLearning/ABLkit/tree/main/examples/bdd_oia) -## References +## 📚 References For more information about ABL, please refer to: [Zhou, 2019](http://scis.scichina.com/en/2019/076101.pdf) and [Zhou and Huang, 2022](https://www.lamda.nju.edu.cn/publication/chap_ABL.pdf). @@ -220,7 +221,7 @@ For more information about ABL, please refer to: [Zhou, 2019](http://scis.scichi } ``` -## Citation +## 📝 Citation To cite ABLkit, please cite the following paper: [Huang et al., 2024](https://journal.hep.com.cn/fcs/EN/10.1007/s11704-024-40085-7). @@ -234,4 +235,54 @@ To cite ABLkit, please cite the following paper: [Huang et al., 2024](https://j pages = {186354}, year = {2024} } -``` \ No newline at end of file +``` + +## ✨ Contributors + +We would like to thank the following contributors for their efforts on this project: (*: current maintainer) + + + + + + + + + +
+ + +
+ En-Hao Gao +
+
+ + +
+ Yu-Xuan Huang +
+
+ + +
+ Wen-Chao Hu +
* +
+ + +
+ Qi-Jie Li +
+
+ + +
+ Yang Hang +
* +
+ +We also thank the following users for their helpful suggestions and feedback: + +- [Hao-Yuan He](https://github.com/Hao-Yuan-He) +- [Lin-Han Jia](https://github.com/YGZWQZD) +- [Wang-Zhou Dai](https://github.com/haldai) \ No newline at end of file diff --git a/ablkit/bridge/__init__.py b/ablkit/bridge/__init__.py index 502a118cf..3d0f3f8e5 100644 --- a/ablkit/bridge/__init__.py +++ b/ablkit/bridge/__init__.py @@ -1,4 +1,6 @@ +from .a3bl_bridge import A3BLBridge from .base_bridge import BaseBridge from .simple_bridge import SimpleBridge +from .verification_bridge import VerificationBridge -__all__ = ["BaseBridge", "SimpleBridge"] +__all__ = ["BaseBridge", "SimpleBridge", "A3BLBridge", "VerificationBridge"] diff --git a/ablkit/bridge/a3bl_bridge.py b/ablkit/bridge/a3bl_bridge.py new file mode 100644 index 000000000..d6c53faa2 --- /dev/null +++ b/ablkit/bridge/a3bl_bridge.py @@ -0,0 +1,166 @@ +import os.path as osp +from typing import Any, List, Optional, Tuple, Union + +from ..data.evaluation.base_metric import BaseMetric +from ..data.structures.list_data import ListData +from ..learning import ABLModel +from ..reasoning import A3BLReasoner +from ..utils import print_log +from .simple_bridge import SimpleBridge + + +class A3BLBridge(SimpleBridge): + """ + An ambiguity-aware implementation for bridging machine learning and reasoning parts. + + Reference: https://github.com/Hao-Yuan-He/A3BL + + Involves the following five steps: + - Predict class probabilities and indices for the given data examples. + - Map indices into pseudo-labels. + - Enumerate all valid pseudo-labels. + - Revise pseudo-labels to label distribution based on the class probabilities. + - Train the model. + + Parameters + ---------- + model : ABLModel + The machine learning model wrapped in ``ABLModel``, used for prediction + and training. The wrapped base model should expose ``extract_features`` + so embeddings are available for the soft-label aggregation. + reasoner : A3BLReasoner + The reasoning part wrapped in ``A3BLReasoner``, used for pseudo-label + enumeration and soft-label aggregation. + metric_list : List[BaseMetric] + A list of metrics used for evaluating the model's performance. + """ + + def __init__( + self, + model: ABLModel, + reasoner: A3BLReasoner, + metric_list: List[BaseMetric], + ): + super().__init__(model, reasoner, metric_list) + + def abduce_soft_label(self, data_examples: ListData) -> List[List[Any]]: + """ + Revise predicted pseudo-labels to a soft label, given data examples using abduction. + + Parameters + ---------- + data_examples : ListData + Data examples containing predicted pseudo-labels. + + Returns + ------- + List[List[Any]] + A list of abduced soft labels for the given data examples. + """ + self.reasoner.batch_abduce(data_examples) + return data_examples.abduced_soft_label + + def train_data_iter( + self, + train_data, + val_data=None, + segment_size=1.0, + ): + data_examples = self.data_preprocess("train", train_data) + + if val_data is not None: + val_data_examples = self.data_preprocess("val", val_data) + else: + val_data_examples = data_examples + + if isinstance(segment_size, int): + if segment_size <= 0: + raise ValueError("segment_size should be positive.") + elif isinstance(segment_size, float): + if 0 < segment_size <= 1: + segment_size = int(segment_size * len(data_examples)) + else: + raise ValueError("segment_size should be in (0, 1].") + else: + raise ValueError("segment_size should be int or float.") + + for seg_idx in range((len(data_examples) - 1) // segment_size + 1): + sub = data_examples[seg_idx * segment_size: (seg_idx + 1) * segment_size] + yield sub, val_data_examples + + def train( + self, + train_data: Union[ + ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]] + ], + val_data: Optional[ + Union[ + ListData, + Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]], + ] + ] = None, + loops: int = 50, + segment_size: Union[int, float] = 1.0, + eval_interval: int = 1, + save_interval: Optional[int] = None, + save_dir: Optional[str] = None, + ): + """ + A typical training pipeline of Abuductive Learning. + + Parameters + ---------- + train_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]] + Training data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` + object with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. + - ``X`` is a list of sublists representing the input data. + - ``gt_pseudo_label`` is only used to evaluate the performance of the ``ABLModel`` but + not to train. ``gt_pseudo_label`` can be ``None``. + - ``Y`` is a list representing the ground truth reasoning result for each sublist + in ``X``. + label_data : Union[ListData, Tuple[List[List[Any]], List[List[Any]], List[Any]]], optional + Labeled data should be in the same format as ``train_data``. The only difference is + that the ``gt_pseudo_label`` in ``label_data`` should not be ``None`` and will be + utilized to train the model. Defaults to None. + val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]], optional # noqa: E501 pylint: disable=line-too-long + Validation data should be in the same format as ``train_data``. Both ``gt_pseudo_label`` + and ``Y`` can be either None or not, which depends on the evaluation metircs in + ``self.metric_list``. If ``val_data`` is None, ``train_data`` will be used to validate + the model during training time. Defaults to None. + loops : int + Learning part and Reasoning part will be iteratively optimized + for ``loops`` times. Defaults to 50. + segment_size : Union[int, float] + Data will be split into segments of this size and data in each segment + will be used together to train the model. Defaults to 1.0. + eval_interval : int + The model will be evaluated every ``eval_interval`` loop during training, + Defaults to 1. + save_interval : int, optional + The model will be saved every ``eval_interval`` loop during training. + Defaults to None. + save_dir : str, optional + Directory to save the model. Defaults to None. + """ + for loop in range(loops): + iterator = self.train_data_iter(train_data, val_data, segment_size) + for train_examples_batch, val_examples_batch in iterator: + print_log( + f"loop(train) [{loop + 1}/{loops}] segment(train) ", logger="current" + ) + self.predict(train_examples_batch) + self.idx_to_pseudo_label(train_examples_batch) + self.abduce_pseudo_label(train_examples_batch) + self.filter_pseudo_label(train_examples_batch) + self.pseudo_label_to_idx(train_examples_batch) + self.model.train(train_examples_batch) + + if (loop + 1) % eval_interval == 0 or loop == loops - 1: + print_log(f"Eval start: loop(val) [{loop + 1}]", logger="current") + self._valid(val_examples_batch) + + if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1): + print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current") + self.model.save( + save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth") + ) diff --git a/ablkit/bridge/base_bridge.py b/ablkit/bridge/base_bridge.py index a3a40add2..8f5e55f35 100644 --- a/ablkit/bridge/base_bridge.py +++ b/ablkit/bridge/base_bridge.py @@ -37,9 +37,10 @@ class BaseBridge(metaclass=ABCMeta): def __init__(self, model: ABLModel, reasoner: Reasoner) -> None: if not isinstance(model, ABLModel): raise TypeError(f"Expected an instance of ABLModel, but received type: {type(model)}") - if not isinstance(reasoner, Reasoner): + if not (hasattr(reasoner, "idx_to_label") and hasattr(reasoner, "label_to_idx")): raise TypeError( - f"Expected an instance of Reasoner, but received type: {type(reasoner)}" + "Expected a reasoner exposing idx_to_label / label_to_idx (e.g. Reasoner " + f"or VerificationReasoner), but received type: {type(reasoner)}" ) self.model = model diff --git a/ablkit/bridge/simple_bridge.py b/ablkit/bridge/simple_bridge.py index 5c2cbfbb9..e48254714 100644 --- a/ablkit/bridge/simple_bridge.py +++ b/ablkit/bridge/simple_bridge.py @@ -94,6 +94,23 @@ def abduce_pseudo_label(self, data_examples: ListData) -> List[List[Any]]: self.reasoner.batch_abduce(data_examples) return data_examples.abduced_pseudo_label + def supervised_abduce_pseudo_label(self, data_examples: ListData) -> List[List[Any]]: + """ + Revise predicted pseudo-labels of the given data examples using ground truth. + + Parameters + ---------- + data_examples : ListData + Data examples containing predicted pseudo-labels. + + Returns + ------- + List[List[Any]] + A list of ground truth/abduced pseudo-labels for the given data examples. + """ + self.reasoner.batch_supervised_abduce(data_examples) + return data_examples.abduced_pseudo_label + def idx_to_pseudo_label(self, data_examples: ListData) -> List[List[Any]]: """ Map indices of data examples into pseudo-labels. @@ -211,10 +228,14 @@ def train( Union[ListData, Tuple[List[List[Any]], List[List[Any]], List[Any]]] ] = None, val_data: Optional[ - Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] + Union[ + ListData, + Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]], + ] ] = None, loops: int = 50, segment_size: Union[int, float] = 1.0, + use_supervised_data: bool = False, eval_interval: int = 1, save_interval: Optional[int] = None, save_dir: Optional[str] = None, @@ -257,58 +278,95 @@ def train( Directory to save the model. Defaults to None. """ data_examples = self.data_preprocess("train", train_data) + label_data_examples = ( + self.data_preprocess("label", label_data) if label_data is not None else None + ) + val_data_examples = ( + self.data_preprocess("val", val_data) if val_data is not None else data_examples + ) - if label_data is not None: - label_data_examples = self.data_preprocess("label", label_data) - else: - label_data_examples = None - - if val_data is not None: - val_data_examples = self.data_preprocess("val", val_data) - else: - val_data_examples = data_examples - - if isinstance(segment_size, int): - if segment_size <= 0: - raise ValueError("segment_size should be positive.") - elif isinstance(segment_size, float): - if 0 < segment_size <= 1: - segment_size = int(segment_size * len(data_examples)) - else: - raise ValueError("segment_size should be in (0, 1].") - else: - raise ValueError("segment_size should be int or float.") + segment_size = self._resolve_segment_size(segment_size, len(data_examples)) + num_segments = (len(data_examples) - 1) // segment_size + 1 for loop in range(loops): - for seg_idx in range((len(data_examples) - 1) // segment_size + 1): + for seg_idx in range(num_segments): print_log( f"loop(train) [{loop + 1}/{loops}] segment(train) " - f"[{(seg_idx + 1)}/{(len(data_examples) - 1) // segment_size + 1}] ", + f"[{seg_idx + 1}/{num_segments}] ", logger="current", ) - sub_data_examples = data_examples[ seg_idx * segment_size : (seg_idx + 1) * segment_size ] - self.predict(sub_data_examples) - self.idx_to_pseudo_label(sub_data_examples) - self.abduce_pseudo_label(sub_data_examples) - self.filter_pseudo_label(sub_data_examples) - self.concat_data_examples(sub_data_examples, label_data_examples) - self.pseudo_label_to_idx(sub_data_examples) - self.model.train(sub_data_examples) - - if (loop + 1) % eval_interval == 0 or loop == loops - 1: - print_log(f"Eval start: loop(val) [{loop + 1}]", logger="current") - self._valid(val_data_examples) - - if save_interval is not None and ((loop + 1) % save_interval == 0 or loop == loops - 1): - print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current") - self.model.save( - save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth") + self._train_one_segment( + sub_data_examples, label_data_examples, use_supervised_data ) - def _valid(self, data_examples: ListData) -> None: + self._maybe_eval(val_data_examples, loop, loops, eval_interval) + self._maybe_save(loop, loops, save_interval, save_dir) + + @staticmethod + def _resolve_segment_size(segment_size: Union[int, float], dataset_len: int) -> int: + """Validate and convert ``segment_size`` into an absolute number of examples.""" + if isinstance(segment_size, int): + if segment_size <= 0: + raise ValueError("segment_size should be positive.") + return segment_size + if isinstance(segment_size, float): + if not (0 < segment_size <= 1): + raise ValueError("segment_size should be in (0, 1].") + return int(segment_size * dataset_len) + raise ValueError("segment_size should be int or float.") + + def _train_one_segment( + self, + sub_data_examples: ListData, + label_data_examples: Optional[ListData], + use_supervised_data: bool, + ) -> None: + """Run prediction, abduction, label-data concat, and a single model.train step.""" + self.predict(sub_data_examples) + self.idx_to_pseudo_label(sub_data_examples) + if use_supervised_data: + self.supervised_abduce_pseudo_label(sub_data_examples) + else: + self.abduce_pseudo_label(sub_data_examples) + self.filter_pseudo_label(sub_data_examples) + self.concat_data_examples(sub_data_examples, label_data_examples) + self.pseudo_label_to_idx(sub_data_examples) + if len(sub_data_examples) == 0: + return + self.model.train(sub_data_examples) + + def _maybe_eval( + self, + val_data_examples: ListData, + loop: int, + loops: int, + eval_interval: int, + ) -> None: + """Evaluate on ``val_data_examples`` at the configured interval (and on the last loop).""" + if (loop + 1) % eval_interval == 0 or loop == loops - 1: + print_log(f"Eval start: loop(val) [{loop + 1}]", logger="current") + self._valid(val_data_examples, prefix="val") + + def _maybe_save( + self, + loop: int, + loops: int, + save_interval: Optional[int], + save_dir: Optional[str], + ) -> None: + """Persist the model at the configured interval (and on the last loop).""" + if save_interval is None: + return + if (loop + 1) % save_interval == 0 or loop == loops - 1: + print_log(f"Saving model: loop(save) [{loop + 1}]", logger="current") + self.model.save( + save_path=osp.join(save_dir, f"model_checkpoint_loop_{loop + 1}.pth") + ) + + def _valid(self, data_examples: ListData, prefix: str = "val") -> None: """ Internal method for validating the model with given data examples. @@ -320,21 +378,31 @@ def _valid(self, data_examples: ListData) -> None: self.predict(data_examples) self.idx_to_pseudo_label(data_examples) + for metric in self.metric_list: + metric.prefix = prefix + for metric in self.metric_list: metric.process(data_examples) res = dict() for metric in self.metric_list: res.update(metric.evaluate()) + msg = "Evaluation ended, " for k, v in res.items(): - msg += k + f": {v:.3f} " + try: + v = float(v) + msg += k + f": {v:.3f} " + except (TypeError, ValueError): + pass + print_log(msg, logger="current") def valid( self, val_data: Union[ - ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]] + ListData, + Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]], ], ) -> None: """ @@ -349,12 +417,13 @@ def valid( ``self.metric_list``. """ val_data_examples = self.data_preprocess("val", val_data) - self._valid(val_data_examples) + self._valid(val_data_examples, prefix="val") def test( self, test_data: Union[ - ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]] + ListData, + Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]], ], ) -> None: """ @@ -370,4 +439,4 @@ def test( """ print_log("Test start:", logger="current") test_data_examples = self.data_preprocess("test", test_data) - self._valid(test_data_examples) + self._valid(test_data_examples, prefix="test") diff --git a/ablkit/bridge/verification_bridge.py b/ablkit/bridge/verification_bridge.py new file mode 100644 index 000000000..b87537d24 --- /dev/null +++ b/ablkit/bridge/verification_bridge.py @@ -0,0 +1,117 @@ +""" +Bridge for Verification Learning. + +:class:`VerificationBridge` replaces the single-candidate abduction step of +:class:`SimpleBridge` with a top-K enumeration provided by +:class:`~ablkit.reasoning.VerificationReasoner`. For each +segment the bridge trains the model once per top-K candidate, exposing +the model to every assignment that is consistent with the knowledge base. + +Reference: https://github.com/VerificationLearning/VerificationLearning +""" + +from typing import Any, List, Optional, Tuple, Union + +from ..data.evaluation import BaseMetric +from ..data.structures import ListData +from ..learning import ABLModel +from ..reasoning.reasoner import VerificationReasoner +from ..utils import print_log +from .simple_bridge import SimpleBridge + + +class VerificationBridge(SimpleBridge): + """ + Bridge implementing the Verification Learning training loop. + + Parameters + ---------- + model : ABLModel + Wrapped learning model. + reasoner : VerificationReasoner + Top-K reasoner. The bridge reads ``reasoner.top_k`` to decide how + many training passes to run per segment. + metric_list : List[BaseMetric] + Evaluation metrics, identical to :class:`SimpleBridge`. + """ + + def __init__( + self, + model: ABLModel, + reasoner: VerificationReasoner, + metric_list: List[BaseMetric], + ) -> None: + if not isinstance(reasoner, VerificationReasoner): + raise TypeError( + "VerificationBridge requires a VerificationReasoner; " + f"got {type(reasoner).__name__}." + ) + super().__init__(model, reasoner, metric_list) + + def train( + self, + train_data: Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]], + val_data: Optional[ + Union[ + ListData, + Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]], + ] + ] = None, + loops: int = 50, + segment_size: Union[int, float] = 1.0, + eval_interval: int = 1, + save_interval: Optional[int] = None, + save_dir: Optional[str] = None, + ) -> None: + """ + Verification Learning training loop. For each segment we predict + once, enumerate the top-K consistent candidates, then run a + ``model.train`` pass per candidate. + """ + data_examples = self.data_preprocess("train", train_data) + val_data_examples = ( + self.data_preprocess("val", val_data) if val_data is not None else data_examples + ) + + segment_size = self._resolve_segment_size(segment_size, len(data_examples)) + num_segments = (len(data_examples) - 1) // segment_size + 1 + + for loop in range(loops): + for seg_idx in range(num_segments): + print_log( + f"loop(train) [{loop + 1}/{loops}] segment(train) " + f"[{seg_idx + 1}/{num_segments}] ", + logger="current", + ) + sub_data_examples = data_examples[ + seg_idx * segment_size : (seg_idx + 1) * segment_size + ] + self._train_one_segment_verification(sub_data_examples) + + self._maybe_eval(val_data_examples, loop, loops, eval_interval) + self._maybe_save(loop, loops, save_interval, save_dir) + + def _train_one_segment_verification(self, sub_data_examples: ListData) -> None: + """ + Predict, enumerate top-K candidates, then train once per candidate. + Each example's k-th training pass uses its k-th candidate (or, if + the example yielded fewer than k candidates, its last available + candidate, repeated). + """ + self.predict(sub_data_examples) + self.idx_to_pseudo_label(sub_data_examples) + per_example_candidates = self.reasoner.batch_top_k(sub_data_examples) + + if not per_example_candidates: + return + + max_k = max(len(cands) for cands in per_example_candidates) + for k_idx in range(max_k): + sub_data_examples.abduced_pseudo_label = [ + cands[min(k_idx, len(cands) - 1)] for cands in per_example_candidates + ] + self.filter_pseudo_label(sub_data_examples) + self.pseudo_label_to_idx(sub_data_examples) + if len(sub_data_examples) == 0: + continue + self.model.train(sub_data_examples) diff --git a/ablkit/data/structures/base_data_element.py b/ablkit/data/structures/base_data_element.py index 2e9b00052..e9bd60c26 100644 --- a/ablkit/data/structures/base_data_element.py +++ b/ablkit/data/structures/base_data_element.py @@ -306,14 +306,12 @@ def keys(self) -> list: Returns: list: Contains all keys in data_fields. """ - # We assume that the name of the attribute related to property is - # '_' + the name of the property. We use this rule to filter out - # private keys. - # TODO: Use a more robust way to solve this problem + cls = type(self) private_keys = { - "_" + key - for key in self._data_fields - if isinstance(getattr(type(self), key, None), property) + name + for name in self._data_fields + if name.startswith("_") + and isinstance(getattr(cls, name[1:], None), property) } return list(self._data_fields - private_keys) diff --git a/ablkit/data/structures/list_data.py b/ablkit/data/structures/list_data.py index 256deda7f..6fea13ad1 100644 --- a/ablkit/data/structures/list_data.py +++ b/ablkit/data/structures/list_data.py @@ -194,7 +194,7 @@ def __getitem__(self, item: IndexType) -> "ListData": new_data[k] = None else: new_data[k] = v[item] - return new_data # type:ignore + return new_data # type: ignore def flatten(self, item: str) -> List: """ diff --git a/ablkit/learning/__init__.py b/ablkit/learning/__init__.py index ad016a654..bd2f2a7f7 100644 --- a/ablkit/learning/__init__.py +++ b/ablkit/learning/__init__.py @@ -1,5 +1,19 @@ -from .abl_model import ABLModel -from .basic_nn import BasicNN -from .torch_dataset import ClassificationDataset, PredictionDataset, RegressionDataset +from .abl_model import ABLModel, MultiLabelABLModel +from .basic_nn import BasicNN, MultiLabelBasicNN +from .torch_dataset import ( + ClassificationDataset, + MultiLabelClassificationDataset, + PredictionDataset, + RegressionDataset, +) -__all__ = ["ABLModel", "BasicNN", "ClassificationDataset", "PredictionDataset", "RegressionDataset"] +__all__ = [ + "ABLModel", + "BasicNN", + "MultiLabelABLModel", + "MultiLabelBasicNN", + "ClassificationDataset", + "MultiLabelClassificationDataset", + "PredictionDataset", + "RegressionDataset", +] diff --git a/ablkit/learning/abl_model.py b/ablkit/learning/abl_model.py index 7ceaa1510..db7f7cdee 100644 --- a/ablkit/learning/abl_model.py +++ b/ablkit/learning/abl_model.py @@ -8,6 +8,8 @@ import pickle from typing import Any, Dict +import numpy as np + from ..data.structures import ListData from ..utils import reform_list @@ -20,9 +22,10 @@ class ABLModel: ---------- base_model : Machine Learning Model The machine learning base model used for training and prediction. This model should - implement the ``fit`` and ``predict`` methods. It's recommended, but not required, for the - model to also implement the ``predict_proba`` method for generating - predictions on the probabilities. + implement the ``fit`` and ``predict`` methods. It's recommended, but not required, for + the model to also implement ``predict_proba`` (used to populate ``pred_prob``) and + ``extract_features`` (used to populate ``data_example.embeddings`` for distance + functions such as ``similarity``). """ def __init__(self, base_model: Any) -> None: @@ -31,7 +34,7 @@ def __init__(self, base_model: Any) -> None: self.base_model = base_model - def predict(self, data_examples: ListData) -> Dict: + def predict(self, data_examples: ListData) -> Dict[str, Any]: """ Predict the labels and probabilities for the given data. @@ -47,8 +50,14 @@ def predict(self, data_examples: ListData) -> Dict: """ model = self.base_model data_X = data_examples.flatten("X") + embeddings = None if hasattr(model, "predict_proba"): prob = model.predict_proba(X=data_X) + if hasattr(model, "extract_features"): + try: + embeddings = model.extract_features(X=data_X) + except AttributeError: + embeddings = None label = prob.argmax(axis=1) prob = reform_list(prob, data_examples.X) else: @@ -58,6 +67,8 @@ def predict(self, data_examples: ListData) -> Dict: data_examples.pred_idx = label data_examples.pred_prob = prob + if embeddings is not None: + data_examples.embeddings = reform_list(embeddings, data_examples.X) return {"label": label, "prob": prob} @@ -138,3 +149,57 @@ def load(self, *args, **kwargs) -> None: this method should match those expected by the ``load`` method of self.base_model. """ self._model_operation("load", *args, **kwargs) + + +# ============================================================================= +# Multi-label variants +# ============================================================================= + + +class MultiLabelABLModel(ABLModel): + """ + Multi-label variant of :class:`ABLModel`. + + The standard :class:`ABLModel.predict` selects a single class index per + instance via ``argmax`` over a softmax distribution. For multi-label + settings (each instance can have multiple active labels), this class + instead thresholds the per-label sigmoid probabilities at 0.5 and + stores the resulting binary indicator vectors on ``pred_idx``. + + Pair it with :class:`~ablkit.learning.MultiLabelBasicNN` (which + provides ``predict_proba`` returning ``(num_samples, num_labels)`` + sigmoid probabilities) for the typical multi-label workflow. + """ + + def predict(self, data_examples: ListData) -> Dict[str, Any]: + """ + Predict per-label binary indicators and per-label probabilities. + + Parameters + ---------- + data_examples : ListData + A batch of data to predict on. + + Returns + ------- + Dict[str, Any] + A dictionary with keys ``"label"`` (binary indicator vectors + grouped per example) and ``"prob"`` (per-label probabilities + grouped per example, or ``None`` if the base model does not + expose ``predict_proba``). + """ + model = self.base_model + data_X = data_examples.flatten("X") + if hasattr(model, "predict_proba"): + prob = model.predict_proba(X=data_X) + label = np.where(prob > 0.5, 1, 0).astype(int) + prob = reform_list(prob, data_examples.X) + else: + prob = None + label = model.predict(X=data_X) + label = reform_list(label, data_examples.X) + + data_examples.pred_idx = label + data_examples.pred_prob = prob + + return {"label": label, "prob": prob} diff --git a/ablkit/learning/basic_nn.py b/ablkit/learning/basic_nn.py index 5ad44f03d..e921ceabe 100644 --- a/ablkit/learning/basic_nn.py +++ b/ablkit/learning/basic_nn.py @@ -15,7 +15,11 @@ from torch.utils.data import DataLoader from ..utils.logger import print_log -from .torch_dataset import ClassificationDataset, PredictionDataset +from .torch_dataset import ( + ClassificationDataset, + MultiLabelClassificationDataset, + PredictionDataset, +) class BasicNN: @@ -367,6 +371,86 @@ def predict_proba( ) return self._predict(data_loader).softmax(axis=1).cpu().numpy() + def _extract_features(self, data_loader: DataLoader) -> torch.Tensor: + """ + Internal method to compute feature embeddings via ``self.model.extract_features`` + over every batch in ``data_loader``. + + Parameters + ---------- + data_loader : DataLoader + DataLoader providing input samples. + + Returns + ------- + torch.Tensor + Concatenated feature tensor across all batches. + """ + if not isinstance(data_loader, DataLoader): + raise TypeError( + "data_loader must be an instance of torch.utils.data.DataLoader, " + f"but got {type(data_loader)}" + ) + if not hasattr(self.model, "extract_features"): + raise AttributeError( + f"{type(self.model).__name__} does not implement extract_features(x). " + "Add such a method to your PyTorch model to enable feature extraction " + "(used by dist_func='similarity', among others)." + ) + model = self.model + device = self.device + model.eval() + with torch.no_grad(): + results = [] + for data in data_loader: + data = data.to(device) + results.append(model.extract_features(data)) + return torch.cat(results, dim=0) + + def extract_features( + self, + data_loader: Optional[DataLoader] = None, + X: Optional[List[Any]] = None, + ) -> numpy.ndarray: + """ + Compute feature embeddings for ``X`` (or a prebuilt ``data_loader``). + When both are provided, ``data_loader`` takes precedence. + + The wrapped PyTorch model must implement ``extract_features(x)`` + returning the embedding tensor (typically penultimate-layer + activations) used by downstream consumers such as + ``dist_func='similarity'``. + + Parameters + ---------- + data_loader : DataLoader, optional + DataLoader to use directly. Defaults to None. + X : List[Any], optional + Raw input list; converted to a ``PredictionDataset`` when used. + Defaults to None. + + Returns + ------- + numpy.ndarray + Feature embeddings of shape ``(num_samples, embedding_dim)``. + """ + if data_loader is not None and X is not None: + print_log( + "Extracting features from data_loader; ignoring X.", + logger="current", + level=logging.WARNING, + ) + if data_loader is None: + dataset = PredictionDataset(X, self.test_transform) + data_loader = DataLoader( + dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + pin_memory=torch.cuda.is_available(), + ) + return self._extract_features(data_loader).cpu().numpy() + def _score(self, data_loader: DataLoader) -> Tuple[float, float]: """ Internal method to compute loss and accuracy for the data provided through a DataLoader. @@ -455,7 +539,7 @@ def score( else: data_loader = self._data_loader(X, y) mean_loss, accuracy = self._score(data_loader) - print_log(f"mean loss: {mean_loss:.3f}, accuray: {accuracy:.3f}", logger="current") + print_log(f"mean loss: {mean_loss:.3f}, accuracy: {accuracy:.3f}", logger="current") return accuracy def _data_loader( @@ -528,12 +612,12 @@ def save(self, epoch_id: int = 0, save_path: Optional[str] = None) -> None: print_log(f"Checkpoints will be saved to {save_path}", logger="current") - save_parma_dic = { + save_param_dict = { "model": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), } - torch.save(save_parma_dic, save_path) + torch.save(save_param_dict, save_path) def load(self, load_path: str) -> None: """ @@ -557,3 +641,109 @@ def load(self, load_path: str) -> None: self.model.load_state_dict(param_dic["model"]) if "optimizer" in param_dic.keys(): self.optimizer.load_state_dict(param_dic["optimizer"]) + + +# ============================================================================= +# Multi-label variants +# ============================================================================= + + +class MultiLabelBasicNN(BasicNN): + """ + A multi-label variant of :class:`BasicNN`. + + The standard :class:`BasicNN` assumes a single-label, multi-class + classification setting (softmax output, argmax prediction). In + contrast, :class:`MultiLabelBasicNN` treats each output dimension as + an independent binary decision (sigmoid output, threshold-based binary + vector prediction) and uses + :class:`~ablkit.learning.MultiLabelClassificationDataset` so that + targets can be fed straight into losses like ``BCEWithLogitsLoss``. + + Apart from prediction and dataset handling, the class reuses the full + training and evaluation pipeline from :class:`BasicNN`. + """ + + def predict( + self, + data_loader: Optional[DataLoader] = None, + X: Optional[List[Any]] = None, + ) -> numpy.ndarray: + """ + Return a binary indicator vector for each sample by thresholding + the per-label sigmoid probabilities at 0.5. + """ + if data_loader is not None and X is not None: + print_log( + "Predict the class of input data in data_loader instead of X.", + logger="current", + level=logging.WARNING, + ) + + if data_loader is None: + dataset = PredictionDataset(X, self.test_transform) + data_loader = DataLoader( + dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + pin_memory=torch.cuda.is_available(), + ) + pred_probs = self._predict(data_loader).sigmoid() + pred = torch.where(pred_probs > 0.5, 1, 0).int() + return pred.cpu().numpy() + + def predict_proba( + self, + data_loader: Optional[DataLoader] = None, + X: Optional[List[Any]] = None, + ) -> numpy.ndarray: + """ + Return per-label sigmoid probabilities of shape + ``(num_samples, num_labels)``. + """ + if data_loader is not None and X is not None: + print_log( + "Predict the class probability of input data in data_loader instead of X.", + logger="current", + level=logging.WARNING, + ) + + if data_loader is None: + dataset = PredictionDataset(X, self.test_transform) + data_loader = DataLoader( + dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + pin_memory=torch.cuda.is_available(), + ) + return self._predict(data_loader).sigmoid().cpu().numpy() + + def _data_loader( + self, + X: Optional[List[Any]], + y: Optional[List[int]] = None, + shuffle: Optional[bool] = True, + ) -> DataLoader: + """ + Build a DataLoader backed by + :class:`~ablkit.learning.MultiLabelClassificationDataset`. + """ + if X is None: + raise ValueError("X should not be None.") + if y is None: + y = [0] * len(X) + if not len(y) == len(X): + raise ValueError("X and y should have equal length.") + + dataset = MultiLabelClassificationDataset(X, y, transform=self.train_transform) + data_loader = DataLoader( + dataset, + batch_size=self.batch_size, + shuffle=shuffle, + num_workers=self.num_workers, + collate_fn=self.collate_fn, + pin_memory=torch.cuda.is_available(), + ) + return data_loader diff --git a/ablkit/learning/torch_dataset/__init__.py b/ablkit/learning/torch_dataset/__init__.py index b8237a2d7..f1a0d21e8 100644 --- a/ablkit/learning/torch_dataset/__init__.py +++ b/ablkit/learning/torch_dataset/__init__.py @@ -1,9 +1,11 @@ from .classification_dataset import ClassificationDataset +from .multi_label_classification_dataset import MultiLabelClassificationDataset from .prediction_dataset import PredictionDataset from .regression_dataset import RegressionDataset __all__ = [ "ClassificationDataset", + "MultiLabelClassificationDataset", "PredictionDataset", "RegressionDataset", ] diff --git a/ablkit/learning/torch_dataset/multi_label_classification_dataset.py b/ablkit/learning/torch_dataset/multi_label_classification_dataset.py new file mode 100644 index 000000000..3e5c68dd9 --- /dev/null +++ b/ablkit/learning/torch_dataset/multi_label_classification_dataset.py @@ -0,0 +1,47 @@ +""" +Implementation of PyTorch dataset class used for multi-label classification. + +Copyright (c) 2024 LAMDA. All rights reserved. +""" + +from typing import Any, Callable, List, Optional + +import numpy as np +import torch + +from .classification_dataset import ClassificationDataset + + +class MultiLabelClassificationDataset(ClassificationDataset): + """ + Dataset used for multi-label classification, where each target ``Y[i]`` + is a binary indicator vector (one entry per label) rather than a single + class index. ``Y`` is stored as a ``float32`` tensor so it can be fed + directly into ``BCEWithLogitsLoss`` and similar losses. + + Parameters + ---------- + X : List[Any] + The input data. + Y : List[Any] + The per-sample label vectors. Each entry is converted via + ``numpy.stack`` and stored as a ``FloatTensor``. + transform : Callable[..., Any], optional + A function/transform that takes an object and returns a transformed + version. Defaults to None. + """ + + def __init__( + self, + X: List[Any], + Y: List[Any], + transform: Optional[Callable[..., Any]] = None, + ) -> None: + if (not isinstance(X, list)) or (not isinstance(Y, list)): + raise ValueError("X and Y should be of type list.") + if len(X) != len(Y): + raise ValueError("Length of X and Y must be equal.") + + self.X = X + self.Y = torch.FloatTensor(np.stack(Y, axis=0)) + self.transform = transform diff --git a/ablkit/reasoning/__init__.py b/ablkit/reasoning/__init__.py index 9d5a219fe..b9e329f77 100644 --- a/ablkit/reasoning/__init__.py +++ b/ablkit/reasoning/__init__.py @@ -1,4 +1,11 @@ from .kb import GroundKB, KBBase, PrologKB -from .reasoner import Reasoner +from .reasoner import A3BLReasoner, Reasoner, VerificationReasoner -__all__ = ["KBBase", "GroundKB", "PrologKB", "Reasoner"] +__all__ = [ + "KBBase", + "GroundKB", + "PrologKB", + "Reasoner", + "A3BLReasoner", + "VerificationReasoner", +] diff --git a/ablkit/reasoning/kb.py b/ablkit/reasoning/kb.py index 3202d8dd1..9fcf7b58b 100644 --- a/ablkit/reasoning/kb.py +++ b/ablkit/reasoning/kb.py @@ -487,11 +487,14 @@ def __init__(self, pseudo_label_list: List[Any], pl_file: str): try: import pyswip # pylint: disable=import-outside-toplevel except (IndexError, ImportError): - print( - "A Prolog-based knowledge base is in use. Please install SWI-Prolog using the" - + "command 'sudo apt-get install swi-prolog' for Linux users, or download it " - + "following the guide in https://github.com/yuce/pyswip/blob/master/INSTALL.md " - + "for Windows and Mac users." + print_log( + "A Prolog-based knowledge base is in use. Please install SWI-Prolog using " + "the command 'sudo apt-get install swi-prolog' for Linux users, or download " + "it following the guide in " + "https://github.com/yuce/pyswip/blob/master/INSTALL.md " + "for Windows and Mac users.", + logger="current", + level=logging.WARNING, ) self.prolog = pyswip.Prolog() diff --git a/ablkit/reasoning/reasoner.py b/ablkit/reasoning/reasoner.py index 2905007dc..2b5b934be 100644 --- a/ablkit/reasoning/reasoner.py +++ b/ablkit/reasoning/reasoner.py @@ -5,15 +5,22 @@ Copyright (c) 2024 LAMDA. All rights reserved. """ +import heapq import inspect -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, Iterator, List, Optional, Tuple, Union import numpy as np from zoopt import Dimension, Objective, Opt, Parameter, Solution from ..data.structures import ListData from ..reasoning import KBBase -from ..utils.utils import hamming_dist, confidence_dist, avg_confidence_dist +from ..utils.utils import ( + avg_confidence_dist, + confidence_dist, + hamming_dist, + rejection_dist, + similarity_dist, +) class Reasoner: @@ -30,12 +37,17 @@ class Reasoner: measure, wherein the candidate with lowest cost is selected as the final abduced label. It can be either a string representing a predefined distance function or a callable function. The available predefined distance functions: - 'hamming' | 'confidence' | 'avg_confidence'. 'hamming' directly calculates the - Hamming distance between the predicted pseudo-label in the data example and each - candidate. 'confidence' and 'avg_confidence' calculates the confidence distance - between the predicted probabilities in the data example and each candidate, where - the confidence distance is defined as 1 - the product of prediction probabilities - in 'confidence' and 1 - the average of prediction probabilities in 'avg_confidence'. + 'hamming' | 'confidence' | 'avg_confidence' | 'similarity' | 'rejection'. + 'hamming' directly calculates the Hamming distance between the predicted + pseudo-label in the data example and each candidate. 'confidence' and + 'avg_confidence' calculate the confidence distance between the predicted + probabilities and each candidate, defined as ``1 - product`` and + ``1 - average`` of the candidate's per-symbol probabilities respectively. + 'similarity' compares candidates against the geometry of the model's + embeddings (requires the base model to expose ``extract_features``; + ``ABLModel`` then stores the result on ``data_example.embeddings``). + 'rejection' combines confidence distance with a candidate-complexity penalty, + favoring shorter candidates when scores are close. Alternatively, the callable function should have the signature ``dist_func(data_example, candidates, candidate_idxs, reasoning_results)`` and must return a cost list. Each element in this cost list should be a numerical value @@ -83,10 +95,11 @@ def __init__( def _check_valid_dist(self, dist_func): if isinstance(dist_func, str): - if dist_func not in ["hamming", "confidence", "avg_confidence"]: + valid = ["hamming", "confidence", "avg_confidence", "similarity", "rejection"] + if dist_func not in valid: raise NotImplementedError( - 'Valid options for predefined dist_func include "hamming", ' - + f'"confidence" and "avg_confidence", but got {dist_func}.' + f"Valid options for predefined dist_func are {valid}, " + f"but got {dist_func!r}." ) return elif callable(dist_func): @@ -179,9 +192,22 @@ def _get_cost_list( elif self.dist_func == "avg_confidence": candidates_idxs = [[self.label_to_idx[x] for x in c] for c in candidates] return avg_confidence_dist(data_example.pred_prob, candidates_idxs) + elif self.dist_func == "similarity": + embeddings = getattr(data_example, "embeddings", None) + if embeddings is None: + raise ValueError( + "dist_func='similarity' requires the base model to expose an " + "extract_features(X=...) method so ABLModel can populate " + "data_example.embeddings." + ) + candidates_idxs = [[self.label_to_idx[x] for x in c] for c in candidates] + return similarity_dist(embeddings, candidates_idxs=candidates_idxs) + elif self.dist_func == "rejection": + candidates_idxs = [[self.label_to_idx[x] for x in c] for c in candidates] + return rejection_dist(data_example.pred_prob, candidates_idxs=candidates_idxs) else: - candidate_idxs = [[self.label_to_idx[x] for x in c] for c in candidates] - cost_list = self.dist_func(data_example, candidates, candidate_idxs, reasoning_results) + candidates_idxs = [[self.label_to_idx[x] for x in c] for c in candidates] + cost_list = self.dist_func(data_example, candidates, candidates_idxs, reasoning_results) if len(cost_list) != len(candidates): raise ValueError( "The length of the array returned by dist_func must be equal to the number " @@ -354,5 +380,444 @@ def batch_abduce(self, data_examples: ListData) -> List[List[Any]]: data_examples.abduced_pseudo_label = abduced_pseudo_label return abduced_pseudo_label + def batch_supervised_abduce(self, data_examples: ListData) -> List[List[Any]]: + """ + Perform abductive reasoning on the given prediction data examples, using supervised data + when gt_pseudo_label is given. + """ + abduced_pseudo_label = [ + ( + data_example.gt_pseudo_label + if data_example.gt_pseudo_label + else self.abduce(data_example) + ) + for data_example in data_examples + ] + data_examples.abduced_pseudo_label = abduced_pseudo_label + return abduced_pseudo_label + + def __call__(self, data_examples: ListData) -> List[List[Any]]: + return self.batch_abduce(data_examples) + + +# ============================================================================= +# A3BL: Ambiguity-Aware Abductive Learning +# +# Reference: https://github.com/Hao-Yuan-He/A3BL +# ============================================================================= + + +class A3BLReasoner(Reasoner): + """ + Reasoner for minimizing the inconsistency between the knowledge base and learning models. + + Parameters + ---------- + kb : class KBBase + The knowledge base to be used for reasoning. + dist_func : Union[str, Callable], optional + The distance function used to determine the cost list between each + candidate and the given prediction. The cost is also referred to as a consistency + measure, wherein the candidate with the lowest cost is selected as the final + abduced label. It can be either a string representing a predefined distance + function or a callable function. The available predefined distance functions: + 'hamming' | 'confidence' | 'avg_confidence' | 'similarity' | 'rejection'. + See :class:`Reasoner` for the full description of each option. + Defaults to 'confidence'. + idx_to_label : dict, optional + A mapping from index in the base model to label. If not provided, a default + order-based index to label mapping is created. Defaults to None. + max_revision : Union[int, float], optional + The upper limit on the number of revisions for each data example when + performing abductive reasoning. If float, denotes the fraction of the total + length that can be revised. A value of -1 implies no restriction on the + number of revisions. Defaults to -1. + require_more_revision : int, optional + Specifies additional number of revisions permitted beyond the minimum required + when performing abductive reasoning. Defaults to 0. + use_zoopt : bool, optional + Whether to use ZOOpt library during abductive reasoning. Defaults to False. + topK : int, optional + Number of top-ranked candidates to keep when forming the soft label. ``-1`` + keeps all candidates. Defaults to 16. + temperature : float, optional + Softmax temperature used when aggregating candidate probabilities into a + soft label. Lower values produce sharper distributions. Defaults to 0.2. + multi_label : bool, optional + Whether the underlying task is multi-label (each symbol is a binary vector + rather than a single class index). Defaults to False. + """ + + def __init__( + self, + kb, + dist_func="confidence", + idx_to_label=None, + max_revision: Union[int, float] = -1, + require_more_revision: int = 0, + use_zoopt: bool = False, + topK: int = 16, + temperature: float = 0.2, + multi_label: bool = False, + ): + super().__init__( + kb, dist_func, idx_to_label, max_revision, require_more_revision, use_zoopt + ) + import torch + + self.topK = topK + self.temperature = temperature + self.class_num = len(self.kb.pseudo_label_list) + self.multi_label = multi_label + self.device = "cuda" if torch.cuda.is_available() else "cpu" + + def _confidence_dist( + self, pred_probs: np.ndarray, candidate_idxs: List[List[Any]], temp: float = 1.0 + ) -> np.ndarray: + from scipy.special import softmax + + candidates_array = np.array(candidate_idxs) + _, symbol_num = candidates_array.shape + row_indices = np.arange(symbol_num)[:, np.newaxis] + selected_probs = pred_probs[row_indices, candidates_array.T] + candidate_probs = np.sum(selected_probs, axis=0) / temp + return softmax(candidate_probs) + + def _confidence_dist_multi_label( + self, pred_probs: np.ndarray, candidate_idxs: List[List[Any]], temp: float = 1.0 + ) -> np.ndarray: + from scipy.special import softmax + + candidate_probs = pred_probs @ np.array(candidate_idxs).T / temp + return softmax(candidate_probs.squeeze(axis=0)) + + def _candidates_idxs(self, candidates: List[List[Any]]): + return [[self.label_to_idx[x] for x in c] for c in candidates] + + def _topk( + self, candidates: List[Any], candidate_probs: np.ndarray, K: int = -1 + ) -> Tuple[List[List[Any]], List[Any]]: + """ + Performs a top-k selection from the candidate_set based on candidate_probs. + If `K` is set to -1, all candidates are chosen. + Returns a tuple containing the selected candidates and their corresponding probabilities. + """ + import heapq + + if K == -1 or len(candidates) <= K: + return candidates, candidate_probs + + # Iterate over all candidates and maintain a heap of size K with the largest probabilities + heap = [] + for i, (candidate, prob) in enumerate(zip(candidates, candidate_probs)): + if i < K: + heapq.heappush(heap, (prob, candidate)) + else: + if prob > heap[0][0]: + heapq.heappop(heap) + heapq.heappush(heap, (prob, candidate)) + + # Extract top-k elements from the heap, + # and reverse them to get the highest probabilities first + topk_probs, topk_candidates = zip(*heap) + return list(topk_candidates), list(topk_probs) + + def multi_label_aggregate(self, candidates: List[List[int]], candidate_probs: List[float]): + """ + An multi-label version of A3BL. + """ + import torch + + with torch.no_grad(): + symbol_num = len(candidates[0]) + aggregate_label = torch.zeros(size=(symbol_num, 1)) + for candidate, prob in zip(candidates, candidate_probs): + for i, item in enumerate(candidate): + if item == 1: + aggregate_label[i] += prob + return list(aggregate_label.unbind(1)) + + def aggregate(self, candidates: List[List[int]], candidate_probs: List[float]): + import torch + import torch.nn.functional as F + + with torch.no_grad(): + candidates_tensor = torch.tensor(candidates, device=self.device, dtype=torch.long) + probs_tensor = torch.tensor(candidate_probs, device=self.device, dtype=torch.float32) + one_hot = F.one_hot(candidates_tensor, num_classes=self.class_num).float() # [N, M, C] + weighted_one_hot = one_hot * probs_tensor.unsqueeze(-1).unsqueeze(-1) # [N, M, C] + aggregate_label = weighted_one_hot.sum(dim=0) # [M, C] + return [tensor.cpu() for tensor in aggregate_label.unbind(0)] + + def abduce(self, data_example: ListData) -> Tuple[List[Any], List[Any]]: + """ + Perform abduction and get a soft label distribution aggregated from + all valid candidates that satisfy the underlying rules. + + Parameters + ---------- + data_example : ListData + Data example. + + Returns + ------- + soft_label : List[Any] + Soft label aggregated from the top-k valid candidates. + pseudo_label : List[Any] + Hard pseudo-label revision (the top-1 candidate) that is + consistent with the knowledge base. + """ + max_revision_num = data_example.elements_num("pred_pseudo_label") + max_revision_num = self._get_max_revision_num(self.max_revision, max_revision_num) + candidates, _ = self.kb.abduce_candidates( + pseudo_label=data_example.pred_pseudo_label, + y=data_example.Y, + x=data_example.X, + max_revision_num=max_revision_num, + require_more_revision=self.require_more_revision, + ) + + if len(candidates) == 0: + return [], [] + + confidence_dist_cal = ( + self._confidence_dist if not self.multi_label else self._confidence_dist_multi_label + ) + + candidate_probs = confidence_dist_cal( + data_example.pred_prob, self._candidates_idxs(candidates), self.temperature + ) + topk_candidates, topk_candidates_probs = self._topk(candidates, candidate_probs, self.topK) + aggregated_labels = ( + self.aggregate(topk_candidates, topk_candidates_probs) + if not self.multi_label + else self.multi_label_aggregate(topk_candidates, topk_candidates_probs) + ) + return aggregated_labels, topk_candidates[0] + + def batch_abduce(self, data_examples: ListData) -> List[List[Any]]: + """ + Perform abductive reasoning on the given prediction data examples. + For detailed information, refer to ``abduce``. + """ + abduced_soft_label, abduced_pseudo_label = zip( + *[self.abduce(data_example) for data_example in data_examples] + ) + data_examples.abduced_soft_label = abduced_soft_label + data_examples.abduced_pseudo_label = abduced_pseudo_label + return abduced_soft_label + def __call__(self, data_examples: ListData) -> List[List[Any]]: return self.batch_abduce(data_examples) + + +# ============================================================================= +# Verification Learning +# +# Walks the per-symbol probability lattice in descending joint-probability +# order, collecting the first top_k assignments that satisfy the knowledge +# base. Reference: https://github.com/VerificationLearning/VerificationLearning +# ============================================================================= + + +def enumerate_label_assignments( + pred_prob: np.ndarray, max_iter: int = 10000 +) -> Iterator[Tuple[List[int], float, List[float]]]: + """ + Yield label-index assignments for a single data example in descending + joint-probability order. The walk is a Lawler-style best-first search: + each state is the tuple of per-symbol rank indices, and successors are + generated by advancing any one symbol to its next-best class. + + Parameters + ---------- + pred_prob : np.ndarray + Per-symbol probability matrix with shape ``(num_symbols, num_classes)``. + max_iter : int, optional + Hard cap on the number of yields. Defaults to 10000. + + Yields + ------ + labels : List[int] + Class indices for each symbol. + joint_prob : float + Product of the chosen per-symbol probabilities. + per_symbol_probs : List[float] + The chosen probability for each symbol. + """ + pred_prob = np.asarray(pred_prob, dtype=float) + num_symbols, num_classes = pred_prob.shape + if num_symbols == 0: + return + + sorted_indices = np.argsort(-pred_prob, axis=1) + sorted_probs = np.take_along_axis(pred_prob, sorted_indices, axis=1) + + initial_state = (0,) * num_symbols + initial_prob = float(np.prod(sorted_probs[:, 0])) + + seen = {initial_state} + heap: List[Tuple[float, Tuple[int, ...]]] = [(-initial_prob, initial_state)] + yields = 0 + + while heap and yields < max_iter: + neg_prob, state = heapq.heappop(heap) + joint_prob = -neg_prob + labels = [int(sorted_indices[i, state[i]]) for i in range(num_symbols)] + per_symbol_probs = [float(sorted_probs[i, state[i]]) for i in range(num_symbols)] + yield labels, joint_prob, per_symbol_probs + yields += 1 + + for sym in range(num_symbols): + next_rank = state[sym] + 1 + if next_rank >= num_classes: + continue + new_state = state[:sym] + (next_rank,) + state[sym + 1:] + if new_state in seen: + continue + seen.add(new_state) + current_p = sorted_probs[sym, state[sym]] + next_p = sorted_probs[sym, next_rank] + if current_p <= 0: + new_joint = 0.0 + else: + new_joint = joint_prob * (next_p / current_p) + heapq.heappush(heap, (-new_joint, new_state)) + + +def top_k_satisfying( + pred_prob: np.ndarray, + predicate: Callable[[List[Any]], bool], + top_k: int = 1, + max_iter: int = 10000, + idx_to_label: Optional[dict] = None, +) -> Tuple[List[List[Any]], List[float]]: + """ + Walk label assignments in descending joint-probability order and return + the first ``top_k`` that satisfy ``predicate``. If none is found within + ``max_iter`` iterations the single highest-probability assignment is + returned as a fallback so callers always receive a usable label. + + Parameters + ---------- + pred_prob : np.ndarray + Per-symbol probability matrix with shape ``(num_symbols, num_classes)``. + predicate : Callable[[List[Any]], bool] + Function called on each candidate label sequence; truthy means the + candidate is consistent with the knowledge base. + top_k : int, optional + Maximum number of satisfying candidates to return. Defaults to 1. + max_iter : int, optional + Hard cap on enumeration steps. Defaults to 10000. + idx_to_label : dict, optional + Optional mapping from class index to pseudo-label. When omitted, the + raw class indices are returned. + + Returns + ------- + candidates : List[List[Any]] + Label assignments that satisfy ``predicate`` (or the fallback). + probs : List[float] + Joint probability of each returned candidate. + """ + matches: List[List[Any]] = [] + probs: List[float] = [] + fallback: Optional[Tuple[List[Any], float]] = None + + for labels_idx, joint_prob, _ in enumerate_label_assignments(pred_prob, max_iter): + labels = ( + [idx_to_label[i] for i in labels_idx] if idx_to_label is not None else labels_idx + ) + if fallback is None: + fallback = (labels, joint_prob) + if predicate(labels): + matches.append(labels) + probs.append(joint_prob) + if len(matches) >= top_k: + break + + if not matches and fallback is not None: + matches.append(fallback[0]) + probs.append(fallback[1]) + return matches, probs + + +class VerificationReasoner: + """ + Reasoner used by :class:`~ablkit.bridge.VerificationBridge`. Rather than + picking a single best candidate via a distance function, it enumerates + the top ``top_k`` label assignments that satisfy the knowledge base, + ordered by joint probability. The bridge then trains the model on each + of those candidates. + + Parameters + ---------- + kb : KBBase + The knowledge base used to verify candidates. ``kb.logic_forward`` + must return the reasoning result so it can be compared with each + data example's ``Y``. + top_k : int, optional + Number of satisfying candidates to enumerate per example. + Defaults to 1. + max_iter : int, optional + Maximum number of enumeration steps per example before giving up + and returning the fallback. Defaults to 10000. + idx_to_label : dict, optional + A mapping from base-model index to pseudo-label. If omitted a + default order-based mapping is built from ``kb.pseudo_label_list``. + """ + + def __init__( + self, + kb: KBBase, + top_k: int = 1, + max_iter: int = 10000, + idx_to_label: Optional[dict] = None, + ) -> None: + if top_k < 1: + raise ValueError("top_k must be >= 1.") + if max_iter < 1: + raise ValueError("max_iter must be >= 1.") + self.kb = kb + self.top_k = top_k + self.max_iter = max_iter + if idx_to_label is None: + idx_to_label = dict(enumerate(kb.pseudo_label_list)) + self.idx_to_label = idx_to_label + self.label_to_idx = {label: idx for idx, label in idx_to_label.items()} + + def top_k_candidates( + self, pred_prob: np.ndarray, y: Any + ) -> Tuple[List[List[Any]], List[float]]: + """ + Return up to ``top_k`` label assignments for one data example whose + ``kb.logic_forward`` matches ``y``. + """ + + def predicate(labels: List[Any]) -> bool: + return self.kb.logic_forward(labels) == y + + return top_k_satisfying( + pred_prob, + predicate, + top_k=self.top_k, + max_iter=self.max_iter, + idx_to_label=self.idx_to_label, + ) + + def batch_top_k(self, data_examples) -> List[List[List[Any]]]: + """ + Run :meth:`top_k_candidates` on every example in ``data_examples``. + Stores the result on ``data_examples.top_k_candidates`` and + ``data_examples.top_k_probs``. Returns the list of per-example + candidate lists. + """ + all_candidates: List[List[List[Any]]] = [] + all_probs: List[List[float]] = [] + for data_example in data_examples: + cands, probs = self.top_k_candidates(data_example.pred_prob, data_example.Y) + all_candidates.append(cands) + all_probs.append(probs) + data_examples.top_k_candidates = all_candidates + data_examples.top_k_probs = all_probs + return all_candidates diff --git a/ablkit/utils/__init__.py b/ablkit/utils/__init__.py index 7962745cf..c1214d290 100644 --- a/ablkit/utils/__init__.py +++ b/ablkit/utils/__init__.py @@ -1,13 +1,15 @@ from .cache import Cache, abl_cache from .logger import ABLLogger, print_log from .utils import ( - confidence_dist, avg_confidence_dist, + confidence_dist, flatten, hamming_dist, reform_list, - to_hashable, + rejection_dist, + similarity_dist, tab_data_to_tuple, + to_hashable, ) __all__ = [ @@ -19,6 +21,8 @@ "flatten", "hamming_dist", "reform_list", + "rejection_dist", + "similarity_dist", "to_hashable", "abl_cache", "tab_data_to_tuple", diff --git a/ablkit/utils/utils.py b/ablkit/utils/utils.py index abc1e3561..02c60a4ab 100644 --- a/ablkit/utils/utils.py +++ b/ablkit/utils/utils.py @@ -138,6 +138,79 @@ def avg_confidence_dist(pred_prob: np.ndarray, candidates_idxs: List[List[Any]]) return 1 - np.average(pred_prob[cols, candidates_idxs], axis=1) +def similarity_dist( + pred_embeddings: np.ndarray, candidates_idxs: List[List[Any]] +) -> np.ndarray: + """ + Compute a similarity-based cost for each candidate label assignment. + + For each candidate, the cost is the average cosine similarity between + symbol pairs assigned different labels minus the average between pairs + assigned the same label. Lower values mean the candidate's labeling is + more consistent with the embedding geometry. + + Parameters + ---------- + pred_embeddings : np.ndarray + Embedding matrix for the symbols in a single data example, with shape + ``(num_symbols, embedding_dim)``. + candidates_idxs : List[List[Any]] + Candidate label assignments, each of length ``num_symbols``. + + Returns + ------- + np.ndarray + Cost for each candidate. + """ + phi = np.asarray(pred_embeddings, dtype=float) + norms = np.linalg.norm(phi, axis=1, keepdims=True) + phi_normalized = phi / (norms + 1e-8) + sim = phi_normalized @ phi_normalized.T + num_symbols = sim.shape[0] + triu_i, triu_j = np.triu_indices(num_symbols, k=1) + pair_sims = sim[triu_i, triu_j] + + costs = [] + for cand in candidates_idxs: + labels = np.asarray(cand) + same = labels[triu_i] == labels[triu_j] + intra = pair_sims[same].mean() if same.any() else 0.0 + inter = pair_sims[~same].mean() if (~same).any() else 0.0 + costs.append(inter - intra) + return np.asarray(costs) + + +def rejection_dist( + pred_prob: np.ndarray, candidates_idxs: List[List[Any]], alpha: float = 0.5 +) -> np.ndarray: + """ + Compute a rejection-aware cost that combines model confidence with + candidate complexity. Each candidate's cost is a convex combination of + the standard confidence distance and a normalized length term, so + longer (more complex) candidates are penalized. + + Parameters + ---------- + pred_prob : np.ndarray + Prediction probability distributions for the symbols in a single + data example. + candidates_idxs : List[List[Any]] + Candidate label assignments. + alpha : float, optional + Weight in ``[0, 1]`` for the complexity term. Defaults to 0.5. + + Returns + ------- + np.ndarray + Cost for each candidate. + """ + conf = confidence_dist(pred_prob, candidates_idxs) + complexity = np.array([len(c) for c in candidates_idxs], dtype=float) + if complexity.max() > 0: + complexity = complexity / complexity.max() + return (1 - alpha) * conf + alpha * complexity + + def to_hashable(x: Union[List[Any], Any]) -> Union[Tuple[Any, ...], Any]: """ Convert a nested list to a nested tuple so it is hashable. diff --git a/docs/API/ablkit.bridge.rst b/docs/API/ablkit.bridge.rst index 3cfc12702..7140bc24d 100644 --- a/docs/API/ablkit.bridge.rst +++ b/docs/API/ablkit.bridge.rst @@ -1,7 +1,7 @@ ablkit.bridge -================== +============= .. automodule:: ablkit.bridge :members: :undoc-members: - :show-inheritance: \ No newline at end of file + :show-inheritance: diff --git a/docs/API/ablkit.data.rst b/docs/API/ablkit.data.rst index 32bc5d26d..d935b04f3 100644 --- a/docs/API/ablkit.data.rst +++ b/docs/API/ablkit.data.rst @@ -1,5 +1,5 @@ ablkit.data -=================== +=========== ``structures`` -------------- diff --git a/docs/API/ablkit.learning.rst b/docs/API/ablkit.learning.rst index f26a866e2..0b0c2683f 100644 --- a/docs/API/ablkit.learning.rst +++ b/docs/API/ablkit.learning.rst @@ -1,5 +1,5 @@ ablkit.learning -================== +=============== .. autoclass:: ablkit.learning.ABLModel :members: diff --git a/docs/API/ablkit.reasoning.rst b/docs/API/ablkit.reasoning.rst index 8707fc308..9de28556c 100644 --- a/docs/API/ablkit.reasoning.rst +++ b/docs/API/ablkit.reasoning.rst @@ -1,7 +1,36 @@ ablkit.reasoning -================== +================ -.. automodule:: ablkit.reasoning +.. autoclass:: ablkit.reasoning.KBBase :members: :undoc-members: - :show-inheritance: \ No newline at end of file + :show-inheritance: + +.. autoclass:: ablkit.reasoning.GroundKB + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: ablkit.reasoning.PrologKB + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: ablkit.reasoning.Reasoner + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: ablkit.reasoning.A3BLReasoner + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: ablkit.reasoning.VerificationReasoner + :members: + :undoc-members: + :show-inheritance: + +.. autofunction:: ablkit.reasoning.reasoner.enumerate_label_assignments + +.. autofunction:: ablkit.reasoning.reasoner.top_k_satisfying diff --git a/docs/API/ablkit.utils.rst b/docs/API/ablkit.utils.rst index b963a8963..604024ce0 100644 --- a/docs/API/ablkit.utils.rst +++ b/docs/API/ablkit.utils.rst @@ -1,7 +1,20 @@ ablkit.utils -================== +============ .. automodule:: ablkit.utils :members: + ABLLogger, + print_log, + Cache, + abl_cache, + hamming_dist, + confidence_dist, + avg_confidence_dist, + similarity_dist, + rejection_dist, + flatten, + reform_list, + to_hashable, + tab_data_to_tuple :undoc-members: - :show-inheritance: \ No newline at end of file + :show-inheritance: diff --git a/docs/Examples/BDD-OIA.rst b/docs/Examples/BDD-OIA.rst new file mode 100644 index 000000000..df75278ae --- /dev/null +++ b/docs/Examples/BDD-OIA.rst @@ -0,0 +1,216 @@ +BDD-OIA +======= + +.. raw:: html + +

For detailed code implementation, please view it on GitHub.

+ +Below shows an implementation of `BDD-OIA `__. +The BDD-OIA dataset comprises frames extracted from driving scene videos +that are used for autonomous driving predictions. Each frame is +annotated with 4 binary action labels (:math:`\textsf{move_forward}`, +:math:`\textsf{stop}`, :math:`\textsf{turn_left}`, :math:`\textsf{turn_right}`), +as well as 21 intermediate binary concept labels such as +:math:`\textsf{red_light}` and :math:`\textsf{road_clear}` that explain those +actions. + +The objective is to predict the possible actions for each frame. +During training we use only the action-level supervision together with +a knowledge base that captures the relations between concepts and +actions, e.g., +:math:`\textsf{red_light} \lor \textsf{traffic_sign} \lor \textsf{obstacle} \implies \textsf{stop}`. +The training set contains 16,000 frames; the test set contains 4,500. + +Intuitively, the learning part predicts the 21 binary concept +pseudo-labels from each frame, and the reasoning part uses the +knowledge base to derive the four action labels from those concepts. +When the learning part's predictions conflict with the ground-truth +actions, the reasoner revises the concepts via abductive reasoning, +and those revised concepts are used to further train the learning +part. + +The dataset was preprocessed by `Marconato et al. (2023) `__ +with a pretrained Faster-RCNN on BDD-100k together with the first +module of CBM-AUC `(Sawada & Nakamura, 2022) `__, +yielding a 2048-dimensional visual feature for each frame. + +.. code:: python + + # Import necessary libraries and modules + import os.path as osp + + import numpy as np + import torch + import torch.nn as nn + from torch import optim + + from ablkit.data.evaluation import SymbolAccuracy + from ablkit.learning import MultiLabelABLModel, MultiLabelBasicNN + from ablkit.reasoning import KBBase, Reasoner + from ablkit.utils import ABLLogger, print_log + + from bridge import BDDBridge + from dataset.data_util import get_dataset + from metric import BDDReasoningMetric + from models.nn import ConceptNet + +Working with Data +----------------- + +First, we load the training, validation, and testing splits: + +.. code:: python + + train_data = get_dataset(fname="train.npz", get_pseudo_label=True) + val_data = get_dataset(fname="val.npz", get_pseudo_label=True) + test_data = get_dataset(fname="test.npz", get_pseudo_label=True) + +Each split consists of three components (``X``, ``gt_pseudo_label``, +and ``Y``) with one entry per frame: + +- ``X[i]`` is a list with a single ndarray of shape ``(2048,)``, the + pre-extracted visual feature for the frame. +- ``gt_pseudo_label[i]`` is a list of length 21 holding the binary + concept annotations (``red_light``, ``road_clear``, …). +- ``Y[i]`` is a tuple of length 4 holding the binary action labels + (``move_forward``, ``stop``, ``turn_left``, ``turn_right``). + +During training only ``X`` and ``Y`` are used; ``gt_pseudo_label`` is +held back for evaluation. + +Building the Learning Part +-------------------------- + +To build the learning part we first construct a PyTorch model, +``ConceptNet``, then wrap it in +:class:`~ablkit.learning.MultiLabelBasicNN` to obtain an sklearn-style +base model. ``MultiLabelBasicNN`` is a multi-label variant of +``BasicNN``: the output uses sigmoid activations rather than softmax, +predictions are binary vectors rather than single class indices, and +the dataset is a +:class:`~ablkit.learning.MultiLabelClassificationDataset`. The 21 +outputs therefore correspond to the 21 binary concept labels. + +.. code:: python + + net = ConceptNet() + loss_fn = nn.BCEWithLogitsLoss() + optimizer = optim.Adam(net.parameters(), lr=0.002) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + scheduler = optim.lr_scheduler.OneCycleLR( + optimizer, + max_lr=0.002, + pct_start=0.15, + epochs=2, + steps_per_epoch=int(1 / 0.01) + 1, + ) + + base_model = MultiLabelBasicNN( + net, + loss_fn, + optimizer, + scheduler=scheduler, + device=device, + batch_size=32, + num_epochs=1, + ) + +``MultiLabelBasicNN`` operates on a single frame at a time. To work at +the example level (a frame together with its label set), we wrap the +base model in :class:`~ablkit.learning.MultiLabelABLModel`, an +``ABLModel`` subclass that threshold-binarises the sigmoid +probabilities into per-concept 0/1 pseudo-labels. + +.. code:: python + + model = MultiLabelABLModel(base_model) + +Building the Reasoning Part +--------------------------- + +The knowledge base ``BDDKB`` encodes the rules linking the 21 +concepts to the 4 actions (e.g., ``red_light`` or ``obstacle`` imply +``stop``; ``green_light`` together with ``road_clear`` implies +``move_forward``). It subclasses ``KBBase``; the ``pseudo_label_list`` +parameter is ``[0, 1]`` because each pseudo-label is binary, and the +``logic_forward`` method computes the 4-tuple of action labels from +the 21 concept attributes. + +.. code:: python + + from reasoning.bddkb import BDDKB + + kb = BDDKB() + +Since abductive reasoning is non-deterministic, multiple concept +revisions can be consistent with the ground-truth actions. The +``Reasoner`` picks the revision that minimises a user-supplied +distance function. For BDD-OIA we provide +``multi_label_confidence_dist``, which sums ``-log(p)`` over the +concept-by-concept probabilities so that revisions consistent with the +learning part's per-concept confidence are preferred: + +.. code:: python + + def multi_label_confidence_dist(data_example, candidates, candidates_idxs, reasoning_results): + pred_prob = data_example.pred_prob.T # nc x 1 + pred_prob = np.concatenate([1 - pred_prob, pred_prob], axis=1) # nc x 2 + cols = np.arange(len(candidates_idxs[0]))[None, :] + corr_prob = pred_prob[cols, candidates_idxs] + costs = -np.sum(np.log(corr_prob + 1e-6), axis=1) + return costs + + reasoner = Reasoner( + kb, + dist_func=multi_label_confidence_dist, + max_revision=3, + require_more_revision=3, + ) + +``max_revision`` and ``require_more_revision`` cap how many concept +flips the reasoner explores when searching for a consistent +revision. + +Building Evaluation Metrics +--------------------------- + +We track two metrics. ``SymbolAccuracy`` measures how often the +predicted concepts match the ground-truth concepts, and +``BDDReasoningMetric`` measures the per-action accuracy after +passing the predicted concepts through ``logic_forward``. + +.. code:: python + + metric_list = [ + SymbolAccuracy(prefix="bdd_oia"), + BDDReasoningMetric(kb=kb, prefix="bdd_oia"), + ] + +Bridging Learning and Reasoning +------------------------------- + +Finally we bridge the learning and reasoning parts via ``BDDBridge``, +a thin subclass of ``SimpleBridge`` that handles the +multi-label-specific shape of ``pred_idx`` (a ``[1, nc]`` ndarray per +example). + +.. code:: python + + bridge = BDDBridge(model, reasoner, metric_list) + +Training and testing reuse the standard ``SimpleBridge`` interface: + +.. code:: python + + print_log("Abductive Learning on the BDD_OIA example.", logger="current") + log_dir = ABLLogger.get_current_instance().log_dir + weights_dir = osp.join(log_dir, "weights") + + bridge.train( + train_data, + loops=2, + segment_size=0.01, + save_interval=1, + save_dir=weights_dir, + ) + bridge.test(test_data) diff --git a/docs/Examples/HED.rst b/docs/Examples/HED.rst index d8fbed02a..6c4851445 100644 --- a/docs/Examples/HED.rst +++ b/docs/Examples/HED.rst @@ -179,13 +179,13 @@ sklearn-style interface. .. code:: python # class of symbol may be one of ['0', '1', '+', '='], total of 4 classes - cls = SymbolNet(num_classes=4) + net = SymbolNet(num_classes=4) loss_fn = nn.CrossEntropyLoss() - optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001, weight_decay=1e-4) + optimizer = torch.optim.RMSprop(net.parameters(), lr=0.001, weight_decay=1e-4) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - + base_model = BasicNN( - cls, + net, loss_fn, optimizer, device=device, diff --git a/docs/Examples/HWF.rst b/docs/Examples/HWF.rst index 9ab0c54d4..4b801c1d5 100644 --- a/docs/Examples/HWF.rst +++ b/docs/Examples/HWF.rst @@ -8,7 +8,7 @@ Handwritten Formula (HWF) Below shows an implementation of `Handwritten Formula `__. In this task, handwritten images of decimal formulas and their computed results -are given, alongwith a domain knowledge base containing information on +are given, along with a domain knowledge base containing information on how to compute the decimal formula. The task is to recognize the symbols (which can be digits or operators ‘+’, ‘-’, ‘×’, ‘÷’) of handwritten images and accurately determine their results. @@ -100,7 +100,7 @@ Out: The ith element of X, gt_pseudo_label, and Y together constitute the ith data example. Here we use two of them (the 1001st and the 3001st) as -illstrations: +illustrations: .. code:: python @@ -177,14 +177,14 @@ sklearn-style interface. .. code:: python - # class of symbol may be one of ['1', ..., '9', '+', '-', '*', '/'], total of 14 classes - cls = SymbolNet(num_classes=13, image_size=(45, 45, 1)) + # class of symbol may be one of ['1', ..., '9', '+', '-', '*', '/'], total of 13 classes + net = SymbolNet(num_classes=13, image_size=(45, 45, 1)) loss_fn = nn.CrossEntropyLoss() - optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99)) + optimizer = torch.optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.99)) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - + base_model = BasicNN( - model=cls, + model=net, loss_fn=loss_fn, optimizer=optimizer, device=device, diff --git a/docs/Examples/MNISTAdd.rst b/docs/Examples/MNISTAdd.rst index b90bb86e3..850921958 100644 --- a/docs/Examples/MNISTAdd.rst +++ b/docs/Examples/MNISTAdd.rst @@ -7,7 +7,7 @@ MNIST Addition Below shows an implementation of `MNIST Addition `__. In this task, pairs of -MNIST handwritten images and their sums are given, alongwith a domain +MNIST handwritten images and their sums are given, along with a domain knowledge base containing information on how to perform addition operations. The task is to recognize the digits of handwritten images and accurately determine their sum. @@ -147,14 +147,14 @@ model with a sklearn-style interface. .. code:: python - cls = LeNet5(num_classes=10) + net = LeNet5(num_classes=10) loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1) - optimizer = RMSprop(cls.parameters(), lr=0.001, alpha=0.9) + optimizer = RMSprop(net.parameters(), lr=0.001, alpha=0.9) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, pct_start=0.1, total_steps=100) base_model = BasicNN( - cls, + net, loss_fn, optimizer, scheduler=scheduler, @@ -298,8 +298,10 @@ candidate that has the highest consistency. customized within the ``dist_func`` parameter. In the code above, we employ a consistency measurement based on confidence, which calculates the consistency between the data example and candidates based on the - confidence derived from the predicted probability. In ``examples/mnist_add/main.py``, we - provide options for utilizing other forms of consistency measurement. + confidence derived from the predicted probability. In + ``examples/mnist_add/main.py``, the ``--dist-func`` flag lets you swap in + other predefined options (``hamming``, ``avg_confidence``, ``similarity``, + ``rejection``) without editing the code. Also, during the process of inconsistency minimization, we can leverage `ZOOpt library `__ for acceleration. @@ -368,7 +370,57 @@ Log: abl - INFO - Eval start: loop(val) [2] abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.993 mnist_add/reasoning_accuracy: 0.986 abl - INFO - Test start: - abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.991 mnist_add/reasoning_accuracy: 0.980 + abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.991 mnist_add/reasoning_accuracy: 0.980 + + +Command-line options +-------------------- + +In addition to the standard pipeline above, ``examples/mnist_add/main.py`` accepts +several flags that switch in alternative methods. The defaults reproduce the +standard pipeline, so existing usage is unaffected. + +- ``--method {standard,a3bl,verification}``: choose the learning/reasoning + pipeline. + ``standard`` uses ``Reasoner`` / ``SimpleBridge``. + ``a3bl`` uses the ambiguity-aware ``A3BLReasoner`` together with + ``A3BLBridge``. + ``verification`` uses ``VerificationReasoner`` and ``VerificationBridge``: + rather than picking one best candidate by ``dist_func``, it enumerates + the top ``--top-k`` consistent candidates by joint probability and trains + the model once per candidate. + All methods share the same ``BasicNN`` / ``ABLModel`` wrappers. +- ``--dist-func {hamming,confidence,avg_confidence,similarity,rejection}``: + passed straight to the reasoner. ``similarity`` requires the wrapped + PyTorch model to implement ``extract_features(x)`` (the ``LeNet5`` in + this example does); ``BasicNN`` then surfaces those embeddings to the + reasoner automatically. Ignored when ``--method verification``. +- ``--labeled-ratio FLOAT``: fraction in ``(0, 1]`` of training samples that + keep their ground-truth pseudo-labels. Values below ``1.0`` enable the + semi-supervised pipeline (``use_supervised_data=True`` on + ``bridge.train``). Only valid with ``--method standard``. +- ``--top-k INT``: number of consistent candidates the verification + reasoner enumerates per example. Only used with ``--method verification``. + Defaults to ``1``. + +Examples: + +.. code:: bash + + # Standard pipeline (default) + python main.py + + # A3BL with similarity-based consistency + python main.py --method a3bl --dist-func similarity + + # Semi-supervised: keep 30% of pseudo-labels, abduce the rest + python main.py --labeled-ratio 0.3 + + # Rejection-aware reasoning + python main.py --dist-func rejection + + # Verification Learning: train on the top-3 consistent candidates + python main.py --method verification --top-k 3 Environment diff --git a/docs/Intro/Advanced.rst b/docs/Intro/Advanced.rst new file mode 100644 index 000000000..b10ab57b3 --- /dev/null +++ b/docs/Intro/Advanced.rst @@ -0,0 +1,179 @@ +.. _Advanced: + +Advanced Topics +=============== + +The standard ABL pipeline (``BasicNN`` + ``ABLModel`` + ``Reasoner`` + +``SimpleBridge``) covers the majority of tasks. ABLkit also ships a few +drop-in variants for settings where the standard pipeline does not quite +fit. This page collects them in one place. + +The four topics below are independent: pick whichever ones apply to your +task. + +* :ref:`advanced-multilabel`: when each instance can carry **multiple + active labels** (sigmoid + binary indicator vectors) rather than a + single class. +* :ref:`advanced-semisupervised`: when **part of the training set + carries ground-truth pseudo-labels** and the rest must be abduced. +* :ref:`advanced-a3bl`: when many label assignments are consistent with + the knowledge base and we want to aggregate them into a **soft + label** instead of picking the single best one. +* :ref:`advanced-verification`: when we want to train against the + **top-K consistent label assignments** by joint probability rather + than a single best candidate. + + +.. _advanced-multilabel: + +Multi-Label Models +------------------ + +By default ``BasicNN`` and ``ABLModel`` assume a single-label +multi-class setting: softmax over classes, ``argmax`` at prediction +time, one integer label per instance. For tasks where each instance is +described by a *vector* of independent binary attributes (e.g., the 21 +binary concepts in BDD-OIA), ABLkit provides multi-label drop-in +replacements: + +* :class:`~ablkit.learning.MultiLabelBasicNN`: sigmoid output, + threshold at 0.5 for prediction, ``MultiLabelClassificationDataset`` + for training. +* :class:`~ablkit.learning.MultiLabelABLModel`: wraps a multi-label + base model and thresholds per-label probabilities into binary + indicator vectors stored on ``pred_idx``. +* :class:`~ablkit.learning.MultiLabelClassificationDataset`: stores + ``Y`` as a ``FloatTensor`` so it can be fed directly into + ``BCEWithLogitsLoss``. + +Typical usage swaps the standard classes 1-for-1: + +.. code:: python + + import torch.nn as nn + from torch import optim + + from ablkit.learning import MultiLabelABLModel, MultiLabelBasicNN + + net = MyMultiLabelNet() # PyTorch model with num_labels outputs + loss_fn = nn.BCEWithLogitsLoss() + optimizer = optim.Adam(net.parameters(), lr=2e-3) + + base_model = MultiLabelBasicNN(net, loss_fn, optimizer, device="cpu", + batch_size=32, num_epochs=1) + model = MultiLabelABLModel(base_model) + +See the BDD-OIA example for an end-to-end multi-label pipeline. + + +.. _advanced-semisupervised: + +Semi-Supervised Training +------------------------ + +When part of the training set already carries ground-truth +pseudo-labels (and the rest is unlabeled), ``SimpleBridge`` can be +asked to use those labels directly instead of abducing them. + +The mechanism is purely a flag on ``SimpleBridge.train``: + +* Provide a ``train_data`` tuple ``(X, gt_pseudo_label, Y)`` where the + ``gt_pseudo_label`` for unlabeled examples is ``None``. +* Pass ``use_supervised_data=True``. + +Under the hood the bridge calls +``Reasoner.batch_supervised_abduce``, which keeps existing +``gt_pseudo_label`` values verbatim and only abduces a candidate for +the ``None`` entries: + +.. code:: python + + bridge.train( + train_data=(X, pseudo_label_with_some_None, Y), + use_supervised_data=True, + loops=50, + segment_size=0.01, + ) + +The ``--labeled-ratio`` flag in the MNIST Addition example +demonstrates how to mask out a fraction of pseudo-labels and feed the +result through this flow. + + +.. _advanced-a3bl: + +A3BL: Ambiguity-Aware Abductive Learning +---------------------------------------- + +When many label assignments are consistent with the knowledge base +for a given example, picking only the lowest-distance candidate +discards useful signal. A3BL (Ambiguity-Aware Abductive Learning) +keeps the top candidates, weights them by their joint probability, +and trains the model on the resulting *soft label distribution*. + +ABLkit ships two classes: + +* :class:`~ablkit.reasoning.A3BLReasoner`: enumerates valid + candidates, scores them via a softmax over per-symbol probabilities, + and aggregates the top-K into a soft label. +* :class:`~ablkit.bridge.A3BLBridge`: runs the ambiguity-aware + prediction → soft-label-abduction → train loop. + +Minimal wiring: + +.. code:: python + + from ablkit.bridge import A3BLBridge + from ablkit.reasoning import A3BLReasoner + + reasoner = A3BLReasoner(kb, topK=16, temperature=0.2) + bridge = A3BLBridge(model, reasoner, metric_list) + bridge.train(train_data, loops=2, segment_size=0.01) + +Reference: https://github.com/Hao-Yuan-He/A3BL + + +.. _advanced-verification: + +Verification Learning +--------------------- + +Verification Learning replaces the standard "abduce the single best +candidate" step with a top-K enumeration: starting from the most +probable joint label assignment, the search walks the per-symbol +probability lattice in **descending joint-probability order** and +collects the first ``top_k`` candidates that satisfy the knowledge +base. The model is then trained once per candidate per segment. + +ABLkit ships two classes (consolidated in +``ablkit/reasoning/reasoner.py`` and +``ablkit/bridge/verification_bridge.py``): + +* :class:`~ablkit.reasoning.VerificationReasoner`: exposes + ``top_k_candidates(pred_prob, y)`` and the batched variant. +* :class:`~ablkit.bridge.VerificationBridge`: drives the + predict → enumerate → train-per-candidate loop. + +Helpers usable without a reasoner instance: + +* :func:`ablkit.reasoning.reasoner.enumerate_label_assignments`: a + generator over label assignments in descending joint-probability + order. +* :func:`ablkit.reasoning.reasoner.top_k_satisfying`: wraps the + generator with a user predicate and a fallback when nothing matches. + +Minimal wiring: + +.. code:: python + + from ablkit.bridge import VerificationBridge + from ablkit.reasoning import VerificationReasoner + + reasoner = VerificationReasoner(kb, top_k=3, max_iter=10000) + bridge = VerificationBridge(model, reasoner, metric_list) + bridge.train(train_data, loops=2, segment_size=0.01) + +Reference: https://github.com/VerificationLearning/VerificationLearning + +The ``--method verification --top-k K`` flags in the MNIST Addition +example demonstrate the full pipeline. diff --git a/docs/Intro/Basics.rst b/docs/Intro/Basics.rst index 5cb83847b..73af0b69b 100644 --- a/docs/Intro/Basics.rst +++ b/docs/Intro/Basics.rst @@ -10,7 +10,7 @@ Learn the Basics ================ Modules in ABLkit ----------------------- +----------------- ABLkit is an efficient toolkit for `Abductive Learning <../Overview/Abductive-Learning.html>`_ (ABL), a paradigm which integrates machine learning and logical reasoning in a balanced-loop. @@ -50,7 +50,7 @@ learning, and reasoning, facilitating the training and testing of the entire ABL framework. Use ABLkit Step by Step ----------------------------- +----------------------- In a typical ABL process, as illustrated below, data inputs are first predicted by the learning model ``ABLModel.predict``, and the outcomes are pseudo-labels. diff --git a/docs/Intro/Learning.rst b/docs/Intro/Learning.rst index efb1e9444..d460ea387 100644 --- a/docs/Intro/Learning.rst +++ b/docs/Intro/Learning.rst @@ -43,13 +43,13 @@ For a PyTorch-based neural network, we need to encapsulate it within a ``BasicNN .. code:: python # Load a PyTorch-based neural network - cls = torchvision.models.resnet18(pretrained=True) + net = torchvision.models.resnet18(pretrained=True) # loss function and optimizer are used for training - loss_fn = torch.nn.CrossEntropyLoss() - optimizer = torch.optim.Adam(cls.parameters()) + loss_fn = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.Adam(net.parameters()) - base_model = BasicNN(cls, loss_fn, optimizer) + base_model = BasicNN(net, loss_fn, optimizer) BasicNN ^^^^^^^ diff --git a/docs/Intro/Quick-Start.rst b/docs/Intro/Quick-Start.rst index 4563673e5..3eb6252e6 100644 --- a/docs/Intro/Quick-Start.rst +++ b/docs/Intro/Quick-Start.rst @@ -41,7 +41,7 @@ In this example, we build a simple LeNet5 network as the base model. # The 'models' module below is located in 'examples/mnist_add/' from models.nn import LeNet5 - cls = LeNet5(num_classes=10) + net = LeNet5(num_classes=10) To facilitate uniform processing, ABLkit provides the ``BasicNN`` class to convert a PyTorch-based neural network into a format compatible with scikit-learn models. To construct a ``BasicNN`` instance, aside from the network itself, we also need to define a loss function, an optimizer, and the computing device. @@ -51,9 +51,9 @@ To facilitate uniform processing, ABLkit provides the ``BasicNN`` class to conve from ablkit.learning import BasicNN loss_fn = torch.nn.CrossEntropyLoss() - optimizer = torch.optim.RMSprop(cls.parameters(), lr=0.001) + optimizer = torch.optim.RMSprop(net.parameters(), lr=0.001) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - base_model = BasicNN(model=cls, loss_fn=loss_fn, optimizer=optimizer, device=device) + base_model = BasicNN(model=net, loss_fn=loss_fn, optimizer=optimizer, device=device) The base model built above is trained to make predictions on instance-level data (e.g., a single image), while ABL deals with example-level data. To bridge this gap, we wrap the ``base_model`` into an instance of ``ABLModel``. This class serves as a unified wrapper for base models, facilitating the learning part to train, test, and predict on example-level data, (e.g., images that comprise an equation). @@ -109,7 +109,7 @@ ABLkit provides two basic metrics, namely ``SymbolAccuracy`` and ``ReasoningMetr Read more about `building evaluation metrics `_ Bridging Learning and Reasoning ---------------------------------------- +------------------------------- Now, we use ``SimpleBridge`` to combine learning and reasoning in a unified ABL framework. diff --git a/docs/Intro/Reasoning.rst b/docs/Intro/Reasoning.rst index 33ce16319..935cf62af 100644 --- a/docs/Intro/Reasoning.rst +++ b/docs/Intro/Reasoning.rst @@ -8,7 +8,7 @@ Reasoning part -=============== +============== In this section, we will look at how to build the reasoning part, which leverages domain knowledge and performs deductive or abductive reasoning. @@ -318,11 +318,20 @@ specify: used when determining consistency between your prediction and candidate returned from knowledge base. This can be either a user-defined function or one that is predefined. Valid predefined options include - “hamming”, “confidence” and “avg_confidence”. For “hamming”, it directly calculates the Hamming distance between the - predicted pseudo-label in the data example and candidate. For “confidence”, it - calculates the confidence distance between the predicted probabilities in the data - example and each candidate, where the confidence distance is defined as 1 - the product - of prediction probabilities in “confidence” and 1 - the average of prediction probabilities in “avg_confidence”. + “hamming”, “confidence”, “avg_confidence”, “similarity” and “rejection”. + For “hamming”, it directly calculates the Hamming distance between the + predicted pseudo-label in the data example and candidate. For “confidence” and + “avg_confidence”, it calculates the confidence distance between the predicted + probabilities and each candidate, defined as ``1 - product`` and ``1 - average`` + of the candidate's per-symbol probabilities respectively. For “similarity”, + it compares candidates against the geometry of the model's embeddings. + This requires the wrapped PyTorch model to implement + ``extract_features(x)`` (returning, for example, penultimate-layer + activations); ``BasicNN`` then surfaces them via its own + ``extract_features`` method, and ``ABLModel`` automatically stores the + resulting embeddings on ``data_example.embeddings`` for the reasoner. + For “rejection”, it combines the confidence distance with a candidate-complexity + penalty so that shorter candidates are favored when scores are close. Defaults to “confidence”. - ``idx_to_label`` (dict, optional), a mapping from index in the base model to label. If not provided, a default order-based index to label mapping is created. @@ -333,7 +342,7 @@ The main method implemented by ``Reasoner`` is based on the distance function defined in ``dist_func``. MNIST Addition example (cont.) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ As an example, consider these data examples for MNIST Addition: @@ -377,7 +386,7 @@ Out: Specifically, as mentioned before, “confidence” calculates the distance between the data example and candidates based on the confidence derived from the predicted probability. -Take ``example1`` as an example, the ``pred_prob`` in it indicates a higher -confidence that the first label should be "1" rather than "7". Therefore, among the +Take ``example1`` as an example, the ``pred_prob`` in it indicates a higher +confidence that the first label should be "1" rather than "7". Therefore, among the candidates [1,7] and [7,1], it would be closer to [1,7] (as its first label is "1"). diff --git a/docs/index.rst b/docs/index.rst index d398c4d5a..f3ce3914e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -127,6 +127,7 @@ To cite ABLkit, please cite the following paper: `Huang et al., 2024 List[List[Any]]: + pred_idx = data_examples.pred_idx # [ ndarray(1,nc),... ] + pred_pseudo_label = [] + for sub_list in pred_idx: + sub_list = sub_list.squeeze() # 1 x nc -> nc + pred_pseudo_label.append([self.reasoner.idx_to_label[_idx] for _idx in sub_list]) + data_examples.pred_pseudo_label = pred_pseudo_label + return data_examples.pred_pseudo_label + + def pseudo_label_to_idx(self, data_examples: ListData) -> List[List[Any]]: + abduced_pseudo_label = data_examples.abduced_pseudo_label + abduced_idx = [] + for sub_list in abduced_pseudo_label: + sub_list = np.array([self.reasoner.label_to_idx[_lab] for _lab in sub_list]) + abduced_idx.append(sub_list) + data_examples.abduced_idx = abduced_idx + return data_examples.abduced_idx diff --git a/examples/bdd_oia/dataset/data_util.py b/examples/bdd_oia/dataset/data_util.py new file mode 100644 index 000000000..a4a9ee965 --- /dev/null +++ b/examples/bdd_oia/dataset/data_util.py @@ -0,0 +1,19 @@ +import os +import numpy as np + +CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) + + +def get_dataset(fname, get_pseudo_label=True): + fname = os.path.join(CURRENT_DIR, fname) + data = np.load(fname) + X = data["X"] + X = [[emb.astype(np.float32)] for emb in X] + pseudo_label = data["pseudo_label"].astype(int).tolist() if get_pseudo_label else None + Y = data["Y"][:, :4].astype(int).tolist() + Y = [tuple(y) for y in Y] + return X, pseudo_label, Y + + +if __name__ == "__main__": + dataset = get_dataset("val.npz") diff --git a/examples/bdd_oia/main.py b/examples/bdd_oia/main.py new file mode 100644 index 000000000..d720b6493 --- /dev/null +++ b/examples/bdd_oia/main.py @@ -0,0 +1,147 @@ +import argparse +import os.path as osp +import numpy as np +import torch +from torch import optim +import torch.nn as nn + +from ablkit.data.evaluation import SymbolAccuracy +from ablkit.learning import MultiLabelABLModel, MultiLabelBasicNN +from ablkit.reasoning import Reasoner +from ablkit.utils import ABLLogger, print_log + +from models.nn import ConceptNet +from reasoning.bddkb import BDDKB +from dataset.data_util import get_dataset +from bridge import BDDBridge +from metric import BDDReasoningMetric + + +def multi_label_confidence_dist(data_example, candidates, candidates_idxs, reasoning_results): + pred_prob = data_example.pred_prob.T # nc x 1 + pred_prob = np.concatenate([1 - pred_prob, pred_prob], axis=1) # nc x 2 + cols = np.arange(len(candidates_idxs[0]))[None, :] + corr_prob = pred_prob[cols, candidates_idxs] + costs = -np.sum(np.log(corr_prob + 1e-6), axis=1) + return costs + + +def parse_args(): + parser = argparse.ArgumentParser(description="BDD-OIA example") + parser.add_argument( + "--no-cuda", action="store_true", default=False, help="disables CUDA training" + ) + parser.add_argument( + "--epochs", + type=int, + default=1, + help="number of epochs in each learning loop iteration (default : 1)", + ) + parser.add_argument( + "--lr", type=float, default=2e-3, help="base model learning rate (default : 0.002)" + ) + parser.add_argument( + "--batch-size", type=int, default=32, help="base model batch size (default : 32)" + ) + parser.add_argument( + "--loops", type=int, default=2, help="number of loop iterations (default : 2)" + ) + parser.add_argument( + "--segment_size", type=float, default=0.01, help="segment size (default : 0.01)" + ) + parser.add_argument("--save_interval", type=int, default=1, help="save interval (default : 1)") + parser.add_argument( + "--max-revision", type=int, default=3, help="maximum revision in reasoner (default : 3)" + ) + parser.add_argument( + "--require-more-revision", + type=int, + default=3, + help="require more revision in reasoner (default : 3)", + ) + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + # Build logger + print_log("Abductive Learning on the BDD-OIA example.", logger="current") + + # -- Working with Data ------------------------------ + print_log("Working with Data.", logger="current") + train_data = get_dataset(fname="train.npz", get_pseudo_label=True) + val_data = get_dataset(fname="val.npz", get_pseudo_label=True) + test_data = get_dataset(fname="test.npz", get_pseudo_label=True) + + # -- Building the Learning Part --------------------- + print_log("Building the Learning Part.", logger="current") + + # Build necessary components for MultiLabelBasicNN + net = ConceptNet() + loss_fn = nn.BCEWithLogitsLoss() + optimizer = optim.Adam(net.parameters(), lr=args.lr) + use_cuda = not args.no_cuda and torch.cuda.is_available() + device = torch.device("cuda" if use_cuda else "cpu") + scheduler = optim.lr_scheduler.OneCycleLR( + optimizer, + max_lr=args.lr, + pct_start=0.15, + epochs=args.loops, + steps_per_epoch=int(1 / args.segment_size) + 1, + ) + + base_model = MultiLabelBasicNN( + net, + loss_fn, + optimizer, + scheduler=scheduler, + device=device, + batch_size=args.batch_size, + num_epochs=args.epochs, + ) + + model = MultiLabelABLModel(base_model) + + # -- Building the Reasoning Part -------------------- + print_log("Building the Reasoning Part.", logger="current") + + # Build knowledge base + kb = BDDKB() + + # Create reasoner + reasoner = Reasoner( + kb, + dist_func=multi_label_confidence_dist, + max_revision=args.max_revision, + require_more_revision=args.require_more_revision, + ) + + # -- Building Evaluation Metrics -------------------- + print_log("Building Evaluation Metrics.", logger="current") + metric_list = [SymbolAccuracy(prefix="bdd_oia"), BDDReasoningMetric(kb=kb, prefix="bdd_oia")] + + # -- Bridging Learning and Reasoning ---------------- + print_log("Bridge Learning and Reasoning.", logger="current") + bridge = BDDBridge(model, reasoner, metric_list) + + # Retrieve the directory of the Log file and define the directory for saving the model weights. + log_dir = ABLLogger.get_current_instance().log_dir + weights_dir = osp.join(log_dir, "weights") + + # Train and Test + bridge.train( + train_data=train_data, + val_data=val_data, + loops=args.loops, + segment_size=args.segment_size, + save_interval=args.save_interval, + save_dir=weights_dir, + ) + bridge.test(test_data) + + +if __name__ == "__main__": + main() diff --git a/examples/bdd_oia/metric.py b/examples/bdd_oia/metric.py new file mode 100644 index 000000000..e2042ecfc --- /dev/null +++ b/examples/bdd_oia/metric.py @@ -0,0 +1,27 @@ +from typing import Optional + +from ablkit.reasoning import KBBase +from ablkit.data import BaseMetric, ListData + + +class BDDReasoningMetric(BaseMetric): + def __init__(self, kb: KBBase, prefix: Optional[str] = None) -> None: + super().__init__(prefix) + self.kb = kb + + def process(self, data_examples: ListData) -> None: + pred_pseudo_label_list = data_examples.pred_pseudo_label + y_list = data_examples.Y + x_list = data_examples.X + for pred_pseudo_label, y, x in zip(pred_pseudo_label_list, y_list, x_list): + pred_y = self.kb.logic_forward( + pred_pseudo_label, *(x,) if self.kb._num_args == 2 else () + ) + for py, yy in zip(pred_y, y): + self.results.append(int(py == yy)) + + def compute_metrics(self) -> dict: + results = self.results + metrics = dict() + metrics["reasoning_accuracy"] = sum(results) / len(results) + return metrics diff --git a/examples/bdd_oia/models/__init__.py b/examples/bdd_oia/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/bdd_oia/models/nn.py b/examples/bdd_oia/models/nn.py new file mode 100644 index 000000000..8ff817335 --- /dev/null +++ b/examples/bdd_oia/models/nn.py @@ -0,0 +1,24 @@ +from torch import nn + + +class SimpleNet(nn.Module): + def __init__(self, num_features=2048, num_concepts=21): + super(SimpleNet, self).__init__() + self.fc = nn.Linear(num_features, num_concepts) + + def forward(self, x): + return self.fc(x) + + +class ConceptNet(nn.Module): + def __init__(self, num_features=2048, num_concepts=21): + super(ConceptNet, self).__init__() + intermidate_dim = 256 + self.fc = nn.Sequential( + nn.Linear(num_features, intermidate_dim), + nn.SiLU(), + nn.Linear(intermidate_dim, num_concepts), + ) + + def forward(self, x): + return self.fc(x) diff --git a/examples/bdd_oia/reasoning/__init__.py b/examples/bdd_oia/reasoning/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/bdd_oia/reasoning/bddkb.py b/examples/bdd_oia/reasoning/bddkb.py new file mode 100644 index 000000000..02e3d6a4b --- /dev/null +++ b/examples/bdd_oia/reasoning/bddkb.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +from ablkit.reasoning import KBBase + + +class BDDKB(KBBase): + def __init__(self, pseudo_label_list=None): + if pseudo_label_list is None: + pseudo_label_list = [0, 1] + super().__init__(pseudo_label_list) + + def logic_forward(self, attrs): + """ + Abduction space + (0, 1, 0, 0) 610812 + (0, 1, 0, 1) 75012 + (0, 1, 1, 0) 75012 + (0, 1, 1, 1) 9212 + (1, 0, 0, 0) 12996 + (1, 0, 0, 1) 1596 + (1, 0, 1, 0) 1596 + (1, 0, 1, 1) 196 + """ + if len(attrs) != 21: + raise ValueError( + f"BDDKB.logic_forward expects exactly 21 concept attributes, got {len(attrs)}." + ) + ( + green_light, + follow, + road_clear, + red_light, + traffic_sign, + car, + person, + rider, + other_obstacle, + left_lane, + left_green_light, + left_follow, + no_left_lane, + left_obstacle, + left_solid_line, + right_lane, + right_green_light, + right_follow, + no_right_lane, + right_obstacle, + right_solid_line, + ) = attrs + + illegal_return = (0, 0, 0, 0) + if red_light == green_light == 1: + return illegal_return + obstacle = car or person or rider or other_obstacle + if road_clear == obstacle: + return illegal_return + move_forward = green_light or follow or road_clear + stop = red_light or traffic_sign or obstacle + if stop: + move_forward = 0 + + can_turn_left = left_lane or left_green_light or left_follow + cannot_turn_left = no_left_lane or left_obstacle or left_solid_line + turn_left = can_turn_left and int(not cannot_turn_left) + + can_turn_right = right_lane or right_green_light or right_follow + cannot_turn_right = no_right_lane or right_obstacle or right_solid_line + turn_right = can_turn_right and int(not cannot_turn_right) + + return move_forward, stop, turn_left, turn_right diff --git a/examples/bdd_oia/requirements.txt b/examples/bdd_oia/requirements.txt new file mode 100644 index 000000000..a1d6490d8 --- /dev/null +++ b/examples/bdd_oia/requirements.txt @@ -0,0 +1 @@ +ablkit diff --git a/examples/hed/bridge.py b/examples/hed/bridge.py index 82ab1779b..08605b9c3 100644 --- a/examples/hed/bridge.py +++ b/examples/hed/bridge.py @@ -26,7 +26,7 @@ def __init__( ) -> None: super().__init__(model, reasoner, metric_list) - def pretrain(self, weights_dir): + def pretrain(self, weights_dir: str) -> None: if not os.path.exists(os.path.join(weights_dir, "pretrain_weights.pth")): print_log("Pretrain Start", logger="current") @@ -64,7 +64,7 @@ def pretrain(self, weights_dir): self.model.load(load_path=os.path.join(weights_dir, "pretrain_weights.pth")) - def select_mapping_and_abduce(self, data_examples: ListData): + def select_mapping_and_abduce(self, data_examples: ListData) -> List[List[Any]]: candidate_mappings = gen_mappings([0, 1, 2, 3], ["+", "=", 0, 1]) mapping_score = [] abduced_pseudo_label_list = [] @@ -87,11 +87,13 @@ def select_mapping_and_abduce(self, data_examples: ListData): return data_examples.abduced_pseudo_label - def abduce_pseudo_label(self, data_examples: ListData): + def abduce_pseudo_label(self, data_examples: ListData) -> List[List[Any]]: self.reasoner.abduce(data_examples) return data_examples.abduced_pseudo_label - def check_training_impact(self, filtered_data_examples, data_examples): + def check_training_impact( + self, filtered_data_examples: ListData, data_examples: ListData + ) -> bool: character_accuracy = self.model.valid(filtered_data_examples) revisible_ratio = len(filtered_data_examples.X) / len(data_examples.X) log_string = ( @@ -104,7 +106,7 @@ def check_training_impact(self, filtered_data_examples, data_examples): return True return False - def check_rule_quality(self, rule, val_data, equation_len): + def check_rule_quality(self, rule: Any, val_data: Any, equation_len: int) -> bool: val_X_true = self.data_preprocess(val_data[1], equation_len) val_X_false = self.data_preprocess(val_data[0], equation_len) @@ -121,7 +123,7 @@ def check_rule_quality(self, rule, val_data, equation_len): return True return False - def calc_consistent_ratio(self, data_examples, rule): + def calc_consistent_ratio(self, data_examples: ListData, rule: Any) -> float: self.predict(data_examples) pred_pseudo_label = self.idx_to_pseudo_label(data_examples) consistent_num = sum( @@ -129,7 +131,9 @@ def calc_consistent_ratio(self, data_examples, rule): ) return consistent_num / len(data_examples.X) - def get_rules_from_data(self, data_examples, samples_per_rule, samples_num): + def get_rules_from_data( + self, data_examples: ListData, samples_per_rule: int, samples_num: int + ) -> List[Any]: rules = [] sampler = InfiniteSampler(len(data_examples), batch_size=samples_per_rule) @@ -159,7 +163,7 @@ def get_rules_from_data(self, data_examples, samples_per_rule, samples_num): return rules @staticmethod - def filter_empty(data_examples: ListData): + def filter_empty(data_examples: ListData) -> ListData: consistent_dix = [ i for i in range(len(data_examples.abduced_pseudo_label)) @@ -168,7 +172,7 @@ def filter_empty(data_examples: ListData): return data_examples[consistent_dix] @staticmethod - def select_rules(rule_dict): + def select_rules(rule_dict: dict) -> List[Any]: add_nums_dict = {} for r in list(rule_dict): add_nums = str(r.split("]")[0].split("[")[1]) + str( @@ -185,7 +189,7 @@ def select_rules(rule_dict): add_nums_dict[add_nums] = r return list(rule_dict) - def data_preprocess(self, data, equation_len) -> ListData: + def data_preprocess(self, data: Any, equation_len: int) -> ListData: data_examples = ListData() data_examples.X = data[equation_len] + data[equation_len + 1] data_examples.gt_pseudo_label = None @@ -193,7 +197,15 @@ def data_preprocess(self, data, equation_len) -> ListData: return data_examples - def train(self, train_data, val_data, segment_size=10, min_len=5, max_len=8, save_dir="./"): + def train( + self, + train_data: Any, + val_data: Any, + segment_size: int = 10, + min_len: int = 5, + max_len: int = 8, + save_dir: str = "./", + ) -> None: for equation_len in range(min_len, max_len): print_log( f"============== equation_len: {equation_len}-{equation_len + 1} ================", diff --git a/examples/hed/datasets/get_dataset.py b/examples/hed/datasets/get_dataset.py index cea07bc38..f16701bb1 100644 --- a/examples/hed/datasets/get_dataset.py +++ b/examples/hed/datasets/get_dataset.py @@ -10,6 +10,8 @@ import numpy as np from torchvision.transforms import transforms +from ablkit.utils import print_log + CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) @@ -79,10 +81,10 @@ def get_dataset(dataset="mnist", train=True): data_dir = CURRENT_DIR + "/mnist_images" if not os.path.exists(data_dir): - print("Dataset not exist, downloading it...") + print_log("Dataset not present, downloading it...", logger="current") url = "https://drive.google.com/u/0/uc?id=1W2AUn_fnXa4XkgLk4d17K3bEgpae8GMg&export=download" download_and_unzip(url, os.path.join(CURRENT_DIR, "HED.zip")) - print("Download and extraction complete.") + print_log("Download and extraction complete.", logger="current") if train: file = os.path.join(data_dir, "expr_train.json") diff --git a/examples/hed/main.py b/examples/hed/main.py index c1d7e76c8..e25e645a2 100644 --- a/examples/hed/main.py +++ b/examples/hed/main.py @@ -61,15 +61,15 @@ def main(): print_log("Building the Learning Part.", logger="current") # Build necessary components for BasicNN - cls = SymbolNet(num_classes=4) + net = SymbolNet(num_classes=4) loss_fn = nn.CrossEntropyLoss() - optimizer = torch.optim.RMSprop(cls.parameters(), lr=args.lr, weight_decay=args.weight_decay) + optimizer = torch.optim.RMSprop(net.parameters(), lr=args.lr, weight_decay=args.weight_decay) use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") # Build BasicNN base_model = BasicNN( - cls, + net, loss_fn, optimizer, device=device, diff --git a/examples/hed/models/__init__.py b/examples/hed/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/hed/reasoning/reasoning.py b/examples/hed/reasoning/reasoning.py index 1d27763f8..22293a6d5 100644 --- a/examples/hed/reasoning/reasoning.py +++ b/examples/hed/reasoning/reasoning.py @@ -1,8 +1,10 @@ import math import os +from typing import Any, List, Optional, Union import numpy as np +from ablkit.data.structures import ListData from ablkit.reasoning import PrologKB, Reasoner from ablkit.utils import reform_list @@ -11,37 +13,44 @@ class HedKB(PrologKB): def __init__( - self, pseudo_label_list=[1, 0, "+", "="], pl_file=os.path.join(CURRENT_DIR, "learn_add.pl") - ): + self, + pseudo_label_list: List[Any] = [1, 0, "+", "="], + pl_file: str = os.path.join(CURRENT_DIR, "learn_add.pl"), + ) -> None: pl_file = pl_file.replace("\\", "/") super().__init__(pseudo_label_list, pl_file) - self.learned_rules = {} + self.learned_rules: dict = {} - def consist_rule(self, exs, rules): + def consist_rule(self, exs: Any, rules: Any) -> bool: rules = str(rules).replace("'", "") return len(list(self.prolog.query("eval_inst_feature(%s, %s)." % (exs, rules)))) != 0 - def abduce_rules(self, pred_res): + def abduce_rules(self, pred_res: Any) -> Optional[List[Any]]: prolog_result = list(self.prolog.query("consistent_inst_feature(%s, X)." % pred_res)) if len(prolog_result) == 0: return None prolog_rules = prolog_result[0]["X"] - rules = [rule.value for rule in prolog_rules] - return rules + return [rule.value for rule in prolog_rules] class HedReasoner(Reasoner): - def revise_at_idx(self, data_example): + def revise_at_idx(self, data_example: ListData) -> Any: revision_idx = np.where(np.array(data_example.flatten("revision_flag")) != 0)[0] candidate = self.kb.revise_at_idx( data_example.pred_pseudo_label, data_example.Y, data_example.X, revision_idx ) return candidate - def zoopt_budget(self, symbol_num): + def zoopt_budget(self, symbol_num: int) -> int: return 200 - def zoopt_score(self, symbol_num, data_example, sol, get_score=True): + def zoopt_score( + self, + symbol_num: int, + data_example: ListData, + sol: Any, + get_score: bool = True, + ) -> Union[float, List[int]]: revision_flag = reform_list( list(sol.get_x().astype(np.int32)), data_example.pred_pseudo_label ) @@ -82,7 +91,7 @@ def zoopt_score(self, symbol_num, data_example, sol, get_score=True): else: return max_consistent_idxs - def abduce(self, data_example): + def abduce(self, data_example: ListData) -> List[List[Any]]: symbol_num = data_example.elements_num("pred_pseudo_label") max_revision_num = self._get_max_revision_num(self.max_revision, symbol_num) @@ -98,5 +107,5 @@ def abduce(self, data_example): data_example.abduced_pseudo_label = abduced_pseudo_label return abduced_pseudo_label - def abduce_rules(self, pred_res): + def abduce_rules(self, pred_res: Any) -> Optional[List[Any]]: return self.kb.abduce_rules(pred_res) diff --git a/examples/hwf/datasets/get_dataset.py b/examples/hwf/datasets/get_dataset.py index 0700d3ceb..4109910a0 100644 --- a/examples/hwf/datasets/get_dataset.py +++ b/examples/hwf/datasets/get_dataset.py @@ -6,6 +6,8 @@ from PIL import Image from torchvision.transforms import transforms +from ablkit.utils import print_log + CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1,))]) @@ -30,10 +32,10 @@ def get_dataset(train=True, get_pseudo_label=False): data_dir = CURRENT_DIR + "/data" if not os.path.exists(data_dir): - print("Dataset not exist, downloading it...") + print_log("Dataset not present, downloading it...", logger="current") url = "https://drive.google.com/u/0/uc?id=1t52OE2Wdm5GdShX1jD2Wy8phCllk0r8I&export=download" download_and_unzip(url, os.path.join(CURRENT_DIR, "HWF.zip")) - print("Download and extraction complete.") + print_log("Download and extraction complete.", logger="current") if train: file = os.path.join(data_dir, "expr_train.json") diff --git a/examples/hwf/main.py b/examples/hwf/main.py index 6ea03804c..45c427f44 100644 --- a/examples/hwf/main.py +++ b/examples/hwf/main.py @@ -1,5 +1,8 @@ import argparse +import ast +import operator import os.path as osp +from typing import List import numpy as np import torch @@ -14,56 +17,80 @@ from datasets import get_dataset from models.nn import SymbolNet +DIGITS = {"1", "2", "3", "4", "5", "6", "7", "8", "9"} +OPERATORS = {"+", "-", "*", "/"} +PSEUDO_LABEL_LIST = sorted(DIGITS) + sorted(OPERATORS) + + +def _is_well_formed(formula: List[str]) -> bool: + """Return True iff ``formula`` alternates digit-operator-digit and has odd length.""" + if len(formula) % 2 == 0: + return False + for i, sym in enumerate(formula): + expected = DIGITS if i % 2 == 0 else OPERATORS + if sym not in expected: + return False + return True + + +_SAFE_BINOPS = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, +} + + +def _safe_eval(node: ast.AST) -> float: + """Evaluate an arithmetic AST restricted to the four basic operators.""" + if isinstance(node, ast.Expression): + return _safe_eval(node.body) + if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)): + return node.value + if isinstance(node, ast.BinOp) and type(node.op) in _SAFE_BINOPS: + return _SAFE_BINOPS[type(node.op)](_safe_eval(node.left), _safe_eval(node.right)) + raise ValueError(f"unsupported AST node: {type(node).__name__}") + + +def _evaluate(formula: List[str]) -> float: + """Evaluate a well-formed digit/operator formula with proper precedence.""" + try: + tree = ast.parse("".join(formula), mode="eval") + return _safe_eval(tree) + except (SyntaxError, ZeroDivisionError, ValueError): + return np.inf + + +def _hwf_logic_forward(formula: List[str]) -> float: + """Shared ``logic_forward`` implementation for both ``HwfKB`` variants.""" + if not _is_well_formed(formula): + return np.inf + return _evaluate(formula) + class HwfKB(KBBase): def __init__( self, - pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"], - max_err=1e-10, - ): + pseudo_label_list: List[str] = PSEUDO_LABEL_LIST, + max_err: float = 1e-10, + ) -> None: super().__init__(pseudo_label_list, max_err) - def _valid_candidate(self, formula): - if len(formula) % 2 == 0: - return False - for i in range(len(formula)): - if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: - return False - if i % 2 != 0 and formula[i] not in ["+", "-", "*", "/"]: - return False - return True - - # Implement the deduction function - def logic_forward(self, formula): - if not self._valid_candidate(formula): - return np.inf - return eval("".join(formula)) + def logic_forward(self, formula: List[str]) -> float: + return _hwf_logic_forward(formula) class HwfGroundKB(GroundKB): def __init__( self, - pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", "+", "-", "*", "/"], - GKB_len_list=[1, 3, 5, 7], - max_err=1e-10, - ): + pseudo_label_list: List[str] = PSEUDO_LABEL_LIST, + GKB_len_list: List[int] = [1, 3, 5, 7], + max_err: float = 1e-10, + ) -> None: super().__init__(pseudo_label_list, GKB_len_list, max_err) - def _valid_candidate(self, formula): - if len(formula) % 2 == 0: - return False - for i in range(len(formula)): - if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: - return False - if i % 2 != 0 and formula[i] not in ["+", "-", "*", "/"]: - return False - return True - - # Implement the deduction function - def logic_forward(self, formula): - if not self._valid_candidate(formula): - return np.inf - return eval("".join(formula)) + def logic_forward(self, formula: List[str]) -> float: + return _hwf_logic_forward(formula) def main(): @@ -81,7 +108,7 @@ def main(): "--label-smoothing", type=float, default=0.2, - help="label smoothing in cross entropy loss (default : 0.2)" + help="label smoothing in cross entropy loss (default : 0.2)", ) parser.add_argument( "--lr", type=float, default=1e-3, help="base model learning rate (default : 0.001)" @@ -130,15 +157,20 @@ def main(): print_log("Building the Learning Part.", logger="current") # Build necessary components for BasicNN - cls = SymbolNet(num_classes=13, image_size=(45, 45, 1)) + net = SymbolNet(num_classes=13, image_size=(45, 45, 1)) loss_fn = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) - optimizer = torch.optim.Adam(cls.parameters(), lr=args.lr) + optimizer = torch.optim.Adam(net.parameters(), lr=args.lr) use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") # Build BasicNN base_model = BasicNN( - cls, loss_fn, optimizer, device=device, batch_size=args.batch_size, num_epochs=args.epochs, + net, + loss_fn, + optimizer, + device=device, + batch_size=args.batch_size, + num_epochs=args.epochs, ) # Build ABLModel diff --git a/examples/hwf/models/__init__.py b/examples/hwf/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/mnist_add/main.py b/examples/mnist_add/main.py index c992c4e5e..41f4434a5 100644 --- a/examples/mnist_add/main.py +++ b/examples/mnist_add/main.py @@ -1,14 +1,22 @@ import argparse import os.path as osp +import random import torch from torch import nn from torch.optim import RMSprop, lr_scheduler -from ablkit.bridge import SimpleBridge +from ablkit.bridge import A3BLBridge, SimpleBridge, VerificationBridge from ablkit.data.evaluation import ReasoningMetric, SymbolAccuracy from ablkit.learning import ABLModel, BasicNN -from ablkit.reasoning import GroundKB, KBBase, PrologKB, Reasoner +from ablkit.reasoning import ( + A3BLReasoner, + GroundKB, + KBBase, + PrologKB, + Reasoner, + VerificationReasoner, +) from ablkit.utils import ABLLogger, print_log from datasets import get_dataset @@ -31,7 +39,10 @@ def logic_forward(self, nums): return sum(nums) -def main(): +DIST_FUNCS = ["hamming", "confidence", "avg_confidence", "similarity", "rejection"] + + +def parse_args(): parser = argparse.ArgumentParser(description="MNIST Addition example") parser.add_argument( "--no-cuda", action="store_true", default=False, help="disables CUDA training" @@ -59,7 +70,7 @@ def main(): "--loops", type=int, default=2, help="number of loop iterations (default : 2)" ) parser.add_argument( - "--segment_size", type=int, default=0.01, help="segment size (default : 0.01)" + "--segment_size", type=float, default=0.01, help="segment size (default : 0.01)" ) parser.add_argument("--save_interval", type=int, default=1, help="save interval (default : 1)") parser.add_argument( @@ -78,37 +89,78 @@ def main(): kb_type.add_argument( "--ground", action="store_true", default=False, help="use GroundKB (default: False)" ) + parser.add_argument( + "--method", + choices=["standard", "a3bl", "verification"], + default="standard", + help="learning/reasoning pipeline to use (default: standard)", + ) + parser.add_argument( + "--dist-func", + choices=DIST_FUNCS, + default="confidence", + help="distance function used by the reasoner (default: confidence)", + ) + parser.add_argument( + "--labeled-ratio", + type=float, + default=1.0, + help=( + "fraction in (0, 1] of training samples that keep their ground-truth pseudo-labels. " + "Values below 1.0 enable the semi-supervised pipeline (default: 1.0)" + ), + ) + parser.add_argument( + "--top-k", + type=int, + default=1, + help=( + "number of consistent candidates the verification reasoner enumerates per example. " + "Only used when --method verification (default: 1)" + ), + ) + parser.add_argument( + "--seed", type=int, default=0, help="random seed for semi-supervised split (default: 0)" + ) args = parser.parse_args() + if not (0.0 < args.labeled_ratio <= 1.0): + parser.error("--labeled-ratio must be in (0, 1].") + if args.method == "a3bl" and args.labeled_ratio < 1.0: + parser.error("--method a3bl does not support --labeled-ratio < 1.0.") + if args.method == "a3bl" and args.dist_func == "rejection": + parser.error("--method a3bl is not compatible with --dist-func rejection.") + if args.method == "verification" and args.labeled_ratio < 1.0: + parser.error("--method verification does not support --labeled-ratio < 1.0.") + if args.top_k < 1: + parser.error("--top-k must be >= 1.") + return args - # Build logger - print_log("Abductive Learning on the MNIST Addition example.", logger="current") - # -- Working with Data ------------------------------ - print_log("Working with Data.", logger="current") - train_data = get_dataset(train=True, get_pseudo_label=True) - test_data = get_dataset(train=False, get_pseudo_label=True) +def build_kb(args): + if args.prolog: + return PrologKB(pseudo_label_list=list(range(10)), pl_file="add.pl") + if args.ground: + return AddGroundKB() + return AddKB() - # -- Building the Learning Part --------------------- - print_log("Building the Learning Part.", logger="current") - # Build necessary components for BasicNN - cls = LeNet5(num_classes=10) +def build_base_model(args): + net = LeNet5(num_classes=10) loss_fn = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) - optimizer = RMSprop(cls.parameters(), lr=args.lr, alpha=args.alpha) + optimizer = RMSprop(net.parameters(), lr=args.lr, alpha=args.alpha) use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") + train_passes_per_segment = args.top_k if args.method == "verification" else 1 scheduler = lr_scheduler.OneCycleLR( optimizer, max_lr=args.lr, pct_start=0.15, epochs=args.loops, - steps_per_epoch=int(1 / args.segment_size), + steps_per_epoch=int(1 / args.segment_size) * train_passes_per_segment, ) - - # Build BasicNN - base_model = BasicNN( - cls, + return BasicNN( + net, loss_fn, optimizer, scheduler=scheduler, @@ -117,45 +169,88 @@ def main(): num_epochs=args.epochs, ) - # Build ABLModel + +def mask_pseudo_labels(pseudo_label, ratio, seed): + """ + Randomly null out ``(1 - ratio)`` of the entries in ``pseudo_label`` so the + semi-supervised pipeline treats them as unlabeled. The remaining entries + keep their ground-truth values and are used directly during abduction. + """ + rng = random.Random(seed) + n = len(pseudo_label) + keep = set(rng.sample(range(n), int(round(ratio * n)))) + return [pseudo_label[i] if i in keep else None for i in range(n)] + + +def main(): + args = parse_args() + + print_log("Abductive Learning on the MNIST Addition example.", logger="current") + + print_log("Working with Data.", logger="current") + train_data = get_dataset(train=True, get_pseudo_label=True) + test_data = get_dataset(train=False, get_pseudo_label=True) + + val_data = None + if args.labeled_ratio < 1.0: + X, pseudo_label, Y = train_data + val_data = train_data + train_data = (X, mask_pseudo_labels(pseudo_label, args.labeled_ratio, args.seed), Y) + print_log( + f"Semi-supervised: keeping {args.labeled_ratio:.0%} of pseudo-labels.", + logger="current", + ) + + print_log("Building the Learning Part.", logger="current") + base_model = build_base_model(args) model = ABLModel(base_model) - # -- Building the Reasoning Part -------------------- print_log("Building the Reasoning Part.", logger="current") - - # Build knowledge base - if args.prolog: - kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="add.pl") - elif args.ground: - kb = AddGroundKB() + kb = build_kb(args) + if args.method == "verification": + reasoner = VerificationReasoner(kb, top_k=args.top_k) + elif args.method == "a3bl": + reasoner = A3BLReasoner( + kb, + dist_func=args.dist_func, + max_revision=args.max_revision, + require_more_revision=args.require_more_revision, + ) else: - kb = AddKB() - - # Create reasoner - reasoner = Reasoner( - kb, max_revision=args.max_revision, require_more_revision=args.require_more_revision - ) + reasoner = Reasoner( + kb, + dist_func=args.dist_func, + max_revision=args.max_revision, + require_more_revision=args.require_more_revision, + ) - # -- Building Evaluation Metrics -------------------- print_log("Building Evaluation Metrics.", logger="current") metric_list = [SymbolAccuracy(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")] - # -- Bridging Learning and Reasoning ---------------- print_log("Bridge Learning and Reasoning.", logger="current") - bridge = SimpleBridge(model, reasoner, metric_list) + if args.method == "verification": + bridge_cls = VerificationBridge + elif args.method == "a3bl": + bridge_cls = A3BLBridge + else: + bridge_cls = SimpleBridge + bridge = bridge_cls(model, reasoner, metric_list) - # Retrieve the directory of the Log file and define the directory for saving the model weights. log_dir = ABLLogger.get_current_instance().log_dir weights_dir = osp.join(log_dir, "weights") - # Train and Test - bridge.train( - train_data, + train_kwargs = dict( loops=args.loops, segment_size=args.segment_size, save_interval=args.save_interval, save_dir=weights_dir, ) + if args.method == "standard" and args.labeled_ratio < 1.0: + train_kwargs["use_supervised_data"] = True + if val_data is not None and args.method != "a3bl": + train_kwargs["val_data"] = val_data + + bridge.train(train_data, **train_kwargs) bridge.test(test_data) diff --git a/examples/mnist_add/models/__init__.py b/examples/mnist_add/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/mnist_add/models/nn.py b/examples/mnist_add/models/nn.py index 93f5cda11..f1019df1e 100644 --- a/examples/mnist_add/models/nn.py +++ b/examples/mnist_add/models/nn.py @@ -1,4 +1,5 @@ from torch import nn +import torch.nn.functional as F class LeNet5(nn.Module): @@ -20,6 +21,9 @@ def __init__(self, num_classes=10, image_size=(28, 28, 1)): nn.ReLU(), nn.Linear(84, num_classes), ) + + def extract_features(self, x): + return self.encoder(x) def forward(self, x): x = self.encoder(x) diff --git a/examples/zoo/kb.py b/examples/zoo/kb.py index 9a08077f1..d39d98536 100644 --- a/examples/zoo/kb.py +++ b/examples/zoo/kb.py @@ -1,5 +1,5 @@ import openml -from z3 import If, Implies, Int, Not, Solver, Sum, sat # noqa: F401 +from z3 import If, Implies, Int, Not, Solver, Sum, sat from ablkit.reasoning import KBBase @@ -10,84 +10,71 @@ def __init__(self): self.solver = Solver() - # Load information of Zoo dataset dataset = openml.datasets.get_dataset( dataset_id=62, download_data=False, download_qualities=False, download_features_meta_data=False, ) - X, y, categorical_indicator, attribute_names = dataset.get_data( + _, y, _, attribute_names = dataset.get_data( target=dataset.default_target_attribute ) self.attribute_names = attribute_names self.target_names = y.cat.categories.tolist() - # print("Attribute names are: ", self.attribute_names) - # print("Target names are: ", self.target_names) - # self.attribute_names = ["hair", "feathers", "eggs", "milk", "airborne", "aquatic", "predator", "toothed", "backbone", "breathes", "venomous", "fins", "legs", "tail", "domestic", "catsize"] # noqa: E501 - # self.target_names = ["mammal", "bird", "reptile", "fish", "amphibian", "insect", "invertebrate"] # noqa: E501 - # Define variables - for name in self.attribute_names + self.target_names: - exec( - f"globals()['{name}'] = Int('{name}')" - ) # or use dict to create var and modify rules - # Define rules + # Create a Z3 Int variable for every attribute and target, keyed by name. + self.vars = {name: Int(name) for name in self.attribute_names + self.target_names} + v = self.vars + rules = [ - Implies(milk == 1, mammal == 1), - Implies(mammal == 1, milk == 1), - Implies(mammal == 1, backbone == 1), - Implies(mammal == 1, breathes == 1), - Implies(feathers == 1, bird == 1), - Implies(bird == 1, feathers == 1), - Implies(bird == 1, eggs == 1), - Implies(bird == 1, backbone == 1), - Implies(bird == 1, breathes == 1), - Implies(bird == 1, legs == 2), - Implies(bird == 1, tail == 1), - Implies(reptile == 1, backbone == 1), - Implies(reptile == 1, breathes == 1), - Implies(reptile == 1, tail == 1), - Implies(fish == 1, aquatic == 1), - Implies(fish == 1, toothed == 1), - Implies(fish == 1, backbone == 1), - Implies(fish == 1, Not(breathes == 1)), - Implies(fish == 1, fins == 1), - Implies(fish == 1, legs == 0), - Implies(fish == 1, tail == 1), - Implies(amphibian == 1, eggs == 1), - Implies(amphibian == 1, aquatic == 1), - Implies(amphibian == 1, backbone == 1), - Implies(amphibian == 1, breathes == 1), - Implies(amphibian == 1, legs == 4), - Implies(insect == 1, eggs == 1), - Implies(insect == 1, Not(backbone == 1)), - Implies(insect == 1, legs == 6), - Implies(invertebrate == 1, Not(backbone == 1)), + Implies(v["milk"] == 1, v["mammal"] == 1), + Implies(v["mammal"] == 1, v["milk"] == 1), + Implies(v["mammal"] == 1, v["backbone"] == 1), + Implies(v["mammal"] == 1, v["breathes"] == 1), + Implies(v["feathers"] == 1, v["bird"] == 1), + Implies(v["bird"] == 1, v["feathers"] == 1), + Implies(v["bird"] == 1, v["eggs"] == 1), + Implies(v["bird"] == 1, v["backbone"] == 1), + Implies(v["bird"] == 1, v["breathes"] == 1), + Implies(v["bird"] == 1, v["legs"] == 2), + Implies(v["bird"] == 1, v["tail"] == 1), + Implies(v["reptile"] == 1, v["backbone"] == 1), + Implies(v["reptile"] == 1, v["breathes"] == 1), + Implies(v["reptile"] == 1, v["tail"] == 1), + Implies(v["fish"] == 1, v["aquatic"] == 1), + Implies(v["fish"] == 1, v["toothed"] == 1), + Implies(v["fish"] == 1, v["backbone"] == 1), + Implies(v["fish"] == 1, Not(v["breathes"] == 1)), + Implies(v["fish"] == 1, v["fins"] == 1), + Implies(v["fish"] == 1, v["legs"] == 0), + Implies(v["fish"] == 1, v["tail"] == 1), + Implies(v["amphibian"] == 1, v["eggs"] == 1), + Implies(v["amphibian"] == 1, v["aquatic"] == 1), + Implies(v["amphibian"] == 1, v["backbone"] == 1), + Implies(v["amphibian"] == 1, v["breathes"] == 1), + Implies(v["amphibian"] == 1, v["legs"] == 4), + Implies(v["insect"] == 1, v["eggs"] == 1), + Implies(v["insect"] == 1, Not(v["backbone"] == 1)), + Implies(v["insect"] == 1, v["legs"] == 6), + Implies(v["invertebrate"] == 1, Not(v["backbone"] == 1)), ] - # Define weights and sum of violated weights self.weights = {rule: 1 for rule in rules} self.total_violation_weight = Sum( [If(Not(rule), self.weights[rule], 0) for rule in self.weights] ) def logic_forward(self, pseudo_label, data_point): - attribute_names, target_names = self.attribute_names, self.target_names - solver = self.solver - total_violation_weight = self.total_violation_weight pseudo_label, data_point = pseudo_label[0], data_point[0] - self.solver.reset() - for name, value in zip(attribute_names, data_point): - solver.add(eval(f"{name} == {value}")) - for cate, name in zip(self.pseudo_label_list, target_names): - value = 1 if (cate == pseudo_label) else 0 - solver.add(eval(f"{name} == {value}")) - if solver.check() == sat: - model = solver.model() - total_weight = model.evaluate(total_violation_weight) - return total_weight.as_long() - else: - # No solution found - return 1e10 + for name, value in zip(self.attribute_names, data_point): + self.solver.add(self.vars[name] == value) + for cate, name in zip(self.pseudo_label_list, self.target_names): + value = 1 if cate == pseudo_label else 0 + self.solver.add(self.vars[name] == value) + + if self.solver.check() == sat: + model = self.solver.model() + return model.evaluate(self.total_violation_weight).as_long() + # No solution found + return 1e10 diff --git a/examples/zoo/main.py b/examples/zoo/main.py index 64f8109a7..41d3f4103 100644 --- a/examples/zoo/main.py +++ b/examples/zoo/main.py @@ -14,7 +14,7 @@ from kb import ZooKB -def consitency(data_example, candidates, candidate_idxs, reasoning_results): +def consistency(data_example, candidates, candidate_idxs, reasoning_results): pred_prob = data_example.pred_prob model_scores = avg_confidence_dist(pred_prob, candidate_idxs) rule_scores = np.array(reasoning_results) @@ -57,7 +57,7 @@ def main(): kb = ZooKB() # Create reasoner - reasoner = Reasoner(kb, dist_func=consitency) + reasoner = Reasoner(kb, dist_func=consistency) # -- Building Evaluation Metrics -------------------- print_log("Building Evaluation Metrics.", logger="current") diff --git a/pyproject.toml b/pyproject.toml index 887dd04ff..9a6d5094b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ authors = [ ] description = "Abductive Learning (ABL) toolkit" readme = "README.md" -requires-python = ">=3.7.0" +requires-python = ">=3.8.0" license = {text = "MIT LICENSE"} keywords = ["machine-learning", "framework", "abductive-learning", "neuro-symbolic"] classifiers = [ @@ -24,7 +24,6 @@ classifiers = [ "Operating System :: POSIX :: Linux", "Operating System :: Microsoft :: Windows", "Operating System :: MacOS :: MacOS X", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", @@ -45,6 +44,7 @@ Issues = "https://github.com/AbductiveLearning/ABLkit/issues" [project.optional-dependencies] test = [ + "pytest", "pytest-cov", "black==22.10.0", ] \ No newline at end of file diff --git a/tests/test_reasoning.py b/tests/test_reasoning.py index fa96dcee9..70b9c96a5 100644 --- a/tests/test_reasoning.py +++ b/tests/test_reasoning.py @@ -114,11 +114,10 @@ def test_reasoner_init(self, reasoner_instance): def test_invalid_predefined_dist_func(self, kb_add): with pytest.raises(NotImplementedError) as excinfo: Reasoner(kb_add, "invalid_dist_func") - assert ( - 'Valid options for predefined dist_func include "hamming", "confidence" ' - + 'and "avg_confidence"' - in str(excinfo.value) - ) + message = str(excinfo.value) + for option in ("hamming", "confidence", "avg_confidence", "similarity", "rejection"): + assert option in message + assert "invalid_dist_func" in message def random_dist(self, data_example, candidates, candidate_idxs, reasoning_results): cost_list = [np.random.rand() for _ in candidates]