From c42935b1ec7b01431ded2b5d59506f12b60f3260 Mon Sep 17 00:00:00 2001 From: Yiyun Liu Date: Tue, 2 Jun 2026 16:52:28 -0400 Subject: [PATCH] Run predicators envs in the browser via Pyodide + pybullet WASM MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Split utils.py into a Pyodide-safe lite half and a slim CPython wrapper, repoint env-load-path modules at the lite variant, and add a proof-of-concept that boots predicators inside Pyodide and resets each of the 18 envs in the dropdown. Architecture: - predicators/utils_lite.py is the bulk of the old utils.py minus imports of torch, scipy.stats.beta, imageio, and the pretrained-model SDKs. predicators/utils.py keeps those 11 helpers (DelayDistribution subclasses, save_video/save_images, the four beta_bernoulli_* functions, create_llm_by_name/create_vlm_by_name) and re-exports the lite module via `from utils_lite import *`, so CPython callers see no API change. - structs.py imports utils_lite directly (not utils) to break the cycle that triggers when utils_lite is loaded first under Pyodide. torch and pretrained_model_interface are lazy imports inside the two make_*_process methods that need them (via a shared `_torch_and_delay` helper). - 119 env-path modules (predicators/{envs, ground_truth_models, pybullet_helpers, perception, execution_monitoring}) flipped from `from predicators import utils` to `from predicators import utils_lite as utils` so they don't pull torch into Pyodide. - envs/__init__.py and ground_truth_models/__init__.py call import_submodules(tolerate_import_errors=True) — catches plain ImportError too, not just ModuleNotFoundError — so optional-dep failures (torch / gym_sokoban / gymnasium-robotics) skip the submodule with a warning instead of aborting auto-discovery. - base_env.py defers `from predicators.pretrained_model_interface import OpenAILLM` to inside the one method that uses it. - setup.py: splits heavy ML/LLM deps into an `[ml]` extra; the base install is the env-runtime slim set that's installable under Pyodide via micropip. `[develop]` includes `[ml]` plus formatters. Strict version pins (numpy/matplotlib/pillow/pyyaml) are qualified with `sys_platform != 'emscripten'` so they don't conflict with Pyodide's pre-loaded packages. - predicators/third_party/{__init__.py, fast_downward_translator/__init__.py} — empty package markers so `find_packages` actually ships the third_party tree in the built wheel (the env path actually imports `predicators.third_party.fast_downward_translator.translate`). - Replace `assert not params` with `assert len(params) == 0` in 22 ground-truth option policies; numpy ndarray params trigger "truth value of empty array is ambiguous" otherwise. Browser POC (web/): - web/app/{index.html, main.js} drive a Three.js renderer (via urdf-loader) over Pyodide+pybullet-wasm. JSON.stringify-quoted option names; mesh primitives with no mesh_url are skipped (they were rendering as 1m box stand-ins for in-memory vertex meshes like domino's top triangle). - web/app/setup.py is the Pyodide-side bridge: per-env CFG overrides for envs that need legacy options, multi-link visual transforms, body/color diffs, begin_option/step_option protocol. begin_option wraps the NSRT grounding in try/except so a sampler raising KeyError or AssertionError returns {error: ...} instead of crashing across the Pyodide → JS boundary; step_option also catches the broad Exception after OptionExecutionFailure so a policy raising mid-rollout doesn't leave _active_option armed. Also guards `get_gt_options` with try/except in 3 call sites so envs without registered options (the domino fan/ramp/stairs variants) load cleanly with an empty option list. - web/app/bundle.sh builds the predicators wheel, the gym 0.26 shim wheel, and packs predicators/envs/assets/ into a tarball (skipping the 32 MB tar rebuild when no asset is newer than the existing tarball). The pybullet WASM wheel is fetched from BasisResearch/pybullet-pyodide. - All 18 envs in the dropdown reach `action_dim=...` when smoked headlessly in Chromium via web/app/browser_smoke.mjs: ants, balance, barrier, blocks, boil, circuit, coffee, cover, domino, domino_fan, domino_fan_ramp, domino_fan_ramp_stairs, fan, float, grow, laser, magic_bin, switch. CI (.github/workflows/web.yml): - One `import-check` job. Boots Pyodide in Node, installs the three wheels, asserts every env in web/app/index.html's dropdown registers as a non-abstract BaseEnv subclass. ~30 s wall. No Chromium, no env construction, no asset extraction. - import_check.mjs wraps the Python import in try/except writing the traceback to stderr (which Pyodide's stderr handler captures), and Node-level unhandledRejection/uncaughtException handlers print the stack before exit — so a WASM abort surfaces a real message instead of just `pyodide.asm.js:8`. CI (.github/workflows/predicators.yml): - unit-tests / static-type-checking / lint steps install `[ml]` so they get the heavy deps slim utils.py imports at module top. yapf / isort / docformatter stay slim. mypy.ini: add `web` to the top-level exclude so mypy doesn't bail out on "Duplicate module named setup" (./setup.py vs ./web/app/setup.py). .predicators_pylintrc: add `web` to ignore-paths — the browser bridge monkey-patches pybullet, accesses env internals, and uses in-function imports; those checks don't help the POC. Co-Authored-By: Claude Opus 4.7 --- .github/workflows/predicators.yml | 6 +- .github/workflows/web.yml | 61 + .predicators_pylintrc | 2 +- mypy.ini | 2 +- predicators/envs/__init__.py | 7 +- predicators/envs/ball_and_cup_sticky_table.py | 2 +- predicators/envs/base_env.py | 5 +- predicators/envs/blocks.py | 2 +- predicators/envs/burger.py | 2 +- predicators/envs/cluttered_table.py | 2 +- predicators/envs/coffee.py | 2 +- predicators/envs/cover.py | 2 +- predicators/envs/doors.py | 4 +- predicators/envs/exit_garage.py | 4 +- predicators/envs/grid_row.py | 2 +- predicators/envs/gymnasium_wrapper.py | 5 +- predicators/envs/kitchen.py | 2 +- predicators/envs/narrow_passage.py | 4 +- predicators/envs/noisy_button.py | 2 +- predicators/envs/painting.py | 3 +- predicators/envs/pddl_env.py | 2 +- predicators/envs/playroom.py | 2 +- predicators/envs/pybullet_ants.py | 2 +- predicators/envs/pybullet_balance.py | 2 +- predicators/envs/pybullet_barrier.py | 2 +- predicators/envs/pybullet_blocks.py | 2 +- predicators/envs/pybullet_boil.py | 2 +- predicators/envs/pybullet_circuit.py | 2 +- predicators/envs/pybullet_coffee.py | 2 +- predicators/envs/pybullet_cover.py | 2 +- .../components/domino_component.py | 2 +- .../components/grid_component.py | 2 +- .../envs/pybullet_domino/composed_env.py | 2 +- .../task_generators/domino_task_generator.py | 2 +- predicators/envs/pybullet_env.py | 4 +- predicators/envs/pybullet_fan.py | 2 +- predicators/envs/pybullet_float.py | 2 +- predicators/envs/pybullet_grow.py | 2 +- predicators/envs/pybullet_laser.py | 2 +- predicators/envs/pybullet_magic_bin.py | 2 +- predicators/envs/pybullet_switch.py | 2 +- predicators/envs/sandwich.py | 4 +- predicators/envs/satellites.py | 2 +- predicators/envs/screws.py | 2 +- predicators/envs/sokoban.py | 2 +- predicators/envs/stick_button.py | 4 +- predicators/envs/sticky_table.py | 2 +- predicators/envs/touch_point.py | 2 +- predicators/execution_monitoring/__init__.py | 2 +- .../expected_atoms_monitor.py | 2 +- predicators/ground_truth_models/__init__.py | 7 +- predicators/ground_truth_models/ants/nsrts.py | 2 +- .../ground_truth_models/ants/options.py | 4 +- .../ground_truth_models/balance/nsrts.py | 2 +- .../ground_truth_models/balance/options.py | 4 +- .../ball_and_cup_sticky_table/options.py | 2 +- .../ground_truth_models/blocks/nsrts.py | 2 +- .../ground_truth_models/blocks/options.py | 4 +- predicators/ground_truth_models/boil/nsrts.py | 2 +- .../ground_truth_models/boil/options.py | 2 +- .../boil/options_legacy.py | 10 +- .../ground_truth_models/burger/nsrts.py | 2 +- .../ground_truth_models/burger/options.py | 2 +- .../ground_truth_models/circuit/nsrts.py | 2 +- .../ground_truth_models/circuit/options.py | 8 +- .../cluttered_table/nsrts.py | 2 +- .../cluttered_table/options.py | 2 +- .../ground_truth_models/coffee/nsrts.py | 2 +- .../ground_truth_models/coffee/options.py | 2 +- .../coffee/options_legacy.py | 2 +- .../ground_truth_models/cover/nsrts.py | 2 +- .../ground_truth_models/cover/options.py | 2 +- .../ground_truth_models/domino/nsrts.py | 2 +- .../domino/options_legacy.py | 10 +- .../ground_truth_models/domino/predicates.py | 2 +- .../ground_truth_models/domino/types.py | 2 +- .../ground_truth_models/doors/nsrts.py | 2 +- .../ground_truth_models/doors/options.py | 4 +- .../exit_garage/options.py | 2 +- predicators/ground_truth_models/fan/nsrts.py | 2 +- .../ground_truth_models/fan/options.py | 2 +- .../ground_truth_models/fan/options_legacy.py | 4 +- .../ground_truth_models/float/nsrts.py | 2 +- .../ground_truth_models/float/options.py | 6 +- .../ground_truth_models/grid_row/nsrts.py | 2 +- .../ground_truth_models/grid_row/options.py | 2 +- predicators/ground_truth_models/grow/nsrts.py | 2 +- .../grow/options_legacy.py | 2 +- .../ice_tea_making/options.py | 2 +- .../ground_truth_models/laser/nsrts.py | 2 +- .../ground_truth_models/laser/options.py | 8 +- .../narrow_passage/options.py | 2 +- .../noisy_button/options.py | 2 +- .../ground_truth_models/painting/nsrts.py | 2 +- .../ground_truth_models/painting/options.py | 2 +- .../ground_truth_models/pddl_env/nsrts.py | 4 +- .../ground_truth_models/pddl_env/options.py | 2 +- .../ground_truth_models/playroom/options.py | 2 +- .../repeated_nextto/nsrts.py | 2 +- .../repeated_nextto/options.py | 2 +- .../ground_truth_models/sandwich/nsrts.py | 2 +- .../ground_truth_models/sandwich/options.py | 2 +- .../ground_truth_models/satellites/nsrts.py | 2 +- .../ground_truth_models/satellites/options.py | 2 +- .../ground_truth_models/screws/nsrts.py | 2 +- .../ground_truth_models/screws/options.py | 2 +- .../skill_factories/base.py | 2 +- .../skill_factories/wait.py | 2 +- .../ground_truth_models/sokoban/nsrts.py | 2 +- .../ground_truth_models/sokoban/options.py | 2 +- .../ground_truth_models/stick_button/nsrts.py | 2 +- .../stick_button/options.py | 2 +- .../sticky_table/options.py | 2 +- .../ground_truth_models/tools/nsrts.py | 2 +- .../ground_truth_models/tools/options.py | 6 +- .../ground_truth_models/touch_point/nsrts.py | 2 +- predicators/perception/__init__.py | 2 +- predicators/perception/sokoban_perceiver.py | 2 +- predicators/pybullet_helpers/controllers.py | 2 +- predicators/pybullet_helpers/ikfast/load.py | 2 +- .../pybullet_helpers/motion_planning.py | 2 +- predicators/pybullet_helpers/objects.py | 6 +- predicators/pybullet_helpers/robots/fetch.py | 2 +- predicators/pybullet_helpers/robots/panda.py | 2 +- predicators/structs.py | 41 +- predicators/third_party/__init__.py | 0 .../fast_downward_translator/__init__.py | 0 predicators/utils.py | 5108 +---------------- predicators/utils_lite.py | 5069 ++++++++++++++++ setup.py | 107 +- web/.gitignore | 8 + web/README.md | 121 + web/app/browser_smoke.mjs | 141 + web/app/bundle.sh | 48 + web/app/deploy.sh | 91 + web/app/gym_shim_setup.py | 104 + web/app/import_check.mjs | 109 + web/app/index.html | 107 + web/app/main.js | 913 +++ web/app/node_smoke.mjs | 244 + web/app/serve.sh | 10 + web/app/setup.py | 658 +++ web/app/sweep_smoke.mjs | 107 + web/package-lock.json | 971 ++++ web/package.json | 11 + 145 files changed, 9111 insertions(+), 5249 deletions(-) create mode 100644 .github/workflows/web.yml create mode 100644 predicators/third_party/__init__.py create mode 100644 predicators/third_party/fast_downward_translator/__init__.py create mode 100644 predicators/utils_lite.py create mode 100644 web/.gitignore create mode 100644 web/README.md create mode 100644 web/app/browser_smoke.mjs create mode 100755 web/app/bundle.sh create mode 100755 web/app/deploy.sh create mode 100644 web/app/gym_shim_setup.py create mode 100644 web/app/import_check.mjs create mode 100644 web/app/index.html create mode 100644 web/app/main.js create mode 100644 web/app/node_smoke.mjs create mode 100755 web/app/serve.sh create mode 100644 web/app/setup.py create mode 100644 web/app/sweep_smoke.mjs create mode 100644 web/package-lock.json create mode 100644 web/package.json diff --git a/.github/workflows/predicators.yml b/.github/workflows/predicators.yml index a1bf9df24c..1529d4e73d 100644 --- a/.github/workflows/predicators.yml +++ b/.github/workflows/predicators.yml @@ -24,7 +24,7 @@ jobs: restore-keys: | pip-${{ matrix.python-version }}- - run: | - pip install -e . + pip install -e .[ml] pip install pytest-cov==2.12.1 pytest-split - name: Pytest (group ${{ matrix.group }}/8) run: | @@ -84,7 +84,7 @@ jobs: pip-${{ matrix.python-version }}- - name: Install dependencies run: | - pip install -e . + pip install -e .[ml] pip install mypy==1.8.0 - name: Mypy run: | @@ -109,7 +109,7 @@ jobs: pip-${{ matrix.python-version }}- - name: Install dependencies run: | - pip install -e . + pip install -e .[ml] pip install pytest-pylint==0.18.0 - name: Pylint run: | diff --git a/.github/workflows/web.yml b/.github/workflows/web.yml new file mode 100644 index 0000000000..beb8d9a85e --- /dev/null +++ b/.github/workflows/web.yml @@ -0,0 +1,61 @@ +name: web + +on: [push] + +jobs: + import-check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.10.14 + uses: actions/setup-python@v5 + with: + python-version: "3.10.14" + + - name: Cache pip packages + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: pip-3.10.14-${{ hashFiles('setup.py') }} + restore-keys: | + pip-3.10.14- + + - name: Set up Node 24 + uses: actions/setup-node@v4 + with: + node-version: "24" + + - name: Install predicators + build tools + run: | + pip install -e . + pip install build + + - name: Build predicators + gym shim wheels + run: | + PYTHON=$(which python) ./web/app/bundle.sh + + # Pre-built pybullet WASM wheel from BasisResearch/pybullet-pyodide. + # That repo's CI builds the wheel as a release asset; we just + # download it. Override `PYBULLET_WASM_WHEEL_URL` via a repo + # variable to bump the version without editing this file. + - name: Download pybullet WASM wheel + env: + WHEEL_URL: ${{ vars.PYBULLET_WASM_WHEEL_URL || 'https://github.com/BasisResearch/pybullet-pyodide/releases/latest/download/pybullet-3.2.7-cp313-cp313-pyemscripten_2025_0_wasm32.whl' }} + run: | + curl -L --fail -o web/wheels/pybullet-3.2.7-cp313-cp313-pyemscripten_2025_0_wasm32.whl \ + "$WHEEL_URL" + ls -lh web/wheels/ + + - name: Install Node deps (pyodide + puppeteer-core) + working-directory: web + run: | + npm ci + + # The cheap gate: boot Pyodide in Node, install the three wheels, + # and verify every env named in the dropdown shows up as a + # registered BaseEnv subclass. No Chromium, no env construction, + # no asset extraction. ~30 s wall, vs. ~12 min for per-env smoke. + - name: Pyodide import check + run: | + node web/app/import_check.mjs diff --git a/.predicators_pylintrc b/.predicators_pylintrc index 826e18ba08..057c82a78e 100644 --- a/.predicators_pylintrc +++ b/.predicators_pylintrc @@ -10,7 +10,7 @@ extension-pkg-whitelist=numpy,pybullet,torch,tensorflow,pyrealsense2 ignore=CVS # Add paths to the blacklist. -ignore-paths=predicators/envs/assets,predicators/third_party,venv +ignore-paths=predicators/envs/assets,predicators/third_party,venv,web # Add files or directories matching the regex patterns to the blacklist. The # regex matches against base names, not paths. diff --git a/mypy.ini b/mypy.ini index 6f189e2e9d..a4e2b868d5 100644 --- a/mypy.ini +++ b/mypy.ini @@ -2,7 +2,7 @@ strict_equality = True disallow_untyped_calls = True warn_unreachable = True -exclude = (predicators/envs/assets|venv|prompts|logs) +exclude = (predicators/envs/assets|venv|prompts|logs|web) [mypy-predicators.*] disallow_untyped_defs = True diff --git a/predicators/envs/__init__.py b/predicators/envs/__init__.py index 2510edd608..43f45e1e72 100644 --- a/predicators/envs/__init__.py +++ b/predicators/envs/__init__.py @@ -3,14 +3,15 @@ import logging from typing import Any -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.base_env import BaseEnv __all__ = ["BaseEnv"] _MOST_RECENT_ENV_INSTANCE = {} -# Find the subclasses. -utils.import_submodules(__path__, __name__) +# Find the subclasses. Tolerate missing optional deps so constrained +# runtimes (e.g. Pyodide) can still load the envs that don't need them. +utils.import_submodules(__path__, __name__, tolerate_import_errors=True) def create_new_env(name: str, diff --git a/predicators/envs/ball_and_cup_sticky_table.py b/predicators/envs/ball_and_cup_sticky_table.py index 79f1c14b28..ae5dcf82d9 100644 --- a/predicators/envs/ball_and_cup_sticky_table.py +++ b/predicators/envs/ball_and_cup_sticky_table.py @@ -6,7 +6,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs import BaseEnv from predicators.settings import CFG from predicators.structs import Action, EnvironmentTask, GroundAtom, Object, \ diff --git a/predicators/envs/base_env.py b/predicators/envs/base_env.py index a88eae29eb..8196485244 100644 --- a/predicators/envs/base_env.py +++ b/predicators/envs/base_env.py @@ -10,8 +10,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils -from predicators.pretrained_model_interface import OpenAILLM +from predicators import utils_lite as utils from predicators.settings import CFG from predicators.structs import Action, DefaultEnvironmentTask, \ EnvironmentTask, GroundAtom, Object, Observation, Predicate, State, Task, \ @@ -319,6 +318,8 @@ def _parse_language_goal_from_json( object_names = set(id_to_obj) prompt_prefix = self._get_language_goal_prompt_prefix(object_names) prompt = prompt_prefix + f"\n# {language_goal}" + from predicators.pretrained_model_interface import \ + OpenAILLM # pylint: disable=import-outside-toplevel llm = OpenAILLM(CFG.llm_model_name) responses = llm.sample_completions(prompt, None, diff --git a/predicators/envs/blocks.py b/predicators/envs/blocks.py index e2320df7c3..48009a0232 100644 --- a/predicators/envs/blocks.py +++ b/predicators/envs/blocks.py @@ -19,7 +19,7 @@ from gym.spaces import Box from matplotlib import patches -from predicators import utils +from predicators import utils_lite as utils from predicators.envs import BaseEnv from predicators.settings import CFG from predicators.structs import Action, Array, EnvironmentTask, GroundAtom, \ diff --git a/predicators/envs/burger.py b/predicators/envs/burger.py index 7d6d5f8b6a..c513a2b716 100644 --- a/predicators/envs/burger.py +++ b/predicators/envs/burger.py @@ -18,7 +18,7 @@ from gym.spaces import Box from PIL import Image -from predicators import utils +from predicators import utils_lite as utils from predicators.envs import BaseEnv from predicators.settings import CFG from predicators.structs import Action, DefaultEnvironmentTask, \ diff --git a/predicators/envs/cluttered_table.py b/predicators/envs/cluttered_table.py index 17477f5801..3eab20fce3 100644 --- a/predicators/envs/cluttered_table.py +++ b/predicators/envs/cluttered_table.py @@ -11,7 +11,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs import BaseEnv from predicators.settings import CFG from predicators.structs import Action, Array, EnvironmentTask, GroundAtom, \ diff --git a/predicators/envs/coffee.py b/predicators/envs/coffee.py index 5e179714d5..0911b228fa 100644 --- a/predicators/envs/coffee.py +++ b/predicators/envs/coffee.py @@ -8,7 +8,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs import BaseEnv from predicators.pybullet_helpers.objects import \ sample_collision_free_2d_positions diff --git a/predicators/envs/cover.py b/predicators/envs/cover.py index 6bb1e48cfb..ef8fe5472d 100644 --- a/predicators/envs/cover.py +++ b/predicators/envs/cover.py @@ -11,7 +11,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs import BaseEnv from predicators.settings import CFG from predicators.structs import Action, Array, EnvironmentTask, GroundAtom, \ diff --git a/predicators/envs/doors.py b/predicators/envs/doors.py index 1f3c02615a..c8e9eb250d 100644 --- a/predicators/envs/doors.py +++ b/predicators/envs/doors.py @@ -10,12 +10,12 @@ from gym.spaces import Box from numpy.typing import NDArray -from predicators import utils +from predicators import utils_lite as utils from predicators.envs import BaseEnv from predicators.settings import CFG from predicators.structs import Action, Array, EnvironmentTask, GroundAtom, \ Object, Predicate, State, Type -from predicators.utils import Rectangle, StateWithCache, _Geom2D +from predicators.utils_lite import Rectangle, StateWithCache, _Geom2D class DoorsEnv(BaseEnv): diff --git a/predicators/envs/exit_garage.py b/predicators/envs/exit_garage.py index 5d7930d822..651f6a94df 100644 --- a/predicators/envs/exit_garage.py +++ b/predicators/envs/exit_garage.py @@ -9,12 +9,12 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs import BaseEnv from predicators.settings import CFG from predicators.structs import Action, EnvironmentTask, GroundAtom, Object, \ Predicate, State, Type -from predicators.utils import _Geom2D +from predicators.utils_lite import _Geom2D class ExitGarageEnv(BaseEnv): diff --git a/predicators/envs/grid_row.py b/predicators/envs/grid_row.py index a8486569dd..208f9671d6 100644 --- a/predicators/envs/grid_row.py +++ b/predicators/envs/grid_row.py @@ -7,7 +7,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs import BaseEnv from predicators.settings import CFG from predicators.structs import Action, EnvironmentTask, GroundAtom, Object, \ diff --git a/predicators/envs/gymnasium_wrapper.py b/predicators/envs/gymnasium_wrapper.py index 3829e33cda..d4116307fd 100644 --- a/predicators/envs/gymnasium_wrapper.py +++ b/predicators/envs/gymnasium_wrapper.py @@ -44,7 +44,8 @@ def _ensure_cfg_initialized() -> None: required fields like ``seed`` are missing and ``BaseEnv.__init__`` would crash. """ - from predicators import utils # pylint: disable=import-outside-toplevel + from predicators import \ + utils_lite as utils # pylint: disable=import-outside-toplevel from predicators.settings import \ CFG # pylint: disable=import-outside-toplevel if not hasattr(CFG, "seed"): @@ -88,7 +89,7 @@ def __init__( _ensure_cfg_initialized() if cfg_overrides: from predicators import \ - utils # pylint: disable=import-outside-toplevel + utils_lite as utils # pylint: disable=import-outside-toplevel utils.update_config(cfg_overrides) resolved_cls = _resolve_cls(env_cls) self._env = resolved_cls(use_gui=use_gui, **env_kwargs) diff --git a/predicators/envs/kitchen.py b/predicators/envs/kitchen.py index 564af32370..257d88077f 100644 --- a/predicators/envs/kitchen.py +++ b/predicators/envs/kitchen.py @@ -17,7 +17,7 @@ _MJKITCHEN_IMPORTED = True except (ImportError, RuntimeError): _MJKITCHEN_IMPORTED = False -from predicators import utils +from predicators import utils_lite as utils from predicators.envs import BaseEnv from predicators.settings import CFG from predicators.structs import Action, EnvironmentTask, Image, Object, \ diff --git a/predicators/envs/narrow_passage.py b/predicators/envs/narrow_passage.py index a6f776f900..dd93403221 100644 --- a/predicators/envs/narrow_passage.py +++ b/predicators/envs/narrow_passage.py @@ -7,12 +7,12 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs import BaseEnv from predicators.settings import CFG from predicators.structs import Action, EnvironmentTask, GroundAtom, Object, \ Predicate, State, Type -from predicators.utils import _Geom2D +from predicators.utils_lite import _Geom2D class NarrowPassageEnv(BaseEnv): diff --git a/predicators/envs/noisy_button.py b/predicators/envs/noisy_button.py index 8a481a3eb6..39ec67d4d6 100644 --- a/predicators/envs/noisy_button.py +++ b/predicators/envs/noisy_button.py @@ -17,7 +17,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs import BaseEnv from predicators.settings import CFG from predicators.structs import Action, EnvironmentTask, GroundAtom, Object, \ diff --git a/predicators/envs/painting.py b/predicators/envs/painting.py index 7039a3666d..12383b4750 100644 --- a/predicators/envs/painting.py +++ b/predicators/envs/painting.py @@ -20,7 +20,8 @@ from predicators.settings import CFG from predicators.structs import Action, EnvironmentTask, GroundAtom, Object, \ Predicate, State, Type -from predicators.utils import EnvironmentFailure, HumanDemonstrationFailure +from predicators.utils_lite import EnvironmentFailure, \ + HumanDemonstrationFailure class PaintingEnv(BaseEnv): diff --git a/predicators/envs/pddl_env.py b/predicators/envs/pddl_env.py index dd826e719e..309b85754e 100644 --- a/predicators/envs/pddl_env.py +++ b/predicators/envs/pddl_env.py @@ -19,7 +19,7 @@ from pyperplan.pddl.pddl import Domain as PyperplanDomain from pyperplan.pddl.pddl import Type as PyperplanType -from predicators import utils +from predicators import utils_lite as utils from predicators.envs import BaseEnv from predicators.envs.pddl_procedural_generation import \ create_blocks_pddl_generator, create_delivery_pddl_generator, \ diff --git a/predicators/envs/playroom.py b/predicators/envs/playroom.py index 77e6913d08..b457a581b1 100644 --- a/predicators/envs/playroom.py +++ b/predicators/envs/playroom.py @@ -8,7 +8,7 @@ from gym.spaces import Box from matplotlib import patches -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.blocks import BlocksEnv, BlocksEnvClear from predicators.settings import CFG from predicators.structs import Action, Array, EnvironmentTask, GroundAtom, \ diff --git a/predicators/envs/pybullet_ants.py b/predicators/envs/pybullet_ants.py index 795903f25d..4b93fed637 100644 --- a/predicators/envs/pybullet_ants.py +++ b/predicators/envs/pybullet_ants.py @@ -4,7 +4,7 @@ import numpy as np import pybullet as p -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.objects import create_object, \ create_pybullet_block, sample_collision_free_2d_positions, update_object diff --git a/predicators/envs/pybullet_balance.py b/predicators/envs/pybullet_balance.py index 947f21b846..e00f1126b9 100644 --- a/predicators/envs/pybullet_balance.py +++ b/predicators/envs/pybullet_balance.py @@ -22,7 +22,7 @@ from predicators.settings import CFG from predicators.structs import Action, Array, ConceptPredicate, \ EnvironmentTask, GroundAtom, NSPredicate, Object, Predicate, State, Type -from predicators.utils import VLMQuery +from predicators.utils_lite import VLMQuery class PyBulletBalanceEnv(PyBulletEnv): diff --git a/predicators/envs/pybullet_barrier.py b/predicators/envs/pybullet_barrier.py index d5b45f6f95..a30f01b3e3 100644 --- a/predicators/envs/pybullet_barrier.py +++ b/predicators/envs/pybullet_barrier.py @@ -14,7 +14,7 @@ import numpy as np import pybullet as p -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion from predicators.pybullet_helpers.objects import create_object, \ diff --git a/predicators/envs/pybullet_blocks.py b/predicators/envs/pybullet_blocks.py index 26ed5b0adc..6f47bd92d5 100644 --- a/predicators/envs/pybullet_blocks.py +++ b/predicators/envs/pybullet_blocks.py @@ -7,7 +7,7 @@ import numpy as np import pybullet as p -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.blocks import BlocksEnv from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion diff --git a/predicators/envs/pybullet_boil.py b/predicators/envs/pybullet_boil.py index 7dd5cc8860..a5e9131bfe 100644 --- a/predicators/envs/pybullet_boil.py +++ b/predicators/envs/pybullet_boil.py @@ -8,7 +8,7 @@ import numpy as np import pybullet as p -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers import retry_pybullet_call from predicators.pybullet_helpers.geometry import Pose3D, Quaternion diff --git a/predicators/envs/pybullet_circuit.py b/predicators/envs/pybullet_circuit.py index 9e0e1ae8e5..645a234e7f 100644 --- a/predicators/envs/pybullet_circuit.py +++ b/predicators/envs/pybullet_circuit.py @@ -16,7 +16,7 @@ import numpy as np import pybullet as p -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion from predicators.pybullet_helpers.objects import create_object diff --git a/predicators/envs/pybullet_coffee.py b/predicators/envs/pybullet_coffee.py index 364318200f..315d9445b6 100644 --- a/predicators/envs/pybullet_coffee.py +++ b/predicators/envs/pybullet_coffee.py @@ -37,7 +37,7 @@ import numpy as np import pybullet as p -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.coffee import CoffeeEnv from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion diff --git a/predicators/envs/pybullet_cover.py b/predicators/envs/pybullet_cover.py index e51addbe52..2f648608a5 100644 --- a/predicators/envs/pybullet_cover.py +++ b/predicators/envs/pybullet_cover.py @@ -11,7 +11,7 @@ import numpy as np import pybullet as p -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.cover import CoverEnv from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion diff --git a/predicators/envs/pybullet_domino/components/domino_component.py b/predicators/envs/pybullet_domino/components/domino_component.py index 8375ffba30..56dcd3f533 100644 --- a/predicators/envs/pybullet_domino/components/domino_component.py +++ b/predicators/envs/pybullet_domino/components/domino_component.py @@ -15,7 +15,7 @@ import numpy as np import pybullet as p -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.pybullet_domino.components.base_component import \ DominoEnvComponent from predicators.pybullet_helpers.geometry import Pose3D, Quaternion diff --git a/predicators/envs/pybullet_domino/components/grid_component.py b/predicators/envs/pybullet_domino/components/grid_component.py index 29f98576ea..fd5c2413e8 100644 --- a/predicators/envs/pybullet_domino/components/grid_component.py +++ b/predicators/envs/pybullet_domino/components/grid_component.py @@ -12,7 +12,7 @@ import numpy as np import pybullet as p -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.pybullet_domino.components.base_component import \ DominoEnvComponent from predicators.settings import CFG diff --git a/predicators/envs/pybullet_domino/composed_env.py b/predicators/envs/pybullet_domino/composed_env.py index f27a0759d6..75b2eb6087 100644 --- a/predicators/envs/pybullet_domino/composed_env.py +++ b/predicators/envs/pybullet_domino/composed_env.py @@ -9,7 +9,7 @@ import numpy as np import pybullet as p -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.pybullet_domino.components.ball_component import \ BallComponent from predicators.envs.pybullet_domino.components.base_component import \ diff --git a/predicators/envs/pybullet_domino/task_generators/domino_task_generator.py b/predicators/envs/pybullet_domino/task_generators/domino_task_generator.py index e8b3655b98..e7ca93c270 100644 --- a/predicators/envs/pybullet_domino/task_generators/domino_task_generator.py +++ b/predicators/envs/pybullet_domino/task_generators/domino_task_generator.py @@ -4,7 +4,7 @@ import numpy as np -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.pybullet_domino.components.domino_component import \ DominoComponent, PlacementResult from predicators.envs.pybullet_domino.task_generators.base_generator import \ diff --git a/predicators/envs/pybullet_env.py b/predicators/envs/pybullet_env.py index b589783d9d..f8b42481b2 100644 --- a/predicators/envs/pybullet_env.py +++ b/predicators/envs/pybullet_env.py @@ -43,7 +43,7 @@ from gym.spaces import Box from PIL import Image -from predicators import utils +from predicators import utils_lite as utils from predicators.envs import BaseEnv from predicators.pybullet_helpers import retry_pybullet_call from predicators.pybullet_helpers.camera import create_gui_connection @@ -56,7 +56,7 @@ from predicators.settings import CFG from predicators.structs import Action, Array, EnvironmentTask, Mask, Object, \ Observation, State, Video -from predicators.utils import PyBulletState +from predicators.utils_lite import PyBulletState class PyBulletEnv(BaseEnv): diff --git a/predicators/envs/pybullet_fan.py b/predicators/envs/pybullet_fan.py index cfc0745994..efff5766e0 100644 --- a/predicators/envs/pybullet_fan.py +++ b/predicators/envs/pybullet_fan.py @@ -5,7 +5,7 @@ import numpy as np import pybullet as p -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion from predicators.pybullet_helpers.objects import create_object, \ diff --git a/predicators/envs/pybullet_float.py b/predicators/envs/pybullet_float.py index 88be815747..fa26deae93 100644 --- a/predicators/envs/pybullet_float.py +++ b/predicators/envs/pybullet_float.py @@ -12,7 +12,7 @@ import numpy as np import pybullet as p -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion from predicators.pybullet_helpers.objects import create_object, \ diff --git a/predicators/envs/pybullet_grow.py b/predicators/envs/pybullet_grow.py index d2fc483fe5..ee18e17da6 100644 --- a/predicators/envs/pybullet_grow.py +++ b/predicators/envs/pybullet_grow.py @@ -12,7 +12,7 @@ import numpy as np import pybullet as p -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.pybullet_coffee import PyBulletCoffeeEnv from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion diff --git a/predicators/envs/pybullet_laser.py b/predicators/envs/pybullet_laser.py index 86f8427f04..84bc450f2e 100644 --- a/predicators/envs/pybullet_laser.py +++ b/predicators/envs/pybullet_laser.py @@ -19,7 +19,7 @@ import numpy as np import pybullet as p -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion from predicators.pybullet_helpers.objects import create_object, update_object diff --git a/predicators/envs/pybullet_magic_bin.py b/predicators/envs/pybullet_magic_bin.py index b235022d31..459b386288 100644 --- a/predicators/envs/pybullet_magic_bin.py +++ b/predicators/envs/pybullet_magic_bin.py @@ -15,7 +15,7 @@ import numpy as np import pybullet as p -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion from predicators.pybullet_helpers.objects import create_object, \ diff --git a/predicators/envs/pybullet_switch.py b/predicators/envs/pybullet_switch.py index bca7b23d83..4795214150 100644 --- a/predicators/envs/pybullet_switch.py +++ b/predicators/envs/pybullet_switch.py @@ -16,7 +16,7 @@ import numpy as np import pybullet as p -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.pybullet_env import PyBulletEnv from predicators.pybullet_helpers.geometry import Pose3D, Quaternion from predicators.pybullet_helpers.objects import create_object diff --git a/predicators/envs/sandwich.py b/predicators/envs/sandwich.py index bada028ae3..67d854200d 100644 --- a/predicators/envs/sandwich.py +++ b/predicators/envs/sandwich.py @@ -8,12 +8,12 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs import BaseEnv from predicators.settings import CFG from predicators.structs import RGBA, Action, EnvironmentTask, GroundAtom, \ Object, Predicate, State, Type -from predicators.utils import _Geom2D +from predicators.utils_lite import _Geom2D class SandwichEnv(BaseEnv): diff --git a/predicators/envs/satellites.py b/predicators/envs/satellites.py index 3374e8bd8a..c26254766b 100644 --- a/predicators/envs/satellites.py +++ b/predicators/envs/satellites.py @@ -26,7 +26,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs import BaseEnv from predicators.settings import CFG from predicators.structs import Action, EnvironmentTask, GroundAtom, Object, \ diff --git a/predicators/envs/screws.py b/predicators/envs/screws.py index 4dfcba53c7..242f924635 100644 --- a/predicators/envs/screws.py +++ b/predicators/envs/screws.py @@ -7,7 +7,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs import BaseEnv from predicators.settings import CFG from predicators.structs import Action, EnvironmentTask, GroundAtom, Object, \ diff --git a/predicators/envs/sokoban.py b/predicators/envs/sokoban.py index 0f7c353d5b..23dfb10536 100644 --- a/predicators/envs/sokoban.py +++ b/predicators/envs/sokoban.py @@ -7,7 +7,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs import BaseEnv from predicators.settings import CFG from predicators.structs import Action, EnvironmentTask, Image, Object, \ diff --git a/predicators/envs/stick_button.py b/predicators/envs/stick_button.py index 8eb491d8d3..1a0846bd89 100644 --- a/predicators/envs/stick_button.py +++ b/predicators/envs/stick_button.py @@ -8,12 +8,12 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs import BaseEnv from predicators.settings import CFG from predicators.structs import Action, EnvironmentTask, GroundAtom, Object, \ Predicate, State, Type -from predicators.utils import _Geom2D +from predicators.utils_lite import _Geom2D class StickButtonEnv(BaseEnv): diff --git a/predicators/envs/sticky_table.py b/predicators/envs/sticky_table.py index 6021781882..919dc65ce5 100644 --- a/predicators/envs/sticky_table.py +++ b/predicators/envs/sticky_table.py @@ -8,7 +8,7 @@ from gym.spaces import Box from matplotlib.patches import Wedge -from predicators import utils +from predicators import utils_lite as utils from predicators.envs import BaseEnv from predicators.settings import CFG from predicators.structs import Action, EnvironmentTask, GroundAtom, Object, \ diff --git a/predicators/envs/touch_point.py b/predicators/envs/touch_point.py index 19cef7c82f..b5438af039 100644 --- a/predicators/envs/touch_point.py +++ b/predicators/envs/touch_point.py @@ -8,7 +8,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs import BaseEnv from predicators.settings import CFG from predicators.structs import Action, EnvironmentTask, GroundAtom, Object, \ diff --git a/predicators/execution_monitoring/__init__.py b/predicators/execution_monitoring/__init__.py index d73bf94ca3..614714a01e 100644 --- a/predicators/execution_monitoring/__init__.py +++ b/predicators/execution_monitoring/__init__.py @@ -1,6 +1,6 @@ """Handle creation of execution monitors.""" -from predicators import utils +from predicators import utils_lite as utils from predicators.execution_monitoring.base_execution_monitor import \ BaseExecutionMonitor diff --git a/predicators/execution_monitoring/expected_atoms_monitor.py b/predicators/execution_monitoring/expected_atoms_monitor.py index eecd7bc06b..20b569e40b 100644 --- a/predicators/execution_monitoring/expected_atoms_monitor.py +++ b/predicators/execution_monitoring/expected_atoms_monitor.py @@ -3,7 +3,7 @@ import logging -from predicators import utils +from predicators import utils_lite as utils from predicators.execution_monitoring.base_execution_monitor import \ BaseExecutionMonitor from predicators.settings import CFG diff --git a/predicators/ground_truth_models/__init__.py b/predicators/ground_truth_models/__init__.py index 54b6155d9b..069c1da29a 100644 --- a/predicators/ground_truth_models/__init__.py +++ b/predicators/ground_truth_models/__init__.py @@ -6,7 +6,7 @@ from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs import BaseEnv, get_or_create_env from predicators.settings import CFG from predicators.structs import NSRT, CausalProcess, EndogenousProcess, \ @@ -401,5 +401,6 @@ def _get_options_by_names(env_name: str, return [name_to_option[name] for name in names] -# Find the factories. -utils.import_submodules(__path__, __name__) +# Find the factories. Tolerate missing optional deps (e.g. torch under +# Pyodide) — factories whose import fails are skipped with a warning. +utils.import_submodules(__path__, __name__, tolerate_import_errors=True) diff --git a/predicators/ground_truth_models/ants/nsrts.py b/predicators/ground_truth_models/ants/nsrts.py index 63b77d0b03..53b8491938 100644 --- a/predicators/ground_truth_models/ants/nsrts.py +++ b/predicators/ground_truth_models/ants/nsrts.py @@ -7,7 +7,7 @@ from predicators.ground_truth_models import GroundTruthNSRTFactory from predicators.structs import NSRT, Array, GroundAtom, LiftedAtom, Object, \ ParameterizedOption, Predicate, State, Type, Variable -from predicators.utils import null_sampler +from predicators.utils_lite import null_sampler class PyBulletAntsGroundTruthNSRTFactory(GroundTruthNSRTFactory): diff --git a/predicators/ground_truth_models/ants/options.py b/predicators/ground_truth_models/ants/options.py index 82845ae3a6..29bebd9cde 100644 --- a/predicators/ground_truth_models/ants/options.py +++ b/predicators/ground_truth_models/ants/options.py @@ -6,7 +6,7 @@ import pybullet as p from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.pybullet_ants import PyBulletAntsEnv from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.pybullet_helpers.controllers import \ @@ -200,7 +200,7 @@ def _create_ants_move_to_above_block_option( def _get_current_and_target_pose_and_finger_status( state: State, objects: Sequence[Object], params: Array) -> Tuple[Pose, Pose, str]: - assert not params + assert len(params) == 0 robot, block = objects # Current current_position = (state.get(robot, "x"), state.get(robot, "y"), diff --git a/predicators/ground_truth_models/balance/nsrts.py b/predicators/ground_truth_models/balance/nsrts.py index 2b051e40c0..331c44b076 100644 --- a/predicators/ground_truth_models/balance/nsrts.py +++ b/predicators/ground_truth_models/balance/nsrts.py @@ -7,7 +7,7 @@ from predicators.ground_truth_models import GroundTruthNSRTFactory from predicators.structs import NSRT, Array, GroundAtom, LiftedAtom, Object, \ ParameterizedOption, Predicate, State, Type, Variable -from predicators.utils import null_sampler +from predicators.utils_lite import null_sampler class BalanceGroundTruthNSRTFactory(GroundTruthNSRTFactory): diff --git a/predicators/ground_truth_models/balance/options.py b/predicators/ground_truth_models/balance/options.py index a94dd5aa68..dc17d54210 100644 --- a/predicators/ground_truth_models/balance/options.py +++ b/predicators/ground_truth_models/balance/options.py @@ -7,7 +7,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.pybullet_balance import PyBulletBalanceEnv from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.pybullet_helpers.controllers import \ @@ -246,7 +246,7 @@ def _create_blocks_move_to_above_block_option( def _get_current_and_target_pose_and_finger_status( state: State, objects: Sequence[Object], params: Array) -> Tuple[Pose, Pose, str]: - assert not params + assert len(params) == 0 robot, block = objects current_position = (state.get(robot, "x"), state.get(robot, "y"), state.get(robot, "z")) diff --git a/predicators/ground_truth_models/ball_and_cup_sticky_table/options.py b/predicators/ground_truth_models/ball_and_cup_sticky_table/options.py index bc8d4c40f4..77ea6f6ad2 100644 --- a/predicators/ground_truth_models/ball_and_cup_sticky_table/options.py +++ b/predicators/ground_truth_models/ball_and_cup_sticky_table/options.py @@ -5,7 +5,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.ball_and_cup_sticky_table import BallAndCupStickyTableEnv from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.structs import Action, Array, Object, ParameterizedOption, \ diff --git a/predicators/ground_truth_models/blocks/nsrts.py b/predicators/ground_truth_models/blocks/nsrts.py index 5e421eaafb..a1eb6f8c40 100644 --- a/predicators/ground_truth_models/blocks/nsrts.py +++ b/predicators/ground_truth_models/blocks/nsrts.py @@ -7,7 +7,7 @@ from predicators.ground_truth_models import GroundTruthNSRTFactory from predicators.structs import NSRT, Array, GroundAtom, LiftedAtom, Object, \ ParameterizedOption, Predicate, State, Type, Variable -from predicators.utils import null_sampler +from predicators.utils_lite import null_sampler class BlocksGroundTruthNSRTFactory(GroundTruthNSRTFactory): diff --git a/predicators/ground_truth_models/blocks/options.py b/predicators/ground_truth_models/blocks/options.py index 594a67b9ff..7fb11bcaa4 100644 --- a/predicators/ground_truth_models/blocks/options.py +++ b/predicators/ground_truth_models/blocks/options.py @@ -5,7 +5,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.blocks import BlocksEnv from predicators.envs.pybullet_blocks import PyBulletBlocksEnv from predicators.ground_truth_models import GroundTruthOptionFactory @@ -301,7 +301,7 @@ def _create_blocks_move_to_above_block_option( def _get_current_and_target_pose_and_finger_status( state: State, objects: Sequence[Object], params: Array) -> Tuple[Pose, Pose, str]: - assert not params + assert len(params) == 0 robot, block = objects current_position = (state.get(robot, "pose_x"), state.get(robot, "pose_y"), diff --git a/predicators/ground_truth_models/boil/nsrts.py b/predicators/ground_truth_models/boil/nsrts.py index b39628dc3e..e1956bdd6c 100644 --- a/predicators/ground_truth_models/boil/nsrts.py +++ b/predicators/ground_truth_models/boil/nsrts.py @@ -6,7 +6,7 @@ from predicators.settings import CFG from predicators.structs import NSRT, LiftedAtom, ParameterizedOption, \ Predicate, Type, Variable -from predicators.utils import null_sampler +from predicators.utils_lite import null_sampler class PyBulletBoilGroundTruthNSRTFactory(GroundTruthNSRTFactory): diff --git a/predicators/ground_truth_models/boil/options.py b/predicators/ground_truth_models/boil/options.py index 59b2ccd487..31f2b4445d 100644 --- a/predicators/ground_truth_models/boil/options.py +++ b/predicators/ground_truth_models/boil/options.py @@ -7,7 +7,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.pybullet_boil import PyBulletBoilEnv from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.ground_truth_models.skill_factories import SkillConfig, \ diff --git a/predicators/ground_truth_models/boil/options_legacy.py b/predicators/ground_truth_models/boil/options_legacy.py index a0a5d3d456..04ba20d436 100644 --- a/predicators/ground_truth_models/boil/options_legacy.py +++ b/predicators/ground_truth_models/boil/options_legacy.py @@ -9,7 +9,7 @@ import pybullet as p from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.pybullet_boil import PyBulletBoilEnv from predicators.pybullet_helpers.controllers import \ create_change_fingers_option, create_move_end_effector_to_pose_option @@ -519,7 +519,7 @@ def _create_boil_move_to_above_placing_option( def _get_current_and_target_pose_and_finger_status( state: State, objects: Sequence[Object], params: Array) -> Tuple[Pose, Pose, str]: - assert not params + assert len(params) == 0 if len(objects) == 2: robot, burner = objects # Current @@ -597,7 +597,7 @@ def _create_boil_move_to_above_jug_option( def _get_current_and_target_pose_and_finger_status( state: State, objects: Sequence[Object], params: Array) -> Tuple[Pose, Pose, str]: - assert not params + assert len(params) == 0 robot, jug = objects # Current current_position = (state.get(robot, "x"), state.get(robot, "y"), @@ -645,7 +645,7 @@ def _create_boil_move_to_push_switch_option( def _get_current_and_target_pose_and_finger_status( state: State, objects: Sequence[Object], params: Array) -> \ Tuple[Pose, Pose, str]: - assert not params + assert len(params) == 0 robot, obj = objects switch = next( ( @@ -689,7 +689,7 @@ def _create_boil_move_to_init_option( def _get_current_and_target_pose_and_finger_status( state: State, objects: Sequence[Object], params: Array) -> Tuple[Pose, Pose, str]: - assert not params + assert len(params) == 0 robot = objects[0] current_position = (state.get(robot, "x"), state.get(robot, "y"), state.get(robot, "z")) diff --git a/predicators/ground_truth_models/burger/nsrts.py b/predicators/ground_truth_models/burger/nsrts.py index b69b476f8c..a2290ffb86 100644 --- a/predicators/ground_truth_models/burger/nsrts.py +++ b/predicators/ground_truth_models/burger/nsrts.py @@ -5,7 +5,7 @@ from predicators.ground_truth_models import GroundTruthNSRTFactory from predicators.structs import NSRT, LiftedAtom, ParameterizedOption, \ Predicate, Type, Variable -from predicators.utils import null_sampler +from predicators.utils_lite import null_sampler class BurgerGroundTruthNSRTFactory(GroundTruthNSRTFactory): diff --git a/predicators/ground_truth_models/burger/options.py b/predicators/ground_truth_models/burger/options.py index 4f450ac323..99390cd3dc 100644 --- a/predicators/ground_truth_models/burger/options.py +++ b/predicators/ground_truth_models/burger/options.py @@ -5,7 +5,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.burger import BurgerEnv from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.structs import Action, Array, Object, ParameterizedOption, \ diff --git a/predicators/ground_truth_models/circuit/nsrts.py b/predicators/ground_truth_models/circuit/nsrts.py index 9d711d2856..48f148304c 100644 --- a/predicators/ground_truth_models/circuit/nsrts.py +++ b/predicators/ground_truth_models/circuit/nsrts.py @@ -5,7 +5,7 @@ from predicators.ground_truth_models import GroundTruthNSRTFactory from predicators.structs import NSRT, LiftedAtom, ParameterizedOption, \ Predicate, Type, Variable -from predicators.utils import null_sampler +from predicators.utils_lite import null_sampler class PyBulletCircuitGroundTruthNSRTFactory(GroundTruthNSRTFactory): diff --git a/predicators/ground_truth_models/circuit/options.py b/predicators/ground_truth_models/circuit/options.py index 4cbda1db69..96c13a3e74 100644 --- a/predicators/ground_truth_models/circuit/options.py +++ b/predicators/ground_truth_models/circuit/options.py @@ -7,7 +7,7 @@ import pybullet as p from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.pybullet_circuit import PyBulletCircuitEnv from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.pybullet_helpers.controllers import \ @@ -169,7 +169,7 @@ def _create_circuit_move_to_push_switch_option( def _get_current_and_target_pose_and_finger_status( state: State, objects: Sequence[Object], params: Array) -> \ Tuple[Pose, Pose, str]: - assert not params + assert len(params) == 0 robot, switch = objects current_position = (state.get(robot, "x"), state.get(robot, "y"), state.get(robot, "z")) @@ -212,7 +212,7 @@ def _create_circuit_move_to_above_wire_option( def _get_current_and_target_pose_and_finger_status( state: State, objects: Sequence[Object], params: Array) -> Tuple[Pose, Pose, str]: - assert not params + assert len(params) == 0 robot, snap = objects current_position = (state.get(robot, "x"), state.get(robot, "y"), state.get(robot, "z")) @@ -249,7 +249,7 @@ def _create_circuit_move_to_above_two_snaps_option( def _get_current_and_target_pose_and_finger_status( state: State, objects: Sequence[Object], params: Array) -> Tuple[Pose, Pose, str]: - assert not params + assert len(params) == 0 robot, wire, light, battery = objects rx = state.get(robot, "x") ry = state.get(robot, "y") diff --git a/predicators/ground_truth_models/cluttered_table/nsrts.py b/predicators/ground_truth_models/cluttered_table/nsrts.py index 4f043c5db8..b2a220939f 100644 --- a/predicators/ground_truth_models/cluttered_table/nsrts.py +++ b/predicators/ground_truth_models/cluttered_table/nsrts.py @@ -8,7 +8,7 @@ from predicators.settings import CFG from predicators.structs import NSRT, Array, GroundAtom, LiftedAtom, Object, \ ParameterizedOption, Predicate, State, Type, Variable -from predicators.utils import null_sampler +from predicators.utils_lite import null_sampler class ClutteredTableGroundTruthNSRTFactory(GroundTruthNSRTFactory): diff --git a/predicators/ground_truth_models/cluttered_table/options.py b/predicators/ground_truth_models/cluttered_table/options.py index 404fae5b74..cf93cc3e14 100644 --- a/predicators/ground_truth_models/cluttered_table/options.py +++ b/predicators/ground_truth_models/cluttered_table/options.py @@ -5,7 +5,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.structs import Action, Array, Object, ParameterizedOption, \ Predicate, State, Type diff --git a/predicators/ground_truth_models/coffee/nsrts.py b/predicators/ground_truth_models/coffee/nsrts.py index da359d75d5..e37f5fdaef 100644 --- a/predicators/ground_truth_models/coffee/nsrts.py +++ b/predicators/ground_truth_models/coffee/nsrts.py @@ -8,7 +8,7 @@ from predicators.settings import CFG from predicators.structs import NSRT, Array, GroundAtom, LiftedAtom, Object, \ ParameterizedOption, Predicate, State, Type, Variable -from predicators.utils import null_sampler +from predicators.utils_lite import null_sampler class CoffeeGroundTruthNSRTFactory(GroundTruthNSRTFactory): diff --git a/predicators/ground_truth_models/coffee/options.py b/predicators/ground_truth_models/coffee/options.py index 4669b0b5cc..f1a53e96aa 100644 --- a/predicators/ground_truth_models/coffee/options.py +++ b/predicators/ground_truth_models/coffee/options.py @@ -8,7 +8,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.coffee import CoffeeEnv from predicators.envs.pybullet_coffee import PyBulletCoffeeEnv from predicators.ground_truth_models import GroundTruthOptionFactory diff --git a/predicators/ground_truth_models/coffee/options_legacy.py b/predicators/ground_truth_models/coffee/options_legacy.py index 06f772b01d..28a2b9eb54 100644 --- a/predicators/ground_truth_models/coffee/options_legacy.py +++ b/predicators/ground_truth_models/coffee/options_legacy.py @@ -8,7 +8,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.pybullet_coffee import PyBulletCoffeeEnv from predicators.settings import CFG from predicators.structs import Action, Array, Object, ParameterizedOption, \ diff --git a/predicators/ground_truth_models/cover/nsrts.py b/predicators/ground_truth_models/cover/nsrts.py index e410988c98..e17328e7b0 100644 --- a/predicators/ground_truth_models/cover/nsrts.py +++ b/predicators/ground_truth_models/cover/nsrts.py @@ -5,7 +5,7 @@ import numpy as np -from predicators import utils +from predicators import utils_lite as utils from predicators.ground_truth_models import GroundTruthNSRTFactory from predicators.settings import CFG from predicators.structs import NSRT, Array, GroundAtom, LiftedAtom, Object, \ diff --git a/predicators/ground_truth_models/cover/options.py b/predicators/ground_truth_models/cover/options.py index 73a6fde551..65cd63dcf8 100644 --- a/predicators/ground_truth_models/cover/options.py +++ b/predicators/ground_truth_models/cover/options.py @@ -5,7 +5,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.cover import CoverMultistepOptions from predicators.envs.pybullet_cover import PyBulletCoverEnv from predicators.ground_truth_models import GroundTruthOptionFactory diff --git a/predicators/ground_truth_models/domino/nsrts.py b/predicators/ground_truth_models/domino/nsrts.py index 0781036ed7..f2ea97691b 100644 --- a/predicators/ground_truth_models/domino/nsrts.py +++ b/predicators/ground_truth_models/domino/nsrts.py @@ -6,7 +6,7 @@ from predicators.settings import CFG from predicators.structs import NSRT, LiftedAtom, ParameterizedOption, \ Predicate, Type, Variable -from predicators.utils import null_sampler +from predicators.utils_lite import null_sampler class PyBulletDominoGroundTruthNSRTFactory(GroundTruthNSRTFactory): diff --git a/predicators/ground_truth_models/domino/options_legacy.py b/predicators/ground_truth_models/domino/options_legacy.py index b2ba9e4a7c..21f33ab159 100644 --- a/predicators/ground_truth_models/domino/options_legacy.py +++ b/predicators/ground_truth_models/domino/options_legacy.py @@ -10,7 +10,7 @@ import pybullet as p from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.pybullet_domino import PyBulletDominoEnv from predicators.envs.pybullet_domino.components.domino_component import \ DominoComponent @@ -335,7 +335,7 @@ def _create_domino_move_to_push_domino_option( def _get_current_and_target_pose_and_finger_status( state: State, objects: Sequence[Object], params: Array) -> \ Tuple[Pose, Pose, str]: - assert not params + assert len(params) == 0 robot, domino = objects current_position = (state.get(robot, "x"), state.get(robot, "y"), state.get(robot, "z")) @@ -384,7 +384,7 @@ def _create_domino_move_to_push_start_block_option( def _get_current_and_target_pose_and_finger_status( state: State, objects: Sequence[Object], params: Array) -> \ Tuple[Pose, Pose, str]: - assert not params + assert len(params) == 0 robot, = objects domino = cls._find_start_block(state, domino_type) current_position = (state.get(robot, "x"), state.get(robot, "y"), @@ -425,7 +425,7 @@ def _create_domino_move_to_domino_option( def _get_current_and_target_pose_and_finger_status( state: State, objects: Sequence[Object], params: Array) -> \ Tuple[Pose, Pose, str]: - assert not params + assert len(params) == 0 robot, domino = objects current_position = (state.get(robot, "x"), state.get(robot, "y"), state.get(robot, "z")) @@ -465,7 +465,7 @@ def _create_domino_place_option(cls, name: str, z_func: Callable[[float], def _get_current_and_target_pose_and_finger_status( state: State, objects: Sequence[Object], params: Array) -> \ Tuple[Pose, Pose, str]: - assert not params + assert len(params) == 0 robot, domino_f, domino_b, tgt_pos, rotation = objects current_position = (state.get(robot, "x"), state.get(robot, "y"), state.get(robot, "z")) diff --git a/predicators/ground_truth_models/domino/predicates.py b/predicators/ground_truth_models/domino/predicates.py index 430a63407d..b6e673f960 100644 --- a/predicators/ground_truth_models/domino/predicates.py +++ b/predicators/ground_truth_models/domino/predicates.py @@ -4,7 +4,7 @@ import numpy as np -from predicators import utils +from predicators import utils_lite as utils from predicators.ground_truth_models import GroundTruthPredicateFactory from predicators.structs import DerivedPredicate, GroundAtom, Object, \ Predicate, State, Type diff --git a/predicators/ground_truth_models/domino/types.py b/predicators/ground_truth_models/domino/types.py index 767705c667..e395ceab7b 100644 --- a/predicators/ground_truth_models/domino/types.py +++ b/predicators/ground_truth_models/domino/types.py @@ -10,7 +10,7 @@ PyBulletDominoComposedEnv from predicators.ground_truth_models import GroundTruthTypeFactory from predicators.structs import Object, Task, Type -from predicators.utils import PyBulletState +from predicators.utils_lite import PyBulletState class PyBulletDominoGroundTruthTypeFactory(GroundTruthTypeFactory): diff --git a/predicators/ground_truth_models/doors/nsrts.py b/predicators/ground_truth_models/doors/nsrts.py index 8d26cce23f..bb842298d5 100644 --- a/predicators/ground_truth_models/doors/nsrts.py +++ b/predicators/ground_truth_models/doors/nsrts.py @@ -9,7 +9,7 @@ from predicators.ground_truth_models import GroundTruthNSRTFactory from predicators.structs import NSRT, Array, GroundAtom, LiftedAtom, Object, \ ParameterizedOption, Predicate, State, Type, Variable -from predicators.utils import null_sampler +from predicators.utils_lite import null_sampler class DoorsGroundTruthNSRTFactory(GroundTruthNSRTFactory): diff --git a/predicators/ground_truth_models/doors/options.py b/predicators/ground_truth_models/doors/options.py index d243e17df0..7b46cd7297 100644 --- a/predicators/ground_truth_models/doors/options.py +++ b/predicators/ground_truth_models/doors/options.py @@ -5,14 +5,14 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.doors import DoorKnobsEnv, DoorsEnv from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.settings import CFG from predicators.structs import Action, Array, Object, \ ParameterizedInitiable, ParameterizedOption, ParameterizedPolicy, \ Predicate, State, Type -from predicators.utils import Rectangle, SingletonParameterizedOption, \ +from predicators.utils_lite import Rectangle, SingletonParameterizedOption, \ StateWithCache diff --git a/predicators/ground_truth_models/exit_garage/options.py b/predicators/ground_truth_models/exit_garage/options.py index 5da18fbb8e..c8deaed949 100644 --- a/predicators/ground_truth_models/exit_garage/options.py +++ b/predicators/ground_truth_models/exit_garage/options.py @@ -6,7 +6,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.exit_garage import ExitGarageEnv from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.settings import CFG diff --git a/predicators/ground_truth_models/fan/nsrts.py b/predicators/ground_truth_models/fan/nsrts.py index 62ca7cdd70..0cedc3fa7a 100644 --- a/predicators/ground_truth_models/fan/nsrts.py +++ b/predicators/ground_truth_models/fan/nsrts.py @@ -6,7 +6,7 @@ from predicators.settings import CFG from predicators.structs import NSRT, LiftedAtom, ParameterizedOption, \ Predicate, Type, Variable -from predicators.utils import null_sampler +from predicators.utils_lite import null_sampler class PyBulletFanGroundTruthNSRTFactory(GroundTruthNSRTFactory): diff --git a/predicators/ground_truth_models/fan/options.py b/predicators/ground_truth_models/fan/options.py index 9c929b16bb..2060205d99 100644 --- a/predicators/ground_truth_models/fan/options.py +++ b/predicators/ground_truth_models/fan/options.py @@ -6,7 +6,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.pybullet_fan import PyBulletFanEnv from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.ground_truth_models.skill_factories import SkillConfig, \ diff --git a/predicators/ground_truth_models/fan/options_legacy.py b/predicators/ground_truth_models/fan/options_legacy.py index 5fedcb542a..329e393d34 100644 --- a/predicators/ground_truth_models/fan/options_legacy.py +++ b/predicators/ground_truth_models/fan/options_legacy.py @@ -8,7 +8,7 @@ import pybullet as p from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.pybullet_fan import PyBulletFanEnv from predicators.pybullet_helpers.controllers import \ create_change_fingers_option, create_move_end_effector_to_pose_option @@ -260,7 +260,7 @@ def _create_fan_move_to_push_switch_option( def _get_current_and_target_pose_and_finger_status( state: State, objects: Sequence[Object], params: Array) -> \ Tuple[Pose, Pose, str]: - assert not params + assert len(params) == 0 if CFG.fan_known_controls_relation: robot, fan = objects switch = [switch for switch in state.get_objects(switch_type) diff --git a/predicators/ground_truth_models/float/nsrts.py b/predicators/ground_truth_models/float/nsrts.py index cbab8bd28d..38e9f2d18c 100644 --- a/predicators/ground_truth_models/float/nsrts.py +++ b/predicators/ground_truth_models/float/nsrts.py @@ -5,7 +5,7 @@ from predicators.ground_truth_models import GroundTruthNSRTFactory from predicators.structs import NSRT, LiftedAtom, ParameterizedOption, \ Predicate, Type, Variable -from predicators.utils import null_sampler +from predicators.utils_lite import null_sampler class PyBulletFloatGroundTruthNSRTFactory(GroundTruthNSRTFactory): diff --git a/predicators/ground_truth_models/float/options.py b/predicators/ground_truth_models/float/options.py index ec32a36ec1..dcd258fd81 100644 --- a/predicators/ground_truth_models/float/options.py +++ b/predicators/ground_truth_models/float/options.py @@ -7,7 +7,7 @@ import pybullet as p from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.pybullet_env import PyBulletEnv from predicators.envs.pybullet_float import PyBulletFloatEnv from predicators.ground_truth_models import GroundTruthOptionFactory @@ -142,7 +142,7 @@ def _create_float_move_to_above_block_option( def _get_current_and_target_pose_and_finger_status( state: State, objects: Sequence[Object], params: Array) -> Tuple[Pose, Pose, str]: - assert not params + assert len(params) == 0 robot, block = objects current_position = (state.get(robot, "x"), state.get(robot, "y"), state.get(robot, "z")) @@ -183,7 +183,7 @@ def _create_float_move_to_above_vessel_option( def _get_current_and_target_pose_and_finger_status( state: State, objects: Sequence[Object], params: Array) -> Tuple[Pose, Pose, str]: - assert not params + assert len(params) == 0 robot, vessel = objects rx = state.get(robot, "x") ry = state.get(robot, "y") diff --git a/predicators/ground_truth_models/grid_row/nsrts.py b/predicators/ground_truth_models/grid_row/nsrts.py index 7e3c08a791..22338468c6 100644 --- a/predicators/ground_truth_models/grid_row/nsrts.py +++ b/predicators/ground_truth_models/grid_row/nsrts.py @@ -7,7 +7,7 @@ from predicators.ground_truth_models import GroundTruthNSRTFactory from predicators.structs import NSRT, Array, GroundAtom, LiftedAtom, Object, \ ParameterizedOption, Predicate, State, Type, Variable -from predicators.utils import null_sampler +from predicators.utils_lite import null_sampler class GridRowGroundTruthNSRTFactory(GroundTruthNSRTFactory): diff --git a/predicators/ground_truth_models/grid_row/options.py b/predicators/ground_truth_models/grid_row/options.py index 07d33357c1..44fde8e2ec 100644 --- a/predicators/ground_truth_models/grid_row/options.py +++ b/predicators/ground_truth_models/grid_row/options.py @@ -5,7 +5,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.structs import Action, Array, Object, ParameterizedOption, \ Predicate, State, Type diff --git a/predicators/ground_truth_models/grow/nsrts.py b/predicators/ground_truth_models/grow/nsrts.py index 4bec374486..33c59c3f1d 100644 --- a/predicators/ground_truth_models/grow/nsrts.py +++ b/predicators/ground_truth_models/grow/nsrts.py @@ -7,7 +7,7 @@ from predicators.ground_truth_models import GroundTruthNSRTFactory from predicators.structs import NSRT, Array, GroundAtom, LiftedAtom, Object, \ ParameterizedOption, Predicate, State, Type, Variable -from predicators.utils import null_sampler +from predicators.utils_lite import null_sampler class PyBulletGrowGroundTruthNSRTFactory(GroundTruthNSRTFactory): diff --git a/predicators/ground_truth_models/grow/options_legacy.py b/predicators/ground_truth_models/grow/options_legacy.py index 7012b1b50a..e1b43dc308 100644 --- a/predicators/ground_truth_models/grow/options_legacy.py +++ b/predicators/ground_truth_models/grow/options_legacy.py @@ -8,7 +8,7 @@ import pybullet as p from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.pybullet_coffee import PyBulletCoffeeEnv from predicators.envs.pybullet_grow import PyBulletGrowEnv from predicators.ground_truth_models.coffee.options import \ diff --git a/predicators/ground_truth_models/ice_tea_making/options.py b/predicators/ground_truth_models/ice_tea_making/options.py index 8a845949eb..4f726dd74b 100644 --- a/predicators/ground_truth_models/ice_tea_making/options.py +++ b/predicators/ground_truth_models/ice_tea_making/options.py @@ -4,7 +4,7 @@ from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.structs import Action, Array, Object, ParameterizedOption, \ ParameterizedPolicy, Predicate, State, Type diff --git a/predicators/ground_truth_models/laser/nsrts.py b/predicators/ground_truth_models/laser/nsrts.py index 7b18e7a2ff..d72c6084de 100644 --- a/predicators/ground_truth_models/laser/nsrts.py +++ b/predicators/ground_truth_models/laser/nsrts.py @@ -5,7 +5,7 @@ from predicators.ground_truth_models import GroundTruthNSRTFactory from predicators.structs import NSRT, LiftedAtom, ParameterizedOption, \ Predicate, Type, Variable -from predicators.utils import null_sampler +from predicators.utils_lite import null_sampler class PyBulletLaserGroundTruthNSRTFactory(GroundTruthNSRTFactory): diff --git a/predicators/ground_truth_models/laser/options.py b/predicators/ground_truth_models/laser/options.py index cce717624a..d8cbdf3e65 100644 --- a/predicators/ground_truth_models/laser/options.py +++ b/predicators/ground_truth_models/laser/options.py @@ -8,7 +8,7 @@ import pybullet as p from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.pybullet_env import PyBulletEnv from predicators.envs.pybullet_laser import PyBulletLaserEnv from predicators.ground_truth_models import GroundTruthOptionFactory @@ -171,7 +171,7 @@ def _create_laser_move_to_push_switch_option( def _get_current_and_target_pose_and_finger_status( state: State, objects: Sequence[Object], params: Array) -> \ Tuple[Pose, Pose, str]: - assert not params + assert len(params) == 0 robot, switch = objects current_position = (state.get(robot, "x"), state.get(robot, "y"), state.get(robot, "z")) @@ -217,7 +217,7 @@ def _create_laser_move_to_above_mirror_option( def _get_current_and_target_pose_and_finger_status( state: State, objects: Sequence[Object], params: Array) -> Tuple[Pose, Pose, str]: - assert not params + assert len(params) == 0 robot, mirror = objects current_position = (state.get(robot, "x"), state.get(robot, "y"), state.get(robot, "z")) @@ -260,7 +260,7 @@ def _create_laser_move_to_above_position_option( def _get_current_and_target_pose_and_finger_status( state: State, objects: Sequence[Object], params: Array) -> Tuple[Pose, Pose, str]: - assert not params + assert len(params) == 0 robot, = objects rx = state.get(robot, "x") ry = state.get(robot, "y") diff --git a/predicators/ground_truth_models/narrow_passage/options.py b/predicators/ground_truth_models/narrow_passage/options.py index e1b1097f88..ddaaf26ec3 100644 --- a/predicators/ground_truth_models/narrow_passage/options.py +++ b/predicators/ground_truth_models/narrow_passage/options.py @@ -6,7 +6,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.narrow_passage import NarrowPassageEnv from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.settings import CFG diff --git a/predicators/ground_truth_models/noisy_button/options.py b/predicators/ground_truth_models/noisy_button/options.py index ff62e690d3..9c333bef80 100644 --- a/predicators/ground_truth_models/noisy_button/options.py +++ b/predicators/ground_truth_models/noisy_button/options.py @@ -4,7 +4,7 @@ from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.structs import Action, Array, Object, ParameterizedOption, \ Predicate, State, Type diff --git a/predicators/ground_truth_models/painting/nsrts.py b/predicators/ground_truth_models/painting/nsrts.py index bc8f71e3ed..66ffb537e4 100644 --- a/predicators/ground_truth_models/painting/nsrts.py +++ b/predicators/ground_truth_models/painting/nsrts.py @@ -9,7 +9,7 @@ from predicators.ground_truth_models import GroundTruthNSRTFactory from predicators.structs import NSRT, Array, GroundAtom, LiftedAtom, Object, \ ParameterizedOption, Predicate, State, Type, Variable -from predicators.utils import null_sampler +from predicators.utils_lite import null_sampler class PaintingGroundTruthNSRTFactory(GroundTruthNSRTFactory): diff --git a/predicators/ground_truth_models/painting/options.py b/predicators/ground_truth_models/painting/options.py index 2123855efc..2d7a8cef83 100644 --- a/predicators/ground_truth_models/painting/options.py +++ b/predicators/ground_truth_models/painting/options.py @@ -5,7 +5,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.painting import PaintingEnv from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.structs import Action, Array, Object, ParameterizedOption, \ diff --git a/predicators/ground_truth_models/pddl_env/nsrts.py b/predicators/ground_truth_models/pddl_env/nsrts.py index 2a3889ca6e..272e55b82a 100644 --- a/predicators/ground_truth_models/pddl_env/nsrts.py +++ b/predicators/ground_truth_models/pddl_env/nsrts.py @@ -2,12 +2,12 @@ from typing import Dict, Set -from predicators import utils +from predicators import utils_lite as utils from predicators.envs import get_or_create_env from predicators.envs.pddl_env import _PDDLEnv from predicators.ground_truth_models import GroundTruthNSRTFactory from predicators.structs import NSRT, ParameterizedOption, Predicate, Type -from predicators.utils import null_sampler +from predicators.utils_lite import null_sampler class PDDLEnvGroundTruthNSRTFactory(GroundTruthNSRTFactory): diff --git a/predicators/ground_truth_models/pddl_env/options.py b/predicators/ground_truth_models/pddl_env/options.py index c02b401704..94b02fa50b 100644 --- a/predicators/ground_truth_models/pddl_env/options.py +++ b/predicators/ground_truth_models/pddl_env/options.py @@ -6,7 +6,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.pddl_env import _parse_pddl_domain, _PDDLEnv from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.structs import Action, Array, Object, ParameterizedOption, \ diff --git a/predicators/ground_truth_models/playroom/options.py b/predicators/ground_truth_models/playroom/options.py index bab3d20938..6b0bcc2159 100644 --- a/predicators/ground_truth_models/playroom/options.py +++ b/predicators/ground_truth_models/playroom/options.py @@ -5,7 +5,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.playroom import PlayroomEnv from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.structs import Action, Array, Object, \ diff --git a/predicators/ground_truth_models/repeated_nextto/nsrts.py b/predicators/ground_truth_models/repeated_nextto/nsrts.py index 6b0ac1e8ec..3a7f0b36c5 100644 --- a/predicators/ground_truth_models/repeated_nextto/nsrts.py +++ b/predicators/ground_truth_models/repeated_nextto/nsrts.py @@ -7,7 +7,7 @@ from predicators.ground_truth_models import GroundTruthNSRTFactory from predicators.structs import NSRT, DummyOption, LiftedAtom, \ ParameterizedOption, Predicate, Type, Variable -from predicators.utils import null_sampler +from predicators.utils_lite import null_sampler class RepeatedNextToGroundTruthNSRTFactory(GroundTruthNSRTFactory): diff --git a/predicators/ground_truth_models/repeated_nextto/options.py b/predicators/ground_truth_models/repeated_nextto/options.py index a6eddae739..f314aa1a7e 100644 --- a/predicators/ground_truth_models/repeated_nextto/options.py +++ b/predicators/ground_truth_models/repeated_nextto/options.py @@ -5,7 +5,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.repeated_nextto import RepeatedNextToEnv from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.structs import Action, Array, Object, ParameterizedOption, \ diff --git a/predicators/ground_truth_models/sandwich/nsrts.py b/predicators/ground_truth_models/sandwich/nsrts.py index 27e2f64c34..9a727daf59 100644 --- a/predicators/ground_truth_models/sandwich/nsrts.py +++ b/predicators/ground_truth_models/sandwich/nsrts.py @@ -5,7 +5,7 @@ from predicators.ground_truth_models import GroundTruthNSRTFactory from predicators.structs import NSRT, LiftedAtom, ParameterizedOption, \ Predicate, Type, Variable -from predicators.utils import null_sampler +from predicators.utils_lite import null_sampler class SandwichGroundTruthNSRTFactory(GroundTruthNSRTFactory): diff --git a/predicators/ground_truth_models/sandwich/options.py b/predicators/ground_truth_models/sandwich/options.py index 156a4489ed..ada64797c1 100644 --- a/predicators/ground_truth_models/sandwich/options.py +++ b/predicators/ground_truth_models/sandwich/options.py @@ -5,7 +5,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.sandwich import SandwichEnv from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.structs import Action, Array, Object, ParameterizedOption, \ diff --git a/predicators/ground_truth_models/satellites/nsrts.py b/predicators/ground_truth_models/satellites/nsrts.py index 5be6a76165..3bec6bf14e 100644 --- a/predicators/ground_truth_models/satellites/nsrts.py +++ b/predicators/ground_truth_models/satellites/nsrts.py @@ -8,7 +8,7 @@ from predicators.ground_truth_models import GroundTruthNSRTFactory from predicators.structs import NSRT, Array, GroundAtom, LiftedAtom, Object, \ ParameterizedOption, Predicate, State, Type, Variable -from predicators.utils import null_sampler +from predicators.utils_lite import null_sampler class SatellitesGroundTruthNSRTFactory(GroundTruthNSRTFactory): diff --git a/predicators/ground_truth_models/satellites/options.py b/predicators/ground_truth_models/satellites/options.py index 828b5d06c3..3c567814f6 100644 --- a/predicators/ground_truth_models/satellites/options.py +++ b/predicators/ground_truth_models/satellites/options.py @@ -5,7 +5,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.structs import Action, Array, Object, ParameterizedOption, \ ParameterizedPolicy, Predicate, State, Type diff --git a/predicators/ground_truth_models/screws/nsrts.py b/predicators/ground_truth_models/screws/nsrts.py index 62646df553..68f6bb3fed 100644 --- a/predicators/ground_truth_models/screws/nsrts.py +++ b/predicators/ground_truth_models/screws/nsrts.py @@ -5,7 +5,7 @@ from predicators.ground_truth_models import GroundTruthNSRTFactory from predicators.structs import NSRT, LiftedAtom, ParameterizedOption, \ Predicate, Type, Variable -from predicators.utils import null_sampler +from predicators.utils_lite import null_sampler class ScrewsGroundTruthNSRTFactory(GroundTruthNSRTFactory): diff --git a/predicators/ground_truth_models/screws/options.py b/predicators/ground_truth_models/screws/options.py index df6e7d37a7..19eaa6e5c3 100644 --- a/predicators/ground_truth_models/screws/options.py +++ b/predicators/ground_truth_models/screws/options.py @@ -5,7 +5,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.screws import ScrewsEnv from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.structs import Action, Array, Object, ParameterizedOption, \ diff --git a/predicators/ground_truth_models/skill_factories/base.py b/predicators/ground_truth_models/skill_factories/base.py index d4f17d86bd..13c2a26bcf 100644 --- a/predicators/ground_truth_models/skill_factories/base.py +++ b/predicators/ground_truth_models/skill_factories/base.py @@ -16,7 +16,7 @@ import pybullet as p from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.pybullet_helpers.controllers import \ get_change_fingers_action, get_move_end_effector_to_pose_action from predicators.pybullet_helpers.geometry import Pose diff --git a/predicators/ground_truth_models/skill_factories/wait.py b/predicators/ground_truth_models/skill_factories/wait.py index b2c7c00422..34815ef1e5 100644 --- a/predicators/ground_truth_models/skill_factories/wait.py +++ b/predicators/ground_truth_models/skill_factories/wait.py @@ -19,7 +19,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.ground_truth_models.skill_factories.base import SkillConfig from predicators.structs import Action, Array, Object, ParameterizedOption, \ State, Type diff --git a/predicators/ground_truth_models/sokoban/nsrts.py b/predicators/ground_truth_models/sokoban/nsrts.py index 0e54f3fd0b..7831be8ae0 100644 --- a/predicators/ground_truth_models/sokoban/nsrts.py +++ b/predicators/ground_truth_models/sokoban/nsrts.py @@ -5,7 +5,7 @@ from predicators.ground_truth_models import GroundTruthNSRTFactory from predicators.structs import NSRT, LiftedAtom, ParameterizedOption, \ Predicate, Type, Variable -from predicators.utils import null_sampler +from predicators.utils_lite import null_sampler class SokobanGroundTruthNSRTFactory(GroundTruthNSRTFactory): diff --git a/predicators/ground_truth_models/sokoban/options.py b/predicators/ground_truth_models/sokoban/options.py index 4e5b0654df..898b92ef23 100644 --- a/predicators/ground_truth_models/sokoban/options.py +++ b/predicators/ground_truth_models/sokoban/options.py @@ -6,7 +6,7 @@ from gym.spaces import Box from gym_sokoban.envs.sokoban_env import ACTION_LOOKUP as SOKOBAN_ACTION_LOOKUP -from predicators import utils +from predicators import utils_lite as utils from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.structs import Action, Array, Object, ParameterizedOption, \ ParameterizedPolicy, Predicate, State, Type diff --git a/predicators/ground_truth_models/stick_button/nsrts.py b/predicators/ground_truth_models/stick_button/nsrts.py index 8a76fa5ff6..cfca584ea3 100644 --- a/predicators/ground_truth_models/stick_button/nsrts.py +++ b/predicators/ground_truth_models/stick_button/nsrts.py @@ -7,7 +7,7 @@ from predicators.ground_truth_models import GroundTruthNSRTFactory from predicators.structs import NSRT, Array, GroundAtom, LiftedAtom, Object, \ ParameterizedOption, Predicate, State, Type, Variable -from predicators.utils import null_sampler +from predicators.utils_lite import null_sampler class StickButtonGroundTruthNSRTFactory(GroundTruthNSRTFactory): diff --git a/predicators/ground_truth_models/stick_button/options.py b/predicators/ground_truth_models/stick_button/options.py index a69c1c62c3..bea4c5cfff 100644 --- a/predicators/ground_truth_models/stick_button/options.py +++ b/predicators/ground_truth_models/stick_button/options.py @@ -5,7 +5,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.stick_button import StickButtonEnv from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.settings import CFG diff --git a/predicators/ground_truth_models/sticky_table/options.py b/predicators/ground_truth_models/sticky_table/options.py index 26ea19ea2f..146ce2b566 100644 --- a/predicators/ground_truth_models/sticky_table/options.py +++ b/predicators/ground_truth_models/sticky_table/options.py @@ -5,7 +5,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.sticky_table import StickyTableEnv from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.structs import Action, Array, Object, ParameterizedOption, \ diff --git a/predicators/ground_truth_models/tools/nsrts.py b/predicators/ground_truth_models/tools/nsrts.py index 2247939957..263f06b382 100644 --- a/predicators/ground_truth_models/tools/nsrts.py +++ b/predicators/ground_truth_models/tools/nsrts.py @@ -8,7 +8,7 @@ from predicators.ground_truth_models import GroundTruthNSRTFactory from predicators.structs import NSRT, Array, GroundAtom, LiftedAtom, Object, \ ParameterizedOption, Predicate, State, Type, Variable -from predicators.utils import null_sampler +from predicators.utils_lite import null_sampler class ToolsGroundTruthNSRTFactory(GroundTruthNSRTFactory): diff --git a/predicators/ground_truth_models/tools/options.py b/predicators/ground_truth_models/tools/options.py index dfc30e3ab4..f705c73ee4 100644 --- a/predicators/ground_truth_models/tools/options.py +++ b/predicators/ground_truth_models/tools/options.py @@ -5,7 +5,7 @@ import numpy as np from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.tools import ToolsEnv from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.structs import Action, Array, Object, ParameterizedOption, \ @@ -128,7 +128,7 @@ def _create_pick_policy(cls) -> ParameterizedPolicy: def policy(state: State, memory: Dict, objects: Sequence[Object], params: Array) -> Action: del memory # unused - assert not params + assert len(params) == 0 _, item_or_tool = objects pose_x = state.get(item_or_tool, "pose_x") pose_y = state.get(item_or_tool, "pose_y") @@ -153,7 +153,7 @@ def _create_fasten_policy(cls) -> ParameterizedPolicy: def policy(state: State, memory: Dict, objects: Sequence[Object], params: Array) -> Action: del memory # unused - assert not params + assert len(params) == 0 if len(objects) == 3: # Note that the FastenScrewByHand option has only 3 parameters, # while all other Fasten options have 4 parameters. diff --git a/predicators/ground_truth_models/touch_point/nsrts.py b/predicators/ground_truth_models/touch_point/nsrts.py index bcae926c63..fed488929c 100644 --- a/predicators/ground_truth_models/touch_point/nsrts.py +++ b/predicators/ground_truth_models/touch_point/nsrts.py @@ -8,7 +8,7 @@ from predicators.settings import CFG from predicators.structs import NSRT, Array, GroundAtom, LiftedAtom, Object, \ ParameterizedOption, Predicate, State, Type, Variable -from predicators.utils import null_sampler +from predicators.utils_lite import null_sampler class TouchPointGroundTruthNSRTFactory(GroundTruthNSRTFactory): diff --git a/predicators/perception/__init__.py b/predicators/perception/__init__.py index 7984f70738..a1f1ccf13c 100644 --- a/predicators/perception/__init__.py +++ b/predicators/perception/__init__.py @@ -1,6 +1,6 @@ """Handle creation of perceivers.""" -from predicators import utils +from predicators import utils_lite as utils from predicators.perception.base_perceiver import BasePerceiver __all__ = ["BasePerceiver"] diff --git a/predicators/perception/sokoban_perceiver.py b/predicators/perception/sokoban_perceiver.py index e633d4c802..60a671c3d4 100644 --- a/predicators/perception/sokoban_perceiver.py +++ b/predicators/perception/sokoban_perceiver.py @@ -4,7 +4,7 @@ import numpy as np -from predicators import utils +from predicators import utils_lite as utils from predicators.envs.sokoban import SokobanEnv from predicators.perception.base_perceiver import BasePerceiver from predicators.structs import EnvironmentTask, GroundAtom, Object, \ diff --git a/predicators/pybullet_helpers/controllers.py b/predicators/pybullet_helpers/controllers.py index 6e3bc92925..73cc76f085 100644 --- a/predicators/pybullet_helpers/controllers.py +++ b/predicators/pybullet_helpers/controllers.py @@ -5,7 +5,7 @@ import pybullet as p from gym.spaces import Box -from predicators import utils +from predicators import utils_lite as utils from predicators.pybullet_helpers.geometry import Pose from predicators.pybullet_helpers.inverse_kinematics import \ InverseKinematicsError diff --git a/predicators/pybullet_helpers/ikfast/load.py b/predicators/pybullet_helpers/ikfast/load.py index 10d3074115..fc878341ec 100644 --- a/predicators/pybullet_helpers/ikfast/load.py +++ b/predicators/pybullet_helpers/ikfast/load.py @@ -10,7 +10,7 @@ from types import ModuleType from predicators.pybullet_helpers.ikfast import IKFastInfo -from predicators.utils import get_third_party_path +from predicators.utils_lite import get_third_party_path def install_ikfast_module(ikfast_dir: str) -> None: diff --git a/predicators/pybullet_helpers/motion_planning.py b/predicators/pybullet_helpers/motion_planning.py index b2749b10ee..546445f840 100644 --- a/predicators/pybullet_helpers/motion_planning.py +++ b/predicators/pybullet_helpers/motion_planning.py @@ -7,7 +7,7 @@ import pybullet as p from numpy.typing import NDArray -from predicators import utils +from predicators import utils_lite as utils from predicators.pybullet_helpers.joint import JointPositions from predicators.pybullet_helpers.link import get_link_state from predicators.pybullet_helpers.robots import SingleArmPyBulletRobot diff --git a/predicators/pybullet_helpers/objects.py b/predicators/pybullet_helpers/objects.py index 883a9133ca..ee4b76d44a 100644 --- a/predicators/pybullet_helpers/objects.py +++ b/predicators/pybullet_helpers/objects.py @@ -4,10 +4,10 @@ import numpy as np import pybullet as p -from predicators import utils +from predicators import utils_lite as utils from predicators.pybullet_helpers import retry_pybullet_call from predicators.pybullet_helpers.geometry import Pose3D, Quaternion -from predicators.utils import _Geom2D +from predicators.utils_lite import _Geom2D # import numpy as np default_orn: Quaternion = (0.0, 0.0, 0.0, 1.0) @@ -113,7 +113,7 @@ def sample_collision_free_2d_positions( List[Tuple[float, float]]: A list of (x, y) positions for the shapes, guaranteed to be collision-free. """ - from predicators.utils import Circle, Rectangle \ + from predicators.utils_lite import Circle, Rectangle \ # pylint: disable=import-outside-toplevel def create_geom(px: float, py: float) -> _Geom2D: diff --git a/predicators/pybullet_helpers/robots/fetch.py b/predicators/pybullet_helpers/robots/fetch.py index c5e6b4b913..81d7924482 100644 --- a/predicators/pybullet_helpers/robots/fetch.py +++ b/predicators/pybullet_helpers/robots/fetch.py @@ -1,6 +1,6 @@ """Fetch Robotics Mobile Manipulator (Fetch).""" -from predicators import utils +from predicators import utils_lite as utils from predicators.pybullet_helpers.robots.single_arm import \ SingleArmPyBulletRobot diff --git a/predicators/pybullet_helpers/robots/panda.py b/predicators/pybullet_helpers/robots/panda.py index 879d3b2846..c01a1b6111 100644 --- a/predicators/pybullet_helpers/robots/panda.py +++ b/predicators/pybullet_helpers/robots/panda.py @@ -1,7 +1,7 @@ """Franka Emika Panda robot.""" from typing import Optional -from predicators import utils +from predicators import utils_lite as utils from predicators.pybullet_helpers.ikfast import IKFastInfo from predicators.pybullet_helpers.robots.single_arm import \ SingleArmPyBulletRobot diff --git a/predicators/structs.py b/predicators/structs.py index 608597f5d7..6bff13a2e3 100644 --- a/predicators/structs.py +++ b/predicators/structs.py @@ -15,24 +15,49 @@ cast if TYPE_CHECKING: - from predicators.utils import VLMQuery, VLMState + # `torch` and `predicators.pretrained_model_interface` are heavy + # modules unavailable under Pyodide. Type annotations referencing + # them are deferred to strings via `from __future__ import + # annotations`; the few runtime call sites use lazy local imports. + import torch # pylint: disable=import-outside-toplevel + from torch import Tensor # pylint: disable=import-outside-toplevel + + import predicators.pretrained_model_interface + from predicators.utils_lite import VLMQuery, VLMState # pylint: disable=wrong-import-position import numpy as np import PIL.Image -import torch from gym.spaces import Box from numpy.typing import NDArray from tabulate import tabulate -from torch import Tensor -import predicators.pretrained_model_interface -import predicators.utils as utils # pylint: disable=consider-using-from-import +# structs is a foundational module imported during utils_lite's own +# load. Going through utils (the heavy variant) would create a cycle: +# utils_lite -> structs -> utils -> utils_lite (incomplete). Importing +# utils_lite directly here breaks the cycle, and CPython callers still +# see every `utils.X` symbol via the full `utils` module's wildcard +# re-export of utils_lite. +import predicators.utils_lite as utils # pylint: disable=consider-using-from-import,ungrouped-imports from predicators.settings import CFG # pylint: enable=wrong-import-position +def _torch_and_delay() -> Tuple[Any, Any]: + """Lazy-load torch and DiscreteGaussianDelay for the two process- + construction methods below. + + Both are heavy and unsafe to import under Pyodide. + """ + import torch # pylint: disable=import-outside-toplevel + + from predicators.utils import DiscreteGaussianDelay \ + # pylint: disable=import-outside-toplevel + + return torch, DiscreteGaussianDelay + + @dataclass(frozen=True, order=True) class Type: """Struct defining a type. @@ -1252,6 +1277,7 @@ def make_endogenous_process( process_rng: Optional[np.random.Generator] = None, ) -> EndogenousProcess: """Make a CausalProcess out of this STRIPSOperator object.""" + torch, DiscreteGaussianDelay = _torch_and_delay() assert option is not None and option_vars is not None and \ sampler is not None if process_delay_params is None: @@ -1271,7 +1297,7 @@ def make_endogenous_process( add_effects=self.add_effects if option.name != "Wait" else set(), delete_effects=self.delete_effects if option.name != "Wait" else set(), - delay_distribution=utils.DiscreteGaussianDelay( + delay_distribution=DiscreteGaussianDelay( torch.tensor(process_delay_params[0]), torch.tensor(process_delay_params[1])), strength=process_strength, # type: ignore[arg-type] @@ -1287,12 +1313,13 @@ def make_exogenous_process( _process_rng: Optional[np.random.Generator] = None ) -> ExogenousProcess: """Make an ExogenousProcess out of this STRIPSOperator object.""" + torch, DiscreteGaussianDelay = _torch_and_delay() if process_delay_params is None: process_delay_params = torch.tensor([1, 1 ]) # type: ignore[assignment] if process_strength is None: process_strength = torch.tensor(1.0) # type: ignore[assignment] - dist = utils.DiscreteGaussianDelay(torch.tensor(1), torch.tensor(1)) + dist = DiscreteGaussianDelay(torch.tensor(1), torch.tensor(1)) proc = ExogenousProcess( self.name, diff --git a/predicators/third_party/__init__.py b/predicators/third_party/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/predicators/third_party/fast_downward_translator/__init__.py b/predicators/third_party/fast_downward_translator/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/predicators/utils.py b/predicators/utils.py index 3a48b3e11e..f6e5270afd 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -1,4206 +1,82 @@ -"""General utility methods.""" +"""General utility methods (full version, including ML-dependent helpers). -from __future__ import annotations - -import abc -import contextlib -import copy -import datetime -import functools -import gc -import heapq as hq -import importlib -import io -import itertools -import logging -import os -import pkgutil -import re -import subprocess -import sys -import time -from argparse import ArgumentParser -from collections import defaultdict, namedtuple -from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass, field -from functools import cached_property -from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Dict, \ - FrozenSet, Generator, Generic, Hashable, Iterable, Iterator, List, \ - Optional, Sequence, Set, Tuple -from typing import Type as TypingType -from typing import TypeVar, Union, cast - -import colorlog -import dill as pkl -import imageio -import matplotlib -import matplotlib.pyplot as plt -import numpy as np -import pathos.multiprocessing as mp -import PIL.Image -import torch -from gym.spaces import Box -from matplotlib import patches -from numpy.typing import NDArray -from PIL import ImageDraw, ImageFont -from pyperplan.heuristics.heuristic_base import \ - Heuristic as _PyperplanBaseHeuristic -from pyperplan.planner import HEURISTICS as _PYPERPLAN_HEURISTICS -from scipy.stats import beta as BetaRV - -from predicators.args import create_arg_parser -from predicators.image_patch_wrapper import ImagePatch -from predicators.pretrained_model_interface import GoogleGeminiLLM, \ - GoogleGeminiVLM, LargeLanguageModel, OpenAILLM, OpenAIVLM, OpenRouterLLM, \ - OpenRouterVLM, VisionLanguageModel -from predicators.pybullet_helpers.joint import JointPositions -from predicators.settings import CFG, GlobalSettings -from predicators.structs import NSRT, Action, Array, AtomOptionTrajectory, \ - CausalProcess, DelayDistribution, DerivedPredicate, DummyOption, \ - EntToEntSub, GroundAtom, GroundAtomTrajectory, \ - GroundNSRTOrSTRIPSOperator, Image, LDLRule, LiftedAtom, \ - LiftedDecisionList, LiftedOrGroundAtom, LowLevelTrajectory, Mask, \ - Metrics, NSRTOrSTRIPSOperator, Object, ObjectOrVariable, Observation, \ - OptionSpec, ParameterizedOption, Predicate, Segment, State, \ - STRIPSOperator, Task, Type, Variable, VarToObjSub, Video, VLMPredicate, \ - _GroundEndogenousProcess, _GroundLDLRule, _GroundNSRT, \ - _GroundSTRIPSOperator, _Option, _TypedEntity -from predicators.third_party.fast_downward_translator.translate import \ - main as downward_translate - -if TYPE_CHECKING: - from predicators.envs import BaseEnv - -matplotlib.use("Agg") - -# Unpickling CUDA models errs out if the device isn't recognized because of -# an unusual name, including in supercloud, but we can set it manually -if "CUDA_VISIBLE_DEVICES" in os.environ: # pragma: no cover - cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"].split(",") - if len(cuda_visible_devices) and cuda_visible_devices[0] != "0": - cuda_visible_devices[0] = "0" - os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(cuda_visible_devices) - - -def count_positives_for_ops( - strips_ops: List[STRIPSOperator], - option_specs: List[OptionSpec], - segments: List[Segment], - max_groundings: Optional[int] = None, -) -> Tuple[int, int, List[Set[int]], List[Set[int]]]: - """Returns num true positives, num false positives, and for each strips op, - lists of segment indices that contribute true or false positives. - - The lists of segment indices are useful only for debugging; they are - otherwise redundant with num_true_positives/num_false_positives. - """ - assert len(strips_ops) == len(option_specs) - num_true_positives = 0 - num_false_positives = 0 - # The following two lists are just useful for debugging. - true_positive_idxs: List[Set[int]] = [set() for _ in strips_ops] - false_positive_idxs: List[Set[int]] = [set() for _ in strips_ops] - for seg_idx, segment in enumerate(segments): - objects = set(segment.states[0]) - segment_option = segment.get_option() - option_objects = segment_option.objects - covered_by_some_op = False - # Ground only the operators with a matching option spec. - for op_idx, (op, - option_spec) in enumerate(zip(strips_ops, option_specs)): - # If the parameterized options are different, not relevant. - if option_spec[0] != segment_option.parent: - continue - option_vars = option_spec[1] - assert len(option_vars) == len(option_objects) - option_var_to_obj = dict(zip(option_vars, option_objects)) - # We want to get all ground operators whose corresponding - # substitution is consistent with the option vars for this - # segment. So, determine all of the operator variables - # that are not in the option vars, and consider all - # groundings of them. - for grounding_idx, ground_op in enumerate( - all_ground_operators_given_partial(op, objects, - option_var_to_obj)): - if max_groundings is not None and \ - grounding_idx > max_groundings: - break - # Check the ground_op against the segment. - if not ground_op.preconditions.issubset(segment.init_atoms): - continue - if ground_op.add_effects == segment.add_effects and \ - ground_op.delete_effects == segment.delete_effects: - covered_by_some_op = True - true_positive_idxs[op_idx].add(seg_idx) - else: - false_positive_idxs[op_idx].add(seg_idx) - num_false_positives += 1 - if covered_by_some_op: - num_true_positives += 1 - return num_true_positives, num_false_positives, \ - true_positive_idxs, false_positive_idxs - - -def count_branching_factor(strips_ops: List[STRIPSOperator], - segments: List[Segment]) -> int: - """Returns the total branching factor for all states in the segments.""" - total_branching_factor = 0 - for segment in segments: - atoms = segment.init_atoms - objects = set(segment.states[0]) - ground_ops = { - ground_op - for op in strips_ops - for ground_op in all_ground_operators(op, objects) - } - for _ in get_applicable_operators(ground_ops, atoms): - total_branching_factor += 1 - return total_branching_factor - - -def segment_trajectory_to_start_end_state_sequence( - seg_traj: List[Segment]) -> List[State]: - """Convert a trajectory of segments into a trajectory of states, made up of - only the initial/final states of the segments. - - The length of the return value will always be one greater than the - length of the given seg_traj. - """ - assert len(seg_traj) >= 1 - states = [] - for i, seg in enumerate(seg_traj): - states.append(seg.states[0]) - if i < len(seg_traj) - 1: - assert seg.states[-1].allclose(seg_traj[i + 1].states[0]) - states.append(seg_traj[-1].states[-1]) - assert len(states) == len(seg_traj) + 1 - return states - - -def segment_trajectory_to_atoms_sequence( - seg_traj: List[Segment]) -> List[Set[GroundAtom]]: - """Convert a trajectory of segments into a trajectory of ground atoms. - - The length of the return value will always be one greater than the - length of the given seg_traj. - """ - assert len(seg_traj) >= 1 - atoms_seq = [] - for i, seg in enumerate(seg_traj): - atoms_seq.append(seg.init_atoms) - if i < len(seg_traj) - 1: - assert seg.final_atoms == seg_traj[i + 1].init_atoms - atoms_seq.append(seg_traj[-1].final_atoms) - assert len(atoms_seq) == len(seg_traj) + 1 - return atoms_seq - - -def num_options_in_action_sequence(actions: Sequence[Action]) -> int: - """Given a sequence of actions with options included, get the number of - options that are encountered.""" - num_options = 0 - last_option = None - for action in actions: - current_option = action.get_option() - if not current_option is last_option: - last_option = current_option - num_options += 1 - return num_options - - -def entropy(p: float) -> float: - """Entropy of a Bernoulli variable with parameter p.""" - assert 0.0 <= p <= 1.0 - if p in {0.0, 1.0}: - return 0.0 - return -(p * np.log2(p) + (1 - p) * np.log2(1 - p)) - - -def create_state_from_dict(data: Dict[Object, Dict[str, float]], - simulator_state: Optional[Any] = None) -> State: - """Small utility to generate a state from a dictionary `data` of individual - feature values for each object. - - A simulator_state for the outputted State may optionally be - provided. - """ - state_dict = {} - for obj, obj_data in data.items(): - obj_vec = [] - for feat in obj.type.feature_names: - obj_vec.append(obj_data[feat]) - state_dict[obj] = np.array(obj_vec) - return State(state_dict, simulator_state) - - -def create_json_dict_from_ground_atoms( - ground_atoms: Collection[GroundAtom]) -> Dict[str, List[List[str]]]: - """Saves a set of ground atoms in a JSON-compatible dict. - - Helper for creating the goal dict in create_json_dict_from_task(). - """ - predicate_to_argument_lists = defaultdict(list) - for atom in sorted(ground_atoms): - argument_list = [o.name for o in atom.objects] - predicate_to_argument_lists[atom.predicate.name].append(argument_list) - return dict(predicate_to_argument_lists) - - -def create_json_dict_from_task(task: Task) -> Dict[str, Any]: - """Create a JSON-compatible dict from a task. - - The format of the dict is: - - { - "objects": { - : - } - "init": { - : { - : - } - } - "goal": { - : [ - [] - ] - } - } - - The dict can be loaded with BaseEnv._load_task_from_json(). This is - helpful for testing and designing standalone tasks. - """ - object_dict = {o.name: o.type.name for o in task.init} - init_dict = { - o.name: dict(zip(o.type.feature_names, task.init.data[o])) - for o in task.init - } - goal_dict = create_json_dict_from_ground_atoms(task.goal) - return {"objects": object_dict, "init": init_dict, "goal": goal_dict} - - -def construct_active_sampler_input(state: State, objects: Sequence[Object], - params: Array, - param_option: ParameterizedOption) -> Array: - """Helper function for active sampler learning and explorer.""" - - assert not CFG.sampler_learning_use_goals - sampler_input_lst = [1.0] # start with bias term - if CFG.active_sampler_learning_feature_selection == "all": - for obj in objects: - sampler_input_lst.extend(state[obj]) - sampler_input_lst.extend(params) - - else: - assert CFG.active_sampler_learning_feature_selection == "oracle" - if CFG.env == "bumpy_cover": - if param_option.name == "Pick": - # In this case, the x-data should be - # [block_bumpy, relative_pick_loc] - assert len(objects) == 1 - block = objects[0] - block_pos = state[block][3] - block_bumpy = state[block][5] - sampler_input_lst.append(block_bumpy) - assert len(params) == 1 - sampler_input_lst.append(params[0] - block_pos) - else: - assert param_option.name == "Place" - assert len(objects) == 2 - block, target = objects - target_pos = state[target][3] - grasp = state[block][4] - target_width = state[target][2] - sampler_input_lst.extend([grasp, target_width]) - assert len(params) == 1 - sampler_input_lst.append(params[0] - target_pos) - elif CFG.env == "ball_and_cup_sticky_table": - if "PlaceCup" in param_option.name and "Table" in param_option.name: - _, _, _, table = objects - table_y = state.get(table, "y") - table_x = state.get(table, "x") - sticky = state.get(table, "sticky") - sticky_region_x = state.get(table, "sticky_region_x_offset") - sticky_region_y = state.get(table, "sticky_region_y_offset") - sticky_region_radius = state.get(table, "sticky_region_radius") - table_radius = state.get(table, "radius") - _, _, _, param_x, param_y = params - sampler_input_lst.append(table_radius) - sampler_input_lst.append(sticky) - sampler_input_lst.append(sticky_region_x) - sampler_input_lst.append(sticky_region_y) - sampler_input_lst.append(sticky_region_radius) - sampler_input_lst.append(table_x) - sampler_input_lst.append(table_y) - sampler_input_lst.append(param_x) - sampler_input_lst.append(param_y) - else: # Use all features. - for obj in objects: - sampler_input_lst.extend(state[obj]) - sampler_input_lst.extend(params) - else: - raise NotImplementedError("Oracle feature selection not " - f"implemented for {CFG.env}") - - return np.array(sampler_input_lst) - - -class _Geom2D(abc.ABC): - """A 2D shape that contains some points.""" - - @abc.abstractmethod - def plot(self, ax: plt.Axes, **kwargs: Any) -> None: - """Plot the shape on a given pyplot axis.""" - raise NotImplementedError("Override me!") - - @abc.abstractmethod - def contains_point(self, x: float, y: float) -> bool: - """Checks if a point is contained in the shape.""" - raise NotImplementedError("Override me!") - - @abc.abstractmethod - def sample_random_point(self, - rng: np.random.Generator) -> Tuple[float, float]: - """Samples a random point inside the 2D shape.""" - raise NotImplementedError("Override me!") - - def intersects(self, other: _Geom2D) -> bool: - """Checks if this shape intersects with another one.""" - return geom2ds_intersect(self, other) - - -@dataclass(frozen=True) -class LineSegment(_Geom2D): - """A helper class for visualizing and collision checking line segments.""" - x1: float - y1: float - x2: float - y2: float - - def plot(self, ax: plt.Axes, **kwargs: Any) -> None: - ax.plot([self.x1, self.x2], [self.y1, self.y2], **kwargs) - - def contains_point(self, x: float, y: float) -> bool: - # https://stackoverflow.com/questions/328107 - a = (self.x1, self.y1) - b = (self.x2, self.y2) - c = (x, y) - # Need to use an epsilon for numerical stability. But we are checking - # if the distance from a to b is (approximately) equal to the distance - # from a to c and the distance from c to b. - eps = 1e-6 - - def _dist(p: Tuple[float, float], q: Tuple[float, float]) -> float: - return np.sqrt((p[0] - q[0])**2 + (p[1] - q[1])**2) - - return -eps < _dist(a, c) + _dist(c, b) - _dist(a, b) < eps - - def sample_random_point(self, - rng: np.random.Generator) -> Tuple[float, float]: - line_slope = (self.y2 - self.y1) / (self.x2 - self.x1) - y_intercept = self.y2 - (line_slope * self.x2) - random_x_point = rng.uniform(self.x1, self.x2) - random_y_point_on_line = line_slope * random_x_point + y_intercept - assert self.contains_point(random_x_point, random_y_point_on_line) - return (random_x_point, random_y_point_on_line) - - -@dataclass(frozen=True) -class Circle(_Geom2D): - """A helper class for visualizing and collision checking circles.""" - x: float - y: float - radius: float - - def plot(self, ax: plt.Axes, **kwargs: Any) -> None: - patch = patches.Circle((self.x, self.y), self.radius, **kwargs) - ax.add_patch(patch) - - def contains_point(self, x: float, y: float) -> bool: - return (x - self.x)**2 + (y - self.y)**2 <= self.radius**2 - - def contains_circle(self, other_circle: Circle) -> bool: - """Check whether this circle wholly contains another one.""" - dist_between_centers = np.sqrt((other_circle.x - self.x)**2 + - (other_circle.y - self.y)**2) - return (dist_between_centers + other_circle.radius) <= self.radius - - def sample_random_point(self, - rng: np.random.Generator) -> Tuple[float, float]: - rand_mag = rng.uniform(0, self.radius) - rand_theta = rng.uniform(0, 2 * np.pi) - x_point = self.x + rand_mag * np.cos(rand_theta) - y_point = self.y + rand_mag * np.sin(rand_theta) - assert self.contains_point(x_point, y_point) - return (x_point, y_point) - - -@dataclass(frozen=True) -class Triangle(_Geom2D): - """A helper class for visualizing and collision checking triangles.""" - x1: float - y1: float - x2: float - y2: float - x3: float - y3: float - - def plot(self, ax: plt.Axes, **kwargs: Any) -> None: - patch = patches.Polygon( - [[self.x1, self.y1], [self.x2, self.y2], [self.x3, self.y3]], - **kwargs) - ax.add_patch(patch) - - def __post_init__(self) -> None: - dist1 = np.sqrt((self.x1 - self.x2)**2 + (self.y1 - self.y2)**2) - dist2 = np.sqrt((self.x2 - self.x3)**2 + (self.y2 - self.y3)**2) - dist3 = np.sqrt((self.x3 - self.x1)**2 + (self.y3 - self.y1)**2) - dists = sorted([dist1, dist2, dist3]) - assert dists[0] + dists[1] >= dists[2] - if dists[0] + dists[1] == dists[2]: - raise ValueError("Degenerate triangle!") - - def contains_point(self, x: float, y: float) -> bool: - # Adapted from https://stackoverflow.com/questions/2049582/. - sign1 = ((x - self.x2) * (self.y1 - self.y2) - (self.x1 - self.x2) * - (y - self.y2)) > 0 - sign2 = ((x - self.x3) * (self.y2 - self.y3) - (self.x2 - self.x3) * - (y - self.y3)) > 0 - sign3 = ((x - self.x1) * (self.y3 - self.y1) - (self.x3 - self.x1) * - (y - self.y1)) > 0 - has_neg = (not sign1) or (not sign2) or (not sign3) - has_pos = sign1 or sign2 or sign3 - return not has_neg or not has_pos - - def sample_random_point(self, - rng: np.random.Generator) -> Tuple[float, float]: - a = np.array([self.x2 - self.x1, self.y2 - self.y1]) - b = np.array([self.x3 - self.x1, self.y3 - self.y1]) - u1 = rng.uniform(0, 1) - u2 = rng.uniform(0, 1) - if u1 + u2 > 1.0: - u1 = 1 - u1 - u2 = 1 - u2 - point_in_triangle = (u1 * a + u2 * b) + np.array([self.x1, self.y1]) - assert self.contains_point(point_in_triangle[0], point_in_triangle[1]) - return (point_in_triangle[0], point_in_triangle[1]) - - -@dataclass(frozen=True) -class Rectangle(_Geom2D): - """A helper class for visualizing and collision checking rectangles. - - Following the convention in plt.Rectangle, the origin is at the - bottom left corner, and rotation is anti-clockwise about that point. - - Unlike plt.Rectangle, the angle is in radians. - """ - x: float - y: float - width: float - height: float - theta: float # in radians, between -np.pi and np.pi - - def __post_init__(self) -> None: - assert -np.pi <= self.theta <= np.pi, "Expecting angle in [-pi, pi]." - - @staticmethod - def from_center(center_x: float, center_y: float, width: float, - height: float, rotation_about_center: float) -> Rectangle: - """Create a rectangle given an (x, y) for the center, with theta - rotating about that center point.""" - x = center_x - width / 2 - y = center_y - height / 2 - norm_rect = Rectangle(x, y, width, height, 0.0) - assert np.isclose(norm_rect.center[0], center_x) - assert np.isclose(norm_rect.center[1], center_y) - return norm_rect.rotate_about_point(center_x, center_y, - rotation_about_center) - - @functools.cached_property - def rotation_matrix(self) -> NDArray[np.float64]: - """Get the rotation matrix.""" - return np.array([[np.cos(self.theta), -np.sin(self.theta)], - [np.sin(self.theta), - np.cos(self.theta)]]) - - @functools.cached_property - def inverse_rotation_matrix(self) -> NDArray[np.float64]: - """Get the inverse rotation matrix.""" - return np.array([[np.cos(self.theta), - np.sin(self.theta)], - [-np.sin(self.theta), - np.cos(self.theta)]]) - - @functools.cached_property - def vertices(self) -> List[Tuple[float, float]]: - """Get the four vertices for the rectangle.""" - scale_matrix = np.array([ - [self.width, 0], - [0, self.height], - ]) - translate_vector = np.array([self.x, self.y]) - vertices = np.array([ - (0, 0), - (0, 1), - (1, 1), - (1, 0), - ]) - vertices = vertices @ scale_matrix.T - vertices = vertices @ self.rotation_matrix.T - vertices = translate_vector + vertices - # Convert to a list of tuples. Slightly complicated to appease both - # type checking and linting. - return list(map(lambda p: (p[0], p[1]), vertices)) - - @functools.cached_property - def line_segments(self) -> List[LineSegment]: - """Get the four line segments for the rectangle.""" - vs = list(zip(self.vertices, self.vertices[1:] + [self.vertices[0]])) - line_segments = [] - for ((x1, y1), (x2, y2)) in vs: - line_segments.append(LineSegment(x1, y1, x2, y2)) - return line_segments - - @functools.cached_property - def center(self) -> Tuple[float, float]: - """Get the point at the center of the rectangle.""" - x, y = np.mean(self.vertices, axis=0) - return (x, y) - - @functools.cached_property - def circumscribed_circle(self) -> Circle: - """Returns x, y, radius.""" - x, y = self.center - radius = np.sqrt((self.width / 2)**2 + (self.height / 2)**2) - return Circle(x, y, radius) - - def contains_point(self, x: float, y: float) -> bool: - # First invert translation, then invert rotation. - rx, ry = np.array([x - self.x, y - self.y - ]) @ self.inverse_rotation_matrix.T - return 0 <= rx <= self.width and \ - 0 <= ry <= self.height - - def sample_random_point(self, - rng: np.random.Generator) -> Tuple[float, float]: - rand_width = rng.uniform(0, self.width) - rand_height = rng.uniform(0, self.height) - # First rotate, then translate. - rx, ry = np.array([rand_width, rand_height]) @ self.rotation_matrix.T - x = rx + self.x - y = ry + self.y - assert self.contains_point(x, y) - return (x, y) - - def rotate_about_point(self, x: float, y: float, rot: float) -> Rectangle: - """Create a new rectangle that is this rectangle, but rotated CCW by - the given rotation (in radians), relative to the (x, y) origin. - - Rotates the vertices first, then uses them to recompute the new - theta. - """ - vertices = np.array(self.vertices) - origin = np.array([x, y]) - # Translate the vertices so that they become the "origin". - vertices = vertices - origin - # Rotate. - rotate_matrix = np.array([[np.cos(rot), -np.sin(rot)], - [np.sin(rot), np.cos(rot)]]) - vertices = vertices @ rotate_matrix.T - # Translate the vertices back. - vertices = vertices + origin - # Recompute theta. - (lx, ly), _, _, (rx, ry) = vertices - theta = np.arctan2(ry - ly, rx - lx) - rect = Rectangle(lx, ly, self.width, self.height, theta) - assert np.allclose(rect.vertices, vertices) - return rect - - def plot(self, ax: plt.Axes, **kwargs: Any) -> None: - angle = self.theta * 180 / np.pi - patch = patches.Rectangle((self.x, self.y), - self.width, - self.height, - angle=angle, - **kwargs) - ax.add_patch(patch) - - -def line_segments_intersect(seg1: LineSegment, seg2: LineSegment) -> bool: - """Checks if two line segments intersect. - - This method, which works by checking relative orientation, allows - for collinearity, and only checks if each segment straddles the line - containing the other. - """ - - def _subtract(a: Tuple[float, float], b: Tuple[float, float]) \ - -> Tuple[float, float]: - x1, y1 = a - x2, y2 = b - return (x1 - x2), (y1 - y2) - - def _cross_product(a: Tuple[float, float], b: Tuple[float, float]) \ - -> float: - x1, y1 = b - x2, y2 = a - return x1 * y2 - x2 * y1 - - def _direction(a: Tuple[float, float], b: Tuple[float, float], - c: Tuple[float, float]) -> float: - return _cross_product(_subtract(a, c), _subtract(a, b)) - - p1 = (seg1.x1, seg1.y1) - p2 = (seg1.x2, seg1.y2) - p3 = (seg2.x1, seg2.y1) - p4 = (seg2.x2, seg2.y2) - d1 = _direction(p3, p4, p1) - d2 = _direction(p3, p4, p2) - d3 = _direction(p1, p2, p3) - d4 = _direction(p1, p2, p4) - - return ((d2 < 0 < d1) or (d1 < 0 < d2)) and ((d4 < 0 < d3) or - (d3 < 0 < d4)) - - -def circles_intersect(circ1: Circle, circ2: Circle) -> bool: - """Checks if two circles intersect.""" - x1, y1, r1 = circ1.x, circ1.y, circ1.radius - x2, y2, r2 = circ2.x, circ2.y, circ2.radius - return (x1 - x2)**2 + (y1 - y2)**2 < (r1 + r2)**2 - - -def rectangles_intersect(rect1: Rectangle, rect2: Rectangle) -> bool: - """Checks if two rectangles intersect.""" - # Optimization: if the circumscribed circles don't intersect, then - # the rectangles also don't intersect. - if not circles_intersect(rect1.circumscribed_circle, - rect2.circumscribed_circle): - return False - # Case 1: line segments intersect. - if any( - line_segments_intersect(seg1, seg2) for seg1 in rect1.line_segments - for seg2 in rect2.line_segments): - return True - # Case 2: rect1 inside rect2. - if rect1.contains_point(rect2.center[0], rect2.center[1]): - return True - # Case 3: rect2 inside rect1. - if rect2.contains_point(rect1.center[0], rect1.center[1]): - return True - # Not intersecting. - return False - - -def line_segment_intersects_circle(seg: LineSegment, - circ: Circle, - ax: Optional[plt.Axes] = None) -> bool: - """Checks if a line segment intersects a circle. - - If ax is not None, a diagram is plotted on the axis to illustrate - the computations, which is useful for checking correctness. - """ - # First check if the end points of the segment are in the circle. - if circ.contains_point(seg.x1, seg.y1): - return True - if circ.contains_point(seg.x2, seg.y2): - return True - # Project the circle radius onto the extended line. - c = (circ.x, circ.y) - # Project (a, c) onto (a, b). - a = (seg.x1, seg.y1) - b = (seg.x2, seg.y2) - ba = np.subtract(b, a) - ca = np.subtract(c, a) - da = ba * np.dot(ca, ba) / np.dot(ba, ba) - # The point on the extended line that is the closest to the center. - d = dx, dy = (a[0] + da[0], a[1] + da[1]) - # Optionally plot the important points. - if ax is not None: - circ.plot(ax, color="red", alpha=0.5) - seg.plot(ax, color="black", linewidth=2) - ax.annotate("A", a) - ax.annotate("B", b) - ax.annotate("C", c) - ax.annotate("D", d) - # Check if the point is on the line. If it's not, there is no intersection, - # because we already checked that the circle does not contain the end - # points of the line segment. - if not seg.contains_point(dx, dy): - return False - # So d is on the segment. Check if it's in the circle. - return circ.contains_point(dx, dy) - - -def line_segment_intersects_rectangle(seg: LineSegment, - rect: Rectangle) -> bool: - """Checks if a line segment intersects a rectangle.""" - # Case 1: one of the end points of the segment is in the rectangle. - if rect.contains_point(seg.x1, seg.y1) or \ - rect.contains_point(seg.x2, seg.y2): - return True - # Case 2: the segment intersects with one of the rectangle sides. - return any(line_segments_intersect(s, seg) for s in rect.line_segments) - - -def rectangle_intersects_circle(rect: Rectangle, circ: Circle) -> bool: - """Checks if a rectangle intersects a circle.""" - # Optimization: if the circumscribed circle of the rectangle doesn't - # intersect with the circle, then there can't be an intersection. - if not circles_intersect(rect.circumscribed_circle, circ): - return False - # Case 1: the circle's center is in the rectangle. - if rect.contains_point(circ.x, circ.y): - return True - # Case 2: one of the sides of the rectangle intersects the circle. - for seg in rect.line_segments: - if line_segment_intersects_circle(seg, circ): - return True - return False - - -def geom2ds_intersect(geom1: _Geom2D, geom2: _Geom2D) -> bool: - """Check if two 2D bodies intersect.""" - if isinstance(geom1, LineSegment) and isinstance(geom2, LineSegment): - return line_segments_intersect(geom1, geom2) - if isinstance(geom1, LineSegment) and isinstance(geom2, Circle): - return line_segment_intersects_circle(geom1, geom2) - if isinstance(geom1, LineSegment) and isinstance(geom2, Rectangle): - return line_segment_intersects_rectangle(geom1, geom2) - if isinstance(geom1, Rectangle) and isinstance(geom2, LineSegment): - return line_segment_intersects_rectangle(geom2, geom1) - if isinstance(geom1, Circle) and isinstance(geom2, LineSegment): - return line_segment_intersects_circle(geom2, geom1) - if isinstance(geom1, Rectangle) and isinstance(geom2, Rectangle): - return rectangles_intersect(geom1, geom2) - if isinstance(geom1, Rectangle) and isinstance(geom2, Circle): - return rectangle_intersects_circle(geom1, geom2) - if isinstance(geom1, Circle) and isinstance(geom2, Rectangle): - return rectangle_intersects_circle(geom2, geom1) - if isinstance(geom1, Circle) and isinstance(geom2, Circle): - return circles_intersect(geom1, geom2) - raise NotImplementedError("Intersection not implemented for geoms " - f"{geom1} and {geom2}") - - -@functools.lru_cache(maxsize=None) -def unify(atoms1: FrozenSet[LiftedOrGroundAtom], - atoms2: FrozenSet[LiftedOrGroundAtom]) -> Tuple[bool, EntToEntSub]: - """Return whether the given two sets of atoms can be unified. - - Also return the mapping between variables/objects in these atom - sets. This mapping is empty if the first return value is False. - """ - atoms_lst1 = sorted(atoms1) - atoms_lst2 = sorted(atoms2) - - # Terminate quickly if there is a mismatch between predicates - preds1 = [atom.predicate for atom in atoms_lst1] - preds2 = [atom.predicate for atom in atoms_lst2] - if preds1 != preds2: - return False, {} - - # Terminate quickly if there is a mismatch between numbers - num1 = len({o for atom in atoms_lst1 for o in atom.entities}) - num2 = len({o for atom in atoms_lst2 for o in atom.entities}) - if num1 != num2: - return False, {} - - # Try to get lucky with a one-to-one mapping - subs12: EntToEntSub = {} - subs21 = {} - success = True - for atom1, atom2 in zip(atoms_lst1, atoms_lst2): - if not success: - break - for v1, v2 in zip(atom1.entities, atom2.entities): - if v1 in subs12 and subs12[v1] != v2: - success = False - break - if v2 in subs21: - success = False - break - subs12[v1] = v2 - subs21[v2] = v1 - if success: - return True, subs12 - - # If all else fails, use search - solved, sub = find_substitution(atoms_lst1, atoms_lst2) - rev_sub = {v: k for k, v in sub.items()} - return solved, rev_sub - - -@functools.lru_cache(maxsize=None) -def unify_preconds_effects_options( - preconds1: FrozenSet[LiftedOrGroundAtom], - preconds2: FrozenSet[LiftedOrGroundAtom], - add_effects1: FrozenSet[LiftedOrGroundAtom], - add_effects2: FrozenSet[LiftedOrGroundAtom], - delete_effects1: FrozenSet[LiftedOrGroundAtom], - delete_effects2: FrozenSet[LiftedOrGroundAtom], - param_option1: ParameterizedOption, param_option2: ParameterizedOption, - option_args1: Tuple[_TypedEntity, ...], - option_args2: Tuple[_TypedEntity, ...]) -> Tuple[bool, EntToEntSub]: - """Wrapper around unify() that handles option arguments, preconditions, add - effects, and delete effects. - - Changes predicate names so that all are treated differently by - unify(). - """ - if param_option1 != param_option2: - # Can't unify if the parameterized options are different. - return False, {} - opt_arg_pred1 = Predicate("OPT-ARGS", [a.type for a in option_args1], - _classifier=lambda s, o: False) # dummy - f_option_args1 = frozenset({GroundAtom(opt_arg_pred1, option_args1)}) - new_preconds1 = wrap_atom_predicates(preconds1, "PRE-") - f_new_preconds1 = frozenset(new_preconds1) - new_add_effects1 = wrap_atom_predicates(add_effects1, "ADD-") - f_new_add_effects1 = frozenset(new_add_effects1) - new_delete_effects1 = wrap_atom_predicates(delete_effects1, "DEL-") - f_new_delete_effects1 = frozenset(new_delete_effects1) - - opt_arg_pred2 = Predicate("OPT-ARGS", [a.type for a in option_args2], - _classifier=lambda s, o: False) # dummy - f_option_args2 = frozenset({LiftedAtom(opt_arg_pred2, option_args2)}) - new_preconds2 = wrap_atom_predicates(preconds2, "PRE-") - f_new_preconds2 = frozenset(new_preconds2) - new_add_effects2 = wrap_atom_predicates(add_effects2, "ADD-") - f_new_add_effects2 = frozenset(new_add_effects2) - new_delete_effects2 = wrap_atom_predicates(delete_effects2, "DEL-") - f_new_delete_effects2 = frozenset(new_delete_effects2) - - all_atoms1 = (f_option_args1 | f_new_preconds1 | f_new_add_effects1 - | f_new_delete_effects1) - all_atoms2 = (f_option_args2 | f_new_preconds2 | f_new_add_effects2 - | f_new_delete_effects2) - return unify(all_atoms1, all_atoms2) - - -def wrap_predicate(predicate: Predicate, prefix: str) -> Predicate: - """Return a new predicate which adds the given prefix string to the name. - - NOTE: the classifier is removed. - """ - new_predicate = Predicate(prefix + predicate.name, - predicate.types, - _classifier=lambda s, o: False) # dummy - return new_predicate - - -def wrap_atom_predicates(atoms: Collection[LiftedOrGroundAtom], - prefix: str) -> Set[LiftedOrGroundAtom]: - """Return a new set of atoms which adds the given prefix string to the name - of every atom's predicate. - - NOTE: all the classifiers are removed. - """ - new_atoms = set() - for atom in atoms: - new_predicate = wrap_predicate(atom.predicate, prefix) - new_atoms.add(atom.__class__(new_predicate, atom.entities)) - return new_atoms - - -class LinearChainParameterizedOption(ParameterizedOption): - """A parameterized option implemented via a sequence of "child" - parameterized options. - - This class is meant to help ParameterizedOption manual design. - - The children are executed in order starting with the first in the sequence - and transitioning when the terminal function of each child is hit. - - The children are assumed to chain together, so the initiable of the next - child should always be True when the previous child terminates. If this - is not the case, an AssertionError is raised. - - The children must all have the same types and params_space, which in turn - become the types and params_space for this ParameterizedOption. - - The LinearChainParameterizedOption has memory, which stores the current - child index. - """ - - def __init__(self, name: str, - children: Sequence[ParameterizedOption]) -> None: - assert len(children) > 0 - self._children = children - - # Make sure that the types and params spaces are consistent. - types = children[0].types - params_space = children[0].params_space - for i in range(1, len(self._children)): - child = self._children[i] - assert types == child.types - assert np.allclose(params_space.low, child.params_space.low) - assert np.allclose(params_space.high, child.params_space.high) - - super().__init__(name, - types, - params_space, - policy=self._policy, - initiable=self._initiable, - terminal=self._terminal) - - def _initiable(self, state: State, memory: Dict, objects: Sequence[Object], - params: Array) -> bool: - # Initialize the current child to the first one. - memory["current_child_index"] = 0 - # Create memory dicts for each child to avoid key collisions. One - # example of a failure that arises without this is when using - # multiple SingletonParameterizedOption instances, each of those - # options would be referencing the same start_state in memory. - memory["child_memory"] = [{} for _ in self._children] - current_child = self._children[0] - child_memory = memory["child_memory"][0] - return current_child.initiable(state, child_memory, objects, params) - - def _policy(self, state: State, memory: Dict, objects: Sequence[Object], - params: Array) -> Action: - # Check if the current child has terminated. - current_index = memory["current_child_index"] - current_child = self._children[current_index] - child_memory = memory["child_memory"][current_index] - if current_child.terminal(state, child_memory, objects, params): - # Move on to the next child. - current_index += 1 - memory["current_child_index"] = current_index - current_child = self._children[current_index] - child_memory = memory["child_memory"][current_index] - assert current_child.initiable(state, child_memory, objects, - params) - # logging.debug(f"Executing {current_child.name}") - return current_child.policy(state, child_memory, objects, params) - - def _terminal(self, state: State, memory: Dict, objects: Sequence[Object], - params: Array) -> bool: - # Check if the last child has terminated. - current_index = memory["current_child_index"] - if current_index < len(self._children) - 1: - return False - current_child = self._children[current_index] - child_memory = memory["child_memory"][current_index] - return current_child.terminal(state, child_memory, objects, params) - - -class SingletonParameterizedOption(ParameterizedOption): - """A parameterized option that takes a single action and stops. - - For convenience: - * Initiable defaults to always True. - * Types defaults to []. - * Params space defaults to Box(0, 1, (0, )). - """ - - def __init__( - self, - name: str, - policy: Callable[[State, Dict, Sequence[Object], Array], Action], - types: Optional[Sequence[Type]] = None, - params_space: Optional[Box] = None, - initiable: Optional[Callable[[State, Dict, Sequence[Object], Array], - bool]] = None - ) -> None: - if types is None: - types = [] - if params_space is None: - params_space = Box(0, 1, (0, )) - if initiable is None: - initiable = lambda _1, _2, _3, _4: True - - # Wrap the given initiable so that we can track whether the action - # has been executed yet. - def _initiable(state: State, memory: Dict, objects: Sequence[Object], - params: Array) -> bool: - if "start_state" in memory: - assert state.allclose(memory["start_state"]) - # Always update the memory dict due to the "is" check in _terminal. - memory["start_state"] = state - assert initiable is not None - return initiable(state, memory, objects, params) - - def _terminal(state: State, memory: Dict, objects: Sequence[Object], - params: Array) -> bool: - del objects, params # unused - assert "start_state" in memory, \ - "Must call initiable() before terminal()." - return state is not memory["start_state"] - - super().__init__(name, - types, - params_space, - policy=policy, - initiable=_initiable, - terminal=_terminal) - - -class PyBulletState(State): - """A PyBullet state that stores the robot joint positions in addition to - the features that are exposed in the object-centric state.""" - - @property - def joint_positions(self) -> JointPositions: - """Expose the current joints state in the simulator_state.""" - # if the simulator state is an array - if isinstance(self.simulator_state, Dict): - jp = self.simulator_state["joint_positions"] - else: - jp = self.simulator_state - return cast(JointPositions, jp) - - @property - def state_image(self) -> PIL.Image.Image: - """Expose the current image state in the simulator_state.""" - assert isinstance(self.simulator_state, Dict) - return self.simulator_state["unlabeled_image"] - - @property - def labeled_image(self) -> Optional[PIL.Image.Image]: - """Expose the current image state in the simulator_state.""" - assert isinstance(self.simulator_state, Dict) - return self.simulator_state.get("images") - - @property - def obj_mask_dict(self) -> Optional[Dict[Object, Mask]]: - """Expose the current object masks in the simulator_state.""" - assert isinstance(self.simulator_state, Dict) - return self.simulator_state.get("obj_mask_dict") - - def allclose(self, other: State) -> bool: - # Ignores the simulator state. - return State(self.data).allclose(State(other.data)) - - def copy(self) -> PyBulletState: - copied = super().copy() - state_dict_copy = copied.data - # simulator_state_copy = list(self.joint_positions) - simulator_state_copy = copied.simulator_state - # Forward the hidden blocks `super().copy()` deep-copied: `latent` - # (agent belief) and `privileged` (env-hidden ground truth). Both - # are dropped if not passed explicitly, since this rebuilds the - # PyBulletState rather than returning `copied`. - return PyBulletState(state_dict_copy, - simulator_state_copy, - latent=copied.latent, - privileged=copied.privileged) - - def get_obj_mask(self, obj: Object) -> Mask: - """Return the mask for the object.""" - assert self.obj_mask_dict is not None - mask = self.obj_mask_dict.get(obj) - assert mask is not None - return mask - - def label_all_objects(self) -> None: - """Label all objects in the simulator state.""" - state_ip = ImagePatch(self) - obj_mask_dict = self.obj_mask_dict - assert obj_mask_dict is not None - state_ip.label_all_objects(obj_mask_dict) - assert isinstance(self.simulator_state, Dict) - self.simulator_state["images"] = state_ip.cropped_image_in_PIL - - def add_images_and_masks(self, unlabeled_image: PIL.Image.Image, - masks: Dict[Object, Mask]) -> None: - """Add the unlabeled image and object masks to the simulator state.""" - assert isinstance(self.simulator_state, Dict) - self.simulator_state["unlabeled_image"] = unlabeled_image - self.simulator_state["obj_mask_dict"] = masks - self.label_all_objects() - - -BoundingBox = namedtuple('BoundingBox', 'left lower right upper') - - -@dataclass -class VLMState(PyBulletState): - """PyBulletState extended with VLM/visual perception capabilities.""" - state_image: PIL.Image.Image = None # type: ignore[assignment] - obj_mask_dict: Dict[Object, Mask] = field(default_factory=dict) - labeled_image: Optional[PIL.Image.Image] = None # type: ignore[assignment] - option_history: Optional[List[str]] = None - bbox_features: Dict[Object, np.ndarray] = field( - default_factory=lambda: defaultdict(lambda: np.zeros(4))) - prev_state: Optional[VLMState] = None - - def __hash__(self) -> int: - data_tuple = tuple((k, tuple(v)) for k, v in sorted(self.data.items())) - if self.simulator_state is not None: - data_tuple += tuple(self.simulator_state) - return hash(data_tuple) - - def evaluate_simple_assertion( - self, assertion: str, image: Tuple[BoundingBox, - Sequence[Object]]) -> VLMQuery: - """Given an assertion and an image, queries a VLM and returns whether - the assertion is true or false.""" - bbox, objs = image - return VLMQuery(assertion, bbox, list(objs)) - - def generate_previous_option_message(self) -> str: - """Generate the message for the previous option.""" - assert self.option_history is not None - msg = "Evaluate the truth value of the following assertions in the "\ - "current state as depicted by the image" - if CFG.nsp_pred_include_prev_image_in_prompt and \ - self.prev_state is not None: - msg += " labeled with 'curr. state'" - if CFG.nsp_pred_include_state_str_in_prompt: - msg += " and the information below" - - msg += ".\n" - - if CFG.nsp_pred_include_state_str_in_prompt: - msg += "We have the object positions and the robot's "\ - "proprioception:\n" - msg += self.dict_str(indent=2, - object_features=False, - use_object_id=True, - position_proprio_features=True) - msg += "\n" - - if len(self.option_history) == 0: - msg += "For context, this is at the beginning of a task, before "\ - "the robot has done anything.\n" - else: - msg += "For context, the state is right after the robot has"\ - " successfully executed the action "\ - f"{self.option_history[-1]}." - if CFG.nsp_pred_include_state_str_in_prompt: - if self.prev_state is not None: - msg += " The object position and robot proprioception "\ - "before executing the action is:\n" - msg += self.prev_state.dict_str( - indent=2, - object_features=False, - use_object_id=True, - position_proprio_features=True) - msg += "\n" - if CFG.nsp_pred_include_prev_image_in_prompt: - msg += " The state before executing the action is depicted"\ - " by the image labeled with 'prev. state'." - msg += " Please carefully examine the images depicting the "\ - "'prev. state' and 'curr. state' before making a judgment." - msg += "\n" - msg += "The assertions to evaluate are:" - return msg - - def add_bbox_features(self) -> None: - """Add the features about the bounding box to the objects.""" - for obj, mask in self.obj_mask_dict.items(): - bbox = mask_to_bbox(mask) - for name, value in bbox._asdict().items(): - self.set(obj, f"bbox_{name}", value) - - def set(self, obj: Object, feature_name: str, feature_val: Any) -> None: - """Set the value of an object feature by name.""" - idx = obj.type.feature_names.index(feature_name) - standard_feature_len = len(self.data[obj]) - if idx >= standard_feature_len: - self.bbox_features[obj][idx - standard_feature_len] = feature_val - else: - self.data[obj][idx] = feature_val - - def get(self, obj: Object, feature_name: str) -> Any: - idx = obj.type.feature_names.index(feature_name) - standard_feature_len = len(self.data[obj]) - if idx >= standard_feature_len: - return self.bbox_features[obj][idx - standard_feature_len] - return self.data[obj][idx] - - def dict_str( # type: ignore[override] - self, - indent: int = 0, - object_features: bool = True, - num_decimal_points: int = 2, - use_object_id: bool = False, - position_proprio_features: bool = False) -> str: - """Return a dictionary representation of the state.""" - state_dict = {} - for obj in self: - obj_dict = {} - for attribute, value in zip( - obj.type.feature_names, - np.concatenate([self[obj], self.bbox_features[obj]]) - if self.bbox_features else self[obj]): - if (position_proprio_features and attribute - in ["rot", "fingers"]) or (object_features - and attribute not in [ - "is_heavy", - ]): - if isinstance(value, (float, int, np.float32)): - value = round(float(value), 1) - obj_dict[attribute] = value - - if use_object_id: - obj_name = obj.id_name - else: - obj_name = obj.name - state_dict[f"{obj_name}:{obj.type.name}"] = obj_dict - - spaces = " " * indent - dict_str = spaces + "{" - n_keys = len(state_dict.keys()) - for i, (key, value) in enumerate(state_dict.items()): - value_str = ', '.join(f"'{k}': {v}" for k, v in value.items()) - if value_str == "": - content_str = f"'{key}'" - else: - content_str = f"'{key}': {{{value_str}}}" - if i == 0: - dict_str += f"{content_str},\n" - elif i == n_keys - 1: - dict_str += spaces + f" {content_str}" - else: - dict_str += spaces + f" {content_str},\n" - dict_str += "}" - return dict_str - - def __eq__(self, other: object) -> bool: - assert isinstance(other, VLMState) - if len(self.data) != len(other.data): - return False - for key, value in self.data.items(): - if key not in other.data or not np.array_equal( - value, other.data[key]): - return False - return self.simulator_state == other.simulator_state - - def label_all_objects(self) -> None: - state_ip = ImagePatch(self) - state_ip.label_all_objects(self.obj_mask_dict) - self.labeled_image = state_ip.cropped_image_in_PIL - - def copy(self) -> VLMState: - pybullet_state_copy = super().copy() - state_image_copy = copy.copy(self.state_image) - obj_mask_copy = copy.deepcopy(self.obj_mask_dict) - labeled_image_copy = copy.copy(self.labeled_image) - option_history_copy = copy.copy(self.option_history) - bbox_features_copy = copy.deepcopy(self.bbox_features) - prev_state_copy = self.prev_state.copy() if self.prev_state else None - # Use kwargs for the VLM-specific fields so positional shifts in - # the base `State` dataclass (e.g. the `latent` block added for - # the recurrent partial-observability approach) don't reorder - # this call. - return VLMState( - data=pybullet_state_copy.data, - simulator_state=pybullet_state_copy.simulator_state, - latent=pybullet_state_copy.latent, - privileged=pybullet_state_copy.privileged, - state_image=state_image_copy, - obj_mask_dict=obj_mask_copy, - labeled_image=labeled_image_copy, - option_history=option_history_copy, - bbox_features=bbox_features_copy, - prev_state=prev_state_copy, - ) - - def get_obj_mask(self, obj: Object) -> Mask: - """Return the mask for the object.""" - return self.obj_mask_dict[obj] - - def get_obj_bbox(self, obj: Object) -> BoundingBox: - """Get the bounding box of the object in the state image.""" - mask = self.get_obj_mask(obj) - return mask_to_bbox(mask) - - def crop_to_objects( # pylint: disable=missing-function-docstring - self, - objects: Sequence[Object], - left_margin: int = 30, - lower_margin: int = 30, - right_margin: int = 30, - top_margin: int = 30) -> Tuple[BoundingBox, Sequence[Object]]: - bboxes = [self.get_obj_bbox(obj) for obj in objects] - bbox = smallest_bbox_from_bboxes(bboxes) - return (BoundingBox( - max(bbox.left - left_margin, 0), max(bbox.lower - lower_margin, 0), - min(bbox.right + right_margin, self.state_image.width), - min(bbox.upper + top_margin, self.state_image.height)), objects) - - -@dataclass -class VLMQuery: - """A class to represent a query to a VLM.""" - query_str: str - attention_box: BoundingBox - attn_objects: Optional[List[Object]] = None - ground_atom: Optional[GroundAtom] = None - - -def mask_to_bbox(mask: Mask) -> BoundingBox: - """Return the bounding box of the mask.""" - y_indices, x_indices = np.where(mask) - height = mask.shape[0] - - # Get the bounding box - try: - left = x_indices.min() - right = x_indices.max() - lower = height - (y_indices.max() + 1) - upper = height - (y_indices.min() + 1) - except ValueError: - left, lower, right, upper = 0, 0, 0, 0 - # If the mask is empty, return a bounding box with all zeros - - return BoundingBox(left, lower, right, upper) - - -def smallest_bbox_from_bboxes(bboxes: Sequence[BoundingBox]) -> BoundingBox: - """Return the smallest bounding box that contains all the given - bounding.""" - - # Initialize the bounding box coordinates - left, lower, right, upper = np.inf, np.inf, -np.inf, -np.inf - # Iterate over all masks - for bbox in bboxes: - # Update the bounding box - left = min(left, bbox.left) - lower = min(lower, bbox.lower) - right = max(right, bbox.right) - upper = max(upper, bbox.upper) - return BoundingBox(left, lower, right, upper) - - -class StateWithCache(State): - """A state with a cache stored in the simulator state that is ignored for - state equality checks. - - The cache is deliberately not copied. - """ - - @property - def cache(self) -> Dict[str, Dict]: - """Expose the cache in the simulator_state.""" - return cast(Dict[str, Dict], self.simulator_state) - - def allclose(self, other: State) -> bool: - # Ignores the simulator state. - return State(self.data).allclose(State(other.data)) - - def copy(self) -> State: - copied = super().copy() - # The cache (simulator_state) is deliberately shared, not copied; - # forward the hidden latent/privileged blocks so they survive. - return StateWithCache(copied.data, - self.cache, - latent=copied.latent, - privileged=copied.privileged) - - -class LoggingMonitor(abc.ABC): - """Observes states and actions during environment interaction.""" - - @abc.abstractmethod - def reset(self, train_or_test: str, task_idx: int) -> None: - """Called when the monitor starts a new episode.""" - raise NotImplementedError("Override me!") - - @abc.abstractmethod - def observe(self, obs: Observation, action: Optional[Action]) -> None: - """Record an observation and the action that is about to be taken. - - On the last timestep of a trajectory, no action is taken, so - action is None. - """ - raise NotImplementedError("Override me!") - - -def run_policy( - policy: Callable[[State], Action], - env: BaseEnv, - train_or_test: str, - task_idx: int, - termination_function: Callable[[State], bool], - max_num_steps: int, - do_env_reset: bool = True, - exceptions_to_break_on: Optional[Set[TypingType[Exception]]] = None, - monitor: Optional[LoggingMonitor] = None -) -> Tuple[LowLevelTrajectory, Metrics]: - """Execute a policy starting from the initial state of a train or test task - in the environment. The task's goal is not used. - - Note that the environment internal state is updated. - - Terminates when any of these conditions hold: - (1) the termination_function returns True - (2) max_num_steps is reached - (3) policy() or step() raise an exception of type in exceptions_to_break_on - - Note that in the case where the exception is raised in step, we exclude the - last action from the returned trajectory to maintain the invariant that - the trajectory states are of length one greater than the actions. - - NOTE: this may be deprecated in the future in favor of run_episode defined - in cogman.py. Ideally, we should consolidate both run_policy and - run_policy_with_simulator below into run_episode. - """ - if do_env_reset: - env.reset(train_or_test, task_idx) - if monitor is not None: - monitor.reset(train_or_test, task_idx) - obs = env.get_observation() - assert isinstance(obs, State) - state = obs - states = [state] - actions: List[Action] = [] - metrics: Metrics = defaultdict(float) - metrics["policy_call_time"] = 0.0 - exception_raised_in_step = False - if not termination_function(state): - for _ in range(max_num_steps): - monitor_observed = False - exception_raised_in_step = False - try: - start_time = time.perf_counter() - act = policy(state) - metrics["policy_call_time"] += time.perf_counter() - start_time - except Exception as e: # pylint: disable=broad-except - if not CFG.video_not_break_on_exception: - if exceptions_to_break_on is not None and \ - type(e) in exceptions_to_break_on: - if monitor_observed: - exception_raised_in_step = True - break - raise e - if monitor is not None and not monitor_observed: - monitor.observe(state, None) - monitor_observed = True - else: - if monitor is not None and not monitor_observed: - monitor.observe(state, act) - monitor_observed = True - - try: - # Note: it's important to call monitor.observe() before - # env.step(), because the monitor may use the environment's - # internal state. - state = env.step(act) - actions.append(act) - states.append(state) - except Exception as e: - if exceptions_to_break_on is not None and \ - type(e) in exceptions_to_break_on: - if monitor_observed: - exception_raised_in_step = True - break - raise e - if termination_function(state): - break - if monitor is not None and not exception_raised_in_step: - monitor.observe(state, None) - traj = LowLevelTrajectory(states, actions) - return traj, metrics - - -def run_policy_with_simulator( - policy: Callable[[State], Action], - simulator: Callable[[State, Action], State], - init_state: State, - termination_function: Callable[[State], bool], - max_num_steps: int, - exceptions_to_break_on: Optional[Set[TypingType[Exception]]] = None, - monitor: Optional[LoggingMonitor] = None) -> LowLevelTrajectory: - """Execute a policy from a given initial state, using a simulator. - - *** This function should not be used with any core code, because we want - to avoid the assumption of a simulator when possible. *** - - This is similar to run_policy, with three major differences: - (1) The initial state `init_state` can be any state, not just the initial - state of a train or test task. (2) A simulator (function that takes state - as input) is assumed. (3) Metrics are not returned. - - Note that the environment internal state is NOT updated. - - Terminates when any of these conditions hold: - (1) the termination_function returns True - (2) max_num_steps is reached - (3) policy() or step() raise an exception of type in exceptions_to_break_on - - Note that in the case where the exception is raised in step, we exclude the - last action from the returned trajectory to maintain the invariant that - the trajectory states are of length one greater than the actions. - """ - state = init_state - states = [state] - actions: List[Action] = [] - exception_raised_in_step = False - if not termination_function(state): - for i in range(max_num_steps): - # logging.debug(f"State: {state.pretty_str()}") - monitor_observed = False - exception_raised_in_step = False - try: - act = policy(state) - # logging.debug(f"Action: {act}") - if monitor is not None: - monitor.observe(state, act) - monitor_observed = True - state = simulator(state, act) - actions.append(act) - states.append(state) - except Exception as e: - logging.debug(f"Exception during running policy: {e}") - if exceptions_to_break_on is not None and \ - type(e) in exceptions_to_break_on: - if monitor_observed: - exception_raised_in_step = True - break - if monitor is not None and not monitor_observed: - monitor.observe(state, None) - raise e - if termination_function(state): - break - logging.debug(f"Ran {i + 1} steps") - if monitor is not None and not exception_raised_in_step: - monitor.observe(state, None) - traj = LowLevelTrajectory(states, actions) - return traj - - -class ExceptionWithInfo(Exception): - """An exception with an optional info dictionary that is initially - empty.""" - - def __init__(self, message: str, info: Optional[Dict] = None) -> None: - super().__init__(message) - if info is None: - info = {} - assert isinstance(info, dict) - self.info = info - - -class OptionExecutionFailure(ExceptionWithInfo): - """An exception raised by an option policy in the course of execution.""" - - -class OptionTimeoutFailure(OptionExecutionFailure): - """A special kind of option execution failure due to an exceeded budget.""" - - -class RequestActPolicyFailure(ExceptionWithInfo): - """An exception raised by an acting policy in a request when it fails to - produce an action, which terminates the interaction.""" - - -class HumanDemonstrationFailure(ExceptionWithInfo): - """An exception raised when CFG.demonstrator == "human" and the human gives - a bad input.""" - - -class EnvironmentFailure(ExceptionWithInfo): - """Exception raised when any type of failure occurs in an environment. - - The info dictionary must contain a key "offending_objects", which - maps to a set of objects responsible for the failure. - """ - - def __repr__(self) -> str: - return f"{super().__repr__()}: {self.info}" - - def __str__(self) -> str: - return repr(self) - - -def check_wait_target_atoms( - option: _Option, - state: State, - abstract_function: Callable[[State], Set[GroundAtom]], -) -> Optional[bool]: - """Check if a Wait option's target atoms are satisfied. - - Returns True if targets are met (Wait should terminate), False if - not yet met, or None if no targets were specified (caller should - fall back to any-atom-change behaviour). - """ - pos = option.memory.get("wait_target_atoms", set()) - neg = option.memory.get("wait_target_neg_atoms", set()) - if not pos and not neg: - return None - cur_atoms = abstract_function(state) - return pos.issubset(cur_atoms) and neg.isdisjoint(cur_atoms) - - -def parse_wait_target_annotations( - line: str, - predicates: Collection[Predicate], - objects: Collection[Object], -) -> Tuple[Set[GroundAtom], Set[GroundAtom]]: - """Parse ``-> {Pred(...), NOT Pred(...)}`` from a plan line. - - Returns ``(positive_atoms, negative_atoms)`` where positive atoms - must become TRUE and negative atoms must become FALSE for the Wait - to terminate. - """ - pred_map = {p.name: p for p in predicates} - obj_map = {o.name: o for o in objects} - - sg_match = re.search(r'->\s*\{([^}]*)\}', line) - if not sg_match: - return set(), set() - - pos_atoms: Set[GroundAtom] = set() - neg_atoms: Set[GroundAtom] = set() - atom_re = re.compile(r'(NOT\s+)?(\w+)\(([^)]*)\)') - - for m in atom_re.finditer(sg_match.group(1)): - is_neg = m.group(1) is not None - pred_name = m.group(2) - obj_names = [n.strip().split(':')[0] for n in m.group(3).split(',')] - - if pred_name not in pred_map: - logging.warning("Unknown predicate in Wait target: %s", pred_name) - continue - pred = pred_map[pred_name] - try: - objs = [obj_map[n] for n in obj_names] - except KeyError as e: - logging.warning("Unknown object in Wait target: %s", e) - continue - if len(objs) != len(pred.types): - logging.warning("Arity mismatch for %s: expected %d, got %d", - pred_name, len(pred.types), len(objs)) - continue - atom = GroundAtom(pred, objs) - if is_neg: - neg_atoms.add(atom) - else: - pos_atoms.add(atom) - - return pos_atoms, neg_atoms - - -def inject_wait_targets_for_option( - option: _Option, - step_idx: int, - atoms_sequence: Sequence[Set[GroundAtom]], -) -> None: - """Inject Wait target atoms into a single option from atoms_sequence. - - Computes the expected atom delta from ``atoms_sequence[step_idx]`` - to ``atoms_sequence[step_idx + 1]`` and stores it in the option's - memory so that execution terminates on specific atoms rather than - any noisy change. No-op for non-Wait options or out-of-bounds - indices. - """ - if option.name != "Wait": - return - if step_idx + 1 >= len(atoms_sequence): - return - before = atoms_sequence[step_idx] - after = atoms_sequence[step_idx + 1] - target_pos = after - before - target_neg = before - after - if target_pos: - option.memory["wait_target_atoms"] = target_pos - if target_neg: - option.memory["wait_target_neg_atoms"] = target_neg - - -def strip_wait_annotations(text: str) -> str: - """Remove ``-> {...}`` annotations from plan text lines.""" - return re.sub(r'\s*->\s*\{[^}]*\}', '', text) - - -def _format_wait_target_debug( - state: State, target_atoms: Set[GroundAtom], - abstract_function: Callable[[State], Set[GroundAtom]]) -> str: - """Format state details for debugging why Wait has not terminated.""" - cur_atoms = abstract_function(state) - missing_targets = target_atoms - cur_atoms - target_objects = sorted( - { - ent - for atom in target_atoms - for ent in atom.entities if isinstance(ent, Object) - }, - key=lambda o: o.name) - object_details = [] - for obj in target_objects: - feature_values = [] - for feature_name in obj.type.feature_names: - value = state.get(obj, feature_name) - if isinstance(value, float): - value_str = f"{value:.4f}" - else: - value_str = str(value) - feature_values.append(f"{feature_name}={value_str}") - object_details.append(f"{obj}: " + ", ".join(feature_values)) - details = [ - f"Targets: {sorted(target_atoms)}", - f"Missing: {sorted(missing_targets)}", - f"cur_atoms: {sorted(cur_atoms)}", - ] - if object_details: - details.append(f"target_objects: {'; '.join(object_details)}") - return "; ".join(details) - - -def option_policy_to_policy( - option_policy: Callable[[State], _Option], - max_option_steps: Optional[int] = None, - raise_error_on_repeated_state: bool = False, - abstract_function: Optional[Callable[[State], Set[GroundAtom]]] = None -) -> Callable[[State], Action]: - """Create a policy that executes a policy over options.""" - cur_option = DummyOption - num_cur_option_steps = 0 - last_state: Optional[State] = None - - def _policy(state: State) -> Action: - nonlocal cur_option, num_cur_option_steps, last_state - - if cur_option is DummyOption: - last_option: Optional[_Option] = None - else: - last_option = cur_option - - if max_option_steps is not None and \ - num_cur_option_steps >= max_option_steps: - raise OptionTimeoutFailure( - "Exceeded max option steps.", - info={"last_failed_option": last_option}) - - if last_state is not None and \ - raise_error_on_repeated_state and state.allclose(last_state): - raise OptionTimeoutFailure( - "Encountered repeated state.", - info={"last_failed_option": last_option}) - # logging for debugging - # if last_state is not None: - # cur_atoms = abstract_function(state) - # prev_atoms = abstract_function(last_state) - # logging.debug(f"Prev atoms: {sorted(prev_atoms)}") - # logging.info(f"Add atoms: {sorted(cur_atoms-prev_atoms)} " - # f"Del atoms: {sorted(prev_atoms-cur_atoms)}") - - # whether the noop option should terminate - wait_terminate = False - if CFG.wait_option_terminate_on_atom_change \ - and cur_option.name == "Wait": - assert abstract_function is not None - assert last_state is not None - target_atoms = cur_option.memory.get("wait_target_atoms") - result = check_wait_target_atoms(cur_option, state, - abstract_function) - if result is True: - cur_atoms = abstract_function(state) - logging.debug("Wait terminating: target atoms satisfied. " - f"Targets: {target_atoms}, " - f"cur_atoms: {sorted(cur_atoms)}, " - f"num_option_steps={num_cur_option_steps}") - wait_terminate = True - elif result is False: - assert target_atoms is not None - if num_cur_option_steps <= 1 or num_cur_option_steps % 25 == 0: - wait_debug = _format_wait_target_debug( - state, target_atoms, abstract_function) - logging.debug( - "Wait continuing: target atoms not yet satisfied. " - "%s, num_option_steps=%d", wait_debug, - num_cur_option_steps) - elif result is None: - # No targets specified: fall back to any-atom-change - cur_atoms = abstract_function(state) - prev_atoms = abstract_function(last_state) - if cur_atoms != prev_atoms: - logging.debug(f"Wait terminating due to atom change: " - f"Add: {sorted(cur_atoms-prev_atoms)} " - f"Del: {sorted(prev_atoms-cur_atoms)}") - wait_terminate = True - - last_state = state - - option_terminal = cur_option is not DummyOption and \ - cur_option.terminal(state) - if wait_terminate or cur_option is DummyOption or option_terminal: - if cur_option is not DummyOption: - if wait_terminate: - reason = "atom change during Wait" - elif option_terminal: - reason = "option self-terminated" - else: - reason = "unknown" - logging.info(f"[{cur_option.name}] Terminated: {reason} " - f"(after {num_cur_option_steps} steps)\n") - try: - cur_option = option_policy(state) - except OptionExecutionFailure as e: - e.info["last_failed_option"] = last_option - raise e - if not cur_option.initiable(state): - raise OptionExecutionFailure( - "Unsound option policy.", - info={"last_failed_option": last_option}) - logging.debug(f"[option_policy] Started option {cur_option.name}, " - f"initiable=True") - num_cur_option_steps = 0 - - num_cur_option_steps += 1 - - return cur_option.policy(state) - - return _policy - - -def option_plan_to_policy( - plan: Sequence[_Option], - max_option_steps: Optional[int] = None, - raise_error_on_repeated_state: bool = False, - abstract_function: Optional[Callable[[State], Set[GroundAtom]]] = None -) -> Callable[[State], Action]: - """Create a policy that executes a sequence of options in order.""" - queue = list(plan) # don't modify plan, just in case - total_options = len(queue) - - def _option_policy(state: State) -> _Option: - del state # not used - if not queue: - logging.info("Option plan exhausted after %d options.", - total_options) - raise OptionExecutionFailure("Option plan exhausted!") - option = queue.pop(0) - option_num = total_options - len(queue) - next_option = None if not queue else queue[0].simple_str() - logging.info("Executing option %d/%d: %s (remaining=%d, next=%s)", - option_num, total_options, option.simple_str(), - len(queue), next_option) - return option - - return option_policy_to_policy( - _option_policy, - max_option_steps=max_option_steps, - raise_error_on_repeated_state=raise_error_on_repeated_state, - abstract_function=abstract_function) - - -def nsrt_plan_to_greedy_option_policy( - nsrt_plan: Sequence[_GroundNSRT], - goal: Set[GroundAtom], - rng: np.random.Generator, - necessary_atoms_seq: Optional[Sequence[Set[GroundAtom]]] = None -) -> Callable[[State], _Option]: - """Greedily execute an NSRT plan, assuming downward refinability and that - any sample will work. - - If an option is not initiable or if the plan runs out, an - OptionExecutionFailure is raised. - """ - cur_nsrt: Optional[_GroundNSRT] = None - nsrt_queue = list(nsrt_plan) - if necessary_atoms_seq is None: - empty_atoms: Set[GroundAtom] = set() - necessary_atoms_seq = [empty_atoms for _ in range(len(nsrt_plan) + 1)] - assert len(necessary_atoms_seq) == len(nsrt_plan) + 1 - necessary_atoms_queue = list(necessary_atoms_seq) - - def _option_policy(state: State) -> _Option: - nonlocal cur_nsrt - if not nsrt_queue: - raise OptionExecutionFailure("NSRT plan exhausted.") - expected_atoms = necessary_atoms_queue.pop(0) - if not all(a.holds(state) for a in expected_atoms): - raise OptionExecutionFailure( - "Executing the NSRT failed to achieve the necessary atoms.") - cur_nsrt = nsrt_queue.pop(0) - cur_option = cur_nsrt.sample_option(state, goal, rng) - logging.debug(f"Using option {cur_option.name}{cur_option.objects}" - f"{cur_option.params} from NSRT plan.") - return cur_option - - return _option_policy - - -def nsrt_plan_to_greedy_policy( - nsrt_plan: Sequence[_GroundNSRT], - goal: Set[GroundAtom], - rng: np.random.Generator, - necessary_atoms_seq: Optional[Sequence[Set[GroundAtom]]] = None, - abstract_function: Optional[Callable[[State], Set[GroundAtom]]] = None -) -> Callable[[State], Action]: - """Greedily execute an NSRT plan, assuming downward refinability and that - any sample will work. - - If an option is not initiable or if the plan runs out, an - OptionExecutionFailure is raised. - """ - option_policy = nsrt_plan_to_greedy_option_policy( - nsrt_plan, goal, rng, necessary_atoms_seq=necessary_atoms_seq) - return option_policy_to_policy(option_policy, - abstract_function=abstract_function) - - -def process_plan_to_greedy_option_policy( - process_plan: Sequence[_GroundEndogenousProcess], - goal: Set[GroundAtom], - rng: np.random.Generator, - necessary_atoms_seq: Optional[Sequence[Set[GroundAtom]]] = None, - atoms_seq: Optional[Sequence[Set[GroundAtom]]] = None, -) -> Callable[[State], _Option]: - """Greedily execute a process plan, assuming downward refinability and that - any sample will work. - - If an option is not initiable or if the plan runs out, an - OptionExecutionFailure is raised. - """ - cur_process: Optional[_GroundEndogenousProcess] = None - process_queue = list(process_plan) - if necessary_atoms_seq is None: - empty_atoms: Set[GroundAtom] = set() - necessary_atoms_seq = [ - empty_atoms for _ in range(len(process_plan) + 1) - ] - assert len(necessary_atoms_seq) == len(process_plan) + 1 - necessary_atoms_queue = list(necessary_atoms_seq) - step_idx = 0 - - def _option_policy(state: State) -> _Option: - nonlocal cur_process, step_idx - if not process_queue: - raise OptionExecutionFailure("Process plan exhausted.") - expected_atoms = necessary_atoms_queue.pop(0) - if not all(a.holds(state) for a in expected_atoms): - raise OptionExecutionFailure( - "Executing the process failed to achieve the necessary atoms.") - cur_process = process_queue.pop(0) - cur_option = cur_process.sample_option(state, goal, rng) - if atoms_seq is not None: - inject_wait_targets_for_option(cur_option, step_idx, atoms_seq) - step_idx += 1 - logging.debug(f"Using option {cur_option.name}{cur_option.objects}" - f"{cur_option.params} from process plan.") - return cur_option - - return _option_policy - - -def process_plan_to_greedy_policy( - process_plan: Sequence[_GroundEndogenousProcess], - goal: Set[GroundAtom], - rng: np.random.Generator, - necessary_atoms_seq: Optional[Sequence[Set[GroundAtom]]] = None, - abstract_function: Optional[Callable[[State], Set[GroundAtom]]] = None, - atoms_seq: Optional[Sequence[Set[GroundAtom]]] = None, -) -> Callable[[State], Action]: - """Convert a process plan to a greedy policy.""" - option_policy = process_plan_to_greedy_option_policy( - process_plan, - goal, - rng, - necessary_atoms_seq=necessary_atoms_seq, - atoms_seq=atoms_seq) - return option_policy_to_policy(option_policy, - abstract_function=abstract_function) - - -def sample_applicable_option(param_options: List[ParameterizedOption], - state: State, - rng: np.random.Generator) -> Optional[_Option]: - """Sample an applicable option.""" - for _ in range(CFG.random_options_max_tries): - param_opt = param_options[rng.choice(len(param_options))] - objs = get_random_object_combination(list(state), param_opt.types, rng) - if objs is None: - continue - params = param_opt.params_space.sample() - opt = param_opt.ground(objs, params) - if opt.initiable(state): - return opt - return None - - -def create_random_option_policy( - options: Collection[ParameterizedOption], rng: np.random.Generator, - fallback_policy: Callable[[State], - Action]) -> Callable[[State], Action]: - """Create a policy that executes random initiable options. - - If no applicable option can be found, query the fallback policy. - """ - sorted_options = sorted(options, key=lambda o: o.name) - cur_option = DummyOption - - def _policy(state: State) -> Action: - nonlocal cur_option - if cur_option is DummyOption or cur_option.terminal(state): - cur_option = DummyOption - sample = sample_applicable_option(sorted_options, state, rng) - if sample is not None: - cur_option = sample - else: - return fallback_policy(state) - act = cur_option.policy(state) - return act - - return _policy - - -def sample_applicable_ground_nsrt( - state: State, ground_nsrts: Sequence[_GroundNSRT], - predicates: Set[Predicate], - rng: np.random.Generator) -> Optional[_GroundNSRT]: - """Choose uniformly among the ground NSRTs that are applicable in the - state.""" - atoms = abstract(state, predicates) - applicable_nsrts = sorted(get_applicable_operators(ground_nsrts, atoms)) - if len(applicable_nsrts) == 0: - return None - idx = rng.choice(len(applicable_nsrts)) - return applicable_nsrts[idx] # type: ignore[return-value] - - -def action_arrs_to_policy( - action_arrs: Sequence[Array]) -> Callable[[State], Action]: - """Create a policy that executes action arrays in sequence.""" - - queue = list(action_arrs) # don't modify original, just in case - - def _policy(s: State) -> Action: - del s # unused - return Action(queue.pop(0)) - - return _policy - - -def _get_entity_combinations( - entities: Collection[ObjectOrVariable], - types: Sequence[Type]) -> Iterator[List[ObjectOrVariable]]: - """Get all combinations of entities satisfying the given types sequence.""" - sorted_entities = sorted(entities) - choices = [] - for vt in types: - this_choices = [] - for ent in sorted_entities: - if ent.is_instance(vt): - this_choices.append(ent) - choices.append(this_choices) - for choice in itertools.product(*choices): - yield list(choice) - - -def get_object_combinations(objects: Collection[Object], - types: Sequence[Type]) -> Iterator[List[Object]]: - """Get all combinations of objects satisfying the given types sequence.""" - return _get_entity_combinations(objects, types) - - -def get_variable_combinations( - variables: Collection[Variable], - types: Sequence[Type]) -> Iterator[List[Variable]]: - """Get all combinations of variables satisfying the given types - sequence.""" - return _get_entity_combinations(variables, types) - - -def get_all_ground_atoms_for_predicate( - predicate: Predicate, objects: Collection[Object]) -> Set[GroundAtom]: - """Get all groundings of the predicate given objects. - - Note: we don't want lru_cache() on this function because we might want - to call it with stripped predicates, and we wouldn't want it to return - cached values. - """ - ground_atoms = set() - for args in get_object_combinations(objects, predicate.types): - ground_atom = GroundAtom(predicate, args) - ground_atoms.add(ground_atom) - return ground_atoms - - -def get_all_lifted_atoms_for_predicate( - predicate: Predicate, - variables: FrozenSet[Variable]) -> Set[LiftedAtom]: - """Get all groundings of the predicate given variables. - - Note: we don't want lru_cache() on this function because we might want - to call it with stripped predicates, and we wouldn't want it to return - cached values. - """ - lifted_atoms = set() - for args in get_variable_combinations(variables, predicate.types): - lifted_atom = LiftedAtom(predicate, args) - lifted_atoms.add(lifted_atom) - return lifted_atoms - - -def get_random_object_combination( - objects: Collection[Object], types: Sequence[Type], - rng: np.random.Generator) -> Optional[List[Object]]: - """Get a random list of objects from the given collection that satisfy the - given sequence of types. - - Duplicates are always allowed. If a particular type has no object, - return None. - """ - types_to_objs = defaultdict(list) - for obj in objects: - types_to_objs[obj.type].append(obj) - result = [] - for t in types: - t_objs = types_to_objs[t] - if not t_objs: - return None - result.append(t_objs[rng.choice(len(t_objs))]) - return result - - -def find_substitution( - super_atoms: Collection[LiftedOrGroundAtom], - sub_atoms: Collection[LiftedOrGroundAtom], - allow_redundant: bool = False, -) -> Tuple[bool, EntToEntSub]: - """Find a substitution from the entities in super_atoms to the entities in - sub_atoms s.t. sub_atoms is a subset of super_atoms. - - If allow_redundant is True, then multiple entities in sub_atoms can - refer to the same single entity in super_atoms. - - If no substitution exists, return (False, {}). - """ - super_entities_by_type: Dict[Type, List[_TypedEntity]] = defaultdict(list) - super_pred_to_tuples = defaultdict(set) - for atom in super_atoms: - for obj in atom.entities: - if obj not in super_entities_by_type[obj.type]: - super_entities_by_type[obj.type].append(obj) - super_pred_to_tuples[atom.predicate].add(tuple(atom.entities)) - sub_variables = sorted({e for atom in sub_atoms for e in atom.entities}) - return _find_substitution_helper(sub_atoms, super_entities_by_type, - sub_variables, super_pred_to_tuples, {}, - allow_redundant) - - -def _find_substitution_helper( - sub_atoms: Collection[LiftedOrGroundAtom], - super_entities_by_type: Dict[Type, List[_TypedEntity]], - remaining_sub_variables: List[_TypedEntity], - super_pred_to_tuples: Dict[Predicate, - Set[Tuple[_TypedEntity, - ...]]], partial_sub: EntToEntSub, - allow_redundant: bool) -> Tuple[bool, EntToEntSub]: - """Helper for find_substitution.""" - # Base case: check if all assigned - if not remaining_sub_variables: - return True, partial_sub - # Find next variable to assign - remaining_sub_variables = remaining_sub_variables.copy() - next_sub_var = remaining_sub_variables.pop(0) - # Consider possible assignments - for super_obj in super_entities_by_type[next_sub_var.type]: - if not allow_redundant and super_obj in partial_sub.values(): - continue - new_sub = partial_sub.copy() - new_sub[next_sub_var] = super_obj - # Check if consistent - if not _substitution_consistent(new_sub, super_pred_to_tuples, - sub_atoms): - continue - # Backtracking search - solved, final_sub = _find_substitution_helper(sub_atoms, - super_entities_by_type, - remaining_sub_variables, - super_pred_to_tuples, - new_sub, allow_redundant) - if solved: - return solved, final_sub - # Failure - return False, {} - - -def _substitution_consistent( - partial_sub: EntToEntSub, - super_pred_to_tuples: Dict[Predicate, Set[Tuple[_TypedEntity, ...]]], - sub_atoms: Collection[LiftedOrGroundAtom]) -> bool: - """Helper for _find_substitution_helper.""" - for sub_atom in sub_atoms: - if not set(sub_atom.entities).issubset(partial_sub.keys()): - continue - substituted_vars = tuple(partial_sub[e] for e in sub_atom.entities) - if substituted_vars not in super_pred_to_tuples[sub_atom.predicate]: - return False - return True - - -def create_new_variables( - types: Sequence[Type], - existing_vars: Optional[Collection[Variable]] = None, - var_prefix: str = "?x", -) -> List[Variable]: - """Create new variables of the given types, avoiding name collisions with - existing variables. - - By convention, all new variables are of the form - . - """ - pre_len = len(var_prefix) - existing_var_nums = set() - if existing_vars: - for v in existing_vars: - if v.name.startswith(var_prefix) and v.name[pre_len:].isdigit(): - existing_var_nums.add(int(v.name[pre_len:])) - if existing_var_nums: - counter = itertools.count(max(existing_var_nums) + 1) - else: - counter = itertools.count(0) - new_vars = [] - for t in types: - new_var_name = f"{var_prefix}{next(counter)}" - new_var = Variable(new_var_name, t) - new_vars.append(new_var) - return new_vars - - -def param_option_to_nsrt(param_option: ParameterizedOption, - nsrts: Set[NSRT]) -> NSRT: - """If options and NSRTs are 1:1, then map an option to an NSRT.""" - nsrt_matches = [n for n in nsrts if n.option == param_option] - assert len(nsrt_matches) == 1 - nsrt = nsrt_matches[0] - return nsrt - - -def option_to_ground_nsrt(option: _Option, nsrts: Set[NSRT]) -> _GroundNSRT: - """If options and NSRTs are 1:1, then map an option to an NSRT.""" - nsrt = param_option_to_nsrt(option.parent, nsrts) - return nsrt.ground(option.objects) - - -_S = TypeVar("_S", bound=Hashable) # state in heuristic search -_A = TypeVar("_A") # action in heuristic search - - -@dataclass(frozen=True) -class _HeuristicSearchNode(Generic[_S, _A]): - state: _S - edge_cost: float - cumulative_cost: float - parent: Optional[_HeuristicSearchNode[_S, _A]] = None - action: Optional[_A] = None - - -def _run_heuristic_search( - initial_state: _S, - check_goal: Callable[[_S], bool], - get_successors: Callable[[_S], Iterator[Tuple[_A, _S, float]]], - get_priority: Callable[[_HeuristicSearchNode[_S, _A]], Any], - max_expansions: int = 10000000, - max_evals: int = 10000000, - timeout: int = 10000000, - lazy_expansion: bool = False) -> Tuple[List[_S], List[_A]]: - """A generic heuristic search implementation. - - Depending on get_priority, can implement A*, GBFS, or UCS. - - If no goal is found, returns the state with the best priority. - """ - queue: List[Tuple[Any, int, _HeuristicSearchNode[_S, _A]]] = [] - state_to_best_path_cost: Dict[_S, float] = \ - defaultdict(lambda: float("inf")) - - root_node: _HeuristicSearchNode[_S, _A] = _HeuristicSearchNode( - initial_state, 0, 0) - root_priority = get_priority(root_node) - best_node = root_node - best_node_priority = root_priority - tiebreak = itertools.count() - hq.heappush(queue, (root_priority, next(tiebreak), root_node)) - num_expansions = 0 - num_evals = 1 - start_time = time.perf_counter() - - while len(queue) > 0 and time.perf_counter() - start_time < timeout and \ - num_expansions < max_expansions and num_evals < max_evals: - _, _, node = hq.heappop(queue) - # If we already found a better path here, don't bother. - if state_to_best_path_cost[node.state] < node.cumulative_cost: - continue - # If the goal holds, return. - if check_goal(node.state): - return _finish_plan(node) - num_expansions += 1 - # Generate successors. - for action, child_state, cost in get_successors(node.state): - if time.perf_counter() - start_time >= timeout: - break - child_path_cost = node.cumulative_cost + cost - # If we already found a better path to this child, don't bother. - if state_to_best_path_cost[child_state] <= child_path_cost: - continue - # Add new node. - child_node = _HeuristicSearchNode(state=child_state, - edge_cost=cost, - cumulative_cost=child_path_cost, - parent=node, - action=action) - priority = get_priority(child_node) - num_evals += 1 - hq.heappush(queue, (priority, next(tiebreak), child_node)) - state_to_best_path_cost[child_state] = child_path_cost - if priority < best_node_priority: - best_node_priority = priority - best_node = child_node - # Optimization: if we've found a better child, immediately - # explore the child without expanding the rest of the children. - # Accomplish this by putting the parent node back on the queue. - if lazy_expansion: - hq.heappush(queue, (priority, next(tiebreak), node)) - break - if num_evals >= max_evals: - break - - # Did not find path to goal; return best path seen. - return _finish_plan(best_node) - - -def _finish_plan( - node: _HeuristicSearchNode[_S, _A]) -> Tuple[List[_S], List[_A]]: - """Helper for _run_heuristic_search and run_hill_climbing.""" - rev_state_sequence: List[_S] = [] - rev_action_sequence: List[_A] = [] - - while node.parent is not None: - action = cast(_A, node.action) - rev_action_sequence.append(action) - rev_state_sequence.append(node.state) - node = node.parent - rev_state_sequence.append(node.state) - - return rev_state_sequence[::-1], rev_action_sequence[::-1] - - -def run_gbfs(initial_state: _S, - check_goal: Callable[[_S], bool], - get_successors: Callable[[_S], Iterator[Tuple[_A, _S, float]]], - heuristic: Callable[[_S], float], - max_expansions: int = 10000000, - max_evals: int = 10000000, - timeout: int = 10000000, - lazy_expansion: bool = False) -> Tuple[List[_S], List[_A]]: - """Greedy best-first search.""" - get_priority = lambda n: heuristic(n.state) - return _run_heuristic_search(initial_state, check_goal, get_successors, - get_priority, max_expansions, max_evals, - timeout, lazy_expansion) - - -def run_astar(initial_state: _S, - check_goal: Callable[[_S], bool], - get_successors: Callable[[_S], Iterator[Tuple[_A, _S, float]]], - heuristic: Callable[[_S], float], - max_expansions: int = 10000000, - max_evals: int = 10000000, - timeout: int = 10000000, - lazy_expansion: bool = False) -> Tuple[List[_S], List[_A]]: - """A* search.""" - get_priority = lambda n: heuristic(n.state) + n.cumulative_cost - return _run_heuristic_search(initial_state, check_goal, get_successors, - get_priority, max_expansions, max_evals, - timeout, lazy_expansion) - - -def run_hill_climbing( - initial_state: _S, - check_goal: Callable[[_S], bool], - get_successors: Callable[[_S], Iterator[Tuple[_A, _S, float]]], - heuristic: Callable[[_S], float], - early_termination_heuristic_thresh: Optional[float] = None, - enforced_depth: int = 0, - exhaustive_lookahead: bool = False, - parallelize: bool = False, - verbose: bool = True, - timeout: float = float('inf') -) -> Tuple[List[_S], List[_A], List[float]]: - """Enforced hill climbing local search. - - For each node, this search looks for an improvement up to `enforced_depth`. - If `exhaustive_lookahead` is False (default), for each node, the best child - node is always selected, if that child is - an improvement over the node. If no children improve on the node, look - at the children's children, etc., up to enforced_depth, where enforced_depth - 0 corresponds to simple hill climbing. Terminate when no improvement can - be found. early_termination_heuristic_thresh allows for searching until - heuristic reaches a specified value. - Let b be the branching factor, d be the enforced_depth, this has time - complxity of O(b^{d+1}). - If True, it searches the entire horizon up to the - enforced depth and picks the best overall improvement. - - Lower heuristic is better. - """ - assert enforced_depth >= 0 - cur_node: _HeuristicSearchNode[_S, _A] = _HeuristicSearchNode( - initial_state, 0, 0) - last_heuristic = heuristic(cur_node.state) - heuristics = [last_heuristic] - # visited = {initial_state} # <--- deleted for exhaustive_lookahead - if verbose: - logging.info(f"\n\nStarting hill climbing at state {cur_node.state} " - f"with heuristic {last_heuristic}") - start_time = time.perf_counter() - while True: - visited = {cur_node.state} # <--- added for exhaustive_lookahead - - # Stops when heuristic reaches specified value. - if early_termination_heuristic_thresh is not None \ - and last_heuristic <= early_termination_heuristic_thresh: - break - - if check_goal(cur_node.state): - if verbose: - logging.info("\nTerminating hill climbing, achieved goal") - break - best_heuristic = float("inf") - best_child_node = None - current_depth_nodes = [cur_node] - all_best_heuristics = [] - for depth in range(0, enforced_depth + 1): - if verbose: - logging.info(f"Searching for an improvement at depth {depth}") - # This is a list to ensure determinism. Note that duplicates are - # filtered out in the `child_state in visited` check. - successors_at_depth = [] - for parent in current_depth_nodes: - for action, child_state, cost in get_successors(parent.state): - # Raise error if timeout gets hit. - if time.perf_counter() - start_time > timeout: - raise TimeoutError() - if child_state in visited: - continue - visited.add(child_state) - child_path_cost = parent.cumulative_cost + cost - child_node = _HeuristicSearchNode( - state=child_state, - edge_cost=cost, - cumulative_cost=child_path_cost, - parent=parent, - action=action) - successors_at_depth.append(child_node) - if parallelize: - continue # heuristic computation is parallelized later - child_heuristic = heuristic(child_node.state) - if child_heuristic < best_heuristic: - best_heuristic = child_heuristic - best_child_node = child_node - if parallelize: - # Parallelize the expensive part (heuristic computation). - num_cpus = mp.cpu_count() - fn = lambda n: (heuristic(n.state), n) - with mp.Pool(processes=num_cpus) as p: - for child_heuristic, child_node in p.map( - fn, successors_at_depth): - if child_heuristic < best_heuristic: - best_heuristic = child_heuristic - best_child_node = child_node - all_best_heuristics.append(best_heuristic) - - if not exhaustive_lookahead and last_heuristic > best_heuristic: - # Some improvement found. - if verbose: - logging.info(f"Found an improvement at depth {depth}") - break - # Continue on to the next depth. - current_depth_nodes = successors_at_depth - if not current_depth_nodes: - if verbose: - logging.info( - f"No more successors to explore at depth {depth}.") - break # No need to search deeper if there are no more nodes. - - if verbose: - if exhaustive_lookahead: - logging.info(f"Finished depth {depth}. " - f"Best heuristic so far: {best_heuristic}") - elif last_heuristic <= best_heuristic: - logging.info(f"No improvement found at depth {depth}") - - if best_child_node is None: - if verbose: - logging.info("\nTerminating hill climbing, no more successors") - break - if last_heuristic <= best_heuristic: - if verbose: - logging.info( - "\nTerminating hill climbing, could not improve score") - break - heuristics.extend(all_best_heuristics) - cur_node = best_child_node - last_heuristic = best_heuristic - if verbose: - logging.info(f"\nHill climbing reached new state {cur_node.state} " - f"with heuristic {last_heuristic}") - - states, actions = _finish_plan(cur_node) - # The number of heuristics might not match the plan length perfectly now, - # so we should regenerate them from the final plan. - final_heuristics = [heuristic(s) for s in states] - assert len(states) == len(final_heuristics) - return states, actions, final_heuristics - - -def run_policy_guided_astar( - initial_state: _S, - check_goal: Callable[[_S], bool], - get_valid_actions: Callable[[_S], Iterator[Tuple[_A, float]]], - get_next_state: Callable[[_S, _A], _S], - heuristic: Callable[[_S], float], - policy: Callable[[_S], Optional[_A]], - num_rollout_steps: int, - rollout_step_cost: float, - max_expansions: int = 10000000, - max_evals: int = 10000000, - timeout: int = 10000000, - lazy_expansion: bool = False) -> Tuple[List[_S], List[_A]]: - """Perform A* search, but at each node, roll out a given policy for a given - number of timesteps, creating new successors at each step. - - Stop the rollout prematurely if the policy returns None. - - Note that unlike the other search functions, which take get_successors as - input, this function takes get_valid_actions and get_next_state as two - separate inputs. This is necessary because we need to anticipate the next - state conditioned on the action output by the policy. - - The get_valid_actions generates (action, cost) tuples. For policy-generated - transitions, the costs are ignored, and rollout_step_cost is used instead. - """ - - # Create a new successor function that rolls out the policy first. - # A successor here means: from this state, if you take this sequence of - # actions in order, you'll end up at this final state. - def get_successors(state: _S) -> Iterator[Tuple[List[_A], _S, float]]: - # Get policy-based successors. - policy_state = state - policy_action_seq = [] - policy_cost = 0.0 - for _ in range(num_rollout_steps): - action = policy(policy_state) - valid_actions = {a for a, _ in get_valid_actions(policy_state)} - if action is None or action not in valid_actions: - break - policy_state = get_next_state(policy_state, action) - policy_action_seq.append(action) - policy_cost += rollout_step_cost - yield (list(policy_action_seq), policy_state, policy_cost) - - # Get primitive successors. - for action, cost in get_valid_actions(state): - next_state = get_next_state(state, action) - yield ([action], next_state, cost) - - _, action_subseqs = run_astar(initial_state=initial_state, - check_goal=check_goal, - get_successors=get_successors, - heuristic=heuristic, - max_expansions=max_expansions, - max_evals=max_evals, - timeout=timeout, - lazy_expansion=lazy_expansion) - - # The states are "jumpy", so we need to reconstruct the dense state - # sequence from the action subsequences. We also need to construct a - # flat action sequence. - state = initial_state - state_seq = [state] - action_seq = [] - for action_subseq in action_subseqs: - for action in action_subseq: - action_seq.append(action) - state = get_next_state(state, action) - state_seq.append(state) - - return state_seq, action_seq - - -_RRTState = TypeVar("_RRTState") - - -class RRT(Generic[_RRTState]): - """Rapidly-exploring random tree.""" - - def __init__(self, sample_fn: Callable[[_RRTState], _RRTState], - extend_fn: Callable[[_RRTState, _RRTState], - Iterator[_RRTState]], - collision_fn: Callable[[_RRTState], bool], - distance_fn: Callable[[_RRTState, _RRTState], - float], rng: np.random.Generator, - num_attempts: int, num_iters: int, smooth_amt: int): - self._sample_fn = sample_fn - self._extend_fn = extend_fn - self._collision_fn = collision_fn - self._distance_fn = distance_fn - self._rng = rng - self._num_attempts = num_attempts - self._num_iters = num_iters - self._smooth_amt = smooth_amt - - def query(self, - pt1: _RRTState, - pt2: _RRTState, - sample_goal_eps: float = 0.0) -> Optional[List[_RRTState]]: - """Query the RRT, to get a collision-free path from pt1 to pt2. - - If none is found, returns None. - """ - if self._collision_fn(pt1) or self._collision_fn(pt2): - return None - direct_path = self._try_direct_path(pt1, pt2) - if direct_path is not None: - return direct_path - for _ in range(self._num_attempts): - path = self._rrt_connect(pt1, - goal_sampler=lambda: pt2, - sample_goal_eps=sample_goal_eps) - if path is not None: - return self._smooth_path(path) - return None - - def query_to_goal_fn( - self, - start: _RRTState, - goal_sampler: Callable[[], _RRTState], - goal_fn: Callable[[_RRTState], bool], - sample_goal_eps: float = 0.0) -> Optional[List[_RRTState]]: - """Query the RRT, to get a collision-free path from start to a point - such that goal_fn(point) is True. Uses goal_sampler to sample a target - for a direct path or with probability sample_goal_eps. - - If none is found, returns None. - """ - if self._collision_fn(start): - return None - direct_path = self._try_direct_path(start, goal_sampler()) - if direct_path is not None: - return direct_path - for _ in range(self._num_attempts): - path = self._rrt_connect(start, - goal_sampler, - goal_fn, - sample_goal_eps=sample_goal_eps) - if path is not None: - return self._smooth_path(path) - return None - - def _try_direct_path(self, pt1: _RRTState, - pt2: _RRTState) -> Optional[List[_RRTState]]: - path = [pt1] - for newpt in self._extend_fn(pt1, pt2): - if self._collision_fn(newpt): - return None - path.append(newpt) - return path - - def _rrt_connect( - self, - pt1: _RRTState, - goal_sampler: Callable[[], _RRTState], - goal_fn: Optional[Callable[[_RRTState], bool]] = None, - sample_goal_eps: float = 0.0, - ) -> Optional[List[_RRTState]]: - root = _RRTNode(pt1) - nodes = [root] - - for _ in range(self._num_iters): - # Sample the goal with a small probability, otherwise randomly - # choose a point. - sample_goal = self._rng.random() < sample_goal_eps - samp = goal_sampler() if sample_goal else self._sample_fn(pt1) - min_key = functools.partial(self._get_pt_dist_to_node, samp) - nearest = min(nodes, key=min_key) - reached_goal = False - for newpt in self._extend_fn(nearest.data, samp): - if self._collision_fn(newpt): - break - nearest = _RRTNode(newpt, parent=nearest) - nodes.append(nearest) - else: - reached_goal = sample_goal - # Check goal_fn if defined - if reached_goal or goal_fn is not None and goal_fn(nearest.data): - path = nearest.path_from_root() - return [node.data for node in path] - return None - - def _get_pt_dist_to_node(self, pt: _RRTState, - node: _RRTNode[_RRTState]) -> float: - return self._distance_fn(pt, node.data) - - def _smooth_path(self, path: List[_RRTState]) -> List[_RRTState]: - assert len(path) > 2 - for _ in range(self._smooth_amt): - i = self._rng.integers(0, len(path) - 1) - j = self._rng.integers(0, len(path) - 1) - if abs(i - j) <= 1: - continue - if j < i: - i, j = j, i - shortcut = list(self._extend_fn(path[i], path[j])) - if len(shortcut) < j - i and \ - all(not self._collision_fn(pt) for pt in shortcut): - path = path[:i + 1] + shortcut + path[j + 1:] - return path - - -class BiRRT(RRT[_RRTState]): - """Bidirectional rapidly-exploring random tree.""" - - def query_to_goal_fn( - self, - start: _RRTState, - goal_sampler: Callable[[], _RRTState], - goal_fn: Callable[[_RRTState], bool], - sample_goal_eps: float = 0.0) -> Optional[List[_RRTState]]: - raise NotImplementedError("Can't query to goal function using BiRRT") - - def _rrt_connect( - self, - pt1: _RRTState, - goal_sampler: Callable[[], _RRTState], - goal_fn: Optional[Callable[[_RRTState], bool]] = None, - sample_goal_eps: float = 0.0, - ) -> Optional[List[_RRTState]]: - # goal_fn and sample_goal_eps are unused - pt2 = goal_sampler() - root1, root2 = _RRTNode(pt1), _RRTNode(pt2) - nodes1, nodes2 = [root1], [root2] - - for _ in range(self._num_iters): - if len(nodes1) > len(nodes2): - nodes1, nodes2 = nodes2, nodes1 - samp = self._sample_fn(pt1) - min_key1 = functools.partial(self._get_pt_dist_to_node, samp) - nearest1 = min(nodes1, key=min_key1) - for newpt in self._extend_fn(nearest1.data, samp): - if self._collision_fn(newpt): - break - nearest1 = _RRTNode(newpt, parent=nearest1) - nodes1.append(nearest1) - min_key2 = functools.partial(self._get_pt_dist_to_node, - nearest1.data) - nearest2 = min(nodes2, key=min_key2) - for newpt in self._extend_fn(nearest2.data, nearest1.data): - if self._collision_fn(newpt): - break - nearest2 = _RRTNode(newpt, parent=nearest2) - nodes2.append(nearest2) - else: - path1 = nearest1.path_from_root() - path2 = nearest2.path_from_root() - # This is a tricky case to cover. - if path1[0] != root1: # pragma: no cover - path1, path2 = path2, path1 - assert path1[0] == root1 - path = path1[:-1] + path2[::-1] - return [node.data for node in path] - return None - - -class _RRTNode(Generic[_RRTState]): - """A node for RRT.""" - - def __init__(self, - data: _RRTState, - parent: Optional[_RRTNode[_RRTState]] = None) -> None: - self.data = data - self.parent = parent - - def path_from_root(self) -> List[_RRTNode[_RRTState]]: - """Return the path from the root to this node.""" - sequence = [] - node: Optional[_RRTNode[_RRTState]] = self - while node is not None: - sequence.append(node) - node = node.parent - return sequence[::-1] - - -def strip_predicate(predicate: Predicate) -> Predicate: - """Remove the classifier from the given predicate to make a new Predicate. - - Implement this by replacing the classifier with one that errors. - """ - - def _stripped_classifier(state: State, objects: Sequence[Object]) -> bool: - raise Exception("Stripped classifier should never be called!") - - return Predicate(predicate.name, predicate.types, _stripped_classifier) - - -def strip_task(task: Task, included_predicates: Set[Predicate]) -> Task: - """Create a new task where any excluded goal predicates have their - classifiers removed.""" - stripped_goal: Set[GroundAtom] = set() - for atom in task.goal: - if atom.predicate in included_predicates: - stripped_goal.add(atom) - continue - stripped_pred = strip_predicate(atom.predicate) - stripped_atom = GroundAtom(stripped_pred, atom.objects) - stripped_goal.add(stripped_atom) - return Task(task.init, - stripped_goal, - alt_goal=task.alt_goal, - goal_nl=task.goal_nl) - - -def create_vlm_predicate( - name: str, types: Sequence[Type], - get_vlm_query_str: Callable[[Sequence[Object]], str]) -> VLMPredicate: - """Simple function that creates VLMPredicates with dummy classifiers, which - is the most-common way these need to be created.""" - - def _stripped_classifier( - state: State, - objects: Sequence[Object]) -> bool: # pragma: no cover. - raise Exception("VLM predicate classifier should never be called!") - - return VLMPredicate(name, types, _stripped_classifier, - get_vlm_query_str) # type: ignore[arg-type] - - -def create_llm_by_name( - model_name: str) -> LargeLanguageModel: # pragma: no cover - """Create particular llm using a provided name.""" - if CFG.pretrained_model_service_provider == "openai": - return OpenAILLM(model_name) - if CFG.pretrained_model_service_provider == "google": - return GoogleGeminiLLM(model_name) - if CFG.pretrained_model_service_provider == "openrouter": - return OpenRouterLLM(model_name) - raise ValueError(f"Unknown pretrained model service provider: " - f"{CFG.pretrained_model_service_provider}") - - -def create_vlm_by_name( - model_name: str) -> VisionLanguageModel: # pragma: no cover - """Create particular vlm using a provided name.""" - if CFG.pretrained_model_service_provider == "openai": - return OpenAIVLM(model_name) - if CFG.pretrained_model_service_provider == "google": - return GoogleGeminiVLM(model_name) - if CFG.pretrained_model_service_provider == "openrouter": - return OpenRouterVLM(model_name) - raise ValueError(f"Unknown pretrained model service provider: " - f"{CFG.pretrained_model_service_provider}") - - -def parse_model_output_into_option_plan( - model_prediction: str, objects: Collection[Object], - types: Collection[Type], options: Collection[ParameterizedOption], - parse_continuous_params: bool -) -> List[Tuple[ParameterizedOption, Sequence[Object], Sequence[float]]]: - """Assuming text for an option plan that is predicted as text by a large - model, parse it into a sequence of ParameterizedOptions coupled with a list - of objects and continuous parameters that will be used to ground the - ParameterizedOption. - - We assume the model's output is such that each line is formatted as - option_name(obj0:type0, obj1:type1,...)[continuous_param0, - continuous_param1, ...]. - """ - option_plan: List[Tuple[ParameterizedOption, Sequence[Object], - Sequence[float]]] = [] - # Setup dictionaries enabling us to easily map names to specific - # Python objects during parsing. - option_name_to_option = {op.name: op for op in options} - type_name_to_type = {typ.name: typ for typ in types} - obj_name_to_obj = {o.name: o for o in objects} - options_str_list = model_prediction.split('\n') - for option_str in options_str_list: - option_str_stripped = option_str.strip() - option_name = option_str_stripped.split('(')[0] - # Skip empty option strs. - if not option_str: - continue - if option_name not in option_name_to_option.keys() or \ - "(" not in option_str: - if option_plan: - # Already found some options; stop on first non-option line. - logging.info( - f"Line {option_str} output by model doesn't " - "contain a valid option name. Terminating option plan " - "parsing.") - break - # Skip preamble lines (analysis text before the plan starts). - continue - if parse_continuous_params and "[" not in option_str: - logging.info( - f"Line {option_str} output by model doesn't contain a " - "'[' and is thus improperly formatted.") - break - option = option_name_to_option[option_name] - # Now that we have the option, we need to parse out the objects - # along with specified types. - try: - start_index = option_str_stripped.index('(') + 1 - end_index = option_str_stripped.index(')', start_index) - except ValueError: - logging.info( - f"Line {option_str} output by model is improperly formatted.") - break - typed_objects_str_list = option_str_stripped[ - start_index:end_index].split(',') - objs_list = [] - continuous_params_list = [] - malformed = False - for i, type_object_string in enumerate(typed_objects_str_list): - object_type_str_list = type_object_string.strip().split(':') - # We expect this list to be [object_name, type_name]. - if len(object_type_str_list) != 2: - logging.info(f"Line {option_str} output by model has a " - "malformed object-type list.") - malformed = True - break - object_name = object_type_str_list[0] - type_name = object_type_str_list[1] - if object_name not in obj_name_to_obj.keys(): - logging.info(f"Line {option_str} output by model has an " - "invalid object name.") - malformed = True - break - obj = obj_name_to_obj[object_name] - # Check that the type of this object agrees - # with what's expected given the ParameterizedOption. - if type_name not in type_name_to_type: - logging.info(f"Line {option_str} output by model has an " - "invalid type name.") - malformed = True - break - try: - if option.types[i] not in type_name_to_type[ - type_name].get_ancestors(): - logging.info( - f"Line {option_str} output by model has an " - "invalid type that doesn't agree with the option" - f"{option}") - malformed = True - break - except IndexError: - # In this case, there's more supplied arguments than the - # option has. - logging.info(f"Line {option_str} output by model has an " - "too many object arguments for option" - f"{option}") - malformed = True - break - objs_list.append(obj) - # The types of the objects match, but we haven't yet checked if - # all arguments of the option have an associated object. - if len(objs_list) != len(option.types): - malformed = True - # Now, we attempt to parse out the continuous parameters. - if parse_continuous_params: - params_str_list = option_str_stripped.split('[')[1].strip( - ']').split(',') - for i, continuous_params_str in enumerate(params_str_list): - stripped_continuous_param_str = continuous_params_str.strip() - if len(stripped_continuous_param_str) == 0: - continue - try: - curr_cont_param = float(stripped_continuous_param_str) - except ValueError: - logging.info(f"Line {option_str} output by model has an " - "invalid continouous parameter that can't be" - "converted to a float.") - malformed = True - break - continuous_params_list.append(curr_cont_param) - if len(continuous_params_list) != option.params_space.shape[0]: - logging.info(f"Line {option_str} output by model has " - "invalid continouous parameter(s) that don't " - f"agree with {option}{option.params_space}.") - malformed = True - break - if not malformed: - option_plan.append((option, objs_list, continuous_params_list)) - return option_plan - - -def get_prompt_for_vlm_state_labelling( - prompt_type: str, atoms_list: List[str], label_history: List[str], - imgs_history: List[List[PIL.Image.Image]], - cropped_imgs_history: List[List[PIL.Image.Image]], - skill_history: List[_Option]) -> Tuple[str, List[PIL.Image.Image]]: - """Prompt for labelling atom values in a trajectory. - - Note that all our prompts are saved as separate txt files under the - 'vlm_input_data_prompts/atom_labelling' folder. - """ - # Load the pre-specified prompt. - filepath_prefix = get_path_to_predicators_root() + \ - "/predicators/datasets/vlm_input_data_prompts/atom_labelling/" - try: - with open(filepath_prefix + prompt_type + ".txt", - "r", - encoding="utf-8") as f: - prompt = f.read() - except FileNotFoundError: - raise ValueError("Unknown VLM prompting option " + f"{prompt_type}") - # The prompt ends with a section for 'Predicates', so list these. - for atom_str in atoms_list: - prompt += f"\n{atom_str}" - - if "img_option_diffs" in prompt_type: - # In this case, we need to load the 'per_scene_naive' prompt as well - # for the first timestep. - with open(filepath_prefix + "per_scene_naive.txt", - "r", - encoding="utf-8") as f: - init_prompt = f.read() - for atom_str in atoms_list: - init_prompt += f"\n{atom_str}" - if len(label_history) == 0: - return (init_prompt, imgs_history[0]) - # Now, we use actual difference-based prompting for the second timestep - # and beyond. - curr_prompt = prompt[:] - curr_prompt_imgs = [imgs_history[-2][0], imgs_history[-1][0]] - if CFG.vlm_include_cropped_images: - if CFG.env in ["burger", "burger_no_move"]: # pragma: no cover - curr_prompt_imgs.extend( - [cropped_imgs_history[-1][1], cropped_imgs_history[-1][0]]) - else: - raise NotImplementedError( - f"Cropped images not implemented for {CFG.env}.") - curr_prompt += "\n\nSkill executed between states: " - skill_name = skill_history[-1].name + str(skill_history[-1].objects) - curr_prompt += skill_name - if "label_history" in prompt_type: - curr_prompt += "\n\nPredicate values in the first scene, " \ - "before the skill was executed: \n" - curr_prompt += label_history[-1] - return (curr_prompt, curr_prompt_imgs) - # NOTE: we rip out only the first image from each trajectory - # which is fine for most domains, but will be problematic for - # situations in which there is more than one image per state. - return (prompt, imgs_history[-1]) - - -def query_vlm_for_atom_vals( - vlm_atoms: Collection[GroundAtom], - state: State, - vlm: Optional[VisionLanguageModel] = None) -> Set[GroundAtom]: - """Given a set of ground atoms, queries a VLM and gets the subset of these - atoms that are true.""" - # Short-circuit this function in the case where there are no atoms that - # need be labelled. - if len(vlm_atoms) == 0: - return set() - true_atoms: Set[GroundAtom] = set() - # Get quantities necessary to construct prompt to query VLM. - if state.simulator_state is None: - return true_atoms - assert state.simulator_state is not None - assert isinstance(state.simulator_state["images"], List) - curr_state_imgs = state.simulator_state["images"] - vlm_atoms = sorted(vlm_atoms) - atom_queries_list = [atom.get_vlm_query_str() for atom in vlm_atoms] - prev_states_imgs_history = [] - prev_state_cropped_imgs_history: List[List[PIL.Image.Image]] = [] - if "state_history" in state.simulator_state: # pragma: no cover - prev_states = state.simulator_state["state_history"] - prev_states_imgs_history = [ - s.simulator_state["images"] for s in prev_states - ] - if "cropped_images" in prev_states[0].simulator_state: - prev_states_imgs_history = [ - s.simulator_state["cropped_images"] for s in prev_states - ] - images_history = prev_states_imgs_history + [curr_state_imgs] - skill_history = [] - if "skill_history" in state.simulator_state: # pragma: no cover - skill_history = state.simulator_state["skill_history"] - label_history = [] - if "vlm_label_history" in state.simulator_state: # pragma: no cover - label_history = state.simulator_state["vlm_label_history"] - vlm_query_str, imgs = get_prompt_for_vlm_state_labelling( - CFG.vlm_test_time_atom_label_prompt_type, atom_queries_list, - label_history, images_history, prev_state_cropped_imgs_history, - skill_history) - # Query VLM. - if vlm is None: - vlm = create_vlm_by_name(CFG.vlm_model_name) # pragma: no cover. - if CFG.env in ["pybullet_coffee"]: - vlm_input_imgs = list(imgs) # type: ignore - else: - vlm_input_imgs = \ - [PIL.Image.fromarray(img_arr) for img_arr in imgs] # type: ignore - vlm_output = vlm.sample_completions(vlm_query_str, - vlm_input_imgs, - 0.0, - seed=CFG.seed, - num_completions=1) - assert len(vlm_output) == 1 - vlm_output_str = vlm_output[0] - all_vlm_responses = vlm_output_str.strip().split("\n") - # NOTE: this assumption is likely too brittle; if this is breaking, feel - # free to remove/adjust this and change the below parsing loop accordingly! - if len(atom_queries_list) != len(all_vlm_responses): - return set() - for i, (atom_query, curr_vlm_output_line) in enumerate( - zip(atom_queries_list, all_vlm_responses)): - try: - assert atom_query + ":" in curr_vlm_output_line - assert "." in curr_vlm_output_line - value = curr_vlm_output_line.split(': ')[-1].strip('.').lower() - if value == "true": - true_atoms.add(vlm_atoms[i]) - except AssertionError: # pragma: no cover - continue - return true_atoms - - -def abstract(state: State, - preds: Collection[Predicate], - vlm: Optional[VisionLanguageModel] = None) -> Set[GroundAtom]: - """Get the atomic representation of the given state (i.e., a set of ground - atoms), using the given set of predicates. - - Duplicate arguments in predicates are allowed. Latent-aware - classifiers (`agent_sim_recurrent_predicate_invention`) read their - latent from `state.latent` via `Predicate.holds` — abstract itself - does nothing extra to support them. - """ - # Start by pulling out all VLM predicates. - vlm_preds = set(pred for pred in preds if isinstance(pred, VLMPredicate)) - derived_preds, primitive_preds = set(), set() - for pred in preds: - if isinstance(pred, DerivedPredicate): - derived_preds.add(pred) - else: - primitive_preds.add(pred) - - # Next, classify all non-VLM predicates. - atoms = set() - for pred in primitive_preds: - if pred not in vlm_preds: - for choice in get_object_combinations(list(state), pred.types): - if pred.holds(state, choice): - atoms.add(GroundAtom(pred, choice)) - if len(vlm_preds) > 0: - # Now, aggregate all the VLM predicates and make a single call to a - # VLM to get their values. - vlm_atoms = set() - for pred in vlm_preds: - for choice in get_object_combinations(list(state), pred.types): - vlm_atoms.add(GroundAtom(pred, choice)) - true_vlm_atoms = query_vlm_for_atom_vals(vlm_atoms, state, vlm) - atoms |= true_vlm_atoms - - # Evaluate derived predicates. - if len(derived_preds) > 0: - try: - atoms |= abstract_with_derived_predicates(atoms, derived_preds, - list(state)) - except PredicateEvaluationError as e: - raise e - # buggy_pred = e.pred - # # logging.debug(f"preds before {buggy_pred} is removed: {preds}") - # cnpt_preds.remove(buggy_pred) - # # logging.debug(f"preds after {buggy_pred} is removed: {preds}") - # return abstract(state, prim_preds | cnpt_preds, vlm, - # return_valid_preds) - return atoms - - -def all_ground_operators( - operator: STRIPSOperator, - objects: Collection[Object]) -> Iterator[_GroundSTRIPSOperator]: - """Get all possible groundings of the given operator with the given - objects.""" - types = [p.type for p in operator.parameters] - for choice in get_object_combinations(objects, types): - yield operator.ground(tuple(choice)) - - -def all_ground_operators_given_partial( - operator: STRIPSOperator, objects: Collection[Object], - sub: VarToObjSub) -> Iterator[_GroundSTRIPSOperator]: - """Get all possible groundings of the given operator with the given objects - such that the parameters are consistent with the given substitution.""" - assert set(sub).issubset(set(operator.parameters)) - types = [p.type for p in operator.parameters if p not in sub] - for choice in get_object_combinations(objects, types): - # Complete the choice with the args that are determined from the sub. - choice_lst = list(choice) - choice_lst.reverse() - completed_choice = [] - for p in operator.parameters: - if p in sub: - completed_choice.append(sub[p]) - else: - completed_choice.append(choice_lst.pop()) - assert not choice_lst - ground_op = operator.ground(tuple(completed_choice)) - yield ground_op - - -def all_ground_nsrts(nsrt: Union[NSRT, CausalProcess], - objects: Collection[Object]) -> Iterator[_GroundNSRT]: - """Get all possible groundings of the given NSRT with the given objects.""" - types = [p.type for p in nsrt.parameters] - for choice in get_object_combinations(objects, types): - # only return if there are no repeated arguments - if CFG.no_repeated_arguments_in_grounding: - if len(choice) == len(set(choice)): - yield nsrt.ground(tuple(choice)) # type: ignore[misc] - else: - yield nsrt.ground(tuple(choice)) # type: ignore[misc] - - -def all_ground_nsrts_fd_translator( - nsrts: Set[NSRT], objects: Collection[Object], - predicates: Set[Predicate], types: Set[Type], - init_atoms: Set[GroundAtom], - goal: Set[GroundAtom]) -> Iterator[_GroundNSRT]: - """Get all possible groundings of the given set of NSRTs with the given - objects, using Fast Downward's translator for efficiency.""" - nsrt_name_to_nsrt = {nsrt.name.lower(): nsrt for nsrt in nsrts} - obj_name_to_obj = {obj.name.lower(): obj for obj in objects} - dom_str = create_pddl_domain(nsrts, predicates, types, "mydomain") - prob_str = create_pddl_problem(objects, init_atoms, goal, "mydomain", - "myproblem") - with nostdout(): - sas_task = downward_translate(dom_str, prob_str) # type: ignore - for operator in sas_task.operators: - split_name = operator.name[1:-1].split() # strip out ( and ) - nsrt = nsrt_name_to_nsrt[split_name[0]] - objs = [obj_name_to_obj[name] for name in split_name[1:]] - yield nsrt.ground(objs) - - -def all_possible_ground_atoms(state: State, - preds: Set[Predicate]) -> List[GroundAtom]: - """Get a sorted list of all possible ground atoms in a state given the - predicates. - - Ignores the predicates' classifiers. - """ - objects = frozenset(state) - ground_atoms = set() - for pred in preds: - ground_atoms |= get_all_ground_atoms_for_predicate(pred, objects) - return sorted(ground_atoms) - - -def all_ground_ldl_rules( - rule: LDLRule, - objects: Collection[Object], - static_predicates: Optional[Collection[Predicate]] = None, - init_atoms: Optional[Collection[GroundAtom]] = None -) -> List[_GroundLDLRule]: - """Get all possible groundings of the given rule with the given objects. - - If provided, use the static predicates and init_atoms to avoid - grounding rules that will never have satisfied preconditions in any - state. - """ - if static_predicates is None: - static_predicates = set() - if init_atoms is None: - init_atoms = set() - return _cached_all_ground_ldl_rules(rule, frozenset(objects), - frozenset(static_predicates), - frozenset(init_atoms)) - - -@functools.lru_cache(maxsize=None) -def _cached_all_ground_ldl_rules( - rule: LDLRule, objects: FrozenSet[Object], - static_predicates: FrozenSet[Predicate], - init_atoms: FrozenSet[GroundAtom]) -> List[_GroundLDLRule]: - """Helper for all_ground_ldl_rules() that caches the outputs.""" - ground_rules = [] - # Use static preconds to reduce the map of parameters to possible objects. - # For example, if IsBall(?x) is a positive state precondition, then only - # the objects that appear in init_atoms with IsBall could bind to ?x. - # For now, we just check unary static predicates, since that covers the - # common case where such predicates are used in place of types. - # Create map from each param to unary static predicates. - param_to_pos_preds: Dict[Variable, Set[Predicate]] = { - p: set() - for p in rule.parameters - } - param_to_neg_preds: Dict[Variable, Set[Predicate]] = { - p: set() - for p in rule.parameters - } - for (preconditions, param_to_preds) in [ - (rule.pos_state_preconditions, param_to_pos_preds), - (rule.neg_state_preconditions, param_to_neg_preds), - ]: - for atom in preconditions: - pred = atom.predicate - if pred in static_predicates and pred.arity == 1: - param = atom.variables[0] - param_to_preds[param].add(pred) - # Create the param choices, filtering based on the unary static atoms. - param_choices = [] # list of lists of possible objects for each param - # Preprocess the atom sets for faster lookups. - init_atom_tups = {(a.predicate, tuple(a.objects)) for a in init_atoms} - for param in rule.parameters: - choices = [] - for obj in objects: - # Types must match, as usual. - if obj.type != param.type: - continue - # Check the static conditions. - binding_valid = True - for pred in param_to_pos_preds[param]: - if (pred, (obj, )) not in init_atom_tups: - binding_valid = False - break - for pred in param_to_neg_preds[param]: - if (pred, (obj, )) in init_atom_tups: - binding_valid = False - break - if binding_valid: - choices.append(obj) - # Must be sorted for consistency with other grounding code. - param_choices.append(sorted(choices)) - for choice in itertools.product(*param_choices): - ground_rule = rule.ground(choice) - ground_rules.append(ground_rule) - return ground_rules - - -def parse_ldl_from_str(ldl_str: str, types: Collection[Type], - predicates: Collection[Predicate], - nsrts: Collection[NSRT]) -> LiftedDecisionList: - """Parse a lifted decision list from a string representation of it.""" - parser = _LDLParser(types, predicates, nsrts) - return parser.parse(ldl_str) - - -class _LDLParser: - """Parser for lifted decision lists from strings.""" - - def __init__(self, types: Collection[Type], - predicates: Collection[Predicate], - nsrts: Collection[NSRT]) -> None: - self._nsrt_name_to_nsrt = {nsrt.name.lower(): nsrt for nsrt in nsrts} - self._type_name_to_type = {t.name.lower(): t for t in types} - self._predicate_name_to_predicate = { - p.name.lower(): p - for p in predicates - } - - def parse(self, ldl_str: str) -> LiftedDecisionList: - """Run parsing.""" - ldl_str = ldl_str.lower() # ignore case during parsing - rules = [] - rule_matches = re.finditer(r"\(:rule", ldl_str) - for start in rule_matches: - rule_str = find_balanced_expression(ldl_str, start.start()) - rule = self._parse_rule(rule_str) - rules.append(rule) - return LiftedDecisionList(rules) - - def _parse_rule(self, rule_str: str) -> LDLRule: - rule_pattern = r"\(:rule(.*):parameters(.*):preconditions(.*)" + \ - r":goals(.*):action(.*)\)" - match_result = re.match(rule_pattern, rule_str, re.DOTALL) - assert match_result is not None - # Remove white spaces. - matches = [m.strip().rstrip() for m in match_result.groups()] - # Unpack the matches. - rule_name, params_str, preconds_str, goals_str, nsrt_str = matches - # Handle the parameters. - assert "?" in params_str, "Assuming all rules have parameters." - variable_name_to_variable = {} - assert params_str.endswith(")") - for param_str in params_str[:-1].split("?")[1:]: - param_name, param_type_str = param_str.split("-") - param_name = param_name.strip() - param_type_str = param_type_str.strip() - variable_name = "?" + param_name - param_type = self._type_name_to_type[param_type_str] - variable = Variable(variable_name, param_type) - variable_name_to_variable[variable_name] = variable - # Handle the preconditions. - pos_preconds, neg_preconds = self._parse_lifted_atoms( - preconds_str, variable_name_to_variable) - # Handle the goals. - pos_goals, neg_goals = self._parse_lifted_atoms( - goals_str, variable_name_to_variable) - assert not neg_goals, "Negative LDL goals not currently supported" - # Handle the NSRT. - nsrt = self._parse_into_nsrt(nsrt_str, variable_name_to_variable) - # Finalize the rule. - params = sorted(variable_name_to_variable.values()) - return LDLRule(rule_name, params, pos_preconds, neg_preconds, - pos_goals, nsrt) - - def _parse_lifted_atoms( - self, atoms_str: str, variable_name_to_variable: Dict[str, Variable] - ) -> Tuple[Set[LiftedAtom], Set[LiftedAtom]]: - """Parse the given string (representing either preconditions or - effects) into a set of positive lifted atoms and a set of negative - lifted atoms. - - Check against params to make sure typing is correct. - """ - assert atoms_str[0] == "(" - assert atoms_str[-1] == ")" - - # Handle conjunctions. - if atoms_str.startswith("(and") and atoms_str[4] in (" ", "\n", "("): - clauses = find_all_balanced_expressions(atoms_str[4:-1].strip()) - pos_atoms, neg_atoms = set(), set() - for clause in clauses: - clause_pos_atoms, clause_neg_atoms = self._parse_lifted_atoms( - clause, variable_name_to_variable) - pos_atoms |= clause_pos_atoms - neg_atoms |= clause_neg_atoms - return pos_atoms, neg_atoms - - # Handle negations. - if atoms_str.startswith("(not") and atoms_str[4] in (" ", "\n", "("): - # Only contains a single literal inside not. - split_strs = atoms_str[4:-1].strip()[1:-1].strip().split() - pred = self._predicate_name_to_predicate[split_strs[0]] - args = [variable_name_to_variable[arg] for arg in split_strs[1:]] - lifted_atom = LiftedAtom(pred, args) - return set(), {lifted_atom} - - # Handle single positive atoms. - split_strs = atoms_str[1:-1].split() - # Empty conjunction. - if not split_strs: - return set(), set() - pred = self._predicate_name_to_predicate[split_strs[0]] - args = [variable_name_to_variable[arg] for arg in split_strs[1:]] - lifted_atom = LiftedAtom(pred, args) - return {lifted_atom}, set() - - def _parse_into_nsrt( - self, nsrt_str: str, - variable_name_to_variable: Dict[str, Variable]) -> NSRT: - """Parse the given string into an NSRT.""" - assert nsrt_str[0] == "(" - assert nsrt_str[-1] == ")" - nsrt_str = nsrt_str[1:-1].split()[0] - nsrt = self._nsrt_name_to_nsrt[nsrt_str] - # Validate parameters. - variables = variable_name_to_variable.values() - for v in nsrt.parameters: - assert v in variables, f"NSRT parameter {v} missing from LDL rule" - return nsrt - - -_T = TypeVar("_T") # element of a set - - -def sample_subsets(universe: Sequence[_T], num_samples: int, min_set_size: int, - max_set_size: int, - rng: np.random.Generator) -> Iterator[Set[_T]]: - """Sample multiple subsets from a universe.""" - assert min_set_size <= max_set_size - assert max_set_size <= len(universe), "Not enough elements in universe" - for _ in range(num_samples): - set_size = rng.integers(min_set_size, max_set_size + 1) - idxs = rng.choice(np.arange(len(universe)), - size=set_size, - replace=False) - sample = {universe[i] for i in idxs} - yield sample - - -def create_dataset_filename_str( - saving_ground_atoms: bool, - online_learning_cycle: Optional[int] = None) -> Tuple[str, str]: - """Generate strings to be used for the filename for a dataset file that is - about to be saved. - - Returns a tuple of strings where the first element is the dataset - filename itself and the second is a template string used to generate - it. If saving_ground_atoms is True, then we will name the file with - a "_ground_atoms" suffix. - """ - # Setup the dataset filename for saving/loading GroundAtoms. - regex = r"(\d+)" - suffix_str = "" - suffix_str += f"__{online_learning_cycle}" - if saving_ground_atoms: - suffix_str += "__ground_atoms" - suffix_str += ".data" - dataset_fname_template = ( - f"{CFG.env}__{CFG.offline_data_method}__{CFG.demonstrator}__" - f"{regex}__{CFG.included_options}__{CFG.seed}" + suffix_str) - dataset_fname = os.path.join( - CFG.data_dir, - dataset_fname_template.replace(regex, str(CFG.num_train_tasks))) - return dataset_fname, dataset_fname_template - - -def create_ground_atom_dataset( - trajectories: Sequence[LowLevelTrajectory], - predicates: Set[Predicate]) -> List[GroundAtomTrajectory]: - """Apply all predicates to all trajectories in the dataset.""" - ground_atom_dataset = [] - for traj in trajectories: - atoms = [abstract(s, predicates) for s in traj.states] - ground_atom_dataset.append((traj, atoms)) - return ground_atom_dataset - - -def create_ground_atom_option_dataset( - trajectories: List[LowLevelTrajectory], - predicates: Set[Predicate]) -> List[AtomOptionTrajectory]: - """Apply all predicates to all trajectories in the dataset and also - annotate with options (HLA).""" - ground_atom_option_dataset = [] - for traj in trajectories: - # Note: this is currently just based on the current states. - # We may want to extend this to state history in the future. - atoms = [abstract(s, predicates) for s in traj.states] - options = [a.get_option() for a in traj.actions] - ground_atom_option_dataset.append( - AtomOptionTrajectory( - traj.states, atoms, options, traj.is_demo, - traj.train_task_idx if traj.is_demo else None)) - return ground_atom_option_dataset - - -def prune_ground_atom_dataset( - ground_atom_dataset: List[GroundAtomTrajectory], - kept_predicates: Collection[Predicate]) -> List[GroundAtomTrajectory]: - """Create a new ground atom dataset by keeping only some predicates.""" - new_ground_atom_dataset = [] - for traj, atoms in ground_atom_dataset: - assert len(traj.states) == len(atoms) - kept_atoms = [{a - for a in sa if a.predicate in kept_predicates} - for sa in atoms] - new_ground_atom_dataset.append((traj, kept_atoms)) - return new_ground_atom_dataset - - -def load_ground_atom_dataset( - dataset_fname: str, - trajectories: List[LowLevelTrajectory]) -> List[GroundAtomTrajectory]: - """Load a previously-saved ground atom dataset. - - Note importantly that we only save atoms themselves, we don't save - the low-level trajectory information that's necessary to make - GroundAtomTrajectories given series of ground atoms (that info can - be saved separately, in case one wants to just load trajectories and - not also load ground atoms). Thus, this function needs to take these - trajectories as input. - """ - os.makedirs(CFG.data_dir, exist_ok=True) - # Check that the dataset file was previously saved. - ground_atom_dataset_atoms: Optional[List[List[Set[GroundAtom]]]] = [] - if os.path.exists(dataset_fname): - # Load the ground atoms dataset. - with open(dataset_fname, "rb") as f: - ground_atom_dataset_atoms = pkl.load(f) - assert ground_atom_dataset_atoms is not None - assert len(trajectories) == len(ground_atom_dataset_atoms) - logging.info("\n\nLOADED GROUND ATOM DATASET") - - # The saved ground atom dataset consists only of sequences - # of sets of GroundAtoms, we need to recombine this with - # the LowLevelTrajectories to create a GroundAtomTrajectory. - ground_atom_dataset = [] - for i, traj in enumerate(trajectories): - ground_atom_seq = ground_atom_dataset_atoms[i] - ground_atom_dataset.append( - (traj, [set(atoms) for atoms in ground_atom_seq])) - else: - raise ValueError(f"Cannot load ground atoms: {dataset_fname}") - return ground_atom_dataset - - -def save_ground_atom_dataset(ground_atom_dataset: List[GroundAtomTrajectory], - dataset_fname: str) -> None: - """Saves a given ground atom dataset so it can be loaded in the future.""" - # Save ground atoms dataset to file. Note that a - # GroundAtomTrajectory contains a normal LowLevelTrajectory and a - # list of sets of GroundAtoms, so we only save the list of - # GroundAtoms (the LowLevelTrajectories are saved separately). - ground_atom_dataset_to_pkl = [] - for gt_traj in ground_atom_dataset: - trajectory = [] - for ground_atom_set in gt_traj[1]: - trajectory.append(ground_atom_set) - ground_atom_dataset_to_pkl.append(trajectory) - with open(dataset_fname, "wb") as f: - pkl.dump(ground_atom_dataset_to_pkl, f) - - -def merge_ground_atom_datasets( - gad1: List[GroundAtomTrajectory], - gad2: List[GroundAtomTrajectory]) -> List[GroundAtomTrajectory]: - """Merges two ground atom datasets sharing the same underlying low-level - trajectory via the union of ground atoms at each state.""" - assert len(gad1) == len( - gad2), "Ground atom datasets must be of the same length to merge them." - merged_ground_atom_dataset = [] - for ground_atom_traj1, ground_atom_traj2 in zip(gad1, gad2): - ll_traj1, ga_list1 = ground_atom_traj1 - ll_traj2, ga_list2 = ground_atom_traj2 - assert ll_traj1 == ll_traj2, "Ground atom trajectories must share " \ - "the same low-level trajectory to be able to merge them." - merged_ga_list = [ga1 | ga2 for ga1, ga2 in zip(ga_list1, ga_list2)] - merged_ground_atom_dataset.append((ll_traj1, merged_ga_list)) - return merged_ground_atom_dataset - - -def extract_preds_and_types( - ops: Collection[NSRTOrSTRIPSOperator] -) -> Tuple[Dict[str, Predicate], Dict[str, Type]]: - """Extract the predicates and types used in the given operators.""" - preds = {} - types = {} - for op in ops: - for atom in op.preconditions | op.add_effects | op.delete_effects: - for var_type in atom.predicate.types: - types[var_type.name] = var_type - preds[atom.predicate.name] = atom.predicate - return preds, types - - -def get_static_preds(ops: Collection[NSRTOrSTRIPSOperator], - predicates: Collection[Predicate]) -> Set[Predicate]: - """Get the subset of predicates from the given set that are static with - respect to the given lifted operators.""" - static_preds = set() - for pred in predicates: - # This predicate is not static if it appears in any op's effects. - if any( - any(atom.predicate == pred for atom in op.add_effects) or any( - atom.predicate == pred for atom in op.delete_effects) - for op in ops): - continue - static_preds.add(pred) - return static_preds - - -def get_static_atoms(ground_ops: Collection[GroundNSRTOrSTRIPSOperator], - atoms: Collection[GroundAtom]) -> Set[GroundAtom]: - """Get the subset of atoms from the given set that are static with respect - to the given ground operators. - - Note that this can include MORE than simply the set of atoms whose - predicates are static, because now we have ground operators. - """ - static_atoms = set() - for atom in atoms: - # This atom is not static if it appears in any op's effects. - if any( - any(atom == eff for eff in op.add_effects) or any( - atom == eff for eff in op.delete_effects) - for op in ground_ops): - continue - static_atoms.add(atom) - return static_atoms - - -def get_reachable_atoms(ground_ops: Collection[GroundNSRTOrSTRIPSOperator], - atoms: Collection[GroundAtom]) -> Set[GroundAtom]: - """Get all atoms that are reachable from the init atoms.""" - reachables = set(atoms) - while True: - fixed_point_reached = True - for op in ground_ops: - if op.preconditions.issubset(reachables): - for new_reachable_atom in op.add_effects - reachables: - fixed_point_reached = False - reachables.add(new_reachable_atom) - if fixed_point_reached: - break - return reachables - - -def get_applicable_operators( - ground_ops: Collection[Union[GroundNSRTOrSTRIPSOperator, - _GroundEndogenousProcess]], - atoms: Collection[GroundAtom] -) -> Iterator[Union[GroundNSRTOrSTRIPSOperator, _GroundEndogenousProcess]]: - """Iterate over ground operators whose preconditions are satisfied. - - Note: the order may be nondeterministic. Users should be invariant. - """ - for op in ground_ops: - if isinstance(op, (_GroundNSRT, _GroundSTRIPSOperator)): - applicable = op.preconditions.issubset(atoms) - elif isinstance(op, _GroundEndogenousProcess): - applicable = op.condition_at_start.issubset(atoms) - - if applicable: - yield op - - -def apply_operator(op: GroundNSRTOrSTRIPSOperator, - atoms: Set[GroundAtom]) -> Set[GroundAtom]: - """Get a next set of atoms given a current set and a ground operator.""" - # Note that we are removing the ignore effects before the - # application of the operator, because if the ignore effect - # appears in the effects, we still know that the effects - # will be true, so we don't want to remove them. - new_atoms = {a for a in atoms if a.predicate not in op.ignore_effects} - for atom in op.delete_effects: - new_atoms.discard(atom) - for atom in op.add_effects: - new_atoms.add(atom) - return new_atoms +This is the kitchen-sink utils module. It re-exports the env-safe +subset from :mod:`predicators.utils_lite` (so existing callers that +use ``utils.X`` keep working unchanged) and then defines helpers that +depend on heavy libraries (torch, scipy, imageio, the pretrained-model +SDKs). - -def compute_necessary_atoms_seq( - skeleton: List[_GroundNSRT], atoms_seq: List[Set[GroundAtom]], - goal: Set[GroundAtom]) -> List[Set[GroundAtom]]: - """Given a skeleton and a corresponding atoms sequence, return a - 'necessary' atoms sequence that includes only the necessary image at each - step.""" - necessary_atoms_seq = [set(goal)] - necessary_image = set(goal) - for t in range(len(atoms_seq) - 2, -1, -1): - curr_nsrt = skeleton[t] - necessary_image -= set(curr_nsrt.add_effects) - necessary_image |= set(curr_nsrt.preconditions) - necessary_atoms_seq = [set(necessary_image)] + necessary_atoms_seq - return necessary_atoms_seq - - -def get_successors_from_ground_ops( - atoms: Set[GroundAtom], - ground_ops: Collection[GroundNSRTOrSTRIPSOperator], - unique: bool = True) -> Iterator[Set[GroundAtom]]: - """Get all next atoms from ground operators. - - If unique is true, only yield each unique successor once. - """ - seen_successors = set() - for ground_op in get_applicable_operators(ground_ops, atoms): - next_atoms = apply_operator(ground_op, atoms) # type: ignore[type-var] - if unique: - frozen_next_atoms = frozenset(next_atoms) - if frozen_next_atoms in seen_successors: - continue - seen_successors.add(frozen_next_atoms) - yield next_atoms - - -def ops_and_specs_to_dummy_nsrts( - strips_ops: Sequence[STRIPSOperator], - option_specs: Sequence[OptionSpec]) -> Set[NSRT]: - """Create NSRTs from strips operators and option specs with dummy - samplers.""" - assert len(strips_ops) == len(option_specs) - nsrts = set() - for op, (param_option, option_vars) in zip(strips_ops, option_specs): - nsrt = op.make_nsrt( - param_option, - option_vars, # dummy sampler - lambda s, g, rng, o: np.zeros(1, dtype=np.float32)) - nsrts.add(nsrt) - return nsrts - - -# Note: create separate `heuristics.py` module if we need to add new -# heuristics in the future. - - -def create_task_planning_heuristic( - heuristic_name: str, - init_atoms: Set[GroundAtom], - goal: Set[GroundAtom], - ground_ops: Collection[GroundNSRTOrSTRIPSOperator], - predicates: Collection[Predicate], - objects: Collection[Object], -) -> _TaskPlanningHeuristic: - """Create a task planning heuristic that consumes ground atoms and - estimates the cost-to-go.""" - if heuristic_name in _PYPERPLAN_HEURISTICS: - return _create_pyperplan_heuristic(heuristic_name, init_atoms, goal, - ground_ops, predicates, objects) - if heuristic_name == GoalCountHeuristic.HEURISTIC_NAME: - return GoalCountHeuristic(heuristic_name, init_atoms, goal, ground_ops) - raise ValueError(f"Unrecognized heuristic name: {heuristic_name}.") - - -@dataclass(frozen=True) -class _TaskPlanningHeuristic: - """A task planning heuristic.""" - name: str - init_atoms: Collection[GroundAtom] - goal: Set[GroundAtom] - ground_ops: Collection[Union[_GroundNSRT, _GroundSTRIPSOperator]] - - def __call__(self, atoms: Collection[GroundAtom]) -> float: - raise NotImplementedError("Override me!") - - -class GoalCountHeuristic(_TaskPlanningHeuristic): - """The number of goal atoms that are not in the current state.""" - HEURISTIC_NAME: ClassVar[str] = "goal_count" - - def __call__(self, atoms: Collection[GroundAtom]) -> float: - return len(self.goal.difference(atoms)) - - -############################### Pyperplan Glue ############################### - - -def _create_pyperplan_heuristic( - heuristic_name: str, - init_atoms: Set[GroundAtom], - goal: Set[GroundAtom], - ground_ops: Collection[GroundNSRTOrSTRIPSOperator], - predicates: Collection[Predicate], - objects: Collection[Object], -) -> _PyperplanHeuristicWrapper: - """Create a pyperplan heuristic that inherits from - _TaskPlanningHeuristic.""" - assert heuristic_name in _PYPERPLAN_HEURISTICS - static_atoms = get_static_atoms(ground_ops, init_atoms) - pyperplan_heuristic_cls = _PYPERPLAN_HEURISTICS[heuristic_name] - pyperplan_task = _create_pyperplan_task(init_atoms, goal, ground_ops, - predicates, objects, static_atoms) - pyperplan_heuristic = pyperplan_heuristic_cls(pyperplan_task) - pyperplan_goal = _atoms_to_pyperplan_facts(goal - static_atoms) - return _PyperplanHeuristicWrapper(heuristic_name, init_atoms, goal, - ground_ops, static_atoms, - pyperplan_heuristic, pyperplan_goal) - - -_PyperplanFacts = FrozenSet[str] - - -@dataclass(frozen=True) -class _PyperplanNode: - """Container glue for pyperplan heuristics.""" - state: _PyperplanFacts - goal: _PyperplanFacts - - -@dataclass(frozen=True) -class _PyperplanOperator: - """Container glue for pyperplan heuristics.""" - name: str - preconditions: _PyperplanFacts - add_effects: _PyperplanFacts - del_effects: _PyperplanFacts - - -@dataclass(frozen=True) -class _PyperplanTask: - """Container glue for pyperplan heuristics.""" - facts: _PyperplanFacts - initial_state: _PyperplanFacts - goals: _PyperplanFacts - operators: Collection[_PyperplanOperator] - - -@dataclass(frozen=True) -class _PyperplanHeuristicWrapper(_TaskPlanningHeuristic): - """A light wrapper around pyperplan's heuristics.""" - _static_atoms: Set[GroundAtom] - _pyperplan_heuristic: _PyperplanBaseHeuristic - _pyperplan_goal: _PyperplanFacts - - def __call__(self, atoms: Collection[GroundAtom]) -> float: - # Note: filtering out static atoms. - pyperplan_facts = _atoms_to_pyperplan_facts(set(atoms) \ - - self._static_atoms) - return self._evaluate(pyperplan_facts, self._pyperplan_goal, - self._pyperplan_heuristic) - - @staticmethod - @functools.lru_cache(maxsize=None) - def _evaluate(pyperplan_facts: _PyperplanFacts, - pyperplan_goal: _PyperplanFacts, - pyperplan_heuristic: _PyperplanBaseHeuristic) -> float: - pyperplan_node = _PyperplanNode(pyperplan_facts, pyperplan_goal) - logging.disable(logging.DEBUG) - result = pyperplan_heuristic(pyperplan_node) - logging.disable(logging.NOTSET) - return result - - -def _create_pyperplan_task( - init_atoms: Set[GroundAtom], - goal: Set[GroundAtom], - ground_ops: Collection[GroundNSRTOrSTRIPSOperator], - predicates: Collection[Predicate], - objects: Collection[Object], - static_atoms: Set[GroundAtom], -) -> _PyperplanTask: - """Helper glue for pyperplan heuristics.""" - all_atoms = set() - for predicate in predicates: - all_atoms.update( - get_all_ground_atoms_for_predicate(predicate, frozenset(objects))) - # Note: removing static atoms. - pyperplan_facts = _atoms_to_pyperplan_facts(all_atoms - static_atoms) - pyperplan_state = _atoms_to_pyperplan_facts(init_atoms - static_atoms) - pyperplan_goal = _atoms_to_pyperplan_facts(goal - static_atoms) - pyperplan_operators = set() - for op in ground_ops: - # Note: the pyperplan operator must include the objects, because hFF - # uses the operator name in constructing the relaxed plan, and the - # relaxed plan is a set. If we instead just used op.name, there would - # be a very nasty bug where two ground operators in the relaxed plan - # that have different objects are counted as just one. - name = op.name + "-".join(o.name for o in op.objects) - pyperplan_operator = _PyperplanOperator( - name, - # Note: removing static atoms from preconditions. - _atoms_to_pyperplan_facts(op.preconditions - static_atoms), - _atoms_to_pyperplan_facts(op.add_effects), - _atoms_to_pyperplan_facts(op.delete_effects)) - pyperplan_operators.add(pyperplan_operator) - return _PyperplanTask(pyperplan_facts, pyperplan_state, pyperplan_goal, - pyperplan_operators) - - -@functools.lru_cache(maxsize=None) -def _atom_to_pyperplan_fact(atom: GroundAtom) -> str: - """Convert atom to tuple for interface with pyperplan.""" - arg_str = " ".join(o.name for o in atom.objects) - return f"({atom.predicate.name} {arg_str})" - - -def _atoms_to_pyperplan_facts( - atoms: Collection[GroundAtom]) -> _PyperplanFacts: - """Light wrapper around _atom_to_pyperplan_fact() that operates on a - collection of atoms.""" - return frozenset({_atom_to_pyperplan_fact(atom) for atom in atoms}) - - -############################## End Pyperplan Glue ############################## - - -def create_pddl_types_str(types: Collection[Type]) -> str: - """Create a PDDL-style types string that handles hierarchy correctly.""" - # Case 1: no type hierarchy. - if all(t.parent is None for t in types): - types_str = " ".join(t.name for t in sorted(types)) - # Case 2: type hierarchy. - else: - parent_to_children_types: Dict[Type, - List[Type]] = {t: [] - for t in types} - for t in sorted(types): - if t.parent: - parent_to_children_types[t.parent].append(t) - types_str = "" - for parent_type in sorted(parent_to_children_types): - child_types = parent_to_children_types[parent_type] - if not child_types: - # Special case: type has no children and also does not appear - # as a child of another type. - is_child_type = any( - parent_type in children - for children in parent_to_children_types.values()) - if not is_child_type: - types_str += f"\n {parent_type.name}" - # Otherwise, the type will appear as a child elsewhere. - else: - child_type_str = " ".join(t.name for t in child_types) - types_str += f"\n {child_type_str} - {parent_type.name}" - return types_str - - -def create_pddl_domain(operators: Collection[NSRTOrSTRIPSOperator], - predicates: Collection[Predicate], - types: Collection[Type], domain_name: str) -> str: - """Create a PDDL domain str from STRIPSOperators or NSRTs.""" - # Sort everything to ensure determinism. - preds_lst = sorted(predicates) - types_str = create_pddl_types_str(types) - ops_lst = sorted(operators) - preds_str = "\n ".join(pred.pddl_str() for pred in preds_lst) - ops_strs = "\n\n ".join(op.pddl_str() for op in ops_lst) - return f"""(define (domain {domain_name}) - (:requirements :typing) - (:types {types_str}) - - (:predicates\n {preds_str} - ) - - {ops_strs} -)""" - - -def create_pddl_problem(objects: Collection[Object], - init_atoms: Collection[GroundAtom], - goal: Set[GroundAtom], domain_name: str, - problem_name: str) -> str: - """Create a PDDL problem str.""" - # Sort everything to ensure determinism. - objects_lst = sorted(objects) - init_atoms_lst = sorted(init_atoms) - goal_lst = sorted(goal) - objects_str = "\n ".join(f"{o.name} - {o.type.name}" - for o in objects_lst) - init_str = "\n ".join(atom.pddl_str() for atom in init_atoms_lst) - goal_str = "\n ".join(atom.pddl_str() for atom in goal_lst) - return f"""(define (problem {problem_name}) (:domain {domain_name}) - (:objects\n {objects_str} - ) - (:init\n {init_str} - ) - (:goal (and {goal_str})) -) +Browser / Pyodide-targeted code should import +:mod:`predicators.utils_lite` directly instead of this module. """ +from __future__ import annotations -@functools.lru_cache(maxsize=None) -def get_failure_predicate(option: ParameterizedOption, - idxs: Tuple[int]) -> Predicate: - """Create a Failure predicate for a parameterized option.""" - idx_str = ",".join(map(str, idxs)) - arg_types = [option.types[i] for i in idxs] - return Predicate(f"{option.name}Failed_arg{idx_str}", - arg_types, - _classifier=lambda s, o: False) - - -def _get_idxs_to_failure_predicate( - option: ParameterizedOption, - max_arity: int = 1) -> Dict[Tuple[int, ...], Predicate]: - """Helper for get_all_failure_predicates() and get_failure_atoms().""" - idxs_to_failure_predicate: Dict[Tuple[int, ...], Predicate] = {} - num_types = len(option.types) - max_num_idxs = min(max_arity, num_types) - all_idxs = list(range(num_types)) - for arity in range(1, max_num_idxs + 1): - for idxs in itertools.combinations(all_idxs, arity): - pred = get_failure_predicate(option, idxs) - idxs_to_failure_predicate[idxs] = pred - return idxs_to_failure_predicate - - -def get_all_failure_predicates(options: Set[ParameterizedOption], - max_arity: int = 1) -> Set[Predicate]: - """Get all possible failure predicates.""" - failure_preds: Set[Predicate] = set() - for param_opt in options: - preds = _get_idxs_to_failure_predicate(param_opt, max_arity=max_arity) - failure_preds.update(preds.values()) - return failure_preds - - -def get_failure_atoms(failed_options: Collection[_Option], - max_arity: int = 1) -> Set[GroundAtom]: - """Get ground failure atoms for the collection of failure options.""" - failure_atoms: Set[GroundAtom] = set() - failed_option_specs = {(o.parent, tuple(o.objects)) - for o in failed_options} - for (param_opt, objs) in failed_option_specs: - preds = _get_idxs_to_failure_predicate(param_opt, max_arity=max_arity) - for idxs, pred in preds.items(): - obj_for_idxs = [objs[i] for i in idxs] - failure_atom = GroundAtom(pred, obj_for_idxs) - failure_atoms.add(failure_atom) - return failure_atoms - - -@dataclass -class VideoMonitor(LoggingMonitor): - """A monitor that renders each state and action encountered. - - The render_fn is generally env.render. Note that the state is unused - because the environment should use its current internal state to - render. - """ - _render_fn: Callable[[Optional[Action], Optional[str]], Video] - _video: Video = field(init=False, default_factory=list) - - def reset(self, train_or_test: str, task_idx: int) -> None: - self._video = [] - - def observe(self, obs: Observation, action: Optional[Action]) -> None: - del obs # unused - self._video.extend(self._render_fn(action, None)) - - def get_video(self) -> Video: - """Return the video.""" - return self._video - - -@dataclass -class SimulateVideoMonitor(LoggingMonitor): - """A monitor that calls render_state on each state and action seen. - - This monitor is meant for use with run_policy_with_simulator, as - opposed to VideoMonitor, which is meant for use with run_policy. - """ - _task: Task - _render_state_fn: Callable[[State, Task, Optional[Action]], Video] - _video: Video = field(init=False, default_factory=list) - - def reset(self, train_or_test: str, task_idx: int) -> None: - self._video = [] +import logging +import os +from concurrent.futures import ThreadPoolExecutor +from functools import cached_property +from typing import Any, List, Sequence, Tuple, Union - def observe(self, obs: Observation, action: Optional[Action]) -> None: - assert isinstance(obs, State) - self._video.extend(self._render_state_fn(obs, self._task, action)) +import imageio +import numpy as np +import torch +from scipy.stats import beta as BetaRV - def get_video(self) -> Video: - """Return the video.""" - return self._video +# isort: off +# The wildcard re-export of utils_lite has to come before the explicit +# imports of `predicators.structs` and +# `predicators.pretrained_model_interface` below: utils_lite imports +# structs itself, and bringing utils_lite up first avoids a half-loaded +# structs module when utils.py is the first thing imported in a CPython +# process. +# pylint: disable=wildcard-import,unused-wildcard-import +from predicators.utils_lite import * # noqa: F401, F403 +# pylint: enable=wildcard-import,unused-wildcard-import + +# Underscore-prefixed names aren't re-exported by `import *` per +# Python's defaults. Surface the private utils symbols that external +# callers (planning.py, planning_with_processes.py, tests/test_utils.py) +# rely on explicitly. +# pylint: disable=unused-import +from predicators.utils_lite import ( # noqa: F401 + _abstract_with_derived_predicates, _Geom2D, _PyperplanHeuristicWrapper, + _TaskPlanningHeuristic, +) +# pylint: enable=unused-import +from predicators.pretrained_model_interface import GoogleGeminiLLM, \ + GoogleGeminiVLM, LargeLanguageModel, OpenAILLM, OpenAIVLM, \ + OpenRouterLLM, OpenRouterVLM, VisionLanguageModel +from predicators.settings import CFG +from predicators.structs import DelayDistribution, Video +# isort: on -def create_video_from_partial_refinements( - partial_refinements: Sequence[Tuple[Sequence[_GroundNSRT], - Sequence[_Option]]], - env: BaseEnv, - train_or_test: str, - task_idx: int, - max_num_steps: int, -) -> Video: - """Create a video from a list of skeletons and partial refinements. - Note that the environment internal state is updated. - """ - # Right now, the video is created by finding the longest partial - # refinement. One could also implement an "all_skeletons" mode - # that would create one video per skeleton. - if CFG.failure_video_mode == "longest_only": - # Visualize only the overall longest failed plan. - _, plan = max(partial_refinements, key=lambda x: len(x[1])) - policy = option_plan_to_policy(plan) - video: Video = [] - logging.debug("reset env for create video") - state = env.reset(train_or_test, task_idx) - # logging.debug(f"{pformat(state.pretty_str())}") - for _i in range(max_num_steps): - # logging.debug(f"state: {state.pretty_str()}") - try: - act = policy(state) - # logging.debug(f"act: {act}") - except OptionExecutionFailure: - video.extend(env.render()) - if not CFG.video_not_break_on_exception: - break - else: - video.extend(env.render(act)) - # logging.debug("Finished rendering.") - try: - state = env.step(act) - except EnvironmentFailure: - break - return video - raise NotImplementedError("Unrecognized failure video mode: " - f"{CFG.failure_video_mode}.") +def create_llm_by_name( + model_name: str) -> LargeLanguageModel: # pragma: no cover + """Create particular llm using a provided name.""" + if CFG.pretrained_model_service_provider == "openai": + return OpenAILLM(model_name) + if CFG.pretrained_model_service_provider == "google": + return GoogleGeminiLLM(model_name) + if CFG.pretrained_model_service_provider == "openrouter": + return OpenRouterLLM(model_name) + raise ValueError(f"Unknown pretrained model service provider: " + f"{CFG.pretrained_model_service_provider}") -def fig2data(fig: matplotlib.figure.Figure, dpi: int) -> Image: - """Convert matplotlib figure into Image.""" - fig.set_dpi(dpi) - fig.canvas.draw() - data = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8).copy() - data = data.reshape(fig.canvas.get_width_height()[::-1] + (4, )) - data[..., [0, 1, 2, 3]] = data[..., [1, 2, 3, 0]] - return data +def create_vlm_by_name( + model_name: str) -> VisionLanguageModel: # pragma: no cover + """Create particular vlm using a provided name.""" + if CFG.pretrained_model_service_provider == "openai": + return OpenAIVLM(model_name) + if CFG.pretrained_model_service_provider == "google": + return GoogleGeminiVLM(model_name) + if CFG.pretrained_model_service_provider == "openrouter": + return OpenRouterVLM(model_name) + raise ValueError(f"Unknown pretrained model service provider: " + f"{CFG.pretrained_model_service_provider}") def save_video(outfile: str, video: Video) -> None: @@ -4240,272 +116,6 @@ def save_images(outfile_prefix: str, video: Video) -> None: return save_images_parallel(outfile_prefix, video) -def get_env_asset_path(asset_name: str, assert_exists: bool = True) -> str: - """Return the absolute path to env asset.""" - dir_path = os.path.dirname(os.path.realpath(__file__)) - asset_dir_path = os.path.join(dir_path, "envs", "assets") - path = os.path.join(asset_dir_path, asset_name) - if assert_exists: - assert os.path.exists(path), f"Env asset not found: {asset_name}." - return path - - -def get_third_party_path() -> str: - """Return the absolute path to the third party directory.""" - third_party_dir_path = os.path.join(get_path_to_predicators_root(), - "predicators/third_party") - return third_party_dir_path - - -def get_path_to_predicators_root() -> str: - """Return the absolute path to the predicators root directory. - - Specifically, this returns something that looks like: - '/predicators'. Note there is no '/' at the end. - """ - module_path = Path(__file__) - predicators_dir = module_path.parent.parent - return str(predicators_dir) - - -def import_submodules(path: List[str], name: str) -> None: - """Load all submodules on the given path. - - Useful for finding subclasses of an abstract base class - automatically. - """ - if not TYPE_CHECKING: - for _, module_name, _ in pkgutil.walk_packages(path): - if "__init__" not in module_name: - # Important! We use an absolute import here to avoid issues - # with isinstance checking when using relative imports. - importlib.import_module(f"{name}.{module_name}") - - -def update_config(args: Dict[str, Any]) -> None: - """Args is a dictionary of new arguments to add to the config CFG.""" - parser = create_arg_parser() - update_config_with_parser(parser, args) - - -def update_config_with_parser(parser: ArgumentParser, args: Dict[str, - Any]) -> None: - """Helper function for update_config() that accepts a parser argument.""" - arg_specific_settings = GlobalSettings.get_arg_specific_settings(args) - # Only override attributes, don't create new ones. - allowed_args = set(CFG.__dict__) | set(arg_specific_settings) - # Unfortunately, can't figure out any other way to do this. - for parser_action in parser._actions: # pylint: disable=protected-access - allowed_args.add(parser_action.dest) - for k in args: - if k not in allowed_args: - raise ValueError(f"Unrecognized arg: {k}") - for k in ("env", "approach", "seed", "experiment_id"): - if k not in args and hasattr(CFG, k): - # For env, approach, seed, and experiment_id, if we don't - # pass in a value and this key is already in the - # configuration dict, add the current value to args. - args[k] = getattr(CFG, k) - for d in [arg_specific_settings, args]: - for k, v in d.items(): - setattr(CFG, k, v) - - -def reset_config(args: Optional[Dict[str, Any]] = None, - default_seed: int = 123, - default_render_state_dpi: int = 10) -> None: - """Reset to the default CFG, overriding with anything in args. - - This utility is meant for use in testing only. - """ - parser = create_arg_parser() - reset_config_with_parser(parser, args, default_seed, - default_render_state_dpi) - - -def reset_config_with_parser(parser: ArgumentParser, - args: Optional[Dict[str, Any]] = None, - default_seed: int = 123, - default_render_state_dpi: int = 10) -> None: - """Helper function for reset_config that accepts a parser argument.""" - default_args = parser.parse_args([ - "--env", - "default env placeholder", - "--seed", - str(default_seed), - "--approach", - "default approach placeholder", - ]) - arg_dict = { - k: v - for k, v in GlobalSettings.__dict__.items() if not k.startswith("_") - } - arg_dict.update(vars(default_args)) - if args is not None: - arg_dict.update(args) - if args is None or "render_state_dpi" not in args: - # By default, use a small value for the rendering DPI, to avoid - # expensive rendering during testing. - arg_dict["render_state_dpi"] = default_render_state_dpi - update_config_with_parser(parser, arg_dict) - - -def get_config_path_str(experiment_id: Optional[str] = None) -> str: - """Get a filename prefix for configuration based on the current CFG. - - If experiment_id is supplied, it is used in place of - CFG.experiment_id. - """ - if experiment_id is None: - experiment_id = CFG.experiment_id - if CFG.use_counterfactual_dataset_path_name: - return f"{CFG.env}__{CFG.seed}__{CFG.experiment_id}__query" - return (f"{CFG.env}__{CFG.approach}__{CFG.seed}__" - f"{CFG.excluded_predicates}__" - f"{CFG.included_options}__{experiment_id}") - - -def get_approach_save_path_str() -> str: - """Get a path for saving approaches.""" - os.makedirs(CFG.approach_dir, exist_ok=True) - return f"{CFG.approach_dir}/{get_config_path_str()}.saved" - - -def get_approach_load_path_str() -> str: - """Get a path for loading approaches.""" - if not CFG.load_experiment_id: - experiment_id = CFG.experiment_id - else: - experiment_id = CFG.load_experiment_id - return f"{CFG.approach_dir}/{get_config_path_str(experiment_id)}.saved" - - -def parse_args(env_required: bool = True, - approach_required: bool = True, - seed_required: bool = True) -> Dict[str, Any]: - """Parses command line arguments.""" - parser = create_arg_parser(env_required=env_required, - approach_required=approach_required, - seed_required=seed_required) - return parse_args_with_parser(parser) - - -def parse_args_with_parser(parser: ArgumentParser) -> Dict[str, Any]: - """Helper function for parse_args that accepts a parser argument.""" - args, overrides = parser.parse_known_args() - arg_dict = vars(args) - if len(overrides) == 0: - return arg_dict - # Update initial settings to make sure we're overriding - # existing flags only - update_config_with_parser(parser, arg_dict) - # Override global settings - assert len(overrides) >= 2 - assert len(overrides) % 2 == 0 - for flag, value in zip(overrides[:-1:2], overrides[1::2]): - assert flag.startswith("--") - setting_name = flag[2:] - if setting_name not in CFG.__dict__: - raise ValueError(f"Unrecognized flag: {setting_name}") - arg_dict[setting_name] = string_to_python_object(value) - return arg_dict - - -def string_to_python_object(value: str) -> Any: - """Return the Python object corresponding to the given string value.""" - if value in ("None", "none"): - return None - if value in ("True", "true"): - return True - if value in ("False", "false"): - return False - if value.isdigit() or value.startswith("lambda"): - return eval(value) - try: - return float(value) - except ValueError: - pass - if value.startswith("["): - assert value.endswith("]") - inner_strs = value[1:-1].split(",") - return [string_to_python_object(s) for s in inner_strs] - if value.startswith("("): - assert value.endswith(")") - inner_strs = value[1:-1].split(",") - return tuple(string_to_python_object(s) for s in inner_strs) - return value - - -def flush_cache() -> None: - """Clear all lru caches.""" - gc.collect() - _lru_type = functools._lru_cache_wrapper # pylint: disable=protected-access - wrappers = [] - for a in gc.get_objects(): - try: - if isinstance(a, _lru_type): - wrappers.append(a) - except Exception: # pylint: disable=broad-except - continue - - for wrapper in wrappers: - wrapper.cache_clear() - - -def parse_config_excluded_predicates( - env: BaseEnv) -> Tuple[Set[Predicate], Set[Predicate]]: - """Parse the CFG.excluded_predicates string, given an environment. - - Return a tuple of (included predicate set, excluded predicate set). - """ - if CFG.excluded_predicates: - if CFG.excluded_predicates == "all": - excluded_names = { - pred.name - for pred in env.predicates if pred not in env.goal_predicates - } - logging.info(f"All non-goal predicates excluded: {excluded_names}") - included = env.goal_predicates - else: - excluded_names = set(CFG.excluded_predicates.split(",")) - assert excluded_names.issubset( - {pred.name for pred in env.predicates}), \ - "Unrecognized predicate in excluded_predicates!" - included = { - pred - for pred in env.predicates if pred.name not in excluded_names - } - if CFG.offline_data_method != "demo+ground_atoms": - if CFG.allow_exclude_goal_predicates: - if not env.goal_predicates.issubset(included): - logging.info("Note: excluding goal predicates!") - else: - assert env.goal_predicates.issubset(included), \ - "Can't exclude a goal predicate!" - else: - excluded_names = set() - included = env.predicates - excluded = {pred for pred in env.predicates if pred.name in excluded_names} - return included, excluded - - -def replace_goals_with_agent_specific_goals( - included_predicates: Set[Predicate], - excluded_predicates: Set[Predicate], env: BaseEnv) -> Set[Predicate]: - """Replace original goal predicates with agent-specific goal predicates if - the environment defines them.""" - preds = included_predicates - env.goal_predicates \ - | env.agent_goal_predicates - excluded_predicates - return preds - - -def null_sampler(state: State, goal: Set[GroundAtom], rng: np.random.Generator, - objs: Sequence[Object]) -> Array: - """A sampler for an NSRT with no continuous parameters.""" - del state, goal, rng, objs # unused - return np.array([], dtype=np.float32) # no continuous parameters - - class ConstantDelay(DelayDistribution): """ConstantDelay class.""" @@ -4687,188 +297,6 @@ def _str(self) -> str: return f"DiscreteGaussianDelay({self.mu:.4f}, {self.sigma:.4f})" -@functools.lru_cache(maxsize=None) -def get_git_commit_hash() -> str: - """Return the hash of the current git commit.""" - out = subprocess.check_output(["git", "rev-parse", "HEAD"]) - return out.decode("ascii").strip() - - -def get_all_subclasses(cls: Any) -> Set[Any]: - """Get all subclasses of the given class.""" - return set(cls.__subclasses__()).union( - [s for c in cls.__subclasses__() for s in get_all_subclasses(c)]) - - -class _DummyFile(io.StringIO): - """Dummy file object used by nostdout().""" - - def write(self, _: Any) -> int: - """Mock write() method.""" - return 0 - - def flush(self) -> None: - """Mock flush() method.""" - - -@contextlib.contextmanager -def nostdout() -> Generator[None, None, None]: - """Suppress output for a block of code. - - To use, wrap code in the statement `with utils.nostdout():`. Note - that calls to the logging library, which this codebase uses - primarily, are unaffected. So, this utility is mostly helpful when - calling third-party code. - """ - save_stdout = sys.stdout - sys.stdout = _DummyFile() - yield - sys.stdout = save_stdout - - -def query_ldl( - ldl: LiftedDecisionList, - atoms: Set[GroundAtom], - objects: Set[Object], - goal: Set[GroundAtom], - static_predicates: Optional[Set[Predicate]] = None, - init_atoms: Optional[Collection[GroundAtom]] = None -) -> Optional[_GroundNSRT]: - """Queries a lifted decision list representing a goal-conditioned policy. - - Given an abstract state and goal, the rules are grounded in order. The - first applicable ground rule is used to return a ground NSRT. - - If static_predicates is provided, it is used to avoid grounding rules with - nonsense preconditions like IsBall(robot). - - If no rule is applicable, returns None. - """ - for rule in ldl.rules: - for ground_rule in all_ground_ldl_rules( - rule, - objects, - static_predicates=static_predicates, - init_atoms=init_atoms): - if ground_rule.pos_state_preconditions.issubset(atoms) and \ - not ground_rule.neg_state_preconditions & atoms and \ - ground_rule.goal_preconditions.issubset(goal): - return ground_rule.ground_nsrt - return None - - -def generate_random_string(length: int, alphabet: Sequence[str], - rng: np.random.Generator) -> str: - """Generates a random string of the given length using the provided set of - characters (alphabet).""" - assert all(len(c) == 1 for c in alphabet) - return "".join(rng.choice(alphabet, size=length)) - - -def find_balanced_expression(s: str, index: int) -> str: - """Find balanced expression in string starting from given index.""" - assert s[index] == "(" - start_index = index - balance = 1 - while balance != 0: - index += 1 - symbol = s[index] - if symbol == "(": - balance += 1 - elif symbol == ")": - balance -= 1 - return s[start_index:index + 1] - - -def find_all_balanced_expressions(s: str) -> List[str]: - """Return a list of all balanced expressions in a string, starting from the - beginning.""" - assert s[0] == "(" - assert s[-1] == ")" - exprs = [] - index = 0 - start_index = index - balance = 1 - while index < len(s) - 1: - index += 1 - if balance == 0: - exprs.append(s[start_index:index]) - # Jump to next "(". - while True: - if s[index] == "(": - break - index += 1 - start_index = index - balance = 1 - continue - symbol = s[index] - if symbol == "(": - balance += 1 - elif symbol == ")": - balance -= 1 - assert balance == 0 - exprs.append(s[start_index:index + 1]) - return exprs - - -def range_intersection(lb1: float, ub1: float, lb2: float, ub2: float) -> bool: - """Given upper and lower bounds for two ranges, returns True iff the ranges - intersect.""" - return (lb1 <= lb2 <= ub1) or (lb2 <= lb1 <= ub2) - - -def compute_abs_range_given_two_ranges(lb1: float, ub1: float, lb2: float, - ub2: float) -> Tuple[float, float]: - """Given upper and lower bounds of two feature ranges, returns the upper. - - and lower bound of |f1 - f2|. - """ - # Now, we must compute the upper and lower bounds of - # the expression |t1.f1 - t2.f2|. If the intervals - # [lb1, ub1] and [lb2, ub2] overlap, then the lower - # bound of the expression is just 0. Otherwise, if - # lb2 > ub1, the lower bound is |ub1 - lb2|, and if - # ub2 < lb1, the lower bound is |lb1 - ub2|. - if range_intersection(lb1, ub1, lb2, ub2): - lb = 0.0 - else: - lb = min(abs(lb2 - ub1), abs(lb1 - ub2)) - # The upper bound for the expression can be - # computed in a similar fashion. - ub = max(abs(ub2 - lb1), abs(ub1 - lb2)) - return (lb, ub) - - -def roundrobin(iterables: Sequence[Iterator]) -> Iterator: - """roundrobin(['ABC...', 'D...', 'EF...']) --> A D E B F C...""" - # Recipe credited to George Sakkis, code adapted slightly from - # from https://docs.python.org/3/library/itertools.html - num_active = len(iterables) - nexts = itertools.cycle(iter(it).__next__ for it in iterables) - while num_active: - for nxt in nexts: - yield nxt() - - -def get_task_seed(train_or_test: str, task_idx: int) -> int: - """Parses task seed from CFG.test_env_seed_offset.""" - assert task_idx < CFG.test_env_seed_offset - # SeedSequence generates a sequence of random values given an integer - # "entropy". We use CFG.seed to define the "entropy" and then get the - # n^th generated random value and use that to seed the gym environment. - # This is all to avoid unintentional dependence between experiments - # that are conducted with consecutive random seeds. For example, if - # we used CFG.seed + task_idx to seed the gym environment, there would - # be overlap between experiments when CFG.seed = 1, CFG.seed = 2, etc. - seed_entropy = CFG.seed - if train_or_test == "test": - seed_entropy += CFG.test_env_seed_offset - seed_sequence = np.random.SeedSequence(seed_entropy) - # Need to cast to int because generate_state() returns a numpy int. - task_seed = int(seed_sequence.generate_state(task_idx + 1)[-1]) - return task_seed - - def _beta_bernoulli_posterior_alpha_beta( success_history: List[bool], alpha: float = 1.0, @@ -4918,405 +346,3 @@ def beta_from_mean_and_variance(mean: float, rv = BetaRV(alpha, beta) assert abs(rv.mean() - mean) < 1e-6 return rv - - -def _obs_to_state_pass_through(obs: Observation) -> State: - """Helper for run_ground_nsrt_with_assertions.""" - assert isinstance(obs, State) - return obs - - -def run_ground_nsrt_with_assertions(ground_nsrt: _GroundNSRT, - state: State, - env: BaseEnv, - rng: np.random.Generator, - override_params: Optional[Array] = None, - obs_to_state: Callable[ - [Observation], - State] = _obs_to_state_pass_through, - assert_effects: bool = True, - max_steps: int = 400) -> State: - """Utility for tests. - - NOTE: assumes that the internal state of env corresponds to state. - """ - ground_nsrt_str = f"{ground_nsrt.name}{ground_nsrt.objects}" - for atom in ground_nsrt.preconditions: - assert atom.holds(state), \ - f"Precondition for {ground_nsrt_str} failed: {atom}" - option = ground_nsrt.sample_option(state, set(), rng) - if override_params is not None: - option = option.parent.ground(option.objects, - override_params) # pragma: no cover - assert option.initiable(state) - for _ in range(max_steps): - act = option.policy(state) - obs = env.step(act) - state = obs_to_state(obs) - if option.terminal(state): - break - if assert_effects: - for atom in ground_nsrt.add_effects: - assert atom.holds(state), \ - f"Add effect for {ground_nsrt_str} failed: {atom}" - for atom in ground_nsrt.delete_effects: - assert not atom.holds(state), \ - f"Delete effect for {ground_nsrt_str} failed: {atom}" - return state - - -def get_scaled_default_font( - draw: ImageDraw.ImageDraw, - size: int) -> ImageFont.FreeTypeFont: # pragma: no cover - """Method that modifies the size of some provided PIL ImageDraw font. - - Useful for scaling up font sizes when using PIL to insert text - directly into images. - """ - # Determine the scaling factor - base_font = ImageFont.load_default() - width, height = draw.textbbox((0, 0), "A", font=base_font)[:2] - scale_factor = size / max(width, height) - # Scale the font using the factor - return base_font.font_variant(size=int(scale_factor * # type: ignore - base_font.size)) # type: ignore - - -def add_text_to_draw_img( - draw: ImageDraw.ImageDraw, position: Tuple[int, int], text: str, - font: ImageFont.FreeTypeFont -) -> ImageDraw.ImageDraw: # pragma: no cover - """Method that adds some text with a particular font at a particular pixel - position in an input PIL.ImageDraw.ImageDraw image. - - Returns the modified ImageDraw.ImageDraw with the added text. - """ - text_width, text_height = draw.textbbox((0, 0), text, font=font)[2:] - background_position = (position[0] - 5, position[1] - 5 - ) # Slightly larger than text - background_size = (text_width + 10, text_height + 10) - # Draw the background rectangle - draw.rectangle( - (background_position, (background_position[0] + background_size[0], - background_position[1] + background_size[1])), - fill="black") - # Add the text to the image - draw.text(position, text, fill="red", font=font) - return draw - - -def wrap_angle(angle: float) -> float: - """Wrap an angle in radians to [-pi, pi].""" - return np.arctan2(np.sin(angle), np.cos(angle)) - - -def get_parameterized_option_by_name( - options: Set[ParameterizedOption], - option_name: str) -> Optional[ParameterizedOption]: - """Retrieve an option by its name from a set of options.""" - return next((option for option in options if option.name == option_name), - None) - - -def get_object_by_name(objects: Collection[Object], - name: str) -> Optional[Object]: - """Get an object by its name from a collection of objects. - - Args: - objects: Collection of objects to search through - name: Name of the object to find - - Returns: - The object if found, None otherwise - """ - return next((obj for obj in objects if obj.name == name), None) - - -def configure_logging() -> None: - """Configure logging with colored output.""" - # Create a single formatter instance to be reused - colored_formatter = colorlog.ColoredFormatter( - '%(log_color)s%(levelname)s: %(message)s', - log_colors={ - 'DEBUG': 'cyan', - 'INFO': 'green', - 'WARNING': 'yellow', - 'ERROR': 'red', - 'CRITICAL': 'red,bg_white', - }, - reset=True, - style='%') - # Log to stderr. - colorlog_handler = colorlog.StreamHandler() - colorlog_handler.setFormatter(colored_formatter) - handlers: List[logging.Handler] = [colorlog_handler] - if CFG.log_file: - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - CFG.log_file += (f"{CFG.approach}/{CFG.experiment_id}/" - f"seed{CFG.seed}/run_{timestamp}/") - os.makedirs(CFG.log_file, exist_ok=True) - - # Handler for DEBUG level messages - debug_handler = logging.FileHandler(os.path.join( - CFG.log_file, "debug.log"), - mode='w') - debug_handler.setLevel(logging.DEBUG) - debug_handler.setFormatter(colored_formatter) - handlers.append(debug_handler) - - # Handler for INFO level messages - info_handler = logging.FileHandler(os.path.join( - CFG.log_file, "info.log"), - mode='w') - info_handler.setLevel(logging.INFO) - info_handler.setFormatter(colored_formatter) - handlers.append(info_handler) - - logging.basicConfig(level=CFG.loglevel, - format="%(message)s", - handlers=handlers, - force=True) - logging.getLogger('matplotlib.font_manager').setLevel(logging.ERROR) - logging.getLogger('libpng').setLevel(logging.ERROR) - logging.getLogger('PIL').setLevel(logging.ERROR) - logging.getLogger('openai').setLevel(logging.INFO) - # Used by openai package - logging.getLogger("httpx").setLevel(logging.INFO) - logging.getLogger("httpcore").setLevel(logging.INFO) - - -def log_initial_info(str_args: str) -> None: - """Log initial configuration and setup information.""" - if CFG.log_file: - logging.info(f"Logging to {CFG.log_file}") - logging.info(f"Running command: python {str_args}") - logging.info("Full config:") - logging.info(CFG) - logging.info(f"Git commit hash: {get_git_commit_hash()}") - - -def add_label_to_video(video: Video, - prefix: str, - imgs_dir: str, - save: bool = True) -> Video: - """Add a label to each frame of the video and save the images.""" - os.makedirs(imgs_dir, exist_ok=True) - new_video: Video = [] - for i, img in enumerate(video): - img_name = prefix + f"frame_{i+1}" - labeled_img = add_label_to_image( - img, # type: ignore[arg-type] - img_name, - imgs_dir, - save=save) - new_video.append(labeled_img) # type: ignore[arg-type] - return new_video - - -def add_label_to_image(img: PIL.Image.Image, - s_name: str, - obs_dir: str, - f_suffix: str = ".png", - save: bool = True) -> PIL.Image.Image: - """Add a label to an image and potentially save.""" - img_copy = img.copy() - draw = ImageDraw.Draw(img_copy) - font = ImageFont.load_default().font_variant( # type: ignore[union-attr] - size=50) - - # Get text dimensions - bbox = draw.textbbox((0, 0), s_name, font=font) - text_width = bbox[2] - bbox[0] - text_height = bbox[3] - bbox[1] - - # Calculate position (bottom right with padding) - padding = 10 - x = img_copy.width - text_width - padding - y = img_copy.height - text_height - 2 * padding - - text_color = (0, 0, 0) # black - draw.text((x, y), s_name, fill=text_color, font=font) - - if save: - os.makedirs(obs_dir, exist_ok=True) - img_copy.save(os.path.join(obs_dir, s_name + f_suffix)) - logging.debug(f"Saved Image {s_name}") - return img_copy - - -def load_all_images_from_dir(dir_path: str) -> List[PIL.Image.Image]: - """Load all images from a directory.""" - images = [] - img_paths = sorted(os.listdir(dir_path)) - for file in img_paths: - if file.endswith(('.png', '.jpg')): - images.append(PIL.Image.open(os.path.join(dir_path, file))) - return images - - -def all_subsets(input_set: Iterable[Any]) -> Iterator[Set[Any]]: - """Generates all subsets of a given set. - - Args: - input_set: An iterable (e.g., a list, set, tuple) - from which to generate subsets. - - Yields: - tuple: Each subset as a tuple. - """ - s = list(input_set) # Convert to list to handle various iterable inputs - n = len(s) - for i in range(n + 1): # Iterate from subset size 0 up to n - for subset in itertools.combinations(s, i): - yield set(subset) - - -def add_in_auxiliary_predicates(predicates: Set[Predicate]) -> Set[Predicate]: - """Add auxiliary predicates from derived predicates.""" - - def add_auxiliary(pred: Predicate, preds: Set[Predicate]) -> None: - if isinstance(pred, DerivedPredicate): - if pred.auxiliary_predicates: - preds.update(pred.auxiliary_predicates) - for aux_pred in pred.auxiliary_predicates: - add_auxiliary(aux_pred, preds) - - new_preds = predicates.copy() - for pred in predicates: - add_auxiliary(pred, new_preds) - return new_preds - - -def get_derived_predicates( - predicates: Set[Predicate]) -> Set[DerivedPredicate]: - """Get all derived predicates from a set of predicates.""" - return {pred for pred in predicates if isinstance(pred, DerivedPredicate)} - - -# def abstract_with_derived_predicates(atoms, derived_preds, objects): -# """Compute all derived atoms via layered evaluation (fewer passes). -# Potentially faster than the current implementation.""" -# # Build dependency graph over derived preds -# is_derived = {p for p in derived_preds} -# indeg = {p: 0 for p in derived_preds} -# edges = {p: set() for p in derived_preds} -# for p in derived_preds: -# for aux in getattr(p, "auxiliary_predicates", []): -# # only count deps on other derived preds -# q = next( -# (dp for dp in derived_preds -# if dp.name == aux.name), None) -# if q: -# edges[q].add(p); indeg[p] += 1 - -# # Kahn’s algorithm => layers -# frontier = [p for p in derived_preds if indeg[p] == 0] -# layers: list[list] = [] -# while frontier: -# layer = list(frontier); layers.append(layer); frontier = [] -# for u in layer: -# for v in edges[u]: -# indeg[v] -= 1 -# if indeg[v] == 0: -# frontier.append(v) - -# # Evaluate per layer; state grows monotonically -# state = set(atoms) -# derived_all = set() -# # (Optional) cache object choices per predicate once -# by_type = {} -# for o in objects: -# by_type.setdefault(o.type, []).append(o) -# choices_cache = { -# p: list(itertools.product(*(by_type[t] for t in p.types))) -# for p in derived_preds -# } - -# for layer in layers: -# for p in layer: -# for choice in choices_cache[p]: -# if p.holds(state, choice): -# derived_all.add(GroundAtom(p, choice)) -# state |= derived_all # grow state for next layer - -# return derived_all - - -def abstract_with_derived_predicates( - atoms: Set[GroundAtom], derived_preds: Collection[DerivedPredicate], - objects: Collection[Object]) -> Set[GroundAtom]: - """Compute the fixed point of concept predicate atoms.""" - primitive_atoms = atoms - new_concept_atoms: Set[GroundAtom] = set() - prev_new_concept_atoms: Set[GroundAtom] = set() - counter = 0 - while True: - # All the concept atoms that holds; all the previous atoms - atoms = primitive_atoms | new_concept_atoms - new_concept_atoms = _abstract_with_derived_predicates( - atoms, derived_preds, objects) - # logging.debug(f"ite {counter} concept atoms: {new_concept_atoms}") - converged = new_concept_atoms == prev_new_concept_atoms - if converged: - # logging.debug("converged") - break - prev_new_concept_atoms = new_concept_atoms - counter += 1 - return new_concept_atoms - - -def _abstract_with_derived_predicates( - abs_state: Set[GroundAtom], - derived_preds: Collection[DerivedPredicate], - objects: Collection[Object]) -> Set[GroundAtom]: - """Get the atoms based on the existing atomic state and concept - predicates.""" - atoms: Set[GroundAtom] = set() - for pred in derived_preds: - for choice in get_object_combinations(objects, pred.types): - try: - if pred.holds(abs_state, choice): - atoms.add(GroundAtom(pred, choice)) - except Exception as e: - logging.error(f"Error in evaluating concept predicate {pred}: " - f"{e}") - # raise e - raise PredicateEvaluationError( - f"Error in evaluating concept predicate {pred}: {e}", pred) - return atoms - - -def get_base_supporter_predicates( - root_predicate: DerivedPredicate) -> Set[Predicate]: - """Finds all primitive (non-derived) supporter predicates for a given root - derived predicate by traversing its dependency graph.""" - base_predicates: Set[Predicate] = set() - - # Use a worklist to process predicates in a breadth-first manner. - predicates_to_process: List[Predicate] = list( - root_predicate.auxiliary_predicates or []) - processed_predicates: Set[Predicate] = {root_predicate} - - while predicates_to_process: - pred = predicates_to_process.pop(0) - - if pred in processed_predicates: - continue - processed_predicates.add(pred) - - # If the predicate is derived, add its auxiliaries to the worklist. - if isinstance(pred, DerivedPredicate): - predicates_to_process.extend(pred.auxiliary_predicates or []) - # If it's a primitive predicate, we've found a base supporter. - else: - base_predicates.add(pred) - - return base_predicates - - -class PredicateEvaluationError(Exception): - """PredicateEvaluationError class.""" - - def __init__(self, message: str, pred: Any) -> None: - super().__init__(message) - self.pred = pred diff --git a/predicators/utils_lite.py b/predicators/utils_lite.py new file mode 100644 index 0000000000..7d8f3c0379 --- /dev/null +++ b/predicators/utils_lite.py @@ -0,0 +1,5069 @@ +"""General utility methods (Pyodide-safe subset). + +This module is imported directly by predicators code that needs to +load under Pyodide in the browser POC (envs, pybullet_helpers, etc.). +It deliberately avoids `torch`, `imageio`, the pretrained-model SDKs, +and `scipy.stats.beta` — those helpers live in +:mod:`predicators.utils`, which re-exports every symbol from this +module on top of them. + +CPython callers should keep importing :mod:`predicators.utils`; the +two modules are kept in sync via a wildcard re-export. +""" + +from __future__ import annotations + +import abc +import contextlib +import copy +import datetime +import functools +import gc +import heapq as hq +import importlib +import io +import itertools +import logging +import os +import pkgutil +import re +import subprocess +import sys +import time +from argparse import ArgumentParser +from collections import defaultdict, namedtuple +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Dict, \ + FrozenSet, Generator, Generic, Hashable, Iterable, Iterator, List, \ + Optional, Sequence, Set, Tuple +from typing import Type as TypingType +from typing import TypeVar, Union, cast + +import colorlog +import dill as pkl +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import PIL.Image +from gym.spaces import Box +from matplotlib import patches +from numpy.typing import NDArray +from PIL import ImageDraw, ImageFont +from pyperplan.heuristics.heuristic_base import \ + Heuristic as _PyperplanBaseHeuristic +from pyperplan.planner import HEURISTICS as _PYPERPLAN_HEURISTICS + +from predicators.args import create_arg_parser +from predicators.pybullet_helpers.joint import JointPositions +from predicators.settings import CFG, GlobalSettings +from predicators.structs import NSRT, Action, Array, AtomOptionTrajectory, \ + CausalProcess, DerivedPredicate, DummyOption, EntToEntSub, GroundAtom, \ + GroundAtomTrajectory, GroundNSRTOrSTRIPSOperator, Image, LDLRule, \ + LiftedAtom, LiftedDecisionList, LiftedOrGroundAtom, LowLevelTrajectory, \ + Mask, Metrics, NSRTOrSTRIPSOperator, Object, ObjectOrVariable, \ + Observation, OptionSpec, ParameterizedOption, Predicate, Segment, State, \ + STRIPSOperator, Task, Type, Variable, VarToObjSub, Video, VLMPredicate, \ + _GroundEndogenousProcess, _GroundLDLRule, _GroundNSRT, \ + _GroundSTRIPSOperator, _Option, _TypedEntity +from predicators.third_party.fast_downward_translator.translate import \ + main as downward_translate + +if TYPE_CHECKING: + from predicators.envs import BaseEnv + # Used only in type annotations on query_vlm_for_atom_vals / abstract; + # `from __future__ import annotations` keeps the runtime import out + # of the Pyodide load path. The lazy runtime use of + # create_vlm_by_name lives inside query_vlm_for_atom_vals below. + from predicators.pretrained_model_interface import VisionLanguageModel + +matplotlib.use("Agg") + +# Unpickling CUDA models errs out if the device isn't recognized because of +# an unusual name, including in supercloud, but we can set it manually +if "CUDA_VISIBLE_DEVICES" in os.environ: # pragma: no cover + cuda_visible_devices = os.environ["CUDA_VISIBLE_DEVICES"].split(",") + if len(cuda_visible_devices) and cuda_visible_devices[0] != "0": + cuda_visible_devices[0] = "0" + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(cuda_visible_devices) + + +def count_positives_for_ops( + strips_ops: List[STRIPSOperator], + option_specs: List[OptionSpec], + segments: List[Segment], + max_groundings: Optional[int] = None, +) -> Tuple[int, int, List[Set[int]], List[Set[int]]]: + """Returns num true positives, num false positives, and for each strips op, + lists of segment indices that contribute true or false positives. + + The lists of segment indices are useful only for debugging; they are + otherwise redundant with num_true_positives/num_false_positives. + """ + assert len(strips_ops) == len(option_specs) + num_true_positives = 0 + num_false_positives = 0 + # The following two lists are just useful for debugging. + true_positive_idxs: List[Set[int]] = [set() for _ in strips_ops] + false_positive_idxs: List[Set[int]] = [set() for _ in strips_ops] + for seg_idx, segment in enumerate(segments): + objects = set(segment.states[0]) + segment_option = segment.get_option() + option_objects = segment_option.objects + covered_by_some_op = False + # Ground only the operators with a matching option spec. + for op_idx, (op, + option_spec) in enumerate(zip(strips_ops, option_specs)): + # If the parameterized options are different, not relevant. + if option_spec[0] != segment_option.parent: + continue + option_vars = option_spec[1] + assert len(option_vars) == len(option_objects) + option_var_to_obj = dict(zip(option_vars, option_objects)) + # We want to get all ground operators whose corresponding + # substitution is consistent with the option vars for this + # segment. So, determine all of the operator variables + # that are not in the option vars, and consider all + # groundings of them. + for grounding_idx, ground_op in enumerate( + all_ground_operators_given_partial(op, objects, + option_var_to_obj)): + if max_groundings is not None and \ + grounding_idx > max_groundings: + break + # Check the ground_op against the segment. + if not ground_op.preconditions.issubset(segment.init_atoms): + continue + if ground_op.add_effects == segment.add_effects and \ + ground_op.delete_effects == segment.delete_effects: + covered_by_some_op = True + true_positive_idxs[op_idx].add(seg_idx) + else: + false_positive_idxs[op_idx].add(seg_idx) + num_false_positives += 1 + if covered_by_some_op: + num_true_positives += 1 + return num_true_positives, num_false_positives, \ + true_positive_idxs, false_positive_idxs + + +def count_branching_factor(strips_ops: List[STRIPSOperator], + segments: List[Segment]) -> int: + """Returns the total branching factor for all states in the segments.""" + total_branching_factor = 0 + for segment in segments: + atoms = segment.init_atoms + objects = set(segment.states[0]) + ground_ops = { + ground_op + for op in strips_ops + for ground_op in all_ground_operators(op, objects) + } + for _ in get_applicable_operators(ground_ops, atoms): + total_branching_factor += 1 + return total_branching_factor + + +def segment_trajectory_to_start_end_state_sequence( + seg_traj: List[Segment]) -> List[State]: + """Convert a trajectory of segments into a trajectory of states, made up of + only the initial/final states of the segments. + + The length of the return value will always be one greater than the + length of the given seg_traj. + """ + assert len(seg_traj) >= 1 + states = [] + for i, seg in enumerate(seg_traj): + states.append(seg.states[0]) + if i < len(seg_traj) - 1: + assert seg.states[-1].allclose(seg_traj[i + 1].states[0]) + states.append(seg_traj[-1].states[-1]) + assert len(states) == len(seg_traj) + 1 + return states + + +def segment_trajectory_to_atoms_sequence( + seg_traj: List[Segment]) -> List[Set[GroundAtom]]: + """Convert a trajectory of segments into a trajectory of ground atoms. + + The length of the return value will always be one greater than the + length of the given seg_traj. + """ + assert len(seg_traj) >= 1 + atoms_seq = [] + for i, seg in enumerate(seg_traj): + atoms_seq.append(seg.init_atoms) + if i < len(seg_traj) - 1: + assert seg.final_atoms == seg_traj[i + 1].init_atoms + atoms_seq.append(seg_traj[-1].final_atoms) + assert len(atoms_seq) == len(seg_traj) + 1 + return atoms_seq + + +def num_options_in_action_sequence(actions: Sequence[Action]) -> int: + """Given a sequence of actions with options included, get the number of + options that are encountered.""" + num_options = 0 + last_option = None + for action in actions: + current_option = action.get_option() + if not current_option is last_option: + last_option = current_option + num_options += 1 + return num_options + + +def entropy(p: float) -> float: + """Entropy of a Bernoulli variable with parameter p.""" + assert 0.0 <= p <= 1.0 + if p in {0.0, 1.0}: + return 0.0 + return -(p * np.log2(p) + (1 - p) * np.log2(1 - p)) + + +def create_state_from_dict(data: Dict[Object, Dict[str, float]], + simulator_state: Optional[Any] = None) -> State: + """Small utility to generate a state from a dictionary `data` of individual + feature values for each object. + + A simulator_state for the outputted State may optionally be + provided. + """ + state_dict = {} + for obj, obj_data in data.items(): + obj_vec = [] + for feat in obj.type.feature_names: + obj_vec.append(obj_data[feat]) + state_dict[obj] = np.array(obj_vec) + return State(state_dict, simulator_state) + + +def create_json_dict_from_ground_atoms( + ground_atoms: Collection[GroundAtom]) -> Dict[str, List[List[str]]]: + """Saves a set of ground atoms in a JSON-compatible dict. + + Helper for creating the goal dict in create_json_dict_from_task(). + """ + predicate_to_argument_lists = defaultdict(list) + for atom in sorted(ground_atoms): + argument_list = [o.name for o in atom.objects] + predicate_to_argument_lists[atom.predicate.name].append(argument_list) + return dict(predicate_to_argument_lists) + + +def create_json_dict_from_task(task: Task) -> Dict[str, Any]: + """Create a JSON-compatible dict from a task. + + The format of the dict is: + + { + "objects": { + : + } + "init": { + : { + : + } + } + "goal": { + : [ + [] + ] + } + } + + The dict can be loaded with BaseEnv._load_task_from_json(). This is + helpful for testing and designing standalone tasks. + """ + object_dict = {o.name: o.type.name for o in task.init} + init_dict = { + o.name: dict(zip(o.type.feature_names, task.init.data[o])) + for o in task.init + } + goal_dict = create_json_dict_from_ground_atoms(task.goal) + return {"objects": object_dict, "init": init_dict, "goal": goal_dict} + + +def construct_active_sampler_input(state: State, objects: Sequence[Object], + params: Array, + param_option: ParameterizedOption) -> Array: + """Helper function for active sampler learning and explorer.""" + + assert not CFG.sampler_learning_use_goals + sampler_input_lst = [1.0] # start with bias term + if CFG.active_sampler_learning_feature_selection == "all": + for obj in objects: + sampler_input_lst.extend(state[obj]) + sampler_input_lst.extend(params) + + else: + assert CFG.active_sampler_learning_feature_selection == "oracle" + if CFG.env == "bumpy_cover": + if param_option.name == "Pick": + # In this case, the x-data should be + # [block_bumpy, relative_pick_loc] + assert len(objects) == 1 + block = objects[0] + block_pos = state[block][3] + block_bumpy = state[block][5] + sampler_input_lst.append(block_bumpy) + assert len(params) == 1 + sampler_input_lst.append(params[0] - block_pos) + else: + assert param_option.name == "Place" + assert len(objects) == 2 + block, target = objects + target_pos = state[target][3] + grasp = state[block][4] + target_width = state[target][2] + sampler_input_lst.extend([grasp, target_width]) + assert len(params) == 1 + sampler_input_lst.append(params[0] - target_pos) + elif CFG.env == "ball_and_cup_sticky_table": + if "PlaceCup" in param_option.name and "Table" in param_option.name: + _, _, _, table = objects + table_y = state.get(table, "y") + table_x = state.get(table, "x") + sticky = state.get(table, "sticky") + sticky_region_x = state.get(table, "sticky_region_x_offset") + sticky_region_y = state.get(table, "sticky_region_y_offset") + sticky_region_radius = state.get(table, "sticky_region_radius") + table_radius = state.get(table, "radius") + _, _, _, param_x, param_y = params + sampler_input_lst.append(table_radius) + sampler_input_lst.append(sticky) + sampler_input_lst.append(sticky_region_x) + sampler_input_lst.append(sticky_region_y) + sampler_input_lst.append(sticky_region_radius) + sampler_input_lst.append(table_x) + sampler_input_lst.append(table_y) + sampler_input_lst.append(param_x) + sampler_input_lst.append(param_y) + else: # Use all features. + for obj in objects: + sampler_input_lst.extend(state[obj]) + sampler_input_lst.extend(params) + else: + raise NotImplementedError("Oracle feature selection not " + f"implemented for {CFG.env}") + + return np.array(sampler_input_lst) + + +class _Geom2D(abc.ABC): + """A 2D shape that contains some points.""" + + @abc.abstractmethod + def plot(self, ax: plt.Axes, **kwargs: Any) -> None: + """Plot the shape on a given pyplot axis.""" + raise NotImplementedError("Override me!") + + @abc.abstractmethod + def contains_point(self, x: float, y: float) -> bool: + """Checks if a point is contained in the shape.""" + raise NotImplementedError("Override me!") + + @abc.abstractmethod + def sample_random_point(self, + rng: np.random.Generator) -> Tuple[float, float]: + """Samples a random point inside the 2D shape.""" + raise NotImplementedError("Override me!") + + def intersects(self, other: _Geom2D) -> bool: + """Checks if this shape intersects with another one.""" + return geom2ds_intersect(self, other) + + +@dataclass(frozen=True) +class LineSegment(_Geom2D): + """A helper class for visualizing and collision checking line segments.""" + x1: float + y1: float + x2: float + y2: float + + def plot(self, ax: plt.Axes, **kwargs: Any) -> None: + ax.plot([self.x1, self.x2], [self.y1, self.y2], **kwargs) + + def contains_point(self, x: float, y: float) -> bool: + # https://stackoverflow.com/questions/328107 + a = (self.x1, self.y1) + b = (self.x2, self.y2) + c = (x, y) + # Need to use an epsilon for numerical stability. But we are checking + # if the distance from a to b is (approximately) equal to the distance + # from a to c and the distance from c to b. + eps = 1e-6 + + def _dist(p: Tuple[float, float], q: Tuple[float, float]) -> float: + return np.sqrt((p[0] - q[0])**2 + (p[1] - q[1])**2) + + return -eps < _dist(a, c) + _dist(c, b) - _dist(a, b) < eps + + def sample_random_point(self, + rng: np.random.Generator) -> Tuple[float, float]: + line_slope = (self.y2 - self.y1) / (self.x2 - self.x1) + y_intercept = self.y2 - (line_slope * self.x2) + random_x_point = rng.uniform(self.x1, self.x2) + random_y_point_on_line = line_slope * random_x_point + y_intercept + assert self.contains_point(random_x_point, random_y_point_on_line) + return (random_x_point, random_y_point_on_line) + + +@dataclass(frozen=True) +class Circle(_Geom2D): + """A helper class for visualizing and collision checking circles.""" + x: float + y: float + radius: float + + def plot(self, ax: plt.Axes, **kwargs: Any) -> None: + patch = patches.Circle((self.x, self.y), self.radius, **kwargs) + ax.add_patch(patch) + + def contains_point(self, x: float, y: float) -> bool: + return (x - self.x)**2 + (y - self.y)**2 <= self.radius**2 + + def contains_circle(self, other_circle: Circle) -> bool: + """Check whether this circle wholly contains another one.""" + dist_between_centers = np.sqrt((other_circle.x - self.x)**2 + + (other_circle.y - self.y)**2) + return (dist_between_centers + other_circle.radius) <= self.radius + + def sample_random_point(self, + rng: np.random.Generator) -> Tuple[float, float]: + rand_mag = rng.uniform(0, self.radius) + rand_theta = rng.uniform(0, 2 * np.pi) + x_point = self.x + rand_mag * np.cos(rand_theta) + y_point = self.y + rand_mag * np.sin(rand_theta) + assert self.contains_point(x_point, y_point) + return (x_point, y_point) + + +@dataclass(frozen=True) +class Triangle(_Geom2D): + """A helper class for visualizing and collision checking triangles.""" + x1: float + y1: float + x2: float + y2: float + x3: float + y3: float + + def plot(self, ax: plt.Axes, **kwargs: Any) -> None: + patch = patches.Polygon( + [[self.x1, self.y1], [self.x2, self.y2], [self.x3, self.y3]], + **kwargs) + ax.add_patch(patch) + + def __post_init__(self) -> None: + dist1 = np.sqrt((self.x1 - self.x2)**2 + (self.y1 - self.y2)**2) + dist2 = np.sqrt((self.x2 - self.x3)**2 + (self.y2 - self.y3)**2) + dist3 = np.sqrt((self.x3 - self.x1)**2 + (self.y3 - self.y1)**2) + dists = sorted([dist1, dist2, dist3]) + assert dists[0] + dists[1] >= dists[2] + if dists[0] + dists[1] == dists[2]: + raise ValueError("Degenerate triangle!") + + def contains_point(self, x: float, y: float) -> bool: + # Adapted from https://stackoverflow.com/questions/2049582/. + sign1 = ((x - self.x2) * (self.y1 - self.y2) - (self.x1 - self.x2) * + (y - self.y2)) > 0 + sign2 = ((x - self.x3) * (self.y2 - self.y3) - (self.x2 - self.x3) * + (y - self.y3)) > 0 + sign3 = ((x - self.x1) * (self.y3 - self.y1) - (self.x3 - self.x1) * + (y - self.y1)) > 0 + has_neg = (not sign1) or (not sign2) or (not sign3) + has_pos = sign1 or sign2 or sign3 + return not has_neg or not has_pos + + def sample_random_point(self, + rng: np.random.Generator) -> Tuple[float, float]: + a = np.array([self.x2 - self.x1, self.y2 - self.y1]) + b = np.array([self.x3 - self.x1, self.y3 - self.y1]) + u1 = rng.uniform(0, 1) + u2 = rng.uniform(0, 1) + if u1 + u2 > 1.0: + u1 = 1 - u1 + u2 = 1 - u2 + point_in_triangle = (u1 * a + u2 * b) + np.array([self.x1, self.y1]) + assert self.contains_point(point_in_triangle[0], point_in_triangle[1]) + return (point_in_triangle[0], point_in_triangle[1]) + + +@dataclass(frozen=True) +class Rectangle(_Geom2D): + """A helper class for visualizing and collision checking rectangles. + + Following the convention in plt.Rectangle, the origin is at the + bottom left corner, and rotation is anti-clockwise about that point. + + Unlike plt.Rectangle, the angle is in radians. + """ + x: float + y: float + width: float + height: float + theta: float # in radians, between -np.pi and np.pi + + def __post_init__(self) -> None: + assert -np.pi <= self.theta <= np.pi, "Expecting angle in [-pi, pi]." + + @staticmethod + def from_center(center_x: float, center_y: float, width: float, + height: float, rotation_about_center: float) -> Rectangle: + """Create a rectangle given an (x, y) for the center, with theta + rotating about that center point.""" + x = center_x - width / 2 + y = center_y - height / 2 + norm_rect = Rectangle(x, y, width, height, 0.0) + assert np.isclose(norm_rect.center[0], center_x) + assert np.isclose(norm_rect.center[1], center_y) + return norm_rect.rotate_about_point(center_x, center_y, + rotation_about_center) + + @functools.cached_property + def rotation_matrix(self) -> NDArray[np.float64]: + """Get the rotation matrix.""" + return np.array([[np.cos(self.theta), -np.sin(self.theta)], + [np.sin(self.theta), + np.cos(self.theta)]]) + + @functools.cached_property + def inverse_rotation_matrix(self) -> NDArray[np.float64]: + """Get the inverse rotation matrix.""" + return np.array([[np.cos(self.theta), + np.sin(self.theta)], + [-np.sin(self.theta), + np.cos(self.theta)]]) + + @functools.cached_property + def vertices(self) -> List[Tuple[float, float]]: + """Get the four vertices for the rectangle.""" + scale_matrix = np.array([ + [self.width, 0], + [0, self.height], + ]) + translate_vector = np.array([self.x, self.y]) + vertices = np.array([ + (0, 0), + (0, 1), + (1, 1), + (1, 0), + ]) + vertices = vertices @ scale_matrix.T + vertices = vertices @ self.rotation_matrix.T + vertices = translate_vector + vertices + # Convert to a list of tuples. Slightly complicated to appease both + # type checking and linting. + return list(map(lambda p: (p[0], p[1]), vertices)) + + @functools.cached_property + def line_segments(self) -> List[LineSegment]: + """Get the four line segments for the rectangle.""" + vs = list(zip(self.vertices, self.vertices[1:] + [self.vertices[0]])) + line_segments = [] + for ((x1, y1), (x2, y2)) in vs: + line_segments.append(LineSegment(x1, y1, x2, y2)) + return line_segments + + @functools.cached_property + def center(self) -> Tuple[float, float]: + """Get the point at the center of the rectangle.""" + x, y = np.mean(self.vertices, axis=0) + return (x, y) + + @functools.cached_property + def circumscribed_circle(self) -> Circle: + """Returns x, y, radius.""" + x, y = self.center + radius = np.sqrt((self.width / 2)**2 + (self.height / 2)**2) + return Circle(x, y, radius) + + def contains_point(self, x: float, y: float) -> bool: + # First invert translation, then invert rotation. + rx, ry = np.array([x - self.x, y - self.y + ]) @ self.inverse_rotation_matrix.T + return 0 <= rx <= self.width and \ + 0 <= ry <= self.height + + def sample_random_point(self, + rng: np.random.Generator) -> Tuple[float, float]: + rand_width = rng.uniform(0, self.width) + rand_height = rng.uniform(0, self.height) + # First rotate, then translate. + rx, ry = np.array([rand_width, rand_height]) @ self.rotation_matrix.T + x = rx + self.x + y = ry + self.y + assert self.contains_point(x, y) + return (x, y) + + def rotate_about_point(self, x: float, y: float, rot: float) -> Rectangle: + """Create a new rectangle that is this rectangle, but rotated CCW by + the given rotation (in radians), relative to the (x, y) origin. + + Rotates the vertices first, then uses them to recompute the new + theta. + """ + vertices = np.array(self.vertices) + origin = np.array([x, y]) + # Translate the vertices so that they become the "origin". + vertices = vertices - origin + # Rotate. + rotate_matrix = np.array([[np.cos(rot), -np.sin(rot)], + [np.sin(rot), np.cos(rot)]]) + vertices = vertices @ rotate_matrix.T + # Translate the vertices back. + vertices = vertices + origin + # Recompute theta. + (lx, ly), _, _, (rx, ry) = vertices + theta = np.arctan2(ry - ly, rx - lx) + rect = Rectangle(lx, ly, self.width, self.height, theta) + assert np.allclose(rect.vertices, vertices) + return rect + + def plot(self, ax: plt.Axes, **kwargs: Any) -> None: + angle = self.theta * 180 / np.pi + patch = patches.Rectangle((self.x, self.y), + self.width, + self.height, + angle=angle, + **kwargs) + ax.add_patch(patch) + + +def line_segments_intersect(seg1: LineSegment, seg2: LineSegment) -> bool: + """Checks if two line segments intersect. + + This method, which works by checking relative orientation, allows + for collinearity, and only checks if each segment straddles the line + containing the other. + """ + + def _subtract(a: Tuple[float, float], b: Tuple[float, float]) \ + -> Tuple[float, float]: + x1, y1 = a + x2, y2 = b + return (x1 - x2), (y1 - y2) + + def _cross_product(a: Tuple[float, float], b: Tuple[float, float]) \ + -> float: + x1, y1 = b + x2, y2 = a + return x1 * y2 - x2 * y1 + + def _direction(a: Tuple[float, float], b: Tuple[float, float], + c: Tuple[float, float]) -> float: + return _cross_product(_subtract(a, c), _subtract(a, b)) + + p1 = (seg1.x1, seg1.y1) + p2 = (seg1.x2, seg1.y2) + p3 = (seg2.x1, seg2.y1) + p4 = (seg2.x2, seg2.y2) + d1 = _direction(p3, p4, p1) + d2 = _direction(p3, p4, p2) + d3 = _direction(p1, p2, p3) + d4 = _direction(p1, p2, p4) + + return ((d2 < 0 < d1) or (d1 < 0 < d2)) and ((d4 < 0 < d3) or + (d3 < 0 < d4)) + + +def circles_intersect(circ1: Circle, circ2: Circle) -> bool: + """Checks if two circles intersect.""" + x1, y1, r1 = circ1.x, circ1.y, circ1.radius + x2, y2, r2 = circ2.x, circ2.y, circ2.radius + return (x1 - x2)**2 + (y1 - y2)**2 < (r1 + r2)**2 + + +def rectangles_intersect(rect1: Rectangle, rect2: Rectangle) -> bool: + """Checks if two rectangles intersect.""" + # Optimization: if the circumscribed circles don't intersect, then + # the rectangles also don't intersect. + if not circles_intersect(rect1.circumscribed_circle, + rect2.circumscribed_circle): + return False + # Case 1: line segments intersect. + if any( + line_segments_intersect(seg1, seg2) for seg1 in rect1.line_segments + for seg2 in rect2.line_segments): + return True + # Case 2: rect1 inside rect2. + if rect1.contains_point(rect2.center[0], rect2.center[1]): + return True + # Case 3: rect2 inside rect1. + if rect2.contains_point(rect1.center[0], rect1.center[1]): + return True + # Not intersecting. + return False + + +def line_segment_intersects_circle(seg: LineSegment, + circ: Circle, + ax: Optional[plt.Axes] = None) -> bool: + """Checks if a line segment intersects a circle. + + If ax is not None, a diagram is plotted on the axis to illustrate + the computations, which is useful for checking correctness. + """ + # First check if the end points of the segment are in the circle. + if circ.contains_point(seg.x1, seg.y1): + return True + if circ.contains_point(seg.x2, seg.y2): + return True + # Project the circle radius onto the extended line. + c = (circ.x, circ.y) + # Project (a, c) onto (a, b). + a = (seg.x1, seg.y1) + b = (seg.x2, seg.y2) + ba = np.subtract(b, a) + ca = np.subtract(c, a) + da = ba * np.dot(ca, ba) / np.dot(ba, ba) + # The point on the extended line that is the closest to the center. + d = dx, dy = (a[0] + da[0], a[1] + da[1]) + # Optionally plot the important points. + if ax is not None: + circ.plot(ax, color="red", alpha=0.5) + seg.plot(ax, color="black", linewidth=2) + ax.annotate("A", a) + ax.annotate("B", b) + ax.annotate("C", c) + ax.annotate("D", d) + # Check if the point is on the line. If it's not, there is no intersection, + # because we already checked that the circle does not contain the end + # points of the line segment. + if not seg.contains_point(dx, dy): + return False + # So d is on the segment. Check if it's in the circle. + return circ.contains_point(dx, dy) + + +def line_segment_intersects_rectangle(seg: LineSegment, + rect: Rectangle) -> bool: + """Checks if a line segment intersects a rectangle.""" + # Case 1: one of the end points of the segment is in the rectangle. + if rect.contains_point(seg.x1, seg.y1) or \ + rect.contains_point(seg.x2, seg.y2): + return True + # Case 2: the segment intersects with one of the rectangle sides. + return any(line_segments_intersect(s, seg) for s in rect.line_segments) + + +def rectangle_intersects_circle(rect: Rectangle, circ: Circle) -> bool: + """Checks if a rectangle intersects a circle.""" + # Optimization: if the circumscribed circle of the rectangle doesn't + # intersect with the circle, then there can't be an intersection. + if not circles_intersect(rect.circumscribed_circle, circ): + return False + # Case 1: the circle's center is in the rectangle. + if rect.contains_point(circ.x, circ.y): + return True + # Case 2: one of the sides of the rectangle intersects the circle. + for seg in rect.line_segments: + if line_segment_intersects_circle(seg, circ): + return True + return False + + +def geom2ds_intersect(geom1: _Geom2D, geom2: _Geom2D) -> bool: + """Check if two 2D bodies intersect.""" + if isinstance(geom1, LineSegment) and isinstance(geom2, LineSegment): + return line_segments_intersect(geom1, geom2) + if isinstance(geom1, LineSegment) and isinstance(geom2, Circle): + return line_segment_intersects_circle(geom1, geom2) + if isinstance(geom1, LineSegment) and isinstance(geom2, Rectangle): + return line_segment_intersects_rectangle(geom1, geom2) + if isinstance(geom1, Rectangle) and isinstance(geom2, LineSegment): + return line_segment_intersects_rectangle(geom2, geom1) + if isinstance(geom1, Circle) and isinstance(geom2, LineSegment): + return line_segment_intersects_circle(geom2, geom1) + if isinstance(geom1, Rectangle) and isinstance(geom2, Rectangle): + return rectangles_intersect(geom1, geom2) + if isinstance(geom1, Rectangle) and isinstance(geom2, Circle): + return rectangle_intersects_circle(geom1, geom2) + if isinstance(geom1, Circle) and isinstance(geom2, Rectangle): + return rectangle_intersects_circle(geom2, geom1) + if isinstance(geom1, Circle) and isinstance(geom2, Circle): + return circles_intersect(geom1, geom2) + raise NotImplementedError("Intersection not implemented for geoms " + f"{geom1} and {geom2}") + + +@functools.lru_cache(maxsize=None) +def unify(atoms1: FrozenSet[LiftedOrGroundAtom], + atoms2: FrozenSet[LiftedOrGroundAtom]) -> Tuple[bool, EntToEntSub]: + """Return whether the given two sets of atoms can be unified. + + Also return the mapping between variables/objects in these atom + sets. This mapping is empty if the first return value is False. + """ + atoms_lst1 = sorted(atoms1) + atoms_lst2 = sorted(atoms2) + + # Terminate quickly if there is a mismatch between predicates + preds1 = [atom.predicate for atom in atoms_lst1] + preds2 = [atom.predicate for atom in atoms_lst2] + if preds1 != preds2: + return False, {} + + # Terminate quickly if there is a mismatch between numbers + num1 = len({o for atom in atoms_lst1 for o in atom.entities}) + num2 = len({o for atom in atoms_lst2 for o in atom.entities}) + if num1 != num2: + return False, {} + + # Try to get lucky with a one-to-one mapping + subs12: EntToEntSub = {} + subs21 = {} + success = True + for atom1, atom2 in zip(atoms_lst1, atoms_lst2): + if not success: + break + for v1, v2 in zip(atom1.entities, atom2.entities): + if v1 in subs12 and subs12[v1] != v2: + success = False + break + if v2 in subs21: + success = False + break + subs12[v1] = v2 + subs21[v2] = v1 + if success: + return True, subs12 + + # If all else fails, use search + solved, sub = find_substitution(atoms_lst1, atoms_lst2) + rev_sub = {v: k for k, v in sub.items()} + return solved, rev_sub + + +@functools.lru_cache(maxsize=None) +def unify_preconds_effects_options( + preconds1: FrozenSet[LiftedOrGroundAtom], + preconds2: FrozenSet[LiftedOrGroundAtom], + add_effects1: FrozenSet[LiftedOrGroundAtom], + add_effects2: FrozenSet[LiftedOrGroundAtom], + delete_effects1: FrozenSet[LiftedOrGroundAtom], + delete_effects2: FrozenSet[LiftedOrGroundAtom], + param_option1: ParameterizedOption, param_option2: ParameterizedOption, + option_args1: Tuple[_TypedEntity, ...], + option_args2: Tuple[_TypedEntity, ...]) -> Tuple[bool, EntToEntSub]: + """Wrapper around unify() that handles option arguments, preconditions, add + effects, and delete effects. + + Changes predicate names so that all are treated differently by + unify(). + """ + if param_option1 != param_option2: + # Can't unify if the parameterized options are different. + return False, {} + opt_arg_pred1 = Predicate("OPT-ARGS", [a.type for a in option_args1], + _classifier=lambda s, o: False) # dummy + f_option_args1 = frozenset({GroundAtom(opt_arg_pred1, option_args1)}) + new_preconds1 = wrap_atom_predicates(preconds1, "PRE-") + f_new_preconds1 = frozenset(new_preconds1) + new_add_effects1 = wrap_atom_predicates(add_effects1, "ADD-") + f_new_add_effects1 = frozenset(new_add_effects1) + new_delete_effects1 = wrap_atom_predicates(delete_effects1, "DEL-") + f_new_delete_effects1 = frozenset(new_delete_effects1) + + opt_arg_pred2 = Predicate("OPT-ARGS", [a.type for a in option_args2], + _classifier=lambda s, o: False) # dummy + f_option_args2 = frozenset({LiftedAtom(opt_arg_pred2, option_args2)}) + new_preconds2 = wrap_atom_predicates(preconds2, "PRE-") + f_new_preconds2 = frozenset(new_preconds2) + new_add_effects2 = wrap_atom_predicates(add_effects2, "ADD-") + f_new_add_effects2 = frozenset(new_add_effects2) + new_delete_effects2 = wrap_atom_predicates(delete_effects2, "DEL-") + f_new_delete_effects2 = frozenset(new_delete_effects2) + + all_atoms1 = (f_option_args1 | f_new_preconds1 | f_new_add_effects1 + | f_new_delete_effects1) + all_atoms2 = (f_option_args2 | f_new_preconds2 | f_new_add_effects2 + | f_new_delete_effects2) + return unify(all_atoms1, all_atoms2) + + +def wrap_predicate(predicate: Predicate, prefix: str) -> Predicate: + """Return a new predicate which adds the given prefix string to the name. + + NOTE: the classifier is removed. + """ + new_predicate = Predicate(prefix + predicate.name, + predicate.types, + _classifier=lambda s, o: False) # dummy + return new_predicate + + +def wrap_atom_predicates(atoms: Collection[LiftedOrGroundAtom], + prefix: str) -> Set[LiftedOrGroundAtom]: + """Return a new set of atoms which adds the given prefix string to the name + of every atom's predicate. + + NOTE: all the classifiers are removed. + """ + new_atoms = set() + for atom in atoms: + new_predicate = wrap_predicate(atom.predicate, prefix) + new_atoms.add(atom.__class__(new_predicate, atom.entities)) + return new_atoms + + +class LinearChainParameterizedOption(ParameterizedOption): + """A parameterized option implemented via a sequence of "child" + parameterized options. + + This class is meant to help ParameterizedOption manual design. + + The children are executed in order starting with the first in the sequence + and transitioning when the terminal function of each child is hit. + + The children are assumed to chain together, so the initiable of the next + child should always be True when the previous child terminates. If this + is not the case, an AssertionError is raised. + + The children must all have the same types and params_space, which in turn + become the types and params_space for this ParameterizedOption. + + The LinearChainParameterizedOption has memory, which stores the current + child index. + """ + + def __init__(self, name: str, + children: Sequence[ParameterizedOption]) -> None: + assert len(children) > 0 + self._children = children + + # Make sure that the types and params spaces are consistent. + types = children[0].types + params_space = children[0].params_space + for i in range(1, len(self._children)): + child = self._children[i] + assert types == child.types + assert np.allclose(params_space.low, child.params_space.low) + assert np.allclose(params_space.high, child.params_space.high) + + super().__init__(name, + types, + params_space, + policy=self._policy, + initiable=self._initiable, + terminal=self._terminal) + + def _initiable(self, state: State, memory: Dict, objects: Sequence[Object], + params: Array) -> bool: + # Initialize the current child to the first one. + memory["current_child_index"] = 0 + # Create memory dicts for each child to avoid key collisions. One + # example of a failure that arises without this is when using + # multiple SingletonParameterizedOption instances, each of those + # options would be referencing the same start_state in memory. + memory["child_memory"] = [{} for _ in self._children] + current_child = self._children[0] + child_memory = memory["child_memory"][0] + return current_child.initiable(state, child_memory, objects, params) + + def _policy(self, state: State, memory: Dict, objects: Sequence[Object], + params: Array) -> Action: + # Check if the current child has terminated. + current_index = memory["current_child_index"] + current_child = self._children[current_index] + child_memory = memory["child_memory"][current_index] + if current_child.terminal(state, child_memory, objects, params): + # Move on to the next child. + current_index += 1 + memory["current_child_index"] = current_index + current_child = self._children[current_index] + child_memory = memory["child_memory"][current_index] + assert current_child.initiable(state, child_memory, objects, + params) + # logging.debug(f"Executing {current_child.name}") + return current_child.policy(state, child_memory, objects, params) + + def _terminal(self, state: State, memory: Dict, objects: Sequence[Object], + params: Array) -> bool: + # Check if the last child has terminated. + current_index = memory["current_child_index"] + if current_index < len(self._children) - 1: + return False + current_child = self._children[current_index] + child_memory = memory["child_memory"][current_index] + return current_child.terminal(state, child_memory, objects, params) + + +class SingletonParameterizedOption(ParameterizedOption): + """A parameterized option that takes a single action and stops. + + For convenience: + * Initiable defaults to always True. + * Types defaults to []. + * Params space defaults to Box(0, 1, (0, )). + """ + + def __init__( + self, + name: str, + policy: Callable[[State, Dict, Sequence[Object], Array], Action], + types: Optional[Sequence[Type]] = None, + params_space: Optional[Box] = None, + initiable: Optional[Callable[[State, Dict, Sequence[Object], Array], + bool]] = None + ) -> None: + if types is None: + types = [] + if params_space is None: + params_space = Box(0, 1, (0, )) + if initiable is None: + initiable = lambda _1, _2, _3, _4: True + + # Wrap the given initiable so that we can track whether the action + # has been executed yet. + def _initiable(state: State, memory: Dict, objects: Sequence[Object], + params: Array) -> bool: + if "start_state" in memory: + assert state.allclose(memory["start_state"]) + # Always update the memory dict due to the "is" check in _terminal. + memory["start_state"] = state + assert initiable is not None + return initiable(state, memory, objects, params) + + def _terminal(state: State, memory: Dict, objects: Sequence[Object], + params: Array) -> bool: + del objects, params # unused + assert "start_state" in memory, \ + "Must call initiable() before terminal()." + return state is not memory["start_state"] + + super().__init__(name, + types, + params_space, + policy=policy, + initiable=_initiable, + terminal=_terminal) + + +class PyBulletState(State): + """A PyBullet state that stores the robot joint positions in addition to + the features that are exposed in the object-centric state.""" + + @property + def joint_positions(self) -> JointPositions: + """Expose the current joints state in the simulator_state.""" + # if the simulator state is an array + if isinstance(self.simulator_state, Dict): + jp = self.simulator_state["joint_positions"] + else: + jp = self.simulator_state + return cast(JointPositions, jp) + + @property + def state_image(self) -> PIL.Image.Image: + """Expose the current image state in the simulator_state.""" + assert isinstance(self.simulator_state, Dict) + return self.simulator_state["unlabeled_image"] + + @property + def labeled_image(self) -> Optional[PIL.Image.Image]: + """Expose the current image state in the simulator_state.""" + assert isinstance(self.simulator_state, Dict) + return self.simulator_state.get("images") + + @property + def obj_mask_dict(self) -> Optional[Dict[Object, Mask]]: + """Expose the current object masks in the simulator_state.""" + assert isinstance(self.simulator_state, Dict) + return self.simulator_state.get("obj_mask_dict") + + def allclose(self, other: State) -> bool: + # Ignores the simulator state. + return State(self.data).allclose(State(other.data)) + + def copy(self) -> PyBulletState: + copied = super().copy() + state_dict_copy = copied.data + # simulator_state_copy = list(self.joint_positions) + simulator_state_copy = copied.simulator_state + # Forward the hidden blocks `super().copy()` deep-copied: `latent` + # (agent belief) and `privileged` (env-hidden ground truth). Both + # are dropped if not passed explicitly, since this rebuilds the + # PyBulletState rather than returning `copied`. + return PyBulletState(state_dict_copy, + simulator_state_copy, + latent=copied.latent, + privileged=copied.privileged) + + def get_obj_mask(self, obj: Object) -> Mask: + """Return the mask for the object.""" + assert self.obj_mask_dict is not None + mask = self.obj_mask_dict.get(obj) + assert mask is not None + return mask + + def label_all_objects(self) -> None: + """Label all objects in the simulator state.""" + # Lazy import: image_patch_wrapper pulls in torch/torchvision + # which aren't available in Pyodide. CPython callers see it via + # the deferred import here. + from predicators.image_patch_wrapper import \ + ImagePatch # noqa: PLC0415 # pylint: disable=import-outside-toplevel + state_ip = ImagePatch(self) + obj_mask_dict = self.obj_mask_dict + assert obj_mask_dict is not None + state_ip.label_all_objects(obj_mask_dict) + assert isinstance(self.simulator_state, Dict) + self.simulator_state["images"] = state_ip.cropped_image_in_PIL + + def add_images_and_masks(self, unlabeled_image: PIL.Image.Image, + masks: Dict[Object, Mask]) -> None: + """Add the unlabeled image and object masks to the simulator state.""" + assert isinstance(self.simulator_state, Dict) + self.simulator_state["unlabeled_image"] = unlabeled_image + self.simulator_state["obj_mask_dict"] = masks + self.label_all_objects() + + +BoundingBox = namedtuple('BoundingBox', 'left lower right upper') + + +@dataclass +class VLMState(PyBulletState): + """PyBulletState extended with VLM/visual perception capabilities.""" + state_image: PIL.Image.Image = None # type: ignore[assignment] + obj_mask_dict: Dict[Object, Mask] = field(default_factory=dict) + labeled_image: Optional[PIL.Image.Image] = None # type: ignore[assignment] + option_history: Optional[List[str]] = None + bbox_features: Dict[Object, np.ndarray] = field( + default_factory=lambda: defaultdict(lambda: np.zeros(4))) + prev_state: Optional[VLMState] = None + + def __hash__(self) -> int: + data_tuple = tuple((k, tuple(v)) for k, v in sorted(self.data.items())) + if self.simulator_state is not None: + data_tuple += tuple(self.simulator_state) + return hash(data_tuple) + + def evaluate_simple_assertion( + self, assertion: str, image: Tuple[BoundingBox, + Sequence[Object]]) -> VLMQuery: + """Given an assertion and an image, queries a VLM and returns whether + the assertion is true or false.""" + bbox, objs = image + return VLMQuery(assertion, bbox, list(objs)) + + def generate_previous_option_message(self) -> str: + """Generate the message for the previous option.""" + assert self.option_history is not None + msg = "Evaluate the truth value of the following assertions in the "\ + "current state as depicted by the image" + if CFG.nsp_pred_include_prev_image_in_prompt and \ + self.prev_state is not None: + msg += " labeled with 'curr. state'" + if CFG.nsp_pred_include_state_str_in_prompt: + msg += " and the information below" + + msg += ".\n" + + if CFG.nsp_pred_include_state_str_in_prompt: + msg += "We have the object positions and the robot's "\ + "proprioception:\n" + msg += self.dict_str(indent=2, + object_features=False, + use_object_id=True, + position_proprio_features=True) + msg += "\n" + + if len(self.option_history) == 0: + msg += "For context, this is at the beginning of a task, before "\ + "the robot has done anything.\n" + else: + msg += "For context, the state is right after the robot has"\ + " successfully executed the action "\ + f"{self.option_history[-1]}." + if CFG.nsp_pred_include_state_str_in_prompt: + if self.prev_state is not None: + msg += " The object position and robot proprioception "\ + "before executing the action is:\n" + msg += self.prev_state.dict_str( + indent=2, + object_features=False, + use_object_id=True, + position_proprio_features=True) + msg += "\n" + if CFG.nsp_pred_include_prev_image_in_prompt: + msg += " The state before executing the action is depicted"\ + " by the image labeled with 'prev. state'." + msg += " Please carefully examine the images depicting the "\ + "'prev. state' and 'curr. state' before making a judgment." + msg += "\n" + msg += "The assertions to evaluate are:" + return msg + + def add_bbox_features(self) -> None: + """Add the features about the bounding box to the objects.""" + for obj, mask in self.obj_mask_dict.items(): + bbox = mask_to_bbox(mask) + for name, value in bbox._asdict().items(): + self.set(obj, f"bbox_{name}", value) + + def set(self, obj: Object, feature_name: str, feature_val: Any) -> None: + """Set the value of an object feature by name.""" + idx = obj.type.feature_names.index(feature_name) + standard_feature_len = len(self.data[obj]) + if idx >= standard_feature_len: + self.bbox_features[obj][idx - standard_feature_len] = feature_val + else: + self.data[obj][idx] = feature_val + + def get(self, obj: Object, feature_name: str) -> Any: + idx = obj.type.feature_names.index(feature_name) + standard_feature_len = len(self.data[obj]) + if idx >= standard_feature_len: + return self.bbox_features[obj][idx - standard_feature_len] + return self.data[obj][idx] + + def dict_str( # type: ignore[override] + self, + indent: int = 0, + object_features: bool = True, + num_decimal_points: int = 2, + use_object_id: bool = False, + position_proprio_features: bool = False) -> str: + """Return a dictionary representation of the state.""" + state_dict = {} + for obj in self: + obj_dict = {} + for attribute, value in zip( + obj.type.feature_names, + np.concatenate([self[obj], self.bbox_features[obj]]) + if self.bbox_features else self[obj]): + if (position_proprio_features and attribute + in ["rot", "fingers"]) or (object_features + and attribute not in [ + "is_heavy", + ]): + if isinstance(value, (float, int, np.float32)): + value = round(float(value), 1) + obj_dict[attribute] = value + + if use_object_id: + obj_name = obj.id_name + else: + obj_name = obj.name + state_dict[f"{obj_name}:{obj.type.name}"] = obj_dict + + spaces = " " * indent + dict_str = spaces + "{" + n_keys = len(state_dict.keys()) + for i, (key, value) in enumerate(state_dict.items()): + value_str = ', '.join(f"'{k}': {v}" for k, v in value.items()) + if value_str == "": + content_str = f"'{key}'" + else: + content_str = f"'{key}': {{{value_str}}}" + if i == 0: + dict_str += f"{content_str},\n" + elif i == n_keys - 1: + dict_str += spaces + f" {content_str}" + else: + dict_str += spaces + f" {content_str},\n" + dict_str += "}" + return dict_str + + def __eq__(self, other: object) -> bool: + assert isinstance(other, VLMState) + if len(self.data) != len(other.data): + return False + for key, value in self.data.items(): + if key not in other.data or not np.array_equal( + value, other.data[key]): + return False + return self.simulator_state == other.simulator_state + + def label_all_objects(self) -> None: + from predicators.image_patch_wrapper import \ + ImagePatch # noqa: PLC0415 # pylint: disable=import-outside-toplevel + state_ip = ImagePatch(self) + state_ip.label_all_objects(self.obj_mask_dict) + self.labeled_image = state_ip.cropped_image_in_PIL + + def copy(self) -> VLMState: + pybullet_state_copy = super().copy() + state_image_copy = copy.copy(self.state_image) + obj_mask_copy = copy.deepcopy(self.obj_mask_dict) + labeled_image_copy = copy.copy(self.labeled_image) + option_history_copy = copy.copy(self.option_history) + bbox_features_copy = copy.deepcopy(self.bbox_features) + prev_state_copy = self.prev_state.copy() if self.prev_state else None + # Use kwargs for the VLM-specific fields so positional shifts in + # the base `State` dataclass (e.g. the `latent` block added for + # the recurrent partial-observability approach) don't reorder + # this call. + return VLMState( + data=pybullet_state_copy.data, + simulator_state=pybullet_state_copy.simulator_state, + latent=pybullet_state_copy.latent, + privileged=pybullet_state_copy.privileged, + state_image=state_image_copy, + obj_mask_dict=obj_mask_copy, + labeled_image=labeled_image_copy, + option_history=option_history_copy, + bbox_features=bbox_features_copy, + prev_state=prev_state_copy, + ) + + def get_obj_mask(self, obj: Object) -> Mask: + """Return the mask for the object.""" + return self.obj_mask_dict[obj] + + def get_obj_bbox(self, obj: Object) -> BoundingBox: + """Get the bounding box of the object in the state image.""" + mask = self.get_obj_mask(obj) + return mask_to_bbox(mask) + + def crop_to_objects( # pylint: disable=missing-function-docstring + self, + objects: Sequence[Object], + left_margin: int = 30, + lower_margin: int = 30, + right_margin: int = 30, + top_margin: int = 30) -> Tuple[BoundingBox, Sequence[Object]]: + bboxes = [self.get_obj_bbox(obj) for obj in objects] + bbox = smallest_bbox_from_bboxes(bboxes) + return (BoundingBox( + max(bbox.left - left_margin, 0), max(bbox.lower - lower_margin, 0), + min(bbox.right + right_margin, self.state_image.width), + min(bbox.upper + top_margin, self.state_image.height)), objects) + + +@dataclass +class VLMQuery: + """A class to represent a query to a VLM.""" + query_str: str + attention_box: BoundingBox + attn_objects: Optional[List[Object]] = None + ground_atom: Optional[GroundAtom] = None + + +def mask_to_bbox(mask: Mask) -> BoundingBox: + """Return the bounding box of the mask.""" + y_indices, x_indices = np.where(mask) + height = mask.shape[0] + + # Get the bounding box + try: + left = x_indices.min() + right = x_indices.max() + lower = height - (y_indices.max() + 1) + upper = height - (y_indices.min() + 1) + except ValueError: + left, lower, right, upper = 0, 0, 0, 0 + # If the mask is empty, return a bounding box with all zeros + + return BoundingBox(left, lower, right, upper) + + +def smallest_bbox_from_bboxes(bboxes: Sequence[BoundingBox]) -> BoundingBox: + """Return the smallest bounding box that contains all the given + bounding.""" + + # Initialize the bounding box coordinates + left, lower, right, upper = np.inf, np.inf, -np.inf, -np.inf + # Iterate over all masks + for bbox in bboxes: + # Update the bounding box + left = min(left, bbox.left) + lower = min(lower, bbox.lower) + right = max(right, bbox.right) + upper = max(upper, bbox.upper) + return BoundingBox(left, lower, right, upper) + + +class StateWithCache(State): + """A state with a cache stored in the simulator state that is ignored for + state equality checks. + + The cache is deliberately not copied. + """ + + @property + def cache(self) -> Dict[str, Dict]: + """Expose the cache in the simulator_state.""" + return cast(Dict[str, Dict], self.simulator_state) + + def allclose(self, other: State) -> bool: + # Ignores the simulator state. + return State(self.data).allclose(State(other.data)) + + def copy(self) -> State: + copied = super().copy() + # The cache (simulator_state) is deliberately shared, not copied; + # forward the hidden latent/privileged blocks so they survive. + return StateWithCache(copied.data, + self.cache, + latent=copied.latent, + privileged=copied.privileged) + + +class LoggingMonitor(abc.ABC): + """Observes states and actions during environment interaction.""" + + @abc.abstractmethod + def reset(self, train_or_test: str, task_idx: int) -> None: + """Called when the monitor starts a new episode.""" + raise NotImplementedError("Override me!") + + @abc.abstractmethod + def observe(self, obs: Observation, action: Optional[Action]) -> None: + """Record an observation and the action that is about to be taken. + + On the last timestep of a trajectory, no action is taken, so + action is None. + """ + raise NotImplementedError("Override me!") + + +def run_policy( + policy: Callable[[State], Action], + env: BaseEnv, + train_or_test: str, + task_idx: int, + termination_function: Callable[[State], bool], + max_num_steps: int, + do_env_reset: bool = True, + exceptions_to_break_on: Optional[Set[TypingType[Exception]]] = None, + monitor: Optional[LoggingMonitor] = None +) -> Tuple[LowLevelTrajectory, Metrics]: + """Execute a policy starting from the initial state of a train or test task + in the environment. The task's goal is not used. + + Note that the environment internal state is updated. + + Terminates when any of these conditions hold: + (1) the termination_function returns True + (2) max_num_steps is reached + (3) policy() or step() raise an exception of type in exceptions_to_break_on + + Note that in the case where the exception is raised in step, we exclude the + last action from the returned trajectory to maintain the invariant that + the trajectory states are of length one greater than the actions. + + NOTE: this may be deprecated in the future in favor of run_episode defined + in cogman.py. Ideally, we should consolidate both run_policy and + run_policy_with_simulator below into run_episode. + """ + if do_env_reset: + env.reset(train_or_test, task_idx) + if monitor is not None: + monitor.reset(train_or_test, task_idx) + obs = env.get_observation() + assert isinstance(obs, State) + state = obs + states = [state] + actions: List[Action] = [] + metrics: Metrics = defaultdict(float) + metrics["policy_call_time"] = 0.0 + exception_raised_in_step = False + if not termination_function(state): + for _ in range(max_num_steps): + monitor_observed = False + exception_raised_in_step = False + try: + start_time = time.perf_counter() + act = policy(state) + metrics["policy_call_time"] += time.perf_counter() - start_time + except Exception as e: # pylint: disable=broad-except + if not CFG.video_not_break_on_exception: + if exceptions_to_break_on is not None and \ + type(e) in exceptions_to_break_on: + if monitor_observed: + exception_raised_in_step = True + break + raise e + if monitor is not None and not monitor_observed: + monitor.observe(state, None) + monitor_observed = True + else: + if monitor is not None and not monitor_observed: + monitor.observe(state, act) + monitor_observed = True + + try: + # Note: it's important to call monitor.observe() before + # env.step(), because the monitor may use the environment's + # internal state. + state = env.step(act) + actions.append(act) + states.append(state) + except Exception as e: + if exceptions_to_break_on is not None and \ + type(e) in exceptions_to_break_on: + if monitor_observed: + exception_raised_in_step = True + break + raise e + if termination_function(state): + break + if monitor is not None and not exception_raised_in_step: + monitor.observe(state, None) + traj = LowLevelTrajectory(states, actions) + return traj, metrics + + +def run_policy_with_simulator( + policy: Callable[[State], Action], + simulator: Callable[[State, Action], State], + init_state: State, + termination_function: Callable[[State], bool], + max_num_steps: int, + exceptions_to_break_on: Optional[Set[TypingType[Exception]]] = None, + monitor: Optional[LoggingMonitor] = None) -> LowLevelTrajectory: + """Execute a policy from a given initial state, using a simulator. + + *** This function should not be used with any core code, because we want + to avoid the assumption of a simulator when possible. *** + + This is similar to run_policy, with three major differences: + (1) The initial state `init_state` can be any state, not just the initial + state of a train or test task. (2) A simulator (function that takes state + as input) is assumed. (3) Metrics are not returned. + + Note that the environment internal state is NOT updated. + + Terminates when any of these conditions hold: + (1) the termination_function returns True + (2) max_num_steps is reached + (3) policy() or step() raise an exception of type in exceptions_to_break_on + + Note that in the case where the exception is raised in step, we exclude the + last action from the returned trajectory to maintain the invariant that + the trajectory states are of length one greater than the actions. + """ + state = init_state + states = [state] + actions: List[Action] = [] + exception_raised_in_step = False + if not termination_function(state): + for i in range(max_num_steps): + # logging.debug(f"State: {state.pretty_str()}") + monitor_observed = False + exception_raised_in_step = False + try: + act = policy(state) + # logging.debug(f"Action: {act}") + if monitor is not None: + monitor.observe(state, act) + monitor_observed = True + state = simulator(state, act) + actions.append(act) + states.append(state) + except Exception as e: + logging.debug(f"Exception during running policy: {e}") + if exceptions_to_break_on is not None and \ + type(e) in exceptions_to_break_on: + if monitor_observed: + exception_raised_in_step = True + break + if monitor is not None and not monitor_observed: + monitor.observe(state, None) + raise e + if termination_function(state): + break + logging.debug(f"Ran {i + 1} steps") + if monitor is not None and not exception_raised_in_step: + monitor.observe(state, None) + traj = LowLevelTrajectory(states, actions) + return traj + + +class ExceptionWithInfo(Exception): + """An exception with an optional info dictionary that is initially + empty.""" + + def __init__(self, message: str, info: Optional[Dict] = None) -> None: + super().__init__(message) + if info is None: + info = {} + assert isinstance(info, dict) + self.info = info + + +class OptionExecutionFailure(ExceptionWithInfo): + """An exception raised by an option policy in the course of execution.""" + + +class OptionTimeoutFailure(OptionExecutionFailure): + """A special kind of option execution failure due to an exceeded budget.""" + + +class RequestActPolicyFailure(ExceptionWithInfo): + """An exception raised by an acting policy in a request when it fails to + produce an action, which terminates the interaction.""" + + +class HumanDemonstrationFailure(ExceptionWithInfo): + """An exception raised when CFG.demonstrator == "human" and the human gives + a bad input.""" + + +class EnvironmentFailure(ExceptionWithInfo): + """Exception raised when any type of failure occurs in an environment. + + The info dictionary must contain a key "offending_objects", which + maps to a set of objects responsible for the failure. + """ + + def __repr__(self) -> str: + return f"{super().__repr__()}: {self.info}" + + def __str__(self) -> str: + return repr(self) + + +def check_wait_target_atoms( + option: _Option, + state: State, + abstract_function: Callable[[State], Set[GroundAtom]], +) -> Optional[bool]: + """Check if a Wait option's target atoms are satisfied. + + Returns True if targets are met (Wait should terminate), False if + not yet met, or None if no targets were specified (caller should + fall back to any-atom-change behaviour). + """ + pos = option.memory.get("wait_target_atoms", set()) + neg = option.memory.get("wait_target_neg_atoms", set()) + if not pos and not neg: + return None + cur_atoms = abstract_function(state) + return pos.issubset(cur_atoms) and neg.isdisjoint(cur_atoms) + + +def parse_wait_target_annotations( + line: str, + predicates: Collection[Predicate], + objects: Collection[Object], +) -> Tuple[Set[GroundAtom], Set[GroundAtom]]: + """Parse ``-> {Pred(...), NOT Pred(...)}`` from a plan line. + + Returns ``(positive_atoms, negative_atoms)`` where positive atoms + must become TRUE and negative atoms must become FALSE for the Wait + to terminate. + """ + pred_map = {p.name: p for p in predicates} + obj_map = {o.name: o for o in objects} + + sg_match = re.search(r'->\s*\{([^}]*)\}', line) + if not sg_match: + return set(), set() + + pos_atoms: Set[GroundAtom] = set() + neg_atoms: Set[GroundAtom] = set() + atom_re = re.compile(r'(NOT\s+)?(\w+)\(([^)]*)\)') + + for m in atom_re.finditer(sg_match.group(1)): + is_neg = m.group(1) is not None + pred_name = m.group(2) + obj_names = [n.strip().split(':')[0] for n in m.group(3).split(',')] + + if pred_name not in pred_map: + logging.warning("Unknown predicate in Wait target: %s", pred_name) + continue + pred = pred_map[pred_name] + try: + objs = [obj_map[n] for n in obj_names] + except KeyError as e: + logging.warning("Unknown object in Wait target: %s", e) + continue + if len(objs) != len(pred.types): + logging.warning("Arity mismatch for %s: expected %d, got %d", + pred_name, len(pred.types), len(objs)) + continue + atom = GroundAtom(pred, objs) + if is_neg: + neg_atoms.add(atom) + else: + pos_atoms.add(atom) + + return pos_atoms, neg_atoms + + +def inject_wait_targets_for_option( + option: _Option, + step_idx: int, + atoms_sequence: Sequence[Set[GroundAtom]], +) -> None: + """Inject Wait target atoms into a single option from atoms_sequence. + + Computes the expected atom delta from ``atoms_sequence[step_idx]`` + to ``atoms_sequence[step_idx + 1]`` and stores it in the option's + memory so that execution terminates on specific atoms rather than + any noisy change. No-op for non-Wait options or out-of-bounds + indices. + """ + if option.name != "Wait": + return + if step_idx + 1 >= len(atoms_sequence): + return + before = atoms_sequence[step_idx] + after = atoms_sequence[step_idx + 1] + target_pos = after - before + target_neg = before - after + if target_pos: + option.memory["wait_target_atoms"] = target_pos + if target_neg: + option.memory["wait_target_neg_atoms"] = target_neg + + +def strip_wait_annotations(text: str) -> str: + """Remove ``-> {...}`` annotations from plan text lines.""" + return re.sub(r'\s*->\s*\{[^}]*\}', '', text) + + +def _format_wait_target_debug( + state: State, target_atoms: Set[GroundAtom], + abstract_function: Callable[[State], Set[GroundAtom]]) -> str: + """Format state details for debugging why Wait has not terminated.""" + cur_atoms = abstract_function(state) + missing_targets = target_atoms - cur_atoms + target_objects = sorted( + { + ent + for atom in target_atoms + for ent in atom.entities if isinstance(ent, Object) + }, + key=lambda o: o.name) + object_details = [] + for obj in target_objects: + feature_values = [] + for feature_name in obj.type.feature_names: + value = state.get(obj, feature_name) + if isinstance(value, float): + value_str = f"{value:.4f}" + else: + value_str = str(value) + feature_values.append(f"{feature_name}={value_str}") + object_details.append(f"{obj}: " + ", ".join(feature_values)) + details = [ + f"Targets: {sorted(target_atoms)}", + f"Missing: {sorted(missing_targets)}", + f"cur_atoms: {sorted(cur_atoms)}", + ] + if object_details: + details.append(f"target_objects: {'; '.join(object_details)}") + return "; ".join(details) + + +def option_policy_to_policy( + option_policy: Callable[[State], _Option], + max_option_steps: Optional[int] = None, + raise_error_on_repeated_state: bool = False, + abstract_function: Optional[Callable[[State], Set[GroundAtom]]] = None +) -> Callable[[State], Action]: + """Create a policy that executes a policy over options.""" + cur_option = DummyOption + num_cur_option_steps = 0 + last_state: Optional[State] = None + + def _policy(state: State) -> Action: + nonlocal cur_option, num_cur_option_steps, last_state + + if cur_option is DummyOption: + last_option: Optional[_Option] = None + else: + last_option = cur_option + + if max_option_steps is not None and \ + num_cur_option_steps >= max_option_steps: + raise OptionTimeoutFailure( + "Exceeded max option steps.", + info={"last_failed_option": last_option}) + + if last_state is not None and \ + raise_error_on_repeated_state and state.allclose(last_state): + raise OptionTimeoutFailure( + "Encountered repeated state.", + info={"last_failed_option": last_option}) + # logging for debugging + # if last_state is not None: + # cur_atoms = abstract_function(state) + # prev_atoms = abstract_function(last_state) + # logging.debug(f"Prev atoms: {sorted(prev_atoms)}") + # logging.info(f"Add atoms: {sorted(cur_atoms-prev_atoms)} " + # f"Del atoms: {sorted(prev_atoms-cur_atoms)}") + + # whether the noop option should terminate + wait_terminate = False + if CFG.wait_option_terminate_on_atom_change \ + and cur_option.name == "Wait": + assert abstract_function is not None + assert last_state is not None + target_atoms = cur_option.memory.get("wait_target_atoms") + result = check_wait_target_atoms(cur_option, state, + abstract_function) + if result is True: + cur_atoms = abstract_function(state) + logging.debug("Wait terminating: target atoms satisfied. " + f"Targets: {target_atoms}, " + f"cur_atoms: {sorted(cur_atoms)}, " + f"num_option_steps={num_cur_option_steps}") + wait_terminate = True + elif result is False: + assert target_atoms is not None + if num_cur_option_steps <= 1 or num_cur_option_steps % 25 == 0: + wait_debug = _format_wait_target_debug( + state, target_atoms, abstract_function) + logging.debug( + "Wait continuing: target atoms not yet satisfied. " + "%s, num_option_steps=%d", wait_debug, + num_cur_option_steps) + elif result is None: + # No targets specified: fall back to any-atom-change + cur_atoms = abstract_function(state) + prev_atoms = abstract_function(last_state) + if cur_atoms != prev_atoms: + logging.debug(f"Wait terminating due to atom change: " + f"Add: {sorted(cur_atoms-prev_atoms)} " + f"Del: {sorted(prev_atoms-cur_atoms)}") + wait_terminate = True + + last_state = state + + option_terminal = cur_option is not DummyOption and \ + cur_option.terminal(state) + if wait_terminate or cur_option is DummyOption or option_terminal: + if cur_option is not DummyOption: + if wait_terminate: + reason = "atom change during Wait" + elif option_terminal: + reason = "option self-terminated" + else: + reason = "unknown" + logging.info(f"[{cur_option.name}] Terminated: {reason} " + f"(after {num_cur_option_steps} steps)\n") + try: + cur_option = option_policy(state) + except OptionExecutionFailure as e: + e.info["last_failed_option"] = last_option + raise e + if not cur_option.initiable(state): + raise OptionExecutionFailure( + "Unsound option policy.", + info={"last_failed_option": last_option}) + logging.debug(f"[option_policy] Started option {cur_option.name}, " + f"initiable=True") + num_cur_option_steps = 0 + + num_cur_option_steps += 1 + + return cur_option.policy(state) + + return _policy + + +def option_plan_to_policy( + plan: Sequence[_Option], + max_option_steps: Optional[int] = None, + raise_error_on_repeated_state: bool = False, + abstract_function: Optional[Callable[[State], Set[GroundAtom]]] = None +) -> Callable[[State], Action]: + """Create a policy that executes a sequence of options in order.""" + queue = list(plan) # don't modify plan, just in case + total_options = len(queue) + + def _option_policy(state: State) -> _Option: + del state # not used + if not queue: + logging.info("Option plan exhausted after %d options.", + total_options) + raise OptionExecutionFailure("Option plan exhausted!") + option = queue.pop(0) + option_num = total_options - len(queue) + next_option = None if not queue else queue[0].simple_str() + logging.info("Executing option %d/%d: %s (remaining=%d, next=%s)", + option_num, total_options, option.simple_str(), + len(queue), next_option) + return option + + return option_policy_to_policy( + _option_policy, + max_option_steps=max_option_steps, + raise_error_on_repeated_state=raise_error_on_repeated_state, + abstract_function=abstract_function) + + +def nsrt_plan_to_greedy_option_policy( + nsrt_plan: Sequence[_GroundNSRT], + goal: Set[GroundAtom], + rng: np.random.Generator, + necessary_atoms_seq: Optional[Sequence[Set[GroundAtom]]] = None +) -> Callable[[State], _Option]: + """Greedily execute an NSRT plan, assuming downward refinability and that + any sample will work. + + If an option is not initiable or if the plan runs out, an + OptionExecutionFailure is raised. + """ + cur_nsrt: Optional[_GroundNSRT] = None + nsrt_queue = list(nsrt_plan) + if necessary_atoms_seq is None: + empty_atoms: Set[GroundAtom] = set() + necessary_atoms_seq = [empty_atoms for _ in range(len(nsrt_plan) + 1)] + assert len(necessary_atoms_seq) == len(nsrt_plan) + 1 + necessary_atoms_queue = list(necessary_atoms_seq) + + def _option_policy(state: State) -> _Option: + nonlocal cur_nsrt + if not nsrt_queue: + raise OptionExecutionFailure("NSRT plan exhausted.") + expected_atoms = necessary_atoms_queue.pop(0) + if not all(a.holds(state) for a in expected_atoms): + raise OptionExecutionFailure( + "Executing the NSRT failed to achieve the necessary atoms.") + cur_nsrt = nsrt_queue.pop(0) + cur_option = cur_nsrt.sample_option(state, goal, rng) + logging.debug(f"Using option {cur_option.name}{cur_option.objects}" + f"{cur_option.params} from NSRT plan.") + return cur_option + + return _option_policy + + +def nsrt_plan_to_greedy_policy( + nsrt_plan: Sequence[_GroundNSRT], + goal: Set[GroundAtom], + rng: np.random.Generator, + necessary_atoms_seq: Optional[Sequence[Set[GroundAtom]]] = None, + abstract_function: Optional[Callable[[State], Set[GroundAtom]]] = None +) -> Callable[[State], Action]: + """Greedily execute an NSRT plan, assuming downward refinability and that + any sample will work. + + If an option is not initiable or if the plan runs out, an + OptionExecutionFailure is raised. + """ + option_policy = nsrt_plan_to_greedy_option_policy( + nsrt_plan, goal, rng, necessary_atoms_seq=necessary_atoms_seq) + return option_policy_to_policy(option_policy, + abstract_function=abstract_function) + + +def process_plan_to_greedy_option_policy( + process_plan: Sequence[_GroundEndogenousProcess], + goal: Set[GroundAtom], + rng: np.random.Generator, + necessary_atoms_seq: Optional[Sequence[Set[GroundAtom]]] = None, + atoms_seq: Optional[Sequence[Set[GroundAtom]]] = None, +) -> Callable[[State], _Option]: + """Greedily execute a process plan, assuming downward refinability and that + any sample will work. + + If an option is not initiable or if the plan runs out, an + OptionExecutionFailure is raised. + """ + cur_process: Optional[_GroundEndogenousProcess] = None + process_queue = list(process_plan) + if necessary_atoms_seq is None: + empty_atoms: Set[GroundAtom] = set() + necessary_atoms_seq = [ + empty_atoms for _ in range(len(process_plan) + 1) + ] + assert len(necessary_atoms_seq) == len(process_plan) + 1 + necessary_atoms_queue = list(necessary_atoms_seq) + step_idx = 0 + + def _option_policy(state: State) -> _Option: + nonlocal cur_process, step_idx + if not process_queue: + raise OptionExecutionFailure("Process plan exhausted.") + expected_atoms = necessary_atoms_queue.pop(0) + if not all(a.holds(state) for a in expected_atoms): + raise OptionExecutionFailure( + "Executing the process failed to achieve the necessary atoms.") + cur_process = process_queue.pop(0) + cur_option = cur_process.sample_option(state, goal, rng) + if atoms_seq is not None: + inject_wait_targets_for_option(cur_option, step_idx, atoms_seq) + step_idx += 1 + logging.debug(f"Using option {cur_option.name}{cur_option.objects}" + f"{cur_option.params} from process plan.") + return cur_option + + return _option_policy + + +def process_plan_to_greedy_policy( + process_plan: Sequence[_GroundEndogenousProcess], + goal: Set[GroundAtom], + rng: np.random.Generator, + necessary_atoms_seq: Optional[Sequence[Set[GroundAtom]]] = None, + abstract_function: Optional[Callable[[State], Set[GroundAtom]]] = None, + atoms_seq: Optional[Sequence[Set[GroundAtom]]] = None, +) -> Callable[[State], Action]: + """Convert a process plan to a greedy policy.""" + option_policy = process_plan_to_greedy_option_policy( + process_plan, + goal, + rng, + necessary_atoms_seq=necessary_atoms_seq, + atoms_seq=atoms_seq) + return option_policy_to_policy(option_policy, + abstract_function=abstract_function) + + +def sample_applicable_option(param_options: List[ParameterizedOption], + state: State, + rng: np.random.Generator) -> Optional[_Option]: + """Sample an applicable option.""" + for _ in range(CFG.random_options_max_tries): + param_opt = param_options[rng.choice(len(param_options))] + objs = get_random_object_combination(list(state), param_opt.types, rng) + if objs is None: + continue + params = param_opt.params_space.sample() + opt = param_opt.ground(objs, params) + if opt.initiable(state): + return opt + return None + + +def create_random_option_policy( + options: Collection[ParameterizedOption], rng: np.random.Generator, + fallback_policy: Callable[[State], + Action]) -> Callable[[State], Action]: + """Create a policy that executes random initiable options. + + If no applicable option can be found, query the fallback policy. + """ + sorted_options = sorted(options, key=lambda o: o.name) + cur_option = DummyOption + + def _policy(state: State) -> Action: + nonlocal cur_option + if cur_option is DummyOption or cur_option.terminal(state): + cur_option = DummyOption + sample = sample_applicable_option(sorted_options, state, rng) + if sample is not None: + cur_option = sample + else: + return fallback_policy(state) + act = cur_option.policy(state) + return act + + return _policy + + +def sample_applicable_ground_nsrt( + state: State, ground_nsrts: Sequence[_GroundNSRT], + predicates: Set[Predicate], + rng: np.random.Generator) -> Optional[_GroundNSRT]: + """Choose uniformly among the ground NSRTs that are applicable in the + state.""" + atoms = abstract(state, predicates) + applicable_nsrts = sorted(get_applicable_operators(ground_nsrts, atoms)) + if len(applicable_nsrts) == 0: + return None + idx = rng.choice(len(applicable_nsrts)) + return applicable_nsrts[idx] # type: ignore[return-value] + + +def action_arrs_to_policy( + action_arrs: Sequence[Array]) -> Callable[[State], Action]: + """Create a policy that executes action arrays in sequence.""" + + queue = list(action_arrs) # don't modify original, just in case + + def _policy(s: State) -> Action: + del s # unused + return Action(queue.pop(0)) + + return _policy + + +def _get_entity_combinations( + entities: Collection[ObjectOrVariable], + types: Sequence[Type]) -> Iterator[List[ObjectOrVariable]]: + """Get all combinations of entities satisfying the given types sequence.""" + sorted_entities = sorted(entities) + choices = [] + for vt in types: + this_choices = [] + for ent in sorted_entities: + if ent.is_instance(vt): + this_choices.append(ent) + choices.append(this_choices) + for choice in itertools.product(*choices): + yield list(choice) + + +def get_object_combinations(objects: Collection[Object], + types: Sequence[Type]) -> Iterator[List[Object]]: + """Get all combinations of objects satisfying the given types sequence.""" + return _get_entity_combinations(objects, types) + + +def get_variable_combinations( + variables: Collection[Variable], + types: Sequence[Type]) -> Iterator[List[Variable]]: + """Get all combinations of variables satisfying the given types + sequence.""" + return _get_entity_combinations(variables, types) + + +def get_all_ground_atoms_for_predicate( + predicate: Predicate, objects: Collection[Object]) -> Set[GroundAtom]: + """Get all groundings of the predicate given objects. + + Note: we don't want lru_cache() on this function because we might want + to call it with stripped predicates, and we wouldn't want it to return + cached values. + """ + ground_atoms = set() + for args in get_object_combinations(objects, predicate.types): + ground_atom = GroundAtom(predicate, args) + ground_atoms.add(ground_atom) + return ground_atoms + + +def get_all_lifted_atoms_for_predicate( + predicate: Predicate, + variables: FrozenSet[Variable]) -> Set[LiftedAtom]: + """Get all groundings of the predicate given variables. + + Note: we don't want lru_cache() on this function because we might want + to call it with stripped predicates, and we wouldn't want it to return + cached values. + """ + lifted_atoms = set() + for args in get_variable_combinations(variables, predicate.types): + lifted_atom = LiftedAtom(predicate, args) + lifted_atoms.add(lifted_atom) + return lifted_atoms + + +def get_random_object_combination( + objects: Collection[Object], types: Sequence[Type], + rng: np.random.Generator) -> Optional[List[Object]]: + """Get a random list of objects from the given collection that satisfy the + given sequence of types. + + Duplicates are always allowed. If a particular type has no object, + return None. + """ + types_to_objs = defaultdict(list) + for obj in objects: + types_to_objs[obj.type].append(obj) + result = [] + for t in types: + t_objs = types_to_objs[t] + if not t_objs: + return None + result.append(t_objs[rng.choice(len(t_objs))]) + return result + + +def find_substitution( + super_atoms: Collection[LiftedOrGroundAtom], + sub_atoms: Collection[LiftedOrGroundAtom], + allow_redundant: bool = False, +) -> Tuple[bool, EntToEntSub]: + """Find a substitution from the entities in super_atoms to the entities in + sub_atoms s.t. sub_atoms is a subset of super_atoms. + + If allow_redundant is True, then multiple entities in sub_atoms can + refer to the same single entity in super_atoms. + + If no substitution exists, return (False, {}). + """ + super_entities_by_type: Dict[Type, List[_TypedEntity]] = defaultdict(list) + super_pred_to_tuples = defaultdict(set) + for atom in super_atoms: + for obj in atom.entities: + if obj not in super_entities_by_type[obj.type]: + super_entities_by_type[obj.type].append(obj) + super_pred_to_tuples[atom.predicate].add(tuple(atom.entities)) + sub_variables = sorted({e for atom in sub_atoms for e in atom.entities}) + return _find_substitution_helper(sub_atoms, super_entities_by_type, + sub_variables, super_pred_to_tuples, {}, + allow_redundant) + + +def _find_substitution_helper( + sub_atoms: Collection[LiftedOrGroundAtom], + super_entities_by_type: Dict[Type, List[_TypedEntity]], + remaining_sub_variables: List[_TypedEntity], + super_pred_to_tuples: Dict[Predicate, + Set[Tuple[_TypedEntity, + ...]]], partial_sub: EntToEntSub, + allow_redundant: bool) -> Tuple[bool, EntToEntSub]: + """Helper for find_substitution.""" + # Base case: check if all assigned + if not remaining_sub_variables: + return True, partial_sub + # Find next variable to assign + remaining_sub_variables = remaining_sub_variables.copy() + next_sub_var = remaining_sub_variables.pop(0) + # Consider possible assignments + for super_obj in super_entities_by_type[next_sub_var.type]: + if not allow_redundant and super_obj in partial_sub.values(): + continue + new_sub = partial_sub.copy() + new_sub[next_sub_var] = super_obj + # Check if consistent + if not _substitution_consistent(new_sub, super_pred_to_tuples, + sub_atoms): + continue + # Backtracking search + solved, final_sub = _find_substitution_helper(sub_atoms, + super_entities_by_type, + remaining_sub_variables, + super_pred_to_tuples, + new_sub, allow_redundant) + if solved: + return solved, final_sub + # Failure + return False, {} + + +def _substitution_consistent( + partial_sub: EntToEntSub, + super_pred_to_tuples: Dict[Predicate, Set[Tuple[_TypedEntity, ...]]], + sub_atoms: Collection[LiftedOrGroundAtom]) -> bool: + """Helper for _find_substitution_helper.""" + for sub_atom in sub_atoms: + if not set(sub_atom.entities).issubset(partial_sub.keys()): + continue + substituted_vars = tuple(partial_sub[e] for e in sub_atom.entities) + if substituted_vars not in super_pred_to_tuples[sub_atom.predicate]: + return False + return True + + +def create_new_variables( + types: Sequence[Type], + existing_vars: Optional[Collection[Variable]] = None, + var_prefix: str = "?x", +) -> List[Variable]: + """Create new variables of the given types, avoiding name collisions with + existing variables. + + By convention, all new variables are of the form + . + """ + pre_len = len(var_prefix) + existing_var_nums = set() + if existing_vars: + for v in existing_vars: + if v.name.startswith(var_prefix) and v.name[pre_len:].isdigit(): + existing_var_nums.add(int(v.name[pre_len:])) + if existing_var_nums: + counter = itertools.count(max(existing_var_nums) + 1) + else: + counter = itertools.count(0) + new_vars = [] + for t in types: + new_var_name = f"{var_prefix}{next(counter)}" + new_var = Variable(new_var_name, t) + new_vars.append(new_var) + return new_vars + + +def param_option_to_nsrt(param_option: ParameterizedOption, + nsrts: Set[NSRT]) -> NSRT: + """If options and NSRTs are 1:1, then map an option to an NSRT.""" + nsrt_matches = [n for n in nsrts if n.option == param_option] + assert len(nsrt_matches) == 1 + nsrt = nsrt_matches[0] + return nsrt + + +def option_to_ground_nsrt(option: _Option, nsrts: Set[NSRT]) -> _GroundNSRT: + """If options and NSRTs are 1:1, then map an option to an NSRT.""" + nsrt = param_option_to_nsrt(option.parent, nsrts) + return nsrt.ground(option.objects) + + +_S = TypeVar("_S", bound=Hashable) # state in heuristic search +_A = TypeVar("_A") # action in heuristic search + + +@dataclass(frozen=True) +class _HeuristicSearchNode(Generic[_S, _A]): + state: _S + edge_cost: float + cumulative_cost: float + parent: Optional[_HeuristicSearchNode[_S, _A]] = None + action: Optional[_A] = None + + +def _run_heuristic_search( + initial_state: _S, + check_goal: Callable[[_S], bool], + get_successors: Callable[[_S], Iterator[Tuple[_A, _S, float]]], + get_priority: Callable[[_HeuristicSearchNode[_S, _A]], Any], + max_expansions: int = 10000000, + max_evals: int = 10000000, + timeout: int = 10000000, + lazy_expansion: bool = False) -> Tuple[List[_S], List[_A]]: + """A generic heuristic search implementation. + + Depending on get_priority, can implement A*, GBFS, or UCS. + + If no goal is found, returns the state with the best priority. + """ + queue: List[Tuple[Any, int, _HeuristicSearchNode[_S, _A]]] = [] + state_to_best_path_cost: Dict[_S, float] = \ + defaultdict(lambda: float("inf")) + + root_node: _HeuristicSearchNode[_S, _A] = _HeuristicSearchNode( + initial_state, 0, 0) + root_priority = get_priority(root_node) + best_node = root_node + best_node_priority = root_priority + tiebreak = itertools.count() + hq.heappush(queue, (root_priority, next(tiebreak), root_node)) + num_expansions = 0 + num_evals = 1 + start_time = time.perf_counter() + + while len(queue) > 0 and time.perf_counter() - start_time < timeout and \ + num_expansions < max_expansions and num_evals < max_evals: + _, _, node = hq.heappop(queue) + # If we already found a better path here, don't bother. + if state_to_best_path_cost[node.state] < node.cumulative_cost: + continue + # If the goal holds, return. + if check_goal(node.state): + return _finish_plan(node) + num_expansions += 1 + # Generate successors. + for action, child_state, cost in get_successors(node.state): + if time.perf_counter() - start_time >= timeout: + break + child_path_cost = node.cumulative_cost + cost + # If we already found a better path to this child, don't bother. + if state_to_best_path_cost[child_state] <= child_path_cost: + continue + # Add new node. + child_node = _HeuristicSearchNode(state=child_state, + edge_cost=cost, + cumulative_cost=child_path_cost, + parent=node, + action=action) + priority = get_priority(child_node) + num_evals += 1 + hq.heappush(queue, (priority, next(tiebreak), child_node)) + state_to_best_path_cost[child_state] = child_path_cost + if priority < best_node_priority: + best_node_priority = priority + best_node = child_node + # Optimization: if we've found a better child, immediately + # explore the child without expanding the rest of the children. + # Accomplish this by putting the parent node back on the queue. + if lazy_expansion: + hq.heappush(queue, (priority, next(tiebreak), node)) + break + if num_evals >= max_evals: + break + + # Did not find path to goal; return best path seen. + return _finish_plan(best_node) + + +def _finish_plan( + node: _HeuristicSearchNode[_S, _A]) -> Tuple[List[_S], List[_A]]: + """Helper for _run_heuristic_search and run_hill_climbing.""" + rev_state_sequence: List[_S] = [] + rev_action_sequence: List[_A] = [] + + while node.parent is not None: + action = cast(_A, node.action) + rev_action_sequence.append(action) + rev_state_sequence.append(node.state) + node = node.parent + rev_state_sequence.append(node.state) + + return rev_state_sequence[::-1], rev_action_sequence[::-1] + + +def run_gbfs(initial_state: _S, + check_goal: Callable[[_S], bool], + get_successors: Callable[[_S], Iterator[Tuple[_A, _S, float]]], + heuristic: Callable[[_S], float], + max_expansions: int = 10000000, + max_evals: int = 10000000, + timeout: int = 10000000, + lazy_expansion: bool = False) -> Tuple[List[_S], List[_A]]: + """Greedy best-first search.""" + get_priority = lambda n: heuristic(n.state) + return _run_heuristic_search(initial_state, check_goal, get_successors, + get_priority, max_expansions, max_evals, + timeout, lazy_expansion) + + +def run_astar(initial_state: _S, + check_goal: Callable[[_S], bool], + get_successors: Callable[[_S], Iterator[Tuple[_A, _S, float]]], + heuristic: Callable[[_S], float], + max_expansions: int = 10000000, + max_evals: int = 10000000, + timeout: int = 10000000, + lazy_expansion: bool = False) -> Tuple[List[_S], List[_A]]: + """A* search.""" + get_priority = lambda n: heuristic(n.state) + n.cumulative_cost + return _run_heuristic_search(initial_state, check_goal, get_successors, + get_priority, max_expansions, max_evals, + timeout, lazy_expansion) + + +def run_hill_climbing( + initial_state: _S, + check_goal: Callable[[_S], bool], + get_successors: Callable[[_S], Iterator[Tuple[_A, _S, float]]], + heuristic: Callable[[_S], float], + early_termination_heuristic_thresh: Optional[float] = None, + enforced_depth: int = 0, + exhaustive_lookahead: bool = False, + parallelize: bool = False, + verbose: bool = True, + timeout: float = float('inf') +) -> Tuple[List[_S], List[_A], List[float]]: + """Enforced hill climbing local search. + + For each node, this search looks for an improvement up to `enforced_depth`. + If `exhaustive_lookahead` is False (default), for each node, the best child + node is always selected, if that child is + an improvement over the node. If no children improve on the node, look + at the children's children, etc., up to enforced_depth, where enforced_depth + 0 corresponds to simple hill climbing. Terminate when no improvement can + be found. early_termination_heuristic_thresh allows for searching until + heuristic reaches a specified value. + Let b be the branching factor, d be the enforced_depth, this has time + complxity of O(b^{d+1}). + If True, it searches the entire horizon up to the + enforced depth and picks the best overall improvement. + + Lower heuristic is better. + """ + assert enforced_depth >= 0 + cur_node: _HeuristicSearchNode[_S, _A] = _HeuristicSearchNode( + initial_state, 0, 0) + last_heuristic = heuristic(cur_node.state) + heuristics = [last_heuristic] + # visited = {initial_state} # <--- deleted for exhaustive_lookahead + if verbose: + logging.info(f"\n\nStarting hill climbing at state {cur_node.state} " + f"with heuristic {last_heuristic}") + start_time = time.perf_counter() + while True: + visited = {cur_node.state} # <--- added for exhaustive_lookahead + + # Stops when heuristic reaches specified value. + if early_termination_heuristic_thresh is not None \ + and last_heuristic <= early_termination_heuristic_thresh: + break + + if check_goal(cur_node.state): + if verbose: + logging.info("\nTerminating hill climbing, achieved goal") + break + best_heuristic = float("inf") + best_child_node = None + current_depth_nodes = [cur_node] + all_best_heuristics = [] + for depth in range(0, enforced_depth + 1): + if verbose: + logging.info(f"Searching for an improvement at depth {depth}") + # This is a list to ensure determinism. Note that duplicates are + # filtered out in the `child_state in visited` check. + successors_at_depth = [] + for parent in current_depth_nodes: + for action, child_state, cost in get_successors(parent.state): + # Raise error if timeout gets hit. + if time.perf_counter() - start_time > timeout: + raise TimeoutError() + if child_state in visited: + continue + visited.add(child_state) + child_path_cost = parent.cumulative_cost + cost + child_node = _HeuristicSearchNode( + state=child_state, + edge_cost=cost, + cumulative_cost=child_path_cost, + parent=parent, + action=action) + successors_at_depth.append(child_node) + if parallelize: + continue # heuristic computation is parallelized later + child_heuristic = heuristic(child_node.state) + if child_heuristic < best_heuristic: + best_heuristic = child_heuristic + best_child_node = child_node + if parallelize: + # Parallelize the expensive part (heuristic computation). + # pathos is lazy-imported because it isn't available under + # Pyodide; the browser POC never sets parallelize=True. + import pathos.multiprocessing as mp # noqa: PLC0415 # pylint: disable=import-outside-toplevel + num_cpus = mp.cpu_count() + fn = lambda n: (heuristic(n.state), n) + with mp.Pool(processes=num_cpus) as p: + for child_heuristic, child_node in p.map( + fn, successors_at_depth): + if child_heuristic < best_heuristic: + best_heuristic = child_heuristic + best_child_node = child_node + all_best_heuristics.append(best_heuristic) + + if not exhaustive_lookahead and last_heuristic > best_heuristic: + # Some improvement found. + if verbose: + logging.info(f"Found an improvement at depth {depth}") + break + # Continue on to the next depth. + current_depth_nodes = successors_at_depth + if not current_depth_nodes: + if verbose: + logging.info( + f"No more successors to explore at depth {depth}.") + break # No need to search deeper if there are no more nodes. + + if verbose: + if exhaustive_lookahead: + logging.info(f"Finished depth {depth}. " + f"Best heuristic so far: {best_heuristic}") + elif last_heuristic <= best_heuristic: + logging.info(f"No improvement found at depth {depth}") + + if best_child_node is None: + if verbose: + logging.info("\nTerminating hill climbing, no more successors") + break + if last_heuristic <= best_heuristic: + if verbose: + logging.info( + "\nTerminating hill climbing, could not improve score") + break + heuristics.extend(all_best_heuristics) + cur_node = best_child_node + last_heuristic = best_heuristic + if verbose: + logging.info(f"\nHill climbing reached new state {cur_node.state} " + f"with heuristic {last_heuristic}") + + states, actions = _finish_plan(cur_node) + # The number of heuristics might not match the plan length perfectly now, + # so we should regenerate them from the final plan. + final_heuristics = [heuristic(s) for s in states] + assert len(states) == len(final_heuristics) + return states, actions, final_heuristics + + +def run_policy_guided_astar( + initial_state: _S, + check_goal: Callable[[_S], bool], + get_valid_actions: Callable[[_S], Iterator[Tuple[_A, float]]], + get_next_state: Callable[[_S, _A], _S], + heuristic: Callable[[_S], float], + policy: Callable[[_S], Optional[_A]], + num_rollout_steps: int, + rollout_step_cost: float, + max_expansions: int = 10000000, + max_evals: int = 10000000, + timeout: int = 10000000, + lazy_expansion: bool = False) -> Tuple[List[_S], List[_A]]: + """Perform A* search, but at each node, roll out a given policy for a given + number of timesteps, creating new successors at each step. + + Stop the rollout prematurely if the policy returns None. + + Note that unlike the other search functions, which take get_successors as + input, this function takes get_valid_actions and get_next_state as two + separate inputs. This is necessary because we need to anticipate the next + state conditioned on the action output by the policy. + + The get_valid_actions generates (action, cost) tuples. For policy-generated + transitions, the costs are ignored, and rollout_step_cost is used instead. + """ + + # Create a new successor function that rolls out the policy first. + # A successor here means: from this state, if you take this sequence of + # actions in order, you'll end up at this final state. + def get_successors(state: _S) -> Iterator[Tuple[List[_A], _S, float]]: + # Get policy-based successors. + policy_state = state + policy_action_seq = [] + policy_cost = 0.0 + for _ in range(num_rollout_steps): + action = policy(policy_state) + valid_actions = {a for a, _ in get_valid_actions(policy_state)} + if action is None or action not in valid_actions: + break + policy_state = get_next_state(policy_state, action) + policy_action_seq.append(action) + policy_cost += rollout_step_cost + yield (list(policy_action_seq), policy_state, policy_cost) + + # Get primitive successors. + for action, cost in get_valid_actions(state): + next_state = get_next_state(state, action) + yield ([action], next_state, cost) + + _, action_subseqs = run_astar(initial_state=initial_state, + check_goal=check_goal, + get_successors=get_successors, + heuristic=heuristic, + max_expansions=max_expansions, + max_evals=max_evals, + timeout=timeout, + lazy_expansion=lazy_expansion) + + # The states are "jumpy", so we need to reconstruct the dense state + # sequence from the action subsequences. We also need to construct a + # flat action sequence. + state = initial_state + state_seq = [state] + action_seq = [] + for action_subseq in action_subseqs: + for action in action_subseq: + action_seq.append(action) + state = get_next_state(state, action) + state_seq.append(state) + + return state_seq, action_seq + + +_RRTState = TypeVar("_RRTState") + + +class RRT(Generic[_RRTState]): + """Rapidly-exploring random tree.""" + + def __init__(self, sample_fn: Callable[[_RRTState], _RRTState], + extend_fn: Callable[[_RRTState, _RRTState], + Iterator[_RRTState]], + collision_fn: Callable[[_RRTState], bool], + distance_fn: Callable[[_RRTState, _RRTState], + float], rng: np.random.Generator, + num_attempts: int, num_iters: int, smooth_amt: int): + self._sample_fn = sample_fn + self._extend_fn = extend_fn + self._collision_fn = collision_fn + self._distance_fn = distance_fn + self._rng = rng + self._num_attempts = num_attempts + self._num_iters = num_iters + self._smooth_amt = smooth_amt + + def query(self, + pt1: _RRTState, + pt2: _RRTState, + sample_goal_eps: float = 0.0) -> Optional[List[_RRTState]]: + """Query the RRT, to get a collision-free path from pt1 to pt2. + + If none is found, returns None. + """ + if self._collision_fn(pt1) or self._collision_fn(pt2): + return None + direct_path = self._try_direct_path(pt1, pt2) + if direct_path is not None: + return direct_path + for _ in range(self._num_attempts): + path = self._rrt_connect(pt1, + goal_sampler=lambda: pt2, + sample_goal_eps=sample_goal_eps) + if path is not None: + return self._smooth_path(path) + return None + + def query_to_goal_fn( + self, + start: _RRTState, + goal_sampler: Callable[[], _RRTState], + goal_fn: Callable[[_RRTState], bool], + sample_goal_eps: float = 0.0) -> Optional[List[_RRTState]]: + """Query the RRT, to get a collision-free path from start to a point + such that goal_fn(point) is True. Uses goal_sampler to sample a target + for a direct path or with probability sample_goal_eps. + + If none is found, returns None. + """ + if self._collision_fn(start): + return None + direct_path = self._try_direct_path(start, goal_sampler()) + if direct_path is not None: + return direct_path + for _ in range(self._num_attempts): + path = self._rrt_connect(start, + goal_sampler, + goal_fn, + sample_goal_eps=sample_goal_eps) + if path is not None: + return self._smooth_path(path) + return None + + def _try_direct_path(self, pt1: _RRTState, + pt2: _RRTState) -> Optional[List[_RRTState]]: + path = [pt1] + for newpt in self._extend_fn(pt1, pt2): + if self._collision_fn(newpt): + return None + path.append(newpt) + return path + + def _rrt_connect( + self, + pt1: _RRTState, + goal_sampler: Callable[[], _RRTState], + goal_fn: Optional[Callable[[_RRTState], bool]] = None, + sample_goal_eps: float = 0.0, + ) -> Optional[List[_RRTState]]: + root = _RRTNode(pt1) + nodes = [root] + + for _ in range(self._num_iters): + # Sample the goal with a small probability, otherwise randomly + # choose a point. + sample_goal = self._rng.random() < sample_goal_eps + samp = goal_sampler() if sample_goal else self._sample_fn(pt1) + min_key = functools.partial(self._get_pt_dist_to_node, samp) + nearest = min(nodes, key=min_key) + reached_goal = False + for newpt in self._extend_fn(nearest.data, samp): + if self._collision_fn(newpt): + break + nearest = _RRTNode(newpt, parent=nearest) + nodes.append(nearest) + else: + reached_goal = sample_goal + # Check goal_fn if defined + if reached_goal or goal_fn is not None and goal_fn(nearest.data): + path = nearest.path_from_root() + return [node.data for node in path] + return None + + def _get_pt_dist_to_node(self, pt: _RRTState, + node: _RRTNode[_RRTState]) -> float: + return self._distance_fn(pt, node.data) + + def _smooth_path(self, path: List[_RRTState]) -> List[_RRTState]: + assert len(path) > 2 + for _ in range(self._smooth_amt): + i = self._rng.integers(0, len(path) - 1) + j = self._rng.integers(0, len(path) - 1) + if abs(i - j) <= 1: + continue + if j < i: + i, j = j, i + shortcut = list(self._extend_fn(path[i], path[j])) + if len(shortcut) < j - i and \ + all(not self._collision_fn(pt) for pt in shortcut): + path = path[:i + 1] + shortcut + path[j + 1:] + return path + + +class BiRRT(RRT[_RRTState]): + """Bidirectional rapidly-exploring random tree.""" + + def query_to_goal_fn( + self, + start: _RRTState, + goal_sampler: Callable[[], _RRTState], + goal_fn: Callable[[_RRTState], bool], + sample_goal_eps: float = 0.0) -> Optional[List[_RRTState]]: + raise NotImplementedError("Can't query to goal function using BiRRT") + + def _rrt_connect( + self, + pt1: _RRTState, + goal_sampler: Callable[[], _RRTState], + goal_fn: Optional[Callable[[_RRTState], bool]] = None, + sample_goal_eps: float = 0.0, + ) -> Optional[List[_RRTState]]: + # goal_fn and sample_goal_eps are unused + pt2 = goal_sampler() + root1, root2 = _RRTNode(pt1), _RRTNode(pt2) + nodes1, nodes2 = [root1], [root2] + + for _ in range(self._num_iters): + if len(nodes1) > len(nodes2): + nodes1, nodes2 = nodes2, nodes1 + samp = self._sample_fn(pt1) + min_key1 = functools.partial(self._get_pt_dist_to_node, samp) + nearest1 = min(nodes1, key=min_key1) + for newpt in self._extend_fn(nearest1.data, samp): + if self._collision_fn(newpt): + break + nearest1 = _RRTNode(newpt, parent=nearest1) + nodes1.append(nearest1) + min_key2 = functools.partial(self._get_pt_dist_to_node, + nearest1.data) + nearest2 = min(nodes2, key=min_key2) + for newpt in self._extend_fn(nearest2.data, nearest1.data): + if self._collision_fn(newpt): + break + nearest2 = _RRTNode(newpt, parent=nearest2) + nodes2.append(nearest2) + else: + path1 = nearest1.path_from_root() + path2 = nearest2.path_from_root() + # This is a tricky case to cover. + if path1[0] != root1: # pragma: no cover + path1, path2 = path2, path1 + assert path1[0] == root1 + path = path1[:-1] + path2[::-1] + return [node.data for node in path] + return None + + +class _RRTNode(Generic[_RRTState]): + """A node for RRT.""" + + def __init__(self, + data: _RRTState, + parent: Optional[_RRTNode[_RRTState]] = None) -> None: + self.data = data + self.parent = parent + + def path_from_root(self) -> List[_RRTNode[_RRTState]]: + """Return the path from the root to this node.""" + sequence = [] + node: Optional[_RRTNode[_RRTState]] = self + while node is not None: + sequence.append(node) + node = node.parent + return sequence[::-1] + + +def strip_predicate(predicate: Predicate) -> Predicate: + """Remove the classifier from the given predicate to make a new Predicate. + + Implement this by replacing the classifier with one that errors. + """ + + def _stripped_classifier(state: State, objects: Sequence[Object]) -> bool: + raise Exception("Stripped classifier should never be called!") + + return Predicate(predicate.name, predicate.types, _stripped_classifier) + + +def strip_task(task: Task, included_predicates: Set[Predicate]) -> Task: + """Create a new task where any excluded goal predicates have their + classifiers removed.""" + stripped_goal: Set[GroundAtom] = set() + for atom in task.goal: + if atom.predicate in included_predicates: + stripped_goal.add(atom) + continue + stripped_pred = strip_predicate(atom.predicate) + stripped_atom = GroundAtom(stripped_pred, atom.objects) + stripped_goal.add(stripped_atom) + return Task(task.init, + stripped_goal, + alt_goal=task.alt_goal, + goal_nl=task.goal_nl) + + +def create_vlm_predicate( + name: str, types: Sequence[Type], + get_vlm_query_str: Callable[[Sequence[Object]], str]) -> VLMPredicate: + """Simple function that creates VLMPredicates with dummy classifiers, which + is the most-common way these need to be created.""" + + def _stripped_classifier( + state: State, + objects: Sequence[Object]) -> bool: # pragma: no cover. + raise Exception("VLM predicate classifier should never be called!") + + return VLMPredicate(name, types, _stripped_classifier, + get_vlm_query_str) # type: ignore[arg-type] + + +def parse_model_output_into_option_plan( + model_prediction: str, objects: Collection[Object], + types: Collection[Type], options: Collection[ParameterizedOption], + parse_continuous_params: bool +) -> List[Tuple[ParameterizedOption, Sequence[Object], Sequence[float]]]: + """Assuming text for an option plan that is predicted as text by a large + model, parse it into a sequence of ParameterizedOptions coupled with a list + of objects and continuous parameters that will be used to ground the + ParameterizedOption. + + We assume the model's output is such that each line is formatted as + option_name(obj0:type0, obj1:type1,...)[continuous_param0, + continuous_param1, ...]. + """ + option_plan: List[Tuple[ParameterizedOption, Sequence[Object], + Sequence[float]]] = [] + # Setup dictionaries enabling us to easily map names to specific + # Python objects during parsing. + option_name_to_option = {op.name: op for op in options} + type_name_to_type = {typ.name: typ for typ in types} + obj_name_to_obj = {o.name: o for o in objects} + options_str_list = model_prediction.split('\n') + for option_str in options_str_list: + option_str_stripped = option_str.strip() + option_name = option_str_stripped.split('(')[0] + # Skip empty option strs. + if not option_str: + continue + if option_name not in option_name_to_option.keys() or \ + "(" not in option_str: + if option_plan: + # Already found some options; stop on first non-option line. + logging.info( + f"Line {option_str} output by model doesn't " + "contain a valid option name. Terminating option plan " + "parsing.") + break + # Skip preamble lines (analysis text before the plan starts). + continue + if parse_continuous_params and "[" not in option_str: + logging.info( + f"Line {option_str} output by model doesn't contain a " + "'[' and is thus improperly formatted.") + break + option = option_name_to_option[option_name] + # Now that we have the option, we need to parse out the objects + # along with specified types. + try: + start_index = option_str_stripped.index('(') + 1 + end_index = option_str_stripped.index(')', start_index) + except ValueError: + logging.info( + f"Line {option_str} output by model is improperly formatted.") + break + typed_objects_str_list = option_str_stripped[ + start_index:end_index].split(',') + objs_list = [] + continuous_params_list = [] + malformed = False + for i, type_object_string in enumerate(typed_objects_str_list): + object_type_str_list = type_object_string.strip().split(':') + # We expect this list to be [object_name, type_name]. + if len(object_type_str_list) != 2: + logging.info(f"Line {option_str} output by model has a " + "malformed object-type list.") + malformed = True + break + object_name = object_type_str_list[0] + type_name = object_type_str_list[1] + if object_name not in obj_name_to_obj.keys(): + logging.info(f"Line {option_str} output by model has an " + "invalid object name.") + malformed = True + break + obj = obj_name_to_obj[object_name] + # Check that the type of this object agrees + # with what's expected given the ParameterizedOption. + if type_name not in type_name_to_type: + logging.info(f"Line {option_str} output by model has an " + "invalid type name.") + malformed = True + break + try: + if option.types[i] not in type_name_to_type[ + type_name].get_ancestors(): + logging.info( + f"Line {option_str} output by model has an " + "invalid type that doesn't agree with the option" + f"{option}") + malformed = True + break + except IndexError: + # In this case, there's more supplied arguments than the + # option has. + logging.info(f"Line {option_str} output by model has an " + "too many object arguments for option" + f"{option}") + malformed = True + break + objs_list.append(obj) + # The types of the objects match, but we haven't yet checked if + # all arguments of the option have an associated object. + if len(objs_list) != len(option.types): + malformed = True + # Now, we attempt to parse out the continuous parameters. + if parse_continuous_params: + params_str_list = option_str_stripped.split('[')[1].strip( + ']').split(',') + for i, continuous_params_str in enumerate(params_str_list): + stripped_continuous_param_str = continuous_params_str.strip() + if len(stripped_continuous_param_str) == 0: + continue + try: + curr_cont_param = float(stripped_continuous_param_str) + except ValueError: + logging.info(f"Line {option_str} output by model has an " + "invalid continouous parameter that can't be" + "converted to a float.") + malformed = True + break + continuous_params_list.append(curr_cont_param) + if len(continuous_params_list) != option.params_space.shape[0]: + logging.info(f"Line {option_str} output by model has " + "invalid continouous parameter(s) that don't " + f"agree with {option}{option.params_space}.") + malformed = True + break + if not malformed: + option_plan.append((option, objs_list, continuous_params_list)) + return option_plan + + +def get_prompt_for_vlm_state_labelling( + prompt_type: str, atoms_list: List[str], label_history: List[str], + imgs_history: List[List[PIL.Image.Image]], + cropped_imgs_history: List[List[PIL.Image.Image]], + skill_history: List[_Option]) -> Tuple[str, List[PIL.Image.Image]]: + """Prompt for labelling atom values in a trajectory. + + Note that all our prompts are saved as separate txt files under the + 'vlm_input_data_prompts/atom_labelling' folder. + """ + # Load the pre-specified prompt. + filepath_prefix = get_path_to_predicators_root() + \ + "/predicators/datasets/vlm_input_data_prompts/atom_labelling/" + try: + with open(filepath_prefix + prompt_type + ".txt", + "r", + encoding="utf-8") as f: + prompt = f.read() + except FileNotFoundError: + raise ValueError("Unknown VLM prompting option " + f"{prompt_type}") + # The prompt ends with a section for 'Predicates', so list these. + for atom_str in atoms_list: + prompt += f"\n{atom_str}" + + if "img_option_diffs" in prompt_type: + # In this case, we need to load the 'per_scene_naive' prompt as well + # for the first timestep. + with open(filepath_prefix + "per_scene_naive.txt", + "r", + encoding="utf-8") as f: + init_prompt = f.read() + for atom_str in atoms_list: + init_prompt += f"\n{atom_str}" + if len(label_history) == 0: + return (init_prompt, imgs_history[0]) + # Now, we use actual difference-based prompting for the second timestep + # and beyond. + curr_prompt = prompt[:] + curr_prompt_imgs = [imgs_history[-2][0], imgs_history[-1][0]] + if CFG.vlm_include_cropped_images: + if CFG.env in ["burger", "burger_no_move"]: # pragma: no cover + curr_prompt_imgs.extend( + [cropped_imgs_history[-1][1], cropped_imgs_history[-1][0]]) + else: + raise NotImplementedError( + f"Cropped images not implemented for {CFG.env}.") + curr_prompt += "\n\nSkill executed between states: " + skill_name = skill_history[-1].name + str(skill_history[-1].objects) + curr_prompt += skill_name + if "label_history" in prompt_type: + curr_prompt += "\n\nPredicate values in the first scene, " \ + "before the skill was executed: \n" + curr_prompt += label_history[-1] + return (curr_prompt, curr_prompt_imgs) + # NOTE: we rip out only the first image from each trajectory + # which is fine for most domains, but will be problematic for + # situations in which there is more than one image per state. + return (prompt, imgs_history[-1]) + + +def query_vlm_for_atom_vals( + vlm_atoms: Collection[GroundAtom], + state: State, + vlm: Optional[VisionLanguageModel] = None) -> Set[GroundAtom]: + """Given a set of ground atoms, queries a VLM and gets the subset of these + atoms that are true.""" + # Short-circuit this function in the case where there are no atoms that + # need be labelled. + if len(vlm_atoms) == 0: + return set() + true_atoms: Set[GroundAtom] = set() + # Get quantities necessary to construct prompt to query VLM. + if state.simulator_state is None: + return true_atoms + assert state.simulator_state is not None + assert isinstance(state.simulator_state["images"], List) + curr_state_imgs = state.simulator_state["images"] + vlm_atoms = sorted(vlm_atoms) + atom_queries_list = [atom.get_vlm_query_str() for atom in vlm_atoms] + prev_states_imgs_history = [] + prev_state_cropped_imgs_history: List[List[PIL.Image.Image]] = [] + if "state_history" in state.simulator_state: # pragma: no cover + prev_states = state.simulator_state["state_history"] + prev_states_imgs_history = [ + s.simulator_state["images"] for s in prev_states + ] + if "cropped_images" in prev_states[0].simulator_state: + prev_states_imgs_history = [ + s.simulator_state["cropped_images"] for s in prev_states + ] + images_history = prev_states_imgs_history + [curr_state_imgs] + skill_history = [] + if "skill_history" in state.simulator_state: # pragma: no cover + skill_history = state.simulator_state["skill_history"] + label_history = [] + if "vlm_label_history" in state.simulator_state: # pragma: no cover + label_history = state.simulator_state["vlm_label_history"] + vlm_query_str, imgs = get_prompt_for_vlm_state_labelling( + CFG.vlm_test_time_atom_label_prompt_type, atom_queries_list, + label_history, images_history, prev_state_cropped_imgs_history, + skill_history) + # Query VLM. + if vlm is None: + # Lazy import: create_vlm_by_name lives in the full utils.py + # (heavy pretrained-model deps), not utils_lite. CPython callers + # see it via `from predicators import utils`; Pyodide callers + # never hit this path (no VLM available). + from predicators.utils import \ + create_vlm_by_name # noqa: PLC0415 # pylint: disable=import-outside-toplevel + vlm = create_vlm_by_name(CFG.vlm_model_name) # pragma: no cover. + if CFG.env in ["pybullet_coffee"]: + vlm_input_imgs = list(imgs) # type: ignore + else: + vlm_input_imgs = \ + [PIL.Image.fromarray(img_arr) for img_arr in imgs] # type: ignore + vlm_output = vlm.sample_completions(vlm_query_str, + vlm_input_imgs, + 0.0, + seed=CFG.seed, + num_completions=1) + assert len(vlm_output) == 1 + vlm_output_str = vlm_output[0] + all_vlm_responses = vlm_output_str.strip().split("\n") + # NOTE: this assumption is likely too brittle; if this is breaking, feel + # free to remove/adjust this and change the below parsing loop accordingly! + if len(atom_queries_list) != len(all_vlm_responses): + return set() + for i, (atom_query, curr_vlm_output_line) in enumerate( + zip(atom_queries_list, all_vlm_responses)): + try: + assert atom_query + ":" in curr_vlm_output_line + assert "." in curr_vlm_output_line + value = curr_vlm_output_line.split(': ')[-1].strip('.').lower() + if value == "true": + true_atoms.add(vlm_atoms[i]) + except AssertionError: # pragma: no cover + continue + return true_atoms + + +def abstract(state: State, + preds: Collection[Predicate], + vlm: Optional[VisionLanguageModel] = None) -> Set[GroundAtom]: + """Get the atomic representation of the given state (i.e., a set of ground + atoms), using the given set of predicates. + + Duplicate arguments in predicates are allowed. Latent-aware + classifiers (`agent_sim_recurrent_predicate_invention`) read their + latent from `state.latent` via `Predicate.holds` — abstract itself + does nothing extra to support them. + """ + # Start by pulling out all VLM predicates. + vlm_preds = set(pred for pred in preds if isinstance(pred, VLMPredicate)) + derived_preds, primitive_preds = set(), set() + for pred in preds: + if isinstance(pred, DerivedPredicate): + derived_preds.add(pred) + else: + primitive_preds.add(pred) + + # Next, classify all non-VLM predicates. + atoms = set() + for pred in primitive_preds: + if pred not in vlm_preds: + for choice in get_object_combinations(list(state), pred.types): + if pred.holds(state, choice): + atoms.add(GroundAtom(pred, choice)) + if len(vlm_preds) > 0: + # Now, aggregate all the VLM predicates and make a single call to a + # VLM to get their values. + vlm_atoms = set() + for pred in vlm_preds: + for choice in get_object_combinations(list(state), pred.types): + vlm_atoms.add(GroundAtom(pred, choice)) + true_vlm_atoms = query_vlm_for_atom_vals(vlm_atoms, state, vlm) + atoms |= true_vlm_atoms + + # Evaluate derived predicates. + if len(derived_preds) > 0: + try: + atoms |= abstract_with_derived_predicates(atoms, derived_preds, + list(state)) + except PredicateEvaluationError as e: + raise e + # buggy_pred = e.pred + # # logging.debug(f"preds before {buggy_pred} is removed: {preds}") + # cnpt_preds.remove(buggy_pred) + # # logging.debug(f"preds after {buggy_pred} is removed: {preds}") + # return abstract(state, prim_preds | cnpt_preds, vlm, + # return_valid_preds) + return atoms + + +def all_ground_operators( + operator: STRIPSOperator, + objects: Collection[Object]) -> Iterator[_GroundSTRIPSOperator]: + """Get all possible groundings of the given operator with the given + objects.""" + types = [p.type for p in operator.parameters] + for choice in get_object_combinations(objects, types): + yield operator.ground(tuple(choice)) + + +def all_ground_operators_given_partial( + operator: STRIPSOperator, objects: Collection[Object], + sub: VarToObjSub) -> Iterator[_GroundSTRIPSOperator]: + """Get all possible groundings of the given operator with the given objects + such that the parameters are consistent with the given substitution.""" + assert set(sub).issubset(set(operator.parameters)) + types = [p.type for p in operator.parameters if p not in sub] + for choice in get_object_combinations(objects, types): + # Complete the choice with the args that are determined from the sub. + choice_lst = list(choice) + choice_lst.reverse() + completed_choice = [] + for p in operator.parameters: + if p in sub: + completed_choice.append(sub[p]) + else: + completed_choice.append(choice_lst.pop()) + assert not choice_lst + ground_op = operator.ground(tuple(completed_choice)) + yield ground_op + + +def all_ground_nsrts(nsrt: Union[NSRT, CausalProcess], + objects: Collection[Object]) -> Iterator[_GroundNSRT]: + """Get all possible groundings of the given NSRT with the given objects.""" + types = [p.type for p in nsrt.parameters] + for choice in get_object_combinations(objects, types): + # only return if there are no repeated arguments + if CFG.no_repeated_arguments_in_grounding: + if len(choice) == len(set(choice)): + yield nsrt.ground(tuple(choice)) # type: ignore[misc] + else: + yield nsrt.ground(tuple(choice)) # type: ignore[misc] + + +def all_ground_nsrts_fd_translator( + nsrts: Set[NSRT], objects: Collection[Object], + predicates: Set[Predicate], types: Set[Type], + init_atoms: Set[GroundAtom], + goal: Set[GroundAtom]) -> Iterator[_GroundNSRT]: + """Get all possible groundings of the given set of NSRTs with the given + objects, using Fast Downward's translator for efficiency.""" + nsrt_name_to_nsrt = {nsrt.name.lower(): nsrt for nsrt in nsrts} + obj_name_to_obj = {obj.name.lower(): obj for obj in objects} + dom_str = create_pddl_domain(nsrts, predicates, types, "mydomain") + prob_str = create_pddl_problem(objects, init_atoms, goal, "mydomain", + "myproblem") + with nostdout(): + sas_task = downward_translate(dom_str, prob_str) # type: ignore + for operator in sas_task.operators: + split_name = operator.name[1:-1].split() # strip out ( and ) + nsrt = nsrt_name_to_nsrt[split_name[0]] + objs = [obj_name_to_obj[name] for name in split_name[1:]] + yield nsrt.ground(objs) + + +def all_possible_ground_atoms(state: State, + preds: Set[Predicate]) -> List[GroundAtom]: + """Get a sorted list of all possible ground atoms in a state given the + predicates. + + Ignores the predicates' classifiers. + """ + objects = frozenset(state) + ground_atoms = set() + for pred in preds: + ground_atoms |= get_all_ground_atoms_for_predicate(pred, objects) + return sorted(ground_atoms) + + +def all_ground_ldl_rules( + rule: LDLRule, + objects: Collection[Object], + static_predicates: Optional[Collection[Predicate]] = None, + init_atoms: Optional[Collection[GroundAtom]] = None +) -> List[_GroundLDLRule]: + """Get all possible groundings of the given rule with the given objects. + + If provided, use the static predicates and init_atoms to avoid + grounding rules that will never have satisfied preconditions in any + state. + """ + if static_predicates is None: + static_predicates = set() + if init_atoms is None: + init_atoms = set() + return _cached_all_ground_ldl_rules(rule, frozenset(objects), + frozenset(static_predicates), + frozenset(init_atoms)) + + +@functools.lru_cache(maxsize=None) +def _cached_all_ground_ldl_rules( + rule: LDLRule, objects: FrozenSet[Object], + static_predicates: FrozenSet[Predicate], + init_atoms: FrozenSet[GroundAtom]) -> List[_GroundLDLRule]: + """Helper for all_ground_ldl_rules() that caches the outputs.""" + ground_rules = [] + # Use static preconds to reduce the map of parameters to possible objects. + # For example, if IsBall(?x) is a positive state precondition, then only + # the objects that appear in init_atoms with IsBall could bind to ?x. + # For now, we just check unary static predicates, since that covers the + # common case where such predicates are used in place of types. + # Create map from each param to unary static predicates. + param_to_pos_preds: Dict[Variable, Set[Predicate]] = { + p: set() + for p in rule.parameters + } + param_to_neg_preds: Dict[Variable, Set[Predicate]] = { + p: set() + for p in rule.parameters + } + for (preconditions, param_to_preds) in [ + (rule.pos_state_preconditions, param_to_pos_preds), + (rule.neg_state_preconditions, param_to_neg_preds), + ]: + for atom in preconditions: + pred = atom.predicate + if pred in static_predicates and pred.arity == 1: + param = atom.variables[0] + param_to_preds[param].add(pred) + # Create the param choices, filtering based on the unary static atoms. + param_choices = [] # list of lists of possible objects for each param + # Preprocess the atom sets for faster lookups. + init_atom_tups = {(a.predicate, tuple(a.objects)) for a in init_atoms} + for param in rule.parameters: + choices = [] + for obj in objects: + # Types must match, as usual. + if obj.type != param.type: + continue + # Check the static conditions. + binding_valid = True + for pred in param_to_pos_preds[param]: + if (pred, (obj, )) not in init_atom_tups: + binding_valid = False + break + for pred in param_to_neg_preds[param]: + if (pred, (obj, )) in init_atom_tups: + binding_valid = False + break + if binding_valid: + choices.append(obj) + # Must be sorted for consistency with other grounding code. + param_choices.append(sorted(choices)) + for choice in itertools.product(*param_choices): + ground_rule = rule.ground(choice) + ground_rules.append(ground_rule) + return ground_rules + + +def parse_ldl_from_str(ldl_str: str, types: Collection[Type], + predicates: Collection[Predicate], + nsrts: Collection[NSRT]) -> LiftedDecisionList: + """Parse a lifted decision list from a string representation of it.""" + parser = _LDLParser(types, predicates, nsrts) + return parser.parse(ldl_str) + + +class _LDLParser: + """Parser for lifted decision lists from strings.""" + + def __init__(self, types: Collection[Type], + predicates: Collection[Predicate], + nsrts: Collection[NSRT]) -> None: + self._nsrt_name_to_nsrt = {nsrt.name.lower(): nsrt for nsrt in nsrts} + self._type_name_to_type = {t.name.lower(): t for t in types} + self._predicate_name_to_predicate = { + p.name.lower(): p + for p in predicates + } + + def parse(self, ldl_str: str) -> LiftedDecisionList: + """Run parsing.""" + ldl_str = ldl_str.lower() # ignore case during parsing + rules = [] + rule_matches = re.finditer(r"\(:rule", ldl_str) + for start in rule_matches: + rule_str = find_balanced_expression(ldl_str, start.start()) + rule = self._parse_rule(rule_str) + rules.append(rule) + return LiftedDecisionList(rules) + + def _parse_rule(self, rule_str: str) -> LDLRule: + rule_pattern = r"\(:rule(.*):parameters(.*):preconditions(.*)" + \ + r":goals(.*):action(.*)\)" + match_result = re.match(rule_pattern, rule_str, re.DOTALL) + assert match_result is not None + # Remove white spaces. + matches = [m.strip().rstrip() for m in match_result.groups()] + # Unpack the matches. + rule_name, params_str, preconds_str, goals_str, nsrt_str = matches + # Handle the parameters. + assert "?" in params_str, "Assuming all rules have parameters." + variable_name_to_variable = {} + assert params_str.endswith(")") + for param_str in params_str[:-1].split("?")[1:]: + param_name, param_type_str = param_str.split("-") + param_name = param_name.strip() + param_type_str = param_type_str.strip() + variable_name = "?" + param_name + param_type = self._type_name_to_type[param_type_str] + variable = Variable(variable_name, param_type) + variable_name_to_variable[variable_name] = variable + # Handle the preconditions. + pos_preconds, neg_preconds = self._parse_lifted_atoms( + preconds_str, variable_name_to_variable) + # Handle the goals. + pos_goals, neg_goals = self._parse_lifted_atoms( + goals_str, variable_name_to_variable) + assert not neg_goals, "Negative LDL goals not currently supported" + # Handle the NSRT. + nsrt = self._parse_into_nsrt(nsrt_str, variable_name_to_variable) + # Finalize the rule. + params = sorted(variable_name_to_variable.values()) + return LDLRule(rule_name, params, pos_preconds, neg_preconds, + pos_goals, nsrt) + + def _parse_lifted_atoms( + self, atoms_str: str, variable_name_to_variable: Dict[str, Variable] + ) -> Tuple[Set[LiftedAtom], Set[LiftedAtom]]: + """Parse the given string (representing either preconditions or + effects) into a set of positive lifted atoms and a set of negative + lifted atoms. + + Check against params to make sure typing is correct. + """ + assert atoms_str[0] == "(" + assert atoms_str[-1] == ")" + + # Handle conjunctions. + if atoms_str.startswith("(and") and atoms_str[4] in (" ", "\n", "("): + clauses = find_all_balanced_expressions(atoms_str[4:-1].strip()) + pos_atoms, neg_atoms = set(), set() + for clause in clauses: + clause_pos_atoms, clause_neg_atoms = self._parse_lifted_atoms( + clause, variable_name_to_variable) + pos_atoms |= clause_pos_atoms + neg_atoms |= clause_neg_atoms + return pos_atoms, neg_atoms + + # Handle negations. + if atoms_str.startswith("(not") and atoms_str[4] in (" ", "\n", "("): + # Only contains a single literal inside not. + split_strs = atoms_str[4:-1].strip()[1:-1].strip().split() + pred = self._predicate_name_to_predicate[split_strs[0]] + args = [variable_name_to_variable[arg] for arg in split_strs[1:]] + lifted_atom = LiftedAtom(pred, args) + return set(), {lifted_atom} + + # Handle single positive atoms. + split_strs = atoms_str[1:-1].split() + # Empty conjunction. + if not split_strs: + return set(), set() + pred = self._predicate_name_to_predicate[split_strs[0]] + args = [variable_name_to_variable[arg] for arg in split_strs[1:]] + lifted_atom = LiftedAtom(pred, args) + return {lifted_atom}, set() + + def _parse_into_nsrt( + self, nsrt_str: str, + variable_name_to_variable: Dict[str, Variable]) -> NSRT: + """Parse the given string into an NSRT.""" + assert nsrt_str[0] == "(" + assert nsrt_str[-1] == ")" + nsrt_str = nsrt_str[1:-1].split()[0] + nsrt = self._nsrt_name_to_nsrt[nsrt_str] + # Validate parameters. + variables = variable_name_to_variable.values() + for v in nsrt.parameters: + assert v in variables, f"NSRT parameter {v} missing from LDL rule" + return nsrt + + +_T = TypeVar("_T") # element of a set + + +def sample_subsets(universe: Sequence[_T], num_samples: int, min_set_size: int, + max_set_size: int, + rng: np.random.Generator) -> Iterator[Set[_T]]: + """Sample multiple subsets from a universe.""" + assert min_set_size <= max_set_size + assert max_set_size <= len(universe), "Not enough elements in universe" + for _ in range(num_samples): + set_size = rng.integers(min_set_size, max_set_size + 1) + idxs = rng.choice(np.arange(len(universe)), + size=set_size, + replace=False) + sample = {universe[i] for i in idxs} + yield sample + + +def create_dataset_filename_str( + saving_ground_atoms: bool, + online_learning_cycle: Optional[int] = None) -> Tuple[str, str]: + """Generate strings to be used for the filename for a dataset file that is + about to be saved. + + Returns a tuple of strings where the first element is the dataset + filename itself and the second is a template string used to generate + it. If saving_ground_atoms is True, then we will name the file with + a "_ground_atoms" suffix. + """ + # Setup the dataset filename for saving/loading GroundAtoms. + regex = r"(\d+)" + suffix_str = "" + suffix_str += f"__{online_learning_cycle}" + if saving_ground_atoms: + suffix_str += "__ground_atoms" + suffix_str += ".data" + dataset_fname_template = ( + f"{CFG.env}__{CFG.offline_data_method}__{CFG.demonstrator}__" + f"{regex}__{CFG.included_options}__{CFG.seed}" + suffix_str) + dataset_fname = os.path.join( + CFG.data_dir, + dataset_fname_template.replace(regex, str(CFG.num_train_tasks))) + return dataset_fname, dataset_fname_template + + +def create_ground_atom_dataset( + trajectories: Sequence[LowLevelTrajectory], + predicates: Set[Predicate]) -> List[GroundAtomTrajectory]: + """Apply all predicates to all trajectories in the dataset.""" + ground_atom_dataset = [] + for traj in trajectories: + atoms = [abstract(s, predicates) for s in traj.states] + ground_atom_dataset.append((traj, atoms)) + return ground_atom_dataset + + +def create_ground_atom_option_dataset( + trajectories: List[LowLevelTrajectory], + predicates: Set[Predicate]) -> List[AtomOptionTrajectory]: + """Apply all predicates to all trajectories in the dataset and also + annotate with options (HLA).""" + ground_atom_option_dataset = [] + for traj in trajectories: + # Note: this is currently just based on the current states. + # We may want to extend this to state history in the future. + atoms = [abstract(s, predicates) for s in traj.states] + options = [a.get_option() for a in traj.actions] + ground_atom_option_dataset.append( + AtomOptionTrajectory( + traj.states, atoms, options, traj.is_demo, + traj.train_task_idx if traj.is_demo else None)) + return ground_atom_option_dataset + + +def prune_ground_atom_dataset( + ground_atom_dataset: List[GroundAtomTrajectory], + kept_predicates: Collection[Predicate]) -> List[GroundAtomTrajectory]: + """Create a new ground atom dataset by keeping only some predicates.""" + new_ground_atom_dataset = [] + for traj, atoms in ground_atom_dataset: + assert len(traj.states) == len(atoms) + kept_atoms = [{a + for a in sa if a.predicate in kept_predicates} + for sa in atoms] + new_ground_atom_dataset.append((traj, kept_atoms)) + return new_ground_atom_dataset + + +def load_ground_atom_dataset( + dataset_fname: str, + trajectories: List[LowLevelTrajectory]) -> List[GroundAtomTrajectory]: + """Load a previously-saved ground atom dataset. + + Note importantly that we only save atoms themselves, we don't save + the low-level trajectory information that's necessary to make + GroundAtomTrajectories given series of ground atoms (that info can + be saved separately, in case one wants to just load trajectories and + not also load ground atoms). Thus, this function needs to take these + trajectories as input. + """ + os.makedirs(CFG.data_dir, exist_ok=True) + # Check that the dataset file was previously saved. + ground_atom_dataset_atoms: Optional[List[List[Set[GroundAtom]]]] = [] + if os.path.exists(dataset_fname): + # Load the ground atoms dataset. + with open(dataset_fname, "rb") as f: + ground_atom_dataset_atoms = pkl.load(f) + assert ground_atom_dataset_atoms is not None + assert len(trajectories) == len(ground_atom_dataset_atoms) + logging.info("\n\nLOADED GROUND ATOM DATASET") + + # The saved ground atom dataset consists only of sequences + # of sets of GroundAtoms, we need to recombine this with + # the LowLevelTrajectories to create a GroundAtomTrajectory. + ground_atom_dataset = [] + for i, traj in enumerate(trajectories): + ground_atom_seq = ground_atom_dataset_atoms[i] + ground_atom_dataset.append( + (traj, [set(atoms) for atoms in ground_atom_seq])) + else: + raise ValueError(f"Cannot load ground atoms: {dataset_fname}") + return ground_atom_dataset + + +def save_ground_atom_dataset(ground_atom_dataset: List[GroundAtomTrajectory], + dataset_fname: str) -> None: + """Saves a given ground atom dataset so it can be loaded in the future.""" + # Save ground atoms dataset to file. Note that a + # GroundAtomTrajectory contains a normal LowLevelTrajectory and a + # list of sets of GroundAtoms, so we only save the list of + # GroundAtoms (the LowLevelTrajectories are saved separately). + ground_atom_dataset_to_pkl = [] + for gt_traj in ground_atom_dataset: + trajectory = [] + for ground_atom_set in gt_traj[1]: + trajectory.append(ground_atom_set) + ground_atom_dataset_to_pkl.append(trajectory) + with open(dataset_fname, "wb") as f: + pkl.dump(ground_atom_dataset_to_pkl, f) + + +def merge_ground_atom_datasets( + gad1: List[GroundAtomTrajectory], + gad2: List[GroundAtomTrajectory]) -> List[GroundAtomTrajectory]: + """Merges two ground atom datasets sharing the same underlying low-level + trajectory via the union of ground atoms at each state.""" + assert len(gad1) == len( + gad2), "Ground atom datasets must be of the same length to merge them." + merged_ground_atom_dataset = [] + for ground_atom_traj1, ground_atom_traj2 in zip(gad1, gad2): + ll_traj1, ga_list1 = ground_atom_traj1 + ll_traj2, ga_list2 = ground_atom_traj2 + assert ll_traj1 == ll_traj2, "Ground atom trajectories must share " \ + "the same low-level trajectory to be able to merge them." + merged_ga_list = [ga1 | ga2 for ga1, ga2 in zip(ga_list1, ga_list2)] + merged_ground_atom_dataset.append((ll_traj1, merged_ga_list)) + return merged_ground_atom_dataset + + +def extract_preds_and_types( + ops: Collection[NSRTOrSTRIPSOperator] +) -> Tuple[Dict[str, Predicate], Dict[str, Type]]: + """Extract the predicates and types used in the given operators.""" + preds = {} + types = {} + for op in ops: + for atom in op.preconditions | op.add_effects | op.delete_effects: + for var_type in atom.predicate.types: + types[var_type.name] = var_type + preds[atom.predicate.name] = atom.predicate + return preds, types + + +def get_static_preds(ops: Collection[NSRTOrSTRIPSOperator], + predicates: Collection[Predicate]) -> Set[Predicate]: + """Get the subset of predicates from the given set that are static with + respect to the given lifted operators.""" + static_preds = set() + for pred in predicates: + # This predicate is not static if it appears in any op's effects. + if any( + any(atom.predicate == pred for atom in op.add_effects) or any( + atom.predicate == pred for atom in op.delete_effects) + for op in ops): + continue + static_preds.add(pred) + return static_preds + + +def get_static_atoms(ground_ops: Collection[GroundNSRTOrSTRIPSOperator], + atoms: Collection[GroundAtom]) -> Set[GroundAtom]: + """Get the subset of atoms from the given set that are static with respect + to the given ground operators. + + Note that this can include MORE than simply the set of atoms whose + predicates are static, because now we have ground operators. + """ + static_atoms = set() + for atom in atoms: + # This atom is not static if it appears in any op's effects. + if any( + any(atom == eff for eff in op.add_effects) or any( + atom == eff for eff in op.delete_effects) + for op in ground_ops): + continue + static_atoms.add(atom) + return static_atoms + + +def get_reachable_atoms(ground_ops: Collection[GroundNSRTOrSTRIPSOperator], + atoms: Collection[GroundAtom]) -> Set[GroundAtom]: + """Get all atoms that are reachable from the init atoms.""" + reachables = set(atoms) + while True: + fixed_point_reached = True + for op in ground_ops: + if op.preconditions.issubset(reachables): + for new_reachable_atom in op.add_effects - reachables: + fixed_point_reached = False + reachables.add(new_reachable_atom) + if fixed_point_reached: + break + return reachables + + +def get_applicable_operators( + ground_ops: Collection[Union[GroundNSRTOrSTRIPSOperator, + _GroundEndogenousProcess]], + atoms: Collection[GroundAtom] +) -> Iterator[Union[GroundNSRTOrSTRIPSOperator, _GroundEndogenousProcess]]: + """Iterate over ground operators whose preconditions are satisfied. + + Note: the order may be nondeterministic. Users should be invariant. + """ + for op in ground_ops: + if isinstance(op, (_GroundNSRT, _GroundSTRIPSOperator)): + applicable = op.preconditions.issubset(atoms) + elif isinstance(op, _GroundEndogenousProcess): + applicable = op.condition_at_start.issubset(atoms) + + if applicable: + yield op + + +def apply_operator(op: GroundNSRTOrSTRIPSOperator, + atoms: Set[GroundAtom]) -> Set[GroundAtom]: + """Get a next set of atoms given a current set and a ground operator.""" + # Note that we are removing the ignore effects before the + # application of the operator, because if the ignore effect + # appears in the effects, we still know that the effects + # will be true, so we don't want to remove them. + new_atoms = {a for a in atoms if a.predicate not in op.ignore_effects} + for atom in op.delete_effects: + new_atoms.discard(atom) + for atom in op.add_effects: + new_atoms.add(atom) + return new_atoms + + +def compute_necessary_atoms_seq( + skeleton: List[_GroundNSRT], atoms_seq: List[Set[GroundAtom]], + goal: Set[GroundAtom]) -> List[Set[GroundAtom]]: + """Given a skeleton and a corresponding atoms sequence, return a + 'necessary' atoms sequence that includes only the necessary image at each + step.""" + necessary_atoms_seq = [set(goal)] + necessary_image = set(goal) + for t in range(len(atoms_seq) - 2, -1, -1): + curr_nsrt = skeleton[t] + necessary_image -= set(curr_nsrt.add_effects) + necessary_image |= set(curr_nsrt.preconditions) + necessary_atoms_seq = [set(necessary_image)] + necessary_atoms_seq + return necessary_atoms_seq + + +def get_successors_from_ground_ops( + atoms: Set[GroundAtom], + ground_ops: Collection[GroundNSRTOrSTRIPSOperator], + unique: bool = True) -> Iterator[Set[GroundAtom]]: + """Get all next atoms from ground operators. + + If unique is true, only yield each unique successor once. + """ + seen_successors = set() + for ground_op in get_applicable_operators(ground_ops, atoms): + next_atoms = apply_operator(ground_op, atoms) # type: ignore[type-var] + if unique: + frozen_next_atoms = frozenset(next_atoms) + if frozen_next_atoms in seen_successors: + continue + seen_successors.add(frozen_next_atoms) + yield next_atoms + + +def ops_and_specs_to_dummy_nsrts( + strips_ops: Sequence[STRIPSOperator], + option_specs: Sequence[OptionSpec]) -> Set[NSRT]: + """Create NSRTs from strips operators and option specs with dummy + samplers.""" + assert len(strips_ops) == len(option_specs) + nsrts = set() + for op, (param_option, option_vars) in zip(strips_ops, option_specs): + nsrt = op.make_nsrt( + param_option, + option_vars, # dummy sampler + lambda s, g, rng, o: np.zeros(1, dtype=np.float32)) + nsrts.add(nsrt) + return nsrts + + +# Note: create separate `heuristics.py` module if we need to add new +# heuristics in the future. + + +def create_task_planning_heuristic( + heuristic_name: str, + init_atoms: Set[GroundAtom], + goal: Set[GroundAtom], + ground_ops: Collection[GroundNSRTOrSTRIPSOperator], + predicates: Collection[Predicate], + objects: Collection[Object], +) -> _TaskPlanningHeuristic: + """Create a task planning heuristic that consumes ground atoms and + estimates the cost-to-go.""" + if heuristic_name in _PYPERPLAN_HEURISTICS: + return _create_pyperplan_heuristic(heuristic_name, init_atoms, goal, + ground_ops, predicates, objects) + if heuristic_name == GoalCountHeuristic.HEURISTIC_NAME: + return GoalCountHeuristic(heuristic_name, init_atoms, goal, ground_ops) + raise ValueError(f"Unrecognized heuristic name: {heuristic_name}.") + + +@dataclass(frozen=True) +class _TaskPlanningHeuristic: + """A task planning heuristic.""" + name: str + init_atoms: Collection[GroundAtom] + goal: Set[GroundAtom] + ground_ops: Collection[Union[_GroundNSRT, _GroundSTRIPSOperator]] + + def __call__(self, atoms: Collection[GroundAtom]) -> float: + raise NotImplementedError("Override me!") + + +class GoalCountHeuristic(_TaskPlanningHeuristic): + """The number of goal atoms that are not in the current state.""" + HEURISTIC_NAME: ClassVar[str] = "goal_count" + + def __call__(self, atoms: Collection[GroundAtom]) -> float: + return len(self.goal.difference(atoms)) + + +############################### Pyperplan Glue ############################### + + +def _create_pyperplan_heuristic( + heuristic_name: str, + init_atoms: Set[GroundAtom], + goal: Set[GroundAtom], + ground_ops: Collection[GroundNSRTOrSTRIPSOperator], + predicates: Collection[Predicate], + objects: Collection[Object], +) -> _PyperplanHeuristicWrapper: + """Create a pyperplan heuristic that inherits from + _TaskPlanningHeuristic.""" + assert heuristic_name in _PYPERPLAN_HEURISTICS + static_atoms = get_static_atoms(ground_ops, init_atoms) + pyperplan_heuristic_cls = _PYPERPLAN_HEURISTICS[heuristic_name] + pyperplan_task = _create_pyperplan_task(init_atoms, goal, ground_ops, + predicates, objects, static_atoms) + pyperplan_heuristic = pyperplan_heuristic_cls(pyperplan_task) + pyperplan_goal = _atoms_to_pyperplan_facts(goal - static_atoms) + return _PyperplanHeuristicWrapper(heuristic_name, init_atoms, goal, + ground_ops, static_atoms, + pyperplan_heuristic, pyperplan_goal) + + +_PyperplanFacts = FrozenSet[str] + + +@dataclass(frozen=True) +class _PyperplanNode: + """Container glue for pyperplan heuristics.""" + state: _PyperplanFacts + goal: _PyperplanFacts + + +@dataclass(frozen=True) +class _PyperplanOperator: + """Container glue for pyperplan heuristics.""" + name: str + preconditions: _PyperplanFacts + add_effects: _PyperplanFacts + del_effects: _PyperplanFacts + + +@dataclass(frozen=True) +class _PyperplanTask: + """Container glue for pyperplan heuristics.""" + facts: _PyperplanFacts + initial_state: _PyperplanFacts + goals: _PyperplanFacts + operators: Collection[_PyperplanOperator] + + +@dataclass(frozen=True) +class _PyperplanHeuristicWrapper(_TaskPlanningHeuristic): + """A light wrapper around pyperplan's heuristics.""" + _static_atoms: Set[GroundAtom] + _pyperplan_heuristic: _PyperplanBaseHeuristic + _pyperplan_goal: _PyperplanFacts + + def __call__(self, atoms: Collection[GroundAtom]) -> float: + # Note: filtering out static atoms. + pyperplan_facts = _atoms_to_pyperplan_facts(set(atoms) \ + - self._static_atoms) + return self._evaluate(pyperplan_facts, self._pyperplan_goal, + self._pyperplan_heuristic) + + @staticmethod + @functools.lru_cache(maxsize=None) + def _evaluate(pyperplan_facts: _PyperplanFacts, + pyperplan_goal: _PyperplanFacts, + pyperplan_heuristic: _PyperplanBaseHeuristic) -> float: + pyperplan_node = _PyperplanNode(pyperplan_facts, pyperplan_goal) + logging.disable(logging.DEBUG) + result = pyperplan_heuristic(pyperplan_node) + logging.disable(logging.NOTSET) + return result + + +def _create_pyperplan_task( + init_atoms: Set[GroundAtom], + goal: Set[GroundAtom], + ground_ops: Collection[GroundNSRTOrSTRIPSOperator], + predicates: Collection[Predicate], + objects: Collection[Object], + static_atoms: Set[GroundAtom], +) -> _PyperplanTask: + """Helper glue for pyperplan heuristics.""" + all_atoms = set() + for predicate in predicates: + all_atoms.update( + get_all_ground_atoms_for_predicate(predicate, frozenset(objects))) + # Note: removing static atoms. + pyperplan_facts = _atoms_to_pyperplan_facts(all_atoms - static_atoms) + pyperplan_state = _atoms_to_pyperplan_facts(init_atoms - static_atoms) + pyperplan_goal = _atoms_to_pyperplan_facts(goal - static_atoms) + pyperplan_operators = set() + for op in ground_ops: + # Note: the pyperplan operator must include the objects, because hFF + # uses the operator name in constructing the relaxed plan, and the + # relaxed plan is a set. If we instead just used op.name, there would + # be a very nasty bug where two ground operators in the relaxed plan + # that have different objects are counted as just one. + name = op.name + "-".join(o.name for o in op.objects) + pyperplan_operator = _PyperplanOperator( + name, + # Note: removing static atoms from preconditions. + _atoms_to_pyperplan_facts(op.preconditions - static_atoms), + _atoms_to_pyperplan_facts(op.add_effects), + _atoms_to_pyperplan_facts(op.delete_effects)) + pyperplan_operators.add(pyperplan_operator) + return _PyperplanTask(pyperplan_facts, pyperplan_state, pyperplan_goal, + pyperplan_operators) + + +@functools.lru_cache(maxsize=None) +def _atom_to_pyperplan_fact(atom: GroundAtom) -> str: + """Convert atom to tuple for interface with pyperplan.""" + arg_str = " ".join(o.name for o in atom.objects) + return f"({atom.predicate.name} {arg_str})" + + +def _atoms_to_pyperplan_facts( + atoms: Collection[GroundAtom]) -> _PyperplanFacts: + """Light wrapper around _atom_to_pyperplan_fact() that operates on a + collection of atoms.""" + return frozenset({_atom_to_pyperplan_fact(atom) for atom in atoms}) + + +############################## End Pyperplan Glue ############################## + + +def create_pddl_types_str(types: Collection[Type]) -> str: + """Create a PDDL-style types string that handles hierarchy correctly.""" + # Case 1: no type hierarchy. + if all(t.parent is None for t in types): + types_str = " ".join(t.name for t in sorted(types)) + # Case 2: type hierarchy. + else: + parent_to_children_types: Dict[Type, + List[Type]] = {t: [] + for t in types} + for t in sorted(types): + if t.parent: + parent_to_children_types[t.parent].append(t) + types_str = "" + for parent_type in sorted(parent_to_children_types): + child_types = parent_to_children_types[parent_type] + if not child_types: + # Special case: type has no children and also does not appear + # as a child of another type. + is_child_type = any( + parent_type in children + for children in parent_to_children_types.values()) + if not is_child_type: + types_str += f"\n {parent_type.name}" + # Otherwise, the type will appear as a child elsewhere. + else: + child_type_str = " ".join(t.name for t in child_types) + types_str += f"\n {child_type_str} - {parent_type.name}" + return types_str + + +def create_pddl_domain(operators: Collection[NSRTOrSTRIPSOperator], + predicates: Collection[Predicate], + types: Collection[Type], domain_name: str) -> str: + """Create a PDDL domain str from STRIPSOperators or NSRTs.""" + # Sort everything to ensure determinism. + preds_lst = sorted(predicates) + types_str = create_pddl_types_str(types) + ops_lst = sorted(operators) + preds_str = "\n ".join(pred.pddl_str() for pred in preds_lst) + ops_strs = "\n\n ".join(op.pddl_str() for op in ops_lst) + return f"""(define (domain {domain_name}) + (:requirements :typing) + (:types {types_str}) + + (:predicates\n {preds_str} + ) + + {ops_strs} +)""" + + +def create_pddl_problem(objects: Collection[Object], + init_atoms: Collection[GroundAtom], + goal: Set[GroundAtom], domain_name: str, + problem_name: str) -> str: + """Create a PDDL problem str.""" + # Sort everything to ensure determinism. + objects_lst = sorted(objects) + init_atoms_lst = sorted(init_atoms) + goal_lst = sorted(goal) + objects_str = "\n ".join(f"{o.name} - {o.type.name}" + for o in objects_lst) + init_str = "\n ".join(atom.pddl_str() for atom in init_atoms_lst) + goal_str = "\n ".join(atom.pddl_str() for atom in goal_lst) + return f"""(define (problem {problem_name}) (:domain {domain_name}) + (:objects\n {objects_str} + ) + (:init\n {init_str} + ) + (:goal (and {goal_str})) +) +""" + + +@functools.lru_cache(maxsize=None) +def get_failure_predicate(option: ParameterizedOption, + idxs: Tuple[int]) -> Predicate: + """Create a Failure predicate for a parameterized option.""" + idx_str = ",".join(map(str, idxs)) + arg_types = [option.types[i] for i in idxs] + return Predicate(f"{option.name}Failed_arg{idx_str}", + arg_types, + _classifier=lambda s, o: False) + + +def _get_idxs_to_failure_predicate( + option: ParameterizedOption, + max_arity: int = 1) -> Dict[Tuple[int, ...], Predicate]: + """Helper for get_all_failure_predicates() and get_failure_atoms().""" + idxs_to_failure_predicate: Dict[Tuple[int, ...], Predicate] = {} + num_types = len(option.types) + max_num_idxs = min(max_arity, num_types) + all_idxs = list(range(num_types)) + for arity in range(1, max_num_idxs + 1): + for idxs in itertools.combinations(all_idxs, arity): + pred = get_failure_predicate(option, idxs) + idxs_to_failure_predicate[idxs] = pred + return idxs_to_failure_predicate + + +def get_all_failure_predicates(options: Set[ParameterizedOption], + max_arity: int = 1) -> Set[Predicate]: + """Get all possible failure predicates.""" + failure_preds: Set[Predicate] = set() + for param_opt in options: + preds = _get_idxs_to_failure_predicate(param_opt, max_arity=max_arity) + failure_preds.update(preds.values()) + return failure_preds + + +def get_failure_atoms(failed_options: Collection[_Option], + max_arity: int = 1) -> Set[GroundAtom]: + """Get ground failure atoms for the collection of failure options.""" + failure_atoms: Set[GroundAtom] = set() + failed_option_specs = {(o.parent, tuple(o.objects)) + for o in failed_options} + for (param_opt, objs) in failed_option_specs: + preds = _get_idxs_to_failure_predicate(param_opt, max_arity=max_arity) + for idxs, pred in preds.items(): + obj_for_idxs = [objs[i] for i in idxs] + failure_atom = GroundAtom(pred, obj_for_idxs) + failure_atoms.add(failure_atom) + return failure_atoms + + +@dataclass +class VideoMonitor(LoggingMonitor): + """A monitor that renders each state and action encountered. + + The render_fn is generally env.render. Note that the state is unused + because the environment should use its current internal state to + render. + """ + _render_fn: Callable[[Optional[Action], Optional[str]], Video] + _video: Video = field(init=False, default_factory=list) + + def reset(self, train_or_test: str, task_idx: int) -> None: + self._video = [] + + def observe(self, obs: Observation, action: Optional[Action]) -> None: + del obs # unused + self._video.extend(self._render_fn(action, None)) + + def get_video(self) -> Video: + """Return the video.""" + return self._video + + +@dataclass +class SimulateVideoMonitor(LoggingMonitor): + """A monitor that calls render_state on each state and action seen. + + This monitor is meant for use with run_policy_with_simulator, as + opposed to VideoMonitor, which is meant for use with run_policy. + """ + _task: Task + _render_state_fn: Callable[[State, Task, Optional[Action]], Video] + _video: Video = field(init=False, default_factory=list) + + def reset(self, train_or_test: str, task_idx: int) -> None: + self._video = [] + + def observe(self, obs: Observation, action: Optional[Action]) -> None: + assert isinstance(obs, State) + self._video.extend(self._render_state_fn(obs, self._task, action)) + + def get_video(self) -> Video: + """Return the video.""" + return self._video + + +def create_video_from_partial_refinements( + partial_refinements: Sequence[Tuple[Sequence[_GroundNSRT], + Sequence[_Option]]], + env: BaseEnv, + train_or_test: str, + task_idx: int, + max_num_steps: int, +) -> Video: + """Create a video from a list of skeletons and partial refinements. + + Note that the environment internal state is updated. + """ + # Right now, the video is created by finding the longest partial + # refinement. One could also implement an "all_skeletons" mode + # that would create one video per skeleton. + if CFG.failure_video_mode == "longest_only": + # Visualize only the overall longest failed plan. + _, plan = max(partial_refinements, key=lambda x: len(x[1])) + policy = option_plan_to_policy(plan) + video: Video = [] + logging.debug("reset env for create video") + state = env.reset(train_or_test, task_idx) + # logging.debug(f"{pformat(state.pretty_str())}") + for _i in range(max_num_steps): + # logging.debug(f"state: {state.pretty_str()}") + try: + act = policy(state) + # logging.debug(f"act: {act}") + except OptionExecutionFailure: + video.extend(env.render()) + if not CFG.video_not_break_on_exception: + break + else: + video.extend(env.render(act)) + # logging.debug("Finished rendering.") + try: + state = env.step(act) + except EnvironmentFailure: + break + return video + raise NotImplementedError("Unrecognized failure video mode: " + f"{CFG.failure_video_mode}.") + + +def fig2data(fig: matplotlib.figure.Figure, dpi: int) -> Image: + """Convert matplotlib figure into Image.""" + fig.set_dpi(dpi) + fig.canvas.draw() + data = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8).copy() + data = data.reshape(fig.canvas.get_width_height()[::-1] + (4, )) + data[..., [0, 1, 2, 3]] = data[..., [1, 2, 3, 0]] + return data + + +def get_env_asset_path(asset_name: str, assert_exists: bool = True) -> str: + """Return the absolute path to env asset.""" + dir_path = os.path.dirname(os.path.realpath(__file__)) + asset_dir_path = os.path.join(dir_path, "envs", "assets") + path = os.path.join(asset_dir_path, asset_name) + if assert_exists: + assert os.path.exists(path), f"Env asset not found: {asset_name}." + return path + + +def get_third_party_path() -> str: + """Return the absolute path to the third party directory.""" + third_party_dir_path = os.path.join(get_path_to_predicators_root(), + "predicators/third_party") + return third_party_dir_path + + +def get_path_to_predicators_root() -> str: + """Return the absolute path to the predicators root directory. + + Specifically, this returns something that looks like: + '/predicators'. Note there is no '/' at the end. + """ + module_path = Path(__file__) + predicators_dir = module_path.parent.parent + return str(predicators_dir) + + +def import_submodules(path: List[str], + name: str, + tolerate_import_errors: bool = False) -> None: + """Load all submodules on the given path. + + Useful for finding subclasses of an abstract base class + automatically. With ``tolerate_import_errors=True``, any submodule + that fails with ``ModuleNotFoundError`` (e.g. an optional third- + party dep is unavailable) is skipped with a warning instead of + aborting the whole load. Use this in contexts where some submodules + are intentionally optional — e.g. predicators running under Pyodide + where heavy deps like torch / gym_sokoban / gymnasium-robotics + aren't installed, but most envs still work. + """ + if not TYPE_CHECKING: + for _, module_name, _ in pkgutil.walk_packages(path): + if "__init__" not in module_name: + # Important! We use an absolute import here to avoid issues + # with isinstance checking when using relative imports. + full_name = f"{name}.{module_name}" + try: + importlib.import_module(full_name) + except ImportError as exc: + if not tolerate_import_errors: + raise + # `cannot import name X from Y` is raised as plain + # ImportError (not ModuleNotFoundError) and is + # exactly the shape that fires under Pyodide when a + # submodule's top-level pulls a heavy symbol from + # the full utils. Skip both, not just MNFE. + name = getattr(exc, "name", None) or str(exc) + logging.warning( + "Skipping %s: missing optional dependency (%s)", + full_name, name) + + +def update_config(args: Dict[str, Any]) -> None: + """Args is a dictionary of new arguments to add to the config CFG.""" + parser = create_arg_parser() + update_config_with_parser(parser, args) + + +def update_config_with_parser(parser: ArgumentParser, args: Dict[str, + Any]) -> None: + """Helper function for update_config() that accepts a parser argument.""" + arg_specific_settings = GlobalSettings.get_arg_specific_settings(args) + # Only override attributes, don't create new ones. + allowed_args = set(CFG.__dict__) | set(arg_specific_settings) + # Unfortunately, can't figure out any other way to do this. + for parser_action in parser._actions: # pylint: disable=protected-access + allowed_args.add(parser_action.dest) + for k in args: + if k not in allowed_args: + raise ValueError(f"Unrecognized arg: {k}") + for k in ("env", "approach", "seed", "experiment_id"): + if k not in args and hasattr(CFG, k): + # For env, approach, seed, and experiment_id, if we don't + # pass in a value and this key is already in the + # configuration dict, add the current value to args. + args[k] = getattr(CFG, k) + for d in [arg_specific_settings, args]: + for k, v in d.items(): + setattr(CFG, k, v) + + +def reset_config(args: Optional[Dict[str, Any]] = None, + default_seed: int = 123, + default_render_state_dpi: int = 10) -> None: + """Reset to the default CFG, overriding with anything in args. + + This utility is meant for use in testing only. + """ + parser = create_arg_parser() + reset_config_with_parser(parser, args, default_seed, + default_render_state_dpi) + + +def reset_config_with_parser(parser: ArgumentParser, + args: Optional[Dict[str, Any]] = None, + default_seed: int = 123, + default_render_state_dpi: int = 10) -> None: + """Helper function for reset_config that accepts a parser argument.""" + default_args = parser.parse_args([ + "--env", + "default env placeholder", + "--seed", + str(default_seed), + "--approach", + "default approach placeholder", + ]) + arg_dict = { + k: v + for k, v in GlobalSettings.__dict__.items() if not k.startswith("_") + } + arg_dict.update(vars(default_args)) + if args is not None: + arg_dict.update(args) + if args is None or "render_state_dpi" not in args: + # By default, use a small value for the rendering DPI, to avoid + # expensive rendering during testing. + arg_dict["render_state_dpi"] = default_render_state_dpi + update_config_with_parser(parser, arg_dict) + + +def get_config_path_str(experiment_id: Optional[str] = None) -> str: + """Get a filename prefix for configuration based on the current CFG. + + If experiment_id is supplied, it is used in place of + CFG.experiment_id. + """ + if experiment_id is None: + experiment_id = CFG.experiment_id + if CFG.use_counterfactual_dataset_path_name: + return f"{CFG.env}__{CFG.seed}__{CFG.experiment_id}__query" + return (f"{CFG.env}__{CFG.approach}__{CFG.seed}__" + f"{CFG.excluded_predicates}__" + f"{CFG.included_options}__{experiment_id}") + + +def get_approach_save_path_str() -> str: + """Get a path for saving approaches.""" + os.makedirs(CFG.approach_dir, exist_ok=True) + return f"{CFG.approach_dir}/{get_config_path_str()}.saved" + + +def get_approach_load_path_str() -> str: + """Get a path for loading approaches.""" + if not CFG.load_experiment_id: + experiment_id = CFG.experiment_id + else: + experiment_id = CFG.load_experiment_id + return f"{CFG.approach_dir}/{get_config_path_str(experiment_id)}.saved" + + +def parse_args(env_required: bool = True, + approach_required: bool = True, + seed_required: bool = True) -> Dict[str, Any]: + """Parses command line arguments.""" + parser = create_arg_parser(env_required=env_required, + approach_required=approach_required, + seed_required=seed_required) + return parse_args_with_parser(parser) + + +def parse_args_with_parser(parser: ArgumentParser) -> Dict[str, Any]: + """Helper function for parse_args that accepts a parser argument.""" + args, overrides = parser.parse_known_args() + arg_dict = vars(args) + if len(overrides) == 0: + return arg_dict + # Update initial settings to make sure we're overriding + # existing flags only + update_config_with_parser(parser, arg_dict) + # Override global settings + assert len(overrides) >= 2 + assert len(overrides) % 2 == 0 + for flag, value in zip(overrides[:-1:2], overrides[1::2]): + assert flag.startswith("--") + setting_name = flag[2:] + if setting_name not in CFG.__dict__: + raise ValueError(f"Unrecognized flag: {setting_name}") + arg_dict[setting_name] = string_to_python_object(value) + return arg_dict + + +def string_to_python_object(value: str) -> Any: + """Return the Python object corresponding to the given string value.""" + if value in ("None", "none"): + return None + if value in ("True", "true"): + return True + if value in ("False", "false"): + return False + if value.isdigit() or value.startswith("lambda"): + return eval(value) + try: + return float(value) + except ValueError: + pass + if value.startswith("["): + assert value.endswith("]") + inner_strs = value[1:-1].split(",") + return [string_to_python_object(s) for s in inner_strs] + if value.startswith("("): + assert value.endswith(")") + inner_strs = value[1:-1].split(",") + return tuple(string_to_python_object(s) for s in inner_strs) + return value + + +def flush_cache() -> None: + """Clear all lru caches.""" + gc.collect() + _lru_type = functools._lru_cache_wrapper # pylint: disable=protected-access + wrappers = [] + for a in gc.get_objects(): + try: + if isinstance(a, _lru_type): + wrappers.append(a) + except Exception: # pylint: disable=broad-except + continue + + for wrapper in wrappers: + wrapper.cache_clear() + + +def parse_config_excluded_predicates( + env: BaseEnv) -> Tuple[Set[Predicate], Set[Predicate]]: + """Parse the CFG.excluded_predicates string, given an environment. + + Return a tuple of (included predicate set, excluded predicate set). + """ + if CFG.excluded_predicates: + if CFG.excluded_predicates == "all": + excluded_names = { + pred.name + for pred in env.predicates if pred not in env.goal_predicates + } + logging.info(f"All non-goal predicates excluded: {excluded_names}") + included = env.goal_predicates + else: + excluded_names = set(CFG.excluded_predicates.split(",")) + assert excluded_names.issubset( + {pred.name for pred in env.predicates}), \ + "Unrecognized predicate in excluded_predicates!" + included = { + pred + for pred in env.predicates if pred.name not in excluded_names + } + if CFG.offline_data_method != "demo+ground_atoms": + if CFG.allow_exclude_goal_predicates: + if not env.goal_predicates.issubset(included): + logging.info("Note: excluding goal predicates!") + else: + assert env.goal_predicates.issubset(included), \ + "Can't exclude a goal predicate!" + else: + excluded_names = set() + included = env.predicates + excluded = {pred for pred in env.predicates if pred.name in excluded_names} + return included, excluded + + +def replace_goals_with_agent_specific_goals( + included_predicates: Set[Predicate], + excluded_predicates: Set[Predicate], env: BaseEnv) -> Set[Predicate]: + """Replace original goal predicates with agent-specific goal predicates if + the environment defines them.""" + preds = included_predicates - env.goal_predicates \ + | env.agent_goal_predicates - excluded_predicates + return preds + + +def null_sampler(state: State, goal: Set[GroundAtom], rng: np.random.Generator, + objs: Sequence[Object]) -> Array: + """A sampler for an NSRT with no continuous parameters.""" + del state, goal, rng, objs # unused + return np.array([], dtype=np.float32) # no continuous parameters + + +def get_git_commit_hash() -> str: + """Return the hash of the current git commit.""" + out = subprocess.check_output(["git", "rev-parse", "HEAD"]) + return out.decode("ascii").strip() + + +def get_all_subclasses(cls: Any) -> Set[Any]: + """Get all subclasses of the given class.""" + return set(cls.__subclasses__()).union( + [s for c in cls.__subclasses__() for s in get_all_subclasses(c)]) + + +class _DummyFile(io.StringIO): + """Dummy file object used by nostdout().""" + + def write(self, _: Any) -> int: + """Mock write() method.""" + return 0 + + def flush(self) -> None: + """Mock flush() method.""" + + +@contextlib.contextmanager +def nostdout() -> Generator[None, None, None]: + """Suppress output for a block of code. + + To use, wrap code in the statement `with utils.nostdout():`. Note + that calls to the logging library, which this codebase uses + primarily, are unaffected. So, this utility is mostly helpful when + calling third-party code. + """ + save_stdout = sys.stdout + sys.stdout = _DummyFile() + yield + sys.stdout = save_stdout + + +def query_ldl( + ldl: LiftedDecisionList, + atoms: Set[GroundAtom], + objects: Set[Object], + goal: Set[GroundAtom], + static_predicates: Optional[Set[Predicate]] = None, + init_atoms: Optional[Collection[GroundAtom]] = None +) -> Optional[_GroundNSRT]: + """Queries a lifted decision list representing a goal-conditioned policy. + + Given an abstract state and goal, the rules are grounded in order. The + first applicable ground rule is used to return a ground NSRT. + + If static_predicates is provided, it is used to avoid grounding rules with + nonsense preconditions like IsBall(robot). + + If no rule is applicable, returns None. + """ + for rule in ldl.rules: + for ground_rule in all_ground_ldl_rules( + rule, + objects, + static_predicates=static_predicates, + init_atoms=init_atoms): + if ground_rule.pos_state_preconditions.issubset(atoms) and \ + not ground_rule.neg_state_preconditions & atoms and \ + ground_rule.goal_preconditions.issubset(goal): + return ground_rule.ground_nsrt + return None + + +def generate_random_string(length: int, alphabet: Sequence[str], + rng: np.random.Generator) -> str: + """Generates a random string of the given length using the provided set of + characters (alphabet).""" + assert all(len(c) == 1 for c in alphabet) + return "".join(rng.choice(alphabet, size=length)) + + +def find_balanced_expression(s: str, index: int) -> str: + """Find balanced expression in string starting from given index.""" + assert s[index] == "(" + start_index = index + balance = 1 + while balance != 0: + index += 1 + symbol = s[index] + if symbol == "(": + balance += 1 + elif symbol == ")": + balance -= 1 + return s[start_index:index + 1] + + +def find_all_balanced_expressions(s: str) -> List[str]: + """Return a list of all balanced expressions in a string, starting from the + beginning.""" + assert s[0] == "(" + assert s[-1] == ")" + exprs = [] + index = 0 + start_index = index + balance = 1 + while index < len(s) - 1: + index += 1 + if balance == 0: + exprs.append(s[start_index:index]) + # Jump to next "(". + while True: + if s[index] == "(": + break + index += 1 + start_index = index + balance = 1 + continue + symbol = s[index] + if symbol == "(": + balance += 1 + elif symbol == ")": + balance -= 1 + assert balance == 0 + exprs.append(s[start_index:index + 1]) + return exprs + + +def range_intersection(lb1: float, ub1: float, lb2: float, ub2: float) -> bool: + """Given upper and lower bounds for two ranges, returns True iff the ranges + intersect.""" + return (lb1 <= lb2 <= ub1) or (lb2 <= lb1 <= ub2) + + +def compute_abs_range_given_two_ranges(lb1: float, ub1: float, lb2: float, + ub2: float) -> Tuple[float, float]: + """Given upper and lower bounds of two feature ranges, returns the upper. + + and lower bound of |f1 - f2|. + """ + # Now, we must compute the upper and lower bounds of + # the expression |t1.f1 - t2.f2|. If the intervals + # [lb1, ub1] and [lb2, ub2] overlap, then the lower + # bound of the expression is just 0. Otherwise, if + # lb2 > ub1, the lower bound is |ub1 - lb2|, and if + # ub2 < lb1, the lower bound is |lb1 - ub2|. + if range_intersection(lb1, ub1, lb2, ub2): + lb = 0.0 + else: + lb = min(abs(lb2 - ub1), abs(lb1 - ub2)) + # The upper bound for the expression can be + # computed in a similar fashion. + ub = max(abs(ub2 - lb1), abs(ub1 - lb2)) + return (lb, ub) + + +def roundrobin(iterables: Sequence[Iterator]) -> Iterator: + """roundrobin(['ABC...', 'D...', 'EF...']) --> A D E B F C...""" + # Recipe credited to George Sakkis, code adapted slightly from + # from https://docs.python.org/3/library/itertools.html + num_active = len(iterables) + nexts = itertools.cycle(iter(it).__next__ for it in iterables) + while num_active: + for nxt in nexts: + yield nxt() + + +def get_task_seed(train_or_test: str, task_idx: int) -> int: + """Parses task seed from CFG.test_env_seed_offset.""" + assert task_idx < CFG.test_env_seed_offset + # SeedSequence generates a sequence of random values given an integer + # "entropy". We use CFG.seed to define the "entropy" and then get the + # n^th generated random value and use that to seed the gym environment. + # This is all to avoid unintentional dependence between experiments + # that are conducted with consecutive random seeds. For example, if + # we used CFG.seed + task_idx to seed the gym environment, there would + # be overlap between experiments when CFG.seed = 1, CFG.seed = 2, etc. + seed_entropy = CFG.seed + if train_or_test == "test": + seed_entropy += CFG.test_env_seed_offset + seed_sequence = np.random.SeedSequence(seed_entropy) + # Need to cast to int because generate_state() returns a numpy int. + task_seed = int(seed_sequence.generate_state(task_idx + 1)[-1]) + return task_seed + + +def _obs_to_state_pass_through(obs: Observation) -> State: + """Helper for run_ground_nsrt_with_assertions.""" + assert isinstance(obs, State) + return obs + + +def run_ground_nsrt_with_assertions(ground_nsrt: _GroundNSRT, + state: State, + env: BaseEnv, + rng: np.random.Generator, + override_params: Optional[Array] = None, + obs_to_state: Callable[ + [Observation], + State] = _obs_to_state_pass_through, + assert_effects: bool = True, + max_steps: int = 400) -> State: + """Utility for tests. + + NOTE: assumes that the internal state of env corresponds to state. + """ + ground_nsrt_str = f"{ground_nsrt.name}{ground_nsrt.objects}" + for atom in ground_nsrt.preconditions: + assert atom.holds(state), \ + f"Precondition for {ground_nsrt_str} failed: {atom}" + option = ground_nsrt.sample_option(state, set(), rng) + if override_params is not None: + option = option.parent.ground(option.objects, + override_params) # pragma: no cover + assert option.initiable(state) + for _ in range(max_steps): + act = option.policy(state) + obs = env.step(act) + state = obs_to_state(obs) + if option.terminal(state): + break + if assert_effects: + for atom in ground_nsrt.add_effects: + assert atom.holds(state), \ + f"Add effect for {ground_nsrt_str} failed: {atom}" + for atom in ground_nsrt.delete_effects: + assert not atom.holds(state), \ + f"Delete effect for {ground_nsrt_str} failed: {atom}" + return state + + +def get_scaled_default_font( + draw: ImageDraw.ImageDraw, + size: int) -> ImageFont.FreeTypeFont: # pragma: no cover + """Method that modifies the size of some provided PIL ImageDraw font. + + Useful for scaling up font sizes when using PIL to insert text + directly into images. + """ + # Determine the scaling factor + base_font = ImageFont.load_default() + width, height = draw.textbbox((0, 0), "A", font=base_font)[:2] + scale_factor = size / max(width, height) + # Scale the font using the factor + return base_font.font_variant(size=int(scale_factor * # type: ignore + base_font.size)) # type: ignore + + +def add_text_to_draw_img( + draw: ImageDraw.ImageDraw, position: Tuple[int, int], text: str, + font: ImageFont.FreeTypeFont +) -> ImageDraw.ImageDraw: # pragma: no cover + """Method that adds some text with a particular font at a particular pixel + position in an input PIL.ImageDraw.ImageDraw image. + + Returns the modified ImageDraw.ImageDraw with the added text. + """ + text_width, text_height = draw.textbbox((0, 0), text, font=font)[2:] + background_position = (position[0] - 5, position[1] - 5 + ) # Slightly larger than text + background_size = (text_width + 10, text_height + 10) + # Draw the background rectangle + draw.rectangle( + (background_position, (background_position[0] + background_size[0], + background_position[1] + background_size[1])), + fill="black") + # Add the text to the image + draw.text(position, text, fill="red", font=font) + return draw + + +def wrap_angle(angle: float) -> float: + """Wrap an angle in radians to [-pi, pi].""" + return np.arctan2(np.sin(angle), np.cos(angle)) + + +def get_parameterized_option_by_name( + options: Set[ParameterizedOption], + option_name: str) -> Optional[ParameterizedOption]: + """Retrieve an option by its name from a set of options.""" + return next((option for option in options if option.name == option_name), + None) + + +def get_object_by_name(objects: Collection[Object], + name: str) -> Optional[Object]: + """Get an object by its name from a collection of objects. + + Args: + objects: Collection of objects to search through + name: Name of the object to find + + Returns: + The object if found, None otherwise + """ + return next((obj for obj in objects if obj.name == name), None) + + +def configure_logging() -> None: + """Configure logging with colored output.""" + # Create a single formatter instance to be reused + colored_formatter = colorlog.ColoredFormatter( + '%(log_color)s%(levelname)s: %(message)s', + log_colors={ + 'DEBUG': 'cyan', + 'INFO': 'green', + 'WARNING': 'yellow', + 'ERROR': 'red', + 'CRITICAL': 'red,bg_white', + }, + reset=True, + style='%') + # Log to stderr. + colorlog_handler = colorlog.StreamHandler() + colorlog_handler.setFormatter(colored_formatter) + handlers: List[logging.Handler] = [colorlog_handler] + if CFG.log_file: + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + CFG.log_file += (f"{CFG.approach}/{CFG.experiment_id}/" + f"seed{CFG.seed}/run_{timestamp}/") + os.makedirs(CFG.log_file, exist_ok=True) + + # Handler for DEBUG level messages + debug_handler = logging.FileHandler(os.path.join( + CFG.log_file, "debug.log"), + mode='w') + debug_handler.setLevel(logging.DEBUG) + debug_handler.setFormatter(colored_formatter) + handlers.append(debug_handler) + + # Handler for INFO level messages + info_handler = logging.FileHandler(os.path.join( + CFG.log_file, "info.log"), + mode='w') + info_handler.setLevel(logging.INFO) + info_handler.setFormatter(colored_formatter) + handlers.append(info_handler) + + logging.basicConfig(level=CFG.loglevel, + format="%(message)s", + handlers=handlers, + force=True) + logging.getLogger('matplotlib.font_manager').setLevel(logging.ERROR) + logging.getLogger('libpng').setLevel(logging.ERROR) + logging.getLogger('PIL').setLevel(logging.ERROR) + logging.getLogger('openai').setLevel(logging.INFO) + # Used by openai package + logging.getLogger("httpx").setLevel(logging.INFO) + logging.getLogger("httpcore").setLevel(logging.INFO) + + +def log_initial_info(str_args: str) -> None: + """Log initial configuration and setup information.""" + if CFG.log_file: + logging.info(f"Logging to {CFG.log_file}") + logging.info(f"Running command: python {str_args}") + logging.info("Full config:") + logging.info(CFG) + logging.info(f"Git commit hash: {get_git_commit_hash()}") + + +def add_label_to_video(video: Video, + prefix: str, + imgs_dir: str, + save: bool = True) -> Video: + """Add a label to each frame of the video and save the images.""" + os.makedirs(imgs_dir, exist_ok=True) + new_video: Video = [] + for i, img in enumerate(video): + img_name = prefix + f"frame_{i+1}" + labeled_img = add_label_to_image( + img, # type: ignore[arg-type] + img_name, + imgs_dir, + save=save) + new_video.append(labeled_img) # type: ignore[arg-type] + return new_video + + +def add_label_to_image(img: PIL.Image.Image, + s_name: str, + obs_dir: str, + f_suffix: str = ".png", + save: bool = True) -> PIL.Image.Image: + """Add a label to an image and potentially save.""" + img_copy = img.copy() + draw = ImageDraw.Draw(img_copy) + font = ImageFont.load_default().font_variant( # type: ignore[union-attr] + size=50) + + # Get text dimensions + bbox = draw.textbbox((0, 0), s_name, font=font) + text_width = bbox[2] - bbox[0] + text_height = bbox[3] - bbox[1] + + # Calculate position (bottom right with padding) + padding = 10 + x = img_copy.width - text_width - padding + y = img_copy.height - text_height - 2 * padding + + text_color = (0, 0, 0) # black + draw.text((x, y), s_name, fill=text_color, font=font) + + if save: + os.makedirs(obs_dir, exist_ok=True) + img_copy.save(os.path.join(obs_dir, s_name + f_suffix)) + logging.debug(f"Saved Image {s_name}") + return img_copy + + +def load_all_images_from_dir(dir_path: str) -> List[PIL.Image.Image]: + """Load all images from a directory.""" + images = [] + img_paths = sorted(os.listdir(dir_path)) + for file in img_paths: + if file.endswith(('.png', '.jpg')): + images.append(PIL.Image.open(os.path.join(dir_path, file))) + return images + + +def all_subsets(input_set: Iterable[Any]) -> Iterator[Set[Any]]: + """Generates all subsets of a given set. + + Args: + input_set: An iterable (e.g., a list, set, tuple) + from which to generate subsets. + + Yields: + tuple: Each subset as a tuple. + """ + s = list(input_set) # Convert to list to handle various iterable inputs + n = len(s) + for i in range(n + 1): # Iterate from subset size 0 up to n + for subset in itertools.combinations(s, i): + yield set(subset) + + +def add_in_auxiliary_predicates(predicates: Set[Predicate]) -> Set[Predicate]: + """Add auxiliary predicates from derived predicates.""" + + def add_auxiliary(pred: Predicate, preds: Set[Predicate]) -> None: + if isinstance(pred, DerivedPredicate): + if pred.auxiliary_predicates: + preds.update(pred.auxiliary_predicates) + for aux_pred in pred.auxiliary_predicates: + add_auxiliary(aux_pred, preds) + + new_preds = predicates.copy() + for pred in predicates: + add_auxiliary(pred, new_preds) + return new_preds + + +def get_derived_predicates( + predicates: Set[Predicate]) -> Set[DerivedPredicate]: + """Get all derived predicates from a set of predicates.""" + return {pred for pred in predicates if isinstance(pred, DerivedPredicate)} + + +# def abstract_with_derived_predicates(atoms, derived_preds, objects): +# """Compute all derived atoms via layered evaluation (fewer passes). +# Potentially faster than the current implementation.""" +# # Build dependency graph over derived preds +# is_derived = {p for p in derived_preds} +# indeg = {p: 0 for p in derived_preds} +# edges = {p: set() for p in derived_preds} +# for p in derived_preds: +# for aux in getattr(p, "auxiliary_predicates", []): +# # only count deps on other derived preds +# q = next( +# (dp for dp in derived_preds +# if dp.name == aux.name), None) +# if q: +# edges[q].add(p); indeg[p] += 1 + +# # Kahn’s algorithm => layers +# frontier = [p for p in derived_preds if indeg[p] == 0] +# layers: list[list] = [] +# while frontier: +# layer = list(frontier); layers.append(layer); frontier = [] +# for u in layer: +# for v in edges[u]: +# indeg[v] -= 1 +# if indeg[v] == 0: +# frontier.append(v) + +# # Evaluate per layer; state grows monotonically +# state = set(atoms) +# derived_all = set() +# # (Optional) cache object choices per predicate once +# by_type = {} +# for o in objects: +# by_type.setdefault(o.type, []).append(o) +# choices_cache = { +# p: list(itertools.product(*(by_type[t] for t in p.types))) +# for p in derived_preds +# } + +# for layer in layers: +# for p in layer: +# for choice in choices_cache[p]: +# if p.holds(state, choice): +# derived_all.add(GroundAtom(p, choice)) +# state |= derived_all # grow state for next layer + +# return derived_all + + +def abstract_with_derived_predicates( + atoms: Set[GroundAtom], derived_preds: Collection[DerivedPredicate], + objects: Collection[Object]) -> Set[GroundAtom]: + """Compute the fixed point of concept predicate atoms.""" + primitive_atoms = atoms + new_concept_atoms: Set[GroundAtom] = set() + prev_new_concept_atoms: Set[GroundAtom] = set() + counter = 0 + while True: + # All the concept atoms that holds; all the previous atoms + atoms = primitive_atoms | new_concept_atoms + new_concept_atoms = _abstract_with_derived_predicates( + atoms, derived_preds, objects) + # logging.debug(f"ite {counter} concept atoms: {new_concept_atoms}") + converged = new_concept_atoms == prev_new_concept_atoms + if converged: + # logging.debug("converged") + break + prev_new_concept_atoms = new_concept_atoms + counter += 1 + return new_concept_atoms + + +def _abstract_with_derived_predicates( + abs_state: Set[GroundAtom], + derived_preds: Collection[DerivedPredicate], + objects: Collection[Object]) -> Set[GroundAtom]: + """Get the atoms based on the existing atomic state and concept + predicates.""" + atoms: Set[GroundAtom] = set() + for pred in derived_preds: + for choice in get_object_combinations(objects, pred.types): + try: + if pred.holds(abs_state, choice): + atoms.add(GroundAtom(pred, choice)) + except Exception as e: + logging.error(f"Error in evaluating concept predicate {pred}: " + f"{e}") + # raise e + raise PredicateEvaluationError( + f"Error in evaluating concept predicate {pred}: {e}", pred) + return atoms + + +def get_base_supporter_predicates( + root_predicate: DerivedPredicate) -> Set[Predicate]: + """Finds all primitive (non-derived) supporter predicates for a given root + derived predicate by traversing its dependency graph.""" + base_predicates: Set[Predicate] = set() + + # Use a worklist to process predicates in a breadth-first manner. + predicates_to_process: List[Predicate] = list( + root_predicate.auxiliary_predicates or []) + processed_predicates: Set[Predicate] = {root_predicate} + + while predicates_to_process: + pred = predicates_to_process.pop(0) + + if pred in processed_predicates: + continue + processed_predicates.add(pred) + + # If the predicate is derived, add its auxiliaries to the worklist. + if isinstance(pred, DerivedPredicate): + predicates_to_process.extend(pred.auxiliary_predicates or []) + # If it's a primitive predicate, we've found a base supporter. + else: + base_predicates.add(pred) + + return base_predicates + + +class PredicateEvaluationError(Exception): + """PredicateEvaluationError class.""" + + def __init__(self, message: str, pred: Any) -> None: + super().__init__(message) + self.pred = pred diff --git a/setup.py b/setup.py index 4ed6783274..189cad8a14 100644 --- a/setup.py +++ b/setup.py @@ -1,56 +1,87 @@ """Setup script.""" from setuptools import find_packages, setup +# Heavy ML / LLM / planning deps. Required for `python predicators/main.py` +# (learning approaches, NSRT pipeline, LLM-backed tools) but NOT for using +# the PyBullet envs directly (via the gymnasium_wrapper or env modules). +# `pip install predicators` is the slim env-runtime install; learners +# should `pip install predicators[ml]` (or `[develop]`, which is a +# superset including formatters/test tools). +ML_DEPS = [ + "torch>=2.2.0", + "torchvision>=0.17.0", + "scipy==1.9.3", + "scikit-learn>=1.1.3", + "pandas==1.5.1", + "seaborn==0.12.1", + "imageio==2.22.2", + "imageio-ffmpeg", + "openai==1.19.0", + "google-generativeai", + "tenacity", + "httpx==0.28.1", + "claude-agent-sdk>=0.1.73", + "nest_asyncio", + "ImageHash", + "pathos", + "psutil", + "slack_bolt", + "emcee", + "lisdf", + "requests", + "smepy@git+https://github.com/sebdumancic/structure_mapping.git", + "pg3@git+https://github.com/tomsilver/pg3.git", + "gym_sokoban@git+https://github.com/Learning-and-Intelligent-Systems/gym-sokoban.git", # pylint: disable=line-too-long +] + +# Dev tooling (formatters, linters, pytest plugins). Always paired with +# ML_DEPS in the `develop` extra so contributors get a fully working +# checkout from a single `pip install -e .[develop]`. +DEV_TOOLING = [ + "pytest-cov==2.12.1", + "pytest-pylint==0.18.0", + "yapf==0.32.0", + "docformatter==1.4", + "isort==5.10.1", + "mypy-extensions==1.0.0", +] + setup( name="predicators", version="0.1.0", packages=find_packages(include=["predicators", "predicators.*"]), install_requires=[ - "numpy==1.23.5", - "pytest==7.1.3", - "mypy==1.8.0", + # PyBullet env-runtime. Mirrors master's pins on the deps that + # are still in the base install so CPython CI stays + # predictable. Pyodide already ships newer versions of + # numpy / matplotlib / pillow / pyyaml that conflict with + # these exact pins, so qualify those with + # `sys_platform != 'emscripten'` — under Pyodide micropip + # then just keeps the already-loaded copy. + "numpy==1.23.5; sys_platform != 'emscripten'", + "matplotlib==3.6.2; sys_platform != 'emscripten'", + "pillow==10.3.0; sys_platform != 'emscripten'", + "pyyaml==6.0; sys_platform != 'emscripten'", "gym==0.26.2", "gymnasium>=0.28.0", - "matplotlib==3.6.2", - "imageio==2.22.2", - "imageio-ffmpeg", - "pandas==1.5.1", - "torch>=2.2.0", - "torchvision>=0.17.0", - "scipy==1.9.3", + # On Pyodide (sys_platform == "emscripten") the pybullet wheel + # is loaded from a local emfs path by the browser bootstrap; + # PyPI has no Pyodide-targeted pybullet. + "pybullet-arm64>=3.2.8; sys_platform != 'emscripten'", + "pyperplan", "tabulate==0.9.0", "dill==0.3.5.1", - "pyperplan", - "pathos", - "pillow==10.3.0", - "requests", - "slack_bolt", - "pybullet-arm64>=3.2.8", - "scikit-learn>=1.1.3", + "colorlog", + "types-PyYAML", "graphlib-backport", - "openai==1.19.0", - "pyyaml==6.0", + # Test/lint tooling stays in base for now (CLAUDE.md docs the + # runtime install as including them). + "pytest==7.1.3", + "mypy==1.8.0", "pylint==2.14.5", - "types-PyYAML", - "lisdf", - "seaborn==0.12.1", - "smepy@git+https://github.com/sebdumancic/structure_mapping.git", - "pg3@git+https://github.com/tomsilver/pg3.git", - "gym_sokoban@git+https://github.com/Learning-and-Intelligent-Systems/gym-sokoban.git", # pylint: disable=line-too-long - "ImageHash", - "google-generativeai", - "tenacity", - "httpx==0.28.1", - "colorlog", - "psutil", - "claude-agent-sdk>=0.1.73", - "nest_asyncio", - "emcee", ], include_package_data=True, extras_require={ - "develop": [ - "pytest-cov==2.12.1", "pytest-pylint==0.18.0", "yapf==0.32.0", - "docformatter==1.4", "isort==5.10.1", "mypy-extensions==1.0.0" - ] + "ml": ML_DEPS, + "develop": ML_DEPS + DEV_TOOLING, }) diff --git a/web/.gitignore b/web/.gitignore new file mode 100644 index 0000000000..c252a74442 --- /dev/null +++ b/web/.gitignore @@ -0,0 +1,8 @@ +__pycache__/ +*.pyc +node_modules +.venv +predicators_assets +wheels/ +predicators-min.tar.gz +envs/ diff --git a/web/README.md b/web/README.md new file mode 100644 index 0000000000..18261e2e1f --- /dev/null +++ b/web/README.md @@ -0,0 +1,121 @@ +# predicators in the browser (Pyodide POC) + +A proof-of-concept that runs a PyBullet environment from `predicators` +inside a web browser via Pyodide + a WASM build of `pybullet`. No +Python server, no native binaries — the entire stack ships as static +files to the user's browser, where Pyodide loads it. + +## What works today + +- Booting Pyodide in the page. +- Installing the WASM `pybullet` wheel + the pure-Python `predicators` + wheel (with a tiny `gym.spaces.Box` shim). +- Unpacking the `predicators/envs/assets/` tarball into Pyodide's + emulated FS. +- Constructing a PyBullet env (`pybullet_blocks` etc.), rendering a + 320×240 RGBA frame through `p.ER_TINY_RENDERER`, and drawing it + onto a ``. +- **Option-level interaction:** after Reset, the option ` + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +

+  
+
+  
+
+
+ + + + diff --git a/web/app/main.js b/web/app/main.js new file mode 100644 index 0000000000..ae2261ca4e --- /dev/null +++ b/web/app/main.js @@ -0,0 +1,913 @@ +// predicators-in-the-browser bootstrap. +// +// Pyodide runs predicators + pybullet headless; Three.js renders the +// scene client-side from a manifest extracted by the Python bridge. +// urdf-loader handles real URDF parsing (Fetch robot meshes, plane, +// table, etc.); primitive boxes/spheres are built directly from +// THREE.{Box,Sphere}Geometry. Each option execution returns a list of +// {body_id: {pos, orn, joints}} snapshots that we replay onto the +// scene via requestAnimationFrame. + +import { loadPyodide } from "https://cdn.jsdelivr.net/pyodide/v0.29.4/full/pyodide.mjs"; +import * as THREE from "three"; +import { OrbitControls } from "three/addons/controls/OrbitControls.js"; +import { OBJLoader } from "three/addons/loaders/OBJLoader.js"; +import { STLLoader } from "three/addons/loaders/STLLoader.js"; +import { ColladaLoader } from "three/addons/loaders/ColladaLoader.js"; +import URDFLoader from "urdf-loader"; + +const statusEl = document.getElementById("status"); +const logEl = document.getElementById("log"); +const infoEl = document.getElementById("info"); +const envSelect = document.getElementById("env-select"); +const bootBtn = document.getElementById("boot-env"); +const sceneHost = document.getElementById("scene-host"); + +const WHEELS_BASE = "../wheels"; +const PYBULLET_WHEEL = "pybullet-3.2.7-cp313-cp313-pyemscripten_2025_0_wasm32.whl"; +const PREDICATORS_WHEEL = "predicators-0.1.0-py3-none-any.whl"; +const GYM_SHIM_WHEEL = "gym-0.26.2-py3-none-any.whl"; +const ASSETS_TARBALL = "assets.tar.gz"; + +const t0 = performance.now(); +function log(msg) { + const t = ((performance.now() - t0) / 1000).toFixed(1); + const line = `[${t}s] ${msg}`; + console.log(line); + logEl.textContent += line + "\n"; + logEl.scrollTop = logEl.scrollHeight; +} +function setStatus(s) { statusEl.textContent = s; } + +let pyodide = null; + +async function fetchBytes(path) { + const r = await fetch(path); + if (!r.ok) throw new Error(`fetch ${path}: ${r.status}`); + return new Uint8Array(await r.arrayBuffer()); +} + +// ---------------------------------------------------------------------------- +// Three.js scene setup +// ---------------------------------------------------------------------------- + +const scene = new THREE.Scene(); +scene.background = new THREE.Color(0x1a2030); + +// World convention: pybullet uses Z-up. So does this scene. +scene.up = new THREE.Vector3(0, 0, 1); + +const camera = new THREE.PerspectiveCamera(45, 1, 0.01, 100); +camera.up.set(0, 0, 1); +camera.position.set(1.7, -1.4, 1.3); +camera.lookAt(0.75, 0.75, 0.5); + +const renderer = new THREE.WebGLRenderer({ antialias: true, alpha: false }); +renderer.setPixelRatio(window.devicePixelRatio); +renderer.shadowMap.enabled = true; +renderer.shadowMap.type = THREE.PCFSoftShadowMap; +renderer.outputColorSpace = THREE.SRGBColorSpace; +// ACES filmic tone mapping makes PBR materials render with the +// roll-off you'd expect in a game engine — without it, MeshStandard's +// linear output looks washed-out/flat on solid-colored blocks. +renderer.toneMapping = THREE.ACESFilmicToneMapping; +renderer.toneMappingExposure = 1.1; +sceneHost.appendChild(renderer.domElement); + +const controls = new OrbitControls(camera, renderer.domElement); +controls.target.set(0.75, 0.75, 0.5); +controls.update(); + +// Lighting rig (game-engine style): +// - DirectionalLight ("sun"): key light + shadow caster, aimed at the +// workspace centroid. Initialize sun.target now so the light vector +// is sensible before applyEnvCamera (which retargets per-env) runs. +// - AmbientLight: small global lift so deep shadows don't crush. +const WORKSPACE_CENTER = new THREE.Vector3(0.75, 0.75, 0.5); +scene.add(new THREE.AmbientLight(0xffffff, 0.25)); +const sun = new THREE.DirectionalLight(0xfff0d8, 1.5); +sun.position.set(WORKSPACE_CENTER.x + 2, + WORKSPACE_CENTER.y - 2, + WORKSPACE_CENTER.z + 3); +sun.target.position.copy(WORKSPACE_CENTER); +sun.castShadow = true; +sun.shadow.mapSize.set(2048, 2048); +sun.shadow.camera.left = -3; sun.shadow.camera.right = 3; +sun.shadow.camera.top = 3; sun.shadow.camera.bottom = -3; +sun.shadow.camera.near = 0.1; sun.shadow.camera.far = 10; +sun.shadow.bias = -0.0005; +sun.shadow.normalBias = 0.02; +scene.add(sun); +scene.add(sun.target); + +// "Workbench lamp": warm orange SpotLight hanging above the table, +// casting a soft pool on whatever's on the workspace. Distance-decay +// keeps the orange localized so it doesn't tint the floor or robot. +const lampPos = new THREE.Vector3( + WORKSPACE_CENTER.x, WORKSPACE_CENTER.y, WORKSPACE_CENTER.z + 0.9); +const lamp = new THREE.SpotLight(0xff7a25, 8.0, + /* distance */ 1.6, /* angle */ Math.PI / 4.2, + /* penumbra */ 0.55, /* decay */ 1.8); +lamp.position.copy(lampPos); +lamp.target.position.copy(WORKSPACE_CENTER); +lamp.castShadow = true; +lamp.shadow.mapSize.set(1024, 1024); +lamp.shadow.camera.near = 0.05; +lamp.shadow.camera.far = 2.0; +lamp.shadow.bias = -0.0005; +scene.add(lamp); +scene.add(lamp.target); + +// Visible "bulb": small emissive sphere at the lamp position so the +// orange has a visual source. Doesn't itself cast light — purely +// decorative. +const bulb = new THREE.Mesh( + new THREE.SphereGeometry(0.035, 16, 16), + new THREE.MeshStandardMaterial({ + color: 0xff7a25, emissive: 0xff7a25, emissiveIntensity: 2.0, + roughness: 0.4, metalness: 0.0, + }), +); +bulb.position.copy(lampPos); +scene.add(bulb); + +function resize() { + const w = sceneHost.clientWidth || 1; + const h = sceneHost.clientHeight || 1; + renderer.setSize(w, h, false); + camera.aspect = w / h; + camera.updateProjectionMatrix(); +} +window.addEventListener("resize", resize); +new ResizeObserver(resize).observe(sceneHost); +resize(); + +// Position the Three.js camera using the env's pybullet camera +// params. PyBullet uses (yaw, pitch, distance) around a target with +// Z-up, yaw measured from +X axis, pitch negative looking down. +function applyEnvCamera(cam) { + const target = new THREE.Vector3(cam.target[0], cam.target[1], cam.target[2]); + const yawRad = THREE.MathUtils.degToRad(cam.yaw); + const pitchRad = THREE.MathUtils.degToRad(cam.pitch); + // PyBullet's `computeViewMatrixFromYawPitchRoll` constructs the + // forward vector (where the camera looks) as + // forward = (-cos(p)sin(y), cos(p)cos(y), sin(p)) + // and places the camera at `target - forward * distance`. + const cp = Math.cos(pitchRad), sp = Math.sin(pitchRad); + const cy = Math.cos(yawRad), sy = Math.sin(yawRad); + const forward = new THREE.Vector3(-cp * sy, cp * cy, sp); + // Bump the camera back a bit. The env-author's distance is tuned + // for pybullet's 320x240 TinyRenderer view; in Three.js with our + // larger viewport we want more of the robot in frame. + const offset = forward.clone().multiplyScalar(-cam.distance * 1.7); + camera.position.copy(target).add(offset); + camera.fov = cam.fov; + camera.updateProjectionMatrix(); + camera.lookAt(target); + controls.target.copy(target); + controls.update(); + // Adjust shadow camera so it covers the workspace. + const m = Math.max(cam.distance * 1.5, 2.0); + sun.shadow.camera.left = -m; sun.shadow.camera.right = m; + sun.shadow.camera.top = m; sun.shadow.camera.bottom = -m; + sun.shadow.camera.far = cam.distance * 8; + sun.shadow.camera.updateProjectionMatrix(); + // Reposition the sun relative to the target for nicer shadows. + sun.position.set(target.x + 2, target.y - 2, target.z + 3); + sun.target.position.copy(target); + sun.target.updateMatrixWorld(); +} + +// Fit the camera to the scene by computing a bounding box over all +// rendered objects. Called after the manifest finishes loading. +function fitCameraToScene() { + const box = new THREE.Box3(); + let hasContents = false; + // Predicators stashes unused objects at world coords like (10,10). + // Skip anything more than 5 m from the workspace centroid; the + // 0.75/0.75 origin is roughly where the Fetch base + table sit. + const STASH_THRESHOLD = 5.0; + const tmp = new THREE.Vector3(); + for (const [, b] of bodyMap) { + if (b.root.userData.isGround) continue; + b.root.updateMatrixWorld(true); + const objBox = new THREE.Box3().setFromObject(b.root); + if (objBox.isEmpty()) continue; + objBox.getCenter(tmp); + if (Math.abs(tmp.x) > STASH_THRESHOLD || Math.abs(tmp.y) > STASH_THRESHOLD + || Math.abs(tmp.z) > STASH_THRESHOLD) { + log(` skipping body (stashed at ${tmp.x.toFixed(1)},${tmp.y.toFixed(1)},${tmp.z.toFixed(1)})`); + continue; + } + if (!hasContents) { + box.copy(objBox); + hasContents = true; + } else { + box.union(objBox); + } + } + if (!hasContents) { log("fitCameraToScene: no content bodies"); return; } + const center = box.getCenter(new THREE.Vector3()); + const size = box.getSize(new THREE.Vector3()); + const radius = Math.max(size.x, size.y, size.z) * 0.5; + log(`fitCameraToScene: center=(${center.x.toFixed(2)},${center.y.toFixed(2)},${center.z.toFixed(2)}) size=(${size.x.toFixed(2)},${size.y.toFixed(2)},${size.z.toFixed(2)}) bodies=${bodyMap.size}`); + // Place camera at a 45-deg azimuth angle, elevated, distance ~3x radius. + const distance = Math.max(radius * 3.0, 1.0); + const dir = new THREE.Vector3(1, -1, 0.9).normalize(); + camera.position.copy(center).addScaledVector(dir, distance); + camera.lookAt(center); + controls.target.copy(center); + controls.update(); + // Adjust shadow camera extents to cover the scene + a margin. + const m = Math.max(radius * 2.0, 2.0); + sun.shadow.camera.left = -m; sun.shadow.camera.right = m; + sun.shadow.camera.top = m; sun.shadow.camera.bottom = -m; + sun.shadow.camera.far = distance * 4; + sun.shadow.camera.updateProjectionMatrix(); +} + +// Per-body state: body_id -> { root: THREE.Object3D, joints: {name: URDFJoint or null} } +let bodyMap = new Map(); +// Expose for ad-hoc browser-console debugging: window.predBodies(), etc. +window.predBodies = () => Array.from(bodyMap.entries()).map(([id, b]) => { + b.root.updateMatrixWorld(true); + const box = new THREE.Box3().setFromObject(b.root); + const c = box.getCenter(new THREE.Vector3()); + const s = box.getSize(new THREE.Vector3()); + let meshCount = 0; + b.root.traverse((o) => { if (o.isMesh) meshCount++; }); + return { id, kind: b.kind, meshCount, + center: [c.x.toFixed(2), c.y.toFixed(2), c.z.toFixed(2)], + size: [s.x.toFixed(2), s.y.toFixed(2), s.z.toFixed(2)] }; +}); +window.predScene = scene; + +function clearScene() { + for (const [, b] of bodyMap) { + scene.remove(b.root); + b.root.traverse((o) => { + if (o.geometry) o.geometry.dispose(); + if (o.material) { + const mats = Array.isArray(o.material) ? o.material : [o.material]; + for (const m of mats) m.dispose(); + } + }); + } + bodyMap.clear(); +} + +const urdfLoader = new URDFLoader(); +// urdf-loader needs the absolute base URL so package:// resolution works. +urdfLoader.packages = (url) => url; +urdfLoader.parseVisual = true; +urdfLoader.parseCollision = false; +// Default loader only handles .stl / .dae. Some envs reference .obj +// (e.g. plane.urdf -> plane.obj) — plug in OBJLoader explicitly. +urdfLoader.loadMeshCb = (path, manager, done) => { + if (/\.obj$/i.test(path)) { + new OBJLoader(manager).load(path, (obj) => done(obj), + undefined, (err) => done(null, err)); + } else if (/\.stl$/i.test(path)) { + new STLLoader(manager).load(path, (geom) => { + done(new THREE.Mesh(geom, new THREE.MeshStandardMaterial({ + color: 0xb8c0cc, roughness: 0.6, metalness: 0.1, + }))); + }, undefined, (err) => done(null, err)); + } else if (/\.dae$/i.test(path)) { + new ColladaLoader(manager).load(path, (dae) => { + // URDF is Z-up; ColladaLoader unhelpfully rotates Z-up Collada + // assets to Y-up. Clear that so meshes stay aligned with their + // URDF link frames. + dae.scene.rotation.set(0, 0, 0); + // Blender-exported DAEs include leftover Camera/Lamp nodes at + // distant world positions; they're invisible but inflate any + // bbox computation. Drop them. + const stash = []; + dae.scene.traverse((o) => { + if (o.isCamera || o.isLight) stash.push(o); + }); + for (const o of stash) o.parent?.remove(o); + done(dae.scene); + }, undefined, (err) => done(null, err)); + } else { + console.warn(`URDFLoader: no loader for ${path}`); + done(null, new Error(`no loader for ${path}`)); + } +}; + +// Some predicators URDFs (fetch_description/robots/fetch.urdf) use +// undefined XML namespace prefixes like `` inside +// `` blocks. Chromium's DOMParser tolerates them; Firefox's +// (and the spec-strict path) aborts and returns a . We +// don't render gazebo plugins anyway, so just strip them. Returns +// the cleaned URDF text. +function sanitizeUrdf(text) { + return text.replace(//g, ""); +} + +function loadUrdfBody(entry) { + return new Promise((resolve) => { + // Fetch the URDF ourselves so we can sanitize before parsing, + // then hand the cleaned text to urdf-loader. No JS-side + // caching: clone + LoadingManager interactions made it too + // brittle (cloned-before-meshes-load, concurrent waiters + // racing on a single onLoad slot, etc). The browser still + // disk-caches the URDF + mesh fetches across env switches, so + // re-visits aren't full cold loads. + const workingPath = entry.url.substring(0, entry.url.lastIndexOf("/") + 1); + fetch(entry.url).then((r) => { + if (!r.ok) throw new Error(`HTTP ${r.status}`); + return r.text(); + }).then((text) => { + const cleaned = sanitizeUrdf(text); + urdfLoader.workingPath = workingPath; + const robot = urdfLoader.parse(cleaned); + urdfLoader.workingPath = ""; + if (!robot) { + log(`URDF returned null for ${entry.url} — using placeholder`); + makePlaceholder(entry); + return resolve(); + } + robot.up.set(0, 0, 1); + robot.rotation.order = "ZYX"; + // pybullet loadURDF supports a globalScaling kwarg (e.g. + // pybullet_coffee passes 0.09 for kettle.urdf). urdf-loader + // doesn't know about it, so apply it client-side. + if (entry.scale && entry.scale !== 1.0) { + robot.scale.setScalar(entry.scale); + } + robot.traverse((o) => { + if (o.isMesh) { + o.castShadow = true; + o.receiveShadow = true; + } + }); + // Apply per-link RGBA from pybullet (overrides the URDF's + // parsed material colors — e.g. grow's cup/jug URDFs have no + // blocks, so urdf-loader renders them white; the + // env tints them via p.changeVisualShape). + if (entry.link_colors && robot.links) { + for (const [linkName, rgba] of Object.entries(entry.link_colors)) { + const link = robot.links[linkName]; + if (!link) continue; + link.traverse((o) => { + if (!o.isMesh || !o.material) return; + // Clone so we don't mutate a shared default material. + const mat = o.material.clone(); + mat.color?.setRGB?.(rgba[0], rgba[1], rgba[2]); + mat.opacity = rgba[3]; + mat.transparent = rgba[3] < 1; + o.material = mat; + }); + } + } + scene.add(robot); + const joints = {}; + for (const jn of entry.joint_names) { + joints[jn] = robot.joints?.[jn] || null; + } + bodyMap.set(entry.body_id, { root: robot, joints, kind: "urdf" }); + resolve(); + }).catch((err) => { + log(`URDF load failed (${entry.url}): ${err?.message || err}`); + makePlaceholder(entry); + resolve(); + }); + }); +} + +function makePlaceholder(entry) { + log(`PLACEHOLDER for body ${entry.body_id} (${entry.name}, kind=${entry.kind}, url=${entry.url || ''})`); + const root = new THREE.Group(); + let mesh; + if (entry.name && entry.name.toLowerCase().includes("plane")) { + // Real-looking ground plane — receives shadows, not pink. + // Kept modest size (5x5) so it doesn't blow up the scene bbox + // used by fitCameraToScene(). + const geom = new THREE.PlaneGeometry(5, 5); + const mat = new THREE.MeshStandardMaterial({ + color: 0xdde2eb, roughness: 0.95, metalness: 0.0, + }); + mesh = new THREE.Mesh(geom, mat); + mesh.receiveShadow = true; + root.userData.isGround = true; + } else { + mesh = new THREE.Mesh( + new THREE.BoxGeometry(0.1, 0.1, 0.1), + new THREE.MeshStandardMaterial({ color: 0xff00ff, roughness: 0.6 }) + ); + } + root.add(mesh); + scene.add(root); + bodyMap.set(entry.body_id, { root, joints: {}, kind: "placeholder" }); +} + +function makePrimitive(entry) { + // entry.shapes is a list of {geom, dims, mesh_url, local_pos, + // local_orn, rgba, link}. For multi-shape primitives we group them + // under one parent. + const root = new THREE.Group(); + for (const s of entry.shapes) { + let geom = null; + const dims = s.dims; + // pybullet's getVisualShapeData returns FULL extents for BOX + // (not the halfExtents passed to createCollisionShape) — verified + // against pybullet_ants where food_half_extents=(0.03,0.03,0.03) + // gets reported as dims=(0.06,0.06,0.06). So no *2 here. + switch (s.geom) { + case "box": + geom = new THREE.BoxGeometry(dims[0], dims[1], dims[2]); + break; + case "sphere": + geom = new THREE.SphereGeometry(dims[0], 24, 16); + break; + case "cylinder": + // pybullet cylinder dims: [length, radius, _]. + geom = new THREE.CylinderGeometry(dims[1], dims[1], dims[0], 24); + break; + case "plane": + geom = new THREE.PlaneGeometry(5, 5); + break; + case "mesh": + // In-memory vertex meshes (e.g. domino's top triangle, created + // via createVisualShape(GEOM_MESH, vertices=[...])) have no + // mesh_url and getVisualShapeData reports dims=(1,1,1). Drawing + // a 1m box stand-in dwarfs the scene, so skip them. + if (!s.mesh_url) continue; + geom = new THREE.BoxGeometry(dims[0], dims[1], dims[2]); + break; + default: + continue; + } + const rgba = s.rgba; + const mat = new THREE.MeshStandardMaterial({ + color: new THREE.Color(rgba[0], rgba[1], rgba[2]), + opacity: rgba[3], + transparent: rgba[3] < 0.999, + roughness: 0.55, metalness: 0.05, + }); + const mesh = new THREE.Mesh(geom, mat); + mesh.castShadow = true; + mesh.receiveShadow = true; + // Local visual frame offset (from URDF visual origin). + mesh.position.fromArray(s.local_pos); + mesh.quaternion.fromArray(s.local_orn); + root.add(mesh); + } + scene.add(root); + bodyMap.set(entry.body_id, { root, joints: {}, kind: "primitive" }); +} + +async function buildSceneFromManifest(manifest) { + clearScene(); + const urdfPromises = []; + for (const entry of manifest) { + try { + if (entry.kind === "urdf") { + urdfPromises.push(loadUrdfBody(entry)); + } else { + makePrimitive(entry); + } + } catch (e) { + log(`Skipping body ${entry.body_id} (${entry.name}): ${e.message}`); + makePlaceholder(entry); + } + } + await Promise.all(urdfPromises); +} + +function applyFrame(frame) { + for (const [idStr, state] of Object.entries(frame)) { + const id = Number(idStr); + const b = bodyMap.get(id); + if (!b) continue; + b.root.position.fromArray(state.pos); + b.root.quaternion.fromArray(state.orn); + if (b.kind === "urdf") { + for (const [jname, angle] of Object.entries(state.joints)) { + const j = b.joints[jname]; + if (j && typeof j.setJointValue === "function") { + j.setJointValue(angle); + } + } + } + } +} + +let rafRunning = false; +function renderLoop() { + controls.update(); + renderer.render(scene, camera); + if (rafRunning) requestAnimationFrame(renderLoop); +} +function startRenderLoop() { + if (rafRunning) return; + rafRunning = true; + requestAnimationFrame(renderLoop); +} + +// ---------------------------------------------------------------------------- +// Pyodide bootstrap +// ---------------------------------------------------------------------------- + +async function boot() { + setStatus("Loading Pyodide runtime…"); + pyodide = await loadPyodide({ + indexURL: "https://cdn.jsdelivr.net/pyodide/v0.29.4/full/", + stdout: log, stderr: log, + }); + log("Pyodide loaded"); + + setStatus("Loading base packages (numpy, matplotlib, pillow)…"); + await pyodide.loadPackage(["micropip", "numpy", "matplotlib", "pillow"]); + + setStatus("Staging wheels into Pyodide FS…"); + pyodide.FS.writeFile(`/tmp/${PYBULLET_WHEEL}`, await fetchBytes(`${WHEELS_BASE}/${PYBULLET_WHEEL}`)); + pyodide.FS.writeFile(`/tmp/${GYM_SHIM_WHEEL}`, await fetchBytes(`${WHEELS_BASE}/${GYM_SHIM_WHEEL}`)); + pyodide.FS.writeFile(`/tmp/${PREDICATORS_WHEEL}`, await fetchBytes(`${WHEELS_BASE}/${PREDICATORS_WHEEL}`)); + + setStatus("Installing wheels…"); + await pyodide.runPythonAsync(` +import micropip +# Pybullet + gym shim first; these are local emfs wheels that wouldn't +# resolve from PyPI (no Pyodide-targeted pybullet on PyPI; gym 0.26.2 +# has no pure-Python wheel — see web/app/gym_shim_setup.py). +await micropip.install("emfs:/tmp/${PYBULLET_WHEEL}") +await micropip.install("emfs:/tmp/${GYM_SHIM_WHEEL}") +# predicators' install_requires is now the env-runtime slim set +# (matches the dep audit in setup.py). Let micropip pull each dep +# transitively. keep_going=True so platform-specific dead-ends +# (pybullet-arm64 on PyPI, version pins that conflict with the +# Pyodide-shipped numpy/matplotlib/pillow) get skipped instead of +# aborting the install. +await micropip.install("emfs:/tmp/${PREDICATORS_WHEEL}", + deps=True, keep_going=True) +print("predicators installed") +`); + + // Probe the tarball size up-front so the status reflects reality + // rather than a stale hard-coded number. Falls back gracefully if + // the server doesn't return Content-Length. + const assetsUrl = `${WHEELS_BASE}/${ASSETS_TARBALL}`; + let assetSizeLabel = ""; + try { + const head = await fetch(assetsUrl, { method: "HEAD" }); + const len = parseInt(head.headers.get("content-length") || "", 10); + if (Number.isFinite(len) && len > 0) { + assetSizeLabel = ` (${(len / 1024 / 1024).toFixed(1)} MB)`; + } + } catch { /* fall through, show no size */ } + setStatus(`Fetching + unpacking env assets${assetSizeLabel}…`); + const assetsBuf = await fetchBytes(assetsUrl); + pyodide.FS.mkdirTree("/lib/python3.13/site-packages/predicators/envs"); + await pyodide.runPythonAsync( + `import os; os.chdir("/lib/python3.13/site-packages/predicators/envs")`); + pyodide.unpackArchive(assetsBuf, "tar.gz"); + log("Assets unpacked"); + + setStatus("Loading bridge…"); + const setupSrc = await (await fetch("./setup.py")).text(); + pyodide.FS.writeFile("/setup.py", setupSrc); + await pyodide.runPythonAsync("exec(open('/setup.py').read(), globals())"); + + setStatus("Ready. Pick an env from the dropdown."); + envSelect.disabled = false; + // Reset button stays disabled until the user picks an env, so it + // can't be ambiguous about "start" vs "reset". + bootBtn.disabled = true; + startRenderLoop(); +} + +// ---------------------------------------------------------------------------- +// Env reset + option execution +// ---------------------------------------------------------------------------- + +const optionRow = document.getElementById("option-row"); +const optionSelect = document.getElementById("option-select"); +const optionArgs = document.getElementById("option-args"); +const executeBtn = document.getElementById("execute-option"); + +let currentOptions = []; +let currentObjects = []; + +async function resetEnv() { + const envName = envSelect.value; + if (!envName) return; // placeholder option selected + bootBtn.disabled = false; + setStatus(`Constructing ${envName}…`); + const t = (() => { const s = performance.now(); return () => ((performance.now() - s) / 1000).toFixed(2); }); + try { + const tBridge = t(); + const outProxy = await pyodide.runPythonAsync( + `bridge.reset("${envName}")`); + const dtBridge = tBridge(); + const info = outProxy.toJs({ dict_converter: Object.fromEntries }); + outProxy.destroy(); + infoEl.textContent = `task=${info.task_idx} objects=${info.num_objects} action_dim=${info.action_dim} bodies=${info.manifest.length}`; + log(`Reset ${envName} -> ${info.manifest.length} bodies (bridge.reset: ${dtBridge}s)`); + + setStatus("Building Three.js scene from manifest…"); + const tScene = t(); + await buildSceneFromManifest(info.manifest); + log(` buildSceneFromManifest: ${tScene()}s`); + // Initial pose snapshot. + const stateProxy = await pyodide.runPythonAsync(`bridge.get_all_body_states()`); + applyFrame(stateProxy.toJs({ dict_converter: Object.fromEntries })); + stateProxy.destroy(); + if (info.camera) { + applyEnvCamera(info.camera); + } else { + fitCameraToScene(); + } + + const optsProxy = await pyodide.runPythonAsync(`bridge.list_options()`); + const objsProxy = await pyodide.runPythonAsync(`bridge.list_objects()`); + currentOptions = optsProxy.toJs({ dict_converter: Object.fromEntries }); + currentObjects = objsProxy.toJs({ dict_converter: Object.fromEntries }); + optsProxy.destroy(); objsProxy.destroy(); + populateOptionPicker(); + + setStatus("Env ready. Drag canvas to orbit. Pick an option to execute."); + optionRow.style.display = ""; + } catch (e) { + setStatus("Reset failed — see log."); + log("ERROR: " + (e.message || e)); + } +} + +function populateOptionPicker() { + optionSelect.innerHTML = ""; + for (const opt of currentOptions) { + const o = document.createElement("option"); + o.value = opt.name; + o.textContent = `${opt.name}(${opt.type_names.join(", ")})`; + optionSelect.appendChild(o); + } + optionSelect.addEventListener("change", renderOptionArgs); + renderOptionArgs(); +} + +function renderOptionArgs() { + optionArgs.innerHTML = ""; + const opt = currentOptions.find((o) => o.name === optionSelect.value); + if (!opt) return; + for (const tname of opt.type_names) { + const sel = document.createElement("select"); + sel.className = "opt-arg"; + sel.dataset.typeName = tname; + for (const obj of currentObjects.filter((o) => o.type_name === tname)) { + const o = document.createElement("option"); + o.value = obj.name; + o.textContent = obj.name; + sel.appendChild(o); + } + optionArgs.appendChild(sel); + } +} + +// Sim-step consumption rate at the visual layer (steps per second of +// wall-clock playback). Each pybullet step moves the EE by up to +// pybullet_max_vel_norm (~5cm), so 12 steps/s ≈ a 0.6 m/s max EE +// speed — close to a natural collaborative-robot pace. Decoupled +// from pybullet's own sim_steps_per_action (which controls physics +// integration); this just paces how fast we *render* completed +// steps. Override via ?simRate=N. +const SIM_RATE_HZ = (() => { + const q = new URLSearchParams(window.location.search).get("simRate"); + const n = q ? Number(q) : NaN; + return Number.isFinite(n) && n > 0 ? n : 12; +})(); + +// Scratch quaternions reused across interpolated frames to avoid +// allocating per-body per-rAF. +const _qPrev = new THREE.Quaternion(); +const _qCurr = new THREE.Quaternion(); + +function _lerp(a, b, t) { return a + (b - a) * t; } + +// Apply a body state interpolated between two pybullet-step snapshots. +// `t` is in [0, 1]; t=0 is prev, t=1 is curr. Used by the rAF tick to +// paint smooth motion between sim steps when SIM_RATE_HZ < rAF rate. +function applyInterpolatedFrame(prev, curr, t) { + for (const [idStr, currState] of Object.entries(curr)) { + const id = Number(idStr); + const b = bodyMap.get(id); + if (!b) continue; + const prevState = prev?.[idStr]; + if (!prevState) { + // New body this step — paint at curr without interpolation. + b.root.position.fromArray(currState.pos); + b.root.quaternion.fromArray(currState.orn); + } else { + b.root.position.set( + _lerp(prevState.pos[0], currState.pos[0], t), + _lerp(prevState.pos[1], currState.pos[1], t), + _lerp(prevState.pos[2], currState.pos[2], t), + ); + _qPrev.fromArray(prevState.orn); + _qCurr.fromArray(currState.orn); + _qPrev.slerp(_qCurr, t); + b.root.quaternion.copy(_qPrev); + } + if (b.kind === "urdf") { + for (const [jname, currAngle] of Object.entries(currState.joints)) { + const j = b.joints[jname]; + if (!j || typeof j.setJointValue !== "function") continue; + const prevAngle = prevState?.joints?.[jname] ?? currAngle; + j.setJointValue(_lerp(prevAngle, currAngle, t)); + } + } + } +} + +async function executeOption() { + const name = optionSelect.value; + const args = Array.from(optionArgs.querySelectorAll("select.opt-arg")) + .map((s) => s.value); + setStatus(`Executing ${name}(${args.join(", ")})…`); + try { + // Arm the option: bridge stashes the grounded option + initial + // state. Returns the initial-frame body states so JS can paint + // step 0 before any pybullet step has happened. + // JSON.stringify(name) yields a Python-safe string literal too + // (Python and JSON share the "..." escape syntax for common + // payloads), so an option name with " or \ doesn't break the + // injected Python source. + const argList = JSON.stringify(args); + const initProxy = await pyodide.runPythonAsync( + `bridge.begin_option(${JSON.stringify(name)}, ${argList})`); + const init = initProxy.toJs({ dict_converter: Object.fromEntries }); + initProxy.destroy(); + if (init.error) { + setStatus(`${name}: ${init.error}`); + log(`${name}: ${init.error}`); + return; + } + applyFrame(init.initial_frame); + + // Fixed-sim-rate, variable-render-rate loop (the standard + // game-loop pattern): we advance a virtual `simCursor` at + // SIM_RATE_HZ per wall-clock second. Each rAF tick: + // 1) advance simCursor by dt * SIM_RATE_HZ + // 2) while simCursor has crossed an integer, pull next + // pybullet step into `currFrame` (and shift the old curr + // into `prevFrame`) + // 3) render the interpolated state between prevFrame and + // currFrame at the fractional cursor position + // + // This keeps the render synced to the monitor (rAF) while the + // arm moves at SIM_RATE_HZ-per-second of sim steps regardless + // of how fast pybullet itself can step. + let prevFrame = init.initial_frame; + let currFrame = init.initial_frame; + let simCursor = 0; // float, advances at SIM_RATE_HZ per second + let lastTimestamp = null; + let lastSteps = 0; + let pythonDone = false; + let finalError = null; + let finalColorUpdates = null; + + const stepCall = `bridge.step_option()`; + + function pullOneSimStep() { + let r; + try { + const proxy = pyodide.runPython(stepCall); + r = proxy.toJs({ dict_converter: Object.fromEntries }); + proxy.destroy(); + } catch (e) { + log("ERROR during step: " + (e.message || e)); + pythonDone = true; + return; + } + if (r.added_bodies?.length || r.removed_body_ids?.length) { + reconcileBodies(r.added_bodies || [], + r.removed_body_ids || []); + } + prevFrame = currFrame; + currFrame = r.frame; + lastSteps = r.steps; + if (r.done) { + pythonDone = true; + finalError = r.error; + finalColorUpdates = r.color_updates; + } + } + + await new Promise((resolve) => { + function tick(timestamp) { + if (lastTimestamp === null) lastTimestamp = timestamp; + const dt = Math.min(0.1, (timestamp - lastTimestamp) / 1000); + lastTimestamp = timestamp; + simCursor += dt * SIM_RATE_HZ; + + // Consume integer sim steps while we have visual budget for + // them. Cap the while-loop so a long stall (tab backgrounded, + // etc.) doesn't blow through the whole option in one tick. + let pulled = 0; + while (simCursor >= 1 && !pythonDone && pulled < 4) { + pullOneSimStep(); + simCursor -= 1; + pulled += 1; + } + if (pythonDone && simCursor > 1) simCursor = 1; + + // Render at the fractional cursor between prev and curr. + applyInterpolatedFrame(prevFrame, currFrame, + Math.min(1, Math.max(0, simCursor))); + + if (lastSteps && lastSteps % 30 === 0) { + setStatus(`Executing ${name}(${args.join(", ")}) — step ${lastSteps}`); + } + + if (pythonDone && simCursor >= 1) { + // Land exactly on the last sim state, then finalize. + applyInterpolatedFrame(prevFrame, currFrame, 1); + if (finalColorUpdates) applyColorUpdates(finalColorUpdates); + resolve(); + } else { + requestAnimationFrame(tick); + } + } + requestAnimationFrame(tick); + }); + + if (finalError) { + setStatus(`${name}: ${finalError}`); + log(`${name}(${args.join(", ")}) stopped after ${lastSteps} steps: ${finalError}`); + } else { + setStatus(`${name} done in ${lastSteps} steps.`); + log(`Executed ${name}(${args.join(", ")}) -> ${lastSteps} steps`); + } + } catch (e) { + setStatus("Execute failed — see log."); + log("ERROR: " + (e.message || e)); + } +} + +function applyColorUpdates(updates) { + for (const [bidStr, linkMap] of Object.entries(updates)) { + const bid = Number(bidStr); + const b = bodyMap.get(bid); + if (!b) continue; + if (b.kind === "primitive") { + // Primitive bodies are a flat group: one child Mesh per shape, + // in the same order the manifest produced them. We don't track + // which mesh maps to which link, but our envs all repaint with + // a single uniform color per body. Push that color onto every + // child mesh. + const rgba = Object.values(linkMap)[0]; + if (!rgba) continue; + b.root.traverse((o) => { + if (!o.isMesh || !o.material) return; + o.material.color?.setRGB?.(rgba[0], rgba[1], rgba[2]); + o.material.opacity = rgba[3]; + o.material.transparent = rgba[3] < 1; + o.material.needsUpdate = true; + }); + } else if (b.kind === "urdf") { + for (const [linkIdxStr, rgba] of Object.entries(linkMap)) { + // urdf-loader keys links by name, not index. linkIdx -1 is the + // base. We don't have a robust index→name map, so just paint + // every mesh in the body — matches the envs that recolor a + // whole URDF (e.g. grow's cup/jug tints). + void linkIdxStr; + b.root.traverse((o) => { + if (!o.isMesh || !o.material) return; + o.material.color?.setRGB?.(rgba[0], rgba[1], rgba[2]); + o.material.opacity = rgba[3]; + o.material.transparent = rgba[3] < 1; + o.material.needsUpdate = true; + }); + } + } + } +} + +async function reconcileBodies(added, removedIds) { + for (const bid of removedIds) { + const b = bodyMap.get(bid); + if (!b) continue; + scene.remove(b.root); + bodyMap.delete(bid); + } + const urdfPromises = []; + for (const entry of added) { + if (entry.kind === "urdf") { + urdfPromises.push(loadUrdfBody(entry)); + } else { + makePrimitive(entry); + } + } + await Promise.all(urdfPromises); +} + +executeBtn.addEventListener("click", executeOption); +bootBtn.addEventListener("click", resetEnv); +envSelect.addEventListener("change", resetEnv); + +boot().catch((e) => { + setStatus("Boot failed — see log."); + log("FATAL: " + (e.message || e)); + console.error(e); +}); diff --git a/web/app/node_smoke.mjs b/web/app/node_smoke.mjs new file mode 100644 index 0000000000..51e4fffe67 --- /dev/null +++ b/web/app/node_smoke.mjs @@ -0,0 +1,244 @@ +// Node-side smoke test: mirrors what main.js does in the browser, so we +// can iterate on the Pyodide bridge without spinning up a real browser. +// +// Run from web/app/ via: node --experimental-fetch node_smoke.mjs +// (or with regular node if your version has fetch built in). + +import { loadPyodide } from "../node_modules/pyodide/pyodide.mjs"; +import { readFileSync } from "node:fs"; +import { resolve, dirname } from "node:path"; +import { fileURLToPath } from "node:url"; + +const HERE = dirname(fileURLToPath(import.meta.url)); +const WHEELS = resolve(HERE, "../wheels"); + +const t0 = Date.now(); +const log = (...a) => console.log(`[${((Date.now() - t0) / 1000).toFixed(1)}s]`, ...a); + +log("Loading Pyodide…"); +const pyodide = await loadPyodide({ stdout: log, stderr: log }); +log("Pyodide ready"); + +await pyodide.loadPackage(["micropip", "numpy", "matplotlib", "pillow"]); +log("Base packages loaded"); + +// Pyodide's micropip can install from local files via file:// URIs, +// but it's easier to read the wheel bytes and hand it to micropip. +function readWheel(name) { + return readFileSync(resolve(WHEELS, name)); +} + +// Pyodide's micropip parses the wheel filename out of the URL, so we +// must keep the canonical names rather than abbreviating. +const PYBULLET_WHEEL = "pybullet-3.2.7-cp313-cp313-pyemscripten_2025_0_wasm32.whl"; +const PREDICATORS_WHEEL = "predicators-0.1.0-py3-none-any.whl"; +const GYM_SHIM_WHEEL = "gym-0.26.2-py3-none-any.whl"; +pyodide.FS.writeFile(`/tmp/${PYBULLET_WHEEL}`, readWheel(PYBULLET_WHEEL)); +pyodide.FS.writeFile(`/tmp/${PREDICATORS_WHEEL}`, readWheel(PREDICATORS_WHEEL)); +pyodide.FS.writeFile(`/tmp/${GYM_SHIM_WHEEL}`, readWheel(GYM_SHIM_WHEEL)); + +try { + await pyodide.runPythonAsync(` +import sys, traceback +print("=== start install block ===", flush=True) +try: + import micropip + print("micropip loaded", flush=True) + await micropip.install("emfs:/tmp/${PYBULLET_WHEEL}") + print("pybullet installed", flush=True) + # Try to import pybullet right away to see if it loads. + try: + import pybullet as p + print("pybullet imported; version =", getattr(p, "__version__", "?"), flush=True) + cid = p.connect(p.DIRECT) + print("pybullet connected, cid =", cid, flush=True) + except Exception as e: + print("pybullet import/connect FAILED:", type(e).__name__, e, flush=True) + traceback.print_exc() + raise + # Install our gym shim first so micropip doesn't try the real gym. + await micropip.install("emfs:/tmp/${GYM_SHIM_WHEEL}") + print("gym shim installed", flush=True) + # The predicators wheel's install_requires is now the env-runtime + # slim set, so let micropip do the resolution. keep_going=True + # skips platform-specific dead-ends (pybullet-arm64 on PyPI, + # numpy/matplotlib/pillow version pins vs. Pyodide-shipped). + await micropip.install("emfs:/tmp/${PREDICATORS_WHEEL}", + deps=True, keep_going=True) + print("predicators installed", flush=True) +except Exception as e: + print("INSTALL ERROR:", type(e).__name__, e, flush=True) + traceback.print_exc() + raise +`); +} catch (e) { + log("install threw: " + (e.message || e)); + process.exit(1); +} +log("Wheels installed"); + +// Mount the env asset dir at the path predicators expects so that +// `os.path.exists(envs/assets/urdf/plane.urdf)` etc. work without +// baking 141 MB of meshes into the wheel. +const ASSET_SRC = resolve(HERE, "../../predicators/envs/assets"); +const ASSET_DEST = "/lib/python3.13/site-packages/predicators/envs/assets"; +try { pyodide.FS.rmdir(ASSET_DEST); } catch {} +pyodide.FS.mkdirTree(ASSET_DEST); +pyodide.FS.mount(pyodide.FS.filesystems.NODEFS, { root: ASSET_SRC }, ASSET_DEST); +log(`Mounted assets: ${ASSET_SRC} -> ${ASSET_DEST}`); + +const setupSrc = readFileSync(resolve(HERE, "setup.py"), "utf8"); +pyodide.FS.writeFile("/setup.py", setupSrc); +try { + await pyodide.runPythonAsync(` +import traceback +try: + exec(open('/setup.py').read(), globals()) + print('setup.py loaded ok', flush=True) +except SystemExit as e: + print('SystemExit during setup:', e, flush=True) +except Exception as e: + print('SETUP ERROR:', type(e).__name__, e, flush=True) + traceback.print_exc() + raise +`); +} catch (e) { + log("setup threw: " + (e.message || e)); + process.exit(1); +} +log("Bridge ready"); + +log("Trying bridge.reset…"); +try { + await pyodide.runPythonAsync(` +import traceback +try: + info = bridge.reset("pybullet_coffee") + print(f"MANIFEST ({len(info['manifest'])} bodies):", flush=True) + for e in info['manifest']: + kind = e.get('kind', '?') + url = e.get('url', '') + shapes = e.get('shapes', []) + print(f" body_id={e['body_id']} name={e['name']!r:20s} kind={kind} " + f"url={url or ''}", flush=True) + for s in shapes: + print(f" link={s['link']:>2} geom={s['geom']:8s} " + f"dims={s['dims']} rgba={[round(c,2) for c in s['rgba']]} " + f"local_pos={[round(c,3) for c in s['local_pos']]}", + flush=True) + import sys; sys.exit(0) + + # Inspect grow PickJug behaviour in detail. + import pybullet as p + from predicators.ground_truth_models import get_gt_options, get_gt_nsrts + env = bridge.env + state = env._current_observation + print("---- pre-pick state ----", flush=True) + for o in sorted(env._objects, key=lambda o: o.name): + feats = {f: state.get(o, f) for f in o.type.feature_names + if f in {'x','y','z','rot','fingers','tilt','wrist'}} + print(f" {o.name:12s} {o.type.name:10s} {feats}", flush=True) + + options = {o.name: o for o in get_gt_options('pybullet_grow')} + nsrts = {n.name: n for n in get_gt_nsrts('pybullet_grow', env.predicates, set(options.values()))} + print("NSRTs:", list(nsrts.keys()), flush=True) + + opt = options['PickJug'] + print(f"PickJug params_space: low={opt.params_space.low} high={opt.params_space.high}", flush=True) + name_to_obj = {o.name: o for o in env._objects} + chosen = [name_to_obj['robot'], name_to_obj['jug1']] + import numpy as np + rng = np.random.default_rng(0) + g = bridge._sample_ground_option(opt, chosen, state, rng) + print(f"Grounded PickJug params: {g.params}", flush=True) + print(f"Initiable: {g.initiable(state)}", flush=True) + # Run PickJug step-by-step and print every N steps. + import logging; logging.basicConfig(level=logging.DEBUG) + for step in range(200): + if g.terminal(state): + print(f"TERMINATED at step={step}", flush=True) + break + act = g.policy(state) + state = env.step(act) + r = name_to_obj['robot'] + j = name_to_obj['jug1'] + mem = g.memory if hasattr(g, 'memory') else {} + phase_idx = mem.get('phase_idx', '?') + print(f" step={step:3d} phase_idx={phase_idx} " + f"ee=({state.get(r,'x'):.3f},{state.get(r,'y'):.3f},{state.get(r,'z'):.3f}) " + f"fingers={state.get(r,'fingers'):.3f} " + f"jug=({state.get(j,'x'):.3f},{state.get(j,'y'):.3f},{state.get(j,'z'):.3f})", flush=True) + else: + print("Did not terminate within 200 steps", flush=True) + r = name_to_obj['robot']; j = name_to_obj['jug1'] + print(f"POST-PickJug ee=({state.get(r,'x'):.3f},{state.get(r,'y'):.3f},{state.get(r,'z'):.3f}) " + f"fingers={state.get(r,'fingers'):.3f} " + f"jug=({state.get(j,'x'):.3f},{state.get(j,'y'):.3f},{state.get(j,'z'):.3f})", flush=True) + # Holding predicate check + Holding = next((pp for pp in env.predicates if pp.name == 'Holding'), None) + if Holding is not None: + print(f"Holding(robot, jug1)? {Holding.holds(state, [r, j])}", flush=True) + + # Dump colors so we can pick a matching cup for pour. + print("Colors:", flush=True) + for o in env._objects: + if o.type.name in {'cup', 'jug'}: + cr = state.get(o, 'r'); cg = state.get(o, 'g'); cb = state.get(o, 'b') + print(f" {o.name}: ({cr:.2f}, {cg:.2f}, {cb:.2f})", flush=True) + + # Find a cup that matches jug1's color, fall back to cup0. + jug_obj = name_to_obj['jug1'] + jr, jg, jb = state.get(jug_obj,'r'), state.get(jug_obj,'g'), state.get(jug_obj,'b') + matching = None + for o in env._objects: + if o.type.name == 'cup': + if abs(state.get(o,'r')-jr)<0.01 and abs(state.get(o,'g')-jg)<0.01 and abs(state.get(o,'b')-jb)<0.01: + matching = o.name + break + if matching is None: + matching = 'cup0' + print(f"Pouring into matching cup: {matching}", flush=True) + + # Continue with Pour(robot, jug1, matching) + Place + for op_name, obj_names in [ + ('Pour', ['robot', 'jug1', matching]), + ('Place', ['robot', 'jug1']), + ]: + opt = options[op_name] + chosen = [name_to_obj[n] for n in obj_names] + g = bridge._sample_ground_option(opt, chosen, state, rng) + print(f"=== Try {op_name}{tuple(obj_names)} params={list(g.params)} initiable={g.initiable(state)} ===", flush=True) + if not g.initiable(state): + continue + for step in range(400): + if g.terminal(state): + print(f"{op_name} TERMINATED at step={step}", flush=True) + break + try: + act = g.policy(state) + except Exception as e: + print(f"{op_name} policy threw at step {step}: {type(e).__name__}: {e}", flush=True) + break + state = env.step(act) + if step % 20 == 0 or step < 3: + cup_obj = name_to_obj.get(matching if op_name == 'Pour' else 'cup0') + growth = state.get(cup_obj, 'growth') if cup_obj else -1 + mem = g.memory if hasattr(g, 'memory') else {} + phase_idx = mem.get('phase_idx', '?') + print(f" {op_name} step={step:3d} phase={phase_idx} " + f"ee=({state.get(r,'x'):.3f},{state.get(r,'y'):.3f},{state.get(r,'z'):.3f}) " + f"tilt={state.get(r,'tilt'):.2f} " + f"jug=({state.get(j,'x'):.3f},{state.get(j,'y'):.3f},{state.get(j,'z'):.3f}) " + f"growth={growth:.3f}", flush=True) + else: + print(f"{op_name} did not terminate within 400 steps", flush=True) +except Exception as e: + print("BRIDGE ERROR:", type(e).__name__, e, flush=True) + traceback.print_exc() +`); +} catch (e) { + log("bridge call threw: " + (e.message || e)); +} + +log("DONE"); +process.exit(0); diff --git a/web/app/serve.sh b/web/app/serve.sh new file mode 100755 index 0000000000..244875c04a --- /dev/null +++ b/web/app/serve.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash +# Tiny local server. Pyodide needs HTTP (not file://) and the right MIME +# types for .wasm / .whl / .mjs. Python's http.server does the right +# thing for our purposes (.whl is just a zip; .mjs/.wasm get correct +# MIME thanks to Python 3.13's defaults). +# +# Run from this directory: ./serve.sh then open http://localhost:8080/ +set -euo pipefail +cd "$(dirname "$0")/.." # cd into web/ so /app and /wheels are siblings +python -m http.server 8080 diff --git a/web/app/setup.py b/web/app/setup.py new file mode 100644 index 0000000000..d6e1ebee5c --- /dev/null +++ b/web/app/setup.py @@ -0,0 +1,658 @@ +"""Pyodide-side bridge for the predicators browser demo. + +Exposes a minimal `bridge` object that the JS layer can call: +- bridge.reset(env_name) -> {task_idx, num_objects, action_dim, manifest} +- bridge.list_options() / list_objects() +- bridge.begin_option(name, object_names, ...) -> {initial_frame, error} +- bridge.step_option() -> {frame, steps, done, error, + added_bodies, removed_body_ids, color_updates} + where each frame is a {body_id: {pos, orn, joints: {name: rad}}} snapshot. + JS drives step_option() inside a requestAnimationFrame loop until + done — one pybullet step per browser-frame so the renderer paints + each step instead of seeing the whole option batched at the end. + +Rendering is done client-side by Three.js + urdf-loader. The bridge +extracts (a) a scene manifest (URDF refs for URDF-loaded bodies + +primitive descriptors for createMultiBody bodies) and (b) per-step +body states so the JS side can drive a real WebGL scene. +""" + +import os + +import numpy as np +import pybullet as p + +from predicators import utils_lite as _utils +from predicators.envs import create_new_env +from predicators.settings import CFG + +# Pybullet GEOM_* constants: +# 2 SPHERE, 3 BOX, 4 CYLINDER, 5 MESH, 6 PLANE, 7 CAPSULE +_GEOM_NAMES = { + 2: "sphere", + 3: "box", + 4: "cylinder", + 5: "mesh", + 6: "plane", + 7: "capsule" +} + +# Pyodide FS prefix where the assets tarball is unpacked at boot. We +# strip this so the JS side can fetch via the dev server's static +# `./predicators_assets/` symlink. +_ASSET_FS_PREFIX = "/lib/python3.13/site-packages/predicators/envs/assets/" +_ASSET_URL_PREFIX = "../predicators_assets/" + + +def _asset_url(path): + """Translate a Pyodide-FS asset path to a server-relative URL.""" + if not path: + return None + s = path.decode() if isinstance(path, (bytes, bytearray)) else str(path) + if s.startswith(_ASSET_FS_PREFIX): + return _ASSET_URL_PREFIX + s[len(_ASSET_FS_PREFIX):] + # Some envs store mesh paths relative to a setAdditionalSearchPath + # call. If the path exists under assets/, anchor it there. + if not os.path.isabs(s): + return _ASSET_URL_PREFIX + s + return None # off-tree path we can't serve + + +def _ensure_cfg(): + _utils.reset_config({ + "env": "pybullet_blocks", + "seed": 0, + "num_train_tasks": 1, + "num_test_tasks": 1, + "approach": "oracle", + }) + + +# Env-specific CFG overrides that turn on the full demo content. The +# README pitches coffee as "plug in, brew, pour, serve" — that path +# only exists when has_plug=True; with the predicators default the +# powercord, plug, and socket aren't created at all. +_ENV_CFG_OVERRIDES = { + "pybullet_coffee": { + "coffee_machine_has_plug": True + }, + # Grow's default `grow_use_skill_factories=True` selects a + # half-finished skill-factory Pour (terminal triggers as soon as + # the EE reaches the tilt pose, before any growth happens) and a + # 4-d skill-factory Place that needs a jug-handle-offset-aware + # sampler that nsrts.py doesn't provide. The upstream demo + # configs (scripts/configs/predicatorv3/random_actions_pybullet + # .yaml, ExoPredicator/causal_predicator.yaml) route through the + # legacy path; we mirror that minus weak_pour. We *don't* set + # `grow_weak_pour_terminate_condition` because that exits Pour + # the moment the jug is above + tilted (fine for the random- + # actions demo, but no visible growth) — instead we keep the + # default terminal (Grown.holds(cup)) so the plant fully grows. + "pybullet_grow": { + "grow_use_skill_factories": False, + "grow_place_option_no_sampler": True, + }, + # Same pattern as grow: nsrts.py looks up specific option names + # (PlaceOnBurner, PlaceUnderFaucet, etc.) that only exist on the + # legacy code path. Default skill-factory path produces a single + # generic Place option, so nsrts construction KeyErrors. + "pybullet_boil": { + "boil_use_skill_factories": False + }, +} + + +class _Bridge: + + def __init__(self): + self.env = None + self.task = None + # body_id -> {url, scale}. Filled in by the loadURDF + # monkey-patch during env construction; reset() rebuilds it + # for the freshly-constructed env. + self._urdf_map = {} + # NSRTs for the current env, used by execute_option to get + # state-aware param samples instead of uniform-over-the-box + # samples (the latter misses grasps / drops blocks off-tower). + self._nsrts = set() + # Body ids known to the JS side. Some envs (grow) spawn new + # bodies mid-execution (e.g. growing plants); execute_option + # diffs against this set to surface additions/removals so JS + # can mount/unmount the corresponding meshes. + self._known_body_ids = set() + # Per-option rollout state. begin_option arms this; step_option + # advances it one pybullet step at a time and clears it on done. + # Splitting the rollout this way keeps the main JS thread alive + # between sim steps so the renderer can draw each frame instead + # of seeing the whole option as one batched update at the end. + self._active_option = None + # body_id -> {link_idx (-1 for base): rgba}. Snapshotted on + # reset and again after each execute_option so envs that + # repaint via p.changeVisualShape (balance button, coffee + # button + plate, etc.) get their new colors reflected in JS. + self._known_link_colors = {} + + def reset(self, env_name): + cfg = { + "env": env_name, + "seed": 0, + "num_train_tasks": 1, + "num_test_tasks": 1, + "approach": "oracle", + } + cfg.update(_ENV_CFG_OVERRIDES.get(env_name, {})) + _utils.reset_config(cfg) + + # Disconnect any prior env's pybullet client so its state + # doesn't shadow the new env's queries. Several predicators + # envs (e.g. pybullet_circuit._get_joint_id) call + # p.getNumJoints without an explicit physicsClientId; with + # two clients alive at once pybullet picks one and the new + # env reads garbage. + if self.env is not None: + try: + p.disconnect(physicsClientId=self.env._physics_client_id + ) # noqa: SLF001 + except p.error: + pass + self.env = None + + self._urdf_map = {} + # PyBullet-wasm doesn't implement p.createConstraint (e.g. + # coffee chains cord segments with JOINT_POINT2POINT). Without + # this shim, the env-reset path crashes during cord creation. + # We don't need physical chaining for visualization — let the + # call fail silently and the segments just hover in place. + if not getattr(p.createConstraint, "_bridge_safe", False): + _orig_constraint = p.createConstraint + + def _safe_create_constraint(*args, **kwargs): + try: + return _orig_constraint(*args, **kwargs) + except p.error: + return -1 + + _safe_create_constraint._bridge_safe = True # noqa: SLF001 + p.createConstraint = _safe_create_constraint + if not getattr(p.loadURDF, "_bridge_wrapped", False): + orig_load = p.loadURDF + + def _tracked_load(*args, **kwargs): + bid = orig_load(*args, **kwargs) + # loadURDF signature: (fileName, basePosition, + # baseOrientation, useMaximalCoordinates, + # useFixedBase, flags, globalScaling, + # physicsClientId). + url = _asset_url(args[0]) if args else None + if url is None and "fileName" in kwargs: + url = _asset_url(kwargs["fileName"]) + scale = kwargs.get("globalScaling", None) + if scale is None and len(args) >= 7: + scale = args[6] + if url is not None and isinstance(bid, int) and bid >= 0: + bridge._urdf_map[bid] = { + "url": url, + "scale": float(scale) if scale is not None else 1.0, + } + return bid + + _tracked_load._bridge_wrapped = True # noqa: SLF001 + _tracked_load._orig = orig_load # noqa: SLF001 + p.loadURDF = _tracked_load + + self.env = create_new_env(env_name, do_cache=False, use_gui=False) + self.task = self.env.reset("test", 0) + env = self.env + + # Cache GT NSRTs so execute_option can call gnsrt.sample_option + # (state-aware sampler) instead of uniformly sampling the + # params box. Mirrors the default human_option_control path. + from predicators.ground_truth_models import get_gt_nsrts, \ + get_gt_options + try: + env_options = get_gt_options(env_name) + except NotImplementedError: + env_options = set() + try: + self._nsrts = get_gt_nsrts(env_name, env.predicates, env_options) + except NotImplementedError: + self._nsrts = set() + manifest = self.get_scene_manifest() + self._known_body_ids = {e["body_id"] for e in manifest} + cid = self.env._physics_client_id # noqa: SLF001 + self._known_link_colors = { + bid: self._snapshot_link_colors(bid, cid) + for bid in self._known_body_ids + } + return { + "task_idx": 0, + "num_objects": len(env._objects), # noqa: SLF001 + "action_dim": int(env.action_space.shape[0]), + "manifest": manifest, + # Env-author-defined camera (Three.js translates pybullet + # yaw/pitch/distance into a position). + "camera": { + "target": list(env._camera_target), # noqa: SLF001 + "distance": float(env._camera_distance), # noqa: SLF001 + "yaw": float(env._camera_yaw), # noqa: SLF001 + "pitch": float(env._camera_pitch), # noqa: SLF001 + "fov": float(env._camera_fov), # noqa: SLF001 + }, + } + + # -- Scene manifest + state ------------------------------------ + + def _describe_body(self, body_id, cid): + """Build one manifest entry for a single body id.""" + try: + info = p.getBodyInfo(body_id, physicsClientId=cid) + except p.error: + return None + base_name = info[0].decode() if info and info[0] else "" + body_name = info[1].decode() if info and len(info) > 1 else "" + entry = { + "body_id": body_id, + "name": body_name or base_name or f"body_{body_id}", + } + joint_names = [] + num_joints = p.getNumJoints(body_id, physicsClientId=cid) + for j in range(num_joints): + jinfo = p.getJointInfo(body_id, j, physicsClientId=cid) + jname = jinfo[1].decode() if jinfo[1] else f"joint_{j}" + jtype = int(jinfo[2]) + if jtype != 4: # skip FIXED + joint_names.append(jname) + entry["joint_names"] = joint_names + if body_id in self._urdf_map: + entry["kind"] = "urdf" + entry["url"] = self._urdf_map[body_id]["url"] + entry["scale"] = self._urdf_map[body_id]["scale"] + entry["link_colors"] = self._collect_link_colors( + body_id, cid, base_name) + else: + vis = p.getVisualShapeData(body_id, physicsClientId=cid) + # Cache link-in-base transforms so multi-link primitives + # (e.g. coffee's machine = base + top + dispense, the + # button + lightbar) render their children at the right + # offset. getVisualShapeData reports visual offsets within + # the link, not within the body root — without composing + # in the link's own pose, all child-link shapes pile up + # on the base. + base_pos, base_orn = p.getBasePositionAndOrientation( + body_id, physicsClientId=cid) + inv_base = p.invertTransform(base_pos, base_orn) + link_pose_in_base_cache = {-1: ((0, 0, 0), (0, 0, 0, 1))} + shapes = [] + for v in vis: + # (uniqueId, linkIdx, geomType, dims, meshFile, + # localPos, localOrn, rgba, textureId) + geom = _GEOM_NAMES.get(int(v[2]), "unknown") + dims = list(v[3]) + mesh_url = _asset_url(v[4]) + rgba = list(v[7]) if v[7] is not None else [1, 1, 1, 1] + visual_pos = list(v[5]) + visual_orn = list(v[6]) + link_idx = int(v[1]) + if link_idx not in link_pose_in_base_cache: + ls = p.getLinkState(body_id, link_idx, physicsClientId=cid) + # worldLinkFramePosition (4) / Orientation (5) is + # the URDF link frame pose in world. Convert to + # base frame. + link_pose_in_base_cache[link_idx] = p.multiplyTransforms( + inv_base[0], inv_base[1], ls[4], ls[5]) + link_pos, link_orn = link_pose_in_base_cache[link_idx] + shape_pos, shape_orn = p.multiplyTransforms( + link_pos, link_orn, visual_pos, visual_orn) + shapes.append({ + "link": link_idx, + "geom": geom, + "dims": dims, + "mesh_url": mesh_url, + "local_pos": list(shape_pos), + "local_orn": list(shape_orn), + "rgba": rgba, + }) + entry["kind"] = "primitive" + entry["shapes"] = shapes + return entry + + def _snapshot_link_colors(self, body_id, cid): + """Return {link_idx: [r,g,b,a]} for the body's current visuals.""" + out = {} + try: + vis = p.getVisualShapeData(body_id, physicsClientId=cid) + except p.error: + return out + for v in vis: + link_idx = int(v[1]) + rgba = list(v[7]) if v[7] is not None else [1, 1, 1, 1] + # Multi-shape links can repeat; first one wins. envs that + # tint a multi-link body recolor every link uniformly, so + # this matches what JS would have rendered originally. + if link_idx not in out: + out[link_idx] = rgba + return out + + def _current_body_ids(self, cid): + """Return the list of live body IDs. + + Uses ``getBodyUniqueId`` rather than ``range(getNumBodies)`` + because pybullet may not reuse IDs after ``removeBody`` — grow + recreates the liquid block every pour tick, so the live ID set + drifts upward. + """ + return [ + p.getBodyUniqueId(i, physicsClientId=cid) + for i in range(p.getNumBodies(physicsClientId=cid)) + ] + + def get_scene_manifest(self): + """Walk all bodies in the current physics client and describe them as + Three.js-buildable entries.""" + cid = self.env._physics_client_id # noqa: SLF001 + entries = [] + for body_id in self._current_body_ids(cid): + entry = self._describe_body(body_id, cid) + if entry is not None: + entries.append(entry) + return entries + + def _collect_link_colors(self, body_id, cid, base_link_name): + """Return {link_name: [r,g,b,a]} for one URDF body. + + Link -1 is pybullet's base-link convention; its name is the + URDF's root name (returned by p.getBodyInfo). Child links + (linkIdx >= 0) are named in jointInfo[12]. We keep just one RGBA + per link — multi-visual links in our envs all get tinted to the + same color via create_object. + """ + link_idx_to_name = {-1: base_link_name} + for j in range(p.getNumJoints(body_id, physicsClientId=cid)): + jinfo = p.getJointInfo(body_id, j, physicsClientId=cid) + link_idx_to_name[j] = (jinfo[12].decode() + if jinfo[12] else f"link_{j}") + out = {} + for v in p.getVisualShapeData(body_id, physicsClientId=cid): + link_idx = int(v[1]) + name = link_idx_to_name.get(link_idx, f"link_{link_idx}") + if name in out: + continue # keep first rgba per link + out[name] = list(v[7]) if v[7] is not None else [1, 1, 1, 1] + return out + + def get_body_state(self, body_id): + """Return base pose + joint angles for a single body.""" + cid = self.env._physics_client_id # noqa: SLF001 + pos, orn = p.getBasePositionAndOrientation(body_id, + physicsClientId=cid) + joints = {} + nj = p.getNumJoints(body_id, physicsClientId=cid) + for j in range(nj): + jinfo = p.getJointInfo(body_id, j, physicsClientId=cid) + jname = jinfo[1].decode() if jinfo[1] else f"joint_{j}" + jtype = int(jinfo[2]) + if jtype == 4: # FIXED + continue + jstate = p.getJointState(body_id, j, physicsClientId=cid) + joints[jname] = float(jstate[0]) + return {"pos": list(pos), "orn": list(orn), "joints": joints} + + def get_all_body_states(self): + cid = self.env._physics_client_id # noqa: SLF001 + return { + body_id: self.get_body_state(body_id) + for body_id in self._current_body_ids(cid) + } + + # -- Option-level introspection --------------------------------- + def list_options(self): + from predicators.ground_truth_models import get_gt_options + if self.env is None: + return [] + try: + options = get_gt_options(self.env.get_name()) + except NotImplementedError: + return [] + return [{ + "name": opt.name, + "type_names": [t.name for t in opt.types], + "params_dim": int(opt.params_space.shape[0]), + } for opt in sorted(options, key=lambda o: o.name)] + + def list_objects(self): + if self.env is None: + return [] + return [ + { + "name": o.name, + "type_name": o.type.name + } for o in sorted( + self.env._objects, # noqa: SLF001 + key=lambda o: o.name) + ] + + def _sample_ground_option(self, opt, chosen, state, rng): + """Ground `opt(chosen)` using the matching NSRT's state-aware sampler. + + Falls back to a uniform sample if no NSRT matches (e.g. envs + without a process/NSRT factory). + """ + goal = self.env._current_task.goal # noqa: SLF001 + objects = set(self.env._objects) # noqa: SLF001 + chosen_tuple = tuple(chosen) + # ParameterizedOption.__eq__ compares by name (structs.py:1037). + # get_gt_options() builds fresh option instances each call, so + # `nsrt.option is opt` is always False — use name equality. + param_dim = int(opt.params_space.shape[0]) + for nsrt in self._nsrts: + if nsrt.option.name != opt.name: + continue + for gnsrt in _utils.all_ground_nsrts(nsrt, objects): + if tuple(gnsrt.option_objs) != chosen_tuple: + continue + # Some upstream NSRTs use `null_sampler` even when the + # paired option (e.g. grow.PickJug) has a non-empty + # params_space — the matching state-aware sampler + # (e.g. _pick_sampler -> [0.0]) lives on the *process* + # factory, which we can't load in Pyodide (torch). + # When the dim doesn't match, fall back to the low + # bound (params_space[0] for the skill-factory options + # is the "natural default", e.g. grasp_z_offset=0.0). + params = gnsrt._sampler( + state, + goal, + rng, # noqa: SLF001 + gnsrt.objects) + if len(params) != param_dim: + params = np.asarray(opt.params_space.low, dtype=np.float32) + else: + params = np.clip(params, opt.params_space.low, + opt.params_space.high) + return opt.ground(chosen, params) + break + low = opt.params_space.low + high = opt.params_space.high + params = rng.uniform(low, high).astype(np.float32) + return opt.ground(chosen, params) + + def begin_option(self, + option_name, + object_names, + params=None, + max_steps=1000): + """Ground an option and stash the rollout state so JS can drive it one + sim step at a time via :meth:`step_option`. + + Returns ``{initial_frame, error}``. On error (unknown option, + not initiable, etc.) the option is *not* armed and step_option + will be a no-op. + """ + from predicators.ground_truth_models import get_gt_options + try: + options = {o.name: o for o in get_gt_options(self.env.get_name())} + except NotImplementedError: + options = {} + if option_name not in options: + return { + "initial_frame": {}, + "error": f"Unknown option: {option_name}" + } + opt = options[option_name] + + name_to_obj = {o.name: o for o in self.env._objects} # noqa: SLF001 + try: + chosen = [name_to_obj[n] for n in object_names] + except KeyError as e: + return { + "initial_frame": {}, + "error": f"Unknown object: {e.args[0]}" + } + for o, t in zip(chosen, opt.types): + if not o.is_instance(t): + return { + "initial_frame": {}, + "error": (f"Object {o.name} of type {o.type.name} " + f"doesn't match expected type {t.name} " + f"for option {opt.name}") + } + + state = self.env._current_observation # noqa: SLF001 + try: + if params is None: + rng = np.random.default_rng(0) + ground_opt = self._sample_ground_option( + opt, chosen, state, rng) + else: + params_arr = np.asarray(params, dtype=np.float32) + ground_opt = opt.ground(chosen, params_arr) + except Exception as e: # pylint: disable=broad-except + # NSRT samplers and option grounding can raise a variety of + # errors (KeyError from a missing object, ValueError from a + # params_space mismatch, AssertionError from a sampler + # precondition). Convert to the JSON error shape rather + # than crashing across the Pyodide → JS boundary. + return { + "initial_frame": {}, + "error": f"Grounding failed: {type(e).__name__}: {e}" + } + + if not ground_opt.initiable(state): + return { + "initial_frame": {}, + "error": (f"{option_name}({','.join(object_names)}) " + "is not initiable in the current state.") + } + + # Snapshot body IDs at option start so step_option can diff + # per-step additions/removals (e.g. grow's plant spawning + # mid-Pour). Colors are snapshotted lazily at end-of-option + # since envs only repaint on terminal predicates and the + # per-step pybullet calls aren't free. + cid = self.env._physics_client_id # noqa: SLF001 + prev_ids = set(self._current_body_ids(cid)) + self._active_option = { + "ground_opt": ground_opt, + "state": state, + "steps": 0, + "max_steps": int(max_steps), + "prev_ids": prev_ids, + "name": option_name, + "args": list(object_names), + } + return {"initial_frame": self.get_all_body_states(), "error": None} + + def step_option(self): + """Run one pybullet step of the currently-armed option. + + Returns ``{frame, steps, done, error, added_bodies, + removed_body_ids, color_updates}``. Caller should loop + ``requestAnimationFrame`` over this until ``done`` is True. + ``color_updates`` is populated only on the terminal step. + """ + ao = self._active_option + if ao is None: + return { + "frame": {}, + "steps": 0, + "done": True, + "error": "no option in flight", + "added_bodies": [], + "removed_body_ids": [], + "color_updates": {} + } + + error_msg = None + done = False + try: + act = ao["ground_opt"].policy(ao["state"]) + ao["state"] = self.env.step(act) + ao["steps"] += 1 + if ao["ground_opt"].terminal(ao["state"]): + done = True + except _utils.OptionExecutionFailure as e: + # Same surface as human_option_control_approach + # (predicators/approaches/human_option_control_approach + # .py:104): bare reason, not class name + repr. + error_msg = str(e.args[0]) if e.args else "Option failed." + done = True + except Exception as e: # pylint: disable=broad-except + # Any other exception from policy/step/terminal needs to + # land on the JS side as a clean done+error rather than + # crashing the bridge mid-rollout (leaving _active_option + # armed for the next step_option). Common offenders: + # IndexError from a sampler-produced bad state, ValueError + # from numpy shape mismatches. + error_msg = f"{type(e).__name__}: {e}" + done = True + if ao["steps"] >= ao["max_steps"]: + done = True + + cid = self.env._physics_client_id # noqa: SLF001 + current_ids = set(self._current_body_ids(cid)) + added_ids = sorted(current_ids - ao["prev_ids"]) + removed_ids = sorted(ao["prev_ids"] - current_ids) + added = [ + e for e in (self._describe_body(bid, cid) for bid in added_ids) + if e is not None + ] + ao["prev_ids"] = current_ids + + # Color diff is only computed at end-of-option — repainting via + # p.changeVisualShape almost always happens on a terminal + # predicate firing (balance button on Balance, coffee button on + # Brew), and per-step _snapshot_link_colors is ~50ms otherwise. + color_updates = {} + if done: + new_colors = {} + for bid in current_ids: + now = self._snapshot_link_colors(bid, cid) + new_colors[bid] = now + if bid in added_ids: + continue # mounted with current color already + prev = self._known_link_colors.get(bid, {}) + changed = { + link_idx: rgba + for link_idx, rgba in now.items() + if prev.get(link_idx) != rgba + } + if changed: + color_updates[bid] = changed + self._known_link_colors = new_colors + + self._known_body_ids = current_ids + result = { + "frame": self.get_all_body_states(), + "steps": ao["steps"], + "done": done, + "error": error_msg, + "added_bodies": added, + "removed_body_ids": removed_ids, + "color_updates": color_updates, + } + if done: + self._active_option = None + return result + + +bridge = _Bridge() +print("predicators bridge ready") diff --git a/web/app/sweep_smoke.mjs b/web/app/sweep_smoke.mjs new file mode 100644 index 0000000000..58df0af17a --- /dev/null +++ b/web/app/sweep_smoke.mjs @@ -0,0 +1,107 @@ +// Boots a headless browser + Pyodide once, then resets the bridge for +// each env in the dropdown in turn. Much faster than running +// browser_smoke.mjs per env (Pyodide cold-boot is the dominant cost). +// +// web/app/serve.sh & +// node web/app/sweep_smoke.mjs +// +// Exits 0 iff every env reaches `action_dim=...` (i.e. the bridge +// returned a non-error reset). Saves /tmp/predicators_sweep_.png +// after each successful reset. + +import puppeteer from "puppeteer-core"; +import { readFileSync } from "node:fs"; +import { resolve, dirname } from "node:path"; +import { fileURLToPath } from "node:url"; + +const URL = process.env.URL || "http://localhost:8765/app/"; +const CHROMIUM = process.env.CHROMIUM || "/usr/bin/chromium"; +const PER_ENV_TIMEOUT_MS = Number(process.env.PER_ENV_TIMEOUT_MS || 60000); + +// Parse the dropdown from index.html so the env list stays in lockstep +// with what users see in the UI. Strips the trailing " (no options)" +// label hints — we only need the option `value`. +const HERE = dirname(fileURLToPath(import.meta.url)); +const HTML = readFileSync(resolve(HERE, "index.html"), "utf8"); +const ENVS = Array.from(HTML.matchAll(/value="(pybullet_[a-z_]+)"/g)) + .map((m) => m[1]); +if (!ENVS.length) { + console.error("No pybullet envs found in index.html dropdown"); + process.exit(2); +} +console.log(`Sweeping ${ENVS.length} envs:`, ENVS.join(" ")); + +const browser = await puppeteer.launch({ + executablePath: CHROMIUM, + headless: "new", + args: [ + "--no-sandbox", "--disable-dev-shm-usage", + "--use-gl=swiftshader", "--enable-unsafe-swiftshader", + "--ignore-gpu-blocklist", + ], +}); + +const page = await browser.newPage(); +page.on("pageerror", (err) => console.error("[pageerror]", err.message)); +// Surface Python tracebacks (logged via console.log from Pyodide). +page.on("console", (msg) => { + const t = msg.text(); + if (t.includes("ERROR") || t.includes("Traceback") + || t.includes("ModuleNotFound") || t.includes("FATAL")) { + console.log("[page]", msg.type(), t); + } +}); + +await page.goto(URL, { waitUntil: "load" }); + +console.log("Waiting for Pyodide bridge…"); +await page.waitForFunction( + () => !document.getElementById("env-select").disabled, + { timeout: 180000 }, +); +console.log("Bridge ready."); + +const results = []; +for (const env of ENVS) { + process.stdout.write(` ${env.padEnd(36)}`); + // Reset the info text so the next waitForFunction polls *this* run. + await page.evaluate(() => { + document.getElementById("info").textContent = ""; + }); + await page.select("#env-select", env); + await page.click("#boot-env"); + + let info = null; + let ok = false; + try { + await page.waitForFunction( + () => document.getElementById("info").textContent.includes("action_dim="), + { timeout: PER_ENV_TIMEOUT_MS }, + ); + info = await page.$eval("#info", (el) => el.textContent); + ok = true; + } catch (e) { + info = await page.$eval("#status", (el) => el.textContent) + .catch(() => ""); + } + results.push({ env, ok, info }); + console.log(ok ? `OK ${info}` : `FAIL ${info}`); + + if (ok) { + await page.screenshot({ + path: `/tmp/predicators_sweep_${env}.png`, + fullPage: false, + }).catch(() => {}); + } +} + +await browser.close(); + +const failed = results.filter((r) => !r.ok); +console.log(""); +console.log(`=== summary: ${results.length - failed.length}/${results.length} OK ===`); +if (failed.length) { + for (const r of failed) console.log(` FAIL ${r.env}: ${r.info}`); + process.exit(1); +} +process.exit(0); diff --git a/web/package-lock.json b/web/package-lock.json new file mode 100644 index 0000000000..61853bdf47 --- /dev/null +++ b/web/package-lock.json @@ -0,0 +1,971 @@ +{ + "name": "predicators-web", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "predicators-web", + "dependencies": { + "puppeteer-core": "^24.0.0", + "pyodide": "^0.29.4", + "three": "^0.184.0", + "urdf-loader": "^0.12.7" + } + }, + "node_modules/@puppeteer/browsers": { + "version": "2.13.2", + "resolved": "https://registry.npmjs.org/@puppeteer/browsers/-/browsers-2.13.2.tgz", + "integrity": "sha512-5EUZSUIc37H6aIXyWO0Z4y8NlF8NnjgmqeQgOGiswAU7pY0HOo16ho4+alIWmSfdZnjqBRawMsP3I5YqLSn6kw==", + "license": "Apache-2.0", + "dependencies": { + "debug": "^4.4.3", + "extract-zip": "^2.0.1", + "progress": "^2.0.3", + "proxy-agent": "^6.5.0", + "semver": "^7.7.4", + "tar-fs": "^3.1.1", + "yargs": "^17.7.2" + }, + "bin": { + "browsers": "lib/cjs/main-cli.js" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@tootallnate/quickjs-emscripten": { + "version": "0.23.0", + "resolved": "https://registry.npmjs.org/@tootallnate/quickjs-emscripten/-/quickjs-emscripten-0.23.0.tgz", + "integrity": "sha512-C5Mc6rdnsaJDjO3UpGW/CQTHtCKaYlScZTly4JIu97Jxo/odCiH0ITnDXSJPTOrEKk/ycSZ0AOgTmkDtkOsvIA==", + "license": "MIT" + }, + "node_modules/@types/emscripten": { + "version": "1.41.5", + "resolved": "https://registry.npmjs.org/@types/emscripten/-/emscripten-1.41.5.tgz", + "integrity": "sha512-cMQm7pxu6BxtHyqJ7mQZ2kXWV5SLmugybFdHCBbJ5eHzOo6VhBckEgAT3//rP5FwPHNPeEiq4SmQ5ucBwsOo4Q==", + "license": "MIT" + }, + "node_modules/@types/node": { + "version": "25.8.0", + "resolved": "https://registry.npmjs.org/@types/node/-/node-25.8.0.tgz", + "integrity": "sha512-TCFSk8IZh+iLX1xtksoBVtdmgL+1IX0fC9BeU4QqFSuNdN/K+HUlhqOzEmSYYpZUVsLYcPqc9KX+60iDuninSQ==", + "license": "MIT", + "optional": true, + "dependencies": { + "undici-types": ">=7.24.0 <7.24.7" + } + }, + "node_modules/@types/yauzl": { + "version": "2.10.3", + "resolved": "https://registry.npmjs.org/@types/yauzl/-/yauzl-2.10.3.tgz", + "integrity": "sha512-oJoftv0LSuaDZE3Le4DbKX+KS9G36NzOeSap90UIK0yMA/NhKJhqlSGtNDORNRaIbQfzjXDrQa0ytJ6mNRGz/Q==", + "license": "MIT", + "optional": true, + "dependencies": { + "@types/node": "*" + } + }, + "node_modules/agent-base": { + "version": "7.1.4", + "resolved": "https://registry.npmjs.org/agent-base/-/agent-base-7.1.4.tgz", + "integrity": "sha512-MnA+YT8fwfJPgBx3m60MNqakm30XOkyIoH1y6huTQvC0PwZG7ki8NacLBcrPbNoo8vEZy7Jpuk7+jMO+CUovTQ==", + "license": "MIT", + "engines": { + "node": ">= 14" + } + }, + "node_modules/ansi-regex": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", + "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/ansi-styles": { + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", + "integrity": "sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==", + "license": "MIT", + "dependencies": { + "color-convert": "^2.0.1" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, + "node_modules/ast-types": { + "version": "0.13.4", + "resolved": "https://registry.npmjs.org/ast-types/-/ast-types-0.13.4.tgz", + "integrity": "sha512-x1FCFnFifvYDDzTaLII71vG5uvDwgtmDTEVWAxrgeiR8VjMONcCXJx7E+USjDtHlwFmt9MysbqgF9b9Vjr6w+w==", + "license": "MIT", + "dependencies": { + "tslib": "^2.0.1" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/b4a": { + "version": "1.8.1", + "resolved": "https://registry.npmjs.org/b4a/-/b4a-1.8.1.tgz", + "integrity": "sha512-aiqre1Nr0B/6DgE2N5vwTc+2/oQZ4Wh1t4NznYY4E00y8LCt6NqdRv81so00oo27D8MVKTpUa/MwUUtBLXCoDw==", + "license": "Apache-2.0", + "peerDependencies": { + "react-native-b4a": "*" + }, + "peerDependenciesMeta": { + "react-native-b4a": { + "optional": true + } + } + }, + "node_modules/bare-events": { + "version": "2.8.3", + "resolved": "https://registry.npmjs.org/bare-events/-/bare-events-2.8.3.tgz", + "integrity": "sha512-HdUm8EMQBLaJvGUdidNNbqpA1kYkwNcb+MYxkxCLAPJGQzlv9J0C24h8V65Z4c5GLd/JEALDvpFCQgpLJqc0zw==", + "license": "Apache-2.0", + "peerDependencies": { + "bare-abort-controller": "*" + }, + "peerDependenciesMeta": { + "bare-abort-controller": { + "optional": true + } + } + }, + "node_modules/bare-fs": { + "version": "4.7.1", + "resolved": "https://registry.npmjs.org/bare-fs/-/bare-fs-4.7.1.tgz", + "integrity": "sha512-WDRsyVN52eAx/lBamKD6uyw8H4228h/x0sGGGegOamM2cd7Pag88GfMQalobXI+HaEUxpCkbKQUDOQqt9wawRw==", + "license": "Apache-2.0", + "dependencies": { + "bare-events": "^2.5.4", + "bare-path": "^3.0.0", + "bare-stream": "^2.6.4", + "bare-url": "^2.2.2", + "fast-fifo": "^1.3.2" + }, + "engines": { + "bare": ">=1.16.0" + }, + "peerDependencies": { + "bare-buffer": "*" + }, + "peerDependenciesMeta": { + "bare-buffer": { + "optional": true + } + } + }, + "node_modules/bare-os": { + "version": "3.9.1", + "resolved": "https://registry.npmjs.org/bare-os/-/bare-os-3.9.1.tgz", + "integrity": "sha512-6M5XjcnsygQNPMCMPXSK379xrJFiZ/AEMNBmFEmQW8d/789VQATvriyi5r0HYTL9TkQ26rn3kgdTG3aisbrXkQ==", + "license": "Apache-2.0", + "engines": { + "bare": ">=1.14.0" + } + }, + "node_modules/bare-path": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/bare-path/-/bare-path-3.0.0.tgz", + "integrity": "sha512-tyfW2cQcB5NN8Saijrhqn0Zh7AnFNsnczRcuWODH0eYAXBsJ5gVxAUuNr7tsHSC6IZ77cA0SitzT+s47kot8Mw==", + "license": "Apache-2.0", + "dependencies": { + "bare-os": "^3.0.1" + } + }, + "node_modules/bare-stream": { + "version": "2.13.1", + "resolved": "https://registry.npmjs.org/bare-stream/-/bare-stream-2.13.1.tgz", + "integrity": "sha512-Vp0cnjYyrEC4whYTymQ+YZi6pBpfiICZO3cfRG8sy67ZNWe951urv1x4eW1BKNngw3U+3fPYb5JQvHbCtxH7Ow==", + "license": "Apache-2.0", + "dependencies": { + "streamx": "^2.25.0", + "teex": "^1.0.1" + }, + "peerDependencies": { + "bare-abort-controller": "*", + "bare-buffer": "*", + "bare-events": "*" + }, + "peerDependenciesMeta": { + "bare-abort-controller": { + "optional": true + }, + "bare-buffer": { + "optional": true + }, + "bare-events": { + "optional": true + } + } + }, + "node_modules/bare-url": { + "version": "2.4.3", + "resolved": "https://registry.npmjs.org/bare-url/-/bare-url-2.4.3.tgz", + "integrity": "sha512-Kccpc7ACfXaxfeInfqKcZtW4pT5YBn1mesc4sCsun6sRwtbJ4h+sNOaksUpYEJUKfN65YWC6Bw2OJEFiKxq8nQ==", + "license": "Apache-2.0", + "dependencies": { + "bare-path": "^3.0.0" + } + }, + "node_modules/basic-ftp": { + "version": "5.3.1", + "resolved": "https://registry.npmjs.org/basic-ftp/-/basic-ftp-5.3.1.tgz", + "integrity": "sha512-bopVNp6ugyA150DDuZfPFdt1KZ5a94ZDiwX4hMgZDzF+GttD80lEy8kj98kbyhLXnPvhtIo93mdnLIjpCAeeOw==", + "license": "MIT", + "engines": { + "node": ">=10.0.0" + } + }, + "node_modules/buffer-crc32": { + "version": "0.2.13", + "resolved": "https://registry.npmjs.org/buffer-crc32/-/buffer-crc32-0.2.13.tgz", + "integrity": "sha512-VO9Ht/+p3SN7SKWqcrgEzjGbRSJYTx+Q1pTQC0wrWqHx0vpJraQ6GtHx8tvcg1rlK1byhU5gccxgOgj7B0TDkQ==", + "license": "MIT", + "engines": { + "node": "*" + } + }, + "node_modules/chromium-bidi": { + "version": "14.0.0", + "resolved": "https://registry.npmjs.org/chromium-bidi/-/chromium-bidi-14.0.0.tgz", + "integrity": "sha512-9gYlLtS6tStdRWzrtXaTMnqcM4dudNegMXJxkR0I/CXObHalYeYcAMPrL19eroNZHtJ8DQmu1E+ZNOYu/IXMXw==", + "license": "Apache-2.0", + "dependencies": { + "mitt": "^3.0.1", + "zod": "^3.24.1" + }, + "peerDependencies": { + "devtools-protocol": "*" + } + }, + "node_modules/cliui": { + "version": "8.0.1", + "resolved": "https://registry.npmjs.org/cliui/-/cliui-8.0.1.tgz", + "integrity": "sha512-BSeNnyus75C4//NQ9gQt1/csTXyo/8Sb+afLAkzAptFuMsod9HFokGNudZpi/oQV73hnVK+sR+5PVRMd+Dr7YQ==", + "license": "ISC", + "dependencies": { + "string-width": "^4.2.0", + "strip-ansi": "^6.0.1", + "wrap-ansi": "^7.0.0" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/color-convert": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", + "integrity": "sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==", + "license": "MIT", + "dependencies": { + "color-name": "~1.1.4" + }, + "engines": { + "node": ">=7.0.0" + } + }, + "node_modules/color-name": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz", + "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", + "license": "MIT" + }, + "node_modules/data-uri-to-buffer": { + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/data-uri-to-buffer/-/data-uri-to-buffer-6.0.2.tgz", + "integrity": "sha512-7hvf7/GW8e86rW0ptuwS3OcBGDjIi6SZva7hCyWC0yYry2cOPmLIjXAUHI6DK2HsnwJd9ifmt57i8eV2n4YNpw==", + "license": "MIT", + "engines": { + "node": ">= 14" + } + }, + "node_modules/debug": { + "version": "4.4.3", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.4.3.tgz", + "integrity": "sha512-RGwwWnwQvkVfavKVt22FGLw+xYSdzARwm0ru6DhTVA3umU5hZc28V3kO4stgYryrTlLpuvgI9GiijltAjNbcqA==", + "license": "MIT", + "dependencies": { + "ms": "^2.1.3" + }, + "engines": { + "node": ">=6.0" + }, + "peerDependenciesMeta": { + "supports-color": { + "optional": true + } + } + }, + "node_modules/degenerator": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/degenerator/-/degenerator-5.0.1.tgz", + "integrity": "sha512-TllpMR/t0M5sqCXfj85i4XaAzxmS5tVA16dqvdkMwGmzI+dXLXnw3J+3Vdv7VKw+ThlTMboK6i9rnZ6Nntj5CQ==", + "license": "MIT", + "dependencies": { + "ast-types": "^0.13.4", + "escodegen": "^2.1.0", + "esprima": "^4.0.1" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/devtools-protocol": { + "version": "0.0.1608973", + "resolved": "https://registry.npmjs.org/devtools-protocol/-/devtools-protocol-0.0.1608973.tgz", + "integrity": "sha512-Tpm17fxYzt+J7VrGdc1k8YdRqS3YV7se/M6KeemEqvUbq/n7At1rWVuXMxQgpWkdwSdIEKYbU//Bve+Shm4YNQ==", + "license": "BSD-3-Clause" + }, + "node_modules/emoji-regex": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", + "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", + "license": "MIT" + }, + "node_modules/end-of-stream": { + "version": "1.4.5", + "resolved": "https://registry.npmjs.org/end-of-stream/-/end-of-stream-1.4.5.tgz", + "integrity": "sha512-ooEGc6HP26xXq/N+GCGOT0JKCLDGrq2bQUZrQ7gyrJiZANJ/8YDTxTpQBXGMn+WbIQXNVpyWymm7KYVICQnyOg==", + "license": "MIT", + "dependencies": { + "once": "^1.4.0" + } + }, + "node_modules/escalade": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz", + "integrity": "sha512-WUj2qlxaQtO4g6Pq5c29GTcWGDyd8itL8zTlipgECz3JesAiiOKotd8JU6otB3PACgG6xkJUyVhboMS+bje/jA==", + "license": "MIT", + "engines": { + "node": ">=6" + } + }, + "node_modules/escodegen": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/escodegen/-/escodegen-2.1.0.tgz", + "integrity": "sha512-2NlIDTwUWJN0mRPQOdtQBzbUHvdGY2P1VXSyU83Q3xKxM7WHX2Ql8dKq782Q9TgQUNOLEzEYu9bzLNj1q88I5w==", + "license": "BSD-2-Clause", + "dependencies": { + "esprima": "^4.0.1", + "estraverse": "^5.2.0", + "esutils": "^2.0.2" + }, + "bin": { + "escodegen": "bin/escodegen.js", + "esgenerate": "bin/esgenerate.js" + }, + "engines": { + "node": ">=6.0" + }, + "optionalDependencies": { + "source-map": "~0.6.1" + } + }, + "node_modules/esprima": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/esprima/-/esprima-4.0.1.tgz", + "integrity": "sha512-eGuFFw7Upda+g4p+QHvnW0RyTX/SVeJBDM/gCtMARO0cLuT2HcEKnTPvhjV6aGeqrCB/sbNop0Kszm0jsaWU4A==", + "license": "BSD-2-Clause", + "bin": { + "esparse": "bin/esparse.js", + "esvalidate": "bin/esvalidate.js" + }, + "engines": { + "node": ">=4" + } + }, + "node_modules/estraverse": { + "version": "5.3.0", + "resolved": "https://registry.npmjs.org/estraverse/-/estraverse-5.3.0.tgz", + "integrity": "sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA==", + "license": "BSD-2-Clause", + "engines": { + "node": ">=4.0" + } + }, + "node_modules/esutils": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/esutils/-/esutils-2.0.3.tgz", + "integrity": "sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==", + "license": "BSD-2-Clause", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/events-universal": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/events-universal/-/events-universal-1.0.1.tgz", + "integrity": "sha512-LUd5euvbMLpwOF8m6ivPCbhQeSiYVNb8Vs0fQ8QjXo0JTkEHpz8pxdQf0gStltaPpw0Cca8b39KxvK9cfKRiAw==", + "license": "Apache-2.0", + "dependencies": { + "bare-events": "^2.7.0" + } + }, + "node_modules/extract-zip": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/extract-zip/-/extract-zip-2.0.1.tgz", + "integrity": "sha512-GDhU9ntwuKyGXdZBUgTIe+vXnWj0fppUEtMDL0+idd5Sta8TGpHssn/eusA9mrPr9qNDym6SxAYZjNvCn/9RBg==", + "license": "BSD-2-Clause", + "dependencies": { + "debug": "^4.1.1", + "get-stream": "^5.1.0", + "yauzl": "^2.10.0" + }, + "bin": { + "extract-zip": "cli.js" + }, + "engines": { + "node": ">= 10.17.0" + }, + "optionalDependencies": { + "@types/yauzl": "^2.9.1" + } + }, + "node_modules/fast-fifo": { + "version": "1.3.2", + "resolved": "https://registry.npmjs.org/fast-fifo/-/fast-fifo-1.3.2.tgz", + "integrity": "sha512-/d9sfos4yxzpwkDkuN7k2SqFKtYNmCTzgfEpz82x34IM9/zc8KGxQoXg1liNC/izpRM/MBdt44Nmx41ZWqk+FQ==", + "license": "MIT" + }, + "node_modules/fd-slicer": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/fd-slicer/-/fd-slicer-1.1.0.tgz", + "integrity": "sha512-cE1qsB/VwyQozZ+q1dGxR8LBYNZeofhEdUNGSMbQD3Gw2lAzX9Zb3uIU6Ebc/Fmyjo9AWWfnn0AUCHqtevs/8g==", + "license": "MIT", + "dependencies": { + "pend": "~1.2.0" + } + }, + "node_modules/get-caller-file": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/get-caller-file/-/get-caller-file-2.0.5.tgz", + "integrity": "sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg==", + "license": "ISC", + "engines": { + "node": "6.* || 8.* || >= 10.*" + } + }, + "node_modules/get-stream": { + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/get-stream/-/get-stream-5.2.0.tgz", + "integrity": "sha512-nBF+F1rAZVCu/p7rjzgA+Yb4lfYXrpl7a6VmJrU8wF9I1CKvP/QwPNZHnOlwbTkY6dvtFIzFMSyQXbLoTQPRpA==", + "license": "MIT", + "dependencies": { + "pump": "^3.0.0" + }, + "engines": { + "node": ">=8" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/get-uri": { + "version": "6.0.5", + "resolved": "https://registry.npmjs.org/get-uri/-/get-uri-6.0.5.tgz", + "integrity": "sha512-b1O07XYq8eRuVzBNgJLstU6FYc1tS6wnMtF1I1D9lE8LxZSOGZ7LhxN54yPP6mGw5f2CkXY2BQUL9Fx41qvcIg==", + "license": "MIT", + "dependencies": { + "basic-ftp": "^5.0.2", + "data-uri-to-buffer": "^6.0.2", + "debug": "^4.3.4" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/http-proxy-agent": { + "version": "7.0.2", + "resolved": "https://registry.npmjs.org/http-proxy-agent/-/http-proxy-agent-7.0.2.tgz", + "integrity": "sha512-T1gkAiYYDWYx3V5Bmyu7HcfcvL7mUrTWiM6yOfa3PIphViJ/gFPbvidQ+veqSOHci/PxBcDabeUNCzpOODJZig==", + "license": "MIT", + "dependencies": { + "agent-base": "^7.1.0", + "debug": "^4.3.4" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/https-proxy-agent": { + "version": "7.0.6", + "resolved": "https://registry.npmjs.org/https-proxy-agent/-/https-proxy-agent-7.0.6.tgz", + "integrity": "sha512-vK9P5/iUfdl95AI+JVyUuIcVtd4ofvtrOr3HNtM2yxC9bnMbEdp3x01OhQNnjb8IJYi38VlTE3mBXwcfvywuSw==", + "license": "MIT", + "dependencies": { + "agent-base": "^7.1.2", + "debug": "4" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/ip-address": { + "version": "10.2.0", + "resolved": "https://registry.npmjs.org/ip-address/-/ip-address-10.2.0.tgz", + "integrity": "sha512-/+S6j4E9AHvW9SWMSEY9Xfy66O5PWvVEJ08O0y5JGyEKQpojb0K0GKpz/v5HJ/G0vi3D2sjGK78119oXZeE0qA==", + "license": "MIT", + "engines": { + "node": ">= 12" + } + }, + "node_modules/is-fullwidth-code-point": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz", + "integrity": "sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==", + "license": "MIT", + "engines": { + "node": ">=8" + } + }, + "node_modules/lru-cache": { + "version": "7.18.3", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-7.18.3.tgz", + "integrity": "sha512-jumlc0BIUrS3qJGgIkWZsyfAM7NCWiBcCDhnd+3NNM5KbBmLTgHVfWBcg6W+rLUsIpzpERPsvwUP7CckAQSOoA==", + "license": "ISC", + "engines": { + "node": ">=12" + } + }, + "node_modules/mitt": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/mitt/-/mitt-3.0.1.tgz", + "integrity": "sha512-vKivATfr97l2/QBCYAkXYDbrIWPM2IIKEl7YPhjCvKlG3kE2gm+uBo6nEXK3M5/Ffh/FLpKExzOQ3JJoJGFKBw==", + "license": "MIT" + }, + "node_modules/ms": { + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", + "license": "MIT" + }, + "node_modules/netmask": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/netmask/-/netmask-2.1.1.tgz", + "integrity": "sha512-eonl3sLUha+S1GzTPxychyhnUzKyeQkZ7jLjKrBagJgPla13F+uQ71HgpFefyHgqrjEbCPkDArxYsjY8/+gLKA==", + "license": "MIT", + "engines": { + "node": ">= 0.4.0" + } + }, + "node_modules/once": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/once/-/once-1.4.0.tgz", + "integrity": "sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w==", + "license": "ISC", + "dependencies": { + "wrappy": "1" + } + }, + "node_modules/pac-proxy-agent": { + "version": "7.2.0", + "resolved": "https://registry.npmjs.org/pac-proxy-agent/-/pac-proxy-agent-7.2.0.tgz", + "integrity": "sha512-TEB8ESquiLMc0lV8vcd5Ql/JAKAoyzHFXaStwjkzpOpC5Yv+pIzLfHvjTSdf3vpa2bMiUQrg9i6276yn8666aA==", + "license": "MIT", + "dependencies": { + "@tootallnate/quickjs-emscripten": "^0.23.0", + "agent-base": "^7.1.2", + "debug": "^4.3.4", + "get-uri": "^6.0.1", + "http-proxy-agent": "^7.0.0", + "https-proxy-agent": "^7.0.6", + "pac-resolver": "^7.0.1", + "socks-proxy-agent": "^8.0.5" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/pac-resolver": { + "version": "7.0.1", + "resolved": "https://registry.npmjs.org/pac-resolver/-/pac-resolver-7.0.1.tgz", + "integrity": "sha512-5NPgf87AT2STgwa2ntRMr45jTKrYBGkVU36yT0ig/n/GMAa3oPqhZfIQ2kMEimReg0+t9kZViDVZ83qfVUlckg==", + "license": "MIT", + "dependencies": { + "degenerator": "^5.0.0", + "netmask": "^2.0.2" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/pend": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/pend/-/pend-1.2.0.tgz", + "integrity": "sha512-F3asv42UuXchdzt+xXqfW1OGlVBe+mxa2mqI0pg5yAHZPvFmY3Y6drSf/GQ1A86WgWEN9Kzh/WrgKa6iGcHXLg==", + "license": "MIT" + }, + "node_modules/progress": { + "version": "2.0.3", + "resolved": "https://registry.npmjs.org/progress/-/progress-2.0.3.tgz", + "integrity": "sha512-7PiHtLll5LdnKIMw100I+8xJXR5gW2QwWYkT6iJva0bXitZKa/XMrSbdmg3r2Xnaidz9Qumd0VPaMrZlF9V9sA==", + "license": "MIT", + "engines": { + "node": ">=0.4.0" + } + }, + "node_modules/proxy-agent": { + "version": "6.5.0", + "resolved": "https://registry.npmjs.org/proxy-agent/-/proxy-agent-6.5.0.tgz", + "integrity": "sha512-TmatMXdr2KlRiA2CyDu8GqR8EjahTG3aY3nXjdzFyoZbmB8hrBsTyMezhULIXKnC0jpfjlmiZ3+EaCzoInSu/A==", + "license": "MIT", + "dependencies": { + "agent-base": "^7.1.2", + "debug": "^4.3.4", + "http-proxy-agent": "^7.0.1", + "https-proxy-agent": "^7.0.6", + "lru-cache": "^7.14.1", + "pac-proxy-agent": "^7.1.0", + "proxy-from-env": "^1.1.0", + "socks-proxy-agent": "^8.0.5" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/proxy-from-env": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/proxy-from-env/-/proxy-from-env-1.1.0.tgz", + "integrity": "sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg==", + "license": "MIT" + }, + "node_modules/pump": { + "version": "3.0.4", + "resolved": "https://registry.npmjs.org/pump/-/pump-3.0.4.tgz", + "integrity": "sha512-VS7sjc6KR7e1ukRFhQSY5LM2uBWAUPiOPa/A3mkKmiMwSmRFUITt0xuj+/lesgnCv+dPIEYlkzrcyXgquIHMcA==", + "license": "MIT", + "dependencies": { + "end-of-stream": "^1.1.0", + "once": "^1.3.1" + } + }, + "node_modules/puppeteer-core": { + "version": "24.43.1", + "resolved": "https://registry.npmjs.org/puppeteer-core/-/puppeteer-core-24.43.1.tgz", + "integrity": "sha512-T5ScUMAsmhdNbgDR41AGESYeS6V9MSgetkSnVhhW+gXvzC42VesKCn5ld87gAZDJ6vLHL9GkRvY9WtQWSnwFbw==", + "license": "Apache-2.0", + "dependencies": { + "@puppeteer/browsers": "2.13.2", + "chromium-bidi": "14.0.0", + "debug": "^4.4.3", + "devtools-protocol": "0.0.1608973", + "typed-query-selector": "^2.12.2", + "webdriver-bidi-protocol": "0.4.1", + "ws": "^8.20.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/pyodide": { + "version": "0.29.4", + "resolved": "https://registry.npmjs.org/pyodide/-/pyodide-0.29.4.tgz", + "integrity": "sha512-tCseTsqU3kSxZIjkue5zXxTMNEwrKZwOIIEQRBA/VzHxFN1hoCxe4w41phfCdHd9it9RcCNQb5K/Re0InqMgvA==", + "license": "MPL-2.0", + "dependencies": { + "@types/emscripten": "^1.41.4", + "ws": "^8.5.0" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/require-directory": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/require-directory/-/require-directory-2.1.1.tgz", + "integrity": "sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q==", + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/semver": { + "version": "7.8.0", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.8.0.tgz", + "integrity": "sha512-AcM7dV/5ul4EekoQ29Agm5vri8JNqRyj39o0qpX6vDF2GZrtutZl5RwgD1XnZjiTAfncsJhMI48QQH3sN87YNA==", + "license": "ISC", + "bin": { + "semver": "bin/semver.js" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/smart-buffer": { + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/smart-buffer/-/smart-buffer-4.2.0.tgz", + "integrity": "sha512-94hK0Hh8rPqQl2xXc3HsaBoOXKV20MToPkcXvwbISWLEs+64sBq5kFgn2kJDHb1Pry9yrP0dxrCI9RRci7RXKg==", + "license": "MIT", + "engines": { + "node": ">= 6.0.0", + "npm": ">= 3.0.0" + } + }, + "node_modules/socks": { + "version": "2.8.9", + "resolved": "https://registry.npmjs.org/socks/-/socks-2.8.9.tgz", + "integrity": "sha512-LJhUYUvItdQ0LkJTmPeaEObWXAqFyfmP85x0tch/ez9cahmhlBBLbIqDFnvBnUJGagb0JbIQrkBs1wJ+yRYpEw==", + "license": "MIT", + "dependencies": { + "ip-address": "^10.1.1", + "smart-buffer": "^4.2.0" + }, + "engines": { + "node": ">= 10.0.0", + "npm": ">= 3.0.0" + } + }, + "node_modules/socks-proxy-agent": { + "version": "8.0.5", + "resolved": "https://registry.npmjs.org/socks-proxy-agent/-/socks-proxy-agent-8.0.5.tgz", + "integrity": "sha512-HehCEsotFqbPW9sJ8WVYB6UbmIMv7kUUORIF2Nncq4VQvBfNBLibW9YZR5dlYCSUhwcD628pRllm7n+E+YTzJw==", + "license": "MIT", + "dependencies": { + "agent-base": "^7.1.2", + "debug": "^4.3.4", + "socks": "^2.8.3" + }, + "engines": { + "node": ">= 14" + } + }, + "node_modules/source-map": { + "version": "0.6.1", + "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz", + "integrity": "sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==", + "license": "BSD-3-Clause", + "optional": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/streamx": { + "version": "2.25.0", + "resolved": "https://registry.npmjs.org/streamx/-/streamx-2.25.0.tgz", + "integrity": "sha512-0nQuG6jf1w+wddNEEXCF4nTg3LtufWINB5eFEN+5TNZW7KWJp6x87+JFL43vaAUPyCfH1wID+mNVyW6OHtFamg==", + "license": "MIT", + "dependencies": { + "events-universal": "^1.0.0", + "fast-fifo": "^1.3.2", + "text-decoder": "^1.1.0" + } + }, + "node_modules/string-width": { + "version": "4.2.3", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz", + "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", + "license": "MIT", + "dependencies": { + "emoji-regex": "^8.0.0", + "is-fullwidth-code-point": "^3.0.0", + "strip-ansi": "^6.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/strip-ansi": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "license": "MIT", + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/tar-fs": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/tar-fs/-/tar-fs-3.1.2.tgz", + "integrity": "sha512-QGxxTxxyleAdyM3kpFs14ymbYmNFrfY+pHj7Z8FgtbZ7w2//VAgLMac7sT6nRpIHjppXO2AwwEOg0bPFVRcmXw==", + "license": "MIT", + "dependencies": { + "pump": "^3.0.0", + "tar-stream": "^3.1.5" + }, + "optionalDependencies": { + "bare-fs": "^4.0.1", + "bare-path": "^3.0.0" + } + }, + "node_modules/tar-stream": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/tar-stream/-/tar-stream-3.2.0.tgz", + "integrity": "sha512-ojzvCvVaNp6aOTFmG7jaRD0meowIAuPc3cMMhSgKiVWws1GyHbGd/xvnyuRKcKlMpt3qvxx6r0hreCNITP9hIg==", + "license": "MIT", + "dependencies": { + "b4a": "^1.6.4", + "bare-fs": "^4.5.5", + "fast-fifo": "^1.2.0", + "streamx": "^2.15.0" + } + }, + "node_modules/teex": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/teex/-/teex-1.0.1.tgz", + "integrity": "sha512-eYE6iEI62Ni1H8oIa7KlDU6uQBtqr4Eajni3wX7rpfXD8ysFx8z0+dri+KWEPWpBsxXfxu58x/0jvTVT1ekOSg==", + "license": "MIT", + "dependencies": { + "streamx": "^2.12.5" + } + }, + "node_modules/text-decoder": { + "version": "1.2.7", + "resolved": "https://registry.npmjs.org/text-decoder/-/text-decoder-1.2.7.tgz", + "integrity": "sha512-vlLytXkeP4xvEq2otHeJfSQIRyWxo/oZGEbXrtEEF9Hnmrdly59sUbzZ/QgyWuLYHctCHxFF4tRQZNQ9k60ExQ==", + "license": "Apache-2.0", + "dependencies": { + "b4a": "^1.6.4" + } + }, + "node_modules/three": { + "version": "0.184.0", + "resolved": "https://registry.npmjs.org/three/-/three-0.184.0.tgz", + "integrity": "sha512-wtTRjG92pM5eUg/KuUnHsqSAlPM296brTOcLgMRqEeylYTh/CdtvKUvCyyCQTzFuStieWxvZb8mVTMvdPyUpxg==", + "license": "MIT" + }, + "node_modules/tslib": { + "version": "2.8.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", + "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", + "license": "0BSD" + }, + "node_modules/typed-query-selector": { + "version": "2.12.2", + "resolved": "https://registry.npmjs.org/typed-query-selector/-/typed-query-selector-2.12.2.tgz", + "integrity": "sha512-EOPFbyIub4ngnEdqi2yOcNeDLaX/0jcE1JoAXQDDMIthap7FoN795lc/SHfIq2d416VufXpM8z/lD+WRm2gfOQ==", + "license": "MIT" + }, + "node_modules/undici-types": { + "version": "7.24.6", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-7.24.6.tgz", + "integrity": "sha512-WRNW+sJgj5OBN4/0JpHFqtqzhpbnV0GuB+OozA9gCL7a993SmU+1JBZCzLNxYsbMfIeDL+lTsphD5jN5N+n0zg==", + "license": "MIT", + "optional": true + }, + "node_modules/urdf-loader": { + "version": "0.12.7", + "resolved": "https://registry.npmjs.org/urdf-loader/-/urdf-loader-0.12.7.tgz", + "integrity": "sha512-9RQDhWt4x9K6R3uHjv68NLwX+JDm3A2XjhksrMgq6xl9NJcwqV21+htHU0cbG619hnqH1d4mUMFuqHo22R8DMA==", + "license": "Apache-2.0", + "peerDependencies": { + "three": ">=0.152.0" + } + }, + "node_modules/webdriver-bidi-protocol": { + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/webdriver-bidi-protocol/-/webdriver-bidi-protocol-0.4.1.tgz", + "integrity": "sha512-ARrjNjtWRRs2w4Tk7nqrf2gBI0QXWuOmMCx2hU+1jUt6d00MjMxURrhxhGbrsoiZKJrhTSTzbIrc554iKI10qw==", + "license": "Apache-2.0" + }, + "node_modules/wrap-ansi": { + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-7.0.0.tgz", + "integrity": "sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q==", + "license": "MIT", + "dependencies": { + "ansi-styles": "^4.0.0", + "string-width": "^4.1.0", + "strip-ansi": "^6.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/wrap-ansi?sponsor=1" + } + }, + "node_modules/wrappy": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz", + "integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==", + "license": "ISC" + }, + "node_modules/ws": { + "version": "8.20.1", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.20.1.tgz", + "integrity": "sha512-It4dO0K5v//JtTXuPkfEOaI3uUN87iYPnqo/ZzqCoG3g8uhA66QUMs/SrM0YK7/NAu+r4LMh/9dq2A7k+rHs+w==", + "license": "MIT", + "engines": { + "node": ">=10.0.0" + }, + "peerDependencies": { + "bufferutil": "^4.0.1", + "utf-8-validate": ">=5.0.2" + }, + "peerDependenciesMeta": { + "bufferutil": { + "optional": true + }, + "utf-8-validate": { + "optional": true + } + } + }, + "node_modules/y18n": { + "version": "5.0.8", + "resolved": "https://registry.npmjs.org/y18n/-/y18n-5.0.8.tgz", + "integrity": "sha512-0pfFzegeDWJHJIAmTLRP2DwHjdF5s7jo9tuztdQxAhINCdvS+3nGINqPd00AphqJR/0LhANUS6/+7SCb98YOfA==", + "license": "ISC", + "engines": { + "node": ">=10" + } + }, + "node_modules/yargs": { + "version": "17.7.2", + "resolved": "https://registry.npmjs.org/yargs/-/yargs-17.7.2.tgz", + "integrity": "sha512-7dSzzRQ++CKnNI/krKnYRV7JKKPUXMEh61soaHKg9mrWEhzFWhFnxPxGl+69cD1Ou63C13NUPCnmIcrvqCuM6w==", + "license": "MIT", + "dependencies": { + "cliui": "^8.0.1", + "escalade": "^3.1.1", + "get-caller-file": "^2.0.5", + "require-directory": "^2.1.1", + "string-width": "^4.2.3", + "y18n": "^5.0.5", + "yargs-parser": "^21.1.1" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/yargs-parser": { + "version": "21.1.1", + "resolved": "https://registry.npmjs.org/yargs-parser/-/yargs-parser-21.1.1.tgz", + "integrity": "sha512-tVpsJW7DdjecAiFpbIB1e3qxIQsE6NoPc5/eTdrbbIC4h0LVsWhnoa3g+m2HclBIujHzsxZ4VJVA+GUuc2/LBw==", + "license": "ISC", + "engines": { + "node": ">=12" + } + }, + "node_modules/yauzl": { + "version": "2.10.0", + "resolved": "https://registry.npmjs.org/yauzl/-/yauzl-2.10.0.tgz", + "integrity": "sha512-p4a9I6X6nu6IhoGmBqAcbJy1mlC4j27vEPZX9F4L4/vZT3Lyq1VkFHw/V/PUcB9Buo+DG3iHkT0x3Qya58zc3g==", + "license": "MIT", + "dependencies": { + "buffer-crc32": "~0.2.3", + "fd-slicer": "~1.1.0" + } + }, + "node_modules/zod": { + "version": "3.25.76", + "resolved": "https://registry.npmjs.org/zod/-/zod-3.25.76.tgz", + "integrity": "sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==", + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/colinhacks" + } + } + } +} diff --git a/web/package.json b/web/package.json new file mode 100644 index 0000000000..768eef9d30 --- /dev/null +++ b/web/package.json @@ -0,0 +1,11 @@ +{ + "name": "predicators-web", + "private": true, + "type": "module", + "dependencies": { + "puppeteer-core": "^24.0.0", + "pyodide": "^0.29.4", + "three": "^0.184.0", + "urdf-loader": "^0.12.7" + } +}