diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..d11c7f22 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,34 @@ +name: Lint + +on: + push: + branches: [main, development] + pull_request: + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Install ruff + run: pip install ruff + + - name: ruff check + run: ruff check . + + - name: ruff format check + run: ruff format --check . + + # Pin to match the pre-commit mirrors-mypy rev and the pyproject dev + # group so CI, local commits, and `uv sync` never run different mypys. + - name: Install mypy + run: pip install mypy==2.1.0 + + - name: mypy + run: mypy . diff --git a/.gitignore b/.gitignore index cce868ec..935470d1 100644 --- a/.gitignore +++ b/.gitignore @@ -145,3 +145,6 @@ electron/ /stage_definitions_for_review.txt gently/ui/tui/node_modules/ gently/ui/tui/dist/ + +# Runtime storage accidentally created on Linux when GENTLY_STORAGE_PATH="D:/" resolves literally +D:/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..88476e0c --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,16 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.15.17 # run `pre-commit autoupdate` to pin to latest + hooks: + - id: ruff + args: [--fix] + - id: ruff-format + + - repo: https://github.com/pre-commit/mirrors-mypy + # Keep in sync with the mypy pin in pyproject.toml dev group and + # .github/workflows/lint.yml. Run `pre-commit autoupdate` to bump. + rev: v2.1.0 + hooks: + - id: mypy + pass_filenames: false + args: ["."] diff --git a/.python-version b/.python-version new file mode 100644 index 00000000..2c073331 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.11 diff --git a/CHANGELOG.md b/CHANGELOG.md index 30669817..0097853a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -394,6 +394,52 @@ Net: ~3,500 lines removed across P6–P7. --- +## v0.22.0 + +File-based storage, a redesigned web UI, new hardware, and the tooling to +keep it all honest. + +**File-based storage (Gently3)** +Retired the SQLite databases. All state now lives as human-browsable files +under `D:\Gently3\` — sessions, embryos, volumes, projections, traces, +campaigns, learnings, agent memory, all YAML/JSONL/TIFF. +- `FileStore` replaces `GentlyStore`; `FileContextStore` replaces the + `agent_mind.db` `ContextStore`. Drop-in API replacements. +- A root `gently.yaml` manifest documents the layout for humans and agents. +- YAML parses are cached in `FileContextStore` — fixes slow Plans/campaign + loading. + +**Web UI redesign** +- Agent chat became a docked, sliding side panel (overlay + pin-to-dock) + instead of owning the screen. +- Added a Home landing tab; the chat no longer auto-runs the startup wizard. +- Login is non-blocking — a "Continue in view-only" escape hatch. +- Recent images aggregate across previous sessions. + +**Hardware** +- Integrated the ACUITYnano temperature controller (config, web control, + SDKs) with a live HiveMQ cloud SIM for hardware-free testing. +- Added the SPIM-head F-drive device, hard limits, and focus/align plans. +- Room-light toggle and a device-layer terminal UI. + +**Agent + perception** +- Integrated the agent with perception: pull tool, prompt context, event + bridge, wake-router. +- Live acquisition control with observable, permissioned autonomy and a + refreshed prompt. +- Retired napari from the agent; added web-chat autocomplete and pruned + dead tools. + +**Tooling and environment** +- Added ruff lint/format tooling and fixed all violations. +- Adopted incremental mypy typing — config, CI, pre-commit wiring, and a + documented policy in `CONTRIBUTING.md`; pinned mypy to 2.1.0. +- Switched environment setup to uv with an offline/UI-only launch path; + pinned pymmcore to device-interface 70. +- Relicensed and updated the author list. + +--- + ## Notes on how we think about this Things we've learned building this, roughly in order: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..b936732d --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,72 @@ +# Contributing to Gently + +## Code quality toolchain + +This project uses [ruff](https://docs.astral.sh/ruff/) for linting and formatting, and +[mypy](https://mypy-lang.org/) for type checking, enforced automatically before every +commit via [pre-commit](https://pre-commit.com/). + +### First-time setup + +Install the dev dependencies (includes ruff, mypy, and pre-commit): + +```bash +uv sync +``` + +Then install the pre-commit hooks: + +```bash +pre-commit install +``` + +From this point on, ruff runs on staged files and mypy runs across the whole +project whenever you `git commit`. + +### Running manually + +To check all files at once (useful before opening a PR): + +```bash +pre-commit run --all-files +``` + +Or run the tools directly: + +```bash +ruff check . # lint +ruff format . # format in-place +``` + +### Keeping hooks up to date + +To update hook versions to their latest releases: + +```bash +pre-commit autoupdate +``` + +### Type checking + +Run mypy the same way pre-commit and CI do: + +```bash +mypy . +``` + +The codebase is being typed incrementally (see issue #46). Modules with +pre-existing errors are listed in the `[[tool.mypy.overrides]]` block in +`pyproject.toml` with `ignore_errors = true`, so `mypy .` passes today even +though not every module is fully typed yet. + +Policy for working with this list: + +- **New modules** must pass `mypy .` cleanly — do not add them to the + overrides list. +- **PRs that substantively touch a module on the overrides list** should fix + that module's type errors and remove it from the list as part of the + change. + +### CI + +Every pull request runs the lint job (`.github/workflows/lint.yml`), which checks ruff lint and formatting and runs mypy across the entire project. Fix any failures locally with `pre-commit run --all-files` before pushing. diff --git a/LICENSE b/LICENSE index 486ead3a..f288702d 100644 --- a/LICENSE +++ b/LICENSE @@ -1,21 +1,674 @@ -MIT License - -Copyright (c) 2023 Kesavan Subburam - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. diff --git a/README.md b/README.md index 5dca56db..78039e70 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Agentic harness for microscopy. -**Status**: v0.11.0 — actively developed at Shroff Lab, Janelia. +**Status**: 0.22.0 — actively developed at Shroff Lab, Janelia. ![Safety Architecture](docs/images/safety_architecture.png) @@ -67,47 +67,162 @@ Currently, the sample abstraction is the `Embryo` object for *C. elegans* work. ### Prerequisites -- Python 3.11+ -- [Node.js](https://nodejs.org/) 18+ (for the Ink TUI) -- An `ANTHROPIC_API_KEY` environment variable +- Python 3.10+ +- An `ANTHROPIC_API_KEY` — either exported in your shell + (`export ANTHROPIC_API_KEY=your-key`) or placed in a `.env` file in the + project root (`ANTHROPIC_API_KEY=your-key`), which is loaded automatically + on launch and is gitignored. *(Not required if you launch with `--no-api` + to browse the UI only — see Launch below.)* +- *(Optional)* `GENTLY_STORAGE_PATH` — where sessions and data live (default `D:/Gently3`) + +Gently is **web-first**: the agent is driven from an in-page chat in your +browser. There is no TUI to build (Node.js is only needed for the paper +diagrams, not the app). ### Setup +This project uses [uv](https://docs.astral.sh/uv/) for environment and +dependency management. If you don't have it yet, install it following the +[uv installation guide](https://docs.astral.sh/uv/getting-started/installation/) +(e.g. `curl -LsSf https://astral.sh/uv/install.sh | sh` on macOS/Linux). + +Gently depends on **`gently-perception`** (the VLM perception harness, repo +`gently-project/gently-perception`), which is not published to PyPI. For development, it is +installed as an **editable sibling clone**, so clone both repos side by side: + ```bash -# Clone and install Python dependencies -git clone https://github.com/pskeshu/gently.git +git clone git@github.com:gently-project/gently.git +git clone git@github.com:gently-project/gently-perception.git + +# Layout: +# / +# gently/ <- you run commands from here +# gently-perception/ <- editable, resolved via [tool.uv.sources] + cd gently -pip install -r requirements.txt +uv sync # base env (add --extra ... for torch etc., see below) +``` + +> The `git@github.com:` URLs use **SSH**, which needs an +> [SSH key configured with GitHub](https://docs.github.com/en/authentication/connecting-to-github-with-ssh). +> If you don't use SSH, clone over HTTPS instead +> (`https://github.com/gently-project/.git`). + +`[tool.uv.sources]` resolves `gently-perception` to the sibling as an editable +install, so your perception edits are live immediately and survive `uv sync`. If +the sibling isn't cloned, `uv sync` fails by design — clone it first. + +You get a `.venv` in the project directory with the runtime + dev dependencies +pinned in `uv.lock`. Activate it with `source .venv/bin/activate`, or prefix +commands with `uv run` (e.g. `uv run python ...`) to use it without activating. + +#### Optional extras -# Build the TUI (one-time, rebuild after TUI code changes) -cd gently/tui -npm install -npm run build -cd ../.. +PyTorch is **not** in the base install — the CUDA build is machine-specific, so +it lives in mutually-exclusive extras wired to the right PyTorch index: + +```bash +# Device-layer accessories (microscope computer): BLE/serial/MQTT transports +uv sync --extra device + +# PyTorch (needed for SAM detection and the ML pipeline) +# NOTE: the GPU and CPU builds are mutually exclusive, so they can't be combined. +uv sync --extra torch-gpu # CUDA 11.8 build (GPU box, e.g. the microscope PC) +uv sync --extra torch-cpu # CPU-only build (dev laptop / CI) +``` + +#### Running tests + +```bash +uv run pytest ``` ### Launch +> The commands below use `uv run` so they work without activating the env. If you've activated it first (`source .venv/bin/activate`), the `uv run` prefix isn't necessary. + +To verify the install, you can start gently without an API key or hardware. The +web UI boots and is browsable, though the agent itself (chat, perception, plan +mode) stays disabled until you add a key: + +```bash +uv run python launch_gently.py --offline --no-api +``` + +For the full launch: + ```bash -# 1. Start the device layer (hardware control + SAM detection) -python start_device_layer.py +# 1. Device layer (hardware control + SAM detection) — separate process, own terminal +uv run python start_device_layer.py + +# 2. Agent + web UI (starts the in-process server and opens your browser) +uv run python launch_gently.py + +# Run without hardware (development / review) +uv run python launch_gently.py --offline -# 2. Launch the agent -python launch_gently.py +# UI-only — boot the web UI with no API key (chat/perception disabled) +uv run python launch_gently.py --no-api -# Or launch without hardware (for development / review) -python launch_gently.py --offline +# Don't auto-open a browser — open the printed URL yourself +uv run python launch_gently.py --no-browser # Resume a previous session -python launch_gently.py --resume # interactive picker -python launch_gently.py --resume latest # most recent session -python launch_gently.py --resume # specific session +uv run python launch_gently.py --resume # interactive picker +uv run python launch_gently.py --resume latest # most recent session +uv run python launch_gently.py --resume # specific session # Verbose / debug logging -python launch_gently.py -v # INFO level -python launch_gently.py --debug # DEBUG level +uv run python launch_gently.py -v # INFO level +uv run python launch_gently.py --debug # DEBUG level +``` + +The launcher prints a banner with the URL (default `http://localhost:8080`), +device status, storage path, and log location. Open that URL in any browser on +the LAN. + +### First sign-in (accounts) + +**Viewing is open** — the dashboard loads read-only for anyone, no login. +Signing in *elevates* you to control (driving hardware, taking the +single-operator lock); it isn't a gate on the page. + +On the **first run**, Gently creates one `admin` account and prints a one-time +random password in the startup banner: + +``` +First-run admin account created — sign in at the URL above: + username: admin + password: ``` +- **Save it now** — the password is printed to the console once and never + written to the log (only a PBKDF2 hash is stored). +- After signing in, add accounts (roles `viewer` / `operator` / `admin`) via the + admin-only `POST /api/auth/users`. +- **Lost it?** There's no reset command yet — delete + `/auth/users.yaml` and restart to bootstrap a fresh + `admin` (this clears all accounts). +- **Just trying it locally?** `GENTLY_NO_AUTH=1` disables accounts entirely + (legacy mode: localhost gets control, remote callers need `X-Gently-Token`). + +Accounts live under `/auth/` (`users.yaml` + `secret.key`), +outside the repo. + +## Make your first plan + +You don't need a microscope to try the core loop — **plan mode is pure agent reasoning and works under `--offline`**. The path from launch to an inspectable plan: + +1. **Open the agent chat.** Click **Agent** in the header (or press `Ctrl`/`Cmd`+`J`). New here? The **Home** tab's *Start an experiment* button runs a short setup wizard (also available anytime via `/wizard` — it sets the organism, the campaign, and what you're trying to learn). +2. **Enter plan mode** — type `/plan` in the chat. The agent switches from *operator* to *scientific collaborator*: it won't touch hardware, it helps you design an experiment. +3. **Describe what you want, in plain language.** For example: + > *"Follow GFP-tagged embryos from bean stage through elongation, imaging every 10 minutes, with a no-laser control — three embryos per condition."* + + The agent drafts a **campaign**: a sequence of typed **plan items** — imaging 📷, bench 🧪, genetics 🧬, analysis 📊, decision points 🚦 — each with concrete specs (strain, interval, laser power, Z-slices, target window, success criteria). Keep replying to refine it; `/plan status` shows progress and `/plan exit` returns to run mode. +4. **Inspect it in the plan viewer.** Open the **Plans** tab. Your campaign appears as a card — click it to open the **plan document**. Each item shows its status (○ planned · ◑ in progress · ● done) and specs; click one to see full details in the inspector. Switch layouts (document / board / graph / timeline) from the view controls, and browse plan **versions** as it evolves. (Typing `/campaign` in chat lists campaigns too.) + +That's the loop: **talk → plan → inspect.** With hardware connected (drop `--offline` and start the device layer), the same campaign drives acquisition — and perception events can wake the agent to adjust it as the embryos develop. + ## Guides | Guide | Audience | What you'll learn | @@ -125,7 +240,7 @@ Four layers with strict downward-only dependencies. The **harness** (reusable ag gently/ ├── core/ # Layer 1: Foundation — zero domain knowledge │ ├── event_bus.py # Async pub/sub messaging -│ ├── store.py # GentlyStore (SQLite + files) +│ ├── file_store.py # FileStore (file-based: YAML / JSONL / TIF) │ ├── imaging.py # Projection, normalization, encoding │ └── coordinates.py # Pixel/stage transforms │ @@ -220,4 +335,6 @@ These papers provide theoretical background for gently's approach: ## License -See [LICENSE](LICENSE) file. +Copyright © 2026 Howard Hughes Medical Institute. + +Gently is licensed under the GNU General Public License v3.0 or later (GPL-3.0-or-later) — see the [LICENSE](LICENSE) file. diff --git a/benchmarks/agent/evaluator.py b/benchmarks/agent/evaluator.py index 190b2c8b..3442e79f 100644 --- a/benchmarks/agent/evaluator.py +++ b/benchmarks/agent/evaluator.py @@ -9,7 +9,7 @@ from dataclasses import dataclass, field from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any logger = logging.getLogger(__name__) @@ -17,17 +17,18 @@ @dataclass class EvalResult: """Result of evaluating a single test case""" + test_id: str query: str - expected_tool: Union[str, List[str]] - actual_tool: Optional[str] + expected_tool: str | list[str] + actual_tool: str | None tool_correct: bool params_correct: bool - param_errors: List[str] = field(default_factory=list) + param_errors: list[str] = field(default_factory=list) input_tokens: int = 0 output_tokens: int = 0 latency_ms: float = 0 - error: Optional[str] = None + error: str | None = None @property def passed(self) -> bool: @@ -37,6 +38,7 @@ def passed(self) -> bool: @dataclass class BenchmarkReport: """Summary report for a benchmark run""" + timestamp: str num_cases: int num_passed: int @@ -45,10 +47,10 @@ class BenchmarkReport: total_input_tokens: int total_output_tokens: int avg_latency_ms: float - results: List[EvalResult] - metadata: Dict[str, Any] = field(default_factory=dict) + results: list[EvalResult] + metadata: dict[str, Any] = field(default_factory=dict) - def to_dict(self) -> Dict: + def to_dict(self) -> dict: return { "timestamp": self.timestamp, "summary": { @@ -61,7 +63,10 @@ def to_dict(self) -> Dict: "tokens": { "total_input": self.total_input_tokens, "total_output": self.total_output_tokens, - "avg_per_query": (self.total_input_tokens + self.total_output_tokens) / self.num_cases if self.num_cases > 0 else 0, + "avg_per_query": (self.total_input_tokens + self.total_output_tokens) + / self.num_cases + if self.num_cases > 0 + else 0, }, "latency": { "avg_ms": self.avg_latency_ms, @@ -92,7 +97,7 @@ class AgentEvaluator: print(f"Tool accuracy: {report.tool_accuracy:.1%}") """ - def __init__(self, test_cases_path: Optional[Path] = None): + def __init__(self, test_cases_path: Path | None = None): """ Parameters ---------- @@ -111,8 +116,8 @@ def __init__(self, test_cases_path: Optional[Path] = None): async def run_benchmark( self, agent, - tags: Optional[List[str]] = None, - max_cases: Optional[int] = None, + tags: list[str] | None = None, + max_cases: int | None = None, ) -> BenchmarkReport: """ Run benchmark against agent @@ -162,7 +167,7 @@ async def run_benchmark( metadata={"version": self.version, "tags": tags}, ) - async def _evaluate_case(self, agent, case: Dict) -> EvalResult: + async def _evaluate_case(self, agent, case: dict) -> EvalResult: """Evaluate a single test case""" test_id = case["id"] query = case["query"] @@ -172,6 +177,7 @@ async def _evaluate_case(self, agent, case: Dict) -> EvalResult: try: # Get tool call from agent (without executing) import time + start = time.perf_counter() tool_call = await self._get_tool_call(agent, query) @@ -211,7 +217,9 @@ async def _evaluate_case(self, agent, case: Dict) -> EvalResult: elif key not in actual_params: param_errors.append(f"missing param: {key}") elif actual_params[key] != expected_value: - param_errors.append(f"{key}: expected {expected_value}, got {actual_params[key]}") + param_errors.append( + f"{key}: expected {expected_value}, got {actual_params[key]}" + ) return EvalResult( test_id=test_id, @@ -238,7 +246,7 @@ async def _evaluate_case(self, agent, case: Dict) -> EvalResult: error=str(e), ) - async def _get_tool_call(self, agent, query: str) -> Optional[Dict]: + async def _get_tool_call(self, agent, query: str) -> dict | None: """ Get the tool call Claude would make for a query @@ -248,7 +256,7 @@ async def _get_tool_call(self, agent, query: str) -> Optional[Dict]: return await agent.get_tool_call(query) -def compare_reports(before: BenchmarkReport, after: BenchmarkReport) -> Dict: +def compare_reports(before: BenchmarkReport, after: BenchmarkReport) -> dict: """ Compare two benchmark reports @@ -271,8 +279,8 @@ def compare_reports(before: BenchmarkReport, after: BenchmarkReport) -> Dict: "tokens": { "before": before.total_input_tokens + before.total_output_tokens, "after": after.total_input_tokens + after.total_output_tokens, - "delta": (after.total_input_tokens + after.total_output_tokens) - - (before.total_input_tokens + before.total_output_tokens), + "delta": (after.total_input_tokens + after.total_output_tokens) + - (before.total_input_tokens + before.total_output_tokens), }, "latency_ms": { "before": before.avg_latency_ms, @@ -280,11 +288,13 @@ def compare_reports(before: BenchmarkReport, after: BenchmarkReport) -> Dict: "delta": after.avg_latency_ms - before.avg_latency_ms, }, "regressions": [ - r.test_id for r in after.results + r.test_id + for r in after.results if not r.passed and any(br.test_id == r.test_id and br.passed for br in before.results) ], "improvements": [ - r.test_id for r in after.results + r.test_id + for r in after.results if r.passed and any(br.test_id == r.test_id and not br.passed for br in before.results) ], } diff --git a/benchmarks/perception/__init__.py b/benchmarks/perception/__init__.py index c811d881..f26bbea4 100644 --- a/benchmarks/perception/__init__.py +++ b/benchmarks/perception/__init__.py @@ -5,9 +5,9 @@ """ from .ground_truth import GroundTruth -from .testset import OfflineTestset, TestCase -from .runner import PerceptionBenchmark, BenchmarkConfig, EmbryoResult from .metrics import PerceptionMetrics +from .runner import BenchmarkConfig, EmbryoResult, PerceptionBenchmark +from .testset import OfflineTestset, TestCase __all__ = [ "GroundTruth", diff --git a/benchmarks/perception/ground_truth.py b/benchmarks/perception/ground_truth.py index d3ce2a0a..71187596 100644 --- a/benchmarks/perception/ground_truth.py +++ b/benchmarks/perception/ground_truth.py @@ -7,11 +7,18 @@ import json from dataclasses import dataclass, field from pathlib import Path -from typing import Dict, List, Optional - # Stage progression order -STAGE_ORDER = ["early", "bean", "comma", "1.5fold", "2fold", "pretzel", "hatching", "hatched"] +STAGE_ORDER = [ + "early", + "bean", + "comma", + "1.5fold", + "2fold", + "pretzel", + "hatching", + "hatched", +] @dataclass @@ -24,14 +31,14 @@ class GroundTruth: """ # {embryo_id: {stage: start_timepoint}} - transitions: Dict[str, Dict[str, int]] = field(default_factory=dict) + transitions: dict[str, dict[str, int]] = field(default_factory=dict) # Metadata - session_id: Optional[str] = None - annotator: Optional[str] = None - notes: Optional[str] = None + session_id: str | None = None + annotator: str | None = None + notes: str | None = None - def get_stage_at(self, embryo_id: str, timepoint: int) -> Optional[str]: + def get_stage_at(self, embryo_id: str, timepoint: int) -> str | None: """ Get the ground truth stage for a given embryo at a given timepoint. @@ -63,15 +70,13 @@ def get_stage_at(self, embryo_id: str, timepoint: int) -> Optional[str]: return current_stage - def get_transition_timepoint( - self, embryo_id: str, stage: str - ) -> Optional[int]: + def get_transition_timepoint(self, embryo_id: str, stage: str) -> int | None: """Get the timepoint when a stage starts for a given embryo.""" if embryo_id not in self.transitions: return None return self.transitions[embryo_id].get(stage) - def get_stages_for_embryo(self, embryo_id: str) -> List[str]: + def get_stages_for_embryo(self, embryo_id: str) -> list[str]: """Get list of stages (in order) for a given embryo.""" if embryo_id not in self.transitions: return [] @@ -79,10 +84,7 @@ def get_stages_for_embryo(self, embryo_id: str) -> List[str]: embryo_transitions = self.transitions[embryo_id] # Sort by start timepoint - sorted_stages = sorted( - embryo_transitions.keys(), - key=lambda s: embryo_transitions[s] - ) + sorted_stages = sorted(embryo_transitions.keys(), key=lambda s: embryo_transitions[s]) return sorted_stages def get_timepoint_range(self, embryo_id: str) -> tuple: @@ -101,11 +103,11 @@ def get_timepoint_range(self, embryo_id: str) -> tuple: return (min(starts), max(starts)) @property - def embryo_ids(self) -> List[str]: + def embryo_ids(self) -> list[str]: """Get list of all embryo IDs with ground truth.""" return list(self.transitions.keys()) - def to_dict(self) -> Dict: + def to_dict(self) -> dict: """Serialize to dictionary for JSON storage.""" return { "session_id": self.session_id, @@ -115,7 +117,7 @@ def to_dict(self) -> Dict: } @classmethod - def from_dict(cls, data: Dict) -> "GroundTruth": + def from_dict(cls, data: dict) -> "GroundTruth": """Load from dictionary.""" return cls( transitions=data.get("transitions", {}), @@ -127,7 +129,7 @@ def from_dict(cls, data: Dict) -> "GroundTruth": @classmethod def from_json(cls, path: Path) -> "GroundTruth": """Load ground truth from JSON file.""" - with open(path, "r") as f: + with open(path) as f: data = json.load(f) return cls.from_dict(data) @@ -139,9 +141,9 @@ def save_json(self, path: Path) -> None: def create_ground_truth_from_email_format( - annotations: Dict[str, str], - session_id: Optional[str] = None, - annotator: Optional[str] = None, + annotations: dict[str, str], + session_id: str | None = None, + annotator: str | None = None, ) -> GroundTruth: """ Create GroundTruth from email-style annotations. diff --git a/benchmarks/perception/live_viewer.py b/benchmarks/perception/live_viewer.py index a6ea2456..7b0f7af7 100644 --- a/benchmarks/perception/live_viewer.py +++ b/benchmarks/perception/live_viewer.py @@ -13,22 +13,19 @@ import argparse import asyncio -import base64 import json import logging import sys import webbrowser -from dataclasses import asdict from datetime import datetime, timedelta from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any # FastAPI and websockets try: + import uvicorn from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.responses import HTMLResponse - from fastapi.staticfiles import StaticFiles - import uvicorn except ImportError: print("Please install: pip install fastapi uvicorn websockets") sys.exit(1) @@ -39,8 +36,8 @@ logger = logging.getLogger(__name__) # Global state for websocket connections -connected_clients: List[WebSocket] = [] -benchmark_state: Dict[str, Any] = { +connected_clients: list[WebSocket] = [] +benchmark_state: dict[str, Any] = { "status": "idle", "current_embryo": None, "current_timepoint": None, @@ -52,7 +49,7 @@ pause_event: asyncio.Event = None # Will be initialized on startup -async def broadcast(message: Dict): +async def broadcast(message: dict): """Broadcast message to all connected clients.""" if not connected_clients: return @@ -396,7 +393,10 @@ async def broadcast(message: Dict):

Perception Benchmark Live Viewer

- +
Connecting...
@@ -555,8 +555,10 @@ async def broadcast(message: Dict): // Display three-view combined image (XY+YZ+XZ orthogonal projections) container.innerHTML = `
-
THREE-VIEW (XY | YZ / XZ)
- Three-View T${timepoint} +
+ THREE-VIEW (XY | YZ / XZ)
+ Three-View T${timepoint}
`; // Also get embryoId and groundTruth from stored data if not provided embryoId = embryoId || imgData.embryoId; @@ -628,10 +630,13 @@ async def broadcast(message: Dict): // Update header document.querySelector('.trace-section h2').innerHTML = - `Reasoning Trace - T${timepoint} `; + `Reasoning Trace - T${timepoint} `; if (traceSteps.length === 0) { - list.innerHTML = '
No trace recorded for T' + timepoint + '
'; + list.innerHTML = '
No trace recorded for T' + + timepoint + '
'; return; } @@ -685,12 +690,16 @@ async def broadcast(message: Dict): 'verified' : ''; html += ` -
+
T${pred.timepoint}
-
${pred.predicted}
-
${pred.ground_truth}
+
${pred.predicted}
+
${pred.ground_truth}
${(pred.confidence * 100).toFixed(0)}%
-
${pred.phase_count > 1 ? pred.phase_count + '-phase' : ''}${verifiedBadge}
+
${pred.phase_count > 1 + ? pred.phase_count + '-phase' : ''}${verifiedBadge}
`; } @@ -749,10 +758,14 @@ async def websocket_endpoint(websocket: WebSocket): connected_clients.append(websocket) # Send current state - await websocket.send_text(json.dumps({ - "type": "status", - "status": benchmark_state["status"], - })) + await websocket.send_text( + json.dumps( + { + "type": "status", + "status": benchmark_state["status"], + } + ) + ) try: while True: @@ -789,8 +802,8 @@ def __init__( embryo_id: str, enable_verification: bool = True, start_timepoint: int = 0, - max_timepoints: Optional[int] = None, - trace_dir: Optional[Path] = None, + max_timepoints: int | None = None, + trace_dir: Path | None = None, ): self.testset = testset self.embryo_id = embryo_id @@ -798,7 +811,7 @@ def __init__( self.start_timepoint = start_timepoint self.max_timepoints = max_timepoints - self.predictions: List[Dict] = [] + self.predictions: list[dict] = [] self.correct_count = 0 self.adjacent_count = 0 self.verified_count = 0 @@ -808,7 +821,7 @@ def __init__( self.trace_dir = trace_dir or Path("benchmarks/results/traces") self.run_dir = self.trace_dir / f"{self.run_id}_{embryo_id}" self.run_dir.mkdir(parents=True, exist_ok=True) - self.traces: Dict[int, List[Dict]] = {} # timepoint -> trace steps + self.traces: dict[int, list[dict]] = {} # timepoint -> trace steps logger.info(f"Trace persistence enabled: {self.run_dir}") @@ -861,23 +874,23 @@ async def run(self): benchmark_state["current_timepoint"] = test_case.timepoint # Send image(s) - await broadcast({ - "type": "image", - "embryo_id": self.embryo_id, - "timepoint": test_case.timepoint, - "ground_truth": test_case.ground_truth_stage, - "image": test_case.image_b64, # Combined for backward compat - "top_image": test_case.top_image_b64, - "side_image": test_case.side_image_b64, - }) + await broadcast( + { + "type": "image", + "embryo_id": self.embryo_id, + "timepoint": test_case.timepoint, + "ground_truth": test_case.ground_truth_stage, + "image": test_case.image_b64, # Combined for backward compat + "top_image": test_case.top_image_b64, + "side_image": test_case.side_image_b64, + } + ) # Clear trace for new prediction await broadcast({"type": "clear_trace"}) # Run perception with trace streaming - result = await self._run_perception_with_streaming( - engine, session, test_case - ) + result = await self._run_perception_with_streaming(engine, session, test_case) # Check accuracy is_correct = result.stage == test_case.ground_truth_stage @@ -915,23 +928,29 @@ async def run(self): await broadcast(pred_msg) # Save trace for this timepoint - self._save_timepoint_trace(test_case.timepoint, result, test_case, is_correct, is_adjacent) + self._save_timepoint_trace( + test_case.timepoint, result, test_case, is_correct, is_adjacent + ) # Send updated stats total = len(self.predictions) - await broadcast({ - "type": "stats", - "accuracy": self.correct_count / total if total > 0 else None, - "adjacent": self.adjacent_count / total if total > 0 else None, - "total": total, - "verified": self.verified_count, - }) + await broadcast( + { + "type": "stats", + "accuracy": self.correct_count / total if total > 0 else None, + "adjacent": self.adjacent_count / total if total > 0 else None, + "total": total, + "verified": self.verified_count, + } + ) # Add observation to session with simulated timestamp # Typical diSPIM acquisition interval is ~4 minutes per timepoint - simulated_timestamp = datetime.now() - timedelta( - minutes=(self.max_timepoints or 100) * 4 - ) + timedelta(minutes=test_case.timepoint * 4) + simulated_timestamp = ( + datetime.now() + - timedelta(minutes=(self.max_timepoints or 100) * 4) + + timedelta(minutes=test_case.timepoint * 4) + ) session.add_observation( timepoint=test_case.timepoint, @@ -953,7 +972,6 @@ async def run(self): async def _run_perception_with_streaming(self, engine, session, test_case): """Run perception and stream trace steps.""" - from gently.harness.perception.session import ReasoningStep # We need to hook into the reasoning trace # For now, run perception and stream the trace after @@ -985,7 +1003,9 @@ async def _run_perception_with_streaming(self, engine, session, test_case): return result - def _save_timepoint_trace(self, timepoint: int, result, test_case, is_correct: bool, is_adjacent: bool): + def _save_timepoint_trace( + self, timepoint: int, result, test_case, is_correct: bool, is_adjacent: bool + ): """Save trace for a single timepoint to disk.""" trace_data = { "timepoint": timepoint, @@ -1035,7 +1055,7 @@ async def run_benchmark_background( embryo_id: str, enable_verification: bool, start_timepoint: int = 0, - max_timepoints: Optional[int] = None, + max_timepoints: int | None = None, ): """Run benchmark in background after server starts.""" print("[DEBUG] run_benchmark_background starting", flush=True) @@ -1049,7 +1069,10 @@ async def run_benchmark_background( # Load data ground_truth = GroundTruth.from_json(ground_truth_path) - print(f"[DEBUG] Loaded ground truth: {len(ground_truth.transitions)} embryos", flush=True) + print( + f"[DEBUG] Loaded ground truth: {len(ground_truth.transitions)} embryos", + flush=True, + ) testset = OfflineTestset( session_path=session_path, ground_truth=ground_truth, @@ -1071,6 +1094,7 @@ async def run_benchmark_background( except Exception as e: print(f"[DEBUG] ERROR in run_benchmark_background: {e}", flush=True) import traceback + traceback.print_exc() @@ -1142,9 +1166,9 @@ def main(): trace_dir = Path("benchmarks/results/traces") run_id = datetime.now().strftime("%Y%m%d_%H%M%S") - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print("Perception Benchmark Live Viewer") - print(f"{'='*60}") + print(f"{'=' * 60}") print(f"Session: {session_path}") print(f"Embryo: {args.embryo}") print(f"Start timepoint: T{args.start_timepoint}") @@ -1152,7 +1176,7 @@ def main(): print(f"Verification: {'disabled' if args.no_verification else 'enabled'}") print(f"Traces: {trace_dir / f'{run_id}_{args.embryo}'}") print(f"URL: http://localhost:{args.port}") - print(f"{'='*60}\n") + print(f"{'=' * 60}\n") # Open browser if not args.no_browser: @@ -1162,14 +1186,16 @@ def main(): @app.on_event("startup") async def startup_event(): print("[DEBUG] Startup event fired", flush=True) - asyncio.create_task(run_benchmark_background( - session_path=session_path, - ground_truth_path=gt_path, - embryo_id=args.embryo, - enable_verification=not args.no_verification, - start_timepoint=args.start_timepoint, - max_timepoints=args.max_timepoints, - )) + asyncio.create_task( + run_benchmark_background( + session_path=session_path, + ground_truth_path=gt_path, + embryo_id=args.embryo, + enable_verification=not args.no_verification, + start_timepoint=args.start_timepoint, + max_timepoints=args.max_timepoints, + ) + ) print("[DEBUG] Background task created", flush=True) # Run server diff --git a/benchmarks/perception/metrics.py b/benchmarks/perception/metrics.py index f69c054f..f0578ea1 100644 --- a/benchmarks/perception/metrics.py +++ b/benchmarks/perception/metrics.py @@ -6,14 +6,23 @@ from collections import defaultdict from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from .runner import BenchmarkReport # Stage order for metrics -STAGE_ORDER = ["early", "bean", "comma", "1.5fold", "2fold", "pretzel", "hatching", "hatched"] +STAGE_ORDER = [ + "early", + "bean", + "comma", + "1.5fold", + "2fold", + "pretzel", + "hatching", + "hatched", +] @dataclass @@ -25,30 +34,30 @@ class PerceptionMetrics: adjacent_accuracy: float = 0.0 # Within 1 stage # Per-stage accuracy - stage_accuracy: Dict[str, float] = field(default_factory=dict) - stage_counts: Dict[str, int] = field(default_factory=dict) + stage_accuracy: dict[str, float] = field(default_factory=dict) + stage_counts: dict[str, int] = field(default_factory=dict) # Confusion matrix: confusion[gt_stage][pred_stage] = count - confusion_matrix: Dict[str, Dict[str, int]] = field(default_factory=dict) + confusion_matrix: dict[str, dict[str, int]] = field(default_factory=dict) # Confidence calibration mean_confidence: float = 0.0 confidence_when_correct: float = 0.0 confidence_when_wrong: float = 0.0 - calibration_bins: List[Tuple[float, float, int]] = field(default_factory=list) + calibration_bins: list[tuple[float, float, int]] = field(default_factory=list) # (confidence_bin_center, accuracy_in_bin, count) expected_calibration_error: float = 0.0 # ECE # Temporal metrics backward_transitions: int = 0 # Errors where stage went backward - stage_transition_delay: Dict[str, float] = field(default_factory=dict) + stage_transition_delay: dict[str, float] = field(default_factory=dict) # How many timepoints after GT transition until prediction caught up # Tool usage total_tool_calls: int = 0 tool_call_rate: float = 0.0 # Avg tool calls per prediction - tool_use_by_stage: Dict[str, float] = field(default_factory=dict) + tool_use_by_stage: dict[str, float] = field(default_factory=dict) # When tools were used vs not accuracy_with_tools: float = 0.0 @@ -59,7 +68,7 @@ class PerceptionMetrics: transitional_rate: float = 0.0 transitional_accuracy: float = 0.0 # Accuracy when marked transitional - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "accuracy": self.accuracy, "adjacent_accuracy": self.adjacent_accuracy, @@ -102,8 +111,7 @@ def compute_metrics(report: "BenchmarkReport") -> PerceptionMetrics: # Collect all predictions all_preds = [ - p for r in report.embryo_results for p in r.predictions - if p.ground_truth_stage is not None + p for r in report.embryo_results for p in r.predictions if p.ground_truth_stage is not None ] if not all_preds: @@ -134,9 +142,7 @@ def compute_metrics(report: "BenchmarkReport") -> PerceptionMetrics: for p in all_preds: confusion[p.ground_truth_stage][p.predicted_stage] += 1 - metrics.confusion_matrix = { - gt: dict(preds) for gt, preds in confusion.items() - } + metrics.confusion_matrix = {gt: dict(preds) for gt, preds in confusion.items()} # Confidence statistics confidences = [p.confidence for p in all_preds] @@ -156,10 +162,7 @@ def compute_metrics(report: "BenchmarkReport") -> PerceptionMetrics: bin_high = (i + 1) / num_bins bin_center = (bin_low + bin_high) / 2 - bin_preds = [ - p for p in all_preds - if bin_low <= p.confidence < bin_high - ] + bin_preds = [p for p in all_preds if bin_low <= p.confidence < bin_high] if bin_preds: bin_accuracy = sum(1 for p in bin_preds if p.is_correct) / len(bin_preds) @@ -207,7 +210,9 @@ def compute_metrics(report: "BenchmarkReport") -> PerceptionMetrics: if with_tools: metrics.accuracy_with_tools = sum(1 for p in with_tools if p.is_correct) / len(with_tools) if without_tools: - metrics.accuracy_without_tools = sum(1 for p in without_tools if p.is_correct) / len(without_tools) + metrics.accuracy_without_tools = sum(1 for p in without_tools if p.is_correct) / len( + without_tools + ) # Transitional observations transitional_preds = [p for p in all_preds if p.is_transitional] @@ -215,16 +220,16 @@ def compute_metrics(report: "BenchmarkReport") -> PerceptionMetrics: metrics.transitional_rate = len(transitional_preds) / len(all_preds) if transitional_preds: - metrics.transitional_accuracy = sum( - 1 for p in transitional_preds if p.is_correct - ) / len(transitional_preds) + metrics.transitional_accuracy = sum(1 for p in transitional_preds if p.is_correct) / len( + transitional_preds + ) return metrics def format_confusion_matrix( - confusion: Dict[str, Dict[str, int]], - stages: Optional[List[str]] = None, + confusion: dict[str, dict[str, int]], + stages: list[str] | None = None, ) -> str: """Format confusion matrix as ASCII table.""" if stages is None: @@ -278,34 +283,38 @@ def format_metrics_summary(metrics: PerceptionMetrics) -> str: count = metrics.stage_counts[stage] lines.append(f" {stage:>10}: {acc:.1%} (n={count})") - lines.extend([ - "", - "CONFIDENCE CALIBRATION", - f" Mean confidence: {metrics.mean_confidence:.2f}", - f" Confidence (correct): {metrics.confidence_when_correct:.2f}", - f" Confidence (wrong): {metrics.confidence_when_wrong:.2f}", - f" Expected Cal. Error: {metrics.expected_calibration_error:.3f}", - "", - "TOOL USAGE", - f" Total tool calls: {metrics.total_tool_calls}", - f" Avg calls per pred: {metrics.tool_call_rate:.2f}", - f" Accuracy with tools: {metrics.accuracy_with_tools:.1%}", - f" Accuracy without tools: {metrics.accuracy_without_tools:.1%}", - "", - "TEMPORAL", - f" Backward transitions: {metrics.backward_transitions}", - "", - "TRANSITIONAL OBSERVATIONS", - f" Count: {metrics.transitional_count}", - f" Rate: {metrics.transitional_rate:.1%}", - f" Accuracy: {metrics.transitional_accuracy:.1%}", - ]) + lines.extend( + [ + "", + "CONFIDENCE CALIBRATION", + f" Mean confidence: {metrics.mean_confidence:.2f}", + f" Confidence (correct): {metrics.confidence_when_correct:.2f}", + f" Confidence (wrong): {metrics.confidence_when_wrong:.2f}", + f" Expected Cal. Error: {metrics.expected_calibration_error:.3f}", + "", + "TOOL USAGE", + f" Total tool calls: {metrics.total_tool_calls}", + f" Avg calls per pred: {metrics.tool_call_rate:.2f}", + f" Accuracy with tools: {metrics.accuracy_with_tools:.1%}", + f" Accuracy without tools: {metrics.accuracy_without_tools:.1%}", + "", + "TEMPORAL", + f" Backward transitions: {metrics.backward_transitions}", + "", + "TRANSITIONAL OBSERVATIONS", + f" Count: {metrics.transitional_count}", + f" Rate: {metrics.transitional_rate:.1%}", + f" Accuracy: {metrics.transitional_accuracy:.1%}", + ] + ) if metrics.confusion_matrix: - lines.extend([ - "", - "CONFUSION MATRIX", - format_confusion_matrix(metrics.confusion_matrix), - ]) + lines.extend( + [ + "", + "CONFUSION MATRIX", + format_confusion_matrix(metrics.confusion_matrix), + ] + ) return "\n".join(lines) diff --git a/benchmarks/perception/runner.py b/benchmarks/perception/runner.py index 33d93fe3..70ac32fd 100644 --- a/benchmarks/perception/runner.py +++ b/benchmarks/perception/runner.py @@ -12,11 +12,11 @@ from dataclasses import dataclass, field from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any from .ground_truth import GroundTruth -from .testset import OfflineTestset, TestCase from .metrics import PerceptionMetrics, compute_metrics +from .testset import OfflineTestset logger = logging.getLogger(__name__) @@ -39,20 +39,20 @@ class BenchmarkConfig: # Test settings start_timepoint: int = 0 - max_timepoints_per_embryo: Optional[int] = None - embryo_ids: Optional[List[str]] = None # None = all + max_timepoints_per_embryo: int | None = None + embryo_ids: list[str] | None = None # None = all # Ablation toggles include_temporal_context: bool = True include_previous_observations: bool = True # Custom system prompt override - system_prompt_override: Optional[str] = None + system_prompt_override: str | None = None # Metadata description: str = "" - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "model": self.model, "temperature": self.temperature, @@ -78,20 +78,20 @@ class PredictionResult: timepoint: int predicted_stage: str - ground_truth_stage: Optional[str] + ground_truth_stage: str | None confidence: float is_transitional: bool - transition_between: Optional[List[str]] + transition_between: list[str] | None reasoning: str - reasoning_trace: Optional[Dict[str, Any]] # Serialized ReasoningTrace + reasoning_trace: dict[str, Any] | None # Serialized ReasoningTrace tool_calls: int - tools_used: List[str] + tools_used: list[str] # Multi-phase verification fields verification_triggered: bool = False phase_count: int = 1 - verification_result: Optional[Dict[str, Any]] = None - candidate_stages: Optional[List[Dict[str, Any]]] = None + verification_result: dict[str, Any] | None = None + candidate_stages: list[dict[str, Any]] | None = None @property def is_correct(self) -> bool: @@ -113,7 +113,7 @@ def is_adjacent_correct(self) -> bool: except ValueError: return False - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "timepoint": self.timepoint, "predicted_stage": self.predicted_stage, @@ -139,9 +139,9 @@ class EmbryoResult: """Results for a single embryo run.""" embryo_id: str - predictions: List[PredictionResult] = field(default_factory=list) + predictions: list[PredictionResult] = field(default_factory=list) duration_seconds: float = 0.0 - error: Optional[str] = None + error: str | None = None @property def accuracy(self) -> float: @@ -159,7 +159,7 @@ def adjacent_accuracy(self) -> float: correct = sum(1 for p in self.predictions if p.is_adjacent_correct) return correct / len(self.predictions) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "embryo_id": self.embryo_id, "predictions": [p.to_dict() for p in self.predictions], @@ -175,11 +175,11 @@ class BenchmarkReport: """Complete benchmark report.""" config: BenchmarkConfig - embryo_results: List[EmbryoResult] = field(default_factory=list) - metrics: Optional[PerceptionMetrics] = None + embryo_results: list[EmbryoResult] = field(default_factory=list) + metrics: PerceptionMetrics | None = None started_at: datetime = field(default_factory=datetime.now) - completed_at: Optional[datetime] = None - session_id: Optional[str] = None + completed_at: datetime | None = None + session_id: str | None = None @property def total_predictions(self) -> int: @@ -192,7 +192,7 @@ def overall_accuracy(self) -> float: return 0.0 return sum(1 for p in all_preds if p.is_correct) / len(all_preds) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "config": self.config.to_dict(), "embryo_results": [r.to_dict() for r in self.embryo_results], @@ -222,7 +222,7 @@ def __init__( self, testset: OfflineTestset, config: BenchmarkConfig, - engine: Optional[Any] = None, # PerceptionEngine + engine: Any | None = None, # PerceptionEngine ): """ Parameters @@ -246,7 +246,6 @@ async def _get_engine(self): # Lazy import to avoid circular dependencies import anthropic from gently.harness.perception.engine import PerceptionEngine - from gently.harness.perception.example_store import ExampleStore client = anthropic.Anthropic() @@ -465,7 +464,8 @@ async def main(): help="Description for this benchmark run", ) parser.add_argument( - "-v", "--verbose", + "-v", + "--verbose", action="store_true", help="Verbose logging", ) @@ -481,6 +481,7 @@ async def main(): # The perception engine reads stage definitions etc. from the active # organism module, which is normally loaded by launch_gently.py. from gently.organisms import load_organism + load_organism("celegans") # Find session path diff --git a/benchmarks/perception/testset.py b/benchmarks/perception/testset.py index 658cc011..075eccdb 100644 --- a/benchmarks/perception/testset.py +++ b/benchmarks/perception/testset.py @@ -6,10 +6,10 @@ import base64 import io +from collections.abc import Iterator from dataclasses import dataclass from datetime import datetime from pathlib import Path -from typing import Iterator, List, Optional, Tuple, Dict import numpy as np @@ -26,10 +26,12 @@ def _ensure_dependencies(): if tifffile is None: import tifffile as _tifffile + tifffile = _tifffile if PIL_Image is None: from PIL import Image as _Image + PIL_Image = _Image @@ -40,23 +42,25 @@ class TestCase: embryo_id: str timepoint: int image_b64: str # Combined view (for backward compatibility) - top_image_b64: Optional[str] # TOP view only - side_image_b64: Optional[str] # SIDE view only - volume: Optional[np.ndarray] - ground_truth_stage: Optional[str] - acquired_at: Optional[datetime] = None + top_image_b64: str | None # TOP view only + side_image_b64: str | None # SIDE view only + volume: np.ndarray | None + ground_truth_stage: str | None + acquired_at: datetime | None = None def _discover_volumes( - session_dir: Path, embryo_id: Optional[str] = None -) -> Dict[str, List[Tuple[datetime, Path]]]: + session_dir: Path, embryo_id: str | None = None +) -> dict[str, list[tuple[datetime, Path]]]: """Discover volume files (with parsed acquisition timestamps) in a session directory.""" if not session_dir.exists(): return {} tif_files = ( - list(session_dir.glob("*.tif")) + list(session_dir.glob("*.tiff")) - + list(session_dir.glob("**/*.tif")) + list(session_dir.glob("**/*.tiff")) + list(session_dir.glob("*.tif")) + + list(session_dir.glob("*.tiff")) + + list(session_dir.glob("**/*.tif")) + + list(session_dir.glob("**/*.tiff")) ) # Deduplicate (flat + recursive may overlap) tif_files = list({f.resolve(): f for f in tif_files}.values()) @@ -89,6 +93,7 @@ def _discover_volumes( def _load_volume(path: Path) -> np.ndarray: """Load a volume from TIFF file.""" from gently.core.imaging import load_volume + return load_volume(path) @@ -122,9 +127,9 @@ def _create_three_view_image(volume: np.ndarray, max_dim: int = 1500) -> str: _ensure_dependencies() from gently.core.imaging import ( - projection_three_view, - compute_crop_bounds, apply_crop_bounds, + compute_crop_bounds, + projection_three_view, ) # Auto-crop to embryo region @@ -150,7 +155,7 @@ def _create_three_view_image(volume: np.ndarray, max_dim: int = 1500) -> str: return base64.b64encode(buffer.getvalue()).decode("utf-8") -def _create_separate_view_images(volume: np.ndarray, max_dim: int = 1000) -> Tuple[str, str]: +def _create_separate_view_images(volume: np.ndarray, max_dim: int = 1000) -> tuple[str, str]: """Create separate TOP and SIDE view images from volume, return base64 tuple. Parameters @@ -244,7 +249,7 @@ def __init__( self._embryo_volumes = _discover_volumes(self.session_path) @property - def embryo_ids(self) -> List[str]: + def embryo_ids(self) -> list[str]: """Get list of embryo IDs with both volumes and ground truth.""" gt_embryos = set(self.ground_truth.embryo_ids) vol_embryos = set(self._embryo_volumes.keys()) @@ -258,7 +263,7 @@ def iter_embryo( self, embryo_id: str, start_timepoint: int = 0, - end_timepoint: Optional[int] = None, + end_timepoint: int | None = None, ) -> Iterator[TestCase]: """ Iterate through timepoints for an embryo sequentially. @@ -318,7 +323,7 @@ def iter_embryo( acquired_at=acquired_at, ) - def iter_all(self) -> Iterator[Tuple[str, Iterator[TestCase]]]: + def iter_all(self) -> Iterator[tuple[str, Iterator[TestCase]]]: """ Iterate through all embryos in the testset. diff --git a/benchmarks/perception/trace_viewer.py b/benchmarks/perception/trace_viewer.py index 5a3ced2d..69ea5fa6 100644 --- a/benchmarks/perception/trace_viewer.py +++ b/benchmarks/perception/trace_viewer.py @@ -7,12 +7,9 @@ import argparse import json import sys -from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Optional - -from .metrics import format_metrics_summary, PerceptionMetrics +from .metrics import PerceptionMetrics, format_metrics_summary HTML_TEMPLATE = """ @@ -293,16 +290,16 @@ """ -def generate_embryo_section(embryo_id: str, predictions: List[Dict]) -> str: +def generate_embryo_section(embryo_id: str, predictions: list[dict]) -> str: """Generate HTML for one embryo's predictions.""" rows = [ '
', - '
Timepoint
', - '
Predicted
', - '
Ground Truth
', - '
Confidence
', - '
Details
', - '
', + "
Timepoint
", + "
Predicted
", + "
Ground Truth
", + "
Confidence
", + "
Details
", + "
", ] for i, pred in enumerate(predictions): @@ -320,15 +317,15 @@ def generate_embryo_section(embryo_id: str, predictions: List[Dict]) -> str: row_id = f"{embryo_id}-{timepoint}" rows.append(f'
') - rows.append(f'
T{timepoint}
') + rows.append(f"
T{timepoint}
") rows.append(f'
{pred_stage}
') rows.append(f'
{gt_stage}
') - rows.append(f'''
+ rows.append(f"""
{confidence:.0%} -
''') +
""") # Details column with expand button tool_calls = pred.get("tool_calls", 0) @@ -351,31 +348,31 @@ def generate_embryo_section(embryo_id: str, predictions: List[Dict]) -> str: if phase_count > 1: badges += f'{phase_count}-phase' - rows.append(f'''
+ rows.append(f"""
{details_str} {badges} [show trace] -
''') - rows.append('
') +
""") + rows.append("") # Reasoning trace (hidden by default) trace_html = format_reasoning_trace(pred.get("reasoning_trace")) - rows.append(f'''
+ rows.append(f"""
Reasoning: {reasoning} {trace_html} -
''') +
""") - return f''' + return f"""

{embryo_id}

{"".join(rows)}
- ''' + """ -def format_reasoning_trace(trace: Optional[Dict]) -> str: +def format_reasoning_trace(trace: dict | None) -> str: """Format reasoning trace as HTML.""" if not trace: return "" @@ -393,56 +390,66 @@ def format_reasoning_trace(trace: Optional[Dict]) -> str: if step_type == "tool_call": tool_name = step.get("tool_name", "") tool_input = step.get("tool_input", {}) - html_parts.append(f''' + html_parts.append(f"""
Tool Call: {tool_name}
Input: {json.dumps(tool_input, indent=2)}
- ''') + """) elif step_type == "tool_result": summary = step.get("tool_result_summary", content) - html_parts.append(f''' + html_parts.append(f"""
Tool Result: {summary}
- ''') + """) elif step_type == "final_decision": - html_parts.append(f''' + html_parts.append(f"""
Final Decision:
{content[:500]}...
- ''') + """) elif step_type == "verification_requested": - html_parts.append(f''' + html_parts.append(f"""
Verification Requested:
{content}
- ''') + """) elif step_type == "verification_subagent": tool_input = step.get("tool_input", {}) summary = step.get("tool_result_summary", content) - html_parts.append(f''' + html_parts.append(f"""
- Subagent: {tool_input.get("stage_a", "?")} vs {tool_input.get("stage_b", "?")}
+ Subagent: {tool_input.get("stage_a", "?")} vs + {tool_input.get("stage_b", "?")}
Result: {summary}
- ''') + """) elif step_type == "verification_result": - html_parts.append(f''' + html_parts.append(f"""
Verification Result:
{content}
- ''') + """) return "".join(html_parts) -def generate_confusion_matrix_html(confusion: Dict[str, Dict[str, int]]) -> str: +def generate_confusion_matrix_html(confusion: dict[str, dict[str, int]]) -> str: """Generate HTML table for confusion matrix.""" - stages = ["early", "bean", "comma", "1.5fold", "2fold", "pretzel", "hatching", "hatched"] + stages = [ + "early", + "bean", + "comma", + "1.5fold", + "2fold", + "pretzel", + "hatching", + "hatched", + ] # Filter to stages present in data present = set() @@ -478,7 +485,7 @@ def generate_confusion_matrix_html(confusion: Dict[str, Dict[str, int]]) -> str: return "".join(rows) -def generate_html_report(report_data: Dict) -> str: +def generate_html_report(report_data: dict) -> str: """Generate complete HTML report from benchmark data.""" # Extract summary metrics metrics = report_data.get("metrics", {}) @@ -493,9 +500,7 @@ def generate_html_report(report_data: Dict) -> str: # Generate embryo sections embryo_sections = [] for er in embryo_results: - embryo_sections.append( - generate_embryo_section(er["embryo_id"], er["predictions"]) - ) + embryo_sections.append(generate_embryo_section(er["embryo_id"], er["predictions"])) # Generate confusion matrix confusion = metrics.get("confusion_matrix", {}) @@ -569,8 +574,7 @@ def main(): # Filter embryo if specified if args.embryo: report_data["embryo_results"] = [ - er for er in report_data.get("embryo_results", []) - if er["embryo_id"] == args.embryo + er for er in report_data.get("embryo_results", []) if er["embryo_id"] == args.embryo ] # Generate HTML diff --git a/benchmarks/runner.py b/benchmarks/runner.py index 4d6d847a..7dd31236 100644 --- a/benchmarks/runner.py +++ b/benchmarks/runner.py @@ -13,8 +13,6 @@ import json import logging import sys -from datetime import datetime -from pathlib import Path logging.basicConfig(level=logging.INFO, format="%(message)s") logger = logging.getLogger(__name__) @@ -66,7 +64,6 @@ async def run_agent_benchmark(args): def compare_reports(args): """Compare two benchmark reports""" - from .agent.evaluator import BenchmarkReport, compare_reports as _compare with open(args.before) as f: before_data = json.load(f) @@ -101,7 +98,10 @@ def compare_reports(args): delta_str = f"+{fmt.format(delta)}" if delta > 0 else fmt.format(delta) status = "improved" if delta > 0 else ("regressed" if delta < 0 else "unchanged") - logger.info(f" {name}: {fmt.format(before_val)} -> {fmt.format(after_val)} ({delta_str}) [{status}]") + logger.info( + f" {name}: {fmt.format(before_val)} -> {fmt.format(after_val)}" + f" ({delta_str}) [{status}]" + ) # Token comparison before_tokens = before_data.get("tokens", {}) @@ -111,7 +111,7 @@ def compare_reports(args): after_total = after_tokens.get("total_input", 0) + after_tokens.get("total_output", 0) token_delta = after_total - before_total - logger.info(f"\nTokens:") + logger.info("\nTokens:") logger.info(f" Total: {before_total:,} -> {after_total:,} ({token_delta:+,})") return 0 diff --git a/config/config.yml b/config/config.yml index e59e9f6d..9c05c274 100644 --- a/config/config.yml +++ b/config/config.yml @@ -1,4 +1,38 @@ organism: "celegans" hardware: "dispim" mmconfig: "MMConfig_tracking_screening.cfg" -mmdirectory: "C:/Program Files/Micro-Manager-1.4" \ No newline at end of file +mmdirectory: "C:/Program Files/Micro-Manager-1.4" + +# SwitchBot Bot — physical button-pusher mounted on the diSPIM room light +# switch. Talks BLE direct (no SwitchBot Hub / cloud). Plans address it by +# name, e.g. `bps.mv(room_light, 'on')`. Remove this block to skip +# registration; the device layer is tolerant of either state. +switchbot: + name: room_light + address: "EC:6F:04:06:5B:23" + timeout: 20.0 + +# ACUITYnano Precision Thermal Controller (Peltier/TEC, 0.0–99.9 °C). When this +# block is present the device layer registers a `temperature` device and the +# Devices tab shows a setpoint control; plans can also block on it via +# `bps.mv(temperature, 20.0)`. Remove/comment the block to skip registration. +# +# backend: mqtt talks to the controller over the vendor's MQTT bridge +# (acuitynano_precision_thermalizer_api). With no broker/port/user/password +# keys it uses the vendor package's embedded HiveMQ Cloud defaults — set them +# here only to point at a different broker. The vendor package must be +# installed on the device-layer machine (not on PyPI). +temperature: + name: temperature + backend: serial # serial | mqtt | mock — USB serial is the working + # link on this machine (MQTT cloud is firewalled) + com_port: "COM8" + baud_rate: 115200 + stabilize_timeout: 600 # seconds to wait for "[ SYSTEM LOCKED ]" on a blocking set + feedback_peltier: false # true = control off the peltier sensor instead of water + # MQTT alternative (vendor HiveMQ Cloud defaults; needs outbound TLS :8883): + # backend: mqtt + # broker: "your-broker.example.com" # optional — overrides the embedded default + # port: 8883 + # user: "username" + # password: "secret" \ No newline at end of file diff --git a/diagnostics/benchmark_gentlystore_fps.py b/diagnostics/benchmark_gentlystore_fps.py index 43187b9e..3e8188a7 100644 --- a/diagnostics/benchmark_gentlystore_fps.py +++ b/diagnostics/benchmark_gentlystore_fps.py @@ -22,23 +22,22 @@ import argparse import shutil import statistics + +# Add gently to path +import sys import tempfile import time from dataclasses import dataclass, field from datetime import datetime from pathlib import Path -from typing import List, Optional import numpy as np -# Add gently to path -import sys GENTLY_ROOT = Path(__file__).resolve().parent.parent if str(GENTLY_ROOT) not in sys.path: sys.path.insert(0, str(GENTLY_ROOT)) -from gently.core.file_store import FileStore - +from gently.core.file_store import FileStore # noqa: E402 # --------------------------------------------------------------------------- # Default parameters -- typical diSPIM volume dimensions @@ -60,9 +59,9 @@ class BenchmarkResult: approach: str num_slices: int volume_shape: tuple - timings: List[float] = field(default_factory=list) - sizes_mb: List[float] = field(default_factory=list) - errors: List[str] = field(default_factory=list) + timings: list[float] = field(default_factory=list) + sizes_mb: list[float] = field(default_factory=list) + errors: list[str] = field(default_factory=list) @property def mean(self) -> float: @@ -130,7 +129,7 @@ def generate_synthetic_volume( z = np.linspace(0, 1, num_slices)[:, None, None] y = np.linspace(0, 1, height)[None, :, None] x = np.linspace(0, 1, width)[None, None, :] - vol = (z * 0.3 + y * 0.3 + x * 0.4) + vol = z * 0.3 + y * 0.3 + x * 0.4 if np.issubdtype(dtype, np.integer): info = np.iinfo(dtype) vol = (vol * info.max).astype(dtype) @@ -149,14 +148,18 @@ def generate_synthetic_volume( cx = np.random.randint(width // 4, 3 * width // 4) sz, sy, sx = 3, 15, 15 - z_idx = np.clip(np.arange(cz - sz*2, cz + sz*2), 0, num_slices - 1) - y_idx = np.clip(np.arange(cy - sy*2, cy + sy*2), 0, height - 1) - x_idx = np.clip(np.arange(cx - sx*2, cx + sx*2), 0, width - 1) + z_idx = np.clip(np.arange(cz - sz * 2, cz + sz * 2), 0, num_slices - 1) + y_idx = np.clip(np.arange(cy - sy * 2, cy + sy * 2), 0, height - 1) + x_idx = np.clip(np.arange(cx - sx * 2, cx + sx * 2), 0, width - 1) - zz, yy, xx = np.meshgrid(z_idx, y_idx, x_idx, indexing='ij') + zz, yy, xx = np.meshgrid(z_idx, y_idx, x_idx, indexing="ij") d2 = ((zz - cz) / sz) ** 2 + ((yy - cy) / sy) ** 2 + ((xx - cx) / sx) ** 2 blob = np.exp(-d2 / 2) - vol[z_idx[0]:z_idx[-1]+1, y_idx[0]:y_idx[-1]+1, x_idx[0]:x_idx[-1]+1] += blob + vol[ + z_idx[0] : z_idx[-1] + 1, + y_idx[0] : y_idx[-1] + 1, + x_idx[0] : x_idx[-1] + 1, + ] += blob # Add background noise vol += np.random.random(shape) * 0.1 @@ -178,7 +181,7 @@ def generate_synthetic_volume( def benchmark_raw_tiff_write( volume: np.ndarray, output_dir: Path, - compression: Optional[str] = "zlib", + compression: str | None = "zlib", ) -> tuple[float, float]: """ Benchmark raw tifffile write (no FileStore). @@ -255,7 +258,7 @@ def benchmark_register_volume( # Main benchmark sweep # --------------------------------------------------------------------------- def run_benchmark_sweep( - slices_list: List[int], + slices_list: list[int], width: int, height: int, num_repeats: int, @@ -265,9 +268,9 @@ def run_benchmark_sweep( run_put_volume: bool = True, run_register: bool = True, skip_projection: bool = False, -) -> List[BenchmarkResult]: +) -> list[BenchmarkResult]: """Run the full benchmark sweep.""" - results: List[BenchmarkResult] = [] + results: list[BenchmarkResult] = [] # Create temporary directory for benchmark temp_dir = Path(tempfile.mkdtemp(prefix="gently_benchmark_")) @@ -291,10 +294,12 @@ def run_benchmark_sweep( timepoint = 0 for config_idx, num_slices in enumerate(slices_list): - print(f"\n{'='*60}") - print(f"Config {config_idx + 1}/{total_configs}: " - f"slices={num_slices}, shape=({num_slices}, {height}, {width})") - print(f"{'='*60}") + print(f"\n{'=' * 60}") + print( + f"Config {config_idx + 1}/{total_configs}: " + f"slices={num_slices}, shape=({num_slices}, {height}, {width})" + ) + print(f"{'=' * 60}") volume_shape = (num_slices, height, width) @@ -307,22 +312,22 @@ def run_benchmark_sweep( # --- Raw TIFF write (baseline) --- if run_raw: res_raw = BenchmarkResult("raw_tiff_zlib", num_slices, volume_shape) - print(f"\n[Raw TIFF zlib]") + print("\n[Raw TIFF zlib]") # Warmup for w in range(num_warmup): - print(f" Warmup {w+1}/{num_warmup}...", end=" ", flush=True) + print(f" Warmup {w + 1}/{num_warmup}...", end=" ", flush=True) dur, size = benchmark_raw_tiff_write(volume, raw_dir, "zlib") print(f"{dur:.3f}s, {size:.1f} MB") # Timed repeats for r in range(num_repeats): - print(f" Repeat {r+1}/{num_repeats}...", end=" ", flush=True) + print(f" Repeat {r + 1}/{num_repeats}...", end=" ", flush=True) try: dur, size = benchmark_raw_tiff_write(volume, raw_dir, "zlib") res_raw.timings.append(dur) res_raw.sizes_mb.append(size) - print(f"{dur:.3f}s, {size:.1f} MB ({1/dur:.1f} vol/s)") + print(f"{dur:.3f}s, {size:.1f} MB ({1 / dur:.1f} vol/s)") except Exception as e: print(f"ERROR: {e}") res_raw.errors.append(str(e)) @@ -333,26 +338,38 @@ def run_benchmark_sweep( # --- put_volume (full pipeline) --- if run_put_volume: res_put = BenchmarkResult("put_volume", num_slices, volume_shape) - print(f"\n[FileStore.put_volume]") + print("\n[FileStore.put_volume]") # Warmup for w in range(num_warmup): embryo_id = f"embryo_{w % NUM_EMBRYOS}" - print(f" Warmup {w+1}/{num_warmup} ({embryo_id})...", end=" ", flush=True) - dur, size = benchmark_put_volume(store, session_id, embryo_id, timepoint, volume) + print( + f" Warmup {w + 1}/{num_warmup} ({embryo_id})...", + end=" ", + flush=True, + ) + dur, size = benchmark_put_volume( + store, session_id, embryo_id, timepoint, volume + ) timepoint += 1 print(f"{dur:.3f}s, {size:.1f} MB") # Timed repeats for r in range(num_repeats): embryo_id = f"embryo_{r % NUM_EMBRYOS}" - print(f" Repeat {r+1}/{num_repeats} ({embryo_id})...", end=" ", flush=True) + print( + f" Repeat {r + 1}/{num_repeats} ({embryo_id})...", + end=" ", + flush=True, + ) try: - dur, size = benchmark_put_volume(store, session_id, embryo_id, timepoint, volume) + dur, size = benchmark_put_volume( + store, session_id, embryo_id, timepoint, volume + ) timepoint += 1 res_put.timings.append(dur) res_put.sizes_mb.append(size) - print(f"{dur:.3f}s, {size:.1f} MB ({1/dur:.1f} vol/s)") + print(f"{dur:.3f}s, {size:.1f} MB ({1 / dur:.1f} vol/s)") except Exception as e: print(f"ERROR: {e}") res_put.errors.append(str(e)) @@ -363,26 +380,38 @@ def run_benchmark_sweep( # --- register_volume (zero-copy path) --- if run_register: res_reg = BenchmarkResult("register_volume", num_slices, volume_shape) - print(f"\n[FileStore.register_volume]") + print("\n[FileStore.register_volume]") # Warmup for w in range(num_warmup): embryo_id = f"embryo_{w % NUM_EMBRYOS}" - print(f" Warmup {w+1}/{num_warmup} ({embryo_id})...", end=" ", flush=True) - dur, size = benchmark_register_volume(store, session_id, embryo_id, timepoint, volume) + print( + f" Warmup {w + 1}/{num_warmup} ({embryo_id})...", + end=" ", + flush=True, + ) + dur, size = benchmark_register_volume( + store, session_id, embryo_id, timepoint, volume + ) timepoint += 1 print(f"{dur:.3f}s, {size:.1f} MB") # Timed repeats for r in range(num_repeats): embryo_id = f"embryo_{r % NUM_EMBRYOS}" - print(f" Repeat {r+1}/{num_repeats} ({embryo_id})...", end=" ", flush=True) + print( + f" Repeat {r + 1}/{num_repeats} ({embryo_id})...", + end=" ", + flush=True, + ) try: - dur, size = benchmark_register_volume(store, session_id, embryo_id, timepoint, volume) + dur, size = benchmark_register_volume( + store, session_id, embryo_id, timepoint, volume + ) timepoint += 1 res_reg.timings.append(dur) res_reg.sizes_mb.append(size) - print(f"{dur:.3f}s, {size:.1f} MB ({1/dur:.1f} vol/s)") + print(f"{dur:.3f}s, {size:.1f} MB ({1 / dur:.1f} vol/s)") except Exception as e: print(f"ERROR: {e}") res_reg.errors.append(str(e)) @@ -410,20 +439,24 @@ def _print_single_result(res: BenchmarkResult): if not res.timings: print(f" -> {res.approach}: NO SUCCESSFUL RUNS ({len(res.errors)} errors)") return - print(f" -> {res.approach}: {res.vol_per_sec:.2f} vol/s, " - f"mean={res.mean:.3f}s, std={res.std:.3f}s, " - f"avg_size={res.avg_size_mb:.1f} MB") + print( + f" -> {res.approach}: {res.vol_per_sec:.2f} vol/s, " + f"mean={res.mean:.3f}s, std={res.std:.3f}s, " + f"avg_size={res.avg_size_mb:.1f} MB" + ) -def print_results_table(results: List[BenchmarkResult]): +def print_results_table(results: list[BenchmarkResult]): """Print formatted ASCII results table.""" if not results: print("No results to display.") return - header = (f"{'Slices':>6} | {'Approach':>18} | {'Vol/s':>7} | " - f"{'Mean(s)':>7} | {'Std(s)':>6} | {'Min(s)':>6} | " - f"{'Max(s)':>6} | {'Size(MB)':>8} | {'MB/s':>7}") + header = ( + f"{'Slices':>6} | {'Approach':>18} | {'Vol/s':>7} | " + f"{'Mean(s)':>7} | {'Std(s)':>6} | {'Min(s)':>6} | " + f"{'Max(s)':>6} | {'Size(MB)':>8} | {'MB/s':>7}" + ) sep = "-" * len(header) print(f"\n{sep}") @@ -434,19 +467,23 @@ def print_results_table(results: List[BenchmarkResult]): for r in results: if r.timings: - print(f"{r.num_slices:>6} | {r.approach:>18} | " - f"{r.vol_per_sec:>7.2f} | {r.mean:>7.3f} | {r.std:>6.3f} | " - f"{r.min_t:>6.3f} | {r.max_t:>6.3f} | {r.avg_size_mb:>8.1f} | " - f"{r.mb_per_sec:>7.1f}") + print( + f"{r.num_slices:>6} | {r.approach:>18} | " + f"{r.vol_per_sec:>7.2f} | {r.mean:>7.3f} | {r.std:>6.3f} | " + f"{r.min_t:>6.3f} | {r.max_t:>6.3f} | {r.avg_size_mb:>8.1f} | " + f"{r.mb_per_sec:>7.1f}" + ) else: - print(f"{r.num_slices:>6} | {r.approach:>18} | " - f"{'FAIL':>7} | {'---':>7} | {'---':>6} | " - f"{'---':>6} | {'---':>6} | {'---':>8} | {'---':>7}") + print( + f"{r.num_slices:>6} | {r.approach:>18} | " + f"{'FAIL':>7} | {'---':>7} | {'---':>6} | " + f"{'---':>6} | {'---':>6} | {'---':>8} | {'---':>7}" + ) print(sep) -def print_overhead_analysis(results: List[BenchmarkResult]): +def print_overhead_analysis(results: list[BenchmarkResult]): """Print overhead analysis comparing FileStore to raw TIFF.""" from collections import defaultdict @@ -454,9 +491,11 @@ def print_overhead_analysis(results: List[BenchmarkResult]): for r in results: groups[r.num_slices][r.approach] = r - has_data = [(k, v) for k, v in groups.items() - if "raw_tiff_zlib" in v and - ("put_volume" in v or "register_volume" in v)] + has_data = [ + (k, v) + for k, v in groups.items() + if "raw_tiff_zlib" in v and ("put_volume" in v or "register_volume" in v) + ] if not has_data: return @@ -487,7 +526,7 @@ def print_overhead_analysis(results: List[BenchmarkResult]): print(sep) -def save_results_csv(results: List[BenchmarkResult], path: Path, run_params: dict): +def save_results_csv(results: list[BenchmarkResult], path: Path, run_params: dict): """Save results to CSV.""" import csv import json @@ -508,30 +547,44 @@ def save_results_csv(results: List[BenchmarkResult], path: Path, run_params: dic writer.writerow([]) # Summary table - writer.writerow([ - "slices", "approach", "vol_per_sec", "mean_s", "std_s", - "min_s", "max_s", "avg_size_mb", "mb_per_sec", "num_repeats", "errors", - ]) + writer.writerow( + [ + "slices", + "approach", + "vol_per_sec", + "mean_s", + "std_s", + "min_s", + "max_s", + "avg_size_mb", + "mb_per_sec", + "num_repeats", + "errors", + ] + ) for r in results: - writer.writerow([ - r.num_slices, r.approach, - f"{r.vol_per_sec:.4f}" if r.timings else "", - f"{r.mean:.6f}" if r.timings else "", - f"{r.std:.6f}" if r.timings else "", - f"{r.min_t:.6f}" if r.timings else "", - f"{r.max_t:.6f}" if r.timings else "", - f"{r.avg_size_mb:.2f}" if r.sizes_mb else "", - f"{r.mb_per_sec:.2f}" if r.timings else "", - len(r.timings), - "; ".join(r.errors) if r.errors else "", - ]) + writer.writerow( + [ + r.num_slices, + r.approach, + f"{r.vol_per_sec:.4f}" if r.timings else "", + f"{r.mean:.6f}" if r.timings else "", + f"{r.std:.6f}" if r.timings else "", + f"{r.min_t:.6f}" if r.timings else "", + f"{r.max_t:.6f}" if r.timings else "", + f"{r.avg_size_mb:.2f}" if r.sizes_mb else "", + f"{r.mb_per_sec:.2f}" if r.timings else "", + len(r.timings), + "; ".join(r.errors) if r.errors else "", + ] + ) # Per-volume timings writer.writerow([]) writer.writerow(["# Per-volume timings"]) writer.writerow(["slices", "approach", "repeat", "elapsed_s", "size_mb"]) for r in results: - for i, (t, s) in enumerate(zip(r.timings, r.sizes_mb)): + for i, (t, s) in enumerate(zip(r.timings, r.sizes_mb, strict=False)): writer.writerow([r.num_slices, r.approach, i + 1, f"{t:.6f}", f"{s:.2f}"]) print(f"\nResults saved to: {path}") @@ -541,24 +594,45 @@ def save_results_csv(results: List[BenchmarkResult], path: Path, run_params: dic # Main # --------------------------------------------------------------------------- def main(): - parser = argparse.ArgumentParser( - description="Benchmark FileStore volume storage throughput" + parser = argparse.ArgumentParser(description="Benchmark FileStore volume storage throughput") + parser.add_argument( + "--slices", + type=int, + nargs="+", + default=DEFAULT_SLICES, + help=f"Slice counts to test (default: {DEFAULT_SLICES})", + ) + parser.add_argument( + "--width", + type=int, + default=DEFAULT_WIDTH, + help=f"Image width (default: {DEFAULT_WIDTH})", + ) + parser.add_argument( + "--height", + type=int, + default=DEFAULT_HEIGHT, + help=f"Image height (default: {DEFAULT_HEIGHT})", + ) + parser.add_argument( + "--repeats", + type=int, + default=NUM_REPEATS, + help=f"Number of timed repeats (default: {NUM_REPEATS})", + ) + parser.add_argument( + "--warmup", + type=int, + default=NUM_WARMUP, + help=f"Number of warmup runs (default: {NUM_WARMUP})", + ) + parser.add_argument( + "--pattern", + choices=["noise", "gradient", "embryo"], + default="embryo", + help="Volume pattern: noise, gradient, or embryo (default: embryo)", ) - parser.add_argument("--slices", type=int, nargs="+", default=DEFAULT_SLICES, - help=f"Slice counts to test (default: {DEFAULT_SLICES})") - parser.add_argument("--width", type=int, default=DEFAULT_WIDTH, - help=f"Image width (default: {DEFAULT_WIDTH})") - parser.add_argument("--height", type=int, default=DEFAULT_HEIGHT, - help=f"Image height (default: {DEFAULT_HEIGHT})") - parser.add_argument("--repeats", type=int, default=NUM_REPEATS, - help=f"Number of timed repeats (default: {NUM_REPEATS})") - parser.add_argument("--warmup", type=int, default=NUM_WARMUP, - help=f"Number of warmup runs (default: {NUM_WARMUP})") - parser.add_argument("--pattern", choices=["noise", "gradient", "embryo"], - default="embryo", - help="Volume pattern: noise, gradient, or embryo (default: embryo)") - parser.add_argument("--save", action="store_true", - help="Save results to CSV") + parser.add_argument("--save", action="store_true", help="Save results to CSV") args = parser.parse_args() print("FileStore Volume Storage Benchmark") @@ -566,7 +640,7 @@ def main(): print(f" Dimensions: {args.width} x {args.height}") print(f" Pattern: {args.pattern}") print(f" Repeats: {args.repeats} (+ {args.warmup} warmup)") - print(f" Approaches: raw_tiff_zlib / put_volume / register_volume") + print(" Approaches: raw_tiff_zlib / put_volume / register_volume") results = run_benchmark_sweep( slices_list=args.slices, @@ -581,7 +655,10 @@ def main(): print_overhead_analysis(results) if args.save: - csv_path = Path("results") / f"benchmark_gentlystore_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv" + csv_path = ( + Path("results") + / f"benchmark_gentlystore_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv" + ) run_params = { "slices": args.slices, "width": args.width, diff --git a/diagnostics/benchmark_volume_fps.py b/diagnostics/benchmark_volume_fps.py index ff7b2e38..381e3697 100644 --- a/diagnostics/benchmark_volume_fps.py +++ b/diagnostics/benchmark_volume_fps.py @@ -15,20 +15,21 @@ """ import os +import statistics import sys import time -import statistics -from pathlib import Path from dataclasses import dataclass, field -from typing import List, Optional, Tuple from datetime import datetime +from pathlib import Path -import yaml import numpy as np import pymmcore +import yaml # Add dispim-control to path for ophyd device imports -DISPIM_CONTROL_DIR = Path(__file__).resolve().parent.parent / "UsersdispimDocumentsGitHubdispim-control" +DISPIM_CONTROL_DIR = ( + Path(__file__).resolve().parent.parent / "UsersdispimDocumentsGitHubdispim-control" +) if str(DISPIM_CONTROL_DIR) not in sys.path: sys.path.insert(0, str(DISPIM_CONTROL_DIR)) @@ -41,10 +42,10 @@ PIEZO_DEVICE = "PiezoStage:P:34" # Default scan parameters (same as run_multi_embryo_volumes.py) -DEFAULT_GALVO_AMPLITUDE = 0.5 # degrees -DEFAULT_GALVO_CENTER = 0.0 # degrees -DEFAULT_PIEZO_AMPLITUDE = 25.0 # um -DEFAULT_PIEZO_CENTER = 50.0 # um +DEFAULT_GALVO_AMPLITUDE = 0.5 # degrees +DEFAULT_GALVO_CENTER = 0.0 # degrees +DEFAULT_PIEZO_AMPLITUDE = 25.0 # um +DEFAULT_PIEZO_CENTER = 50.0 # um DEFAULT_LASER_CONFIG = "488 and 561" DEFAULT_CAMERA_ROI = (128, 896, 2048, 512) # (x, y, width, height) @@ -69,10 +70,30 @@ # Simulated embryo calibration profiles for round-robin reconfig test. # Each entry represents a different embryo with distinct galvo/piezo settings. EMBRYO_PROFILES = [ - {"galvo_amplitude": 0.50, "galvo_center": 0.00, "piezo_amplitude": 25.0, "piezo_center": 50.0}, - {"galvo_amplitude": 0.45, "galvo_center": 0.12, "piezo_amplitude": 22.5, "piezo_center": 55.0}, - {"galvo_amplitude": 0.55, "galvo_center": -0.08, "piezo_amplitude": 27.5, "piezo_center": 45.0}, - {"galvo_amplitude": 0.48, "galvo_center": 0.05, "piezo_amplitude": 24.0, "piezo_center": 52.0}, + { + "galvo_amplitude": 0.50, + "galvo_center": 0.00, + "piezo_amplitude": 25.0, + "piezo_center": 50.0, + }, + { + "galvo_amplitude": 0.45, + "galvo_center": 0.12, + "piezo_amplitude": 22.5, + "piezo_center": 55.0, + }, + { + "galvo_amplitude": 0.55, + "galvo_center": -0.08, + "piezo_amplitude": 27.5, + "piezo_center": 45.0, + }, + { + "galvo_amplitude": 0.48, + "galvo_center": 0.05, + "piezo_amplitude": 24.0, + "piezo_center": 52.0, + }, ] @@ -84,9 +105,9 @@ class BenchmarkResult: approach: str num_slices: int exposure_ms: float - timings: List[float] = field(default_factory=list) - image_counts: List[int] = field(default_factory=list) - errors: List[str] = field(default_factory=list) + timings: list[float] = field(default_factory=list) + image_counts: list[int] = field(default_factory=list) + errors: list[str] = field(default_factory=list) @property def mean(self) -> float: @@ -116,13 +137,13 @@ def total_images(self) -> int: # --------------------------------------------------------------------------- # Config loading # --------------------------------------------------------------------------- -def load_config(path: str) -> Tuple[str, str]: +def load_config(path: str) -> tuple[str, str]: """Read config.yml and return (mm_dir, config_file).""" cfg_path = Path(path) if not cfg_path.exists(): raise FileNotFoundError(f"Config file not found: {cfg_path}") - with open(cfg_path, "r") as f: + with open(cfg_path) as f: cfg = yaml.safe_load(f) mm_dir = cfg["mmdirectory"] @@ -236,8 +257,8 @@ def configure_hardware_raw(core: pymmcore.CMMCore, num_slices: int, exposure_ms: def acquire_volume_raw( core: pymmcore.CMMCore, num_slices: int, - save_dir: Optional[Path] = None, -) -> Tuple[int, float]: + save_dir: Path | None = None, +) -> tuple[int, float]: """ Trigger SPIM and collect images from the circular buffer. @@ -329,10 +350,10 @@ def create_ophyd_devices(core: pymmcore.CMMCore): which imports a theme module that may not be available). """ from dispim_control.devices import ( - DiSPIMScanner, DiSPIMCamera, - DiSPIMPiezo, DiSPIMLaserControl, + DiSPIMPiezo, + DiSPIMScanner, DiSPIMVolumeScanner, ) @@ -356,8 +377,8 @@ def acquire_volume_ophyd( volume_scanner, num_slices: int, exposure_ms: float, - save_dir: Optional[Path] = None, -) -> Tuple[int, float]: + save_dir: Path | None = None, +) -> tuple[int, float]: """ Configure + trigger via ophyd VolumeScanner. @@ -392,8 +413,9 @@ def acquire_volume_ophyd( # --------------------------------------------------------------------------- # Ophyd burst approach -- configure once, skip per-volume reset # --------------------------------------------------------------------------- -def configure_ophyd_burst(volume_scanner, num_slices: int, exposure_ms: float, - core: pymmcore.CMMCore): +def configure_ophyd_burst( + volume_scanner, num_slices: int, exposure_ms: float, core: pymmcore.CMMCore +): """ Configure ophyd devices once for a burst of volumes. @@ -417,8 +439,9 @@ def configure_ophyd_burst(volume_scanner, num_slices: int, exposure_ms: float, time.sleep(0.1) -def acquire_volume_ophyd_burst(volume_scanner, num_slices: int, - core: pymmcore.CMMCore) -> Tuple[int, float]: +def acquire_volume_ophyd_burst( + volume_scanner, num_slices: int, core: pymmcore.CMMCore +) -> tuple[int, float]: """ Acquire one volume using ophyd devices but without per-volume reset. @@ -494,8 +517,9 @@ def cleanup_ophyd_burst(volume_scanner, core: pymmcore.CMMCore): # --------------------------------------------------------------------------- # Burst reconfig approach -- reconfigure galvo/piezo per volume (round-robin) # --------------------------------------------------------------------------- -def configure_burst_reconfig(volume_scanner, num_slices: int, exposure_ms: float, - core: pymmcore.CMMCore): +def configure_burst_reconfig( + volume_scanner, num_slices: int, exposure_ms: float, core: pymmcore.CMMCore +): """ One-time setup for burst_reconfig: camera, scanner X-axis & timing, lasers. @@ -519,9 +543,11 @@ def configure_burst_reconfig(volume_scanner, num_slices: int, exposure_ms: float def acquire_volume_burst_reconfig( - volume_scanner, num_slices: int, profile_idx: int, + volume_scanner, + num_slices: int, + profile_idx: int, core: pymmcore.CMMCore, -) -> Tuple[int, float]: +) -> tuple[int, float]: """ Acquire one volume after reconfiguring galvo/piezo for a specific embryo. @@ -535,7 +561,6 @@ def acquire_volume_burst_reconfig( profile = EMBRYO_PROFILES[profile_idx % len(EMBRYO_PROFILES)] camera_name = volume_scanner.camera.name scanner = volume_scanner.scanner - piezo = volume_scanner.piezo core.clearCircularBuffer() @@ -548,16 +573,12 @@ def acquire_volume_burst_reconfig( t0 = time.perf_counter() # Reconfigure galvo Y-axis for this embryo - core.setProperty(GALVO_DEVICE, "SingleAxisYAmplitude(deg)", - float(profile["galvo_amplitude"])) - core.setProperty(GALVO_DEVICE, "SingleAxisYOffset(deg)", - float(profile["galvo_center"])) + core.setProperty(GALVO_DEVICE, "SingleAxisYAmplitude(deg)", float(profile["galvo_amplitude"])) + core.setProperty(GALVO_DEVICE, "SingleAxisYOffset(deg)", float(profile["galvo_center"])) # Reconfigure piezo for this embryo - core.setProperty(PIEZO_DEVICE, "SingleAxisAmplitude(um)", - float(profile["piezo_amplitude"])) - core.setProperty(PIEZO_DEVICE, "SingleAxisOffset(um)", - float(profile["piezo_center"])) + core.setProperty(PIEZO_DEVICE, "SingleAxisAmplitude(um)", float(profile["piezo_amplitude"])) + core.setProperty(PIEZO_DEVICE, "SingleAxisOffset(um)", float(profile["piezo_center"])) core.setProperty(PIEZO_DEVICE, "SPIMState", "Armed") time.sleep(0.3) @@ -595,9 +616,11 @@ def acquire_volume_burst_reconfig( # Burst reconfig with waitForDevice -- replaces time.sleep() with MMCore API # --------------------------------------------------------------------------- def acquire_volume_burst_reconfig_wfd( - volume_scanner, num_slices: int, profile_idx: int, + volume_scanner, + num_slices: int, + profile_idx: int, core: pymmcore.CMMCore, -) -> Tuple[int, float]: +) -> tuple[int, float]: """ Same as burst_reconfig but uses core.waitForDevice() instead of time.sleep() to wait for hardware readiness. @@ -608,7 +631,6 @@ def acquire_volume_burst_reconfig_wfd( """ profile = EMBRYO_PROFILES[profile_idx % len(EMBRYO_PROFILES)] camera_name = volume_scanner.camera.name - scanner = volume_scanner.scanner core.clearCircularBuffer() @@ -621,17 +643,13 @@ def acquire_volume_burst_reconfig_wfd( t0 = time.perf_counter() # Reconfigure galvo Y-axis for this embryo - core.setProperty(GALVO_DEVICE, "SingleAxisYAmplitude(deg)", - float(profile["galvo_amplitude"])) - core.setProperty(GALVO_DEVICE, "SingleAxisYOffset(deg)", - float(profile["galvo_center"])) + core.setProperty(GALVO_DEVICE, "SingleAxisYAmplitude(deg)", float(profile["galvo_amplitude"])) + core.setProperty(GALVO_DEVICE, "SingleAxisYOffset(deg)", float(profile["galvo_center"])) core.waitForDevice(GALVO_DEVICE) # Reconfigure piezo for this embryo - core.setProperty(PIEZO_DEVICE, "SingleAxisAmplitude(um)", - float(profile["piezo_amplitude"])) - core.setProperty(PIEZO_DEVICE, "SingleAxisOffset(um)", - float(profile["piezo_center"])) + core.setProperty(PIEZO_DEVICE, "SingleAxisAmplitude(um)", float(profile["piezo_amplitude"])) + core.setProperty(PIEZO_DEVICE, "SingleAxisOffset(um)", float(profile["piezo_center"])) core.setProperty(PIEZO_DEVICE, "SPIMState", "Armed") core.waitForDevice(PIEZO_DEVICE) @@ -686,16 +704,16 @@ def _save_volume(volume: np.ndarray, save_dir: Path, approach: str, num_slices: # --------------------------------------------------------------------------- def run_benchmark_sweep( core: pymmcore.CMMCore, - slices_list: List[int], - exposures_list: List[float], + slices_list: list[int], + exposures_list: list[float], num_repeats: int, num_warmup: int, run_raw: bool, run_ophyd: bool, - save_dir: Optional[Path] = None, -) -> List[BenchmarkResult]: + save_dir: Path | None = None, +) -> list[BenchmarkResult]: """Run the full parameter sweep and return results.""" - results: List[BenchmarkResult] = [] + results: list[BenchmarkResult] = [] volume_scanner = None if run_ophyd: @@ -709,15 +727,17 @@ def run_benchmark_sweep( for num_slices in slices_list: for exposure_ms in exposures_list: config_idx += 1 - print(f"\n{'='*60}") - print(f"Config {config_idx}/{total_configs}: " - f"slices={num_slices}, exposure={exposure_ms}ms") - print(f"{'='*60}") + print(f"\n{'=' * 60}") + print( + f"Config {config_idx}/{total_configs}: " + f"slices={num_slices}, exposure={exposure_ms}ms" + ) + print(f"{'=' * 60}") # --- Raw MMCore --- if run_raw: res_raw = BenchmarkResult("raw", num_slices, exposure_ms) - print(f"\n[Raw MMCore] Configuring hardware...") + print("\n[Raw MMCore] Configuring hardware...") try: configure_hardware_raw(core, num_slices, exposure_ms) except Exception as e: @@ -728,7 +748,7 @@ def run_benchmark_sweep( # Warm-up for w in range(num_warmup): - print(f" Warm-up {w+1}/{num_warmup}...", end=" ", flush=True) + print(f" Warm-up {w + 1}/{num_warmup}...", end=" ", flush=True) try: cnt, dur = acquire_volume_raw(core, num_slices) print(f"{cnt} imgs, {dur:.3f}s") @@ -737,15 +757,15 @@ def run_benchmark_sweep( # Timed repeats for r in range(num_repeats): - print(f" Repeat {r+1}/{num_repeats}...", end=" ", flush=True) + print(f" Repeat {r + 1}/{num_repeats}...", end=" ", flush=True) try: cnt, dur = acquire_volume_raw(core, num_slices, save_dir=save_dir) res_raw.timings.append(dur) res_raw.image_counts.append(cnt) - print(f"{cnt} imgs, {dur:.3f}s ({1.0/dur:.1f} vol/s)") + print(f"{cnt} imgs, {dur:.3f}s ({1.0 / dur:.1f} vol/s)") except Exception as e: print(f"ERROR: {e}") - res_raw.errors.append(f"repeat {r+1}: {e}") + res_raw.errors.append(f"repeat {r + 1}: {e}") # Cleanup after raw batch (lasers off) cleanup_raw(core) @@ -758,27 +778,34 @@ def run_benchmark_sweep( # Warm-up for w in range(num_warmup): - print(f" [Ophyd] Warm-up {w+1}/{num_warmup}...", end=" ", flush=True) + print( + f" [Ophyd] Warm-up {w + 1}/{num_warmup}...", + end=" ", + flush=True, + ) try: - cnt, dur = acquire_volume_ophyd( - volume_scanner, num_slices, exposure_ms) + cnt, dur = acquire_volume_ophyd(volume_scanner, num_slices, exposure_ms) print(f"{cnt} imgs, {dur:.3f}s") except Exception as e: print(f"ERROR: {e}") # Timed repeats for r in range(num_repeats): - print(f" [Ophyd] Repeat {r+1}/{num_repeats}...", end=" ", flush=True) + print( + f" [Ophyd] Repeat {r + 1}/{num_repeats}...", + end=" ", + flush=True, + ) try: cnt, dur = acquire_volume_ophyd( - volume_scanner, num_slices, exposure_ms, - save_dir=save_dir) + volume_scanner, num_slices, exposure_ms, save_dir=save_dir + ) res_ophyd.timings.append(dur) res_ophyd.image_counts.append(cnt) - print(f"{cnt} imgs, {dur:.3f}s ({1.0/dur:.1f} vol/s)") + print(f"{cnt} imgs, {dur:.3f}s ({1.0 / dur:.1f} vol/s)") except Exception as e: print(f"ERROR: {e}") - res_ophyd.errors.append(f"repeat {r+1}: {e}") + res_ophyd.errors.append(f"repeat {r + 1}: {e}") results.append(res_ophyd) _print_single_result(res_ophyd) @@ -786,10 +813,9 @@ def run_benchmark_sweep( # --- Ophyd burst (configure once, no per-volume reset) --- if run_ophyd and volume_scanner is not None: res_burst = BenchmarkResult("ophyd_burst", num_slices, exposure_ms) - print(f"\n[Ophyd Burst] Configuring once...") + print("\n[Ophyd Burst] Configuring once...") try: - configure_ophyd_burst(volume_scanner, num_slices, - exposure_ms, core) + configure_ophyd_burst(volume_scanner, num_slices, exposure_ms, core) except Exception as e: print(f" ERROR configuring ophyd burst: {e}") res_burst.errors.append(f"configure: {e}") @@ -798,26 +824,24 @@ def run_benchmark_sweep( # Warm-up for w in range(num_warmup): - print(f" Warm-up {w+1}/{num_warmup}...", end=" ", flush=True) + print(f" Warm-up {w + 1}/{num_warmup}...", end=" ", flush=True) try: - cnt, dur = acquire_volume_ophyd_burst( - volume_scanner, num_slices, core) + cnt, dur = acquire_volume_ophyd_burst(volume_scanner, num_slices, core) print(f"{cnt} imgs, {dur:.3f}s") except Exception as e: print(f"ERROR: {e}") # Timed repeats for r in range(num_repeats): - print(f" Repeat {r+1}/{num_repeats}...", end=" ", flush=True) + print(f" Repeat {r + 1}/{num_repeats}...", end=" ", flush=True) try: - cnt, dur = acquire_volume_ophyd_burst( - volume_scanner, num_slices, core) + cnt, dur = acquire_volume_ophyd_burst(volume_scanner, num_slices, core) res_burst.timings.append(dur) res_burst.image_counts.append(cnt) - print(f"{cnt} imgs, {dur:.3f}s ({1.0/dur:.1f} vol/s)") + print(f"{cnt} imgs, {dur:.3f}s ({1.0 / dur:.1f} vol/s)") except Exception as e: print(f"ERROR: {e}") - res_burst.errors.append(f"repeat {r+1}: {e}") + res_burst.errors.append(f"repeat {r + 1}: {e}") cleanup_ophyd_burst(volume_scanner, core) results.append(res_burst) @@ -826,11 +850,12 @@ def run_benchmark_sweep( # --- Burst reconfig (round-robin galvo/piezo per volume) --- if run_ophyd and volume_scanner is not None: res_reconfig = BenchmarkResult("burst_reconfig", num_slices, exposure_ms) - print(f"\n[Burst Reconfig] Configuring once, " - f"cycling {len(EMBRYO_PROFILES)} embryo profiles...") + print( + f"\n[Burst Reconfig] Configuring once, " + f"cycling {len(EMBRYO_PROFILES)} embryo profiles..." + ) try: - configure_burst_reconfig(volume_scanner, num_slices, - exposure_ms, core) + configure_burst_reconfig(volume_scanner, num_slices, exposure_ms, core) except Exception as e: print(f" ERROR configuring burst reconfig: {e}") res_reconfig.errors.append(f"configure: {e}") @@ -839,12 +864,15 @@ def run_benchmark_sweep( # Warm-up (cycle through profiles) for w in range(num_warmup): - print(f" Warm-up {w+1}/{num_warmup} " - f"(profile {w % len(EMBRYO_PROFILES)})...", - end=" ", flush=True) + print( + f" Warm-up {w + 1}/{num_warmup} (profile {w % len(EMBRYO_PROFILES)})...", + end=" ", + flush=True, + ) try: cnt, dur = acquire_volume_burst_reconfig( - volume_scanner, num_slices, w, core) + volume_scanner, num_slices, w, core + ) print(f"{cnt} imgs, {dur:.3f}s") except Exception as e: print(f"ERROR: {e}") @@ -852,17 +880,21 @@ def run_benchmark_sweep( # Timed repeats (cycle through profiles) for r in range(num_repeats): pidx = r % len(EMBRYO_PROFILES) - print(f" Repeat {r+1}/{num_repeats} " - f"(profile {pidx})...", end=" ", flush=True) + print( + f" Repeat {r + 1}/{num_repeats} (profile {pidx})...", + end=" ", + flush=True, + ) try: cnt, dur = acquire_volume_burst_reconfig( - volume_scanner, num_slices, r, core) + volume_scanner, num_slices, r, core + ) res_reconfig.timings.append(dur) res_reconfig.image_counts.append(cnt) - print(f"{cnt} imgs, {dur:.3f}s ({1.0/dur:.1f} vol/s)") + print(f"{cnt} imgs, {dur:.3f}s ({1.0 / dur:.1f} vol/s)") except Exception as e: print(f"ERROR: {e}") - res_reconfig.errors.append(f"repeat {r+1}: {e}") + res_reconfig.errors.append(f"repeat {r + 1}: {e}") cleanup_ophyd_burst(volume_scanner, core) results.append(res_reconfig) @@ -871,11 +903,12 @@ def run_benchmark_sweep( # --- Burst reconfig with waitForDevice (no time.sleep) --- if run_ophyd and volume_scanner is not None: res_wfd = BenchmarkResult("reconfig_wfd", num_slices, exposure_ms) - print(f"\n[Reconfig WFD] Configuring once, " - f"using waitForDevice() instead of time.sleep()...") + print( + "\n[Reconfig WFD] Configuring once, " + "using waitForDevice() instead of time.sleep()..." + ) try: - configure_burst_reconfig(volume_scanner, num_slices, - exposure_ms, core) + configure_burst_reconfig(volume_scanner, num_slices, exposure_ms, core) except Exception as e: print(f" ERROR configuring reconfig_wfd: {e}") res_wfd.errors.append(f"configure: {e}") @@ -884,12 +917,15 @@ def run_benchmark_sweep( # Warm-up (cycle through profiles) for w in range(num_warmup): - print(f" Warm-up {w+1}/{num_warmup} " - f"(profile {w % len(EMBRYO_PROFILES)})...", - end=" ", flush=True) + print( + f" Warm-up {w + 1}/{num_warmup} (profile {w % len(EMBRYO_PROFILES)})...", + end=" ", + flush=True, + ) try: cnt, dur = acquire_volume_burst_reconfig_wfd( - volume_scanner, num_slices, w, core) + volume_scanner, num_slices, w, core + ) print(f"{cnt} imgs, {dur:.3f}s") except Exception as e: print(f"ERROR: {e}") @@ -897,17 +933,21 @@ def run_benchmark_sweep( # Timed repeats (cycle through profiles) for r in range(num_repeats): pidx = r % len(EMBRYO_PROFILES) - print(f" Repeat {r+1}/{num_repeats} " - f"(profile {pidx})...", end=" ", flush=True) + print( + f" Repeat {r + 1}/{num_repeats} (profile {pidx})...", + end=" ", + flush=True, + ) try: cnt, dur = acquire_volume_burst_reconfig_wfd( - volume_scanner, num_slices, r, core) + volume_scanner, num_slices, r, core + ) res_wfd.timings.append(dur) res_wfd.image_counts.append(cnt) - print(f"{cnt} imgs, {dur:.3f}s ({1.0/dur:.1f} vol/s)") + print(f"{cnt} imgs, {dur:.3f}s ({1.0 / dur:.1f} vol/s)") except Exception as e: print(f"ERROR: {e}") - res_wfd.errors.append(f"repeat {r+1}: {e}") + res_wfd.errors.append(f"repeat {r + 1}: {e}") cleanup_ophyd_burst(volume_scanner, core) results.append(res_wfd) @@ -922,22 +962,25 @@ def run_benchmark_sweep( def _print_single_result(res: BenchmarkResult): """Print a single intermediate result.""" if not res.timings: - print(f" -> {res.approach}: NO SUCCESSFUL RUNS " - f"({len(res.errors)} errors)") + print(f" -> {res.approach}: NO SUCCESSFUL RUNS ({len(res.errors)} errors)") return - print(f" -> {res.approach}: {res.vol_per_sec:.2f} vol/s, " - f"mean={res.mean:.3f}s, std={res.std:.3f}s") + print( + f" -> {res.approach}: {res.vol_per_sec:.2f} vol/s, " + f"mean={res.mean:.3f}s, std={res.std:.3f}s" + ) -def print_results_table(results: List[BenchmarkResult]): +def print_results_table(results: list[BenchmarkResult]): """Print formatted ASCII results table.""" if not results: print("No results to display.") return - header = (f"{'Slices':>6} | {'Exp(ms)':>7} | {'Approach':>14} | " - f"{'Vol/s':>7} | {'Mean(s)':>7} | {'Std(s)':>7} | " - f"{'Min(s)':>7} | {'Max(s)':>7} | {'Images':>6}") + header = ( + f"{'Slices':>6} | {'Exp(ms)':>7} | {'Approach':>14} | " + f"{'Vol/s':>7} | {'Mean(s)':>7} | {'Std(s)':>7} | " + f"{'Min(s)':>7} | {'Max(s)':>7} | {'Images':>6}" + ) sep = "-" * len(header) print(f"\n{sep}") @@ -948,35 +991,43 @@ def print_results_table(results: List[BenchmarkResult]): for r in results: if r.timings: - print(f"{r.num_slices:>6} | {r.exposure_ms:>7.1f} | {r.approach:>14} | " - f"{r.vol_per_sec:>7.2f} | {r.mean:>7.3f} | {r.std:>7.3f} | " - f"{r.min_t:>7.3f} | {r.max_t:>7.3f} | {r.total_images:>6}") + print( + f"{r.num_slices:>6} | {r.exposure_ms:>7.1f} | {r.approach:>14} | " + f"{r.vol_per_sec:>7.2f} | {r.mean:>7.3f} | {r.std:>7.3f} | " + f"{r.min_t:>7.3f} | {r.max_t:>7.3f} | {r.total_images:>6}" + ) else: - print(f"{r.num_slices:>6} | {r.exposure_ms:>7.1f} | {r.approach:>14} | " - f"{'FAIL':>7} | {'---':>7} | {'---':>7} | " - f"{'---':>7} | {'---':>7} | {'---':>6}") + print( + f"{r.num_slices:>6} | {r.exposure_ms:>7.1f} | {r.approach:>14} | " + f"{'FAIL':>7} | {'---':>7} | {'---':>7} | " + f"{'---':>7} | {'---':>7} | {'---':>6}" + ) print(sep) -def print_summary(results: List[BenchmarkResult]): +def print_summary(results: list[BenchmarkResult]): """Print overhead analysis comparing ophyd and ophyd_burst vs raw.""" from collections import defaultdict + groups = defaultdict(dict) for r in results: groups[(r.num_slices, r.exposure_ms)][r.approach] = r # Need at least raw + one ophyd variant - has_data = [(k, v) for k, v in groups.items() - if "raw" in v and ("ophyd" in v or "ophyd_burst" in v - or "burst_reconfig" in v - or "reconfig_wfd" in v)] + has_data = [ + (k, v) + for k, v in groups.items() + if "raw" in v + and ("ophyd" in v or "ophyd_burst" in v or "burst_reconfig" in v or "reconfig_wfd" in v) + ] if not has_data: return print() - header = (f"{'Slices':>6} | {'Exp(ms)':>7} | " - f"{'Approach':>14} | {'vs Raw(ms)':>10} | {'vs Raw(%)':>9}") + header = ( + f"{'Slices':>6} | {'Exp(ms)':>7} | {'Approach':>14} | {'vs Raw(ms)':>10} | {'vs Raw(%)':>9}" + ) sep = "-" * len(header) print(sep) @@ -995,14 +1046,15 @@ def print_summary(results: List[BenchmarkResult]): other_mean = approaches[label].mean overhead_ms = (other_mean - raw_mean) * 1000.0 overhead_pct = ((other_mean - raw_mean) / raw_mean) * 100.0 - print(f"{ns:>6} | {exp:>7.1f} | {label:>14} | " - f"{overhead_ms:>+10.1f} | {overhead_pct:>+8.1f}%") + print( + f"{ns:>6} | {exp:>7.1f} | {label:>14} | " + f"{overhead_ms:>+10.1f} | {overhead_pct:>+8.1f}%" + ) print(sep) -def save_results_csv(results: List[BenchmarkResult], path: Path, - run_params: dict): +def save_results_csv(results: list[BenchmarkResult], path: Path, run_params: dict): """Write benchmark results to a CSV file with full metadata.""" import csv import json @@ -1031,37 +1083,63 @@ def save_results_csv(results: List[BenchmarkResult], path: Path, writer.writerow([]) # --- Summary table --- - writer.writerow([ - "slices", "exposure_ms", "approach", - "vol_per_sec", "mean_s", "std_s", "min_s", "max_s", - "total_images", "num_repeats", "errors", - ]) + writer.writerow( + [ + "slices", + "exposure_ms", + "approach", + "vol_per_sec", + "mean_s", + "std_s", + "min_s", + "max_s", + "total_images", + "num_repeats", + "errors", + ] + ) for r in results: - writer.writerow([ - r.num_slices, r.exposure_ms, r.approach, - f"{r.vol_per_sec:.4f}" if r.timings else "", - f"{r.mean:.6f}" if r.timings else "", - f"{r.std:.6f}" if r.timings else "", - f"{r.min_t:.6f}" if r.timings else "", - f"{r.max_t:.6f}" if r.timings else "", - r.total_images, - len(r.timings), - "; ".join(r.errors) if r.errors else "", - ]) + writer.writerow( + [ + r.num_slices, + r.exposure_ms, + r.approach, + f"{r.vol_per_sec:.4f}" if r.timings else "", + f"{r.mean:.6f}" if r.timings else "", + f"{r.std:.6f}" if r.timings else "", + f"{r.min_t:.6f}" if r.timings else "", + f"{r.max_t:.6f}" if r.timings else "", + r.total_images, + len(r.timings), + "; ".join(r.errors) if r.errors else "", + ] + ) # --- Per-volume raw timings --- writer.writerow([]) writer.writerow(["# Per-volume timings (seconds)"]) - writer.writerow([ - "slices", "exposure_ms", "approach", - "repeat", "elapsed_s", "image_count", - ]) + writer.writerow( + [ + "slices", + "exposure_ms", + "approach", + "repeat", + "elapsed_s", + "image_count", + ] + ) for r in results: - for i, (t, cnt) in enumerate(zip(r.timings, r.image_counts)): - writer.writerow([ - r.num_slices, r.exposure_ms, r.approach, - i + 1, f"{t:.6f}", cnt, - ]) + for i, (t, cnt) in enumerate(zip(r.timings, r.image_counts, strict=False)): + writer.writerow( + [ + r.num_slices, + r.exposure_ms, + r.approach, + i + 1, + f"{t:.6f}", + cnt, + ] + ) print(f"\nResults saved to: {path}") @@ -1076,8 +1154,10 @@ def main(): print(f" Slices series: {SLICES_SERIES}") print(f" Exposure: {EXPOSURE_MS} ms") print(f" Repeats: {NUM_REPEATS} (+ {NUM_WARMUP} warmup)") - print(f" Approaches: raw / ophyd / ophyd_burst / burst_reconfig / reconfig_wfd") - print(f" Embryo profiles: {len(EMBRYO_PROFILES)} (for burst_reconfig & reconfig_wfd round-robin)") + print(" Approaches: raw / ophyd / ophyd_burst / burst_reconfig / reconfig_wfd") + print( + f" Embryo profiles: {len(EMBRYO_PROFILES)} (for burst_reconfig & reconfig_wfd round-robin)" + ) # Load config and initialize mm_dir, config_file = load_config(config_path) @@ -1099,7 +1179,9 @@ def main(): print_summary(results) # Save CSV - csv_path = Path("results") / f"benchmark_volume_fps_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv" + csv_path = ( + Path("results") / f"benchmark_volume_fps_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv" + ) run_params = { "slices": SLICES_SERIES, "exposure_ms": EXPOSURE_MS, diff --git a/diagnostics/measure_centering_error.py b/diagnostics/measure_centering_error.py index 0f159c79..c8c22ef4 100644 --- a/diagnostics/measure_centering_error.py +++ b/diagnostics/measure_centering_error.py @@ -6,8 +6,9 @@ """ import json -import numpy as np + import matplotlib.pyplot as plt +import numpy as np from PIL import Image # Load the image after moving to embryo @@ -21,6 +22,7 @@ # Store clicked position clicked_pos = [None, None] + def on_click(event): if event.xdata is not None and event.ydata is not None: clicked_pos[0] = event.xdata @@ -47,48 +49,56 @@ def on_click(event): "error_y_direction": "BELOW" if error_y > 0 else "ABOVE", "error_x_um": float(error_x_um), "error_y_um": float(error_y_um), - "um_per_pixel": um_per_pixel + "um_per_pixel": um_per_pixel, } # Write to file - with open(OUTPUT_FILE, 'w') as f: + with open(OUTPUT_FILE, "w") as f: json.dump(result, f, indent=2) - print(f"\n{'='*50}") + print(f"\n{'=' * 50}") print(f"CLICKED POSITION: ({clicked_pos[0]:.1f}, {clicked_pos[1]:.1f})") print(f"CENTER POSITION: ({CENTER_X}, {CENTER_Y})") - print(f"{'='*50}") + print(f"{'=' * 50}") print(f"ERROR X: {error_x:+.1f} pixels ({result['error_x_direction']} of center)") print(f"ERROR Y: {error_y:+.1f} pixels ({result['error_y_direction']} center)") - print(f"{'='*50}") + print(f"{'=' * 50}") print(f"ERROR X: {error_x_um:+.1f} um") print(f"ERROR Y: {error_y_um:+.1f} um") - print(f"{'='*50}") + print(f"{'=' * 50}") print(f"\nSaved to: {OUTPUT_FILE}") # Update the plot with clicked marker - ax.plot(clicked_pos[0], clicked_pos[1], 'go', markersize=15, markeredgewidth=3, - markerfacecolor='none', label='Actual embryo position') + ax.plot( + clicked_pos[0], + clicked_pos[1], + "go", + markersize=15, + markeredgewidth=3, + markerfacecolor="none", + label="Actual embryo position", + ) ax.legend() fig.canvas.draw() + # Load image img = np.array(Image.open(IMAGE_PATH)) # Create figure fig, ax = plt.subplots(figsize=(12, 12)) -ax.imshow(img, cmap='gray') +ax.imshow(img, cmap="gray") # Draw crosshairs at center -ax.axhline(CENTER_Y, color='red', linestyle='--', alpha=0.7, linewidth=1, label='Center') -ax.axvline(CENTER_X, color='red', linestyle='--', alpha=0.7, linewidth=1) -ax.plot(CENTER_X, CENTER_Y, 'r+', markersize=30, markeredgewidth=2) +ax.axhline(CENTER_Y, color="red", linestyle="--", alpha=0.7, linewidth=1, label="Center") +ax.axvline(CENTER_X, color="red", linestyle="--", alpha=0.7, linewidth=1) +ax.plot(CENTER_X, CENTER_Y, "r+", markersize=30, markeredgewidth=2) ax.set_title("Click on where EMBRYO 3 actually is\n(Red crosshairs = center where it SHOULD be)") ax.legend() # Connect click event -fig.canvas.mpl_connect('button_press_event', on_click) +fig.canvas.mpl_connect("button_press_event", on_click) print("\nClick on where embryo_3 actually appears in the image.") print("The red crosshairs show the center (where it should be).\n") diff --git a/diagnostics/plot_benchmark_results.py b/diagnostics/plot_benchmark_results.py index 2eba25de..c3b5cdf1 100644 --- a/diagnostics/plot_benchmark_results.py +++ b/diagnostics/plot_benchmark_results.py @@ -10,14 +10,13 @@ python diagnostics/plot_benchmark_results.py results/benchmark_volume_fps_20260127_123405.csv """ -import sys import csv -from pathlib import Path +import sys from collections import defaultdict +from pathlib import Path -import numpy as np import matplotlib.pyplot as plt -import matplotlib.ticker as ticker +import numpy as np # --------------------------------------------------------------------------- @@ -29,7 +28,7 @@ def parse_benchmark_csv(path: Path) -> dict: summary = [] per_volume = [] - with open(path, "r") as f: + with open(path) as f: reader = csv.reader(f) section = "metadata" @@ -55,27 +54,31 @@ def parse_benchmark_csv(path: Path) -> dict: continue if section == "summary": - summary.append({ - "slices": int(row[0]), - "exposure_ms": float(row[1]), - "approach": row[2], - "vol_per_sec": float(row[3]) if row[3] else None, - "mean_s": float(row[4]) if row[4] else None, - "std_s": float(row[5]) if row[5] else None, - "min_s": float(row[6]) if row[6] else None, - "max_s": float(row[7]) if row[7] else None, - "total_images": int(row[8]) if row[8] else 0, - "num_repeats": int(row[9]) if row[9] else 0, - }) + summary.append( + { + "slices": int(row[0]), + "exposure_ms": float(row[1]), + "approach": row[2], + "vol_per_sec": float(row[3]) if row[3] else None, + "mean_s": float(row[4]) if row[4] else None, + "std_s": float(row[5]) if row[5] else None, + "min_s": float(row[6]) if row[6] else None, + "max_s": float(row[7]) if row[7] else None, + "total_images": int(row[8]) if row[8] else 0, + "num_repeats": int(row[9]) if row[9] else 0, + } + ) elif section == "per_volume": - per_volume.append({ - "slices": int(row[0]), - "exposure_ms": float(row[1]), - "approach": row[2], - "repeat": int(row[3]), - "elapsed_s": float(row[4]), - "image_count": int(row[5]), - }) + per_volume.append( + { + "slices": int(row[0]), + "exposure_ms": float(row[1]), + "approach": row[2], + "repeat": int(row[3]), + "elapsed_s": float(row[4]), + "image_count": int(row[5]), + } + ) return {"metadata": metadata, "summary": summary, "per_volume": per_volume} @@ -102,18 +105,18 @@ def plot_throughput(data: dict, output_dir: Path): # Colors and labels color_map = { - "raw": "#2563eb", - "ophyd": "#dc2626", - "ophyd_burst": "#16a34a", - "burst_reconfig": "#ea580c", - "reconfig_wfd": "#7c3aed", + "raw": "#2563eb", + "ophyd": "#dc2626", + "ophyd_burst": "#16a34a", + "burst_reconfig": "#ea580c", + "reconfig_wfd": "#7c3aed", } label_map = { - "raw": "Raw MMCore", - "ophyd": "Ophyd (full)", - "ophyd_burst": "Ophyd burst", - "burst_reconfig": "Reconfig (sleep)", - "reconfig_wfd": "Reconfig (waitForDevice)", + "raw": "Raw MMCore", + "ophyd": "Ophyd (full)", + "ophyd_burst": "Ophyd burst", + "burst_reconfig": "Reconfig (sleep)", + "reconfig_wfd": "Reconfig (waitForDevice)", } fig, ax = plt.subplots(figsize=(10, 5.5)) @@ -124,16 +127,27 @@ def plot_throughput(data: dict, output_dir: Path): for i, approach in enumerate(approaches): vals = [vps[s].get(approach, 0) for s in slices_set] offset = (i - n / 2 + 0.5) * width - bars = ax.bar(x + offset, vals, width * 0.92, - label=label_map.get(approach, approach), - color=color_map.get(approach, "#888"), - edgecolor="white", linewidth=0.5) + bars = ax.bar( + x + offset, + vals, + width * 0.92, + label=label_map.get(approach, approach), + color=color_map.get(approach, "#888"), + edgecolor="white", + linewidth=0.5, + ) # Value labels on bars - for bar, v in zip(bars, vals): + for bar, v in zip(bars, vals, strict=False): if v > 0: - ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.02, - f"{v:.2f}", ha="center", va="bottom", fontsize=7, - fontweight="bold") + ax.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() + 0.02, + f"{v:.2f}", + ha="center", + va="bottom", + fontsize=7, + fontweight="bold", + ) ax.set_xlabel("Slices per volume", fontsize=11) ax.set_ylabel("Volumes per second", fontsize=11) @@ -149,7 +163,7 @@ def plot_throughput(data: dict, output_dir: Path): fig.tight_layout() fig.savefig(output_dir / "benchmark_throughput.png", dpi=180) plt.close(fig) - print(f" Saved: benchmark_throughput.png") + print(" Saved: benchmark_throughput.png") # --------------------------------------------------------------------------- @@ -167,14 +181,14 @@ def plot_overhead(data: dict, output_dir: Path): compare = ["ophyd", "burst_reconfig", "reconfig_wfd"] color_map = { - "ophyd": "#dc2626", - "burst_reconfig": "#ea580c", - "reconfig_wfd": "#7c3aed", + "ophyd": "#dc2626", + "burst_reconfig": "#ea580c", + "reconfig_wfd": "#7c3aed", } label_map = { - "ophyd": "Ophyd (full teardown/setup)", - "burst_reconfig": "Reconfig (time.sleep)", - "reconfig_wfd": "Reconfig (waitForDevice)", + "ophyd": "Ophyd (full teardown/setup)", + "burst_reconfig": "Reconfig (time.sleep)", + "reconfig_wfd": "Reconfig (waitForDevice)", } overhead = defaultdict(dict) @@ -192,15 +206,26 @@ def plot_overhead(data: dict, output_dir: Path): for i, approach in enumerate(compare): vals = [overhead[s].get(approach, 0) for s in slices_set] offset = (i - n / 2 + 0.5) * width - bars = ax.bar(x + offset, vals, width * 0.92, - label=label_map.get(approach, approach), - color=color_map.get(approach, "#888"), - edgecolor="white", linewidth=0.5) - for bar, v in zip(bars, vals): + bars = ax.bar( + x + offset, + vals, + width * 0.92, + label=label_map.get(approach, approach), + color=color_map.get(approach, "#888"), + edgecolor="white", + linewidth=0.5, + ) + for bar, v in zip(bars, vals, strict=False): if v > 0: - ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 8, - f"{v:.0f}", ha="center", va="bottom", fontsize=8, - fontweight="bold") + ax.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height() + 8, + f"{v:.0f}", + ha="center", + va="bottom", + fontsize=8, + fontweight="bold", + ) ax.set_xlabel("Slices per volume", fontsize=11) ax.set_ylabel("Overhead vs raw MMCore (ms)", fontsize=11) @@ -216,7 +241,7 @@ def plot_overhead(data: dict, output_dir: Path): fig.tight_layout() fig.savefig(output_dir / "benchmark_overhead.png", dpi=180) plt.close(fig) - print(f" Saved: benchmark_overhead.png") + print(" Saved: benchmark_overhead.png") # --------------------------------------------------------------------------- @@ -243,20 +268,35 @@ def plot_wfd_savings(data: dict, output_dir: Path): x = np.arange(len(slices_set)) width = 0.32 - for i, (approach, label, color) in enumerate([ - ("burst_reconfig", "sleep()", "#ea580c"), - ("reconfig_wfd", "waitForDevice()", "#7c3aed"), - ]): + for i, (approach, label, color) in enumerate( + [ + ("burst_reconfig", "sleep()", "#ea580c"), + ("reconfig_wfd", "waitForDevice()", "#7c3aed"), + ] + ): acq_times = [raw_means.get(s, 0) for s in slices_set] overheads = [means[s].get(approach, 0) - raw_means.get(s, 0) for s in slices_set] offset = (i - 0.5) * width - ax.bar(x + offset, acq_times, width * 0.92, - color="#93c5fd", edgecolor="white", linewidth=0.5, - label="Acquisition time" if i == 0 else None) - ax.bar(x + offset, overheads, width * 0.92, - bottom=acq_times, color=color, edgecolor="white", linewidth=0.5, - label=f"Overhead ({label})") + ax.bar( + x + offset, + acq_times, + width * 0.92, + color="#93c5fd", + edgecolor="white", + linewidth=0.5, + label="Acquisition time" if i == 0 else None, + ) + ax.bar( + x + offset, + overheads, + width * 0.92, + bottom=acq_times, + color=color, + edgecolor="white", + linewidth=0.5, + label=f"Overhead ({label})", + ) ax.set_xlabel("Slices per volume", fontsize=11) ax.set_ylabel("Total time per volume (s)", fontsize=11) @@ -282,27 +322,53 @@ def plot_wfd_savings(data: dict, output_dir: Path): savings.append(so - wo) bar_width = 0.55 - bars_sleep = ax2.barh(x + 0.15, sleep_overhead, bar_width * 0.48, - color="#ea580c", label="time.sleep() overhead") - bars_wfd = ax2.barh(x - 0.15, wfd_overhead, bar_width * 0.48, - color="#7c3aed", label="waitForDevice() overhead") - - for bar, val, sav in zip(bars_sleep, sleep_overhead, savings): - ax2.text(bar.get_width() + 8, bar.get_y() + bar.get_height() / 2, - f"{val:.0f}ms", va="center", fontsize=9, color="#ea580c", - fontweight="bold") - for bar, val in zip(bars_wfd, wfd_overhead): - ax2.text(bar.get_width() + 8, bar.get_y() + bar.get_height() / 2, - f"{val:.0f}ms", va="center", fontsize=9, color="#7c3aed", - fontweight="bold") + bars_sleep = ax2.barh( + x + 0.15, + sleep_overhead, + bar_width * 0.48, + color="#ea580c", + label="time.sleep() overhead", + ) + bars_wfd = ax2.barh( + x - 0.15, + wfd_overhead, + bar_width * 0.48, + color="#7c3aed", + label="waitForDevice() overhead", + ) + + for bar, val, _sav in zip(bars_sleep, sleep_overhead, savings, strict=False): + ax2.text( + bar.get_width() + 8, + bar.get_y() + bar.get_height() / 2, + f"{val:.0f}ms", + va="center", + fontsize=9, + color="#ea580c", + fontweight="bold", + ) + for bar, val in zip(bars_wfd, wfd_overhead, strict=False): + ax2.text( + bar.get_width() + 8, + bar.get_y() + bar.get_height() / 2, + f"{val:.0f}ms", + va="center", + fontsize=9, + color="#7c3aed", + fontweight="bold", + ) # Add savings annotation - for i, (s, sav) in enumerate(zip(slices_set, savings)): - ax2.annotate(f"-{sav:.0f}ms", - xy=(sleep_overhead[i], i + 0.15), - xytext=(sleep_overhead[i] + 60, i + 0.35), - fontsize=8.5, fontweight="bold", color="#166534", - arrowprops=dict(arrowstyle="->", color="#166534", lw=1.2)) + for i, (_s, sav) in enumerate(zip(slices_set, savings, strict=False)): + ax2.annotate( + f"-{sav:.0f}ms", + xy=(sleep_overhead[i], i + 0.15), + xytext=(sleep_overhead[i] + 60, i + 0.35), + fontsize=8.5, + fontweight="bold", + color="#166534", + arrowprops=dict(arrowstyle="->", color="#166534", lw=1.2), + ) ax2.set_yticks(x) ax2.set_yticklabels([f"{s} slices" for s in slices_set]) @@ -316,7 +382,7 @@ def plot_wfd_savings(data: dict, output_dir: Path): fig.tight_layout() fig.savefig(output_dir / "benchmark_wfd_savings.png", dpi=180) plt.close(fig) - print(f" Saved: benchmark_wfd_savings.png") + print(" Saved: benchmark_wfd_savings.png") # --------------------------------------------------------------------------- @@ -328,18 +394,18 @@ def plot_consistency(data: dict, output_dir: Path): approaches_order = ["raw", "ophyd_burst", "reconfig_wfd", "burst_reconfig", "ophyd"] label_map = { - "raw": "Raw\nMMCore", - "ophyd": "Ophyd\n(full)", - "ophyd_burst": "Ophyd\nburst", - "burst_reconfig": "Reconfig\n(sleep)", - "reconfig_wfd": "Reconfig\n(wfd)", + "raw": "Raw\nMMCore", + "ophyd": "Ophyd\n(full)", + "ophyd_burst": "Ophyd\nburst", + "burst_reconfig": "Reconfig\n(sleep)", + "reconfig_wfd": "Reconfig\n(wfd)", } color_map = { - "raw": "#2563eb", - "ophyd": "#dc2626", - "ophyd_burst": "#16a34a", - "burst_reconfig": "#ea580c", - "reconfig_wfd": "#7c3aed", + "raw": "#2563eb", + "ophyd": "#dc2626", + "ophyd_burst": "#16a34a", + "burst_reconfig": "#ea580c", + "reconfig_wfd": "#7c3aed", } slices_set = sorted(set(r["slices"] for r in per_volume)) @@ -348,21 +414,28 @@ def plot_consistency(data: dict, output_dir: Path): if len(slices_set) == 1: axes = [axes] - for ax, ns in zip(axes, slices_set): + for ax, ns in zip(axes, slices_set, strict=False): box_data = [] labels = [] colors = [] for approach in approaches_order: - timings = [r["elapsed_s"] for r in per_volume - if r["slices"] == ns and r["approach"] == approach] + timings = [ + r["elapsed_s"] + for r in per_volume + if r["slices"] == ns and r["approach"] == approach + ] if timings: box_data.append(timings) labels.append(label_map.get(approach, approach)) colors.append(color_map.get(approach, "#888")) - bp = ax.boxplot(box_data, patch_artist=True, widths=0.55, - medianprops=dict(color="black", linewidth=1.5)) - for patch, c in zip(bp["boxes"], colors): + bp = ax.boxplot( + box_data, + patch_artist=True, + widths=0.55, + medianprops=dict(color="black", linewidth=1.5), + ) + for patch, c in zip(bp["boxes"], colors, strict=False): patch.set_facecolor(c) patch.set_alpha(0.7) @@ -377,7 +450,7 @@ def plot_consistency(data: dict, output_dir: Path): fig.tight_layout() fig.savefig(output_dir / "benchmark_consistency.png", dpi=180, bbox_inches="tight") plt.close(fig) - print(f" Saved: benchmark_consistency.png") + print(" Saved: benchmark_consistency.png") # --------------------------------------------------------------------------- diff --git a/diagnostics/run_multi_embryo_volumes.py b/diagnostics/run_multi_embryo_volumes.py index d52aeada..90af3c68 100644 --- a/diagnostics/run_multi_embryo_volumes.py +++ b/diagnostics/run_multi_embryo_volumes.py @@ -9,14 +9,15 @@ python run_multi_embryo_volumes.py """ -import time import json -import numpy as np -from pathlib import Path +import time from datetime import datetime, timedelta -from client import get_mmc -import tifffile +from pathlib import Path + +import numpy as np import rpyc +import tifffile +from client import get_mmc from tqdm import tqdm # Device configuration @@ -36,9 +37,11 @@ def load_database(): """Load embryo database.""" if not DATABASE_FILE.exists(): - raise FileNotFoundError(f"Database not found: {DATABASE_FILE}\nRun multi_embryo_calibration.py first!") + raise FileNotFoundError( + f"Database not found: {DATABASE_FILE}\nRun multi_embryo_calibration.py first!" + ) - with open(DATABASE_FILE, 'r') as f: + with open(DATABASE_FILE) as f: return json.load(f) @@ -58,8 +61,8 @@ def move_to_embryo(embryo_data): embryo_data : dict Embryo information from database """ - target_x = embryo_data['stage_position_after_centering_um']['x'] - target_y = embryo_data['stage_position_after_centering_um']['y'] + target_x = embryo_data["stage_position_after_centering_um"]["x"] + target_y = embryo_data["stage_position_after_centering_um"]["y"] print(f" Moving to embryo position: ({target_x:.2f}, {target_y:.2f}) µm") @@ -94,7 +97,7 @@ def configure_hardware_for_volume(calibration, num_slices): # Stop any existing sequence acquisition (from previous embryo or calibration) if core.isSequenceRunning(): - print(f" Stopping previous sequence...") + print(" Stopping previous sequence...") core.stopSequenceAcquisition() time.sleep(0.5) @@ -105,14 +108,14 @@ def configure_hardware_for_volume(calibration, num_slices): try: core.setProperty(GALVO_DEVICE, "SPIMState", "Idle") time.sleep(0.2) - except: + except Exception: pass # Extract calibration parameters - slope = calibration['slope_um_per_deg'] - offset = calibration['offset_um'] - galvo_top = calibration.get('edge_top_deg', calibration['galvo_top_deg']) - galvo_bottom = calibration.get('edge_bottom_deg', calibration['galvo_bottom_deg']) + slope = calibration["slope_um_per_deg"] + offset = calibration["offset_um"] + galvo_top = calibration.get("edge_top_deg", calibration["galvo_top_deg"]) + galvo_bottom = calibration.get("edge_bottom_deg", calibration["galvo_bottom_deg"]) # Calculate galvo parameters galvo_center = (galvo_top + galvo_bottom) / 2.0 @@ -126,8 +129,14 @@ def configure_hardware_for_volume(calibration, num_slices): piezo_range = piezo_bottom - piezo_top piezo_amplitude = piezo_range / 2.0 - print(f" Galvo: center={galvo_center:+.4f}°, amplitude=±{galvo_amplitude:.4f}° (range: {galvo_range:.4f}°)") - print(f" Piezo: center={piezo_center:.1f}µm, amplitude=±{piezo_amplitude:.1f}µm (range: {piezo_range:.1f}µm)") + print( + f" Galvo: center={galvo_center:+.4f}°, amplitude=±{galvo_amplitude:.4f}°" + f" (range: {galvo_range:.4f}°)" + ) + print( + f" Piezo: center={piezo_center:.1f}µm, amplitude=±{piezo_amplitude:.1f}µm" + f" (range: {piezo_range:.1f}µm)" + ) # System startup core.setConfig("System", "Startup") @@ -189,13 +198,13 @@ def configure_hardware_for_volume(calibration, num_slices): core.setProperty(PIEZO_DEVICE, "SPIMState", "Armed") time.sleep(0.3) - print(f" ✓ Hardware configured for hardware-triggered acquisition") + print(" ✓ Hardware configured for hardware-triggered acquisition") return { - 'galvo_center': galvo_center, - 'galvo_amplitude': galvo_amplitude, - 'piezo_center': piezo_center, - 'piezo_amplitude': piezo_amplitude + "galvo_center": galvo_center, + "galvo_amplitude": galvo_amplitude, + "piezo_center": piezo_center, + "piezo_amplitude": piezo_amplitude, } @@ -238,7 +247,7 @@ def acquire_volume_for_embryo(embryo_id, calibration, num_slices=50): # Trigger SPIM state machine core.setProperty(GALVO_DEVICE, "SPIMState", "Running") - print(f" ✓ SPIM triggered") + print(" ✓ SPIM triggered") # Collect images images = [] @@ -296,52 +305,59 @@ def save_volume(volume, embryo_id, embryo_number, output_dir): def main(): """Main multi-embryo volume acquisition workflow.""" - print(f"{'='*70}") + print(f"{'=' * 70}") print("MULTI-EMBRYO VOLUME ACQUISITION") - print(f"{'='*70}") + print(f"{'=' * 70}") try: # Load database - print(f"\n{'='*70}") + print(f"\n{'=' * 70}") print("LOADING DATABASE") - print(f"{'='*70}") + print(f"{'=' * 70}") database = load_database() - embryos = database.get('embryos', {}) + embryos = database.get("embryos", {}) num_embryos = len(embryos) print(f" Database: {DATABASE_FILE}") print(f" Found {num_embryos} embryo(s)") if num_embryos == 0: - print(f"\n ⚠ No embryos in database!") - print(f" Run multi_embryo_calibration.py first.") + print("\n ⚠ No embryos in database!") + print(" Run multi_embryo_calibration.py first.") return # List embryos - print(f"\n Embryos:") + print("\n Embryos:") for emb_id, emb_data in embryos.items(): - emb_num = emb_data.get('embryo_number', '?') - pos = emb_data['stage_position_after_centering_um'] + emb_num = emb_data.get("embryo_number", "?") + pos = emb_data["stage_position_after_centering_um"] print(f" {emb_id} (#{emb_num}): ({pos['x']:.1f}, {pos['y']:.1f}) µm") # Acquisition parameters - print(f"\n{'='*70}") + print(f"\n{'=' * 70}") print("ACQUISITION PARAMETERS") - print(f"{'='*70}") + print(f"{'=' * 70}") - num_slices = int(input(f" Number of slices per volume (default 50): ").strip() or "50") + num_slices = int(input(" Number of slices per volume (default 50): ").strip() or "50") print(f" ✓ Will acquire {num_slices} slices per embryo") # Timelapse parameters - num_timepoints = int(input(f" Number of timepoints (default 1 for single acquisition): ").strip() or "1") + num_timepoints = int( + input(" Number of timepoints (default 1 for single acquisition): ").strip() or "1" + ) interval_minutes = 0 if num_timepoints > 1: - interval_minutes = float(input(f" Interval between timepoints in minutes (e.g., 2): ").strip() or "2") + interval_minutes = float( + input(" Interval between timepoints in minutes (e.g., 2): ").strip() or "2" + ) total_duration_hours = (num_timepoints - 1) * interval_minutes / 60.0 - print(f" ✓ Timelapse: {num_timepoints} timepoints every {interval_minutes} min ({total_duration_hours:.1f} hours total)") + print( + f" ✓ Timelapse: {num_timepoints} timepoints every {interval_minutes} min" + f" ({total_duration_hours:.1f} hours total)" + ) else: - print(f" ✓ Single acquisition (no timelapse)") + print(" ✓ Single acquisition (no timelapse)") # Create output directory session_dir = OUTPUT_DIR / datetime.now().strftime("%Y%m%d_%H%M%S") @@ -358,8 +374,11 @@ def main(): desc="Timepoints", unit="tp", position=0, - colour='green', - bar_format='{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]' + colour="green", + bar_format=( + "{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt}" + " [{elapsed}<{remaining}, {rate_fmt}]" + ), ) for timepoint in range(num_timepoints): @@ -367,13 +386,15 @@ def main(): elapsed_hours = (timepoint_start_time - session_start_time) / 3600.0 # Update timepoint progress bar - timepoint_pbar.set_description(f"Timepoint {timepoint+1}/{num_timepoints} (Elapsed: {elapsed_hours:.1f}h)") + timepoint_pbar.set_description( + f"Timepoint {timepoint + 1}/{num_timepoints} (Elapsed: {elapsed_hours:.1f}h)" + ) - print(f"\n{'='*70}") + print(f"\n{'=' * 70}") print(f"TIMEPOINT {timepoint + 1}/{num_timepoints}") if num_timepoints > 1: print(f"Elapsed: {elapsed_hours:.2f} hours") - print(f"{'='*70}") + print(f"{'=' * 70}") # Acquire volume for each embryo timepoint_results = [] @@ -385,51 +406,53 @@ def main(): unit="embryo", position=1, leave=False, - colour='cyan' + colour="cyan", ) for idx, (emb_id, emb_data) in enumerate(embryos.items(), 1): - emb_num = emb_data.get('embryo_number', idx) + emb_num = emb_data.get("embryo_number", idx) embryo_pbar.set_description(f" Embryo {emb_num} (t{timepoint:04d})") print(f"\n[Embryo {idx}/{num_embryos}] {emb_id} (Embryo #{emb_num})") - print(f"{'─'*70}") + print(f"{'─' * 70}") # Move to embryo move_to_embryo(emb_data) # Configure hardware - calibration = emb_data['calibration'] + calibration = emb_data["calibration"] configure_hardware_for_volume(calibration, num_slices) # Acquire volume volume = acquire_volume_for_embryo(emb_id, calibration, num_slices) if volume is None: - print(f" ✗ Failed to acquire volume") - timepoint_results.append({ - 'embryo_id': emb_id, - 'timepoint': timepoint, - 'success': False - }) + print(" ✗ Failed to acquire volume") + timepoint_results.append( + {"embryo_id": emb_id, "timepoint": timepoint, "success": False} + ) embryo_pbar.update(1) continue # Save volume with timepoint in filename timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - filename = session_dir / f"{emb_id}_embryo{emb_num:03d}_t{timepoint:04d}_{timestamp}.tif" + filename = ( + session_dir / f"{emb_id}_embryo{emb_num:03d}_t{timepoint:04d}_{timestamp}.tif" + ) tifffile.imwrite(filename, volume) print(f" ✓ Saved: {filename.name}") - timepoint_results.append({ - 'embryo_id': emb_id, - 'embryo_number': emb_num, - 'timepoint': timepoint, - 'success': True, - 'filename': str(filename), - 'shape': volume.shape - }) + timepoint_results.append( + { + "embryo_id": emb_id, + "embryo_number": emb_num, + "timepoint": timepoint, + "success": True, + "filename": str(filename), + "shape": volume.shape, + } + ) print(f" ✓ Complete: {volume.shape}") embryo_pbar.update(1) @@ -447,10 +470,10 @@ def main(): if wait_time > 0: next_timepoint_time = datetime.now() + timedelta(seconds=wait_time) - print(f"\n{'─'*70}") - print(f"Waiting {wait_time/60:.1f} minutes until next timepoint...") + print(f"\n{'─' * 70}") + print(f"Waiting {wait_time / 60:.1f} minutes until next timepoint...") print(f"Next timepoint at: {next_timepoint_time.strftime('%H:%M:%S')}") - print(f"{'─'*70}") + print(f"{'─' * 70}") # Progress bar for waiting wait_pbar = tqdm( @@ -459,7 +482,7 @@ def main(): unit="s", position=1, leave=False, - colour='yellow' + colour="yellow", ) for _ in range(int(wait_time)): time.sleep(1) @@ -469,70 +492,81 @@ def main(): # Sleep remaining fractional seconds time.sleep(wait_time - int(wait_time)) else: - print(f"\n{'─'*70}") - print(f"⚠ Warning: Acquisition took {timepoint_duration/60:.1f} min (longer than {interval_minutes} min interval)") - print(f"Proceeding immediately to next timepoint...") - print(f"{'─'*70}") + print(f"\n{'─' * 70}") + print( + f"⚠ Warning: Acquisition took {timepoint_duration / 60:.1f} min" + f" (longer than {interval_minutes} min interval)" + ) + print("Proceeding immediately to next timepoint...") + print(f"{'─' * 70}") timepoint_pbar.close() # Cleanup - print(f"\n{'='*70}") + print(f"\n{'=' * 70}") print("CLEANUP") - print(f"{'='*70}") + print(f"{'=' * 70}") core.setConfig("Laser", "ALL OFF") - print(f" ✓ Lasers OFF") + print(" ✓ Lasers OFF") # Summary - print(f"\n{'='*70}") + print(f"\n{'=' * 70}") print("ACQUISITION COMPLETE") - print(f"{'='*70}") + print(f"{'=' * 70}") total_duration = time.time() - session_start_time - successful = sum(1 for r in all_results if r['success']) + successful = sum(1 for r in all_results if r["success"]) total_acquisitions = num_embryos * num_timepoints - print(f"\n Session duration: {total_duration/3600:.2f} hours") + print(f"\n Session duration: {total_duration / 3600:.2f} hours") print(f" Timepoints: {num_timepoints}") print(f" Embryos per timepoint: {num_embryos}") print(f" Successful acquisitions: {successful}/{total_acquisitions}") print(f" Output directory: {session_dir}") - print(f"\n Results:") + print("\n Results:") for result in all_results: - if result['success']: - t = result.get('timepoint', 0) - print(f" ✓ {result['embryo_id']} t{t:04d}: {result['shape']} → {Path(result['filename']).name}") + if result["success"]: + t = result.get("timepoint", 0) + print( + f" ✓ {result['embryo_id']} t{t:04d}: {result['shape']}" + f" → {Path(result['filename']).name}" + ) else: - t = result.get('timepoint', 0) + t = result.get("timepoint", 0) print(f" ✗ {result['embryo_id']} t{t:04d}: Failed") # Save acquisition log log_file = session_dir / "acquisition_log.json" - with open(log_file, 'w') as f: - json.dump({ - 'timestamp': datetime.now().isoformat(), - 'session_duration_hours': total_duration / 3600.0, - 'num_embryos': num_embryos, - 'num_slices': num_slices, - 'num_timepoints': num_timepoints, - 'interval_minutes': interval_minutes, - 'total_acquisitions': total_acquisitions, - 'successful_acquisitions': successful, - 'results': all_results - }, f, indent=2) + with open(log_file, "w") as f: + json.dump( + { + "timestamp": datetime.now().isoformat(), + "session_duration_hours": total_duration / 3600.0, + "num_embryos": num_embryos, + "num_slices": num_slices, + "num_timepoints": num_timepoints, + "interval_minutes": interval_minutes, + "total_acquisitions": total_acquisitions, + "successful_acquisitions": successful, + "results": all_results, + }, + f, + indent=2, + ) print(f"\n ✓ Log saved: {log_file}") - print(f"\n{'='*70}\n") + print(f"\n{'=' * 70}\n") except KeyboardInterrupt: - print(f"\n\nInterrupted\n") + print("\n\nInterrupted\n") except Exception as e: - print(f"\n{'='*70}") + print(f"\n{'=' * 70}") print("ERROR") - print(f"{'='*70}") + print(f"{'=' * 70}") print(f"Error: {e}") import traceback + traceback.print_exc() diff --git a/diagnostics/segment_embryo_nuclei.py b/diagnostics/segment_embryo_nuclei.py index a99538c8..b9c0405f 100644 --- a/diagnostics/segment_embryo_nuclei.py +++ b/diagnostics/segment_embryo_nuclei.py @@ -3,11 +3,12 @@ Run with venv_cv: venv_cv/Scripts/python segment_embryo_nuclei.py """ +from pathlib import Path + +import napari import numpy as np import tifffile -from pathlib import Path from cellpose import models -import napari def load_volume(tiff_path: Path) -> np.ndarray: @@ -39,10 +40,10 @@ def segment_nuclei_3d(volume: np.ndarray, diameter: float = 30.0) -> np.ndarray: masks, flows, styles = model.eval( vol_norm, diameter=diameter, - do_3D=False, # 2D per slice (fast!) + do_3D=False, # 2D per slice (fast!) z_axis=0, stitch_threshold=0.5, # stitch 2D masks into 3D - batch_size=64, # larger batch for speed + batch_size=64, # larger batch for speed ) n_nuclei = len(np.unique(masks)) - 1 @@ -76,6 +77,7 @@ def main(): # Downsample by factor of 2 from scipy.ndimage import zoom + volume = zoom(volume, (1, 0.5, 0.5), order=1) print(f" Downsampled 2x: {volume.shape}") @@ -85,7 +87,7 @@ def main(): # Visualize in Napari print("Opening Napari viewer...") viewer = napari.Viewer() - viewer.add_image(volume, name="Volume", colormap='gray') + viewer.add_image(volume, name="Volume", colormap="gray") viewer.add_labels(masks, name="Nuclei segmentation") napari.run() diff --git a/diagnostics/spim_hardware_triggering_reference.py b/diagnostics/spim_hardware_triggering_reference.py index cab5cfb1..87a9f7a4 100644 --- a/diagnostics/spim_hardware_triggering_reference.py +++ b/diagnostics/spim_hardware_triggering_reference.py @@ -70,9 +70,11 @@ """ import time + import numpy as np from client import get_mmc + def configure_camera_for_hardware_trigger(core, camera_name, exposure_ms): """ Configure Hamamatsu camera for external edge triggering in light sheet mode. @@ -132,8 +134,15 @@ def configure_camera_for_hardware_trigger(core, camera_name, exposure_ms): raise Exception(f"Failed to set TRIGGER ACTIVE to EDGE (got: {trigger_active})") -def configure_spim_scanner(core, scanner_name, num_slices, slice_step_um, - scan_duration_ms, camera_duration_ms, laser_duration_ms): +def configure_spim_scanner( + core, + scanner_name, + num_slices, + slice_step_um, + scan_duration_ms, + camera_duration_ms, + laser_duration_ms, +): """ Configure ASI Tiger scanner for SPIM state machine operation. @@ -180,7 +189,7 @@ def configure_spim_scanner(core, scanner_name, num_slices, slice_step_um, core.setProperty(scanner_name, "SingleAxisYPattern", "1 - Triangle") core.setProperty(scanner_name, "SingleAxisYMode", "3 - Enabled with axes synced") - print(f" X-axis (light sheet): Amplitude=2.0°, Pattern=Triangle, Mode=Synced") + print(" X-axis (light sheet): Amplitude=2.0°, Pattern=Triangle, Mode=Synced") print(f" Y-axis (slice step): Amplitude={y_amplitude:.4f}°, Pattern=Triangle, Mode=Synced") print(f" (Calculated for {num_slices} slices × {slice_step_um} μm steps)") @@ -206,7 +215,7 @@ def configure_spim_scanner(core, scanner_name, num_slices, slice_step_um, core.setProperty(scanner_name, "SPIMDelayBeforeScan(ms)", 0.0) core.setProperty(scanner_name, "SPIMDelayBeforeCamera(ms)", 0.5) - print(f" SPIM State Machine:") + print(" SPIM State Machine:") print(f" NumSlices: {num_slices}") print(f" ScanDuration: {scan_duration_ms} ms (total time per slice)") print(f" CameraDuration: {camera_duration_ms} ms (TTL trigger pulse width)") @@ -214,10 +223,16 @@ def configure_spim_scanner(core, scanner_name, num_slices, slice_step_um, # Verify critical timing relationships if camera_duration_ms > scan_duration_ms: - raise Exception(f"CameraDuration ({camera_duration_ms}ms) must be <= ScanDuration ({scan_duration_ms}ms)") + raise Exception( + f"CameraDuration ({camera_duration_ms}ms) must be <=" + f" ScanDuration ({scan_duration_ms}ms)" + ) if laser_duration_ms > camera_duration_ms: - raise Exception(f"LaserDuration ({laser_duration_ms}ms) must be <= CameraDuration ({camera_duration_ms}ms)") + raise Exception( + f"LaserDuration ({laser_duration_ms}ms) must be <=" + f" CameraDuration ({camera_duration_ms}ms)" + ) def arm_spim_state_machine(core, scanner_name): @@ -256,8 +271,9 @@ def trigger_spim_acquisition(core, scanner_name): print(f" SPIMState: {state}") -def acquire_spim_volume(core, camera_name, scanner_name, num_slices, - scan_duration_ms, timeout_extra_sec=5.0): +def acquire_spim_volume( + core, camera_name, scanner_name, num_slices, scan_duration_ms, timeout_extra_sec=5.0 +): """ Perform hardware-triggered SPIM volume acquisition. @@ -337,7 +353,10 @@ def acquire_spim_volume(core, camera_name, scanner_name, num_slices, count = core.getRemainingImageCount() seq_running = core.isSequenceRunning(camera_name) spim_state = core.getProperty(scanner_name, "SPIMState") - print(f" t={elapsed:.1f}s: images={count}/{num_slices}, seq={seq_running}, SPIM={spim_state}") + print( + f" t={elapsed:.1f}s: images={count}/{num_slices}," + f" seq={seq_running}, SPIM={spim_state}" + ) last_print_time = time.time() time.sleep(0.01) @@ -350,13 +369,16 @@ def acquire_spim_volume(core, camera_name, scanner_name, num_slices, print(" Retrieving images from buffer...") import rpyc + images = [] for i in range(count): img = core.popNextImage() img = rpyc.classic.obtain(img) # Transfer from remote to local images.append(img) - print(f" Image {i+1}/{count}: shape={img.shape}, dtype={img.dtype}, " - f"range=[{img.min()}, {img.max()}], mean={img.mean():.1f}") + print( + f" Image {i + 1}/{count}: shape={img.shape}, dtype={img.dtype}, " + f"range=[{img.min()}, {img.max()}], mean={img.mean():.1f}" + ) # Convert to 3D numpy array volume = np.array(images) @@ -380,9 +402,9 @@ def main(): camera_duration_ms = 155.0 # TTL pulse width (should be ~= exposure) laser_duration_ms = 154.0 # Laser on time (slightly less than camera) - print("="*80) + print("=" * 80) print("ASI diSPIM HARDWARE-TRIGGERED VOLUME ACQUISITION") - print("="*80) + print("=" * 80) try: # Apply system configuration @@ -413,8 +435,13 @@ def main(): # Configure SPIM scanner print() configure_spim_scanner( - core, scanner_name, num_slices, slice_step_um, - scan_duration_ms, camera_duration_ms, laser_duration_ms + core, + scanner_name, + num_slices, + slice_step_um, + scan_duration_ms, + camera_duration_ms, + laser_duration_ms, ) # Arm SPIM state machine @@ -423,43 +450,50 @@ def main(): # Acquire volume print() - volume = acquire_spim_volume( - core, camera_name, scanner_name, num_slices, scan_duration_ms - ) + volume = acquire_spim_volume(core, camera_name, scanner_name, num_slices, scan_duration_ms) # Save volume print("\nSaving volume...") from PIL import Image as PILImage + img_list = [PILImage.fromarray(img.astype(np.uint16)) for img in volume] - img_list[0].save('spim_hardware_triggered_volume.tif', - save_all=True, append_images=img_list[1:]) + img_list[0].save( + "spim_hardware_triggered_volume.tif", + save_all=True, + append_images=img_list[1:], + ) print(f" Saved {len(volume)}-slice volume to: spim_hardware_triggered_volume.tif") print(f" Volume shape: {volume.shape} (slices, height, width)") # Display in napari (optional) try: import napari + print("\nDisplaying in napari...") viewer = napari.Viewer() - viewer.add_image(volume, name='SPIM Volume', colormap='gray', - contrast_limits=[np.percentile(volume, 1), - np.percentile(volume, 99)]) - viewer.dims.axis_labels = ['Z', 'Y', 'X'] + viewer.add_image( + volume, + name="SPIM Volume", + colormap="gray", + contrast_limits=[np.percentile(volume, 1), np.percentile(volume, 99)], + ) + viewer.dims.axis_labels = ["Z", "Y", "X"] print(" Close napari window to continue...") napari.run() except ImportError: print(" (napari not available, skipping visualization)") - print("\n" + "="*80) + print("\n" + "=" * 80) print("ACQUISITION COMPLETE!") - print("="*80) + print("=" * 80) except Exception as e: - print("\n" + "="*80) + print("\n" + "=" * 80) print("ACQUISITION FAILED") - print("="*80) + print("=" * 80) print(f"Error: {e}") import traceback + traceback.print_exc() finally: @@ -469,26 +503,26 @@ def main(): if core.isSequenceRunning(camera_name): core.stopSequenceAcquisition(camera_name) print(" Stopped camera sequence") - except: + except Exception: pass try: core.setProperty(scanner_name, "SPIMState", "Idle") print(" Reset SPIM to Idle") - except: + except Exception: pass try: # Reset camera to internal triggering for live mode core.setProperty(camera_name, "TRIGGER SOURCE", "INTERNAL") print(" Reset camera to internal triggering") - except: + except Exception: pass try: core.setConfig("Laser", "ALL OFF") print(" Lasers OFF") - except: + except Exception: pass diff --git a/diagnostics/switchbot_webgui.py b/diagnostics/switchbot_webgui.py new file mode 100644 index 00000000..5df81015 --- /dev/null +++ b/diagnostics/switchbot_webgui.py @@ -0,0 +1,379 @@ +#!/usr/bin/env python3 +""" +Temporary web GUI to play with the SwitchBot Bot that switches the diSPIM room +light (on for bottom-camera/brightfield imaging, off otherwise). + +This is a TEST TOOL, not part of the production device layer. It drives the Bot +directly over BLE using the same command protocol as +``gently.hardware.switchbot.SwitchBot`` (same command bytes + GATT UUIDs), but +over a single *persistent* connection so the buttons feel snappy and the morse +blinker is fast — the device-layer class is connect-per-command (~1-2 s each), +which is fine for a plan step but hopeless for blinking. + +Features: ON / OFF / PRESS buttons, and a morse-code blinker (blinks the real +room light + mirrors the pattern on screen). The Bot is a mechanical switch +pusher, so each toggle is a ~0.5-1 s servo move — morse is deliberately slow. + +Run: + .venv/bin/python diagnostics/switchbot_webgui.py + # then open http://127.0.0.1:8765 + + .venv/bin/python diagnostics/switchbot_webgui.py --address EC:6F:04:06:5B:23 --port 8765 +""" + +from __future__ import annotations + +import argparse +import asyncio +import logging +from contextlib import asynccontextmanager + +import uvicorn +from fastapi import FastAPI +from fastapi.responses import HTMLResponse, JSONResponse +from pydantic import BaseModel + +# Reuse the device-layer device's protocol definitions (single source of truth). +from gently.hardware.switchbot import _COMMANDS, _CTRL_CHAR + +logger = logging.getLogger("switchbot_webgui") + +DEFAULT_ADDRESS = "EC:6F:04:06:5B:23" + +# ITU morse, letters + digits. Unsupported characters are skipped. +MORSE = { + "A": ".-", + "B": "-...", + "C": "-.-.", + "D": "-..", + "E": ".", + "F": "..-.", + "G": "--.", + "H": "....", + "I": "..", + "J": ".---", + "K": "-.-", + "L": ".-..", + "M": "--", + "N": "-.", + "O": "---", + "P": ".--.", + "Q": "--.-", + "R": ".-.", + "S": "...", + "T": "-", + "U": "..-", + "V": "...-", + "W": ".--", + "X": "-..-", + "Y": "-.--", + "Z": "--..", + "0": "-----", + "1": ".----", + "2": "..---", + "3": "...--", + "4": "....-", + "5": ".....", + "6": "-....", + "7": "--...", + "8": "---..", + "9": "----.", +} + + +class Bot: + """A single persistent BLE connection to the Bot, with serialized access.""" + + def __init__(self, address: str): + self.address = address + self._client = None + self._lock = asyncio.Lock() + self._morse_task: asyncio.Task | None = None + self.state = "unknown" + self.busy = False + + async def _ensure(self): + from bleak import BleakClient + + if self._client is not None and self._client.is_connected: + return + self._client = BleakClient(self.address, timeout=20) + await self._client.connect() + logger.info("connected to %s", self.address) + + async def _write(self, action: str): + """Write one command, reconnecting once if the link dropped.""" + from bleak.exc import BleakError + + for attempt in (1, 2): + try: + await self._ensure() + await self._client.write_gatt_char(_CTRL_CHAR, _COMMANDS[action], response=True) + if action in ("on", "off"): + self.state = action + return + except (BleakError, OSError, asyncio.TimeoutError) as exc: + logger.warning("write %s attempt %d failed: %s", action, attempt, exc) + self._client = None # force reconnect + if attempt == 2: + raise + + async def _cancel_morse(self): + task = self._morse_task + if task and not task.done(): + task.cancel() + await asyncio.gather(task, return_exceptions=True) + self._morse_task = None + + async def command(self, action: str) -> str: + """ON/OFF/PRESS. Interrupts any running morse (manual override).""" + await self._cancel_morse() + async with self._lock: + await self._write(action) + return self.state + + def schedule(self, text: str, unit: float): + """Build an on/off timeline [(state, seconds), ...] for a message.""" + seq = [("off", round(unit, 3))] # settle to a known state first + for ch in text.upper(): + if ch == " ": + seq.append(("off", round(unit * 7, 3))) + continue + code = MORSE.get(ch) + if not code: + continue + for sym in code: + seq.append(("on", round(unit * (3 if sym == "-" else 1), 3))) + seq.append(("off", round(unit, 3))) # intra-letter gap + st, _ = seq[-1] + seq[-1] = (st, round(unit * 3, 3)) # upgrade to inter-letter gap + return seq + + async def start_morse(self, text: str, unit: float): + await self._cancel_morse() + seq = self.schedule(text, unit) + if len(seq) <= 1: + return None + restore = self.state + self._morse_task = asyncio.create_task(self._play(seq, restore)) + return seq + + async def _play(self, seq, restore: str): + async with self._lock: + self.busy = True + try: + for state, dur in seq: + await self._write(state) + await asyncio.sleep(dur) + await self._write(restore if restore in ("on", "off") else "off") + finally: + self.busy = False + + async def stop(self): + await self._cancel_morse() + async with self._lock: + await self._write("off") + return self.state + + async def disconnect(self): + await self._cancel_morse() + if self._client is not None and self._client.is_connected: + await self._client.disconnect() + + +BOT: Bot | None = None + + +@asynccontextmanager +async def lifespan(app: FastAPI): + yield + if BOT is not None: + await BOT.disconnect() + + +app = FastAPI(lifespan=lifespan) + + +class MorseReq(BaseModel): + text: str + unit: float = 1.5 + + +@app.get("/", response_class=HTMLResponse) +async def index(): + return PAGE.replace("__ADDRESS__", BOT.address) + + +@app.get("/status") +async def status(): + return {"state": BOT.state, "busy": BOT.busy, "address": BOT.address} + + +@app.post("/cmd/{action}") +async def cmd(action: str): + if action not in _COMMANDS: + return JSONResponse({"error": f"unknown action {action!r}"}, status_code=400) + try: + state = await BOT.command(action) + except Exception as exc: + return JSONResponse({"error": str(exc)}, status_code=502) + return {"state": state} + + +@app.post("/morse") +async def morse(req: MorseReq): + unit = max(0.3, min(4.0, req.unit)) + text = req.text[:40] + try: + seq = await BOT.start_morse(text, unit) + except Exception as exc: + return JSONResponse({"error": str(exc)}, status_code=502) + if seq is None: + return JSONResponse({"error": "nothing sendable in that text"}, status_code=400) + seconds = round(sum(d for _, d in seq), 1) + return {"schedule": seq, "unit": unit, "seconds": seconds} + + +@app.post("/stop") +async def stop(): + try: + state = await BOT.stop() + except Exception as exc: + return JSONResponse({"error": str(exc)}, status_code=502) + return {"state": state} + + +PAGE = """ + + +diSPIM Room Light + +
+

diSPIM Room Light

+
SwitchBot Bot · __ADDRESS__
+
+
+
+ + + +
+
+ +
+ fast + + slow + 0.7s +
+
+ + +
+
+
+
+ +""" + + +def main(): + ap = argparse.ArgumentParser(description="Temporary SwitchBot room-light web GUI") + ap.add_argument("--address", default=DEFAULT_ADDRESS, help="Bot BLE MAC address") + ap.add_argument("--port", type=int, default=8765) + ap.add_argument("--host", default="127.0.0.1", help="bind host (default: localhost only)") + args = ap.parse_args() + + logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") + global BOT + BOT = Bot(args.address) + print(f"\n diSPIM Room Light GUI → http://{args.host}:{args.port}\n Bot: {args.address}\n") + uvicorn.run(app, host=args.host, port=args.port, log_level="warning") + + +if __name__ == "__main__": + main() diff --git a/docs/CLOSED_LOOP_PARADIGM.md b/docs/CLOSED_LOOP_PARADIGM.md new file mode 100644 index 00000000..8f1cf80d --- /dev/null +++ b/docs/CLOSED_LOOP_PARADIGM.md @@ -0,0 +1,536 @@ +# Closed-Loop Paradigm: Notes on the Discussion + +This document captures the design conversation that produced everything on the +`paradigm/closed-loop` branch: the schema split, the Map-as-embryo-home work, +the operator-action vocabulary, the eval substrate (capture / replay / +decisions / shadow), and the trajectory the system is on. It is a +distillation, not a transcript — a future-self / new-collaborator reference +for *why* this code looks the way it does and *where it is going*. + +--- + +## 1. The Original Friction + +The conversation started from a small, concrete observation by the operator: + +> "It feels awkward that the operator has to go between the chat in the TUI +> and the viz server… or even to chat about detecting embryos." + +That awkwardness is a symptom, not a defect. It surfaces a deeper design +question: **what is the orchestrator (the agent) actually for?** Today the +orchestrator does at least four jobs at once, and one of them — *tool router* +— is the one creating the friction. + +### The four orchestrator roles + +| Role | Description | Replaceable by a button? | +| --- | --- | --- | +| 1. Tool router | "Detect embryos" → `detect_embryos()` call | **Yes** — this is the friction surface | +| 2. Workflow runner | Timelapses, multi-embryo plans, perception loops | No | +| 3. Domain reasoner | Knows microscopy, embryos, safety constraints | No | +| 4. Session memory | Coherent narrative of what happened and why | No | + +Routing a single click through chat for a routine action is the system +fighting against its own users. Routing a multi-step scientific decision +through Claude is using the right tool for the right job. The paradigm here +is: **shrink role 1 to its essentials, keep roles 2–4 first-class, and let +the UI carry the rest.** + +--- + +## 2. Web ↔ Chat Reconciliation Patterns + +Four ways to relate the web UI and the chat orchestrator. Each has a +distinct world model property: + +### A. Chat-only intent (the old default) + +Every action originates in chat. The web is observation + delegated subtasks +(e.g. the marking canvas is a delegation the orchestrator triggers). + +* Cleanest record. +* Worst friction. +* Orchestrator's world model is "complete" because every change passes + through it. + +### B. Two parallel command surfaces + +Operator clicks in web, web acts directly; orchestrator finds out by polling +state or doesn't find out at all. + +* Lowest friction. +* Orchestrator's world model **drifts from reality** — fatal for role 4 + (session memory) and dangerous for role 3 (safety reasoning). + +### C. Web acts, orchestrator subscribes *(the chosen direction)* + +Operator clicks → web performs the action **and** publishes an event +(`OPERATOR_*`) → orchestrator's session memory ingests it. + +* Chat log shows only human conversation. +* Orchestrator's working context shows chat + events as a single timeline. +* Phase 7 (operator events vocabulary + reactive candidate) is the first + installment of this pattern. + +### D. Cross-pattern hybrid + +Different action classes use different patterns. Heavy / novel / composite +actions use chat (A); routine / clickable / contextual actions use web (C). +This is what the system actually drifts toward; pattern C is the substrate +that makes it possible. + +The orchestrator's job shifts from being **a funnel for action** to being **a +brain that knows what's happening on every surface**. + +--- + +## 3. The "Turn" is Wrong; the "Decision Moment" is Right + +Chat-AI literature reasons in *turns* (user message → assistant response). +That model imports an assumption that does not hold here: the human is at +the keyboard continuously. In a microscopy experiment running 12+ hours, the +human checks in once, twice, maybe ten times. The agent is autonomous in +between. + +The right unit is a **decision moment**, triggered by: + +1. **User message** — rare, interrupting (classic chat turn). +2. **Critical event** — error, safety violation, lost focus, perception + anomaly. Wake immediately; decide to act / abort / escalate. +3. **Phase boundary** — between timepoints, between embryos. Built-in + checkpoint: review state, decide whether to continue. +4. **Periodic checkpoint** — every N minutes if nothing else happened. + Catches slow drifts. + +Between moments the agent is asleep. Plans execute autonomously. Events +accumulate on the bus and in the world model. When the next decision +moment fires, the agent reads: + +* The trigger (why am I waking up?) +* The world snapshot (NOW state) +* The events digest (what happened since last wake) +* The conversation history (which might be hours old and less relevant + than usual) + +This is closer to a **supervisory controller** than a chat partner. The +conversation history matters less than usual; what matters more is the +**flight log** (events) plus the **current state snapshot**. + +### Trigger model — concrete + +A small router (in code, not Claude) sits between the bus and the brain: + +``` + user input ─┐ + event bus ──┼─► wake-router ──► (compose context) ──► claude.messages.create + schedule ──┘ +``` + +The router's responsibilities: +* Subscribe to a whitelist of "wake-worthy" event types. +* Hold a debounce / coalescing buffer (so a burst of events becomes one + wake). +* Keep a heartbeat schedule (every N minutes if no other trigger fired). +* On wake, package: trigger, world snapshot, events digest, recent + conversation tail. +* Surface the package to the brain. + +The brain stays the brain (Claude). The router is cheap, deterministic, +debuggable code. It's the **meta-orchestrator** the operator mentioned — +**not as another LLM**, but as a control surface. + +### Phase boundaries: hand-back vs subscribe + +Two designs for letting the brain look in mid-plan: + +* **Plan hands control back** at well-known points (between embryos, every + 5 timepoints). Cheaper, predictable, slightly less reactive. +* **Plan never pauses; brain subscribes to plan events** ("perception + complete for embryo 3"). More reactive, more plan-coupling. + +The first one composes better with the supervisory-controller framing and +is the recommended starting point. + +### Idle ticks + +If 30 min pass with no event and no user, should the agent wake to verify +everything's OK? Default to **yes — periodic ticks with a high action +threshold.** Most ticks should result in the agent doing nothing. The +purpose is catching slow drifts (focus, sample state, hardware +degradation) that don't trigger their own events. + +--- + +## 4. World Model — Tiered, Not Monolithic + +A common mistake is "summarise everything every turn." Better is a tiered +model where different freshness/density tiers carry different cadence +costs. + +### Tier 1 — World snapshot + +Structured, ~30 lines, computed from in-memory state (not events), every +wake. + +Includes: live stage XY/Z, current session id, embryo list with +calibration state, current plan, acquisition status, recent operator +actions (one-line summary). + +Cheap to build, fresh every time. Already mostly present in the +codebase — `agent.experiment.get_summary()` plus the cached +`DEVICE_STATE_UPDATE` payload is 80% of this. + +### Tier 2 — Recent-events digest + +Hand-written formatter over the events bus, filtered to wake-worthy types, +inserted as a system note at each wake. + +Shape: `"Since last response: operator added embryo 4 via Map at 14:32; calibration completed for embryo 2; one perception trace pending."` + +Hand-written because LLM summarisation here adds latency, cost, and +non-determinism for low value. Events are already structured. + +### Tier 3 — Pull tools + +For when reasoning needs depth: `get_recent_perceptions(embryo_id, n=5)`, +`get_session_timeline()`, `get_learnings(campaign_id)`, etc. The agent +calls these when it wants the detail. + +### Tier 4 — Optional LLM summariser + +Reserved for genuinely natural-language streams that resist rule-based +compression: accumulated CV reasoning chains, narrative observations, +cross-session learnings. Use a smaller, faster Claude model (Haiku is the +natural fit). Run lazily, when a tier-3 tool asks for "summarise the last +30 min for embryo 3." + +### Why this shape + +Decision moments are **rare** in autonomous mode. Token budget per wake +can be generous (it's mostly idle compute). What matters more than budget +is **cadence of waking** — saving 200 tokens per turn doesn't help if +you're waking up at the wrong moments. + +--- + +## 5. Testing — Where Most Projects Fail + +You cannot iterate on agent architecture without a way to compare +architectures. Microscopy makes this hard: + +* Physical, non-deterministic, non-replayable in the trivial sense. +* "Correct" is fuzzy — biological judgements rarely have ground truth. +* Slow feedback (a timelapse takes hours). +* Can't always reset to a clean state (samples are consumed). + +Five testing primitives, ranked by payoff per unit work (this ordering +informed Phase 6's build order): + +### 5.1 Event replay *(built — Phase 6a/6b)* + +Capture the full event stream during a real run. Offline, replay it +through any candidate architecture. Diff its decisions against +production's. **Foundation** — without it, every change to the +orchestrator is a flight test. + +### 5.2 Shadow mode *(built — Phase 6d)* + +During a real experiment, candidate architectures run alongside +production. They see the same events but their decisions are *logged, +not enacted*. Unique value over pure replay: shadow agents experience +real temporal cadence, so timing-sensitive things (drift, races) are +caught. + +### 5.3 Synthetic event sequences + +Hand-crafted streams: cascading errors, ambiguous perception, +contradictory operator actions, focus drift, network drop mid-acquisition. +Stress / chaos testing. The orchestrator is correct if it doesn't do +something catastrophic — much easier to score than biological +correctness. + +Trivially built on top of 5.1 — write a `jsonl` by hand, replay it. + +### 5.4 Decision-level micro-benchmarks + +Specific judgements — "given this perception result and these recent +observations, should the agent re-focus?" — captured as +(input → expected decision) pairs labelled by a biologist. Regression +suite. Cheap with biologist time, expensive to bootstrap, very valuable +once you have a few hundred. + +### 5.5 Multi-agent A/B in production + +Two embryos in the same dish, one supervised by architecture A and one +by B (both honouring the firmware fence). Compare biological outcomes. +Slow (one timelapse per data point), but the **only thing that measures +biological correctness end-to-end.** + +--- + +## 6. Embryo Schema: Coarse vs Fine + +Foundational and quietly important. Each embryo carries: + +* `position_coarse` — set by bottom-camera detection or manual Map + placement. Always present. +* `position_fine` — set later by SPIM-objective alignment (workflow not + yet built). Initially `{}`. +* `stage_position` — a *derived property*: `fine if fine else coarse`. + Downstream motion / perception keeps reading this and stays agnostic + about which calibration stage we're in. + +This is the seed for a broader idea: **measurements have provenance and +calibration state**. The same embryo at the same nominal XY can have +different "true" positions depending on which sensor sighted it. Encode +that explicitly so any downstream decision can ask *"how confident is +this position?"* without needing to know the whole calibration history. + +When the operator drags an embryo on the Map, the PUT clears `fine` — +overriding the sighting invalidates any SPIM-derived fine alignment +derived from the old coarse. `OPERATOR_EDITED_EMBRYO` carries +`fine_position_invalidated` so the candidate / future controller can +schedule a re-alignment without inferring it. + +--- + +## 7. The Map as Collaborative World Model + +The Devices > Map page is more than visualisation. It is the **first +collaborative surface** between operator and orchestrator: both can read +the embryo list; both can mutate it. The orchestrator subscribes; the +operator clicks. + +Visual semantics matter: + +* Coarse-only embryo → outlined ring + number. *Provisional.* +* SPIM-fine-calibrated → filled disc + number. *Committed.* + +Calibration state is then visible at a glance across the slide — the +operator can scan and see "embryo 3 still needs alignment" without +opening anything. + +The pick-up / drop interaction (Phase 5) deliberately rejects +click-to-add: the Map is a schematic, not a satellite view. Adding a +new sighting without a visual reference is guessing. New embryos go +through the bottom-camera marking canvas. The Map is for **editing what +already exists**. + +### Future arc + +* **Annotations beyond position**: operator marks "this is the control", + "this one is dead, skip", "I think this is in 2-cell stage". These + become first-class scientific observations through additional + `OPERATOR_*` events. +* **Satellite tile**: render the live bottom-camera frame as an overlay + on the Map at the current stage XY, scaled by um_per_pixel. Inside + that tile, click-to-add becomes meaningful (you can see what you're + picking). Outside, the Map stays schematic. + +--- + +## 8. Revolutionary Trajectories + +Some of these are reasonable extensions; some are genuinely new. + +### 8.1 Plans-as-goals, not scripts + +Operator specifies "characterise gut development for these four +embryos." Orchestrator translates this into a continuously adapted +imaging plan that changes based on what perception sees mid-run. The +plan isn't a fixed script handed to Bluesky — it's a negotiation the +orchestrator keeps in flight, with the world model as the substrate +for adaptation. + +Requires: tier-1 + tier-2 world model, decent perception loop, a way +to express goals as predicates over the world model. + +### 8.2 Compounding cross-session learning + +`agent/learnings/` already exists. Today it's barely used. With replay ++ shadow, an architecture that proposes priors ("embryos at 3-fold +typically need slower piezo") becomes **A/B testable across sessions**. +Improvement gets *measurable*, which is the unlock — most "smart +microscopy" today is shallow because it has no measurement loop. + +The right framing: each session is a **trial**, the orchestrator is the +**experimenter**, the world model is what carries learning between +trials. + +### 8.3 Collaborative world model + +The Map (operator edits embryos) is the first instance. Extend +everywhere: + +* Operator annotates morphology on the Map → orchestrator updates + hypothesis space. +* Operator marks a focus failure → orchestrator marks the calibration + region as untrustworthy. +* Operator confirms a perception → orchestrator increases confidence in + the perception predicate for similar inputs. + +The point is making the operator's tacit knowledge **first-class data** +that the system can reason about, not just record. + +### 8.4 Reverse-mode microscopy + +"I want to know X — plan the imaging that answers X." The orchestrator +translates scientific goals into imaging plans. This is the +plans-as-goals idea taken to its conclusion: the operator describes +intent in scientific terms, the orchestrator owns the imaging strategy. + +Tractable only once 8.1 and the goal language are built. + +### 8.5 Continuous shadow / always-on critic + +Run the production orchestrator + a shadow candidate continuously, and +log all decision divergences. Over weeks, the divergence log becomes a +**dataset of disagreements**. Each disagreement is either: + +* Production was right, candidate was wrong → candidate needs a fix. +* Candidate was right, production was wrong → consider promotion or + investigate why production picked differently. +* Both were defensible → annotate the case. + +Free with the eval substrate (Phase 6); the only addition is a +divergence collator. + +--- + +## 9. Concretely Built Today (`paradigm/closed-loop` branch) + +| # | Commit | What | +| --- | --- | --- | +| 1 | `3e410581` | Schema split: `position_coarse` / `position_fine` / derived `stage_position`. | +| 2 | `617e54c9` | `ExperimentState.notify_embryos_changed()` observer → `EMBRYOS_UPDATE` broadcast. | +| 3 | `144d9fc9` | Map render layer — lavender rings (coarse) / discs (fine) / numbers. | +| 4 | `4fbb9edf` | `detect_embryos` flows through web Marking canvas; `auth.py` + `require_control`. | +| 5 | `8f6553e1` | Map pick-up / drop / Delete to edit embryos in place (control-gated PUT/DELETE). | +| 6 | `808fe813` | Side-fix: re-enable XY joystick at device-layer boot. | +| 7 | `f7a13d69` | Side-fix: image-anchored crosshair + scroll-to-zoom in camera panel. | +| 8 | `d69cc219` | `gently/eval/`: event capture / replay / shadow / decision log scaffolding. | +| 9 | `75d7c9db` | Production decision capture wired through `ConversationManager.call_claude`. | +| 10 | `0a97563e` | `OPERATOR_*` event vocabulary + `ReactiveCandidate` (first real shadow). | + +### Per-session disk shape now + +`D:\Gently3\sessions\{id}\` + +* `events.jsonl` — captured event bus, telemetry-filtered. +* `decisions.jsonl` — every Claude turn (success + error). +* `interaction_log.jsonl` — pre-existing chat-shaped interactions. +* `timeline.jsonl` — pre-existing session timeline. +* Plus everything from the legacy FileStore layout. + +### Eval CLI + +`python scripts/replay_session.py {session_id_prefix} [--histogram] [--candidate {name}] [--real-time] [--time-scale N]` + +--- + +## 10. What is *Not* Done Yet + +These are the natural follow-ups; sketched as future-self breadcrumbs. + +### Near-term (days) + +* **Tier-1 world snapshot** as a system-prompt section the brain sees + on every wake. Build the snapshot from `agent.experiment` plus the + last cached `DEVICE_STATE_UPDATE`. ~30 lines of formatted prose, every + wake. +* **Tier-2 events digest** — hand-written formatter that reads the + bus's recent meaningful events (or the captured jsonl tail) and + produces a one-paragraph "since last response" note. +* **Snapshot ingest into the brain's prompt** — `_update_system_prompt` + already takes a `context_summary`; route tier-1 + tier-2 through it. + +### Medium-term (weeks) + +* **Wake-router** — the code-level scheduler from §3. Currently the + brain only wakes on user message. Add: event-driven wake (subscribe + to wake-worthy events), periodic-tick wake (heartbeat), debounce / + coalesce buffer. +* **More operator events** — `OPERATOR_ANNOTATED_EMBRYO` ("this is the + control", "skip, looks dead"), `OPERATOR_STARTED_TIMELAPSE`, + `OPERATOR_INTERRUPTED_PLAN`, `OPERATOR_TOGGLED_CAMERA`. Whatever the + Map / web UI lets the operator do should publish a typed event. +* **SPIM-fine alignment workflow** — populate `position_fine`. Tool + + per-embryo state transition. Triggers `EMBRYOS_UPDATE` and a new + `FINE_ALIGNMENT_COMPLETED` event the orchestrator can react to. +* **Continuous-shadow harness** — extend `ShadowRunner` to run a + candidate alongside production in the live agent process (not just + during replay). Collect divergences into a per-session + `divergences.jsonl`. + +### Longer arc (months) + +* **A goal expression language** — predicates over the world model that + let the operator say "image until 4-fold" or "follow the cell + divisions in embryo 3 at high resolution." This is the substrate for + §8.1 (plans-as-goals). +* **LLM-driven candidates** — once the rule-based `ReactiveCandidate` + proves the substrate, add Claude-driven candidates (Haiku for cheap, + Opus for thinking). Use the snapshot+digest as their input. +* **Cross-session learning loop** — wire the `learnings/` store into + the world model as priors. Add a learning-write surface (a tool the + orchestrator can call when it notices a pattern). Use shadow A/B to + validate that learnings improve decisions. +* **Goal-driven evaluation** — once goals exist, "did the experiment + achieve its goal" becomes a measurable end-to-end success rate. The + ultimate metric is this, not turn-level decision diffs. + +--- + +## 11. Principles That Surface Throughout + +A few recurring design priors worth naming: + +1. **Distill, don't dump.** Structured summaries beat raw logs in + prompts. Hand-written formatters beat LLM summarisers for + structured data. LLMs for prose, code for structure. +2. **Pull beats push when uncertain.** Default to tools the agent + queries on demand, not data shoved into every prompt. Push only + what's universally relevant (the world snapshot). +3. **Same shape for production and shadow.** If production writes a + Decision with these fields, shadow candidates write Decisions with + the same fields. Diff is then trivial. +4. **Events carry intent; state carries position.** `EMBRYOS_UPDATE` + is state (the embryo list now). `OPERATOR_EDITED_EMBRYO` is intent + (a human just did this). Both exist; they answer different + questions. +5. **The brain doesn't move hardware.** All hardware action goes + through tools that go through the device layer that goes through + ophyd that goes through MMCore. Shadow candidates are constructively + prevented from acting. Layers are not negotiable. +6. **No SaveCardSettings.** Firmware persistent state silently inherits + between sessions; if it ever gets out of sync with code it's a + debugging nightmare. Apply firmware config every boot, code wins. +7. **Localhost is the diSPIM box. Remote is view-only by default.** + Auth surface stays tiny and explicit. Token upgrade is the seam, + not user accounts. + +--- + +## 12. Open Questions (Worth Revisiting Later) + +* **Continuous vs episodic shadows.** Continuous always-on shadow + captures divergence over time but multiplies cost (multiple LLM + candidates running). Episodic shadow at decision moments is cheaper + but misses timing-sensitive cases. Hybrid? +* **Is the conversation history the right substrate at all?** With + decision moments hours apart, prior chat may be more distracting + than useful. Maybe the brain shouldn't see chat history beyond N + hours; the world model + events digest are the durable memory and + chat is just for the active dialogue. +* **How much should the operator know about the orchestrator's plan?** + Today the operator drives by asking. With autonomous mode, the + orchestrator runs experiments largely on its own. Should there be a + permanent "what is the orchestrator thinking right now" surface + visible on the Map? An always-on intent display? +* **Failure semantics.** If a candidate would have made a different + decision than production, and production's decision led to a bad + outcome, the candidate "wins." How do we score? Define "bad outcome" + rigorously enough that it can be measured? + +These are not blockers. They are notes for the next iteration of this +document, after a few weeks of running on the substrate built here. diff --git a/docs/EVAL.md b/docs/EVAL.md new file mode 100644 index 00000000..73f35b69 --- /dev/null +++ b/docs/EVAL.md @@ -0,0 +1,187 @@ + + +> **Status:** design + intended usage for the `gently/eval/` capture/replay substrate and the +> proposed offline replay harness for testing agentic orchestrator patterns. Grounded in the +> code as of the 0.22 epoch; the harness itself is a work-in-progress (see the incremental plan). + +# Testing agentic orchestrator patterns offline (replay harness) + +## Goal + +We want to iterate on the agent's design — its realtime reasoning and the wake-router that +turns developmental events into autonomous turns — **without booking a live microscope run**. +Concretely: take a recorded session, simulate the microscope conditions from its on-disk +artifacts (captured events, recorded volumes, recorded perception traces), drive the *real* +`WakeRouter -> run_wake_turn -> Claude` loop offline, and observe/diff what the agent decides. +This lets us tune wake triggers, coalescing/throttling, prompt construction, and tool policy on +a laptop, replayed at a controllable clock (e.g. 10x), instead of waiting hours for embryos to +develop on a live rig. + +## What's already in place (reuse) + +A real, tested replay/eval substrate shipped in the 0.22 epoch (`gently/eval/`), plus production +capture wiring. None of this is hypothetical — it's on disk and runnable today: + +- **`gently/eval/event_capture.py`** — `EventCapture` wildcard-subscribes the bus and appends every + `Event.to_dict()` to `{session_dir}/events.jsonl`. Auto-wired into **every** live session by + `gently/app/agent.py` `_init_event_capture()` (line ~506, called at agent init). Skips only + `_NO_HISTORY_TYPES` (`DEVICE_STATE_UPDATE`, `BOTTOM_CAMERA_FRAME`, `LOG_RECORD`); `DETECTOR_EVALUATED` + and lifecycle events are **not** skipped. +- **`gently/eval/event_replay.py`** — `EventReplay(path).replay(target_bus, real_time=, time_scale=, on_event=)` + republishes each event via `target.publish_event(ev)`, **preserving the original `Event.timestamp`** + (not re-stamping `now()`). `real_time=True` sleeps `(ev.timestamp - prev)/time_scale` between events, + so cadence is reproducible. `event_types()` gives a pre-flight histogram. +- **`gently/eval/shadow.py` + `candidates.py`** — `ShadowRunner` + `OrchestratorCandidate` host + sandboxed rule-based candidates (e.g. `ReactiveCandidate`) that may *only* write a `DecisionLog`. +- **`gently/eval/decision_log.py`** — `Decision`/`DecisionLog` + `prompt_hash()` (sha256[:16] over + system prompt + messages) for apples-to-apples A/B diffing. +- **`scripts/replay_session.py`** — working CLI: resolves a session by id-prefix via + `FileStore.list_sessions`, prints `--histogram`, or replays `events.jsonl` into a **fresh** `EventBus()` + with an optional `NoOpCandidate`. +- **Recorded perception inputs/outputs** persist via `FileStore` (`gently/core/file_store.py`): + `embryos/{id}/volumes/t{NNNN}.tif` + `.meta.yaml`, `projections/t{NNNN}.jpg`, + `predictions.jsonl`, and `traces/t{NNNN}.json` (verbatim `predicted_stage`/`reasoning`/`raw_response`/`stability`). + Verified on disk: session `68e7dc33` has 9 embryos, 56 predictions on embryo_001, volume + `t0001.tif` shape `[50,512,2048]` uint16. +- **`timeline.jsonl`** (durable, predates eval) carries 256 `detection/evaluated` records on `68e7dc33` + with exactly the fields `WakeRouter._is_wake_worthy` reads (`embryo_id`, `timepoint`, + `detector_name`, `stage`, `reasoning`) — a fallback event source for pre-eval-scaffold sessions. + +### The one central wiring gap + +The real agent subscribes its `WakeRouter` to the **global singleton** bus +(`gently/app/agent.py:126` `self._event_bus = get_event_bus()`; WakeRouter built with that same bus). +But `scripts/replay_session.py:124` replays into a **fresh** `EventBus()` that the agent never sees. +**So today's replay reaches shadow candidates but never the real WakeRouter/agent.** Bridging this — +either `set_event_bus(replay_bus)` before constructing the agent, or replaying into `get_event_bus()` +directly — is the core seam to build. + +## Approaches, compared + +### (A) Event-stream replay into the agent's bus — *recommended first* +Publish recorded `DETECTOR_EVALUATED` + critical events (`HATCHING_DETECTED`, `EMBRYO_TERMINATED`, +`ERROR_OCCURRED`, …) onto the bus the agent's `WakeRouter` is subscribed to, on a controllable clock. + +- **Reuses:** `EventReplay`, `EventCapture` output, the entire real `WakeRouter` (`_is_wake_worthy` + filter at wake_router.py:129, coalesce `COALESCE_WINDOW=20s`, throttle `MIN_WAKE_INTERVAL=120s`, + `_flush -> agent.run_wake_turn`). +- **Fidelity:** Exercises the *real* wake path end-to-end: filtering, transition gate, coalescing, + throttling, prompt build, and a real Claude turn (`run_wake_turn -> handle_message_stream`, gated on + `agent.mode=='run'`). Highest leverage for the least new code. +- **Effort:** Medium. Needs (1) the bus bridge above; (2) a running asyncio loop so the async dispatch + + `loop.call_later` coalesce timers fire (`EventReplay.replay` is a blocking `time.sleep` loop — run it + in a thread or port it to `await asyncio.sleep`, and call `bus.set_event_loop(loop)`); (3) a stub + client so any tools the woken agent calls don't hit hardware (autonomous mode already refuses + irreversible tools via `_autonomous_active`). +- **Can't catch:** Anything depending on *fresh* perception of new pixels — the wake note embeds + `build_perception_snapshot(agent.perceiver, ...)`, which reads **live** Perceiver state, so this + approach needs (B) to make that snapshot reflect the replayed timepoint rather than empty state. +- **Blocker today:** No recorded session yet contains `DETECTOR_EVALUATED` (verified: all 20 captured + `events.jsonl` hold only `STATUS_CHANGED`/`EMBRYO_DETECTED`/`EMBRYOS_UPDATE`). Either capture one fresh + perception-driven session, or synthesize `DETECTOR_EVALUATED` events from `traces/`+`timeline.jsonl`. + +### (B) Perceiver stub — feed recorded traces +Replace `agent.perceiver`/`orchestrator.perceiver` with a duck-typed stub whose `__call__(...)` returns +`.stage`/`.reasoning` from `traces/t{NNNN}.json`, and whose `get_session(embryo_id)` returns an object +with `.stability`/`.summary()` matching what `build_perception_snapshot` reads +(`current_stage`/`stability`/`temporal`/`stage_sequence`). + +- **Reuses:** All downstream code in `_run_perception` (DETECTOR_EVALUATED emit, trace write, + `store_prediction`, `_check_interval_rules`) is pure local code; the Perceiver is the *only* external + VLM dependency. `perceiver` is already an optional ctor arg (timelapse.py:71). +- **Fidelity:** Reproduces recorded perception verbatim — no VLM spend, deterministic. Makes (A)'s wake + snapshot reflect the replayed timepoint. +- **Effort:** Low-medium (one stub class). +- **Can't catch:** Perception on *new* conditions — it only echoes what the recorded run already saw. + Also the stub interface was inferred from call sites (`templates.py` `build_perception_snapshot`, + `timelapse.py` `_run_perception`), not from `gently_perception` source — **verify against the + installed package** before relying on it. + +### (C) Full offline re-feed through the timelapse loop +Inject a fake `microscope_client` whose `acquire_volume(...)` returns +`{'success': True, 'volume': }` keyed by `(embryo_id, timepoint)`, +plus `move_to_position`, etc. `_has_microscope()` (`return self.client is not None`) then gates the +orchestrator *on*, driving the entire per-timepoint loop (acquire -> callback -> `_run_perception`). + +- **Reuses:** The whole `TimelapseOrchestrator`; `client` is a single ctor arg accessed only via named + async methods. +- **Fidelity:** Highest — exercises scheduling, acquisition callbacks, perception, event emission, and + wakes as one system. +- **Effort:** High. The orchestrator schedules entirely off `datetime.now()`/`asyncio.sleep` + (`_pick_next_due`, `_reschedule`, the acquire loop) and ignores `Event.timestamp` — a faithful + time-scaled run needs an **injectable clock** threaded through both the orchestrator and the + WakeRouter's wall-clock timers. There is also no helper to join recorded TIFFs to the event stream. +- **Can't catch:** Same perception-novelty limit as (B); plus device-state/camera-frame telemetry is + absent from `events.jsonl` and must be sourced from disk or synthesized. + +### (D) Shadow mode — score candidates / replay captured turns +Keep the existing `ShadowRunner` path: replay captured events into a bus with rule-based candidates +attached, diff `decisions.jsonl` (production) vs `replay-decisions-*.jsonl` (candidate) via `prompt_hash`. + +- **Reuses:** Fully built already (`scripts/replay_session.py --candidate`). +- **Fidelity:** Tests *alternative* (non-LLM) orchestrator architectures, not the production Claude agent. +- **Effort:** None (exists). +- **Can't catch:** The real agent's reasoning. Also: production only writes `Decision`s for **user turns** + (`conversation.py:343` hardcodes `trigger=DecisionTrigger.USER_MESSAGE`); wake turns via + `call_claude_stream` aren't logged as decisions, so there's currently no production wake-decision row to + diff against. + +## Honest fidelity limits + +- **Recorded perception ≠ new perception.** Approaches B/C echo `traces/`; they cannot evaluate the + Perceiver on conditions the original run didn't encounter. Genuinely testing perception requires a live + (or freshly captured) run. +- **LLM nondeterminism.** `run_wake_turn` makes a real Claude call; the same replayed input can yield + different tool calls run-to-run. `prompt_hash` isolates *input* identity but not *output* determinism — + diffs are about distributions/policy, not exact equality. +- **Clock vs coalesce/throttle.** `WakeRouter` uses real wall-clock `loop.call_later(COALESCE_WINDOW=20s)` + and `loop.time()`-based `MIN_WAKE_INTERVAL=120s` — these are **not** scaled by `time_scale`. A fast + replay collapses bursts into one wake; a high `time_scale` shrinks inter-event sleeps below the fixed + 20s window, again collapsing wakes. These tunables (currently module-level constants in + `wake_router.py:33-35`) must be parameterized/injectable for faithful timed replay. +- **Wall-clock reads break replay.** `TimelapseOrchestrator` (`timelapse.py`) drives scheduling off + `datetime.now()`/`asyncio.sleep` and never consults `Event.timestamp`; perception stamps + `timestamp=datetime.now()`. Any *new* events the woken agent emits use `publish()` (fresh `now()`), + intermixing replayed-historical and live-now timestamps on the same bus — a consistency hazard for + downstream diffing. +- **Telemetry gaps.** `EventCapture` skips `DEVICE_STATE_UPDATE`/`BOTTOM_CAMERA_FRAME`/`LOG_RECORD`, so a + replay can't reconstruct live device readouts or frames from `events.jsonl` (re-capture with + `EventCapture(path, skip=set())` or synthesize). +- **Data availability (verified).** No single recorded session yet combines a full timelapse with + non-trivial capture: `68e7dc33` has 9 embryos + volumes/traces but **no** `events.jsonl`; the newest + sessions have `events.jsonl` but 0 embryos and empty `decisions.jsonl`. All 20 captured `events.jsonl` + contain only setup-phase events — **zero** `DETECTOR_EVALUATED`. + +## Concrete incremental plan + +**Step 0 — Generate one good input stream (unblocks everything).** Either (a) run a single fresh +perception-driven session (live or with a stub client) after the eval-scaffold commit so +`events.jsonl` + non-empty `decisions.jsonl` coexist with volumes/traces; or (b) write a tiny +`synthesize_events.py` that emits `DETECTOR_EVALUATED` events from `68e7dc33`'s `traces/`+`timeline.jsonl` +into a synthetic `events.jsonl`. Validate with `python scripts/replay_session.py --histogram`. + +**Step 1 (smallest useful) — Bus bridge + offline driver skeleton.** New script +`scripts/replay_into_agent.py`: construct a `GentlyAgent` with a stub microscope client, call +`set_event_bus(replay_bus)` (or replay into `get_event_bus()`), set `agent.mode='run'` and +`wake_router.set_mode('ask')`, run an asyncio loop, `bus.set_event_loop(loop)`, and run +`EventReplay(...).replay(bus, real_time=True, time_scale=N)` in a thread. First milestone: a recorded +`DETECTOR_EVALUATED` actually fires `_on_event -> _flush -> run_wake_turn`. Reuses `EventReplay`, +`WakeRouter`, `run_wake_turn` unchanged. + +**Step 2 — Perceiver stub (B).** Add a `RecordedPerceiver` reading `traces/t{NNNN}.json`, injected via +`agent.perceiver`. Verify its `summary()`/result shape against the installed `gently_perception`. Now the +wake prompt's `build_perception_snapshot` reflects the replayed timepoint. + +**Step 3 — Injectable clock + parameterized tunables.** Thread a clock/`now()` provider through +`TimelapseOrchestrator` and make `COALESCE_WINDOW`/`MIN_WAKE_INTERVAL` injectable on `WakeRouter` so a +time-scaled replay reproduces the live wake set. Also scale or virtualize the loop timers. + +**Step 4 — Capture wake decisions + fix trigger labels.** Add `DecisionLog` capture to the +`call_claude_stream` (wake) path and emit `DecisionTrigger.EVENT` for wake turns (today `conversation.py` +only writes `USER_MESSAGE` decisions from `call_claude`). This makes replayed autonomous turns diffable. + +**Step 5 — Optional full re-feed (C).** Add a `RecordedMicroscopeClient` whose `acquire_volume` loads +`volumes/t{NNNN}.tif`, gating the orchestrator on for end-to-end loop testing. + +**Step 6 — Write `docs/EVAL.md`** (referenced as TODO in `gently/eval/__init__.py`) documenting the +replay workflow and fidelity tiers. diff --git a/docs/TOOLS.md b/docs/TOOLS.md index 6d9e020c..feafbfbe 100644 --- a/docs/TOOLS.md +++ b/docs/TOOLS.md @@ -13,7 +13,7 @@ Source: `gently/agent/tools/` (run mode) and `gently/agent/plan_mode/tools/` (pl |------|-------------| | `acquire_volume` | Acquire a single 3D lightsheet volume for a specific embryo with calibration data | | `capture_lightsheet` | Capture a single 2D lightsheet fluorescence image at specified piezo/galvo position | -| `batch_lightsheet` | Capture lightsheet images from ALL embryos and display as a stack | +| `batch_lightsheet` | Capture lightsheet images from ALL embryos and show them in the web UI viewer | ### Analysis (`analysis_tools.py`) @@ -29,22 +29,13 @@ Source: `gently/agent/tools/` (run mode) and `gently/agent/plan_mode/tools/` (pl | `calibrate_embryo` | Run full piezo-galvo calibration for a specific embryo using Claude vision | | `calibrate_all_embryos` | Run piezo-galvo calibration for all detected embryos sequentially | -### Data (`data_tools.py`) - -| Tool | Description | -|------|-------------| -| `list_runs` | List recent Bluesky runs from Databroker | -| `get_run_data` | Get data from a specific Bluesky run | -| `get_run_image` | Get an image from a Bluesky run for analysis | -| `search_runs` | Search Databroker runs by metadata criteria | - ### Detection (`detection_tools.py`) | Tool | Description | |------|-------------| | `detect_embryos` | Automatically detect embryos using brightness detection and SAM segmentation | | `manual_mark_embryos` | Open interactive window to manually mark embryos by clicking | -| `edit_embryos` | Open napari editor to add/remove/move embryo positions | +| `edit_embryos` | Add/remove/move embryo positions in the web map view | | `show_detected_embryos` | Capture fresh image and display all tracked embryos with labeled bounding boxes | ### Detectors (`detector_tools.py`) @@ -145,7 +136,7 @@ Source: `gently/agent/tools/` (run mode) and `gently/agent/plan_mode/tools/` (pl | Tool | Description | |------|-------------| | `view_image` | Capture and display current bottom camera widefield image | -| `view_volume` | Open a volume in napari for 3D visualization | +| `view_volume` | Open a volume in the in-browser 3D viewer | | `list_volumes` | List available volumes for an embryo or all embryos | --- diff --git a/docs/guides/capabilities.md b/docs/guides/capabilities.md index d3b097c7..846d662c 100644 --- a/docs/guides/capabilities.md +++ b/docs/guides/capabilities.md @@ -133,7 +133,6 @@ This design means experimental AI code — perception systems, coding agents, no | **Analysis** | analyze_volume, classify_embryo_stage | No | | **Experiment** | get_experiment_summary, query_embryo_status | No | | **Session** | list_sessions, import_embryos_from_session | No | -| **Data** | list_runs, get_run_data, search_runs | No | | **Planning** | create_campaign, propose_plan, search_literature | No | | **Research** | search_literature, read_paper, search_strains | No | diff --git a/docs/guides/try-offline.md b/docs/guides/try-offline.md index 50ca6e65..e4341291 100644 --- a/docs/guides/try-offline.md +++ b/docs/guides/try-offline.md @@ -4,29 +4,38 @@ Get the agent running in 10 minutes — no microscope needed. ## Prerequisites -- **Python 3.11+** -- **Node.js 18+** (for the terminal UI) +- **Python 3.10+** - An **Anthropic API key** (`ANTHROPIC_API_KEY` environment variable) +Gently is web-first — the agent runs in your browser, so there's no terminal UI to build (no Node.js needed for the app). + ## Install ```bash git clone https://github.com/pskeshu/gently.git cd gently -pip install -r requirements.txt +``` -# Build the TUI (one-time) -cd gently/tui -npm install -npm run build -cd ../.. +Create an environment and install — **either path works**: + +```bash +# venv + pip +python -m venv .venv +source .venv/bin/activate # Windows: .venv\Scripts\activate +pip install -e . +``` + +```bash +# or uv (https://docs.astral.sh/uv/) +uv venv +uv pip install -e . ``` ## Launch ```bash -export ANTHROPIC_API_KEY=sk-ant-... -python launch_gently.py --offline +export ANTHROPIC_API_KEY=sk-ant-... # Windows: set ANTHROPIC_API_KEY=sk-ant-... +python launch_gently.py --offline # uv (no activate): uv run python launch_gently.py --offline ``` The `--offline` flag skips the hardware connection. The full agent launches — conversation, perception, plan mode, memory — just without microscope control. diff --git a/examples/README.md b/examples/README.md index c538fbc0..cb6f6b06 100644 --- a/examples/README.md +++ b/examples/README.md @@ -5,11 +5,9 @@ Working examples of Gently's Bluesky plan system and visualization pipeline. | Example | Description | |---------|-------------| | `example_dispim_workflows.py` | Complete DiSPIM workflows: atomic plans, autofocus, two-point calibration, embryo detection, multi-embryo acquisition | -| `example_napari_visualization.py` | Real-time napari visualization: focus sweeps, embryo detection, dual-sided DiSPIM, custom configurations | ## Requirements ```bash pip install gently[device] # Bluesky + Ophyd for hardware plans -pip install napari[all] # For visualization examples ``` diff --git a/examples/example_dispim_workflows.py b/examples/example_dispim_workflows.py index 133555c3..e81e7574 100644 --- a/examples/example_dispim_workflows.py +++ b/examples/example_dispim_workflows.py @@ -9,7 +9,7 @@ This example shows: 1. Device-agnostic atomic plans (focus_sweep) 2. Autofocus functionality for precise positioning -3. Two-point calibration for coordinate mapping +3. Two-point calibration for coordinate mapping 4. Embryo detection with bottom camera 5. Complete multi-embryo acquisition workflows @@ -17,34 +17,25 @@ """ import logging + import numpy as np from bluesky import RunEngine from bluesky.callbacks import LiveTable # Import gently components from gently import ( - # Device classes - create_dispim_system, - DiSPIMSystem, - - # Plan functions - focus_sweep, # Atomic plan - dispim_piezo_autofocus, # Autofocus functionality - dispim_two_point_calibration, # Calibration plan - full_dispim_workflow, # Complete workflow - # Configuration AutofocusConfig, CalibrationConfig, - + FitFunction, # Analysis utilities - FocusAlgorithm, - FitFunction + FocusAlgorithm, # Complete workflow ) # Import optional napari visualization try: - import napari + import napari # noqa: F401 + NAPARI_AVAILABLE = True except ImportError: NAPARI_AVAILABLE = False @@ -53,46 +44,46 @@ def setup_dispim_session(): """ Setup DiSPIM session with devices and RunEngine - + In practice, this would use actual MM installation paths: system = create_dispim_system("/path/to/micromanager", "/path/to/config.cfg") """ print("Setting up DiSPIM session...") - + # Create RunEngine RE = RunEngine({}) - + # For demonstration, create a mock system # In practice: system = create_dispim_system(mm_dir, config_file) print(" [Note: Using mock system for demonstration]") print(" [In practice: system = create_dispim_system(mm_dir, config_file)]") system = None # Would be actual DiSPIMSystem - + # Setup live callbacks - live_table = LiveTable(['piezo_a_position', 'galvo_a_position', 'camera_a_image']) + live_table = LiveTable(["piezo_a_position", "galvo_a_position", "camera_a_image"]) RE.subscribe(live_table) - + return RE, system def setup_napari_visualization(RE): """Setup optional napari visualization for real-time image display""" - print("\n" + "="*60) + print("\n" + "=" * 60) print("OPTIONAL: NAPARI VISUALIZATION SETUP") - print("="*60) - + print("=" * 60) + if not NAPARI_AVAILABLE: print("\n⚠ Napari not available - skipping visualization setup") print(" To enable real-time image visualization:") print(" pip install napari[all]") print(" Then restart and run this example again.") return None - + print("\n✅ Napari available - setting up real-time visualization") - + # Setup napari callback for all DiSPIM experiments - napari_callback = enable_full_visualization(RE) - + napari_callback = enable_full_visualization(RE) # noqa: F821 + if napari_callback.enabled: print(" ✅ Napari callback enabled and subscribed to RunEngine") print(" ✅ Real-time image display will show:") @@ -100,71 +91,71 @@ def setup_napari_visualization(RE): print(" - Embryo detection scan images") print(" - Dual-channel DiSPIM data (green/magenta)") print(" - Individual camera acquisitions") - + print(f"\n Napari viewer: '{napari_callback.viewer.title}'") print(f" - Focus sweeps: {napari_callback.show_focus_sweeps}") print(f" - Embryo detection: {napari_callback.show_embryo_detection}") print(f" - Dual channels: {napari_callback.dual_channel_mode}") - + print("\n 💡 The napari window is now open - you can:") print(" - Adjust layer visibility and colors") print(" - Explore 3D image stacks") print(" - Take screenshots and export movies") print(" - View images in real-time during experiments") - + else: print(" ❌ Napari callback failed to initialize") return None - + return napari_callback def demonstrate_atomic_plans(RE, light_sheet): """Demonstrate device-agnostic atomic plans""" - print("\n" + "="*60) + print("\n" + "=" * 60) print("1. ATOMIC PLANS - Device-Agnostic Building Blocks") - print("="*60) - + print("=" * 60) + print("\nThe foundation: focus_sweep(positioner, positions, detector)") print("This atomic plan works with ANY positioner and detector:") - + # Define scan positions positions = np.linspace(-10, 10, 11) # 11 positions from -10 to +10 μm - + print(f"\n Positions to scan: {positions}") print(f" Number of positions: {len(positions)}") - + if light_sheet is not None: # Example 1: Piezo focus sweep print("\n Example 1: Piezo focus sweep") print(" focus_sweep(light_sheet.piezo, positions, light_sheet.camera)") # RE(focus_sweep(light_sheet.piezo, positions, light_sheet.camera)) - - # Example 2: Galvo focus sweep + + # Example 2: Galvo focus sweep print("\n Example 2: Galvo focus sweep") print(" focus_sweep(light_sheet.galvo, positions, light_sheet.camera)") # RE(focus_sweep(light_sheet.galvo, positions, light_sheet.camera)) - + # Example 3: XY stage sweep (device-agnostic) print("\n Example 3: XY stage sweep (device-agnostic!)") print(" focus_sweep(xy_stage.x, positions, light_sheet.camera)") # RE(focus_sweep(xy_stage.x, positions, light_sheet.camera)) else: print(" [Would execute: RE(focus_sweep(device, positions, detector))]") - + print("\nKey insight: Same atomic plan, different devices!") print("This is the power of device-agnostic Bluesky plans.") def demonstrate_autofocus_functionality(RE, light_sheet): """Demonstrate autofocus functionality for precise positioning""" - print("\n" + "="*60) + print("\n" + "=" * 60) print("2. AUTOFOCUS FUNCTIONALITY - Precise Device Positioning") - print("="*60) - + print("=" * 60) + print("\nAutofocus builds on the atomic focus_sweep plan:") print("dispim_piezo_autofocus() = focus_sweep() + analysis + validation") - + # Create autofocus configuration config = AutofocusConfig( num_positions=21, @@ -172,251 +163,276 @@ def demonstrate_autofocus_functionality(RE, light_sheet): algorithm=FocusAlgorithm.VOLATH.value, fit_function=FitFunction.GAUSSIAN.value, minimum_r_squared=0.75, - center_at_current=True + center_at_current=True, ) - - print(f"\nAutofocus Configuration:") + + print("\nAutofocus Configuration:") print(f" Positions: {config.num_positions}") print(f" Step size: {config.step_size_um} μm") print(f" Algorithm: {config.algorithm}") print(f" Fit function: {config.fit_function}") print(f" Min R²: {config.minimum_r_squared}") - + if light_sheet is not None: - print(f"\nExecuting autofocus:") - print(f" RE(dispim_piezo_autofocus(light_sheet, config))") + print("\nExecuting autofocus:") + print(" RE(dispim_piezo_autofocus(light_sheet, config))") # RE(dispim_piezo_autofocus(light_sheet, config)) - + if NAPARI_AVAILABLE: - print(f"\n 📺 With napari visualization:") - print(f" - Images stream to napari in real-time") - print(f" - 3D focus stack builds as positions are scanned") - print(f" - Can see focus quality at each position") - print(f" - Visual feedback on autofocus progress") + print("\n 📺 With napari visualization:") + print(" - Images stream to napari in real-time") + print(" - 3D focus stack builds as positions are scanned") + print(" - Can see focus quality at each position") + print(" - Visual feedback on autofocus progress") else: - print(f"\n[Would execute: RE(dispim_piezo_autofocus(light_sheet, config))]") - - print(f"\nAutofocus workflow:") - print(f" 1. bps.stage(light_sheet) # Save current state") - print(f" 2. focus_sweep(piezo, positions, camera) # Atomic plan") - print(f" 3. analyze_focus_stack(positions, images) # Find best position") - print(f" 4. bps.mv(piezo, best_position) # Move to focus") - print(f" 5. bps.unstage(light_sheet) # Restore if failed") - - print(f"\nThis enables precise, automated positioning for experiments!") + print("\n[Would execute: RE(dispim_piezo_autofocus(light_sheet, config))]") + + print("\nAutofocus workflow:") + print(" 1. bps.stage(light_sheet) # Save current state") + print(" 2. focus_sweep(piezo, positions, camera) # Atomic plan") + print(" 3. analyze_focus_stack(positions, images) # Find best position") + print(" 4. bps.mv(piezo, best_position) # Move to focus") + print(" 5. bps.unstage(light_sheet) # Restore if failed") + + print("\nThis enables precise, automated positioning for experiments!") def demonstrate_calibration_workflow(RE, light_sheet): """Demonstrate calibration workflows for coordinate mapping""" - print("\n" + "="*60) + print("\n" + "=" * 60) print("3. CALIBRATION WORKFLOW - Coordinate System Mapping") - print("="*60) - + print("=" * 60) + print("\nTwo-point calibration uses autofocus at each calibration point:") print("dispim_two_point_calibration() = move + autofocus + move + autofocus + fit") - + # Create calibration configuration autofocus_config = AutofocusConfig( num_positions=11, # Faster for calibration step_size_um=1.0, - algorithm=FocusAlgorithm.VOLATH.value + algorithm=FocusAlgorithm.VOLATH.value, ) - + cal_config = CalibrationConfig( point1_um=25.0, point2_um=75.0, autofocus_each_point=True, - autofocus_config=autofocus_config + autofocus_config=autofocus_config, ) - - print(f"\nCalibration Configuration:") + + print("\nCalibration Configuration:") print(f" Point 1: {cal_config.point1_um} μm") - print(f" Point 2: {cal_config.point2_um} μm") + print(f" Point 2: {cal_config.point2_um} μm") print(f" Autofocus at each point: {cal_config.autofocus_each_point}") - + if light_sheet is not None: - print(f"\nExecuting calibration:") - print(f" RE(dispim_two_point_calibration(light_sheet, cal_config))") + print("\nExecuting calibration:") + print(" RE(dispim_two_point_calibration(light_sheet, cal_config))") # RE(dispim_two_point_calibration(light_sheet, cal_config)) else: - print(f"\n[Would execute: RE(dispim_two_point_calibration(light_sheet, cal_config))]") - - print(f"\nCalibration workflow:") - print(f" 1. bps.mv(piezo, point1) # Move to first point") - print(f" 2. dispim_galvo_autofocus() # Focus galvo (uses atomic plans)") - print(f" 3. bps.trigger_and_read([...]) # Record positions") - print(f" 4. bps.mv(piezo, point2) # Move to second point") - print(f" 5. dispim_galvo_autofocus() # Focus galvo again") - print(f" 6. bps.trigger_and_read([...]) # Record positions") - print(f" 7. calculate_linear_fit() # Determine calibration") - - print(f"\nCalibration enables coordinate transformations between devices!") + print("\n[Would execute: RE(dispim_two_point_calibration(light_sheet, cal_config))]") + + print("\nCalibration workflow:") + print(" 1. bps.mv(piezo, point1) # Move to first point") + print(" 2. dispim_galvo_autofocus() # Focus galvo (uses atomic plans)") + print(" 3. bps.trigger_and_read([...]) # Record positions") + print(" 4. bps.mv(piezo, point2) # Move to second point") + print(" 5. dispim_galvo_autofocus() # Focus galvo again") + print(" 6. bps.trigger_and_read([...]) # Record positions") + print(" 7. calculate_linear_fit() # Determine calibration") + + print("\nCalibration enables coordinate transformations between devices!") def demonstrate_embryo_detection_workflow(RE, dispim_system): """Demonstrate embryo detection with bottom camera""" - print("\n" + "="*60) + print("\n" + "=" * 60) print("4. EMBRYO DETECTION WORKFLOW - Automated Sample Finding") - print("="*60) - + print("=" * 60) + print("\nEmbryo detection uses bottom camera for automated sample finding:") print("find_embryos_with_bottom_camera() = XY scan + image analysis + position recording") - + # Create detection configuration detection_config = { - 'scan_area': { - 'x_start': -1000, 'x_stop': 1000, # μm - 'y_start': -1000, 'y_stop': 1000, # μm - 'step_size': 200 # μm between positions + "scan_area": { + "x_start": -1000, + "x_stop": 1000, # μm + "y_start": -1000, + "y_stop": 1000, # μm + "step_size": 200, # μm between positions + }, + "detection": { + "min_size_pixels": 50, + "max_size_pixels": 500, + "brightness_threshold": 0.3, + "circularity_threshold": 0.7, }, - 'detection': { - 'min_size_pixels': 50, - 'max_size_pixels': 500, - 'brightness_threshold': 0.3, - 'circularity_threshold': 0.7 + "safety": { + "z_position_um": 0.0, # Safe Z position during XY scan + "max_scan_time_minutes": 10, }, - 'safety': { - 'z_position_um': 0.0, # Safe Z position during XY scan - 'max_scan_time_minutes': 10 - } } - - print(f"\nDetection Configuration:") - print(f" Scan area: {detection_config['scan_area']['x_start']} to {detection_config['scan_area']['x_stop']} μm (X)") - print(f" {detection_config['scan_area']['y_start']} to {detection_config['scan_area']['y_stop']} μm (Y)") + + print("\nDetection Configuration:") + print( + f" Scan area: {detection_config['scan_area']['x_start']} to" + f" {detection_config['scan_area']['x_stop']} μm (X)" + ) + print( + f" {detection_config['scan_area']['y_start']} to" + f" {detection_config['scan_area']['y_stop']} μm (Y)" + ) print(f" Step size: {detection_config['scan_area']['step_size']} μm") - print(f" Detection thresholds: size {detection_config['detection']['min_size_pixels']}-{detection_config['detection']['max_size_pixels']} pixels") - + print( + f" Detection thresholds: size" + f" {detection_config['detection']['min_size_pixels']}" + f"-{detection_config['detection']['max_size_pixels']} pixels" + ) + if dispim_system is not None: - print(f"\nExecuting embryo detection:") - print(f" RE(find_embryos_with_bottom_camera(dispim_system, detection_config))") + print("\nExecuting embryo detection:") + print(" RE(find_embryos_with_bottom_camera(dispim_system, detection_config))") # RE(find_embryos_with_bottom_camera(dispim_system, detection_config)) - + if NAPARI_AVAILABLE: - print(f"\n 📺 With napari visualization:") - print(f" - Each XY position shows in napari as it's acquired") - print(f" - See scan progress across the sample area") - print(f" - Detected embryos can be highlighted in real-time") - print(f" - Build up a mosaic view of the scanned region") + print("\n 📺 With napari visualization:") + print(" - Each XY position shows in napari as it's acquired") + print(" - See scan progress across the sample area") + print(" - Detected embryos can be highlighted in real-time") + print(" - Build up a mosaic view of the scanned region") else: - print(f"\n[Would execute: RE(find_embryos_with_bottom_camera(dispim_system, detection_config))]") - - print(f"\nDetection workflow:") - print(f" 1. bps.mv(xy_stage.z, safe_z_position) # Move to safe Z") - print(f" 2. XY grid scan with bottom camera # Scan entire area") - print(f" 3. Analyze images for embryo features # Find circular objects") - print(f" 4. Record embryo positions in stage coords # Store locations") - print(f" 5. Convert to light sheet coordinates # Transform coords") - - print(f"\nAutomated detection finds all samples for batch processing!") + print( + "\n[Would execute:" + " RE(find_embryos_with_bottom_camera(dispim_system, detection_config))]" + ) + + print("\nDetection workflow:") + print(" 1. bps.mv(xy_stage.z, safe_z_position) # Move to safe Z") + print(" 2. XY grid scan with bottom camera # Scan entire area") + print(" 3. Analyze images for embryo features # Find circular objects") + print(" 4. Record embryo positions in stage coords # Store locations") + print(" 5. Convert to light sheet coordinates # Transform coords") + + print("\nAutomated detection finds all samples for batch processing!") def demonstrate_complete_workflow(RE, dispim_system): """Demonstrate complete multi-embryo acquisition workflow""" - print("\n" + "="*60) + print("\n" + "=" * 60) print("5. COMPLETE WORKFLOW - Multi-Embryo Light Sheet Acquisition") - print("="*60) - + print("=" * 60) + print("\nThe complete workflow combines all components:") print("full_dispim_workflow() = calibration + embryo_detection + acquisition") - + # Create complete workflow configuration workflow_config = { - 'system_setup': { - 'center_devices': True, - 'run_calibration': True, - 'validate_hardware': True + "system_setup": { + "center_devices": True, + "run_calibration": True, + "validate_hardware": True, }, - 'calibration': { - 'point1_um': 25.0, - 'point2_um': 75.0, - 'autofocus_each_point': True, - 'autofocus_config': { - 'num_positions': 11, - 'step_size_um': 1.0, - 'algorithm': 'volath' - } + "calibration": { + "point1_um": 25.0, + "point2_um": 75.0, + "autofocus_each_point": True, + "autofocus_config": { + "num_positions": 11, + "step_size_um": 1.0, + "algorithm": "volath", + }, }, - 'embryo_detection': { - 'x_start': -1000, 'x_stop': 1000, - 'y_start': -1000, 'y_stop': 1000, - 'step_size': 200, - 'detection_thresholds': { - 'min_size': 50, 'max_size': 500, - 'brightness': 0.3, 'circularity': 0.7 - } + "embryo_detection": { + "x_start": -1000, + "x_stop": 1000, + "y_start": -1000, + "y_stop": 1000, + "step_size": 200, + "detection_thresholds": { + "min_size": 50, + "max_size": 500, + "brightness": 0.3, + "circularity": 0.7, + }, }, - 'acquisition': { - 'autofocus_config': { - 'num_positions': 21, - 'step_size_um': 0.5, - 'algorithm': 'volath' + "acquisition": { + "autofocus_config": { + "num_positions": 21, + "step_size_um": 0.5, + "algorithm": "volath", }, - 'z_stack': { - 'range_um': 50, # ±25 μm around focus - 'step_size_um': 1.0 + "z_stack": { + "range_um": 50, # ±25 μm around focus + "step_size_um": 1.0, }, - 'dual_sided': True, - 'time_points': 1 - } + "dual_sided": True, + "time_points": 1, + }, } - - print(f"\nWorkflow Configuration:") - print(f" Calibration: {workflow_config['calibration']['point1_um']} to {workflow_config['calibration']['point2_um']} μm") - print(f" Detection area: {workflow_config['embryo_detection']['x_start']} to {workflow_config['embryo_detection']['x_stop']} μm") - print(f" Z-stack range: ±{workflow_config['acquisition']['z_stack']['range_um']//2} μm") + + print("\nWorkflow Configuration:") + print( + f" Calibration: {workflow_config['calibration']['point1_um']} to" + f" {workflow_config['calibration']['point2_um']} μm" + ) + print( + f" Detection area: {workflow_config['embryo_detection']['x_start']} to" + f" {workflow_config['embryo_detection']['x_stop']} μm" + ) + print(f" Z-stack range: ±{workflow_config['acquisition']['z_stack']['range_um'] // 2} μm") print(f" Dual-sided: {workflow_config['acquisition']['dual_sided']}") - + if dispim_system is not None: - print(f"\nExecuting complete workflow:") - print(f" RE(full_dispim_workflow(dispim_system, workflow_config))") + print("\nExecuting complete workflow:") + print(" RE(full_dispim_workflow(dispim_system, workflow_config))") # RE(full_dispim_workflow(dispim_system, workflow_config)) else: - print(f"\n[Would execute: RE(full_dispim_workflow(dispim_system, workflow_config))]") - - print(f"\nComplete workflow stages:") - print(f" 1. System initialization and hardware validation") - print(f" 2. Two-point calibration (with autofocus)") - print(f" 3. Embryo detection with bottom camera") - print(f" 4. For each detected embryo:") - print(f" a. Move to embryo position") - print(f" b. Autofocus both sides") - print(f" c. Acquire dual-sided Z-stack") - print(f" d. Save data with metadata") - - print(f"\nAutomated, high-throughput DiSPIM experiments!") + print("\n[Would execute: RE(full_dispim_workflow(dispim_system, workflow_config))]") + + print("\nComplete workflow stages:") + print(" 1. System initialization and hardware validation") + print(" 2. Two-point calibration (with autofocus)") + print(" 3. Embryo detection with bottom camera") + print(" 4. For each detected embryo:") + print(" a. Move to embryo position") + print(" b. Autofocus both sides") + print(" c. Acquire dual-sided Z-stack") + print(" d. Save data with metadata") + + print("\nAutomated, high-throughput DiSPIM experiments!") def demonstrate_extensibility(RE): """Demonstrate how the atomic approach enables easy extension""" - print("\n" + "="*60) + print("\n" + "=" * 60) print("6. EXTENSIBILITY - Easy Addition of New Capabilities") - print("="*60) - + print("=" * 60) + print("\nBecause plans are device-agnostic, new capabilities are easy:") - - print(f"\nNew hardware? Same plans work:") - print(f" focus_sweep(new_positioner, positions, new_detector)") - print(f" dispim_piezo_autofocus(new_light_sheet, config)") - - print(f"\nNew algorithms? Just swap the analysis:") - print(f" config.algorithm = 'new_algorithm'") - print(f" Same dispim_piezo_autofocus() plan!") - - print(f"\nNew workflows? Compose existing plans:") - print(f" def adaptive_autofocus_with_ai(light_sheet, ai_callback):") - print(f" yield from dispim_piezo_autofocus(light_sheet, config)") - print(f" decision = ai_callback(result)") - print(f" if decision.refine:") - print(f" yield from dispim_galvo_autofocus(light_sheet, refined_config)") - - print(f"\nIntegration with other systems:") - print(f" def combined_microscopy_workflow(dispim, confocal, shared_stage):") - print(f" yield from focus_sweep(shared_stage.z, positions, dispim.camera)") - print(f" yield from focus_sweep(shared_stage.z, positions, confocal.camera)") - print(f" # Same atomic plan, different systems!") - - print(f"\nThe atomic approach scales naturally!") + + print("\nNew hardware? Same plans work:") + print(" focus_sweep(new_positioner, positions, new_detector)") + print(" dispim_piezo_autofocus(new_light_sheet, config)") + + print("\nNew algorithms? Just swap the analysis:") + print(" config.algorithm = 'new_algorithm'") + print(" Same dispim_piezo_autofocus() plan!") + + print("\nNew workflows? Compose existing plans:") + print(" def adaptive_autofocus_with_ai(light_sheet, ai_callback):") + print(" yield from dispim_piezo_autofocus(light_sheet, config)") + print(" decision = ai_callback(result)") + print(" if decision.refine:") + print(" yield from dispim_galvo_autofocus(light_sheet, refined_config)") + + print("\nIntegration with other systems:") + print(" def combined_microscopy_workflow(dispim, confocal, shared_stage):") + print(" yield from focus_sweep(shared_stage.z, positions, dispim.camera)") + print(" yield from focus_sweep(shared_stage.z, positions, confocal.camera)") + print(" # Same atomic plan, different systems!") + + print("\nThe atomic approach scales naturally!") def main(): @@ -429,17 +445,17 @@ def main(): print() print("Key concepts:") print(" - Atomic plans work with ANY compatible devices") - print(" - Autofocus enables precise positioning") + print(" - Autofocus enables precise positioning") print(" - Complex workflows compose atomic plans") print(" - Easy extensibility through device agnosticism") - + # Setup session RE, system = setup_dispim_session() - light_sheet = getattr(system, 'side_a', None) if system else None - + light_sheet = getattr(system, "side_a", None) if system else None + # Setup optional napari visualization - napari_callback = setup_napari_visualization(RE) - + setup_napari_visualization(RE) + # Run demonstrations demonstrate_atomic_plans(RE, light_sheet) demonstrate_autofocus_functionality(RE, light_sheet) @@ -447,50 +463,50 @@ def main(): demonstrate_embryo_detection_workflow(RE, system) demonstrate_complete_workflow(RE, system) demonstrate_extensibility(RE) - + # Summary - print("\n" + "="*60) + print("\n" + "=" * 60) print("SUMMARY - Complete DiSPIM Functionality") - print("="*60) - - print(f"\n✓ Created proper Ophyd devices (DiSPIMPiezo, DiSPIMCamera, etc.)") - print(f"✓ Built device-agnostic atomic plans (focus_sweep, move_and_acquire)") - print(f"✓ Implemented autofocus for precise positioning") - print(f"✓ Added calibration workflows for coordinate mapping") - print(f"✓ Created embryo detection for automated sample finding") - print(f"✓ Integrated complete multi-embryo acquisition workflows") + print("=" * 60) + + print("\n✓ Created proper Ophyd devices (DiSPIMPiezo, DiSPIMCamera, etc.)") + print("✓ Built device-agnostic atomic plans (focus_sweep, move_and_acquire)") + print("✓ Implemented autofocus for precise positioning") + print("✓ Added calibration workflows for coordinate mapping") + print("✓ Created embryo detection for automated sample finding") + print("✓ Integrated complete multi-embryo acquisition workflows") if NAPARI_AVAILABLE: - print(f"✓ Enabled real-time image visualization with napari") - - print(f"\nKey benefits:") - print(f" 1. Device-agnostic - plans work with any compatible hardware") - print(f" 2. Composable - atomic plans build into complex workflows") - print(f" 3. Extensible - easy to add new capabilities") - print(f" 4. Reliable - proper Bluesky integration with error handling") - - print(f"\nNext steps:") - print(f" 1. Test with real DiSPIM hardware using safety protocols") + print("✓ Enabled real-time image visualization with napari") + + print("\nKey benefits:") + print(" 1. Device-agnostic - plans work with any compatible hardware") + print(" 2. Composable - atomic plans build into complex workflows") + print(" 3. Extensible - easy to add new capabilities") + print(" 4. Reliable - proper Bluesky integration with error handling") + + print("\nNext steps:") + print(" 1. Test with real DiSPIM hardware using safety protocols") if not NAPARI_AVAILABLE: - print(f" 2. Install napari for real-time visualization: pip install napari[all]") - print(f" 3. Add image analysis for embryo detection") - print(f" 4. Integrate with VLM for intelligent workflows") - print(f" 5. Extend to other microscopy systems") + print(" 2. Install napari for real-time visualization: pip install napari[all]") + print(" 3. Add image analysis for embryo detection") + print(" 4. Integrate with VLM for intelligent workflows") + print(" 5. Extend to other microscopy systems") else: - print(f" 2. Add image analysis for embryo detection") - print(f" 3. Integrate with VLM for intelligent workflows") - print(f" 4. Extend to other microscopy systems") - - print(f"\nThe transformation is complete:") - print(f" 635-line Java monolith → Composable Bluesky atomic plans") - print(f" Device-specific code → Device-agnostic interfaces") - print(f" Rigid workflows → Flexible, extensible compositions") - - print(f"\nGently DiSPIM: Where atomic plans meet experimental flexibility! 🔬") + print(" 2. Add image analysis for embryo detection") + print(" 3. Integrate with VLM for intelligent workflows") + print(" 4. Extend to other microscopy systems") + + print("\nThe transformation is complete:") + print(" 635-line Java monolith → Composable Bluesky atomic plans") + print(" Device-specific code → Device-agnostic interfaces") + print(" Rigid workflows → Flexible, extensible compositions") + + print("\nGently DiSPIM: Where atomic plans meet experimental flexibility! 🔬") if __name__ == "__main__": # Setup logging logging.basicConfig(level=logging.INFO) - + # Run demonstration - main() \ No newline at end of file + main() diff --git a/examples/example_napari_visualization.py b/examples/example_napari_visualization.py deleted file mode 100644 index cfbad522..00000000 --- a/examples/example_napari_visualization.py +++ /dev/null @@ -1,522 +0,0 @@ -#!/usr/bin/env python -""" -DiSPIM Napari Visualization Examples -=================================== - -Demonstrates real-time image visualization for DiSPIM experiments using napari. -Shows different visualization patterns for various experiment types. - -This example shows: -1. Basic napari setup with Bluesky RunEngine -2. Focus sweep visualization (3D image stacks) -3. Embryo detection visualization (2D image sequences) -4. Dual-sided DiSPIM visualization (multi-channel) -5. Custom visualization configurations -6. Integration with complete DiSPIM workflows - -Requirements: - pip install napari[all] - # or with specific backend: pip install napari[pyqt5] -""" - -import logging -import numpy as np -from bluesky import RunEngine -from bluesky.callbacks import LiveTable - -# Import gently components -from gently import ( - # Device classes - create_dispim_system, - DiSPIMSystem, - - # Plan functions - focus_sweep, - dispim_piezo_autofocus, - find_embryos_with_bottom_camera, - full_dispim_workflow, - - # Configuration classes - AutofocusConfig, - CalibrationConfig, - - # Analysis utilities - FocusAlgorithm, - FitFunction -) - -# Import visualization utilities -from gently.ui.web import ( - EmbryoMarker, - mark_embryos_napari, - generate_focus_curve_plot, - generate_calibration_summary_plot, - generate_edge_detection_plot, -) - -# Napari availability check -try: - import napari - NAPARI_AVAILABLE = True -except ImportError: - NAPARI_AVAILABLE = False - - -def check_napari_installation(): - """Check if napari is available and provide installation instructions""" - if not NAPARI_AVAILABLE: - print("❌ Napari not available!") - print("\nTo enable image visualization, install napari:") - print(" pip install napari[all]") - print("\nOr with specific backend:") - print(" pip install napari[pyqt5]") - print(" # or napari[pyside2]") - print("\nAfter installation, restart and run this example again.") - return False - - print("✅ Napari is available - image visualization enabled!") - return True - - -def setup_demo_system(): - """Setup demo DiSPIM system for visualization examples""" - print("Setting up demo DiSPIM system...") - - # Create RunEngine - RE = RunEngine({}) - - # For demonstration, we'll use mock system - # In practice: system = create_dispim_system("/path/to/micromanager", "config.cfg") - print(" [Note: Using mock system for demonstration]") - system = None # Would be actual DiSPIMSystem - light_sheet = None # Would be system.side_a - - return RE, system, light_sheet - - -def demonstrate_basic_napari_setup(RE): - """Demonstrate basic napari visualization setup""" - print("\n" + "="*60) - print("1. BASIC NAPARI SETUP - Real-time Image Visualization") - print("="*60) - - print("\nSetting up napari for DiSPIM visualization:") - - # Create napari callback with default settings - napari_callback = setup_napari_callback() - - if not napari_callback.enabled: - print(" ❌ Napari callback disabled (napari not available)") - return None - - # Subscribe to RunEngine - RE.subscribe(napari_callback) - - print(" ✅ Napari callback created and subscribed") - print(f" ✅ Viewer title: {napari_callback.viewer.title}") - print(f" ✅ Focus sweeps: {napari_callback.show_focus_sweeps}") - print(f" ✅ Embryo detection: {napari_callback.show_embryo_detection}") - print(f" ✅ Dual channel: {napari_callback.dual_channel_mode}") - - print("\nBasic usage pattern:") - print(" RE = RunEngine({})") - print(" napari_callback = setup_napari_callback()") - print(" RE.subscribe(napari_callback)") - print(" # Now any plan with images will display in napari!") - - return napari_callback - - -def demonstrate_focus_sweep_visualization(RE, light_sheet, napari_callback): - """Demonstrate focus sweep visualization""" - print("\n" + "="*60) - print("2. FOCUS SWEEP VISUALIZATION - 3D Image Stacks") - print("="*60) - - if not napari_callback or not napari_callback.enabled: - print(" ⚠ Skipping - napari not available") - return - - print("\nFocus sweep creates 3D image stacks visualized in real-time:") - - # Configure autofocus - config = AutofocusConfig( - num_positions=15, # Fewer positions for faster demo - step_size_um=1.0, - algorithm=FocusAlgorithm.VOLATH.value, - fit_function=FitFunction.GAUSSIAN.value - ) - - print(f"\nAutofocus configuration:") - print(f" Positions: {config.num_positions}") - print(f" Step size: {config.step_size_um} μm") - print(f" Total range: ±{config.num_positions * config.step_size_um / 2} μm") - - if light_sheet is not None: - print(f"\nExecuting autofocus with napari visualization:") - print(f" RE(dispim_piezo_autofocus(light_sheet, config))") - - # This would display images in napari as they're acquired - # RE(dispim_piezo_autofocus(light_sheet, config)) - - print(f"\nNapari display:") - print(f" ✅ Images stream to napari as they're acquired") - print(f" ✅ 3D stack builds up in real-time") - print(f" ✅ Can scrub through Z positions") - print(f" ✅ Focus curve visible as image stack") - - else: - print(f"\n[Would execute: RE(dispim_piezo_autofocus(light_sheet, config))]") - print(f"\nExpected napari behavior:") - print(f" - New layer: 'Focus Sweep (Side A)'") - print(f" - Green colormap for side A data") - print(f" - 3D stack: shape (15, height, width)") - print(f" - Real-time updates as images acquired") - - print(f"\nVisualization features:") - print(f" - Real-time focus quality assessment") - print(f" - Immediate feedback on scan progress") - print(f" - Visual validation of focus curve") - - -def demonstrate_embryo_detection_visualization(RE, system, napari_callback): - """Demonstrate embryo detection visualization""" - print("\n" + "="*60) - print("3. EMBRYO DETECTION VISUALIZATION - 2D Image Sequences") - print("="*60) - - if not napari_callback or not napari_callback.enabled: - print(" ⚠ Skipping - napari not available") - return - - print("\nEmbryo detection creates sequences of 2D images from XY scanning:") - - # Configure embryo detection - detection_config = { - 'scan_area': { - 'x_start': -500, 'x_stop': 500, # Smaller area for demo - 'y_start': -500, 'y_stop': 500, - 'step_size': 100 # μm between positions - }, - 'detection': { - 'min_size_pixels': 50, - 'max_size_pixels': 500, - 'brightness_threshold': 0.3 - } - } - - print(f"\nDetection configuration:") - print(f" Scan area: {detection_config['scan_area']['x_start']} to {detection_config['scan_area']['x_stop']} μm") - print(f" Step size: {detection_config['scan_area']['step_size']} μm") - print(f" Grid size: 11x11 = 121 positions") - - if system is not None: - print(f"\nExecuting embryo detection with napari visualization:") - print(f" RE(find_embryos_with_bottom_camera(system, detection_config))") - - # This would display images in napari as XY scan progresses - # RE(find_embryos_with_bottom_camera(system, detection_config)) - - print(f"\nNapari display:") - print(f" ✅ Each XY position shows in napari immediately") - print(f" ✅ Can see scan progress across sample") - print(f" ✅ Potential embryos highlighted as found") - print(f" ✅ Final mosaic view of scanned area") - - else: - print(f"\n[Would execute: RE(find_embryos_with_bottom_camera(system, detection_config))]") - print(f"\nExpected napari behavior:") - print(f" - New layer: 'Embryo Detection (Side A)'") - print(f" - Updates with each XY position") - print(f" - 121 total images in sequence") - print(f" - Detected embryos marked/highlighted") - - print(f"\nVisualization benefits:") - print(f" - Real-time quality control of scan") - print(f" - Immediate feedback on embryo locations") - print(f" - Visual verification of detection algorithm") - - -def demonstrate_dual_channel_visualization(RE, system, napari_callback): - """Demonstrate dual-sided DiSPIM visualization""" - print("\n" + "="*60) - print("4. DUAL-CHANNEL VISUALIZATION - Multi-Camera Display") - print("="*60) - - if not napari_callback or not napari_callback.enabled: - print(" ⚠ Skipping - napari not available") - return - - print("\nDual-sided DiSPIM generates images from two cameras simultaneously:") - - if system is not None: - print(f"\nSimulating dual-sided acquisition:") - print(f" # Both sides acquire simultaneously") - print(f" side_a_image = system.side_a.camera.read()") - print(f" side_b_image = system.side_b.camera.read()") - - print(f"\nNapari display:") - print(f" ✅ Side A: Green channel") - print(f" ✅ Side B: Magenta channel") - print(f" ✅ Additive blending for overlay") - print(f" ✅ Separate layers for independent control") - print(f" ✅ Synchronized updates") - - else: - print(f"\n[Would show both camera feeds simultaneously]") - - print(f"\nColor scheme:") - print(f" - Side A (illumination from left): Green") - print(f" - Side B (illumination from right): Magenta") - print(f" - Overlaid: Shows complementary information") - - print(f"\nVisualization advantages:") - print(f" - Compare image quality from both sides") - print(f" - See complementary sample information") - print(f" - Identify optimal viewing angle") - print(f" - Real-time feedback for dual-sided experiments") - - -def demonstrate_custom_visualization_configs(RE): - """Demonstrate custom visualization configurations""" - print("\n" + "="*60) - print("5. CUSTOM CONFIGURATIONS - Tailored Visualization") - print("="*60) - - if not NAPARI_AVAILABLE: - print(" ⚠ Skipping - napari not available") - return - - print("\nCustom configurations for different experiment needs:") - - # Configuration 1: Focus-only visualization - print(f"\n1. Focus-Only Configuration:") - print(f" config = {{'show_focus_sweeps': True, 'show_embryo_detection': False}}") - print(f" napari_callback = setup_napari_callback(config)") - - focus_config = { - 'show_focus_sweeps': True, - 'show_embryo_detection': False, - 'show_single_images': False, - 'update_interval': 0.05 # Faster updates - } - - print(f" - Only shows focus sweep experiments") - print(f" - Faster update rate (0.05s)") - print(f" - Optimized for autofocus development") - - # Configuration 2: High-throughput visualization - print(f"\n2. High-Throughput Configuration:") - print(f" config = {{'show_single_images': False, 'update_interval': 1.0}}") - - throughput_config = { - 'show_focus_sweeps': True, - 'show_embryo_detection': True, - 'show_single_images': False, # Skip individual images - 'update_interval': 1.0 # Slower updates for performance - } - - print(f" - Skip individual images to reduce overhead") - print(f" - Slower update rate (1.0s) for performance") - print(f" - Better for automated, high-throughput experiments") - - # Configuration 3: Development/debugging - print(f"\n3. Development/Debugging Configuration:") - print(f" config = {{'show_single_images': True, 'update_interval': 0.01}}") - - debug_config = { - 'show_focus_sweeps': True, - 'show_embryo_detection': True, - 'show_single_images': True, - 'update_interval': 0.01 # Very fast updates - } - - print(f" - Show every image for detailed inspection") - print(f" - Very fast updates (0.01s)") - print(f" - Maximum detail for troubleshooting") - - print(f"\nUsage pattern:") - print(f" # Choose configuration for your needs") - print(f" config = focus_config # or throughput_config, debug_config") - print(f" napari_callback = setup_napari_callback(config)") - print(f" RE.subscribe(napari_callback)") - - -def demonstrate_convenience_functions(RE): - """Demonstrate convenience functions for common patterns""" - print("\n" + "="*60) - print("6. CONVENIENCE FUNCTIONS - Common Usage Patterns") - print("="*60) - - if not NAPARI_AVAILABLE: - print(" ⚠ Skipping - napari not available") - return - - print("\nConvenience functions for common visualization needs:") - - print(f"\n1. Focus Sweep Only:") - print(f" from gently.visualization import enable_focus_sweep_visualization") - print(f" enable_focus_sweep_visualization(RE)") - print(f" # Optimized for autofocus experiments") - - print(f"\n2. Embryo Detection Only:") - print(f" from gently.visualization import enable_embryo_detection_visualization") - print(f" enable_embryo_detection_visualization(RE)") - print(f" # Optimized for sample detection") - - print(f"\n3. Full Visualization:") - print(f" from gently.visualization import enable_full_visualization") - print(f" enable_full_visualization(RE)") - print(f" # Shows everything - good for general use") - - print(f"\n4. Custom Viewer:") - print(f" from gently.visualization import create_napari_viewer") - print(f" viewer = create_napari_viewer('My DiSPIM Experiment')") - print(f" callback = setup_napari_callback(viewer=viewer)") - print(f" # Use your own configured viewer") - - print(f"\nBenefits:") - print(f" - One-line setup for common patterns") - print(f" - Pre-configured for specific experiment types") - print(f" - Easy to integrate into existing workflows") - - -def demonstrate_complete_workflow_visualization(RE, system): - """Demonstrate visualization with complete DiSPIM workflow""" - print("\n" + "="*60) - print("7. COMPLETE WORKFLOW VISUALIZATION - Full Experiment") - print("="*60) - - if not NAPARI_AVAILABLE: - print(" ⚠ Skipping - napari not available") - return - - print("\nVisualization during complete multi-embryo workflow:") - - # Setup full visualization - print(f"\nSetting up comprehensive visualization:") - print(f" napari_callback = enable_full_visualization(RE)") - print(f" # Will show all stages of the workflow") - - # Complete workflow configuration - workflow_config = { - 'system_setup': { - 'center_devices': True, - 'run_calibration': True - }, - 'calibration': { - 'point1_um': 25.0, - 'point2_um': 75.0, - 'autofocus_each_point': True - }, - 'embryo_detection': { - 'x_start': -1000, 'x_stop': 1000, - 'y_start': -1000, 'y_stop': 1000, - 'step_size': 200 - }, - 'acquisition': { - 'z_stack': {'range_um': 50, 'step_size_um': 1.0}, - 'dual_sided': True, - 'time_points': 3 - } - } - - print(f"\nWorkflow stages with visualization:") - - if system is not None: - print(f"\n RE(full_dispim_workflow(system, workflow_config))") - print(f"\n Expected napari display sequence:") - - print(f" 1. Calibration stage:") - print(f" - Focus sweeps at calibration points") - print(f" - Real-time focus quality assessment") - - print(f" 2. Embryo detection stage:") - print(f" - XY scan images streaming in") - print(f" - Detected embryo positions highlighted") - - print(f" 3. Multi-embryo acquisition:") - print(f" - Focus sweeps for each embryo") - print(f" - Z-stack acquisitions (dual-channel)") - print(f" - Time series progression") - - print(f"\nVisualization benefits for complete workflow:") - print(f" ✅ Monitor entire experiment progress") - print(f" ✅ Quality control at each stage") - print(f" ✅ Early detection of issues") - print(f" ✅ Real-time data assessment") - print(f" ✅ Immediate feedback on results") - - -def main(): - """Main napari visualization demonstration""" - print("DiSPIM Napari Visualization Examples") - print("=" * 60) - print() - print("This example demonstrates real-time image visualization") - print("for DiSPIM experiments using napari and Bluesky callbacks.") - print() - - # Check napari installation - if not check_napari_installation(): - return - - # Setup demo system - RE, system, light_sheet = setup_demo_system() - - # Run demonstrations - napari_callback = demonstrate_basic_napari_setup(RE) - demonstrate_focus_sweep_visualization(RE, light_sheet, napari_callback) - demonstrate_embryo_detection_visualization(RE, system, napari_callback) - demonstrate_dual_channel_visualization(RE, system, napari_callback) - demonstrate_custom_visualization_configs(RE) - demonstrate_convenience_functions(RE) - demonstrate_complete_workflow_visualization(RE, system) - - # Summary - print("\n" + "="*60) - print("SUMMARY - Napari Visualization Integration") - print("="*60) - - print(f"\n✅ Napari integration complete:") - print(f" - Real-time image streaming from Bluesky plans") - print(f" - Automatic 3D stack visualization for focus sweeps") - print(f" - 2D image sequences for embryo detection") - print(f" - Dual-channel support for two-sided DiSPIM") - print(f" - Configurable visualization options") - - print(f"\n✅ Key benefits:") - print(f" - Immediate visual feedback during experiments") - print(f" - Quality control and error detection") - print(f" - Interactive data exploration") - print(f" - Non-intrusive - works with existing plans") - print(f" - Optional - graceful fallback if napari not available") - - print(f"\n✅ Usage patterns:") - print(f" - Basic: setup_napari_callback() → RE.subscribe()") - print(f" - Custom: setup_napari_callback(config) for specific needs") - print(f" - Convenience: enable_focus_sweep_visualization(RE)") - print(f" - Integration: Works with all existing DiSPIM plans") - - print(f"\nNext steps:") - print(f" 1. Install napari: pip install napari[all]") - print(f" 2. Add visualization to your DiSPIM experiments") - print(f" 3. Customize configurations for your needs") - print(f" 4. Enjoy real-time image feedback!") - - if napari_callback and napari_callback.enabled: - print(f"\nNapari viewer is open - explore the interface!") - print(f" - Layer controls for each image type") - print(f" - Color/brightness adjustments") - print(f" - 3D visualization controls") - print(f" - Screenshot and movie export options") - - print(f"\nGently DiSPIM + Napari: Real-time microscopy visualization! 🔬✨") - - -if __name__ == "__main__": - # Setup logging - logging.basicConfig(level=logging.INFO) - - # Run demonstration - main() \ No newline at end of file diff --git a/experimental/data_reasoning/__init__.py b/experimental/data_reasoning/__init__.py index 0450d5a1..8b06fd31 100644 --- a/experimental/data_reasoning/__init__.py +++ b/experimental/data_reasoning/__init__.py @@ -8,16 +8,16 @@ - GapPlanner: create plan items to fill annotation gaps """ +from .assessment import DataAssessmentEngine +from .coverage import CoverageAnalyzer +from .gap_planner import GapPlanner from .models import ( CoverageReport, DataQualityReport, NetworkDataInventory, SessionSummary, ) -from .assessment import DataAssessmentEngine -from .coverage import CoverageAnalyzer from .quality import QualityAnalyzer -from .gap_planner import GapPlanner __all__ = [ "CoverageReport", diff --git a/experimental/data_reasoning/assessment.py b/experimental/data_reasoning/assessment.py index 874d8953..e3ef9c2e 100644 --- a/experimental/data_reasoning/assessment.py +++ b/experimental/data_reasoning/assessment.py @@ -7,7 +7,6 @@ """ import logging -from typing import Optional from .models import NetworkDataInventory, SessionSummary @@ -66,16 +65,18 @@ def inventory_local(self) -> list: except Exception: pass - sessions.append(SessionSummary( - session_id=sid, - session_name=sname, - embryo_count=embryo_count, - volume_count=vol_count, - annotated_embryos=annotated, - ground_truth_count=gt_count, - stages_covered=sorted(stages), - is_remote=False, - )) + sessions.append( + SessionSummary( + session_id=sid, + session_name=sname, + embryo_count=embryo_count, + volume_count=vol_count, + annotated_embryos=annotated, + ground_truth_count=gt_count, + stages_covered=sorted(stages), + is_remote=False, + ) + ) except Exception as e: logger.error(f"Local inventory failed: {e}") @@ -98,6 +99,7 @@ async def inventory_remote(self) -> tuple: try: # Build a PeerInfo-like object for PeerClient from ..mesh.models import PeerInfo + peer = PeerInfo( instance_id=peer_entry.instance_id, hostname=peer_entry.hostname, @@ -111,14 +113,16 @@ async def inventory_remote(self) -> tuple: peers_failed += 1 continue for s in peer_sessions: - sessions.append(SessionSummary( - session_id=s.get("session_id", ""), - session_name=s.get("name", ""), - source_peer=peer_entry.instance_id, - embryo_count=s.get("embryo_count", 0), - volume_count=s.get("volume_count", 0), - is_remote=True, - )) + sessions.append( + SessionSummary( + session_id=s.get("session_id", ""), + session_name=s.get("name", ""), + source_peer=peer_entry.instance_id, + embryo_count=s.get("embryo_count", 0), + volume_count=s.get("volume_count", 0), + is_remote=True, + ) + ) except Exception as e: logger.debug(f"Remote inventory failed for {peer_entry.hostname}: {e}") peers_failed += 1 diff --git a/experimental/data_reasoning/coverage.py b/experimental/data_reasoning/coverage.py index ac9c1f0d..4a449729 100644 --- a/experimental/data_reasoning/coverage.py +++ b/experimental/data_reasoning/coverage.py @@ -3,7 +3,6 @@ """ import logging -from typing import Dict, List, Optional from .models import CoverageReport @@ -11,9 +10,21 @@ # Known C. elegans embryonic stages (common ordering) KNOWN_STAGES = [ - "early", "2-cell", "4-cell", "8-cell", "16-cell", - "32-cell", "64-cell", "gastrulation", "bean", "comma", - "1.5-fold", "2-fold", "pretzel", "3-fold", "hatching", + "early", + "2-cell", + "4-cell", + "8-cell", + "16-cell", + "32-cell", + "64-cell", + "gastrulation", + "bean", + "comma", + "1.5-fold", + "2-fold", + "pretzel", + "3-fold", + "hatching", ] # Minimum recommended samples per stage for training @@ -32,7 +43,7 @@ class CoverageAnalyzer: def __init__(self, gently_store=None): self._store = gently_store - def analyze(self, session_ids: Optional[List[str]] = None) -> CoverageReport: + def analyze(self, session_ids: list[str] | None = None) -> CoverageReport: """Analyze annotation coverage across specified sessions (or all). Parameters @@ -49,7 +60,7 @@ def analyze(self, session_ids: Optional[List[str]] = None) -> CoverageReport: total_embryos = 0 annotated_embryos = 0 - stage_counts: Dict[str, int] = {} + stage_counts: dict[str, int] = {} try: sessions = self._store.list_sessions() @@ -89,7 +100,11 @@ def analyze(self, session_ids: Optional[List[str]] = None) -> CoverageReport: # Find gaps and generate recommendations gaps = self._find_gaps(stage_counts) recommendations = self._generate_recommendations( - total_embryos, annotated_embryos, coverage_pct, stage_counts, gaps, + total_embryos, + annotated_embryos, + coverage_pct, + stage_counts, + gaps, ) return CoverageReport( @@ -104,7 +119,7 @@ def analyze(self, session_ids: Optional[List[str]] = None) -> CoverageReport: def analyze_from_inventory(self, inventory) -> CoverageReport: """Build a coverage report from a NetworkDataInventory.""" - stage_counts: Dict[str, int] = {} + stage_counts: dict[str, int] = {} total_embryos = inventory.total_embryos annotated_embryos = inventory.total_annotated @@ -118,7 +133,11 @@ def analyze_from_inventory(self, inventory) -> CoverageReport: imbalance_ratio = (max(counts) / min(counts)) if counts and min(counts) > 0 else 0.0 gaps = self._find_gaps(stage_counts) recommendations = self._generate_recommendations( - total_embryos, annotated_embryos, coverage_pct, stage_counts, gaps, + total_embryos, + annotated_embryos, + coverage_pct, + stage_counts, + gaps, ) return CoverageReport( @@ -131,7 +150,7 @@ def analyze_from_inventory(self, inventory) -> CoverageReport: recommendations=recommendations, ) - def _find_gaps(self, stage_counts: Dict[str, int]) -> List[str]: + def _find_gaps(self, stage_counts: dict[str, int]) -> list[str]: """Identify underrepresented stages.""" gaps = [] if not stage_counts: @@ -142,7 +161,9 @@ def _find_gaps(self, stage_counts: Dict[str, int]) -> List[str]: # Stages with too few samples for stage, count in stage_counts.items(): if count < MIN_SAMPLES_PER_STAGE: - gaps.append(f"{stage} underrepresented ({count} samples, need {MIN_SAMPLES_PER_STAGE})") + gaps.append( + f"{stage} underrepresented ({count} samples, need {MIN_SAMPLES_PER_STAGE})" + ) elif count < avg * 0.5: gaps.append(f"{stage} below average ({count} vs avg {avg:.0f})") @@ -159,9 +180,9 @@ def _generate_recommendations( total_embryos: int, annotated: int, coverage_pct: float, - stage_counts: Dict[str, int], - gaps: List[str], - ) -> List[str]: + stage_counts: dict[str, int], + gaps: list[str], + ) -> list[str]: """Generate actionable recommendations.""" recs = [] diff --git a/experimental/data_reasoning/gap_planner.py b/experimental/data_reasoning/gap_planner.py index 4540d7a1..0a7f6bc6 100644 --- a/experimental/data_reasoning/gap_planner.py +++ b/experimental/data_reasoning/gap_planner.py @@ -3,7 +3,6 @@ """ import logging -from typing import Optional from .models import CoverageReport @@ -70,6 +69,7 @@ def plan_annotation_campaign( # Create items for completely missing stages from .coverage import KNOWN_STAGES + present_stages = set(coverage_report.stage_counts.keys()) for stage in KNOWN_STAGES: if stage not in present_stages: diff --git a/experimental/data_reasoning/models.py b/experimental/data_reasoning/models.py index 9e6f5ef9..8044f0b3 100644 --- a/experimental/data_reasoning/models.py +++ b/experimental/data_reasoning/models.py @@ -3,12 +3,13 @@ """ from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional +from typing import Any @dataclass class SessionSummary: """Summary of a single imaging session's data.""" + session_id: str = "" session_name: str = "" source_peer: str = "" # instance_id of peer (empty = local) @@ -16,11 +17,11 @@ class SessionSummary: volume_count: int = 0 annotated_embryos: int = 0 ground_truth_count: int = 0 - stages_covered: List[str] = field(default_factory=list) + stages_covered: list[str] = field(default_factory=list) total_size_gb: float = 0.0 is_remote: bool = False - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "session_id": self.session_id, "session_name": self.session_name, @@ -38,8 +39,9 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class NetworkDataInventory: """Aggregated inventory of all data across the mesh.""" - local_sessions: List[SessionSummary] = field(default_factory=list) - remote_sessions: List[SessionSummary] = field(default_factory=list) + + local_sessions: list[SessionSummary] = field(default_factory=list) + remote_sessions: list[SessionSummary] = field(default_factory=list) total_embryos: int = 0 total_volumes: int = 0 total_annotated: int = 0 @@ -48,10 +50,10 @@ class NetworkDataInventory: peers_failed: int = 0 @property - def all_sessions(self) -> List[SessionSummary]: + def all_sessions(self) -> list[SessionSummary]: return self.local_sessions + self.remote_sessions - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "local_sessions": [s.to_dict() for s in self.local_sessions], "remote_sessions": [s.to_dict() for s in self.remote_sessions], @@ -67,15 +69,16 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class CoverageReport: """Annotation coverage analysis across network data.""" + total_embryos: int = 0 annotated_embryos: int = 0 coverage_pct: float = 0.0 - stage_counts: Dict[str, int] = field(default_factory=dict) + stage_counts: dict[str, int] = field(default_factory=dict) imbalance_ratio: float = 0.0 - gaps: List[str] = field(default_factory=list) - recommendations: List[str] = field(default_factory=list) + gaps: list[str] = field(default_factory=list) + recommendations: list[str] = field(default_factory=list) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "total_embryos": self.total_embryos, "annotated_embryos": self.annotated_embryos, @@ -90,13 +93,14 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class DataQualityReport: """Data quality validation results.""" + total_volumes_checked: int = 0 readable_volumes: int = 0 missing_projections: int = 0 inconsistent_annotations: int = 0 - issues: List[str] = field(default_factory=list) + issues: list[str] = field(default_factory=list) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "total_volumes_checked": self.total_volumes_checked, "readable_volumes": self.readable_volumes, diff --git a/experimental/data_reasoning/quality.py b/experimental/data_reasoning/quality.py index 86afcf6e..6596fe75 100644 --- a/experimental/data_reasoning/quality.py +++ b/experimental/data_reasoning/quality.py @@ -7,7 +7,6 @@ import logging from pathlib import Path -from typing import Optional from .models import DataQualityReport @@ -26,7 +25,9 @@ class QualityAnalyzer: def __init__(self, gently_store=None): self._store = gently_store - def analyze(self, session_ids: Optional[list] = None, check_files: bool = False) -> DataQualityReport: + def analyze( + self, session_ids: list | None = None, check_files: bool = False + ) -> DataQualityReport: """Run quality checks on local data. Parameters @@ -62,7 +63,11 @@ def analyze(self, session_ids: Optional[list] = None, check_files: bool = False) for vol in vols: # Check volume file exists if requested if check_files: - vpath = vol.file_path if hasattr(vol, "file_path") else vol.get("file_path", "") + vpath = ( + vol.file_path + if hasattr(vol, "file_path") + else vol.get("file_path", "") + ) if vpath and Path(vpath).exists(): report.readable_volumes += 1 elif vpath: @@ -87,7 +92,11 @@ def analyze(self, session_ids: Optional[list] = None, check_files: bool = False) # Check for overlapping timepoint ranges ranges = [] for gt in gts: - start = gt.start_tp if hasattr(gt, "start_tp") else gt.get("start_tp", 0) + start = ( + gt.start_tp + if hasattr(gt, "start_tp") + else gt.get("start_tp", 0) + ) end = gt.end_tp if hasattr(gt, "end_tp") else gt.get("end_tp", 0) if start and end and start > end: report.inconsistent_annotations += 1 diff --git a/experimental/ml_agent/agent.py b/experimental/ml_agent/agent.py index 4e31e697..69e398e9 100644 --- a/experimental/ml_agent/agent.py +++ b/experimental/ml_agent/agent.py @@ -5,14 +5,12 @@ execution loop, communicating back via the event bus. """ -import asyncio import json import logging -from typing import Any, Dict, List, Optional from gently.core.event_bus import EventType, get_event_bus from gently.harness.tools.registry import ToolRegistry -from .prompt import build_ml_system_prompt + from .tools import register_ml_tools logger = logging.getLogger(__name__) @@ -42,8 +40,8 @@ def __init__( self._peer_client = peer_client self._registry = ToolRegistry() self._running = False - self._task: Optional[str] = None - self._campaign_id: Optional[str] = None + self._task: str | None = None + self._campaign_id: str | None = None # Register ML tools on our private registry register_ml_tools( @@ -96,15 +94,16 @@ async def run(self, task: str, campaign_id: str = ""): # Step 2: Check coverage bus.publish( EventType.ML_SUBAGENT_STATUS, - {"status": "checking_coverage", "detail": "Analyzing annotation coverage..."}, + { + "status": "checking_coverage", + "detail": "Analyzing annotation coverage...", + }, source="ml_subagent", ) - coverage_result = await self._registry.execute( - "check_annotation_coverage", {} - ) + coverage_result = await self._registry.execute("check_annotation_coverage", {}) coverage = json.loads(coverage_result) - total_annotated = inventory.get("total_annotated", 0) + inventory.get("total_annotated", 0) total_gt = inventory.get("total_ground_truth", 0) # Step 3: Check if we have enough data @@ -134,7 +133,7 @@ async def run(self, task: str, campaign_id: str = ""): # Determine VRAM vram = 24.0 # default A5000 - local_sessions = inventory.get("local_sessions", []) + inventory.get("local_sessions", []) arch_result = await self._registry.execute( "select_architecture", diff --git a/experimental/ml_agent/prompt.py b/experimental/ml_agent/prompt.py index 74ee5247..cc301743 100644 --- a/experimental/ml_agent/prompt.py +++ b/experimental/ml_agent/prompt.py @@ -2,12 +2,10 @@ System prompt for the ML subagent. """ -from typing import Dict, List - def build_ml_system_prompt( - architecture_registry: Dict, - hardware_info: Dict, + architecture_registry: dict, + hardware_info: dict, available_data_summary: str = "", ) -> str: """Build the system prompt for the ML subagent. @@ -35,8 +33,10 @@ def build_ml_system_prompt( gpu_lines = [] gpus = hardware_info.get("gpus", []) for g in gpus: - gpu_lines.append(f"- GPU {g.get('device_index', 0)}: {g.get('name', 'unknown')} " - f"({g.get('vram_gb', 0)}GB VRAM)") + gpu_lines.append( + f"- GPU {g.get('device_index', 0)}: {g.get('name', 'unknown')} " + f"({g.get('vram_gb', 0)}GB VRAM)" + ) gpu_section = "\n".join(gpu_lines) if gpu_lines else "No GPUs detected." return f"""You are the ML Training Subagent for Gently, a microscopy automation system. @@ -50,14 +50,15 @@ def build_ml_system_prompt( ## Hardware {gpu_section} -CPU cores: {hardware_info.get('cpu_cores', 0)} -RAM: {hardware_info.get('ram_gb', 0)}GB +CPU cores: {hardware_info.get("cpu_cores", 0)} +RAM: {hardware_info.get("ram_gb", 0)}GB ## Available Data {available_data_summary or "Run inventory_datasets to discover available data."} ## Workflow -1. **Assess data**: Use inventory_datasets and check_annotation_coverage to understand what's available +1. **Assess data**: Use inventory_datasets and check_annotation_coverage to understand + what's available 2. **Check readiness**: If coverage is insufficient, report gaps and suggest annotation campaigns 3. **Select architecture**: Reason over the registry, dataset size, and hardware constraints 4. **Configure training**: Set hyperparameters appropriate for the data and architecture diff --git a/experimental/ml_agent/tools.py b/experimental/ml_agent/tools.py index 096d5d27..488ca04c 100644 --- a/experimental/ml_agent/tools.py +++ b/experimental/ml_agent/tools.py @@ -8,16 +8,15 @@ import json import logging from datetime import datetime -from pathlib import Path -from typing import Any, Dict, List, Optional from gently.harness.tools.registry import ToolCategory, ToolParameter logger = logging.getLogger(__name__) -def register_ml_tools(registry, context_store=None, gently_store=None, - verse_map=None, peer_client=None): +def register_ml_tools( + registry, context_store=None, gently_store=None, verse_map=None, peer_client=None +): """Register ML tools on a tool registry. Parameters @@ -50,6 +49,7 @@ def register_ml_tools(registry, context_store=None, gently_store=None, ) async def inventory_datasets(include_remote: bool = True, **kwargs) -> str: from gently.data_reasoning.assessment import DataAssessmentEngine + engine = DataAssessmentEngine( gently_store=gently_store, peer_client=peer_client, @@ -76,6 +76,7 @@ async def inventory_datasets(include_remote: bool = True, **kwargs) -> str: ) async def check_annotation_coverage(session_ids: str = "", **kwargs) -> str: from gently.data_reasoning.coverage import CoverageAnalyzer + analyzer = CoverageAnalyzer(gently_store=gently_store) sid_list = [s.strip() for s in session_ids.split(",") if s.strip()] or None report = analyzer.analyze(session_ids=sid_list) @@ -105,6 +106,7 @@ async def check_annotation_coverage(session_ids: str = "", **kwargs) -> str: ) async def select_architecture(dataset_size: int, vram_gb: float, **kwargs) -> str: from gently.ml.architectures import get_suitable_architectures + results = get_suitable_architectures(dataset_size, vram_gb) return json.dumps(results, indent=2) @@ -116,20 +118,46 @@ async def select_architecture(dataset_size: int, vram_gb: float, **kwargs) -> st ), category=ToolCategory.ML, parameters=[ - ToolParameter(name="campaign_id", type="string", - description="Campaign this pipeline belongs to", required=True), - ToolParameter(name="name", type="string", - description="Pipeline name", required=True), - ToolParameter(name="architecture", type="string", - description="Model architecture ID", required=True), - ToolParameter(name="num_classes", type="integer", - description="Number of output classes", required=True), - ToolParameter(name="batch_size", type="integer", - description="Training batch size", required=False, default=32), - ToolParameter(name="epochs", type="integer", - description="Number of training epochs", required=False, default=50), - ToolParameter(name="learning_rate", type="number", - description="Learning rate", required=False, default=1e-4), + ToolParameter( + name="campaign_id", + type="string", + description="Campaign this pipeline belongs to", + required=True, + ), + ToolParameter(name="name", type="string", description="Pipeline name", required=True), + ToolParameter( + name="architecture", + type="string", + description="Model architecture ID", + required=True, + ), + ToolParameter( + name="num_classes", + type="integer", + description="Number of output classes", + required=True, + ), + ToolParameter( + name="batch_size", + type="integer", + description="Training batch size", + required=False, + default=32, + ), + ToolParameter( + name="epochs", + type="integer", + description="Number of training epochs", + required=False, + default=50, + ), + ToolParameter( + name="learning_rate", + type="number", + description="Learning rate", + required=False, + default=1e-4, + ), ], ) async def configure_training( @@ -170,8 +198,12 @@ async def configure_training( ), category=ToolCategory.ML, parameters=[ - ToolParameter(name="pipeline_id", type="string", - description="Pipeline ID to train", required=True), + ToolParameter( + name="pipeline_id", + type="string", + description="Pipeline ID to train", + required=True, + ), ], ) async def start_local_training(pipeline_id: str, **kwargs) -> str: @@ -198,6 +230,7 @@ async def start_local_training(pipeline_id: str, **kwargs) -> str: return "Error: No data store available" from gently.ml.data_loader import build_labels_from_store + labels = build_labels_from_store(gently_store) if not labels.get("samples"): @@ -205,14 +238,15 @@ async def start_local_training(pipeline_id: str, **kwargs) -> str: # Write labels file from gently.settings import settings + run_dir = settings.storage.base_path / "ml_runs" / run_data["id"] run_dir.mkdir(parents=True, exist_ok=True) labels_file = run_dir / "labels.json" labels_file.write_text(json.dumps(labels, indent=2)) # Start trainer + from gently.ml.models import ModelConfig, TrainingConfig, TrainingRun from gently.ml.trainer import LocalTrainer - from gently.ml.models import TrainingRun, ModelConfig, TrainingConfig trainer = LocalTrainer(run_dir) run = TrainingRun( @@ -228,22 +262,30 @@ async def start_local_training(pipeline_id: str, **kwargs) -> str: ) context_store.update_training_run( - run_data["id"], status="training", started_at=datetime.now().isoformat(), + run_data["id"], + status="training", + started_at=datetime.now().isoformat(), ) - return json.dumps({ - "status": "training_started", - "run_id": run_data["id"], - "pipeline_id": pipeline_id, - }) + return json.dumps( + { + "status": "training_started", + "run_id": run_data["id"], + "pipeline_id": pipeline_id, + } + ) @registry.register( name="get_ml_status", description="Get the status of ML pipelines and training runs.", category=ToolCategory.ML, parameters=[ - ToolParameter(name="pipeline_id", type="string", - description="Pipeline ID (empty = list all)", required=False), + ToolParameter( + name="pipeline_id", + type="string", + description="Pipeline ID (empty = list all)", + required=False, + ), ], ) async def get_ml_status(pipeline_id: str = "", **kwargs) -> str: @@ -265,11 +307,19 @@ async def get_ml_status(pipeline_id: str = "", **kwargs) -> str: ), category=ToolCategory.ML, parameters=[ - ToolParameter(name="campaign_id", type="string", - description="Campaign to add items to", required=True), - ToolParameter(name="target_per_stage", type="integer", - description="Target annotations per stage (default 50)", - required=False, default=50), + ToolParameter( + name="campaign_id", + type="string", + description="Campaign to add items to", + required=True, + ), + ToolParameter( + name="target_per_stage", + type="integer", + description="Target annotations per stage (default 50)", + required=False, + default=50, + ), ], ) async def plan_annotation_campaign( @@ -290,8 +340,11 @@ async def plan_annotation_campaign( target_per_stage=target_per_stage, ) - return json.dumps({ - "created_plan_items": len(created_ids), - "item_ids": created_ids, - "coverage_before": report.to_dict(), - }, indent=2) + return json.dumps( + { + "created_plan_items": len(created_ids), + "item_ids": created_ids, + "coverage_before": report.to_dict(), + }, + indent=2, + ) diff --git a/experimental/test_data_reasoning.py b/experimental/test_data_reasoning.py index 14525d95..56164fcf 100644 --- a/experimental/test_data_reasoning.py +++ b/experimental/test_data_reasoning.py @@ -2,17 +2,15 @@ Tests for data reasoning engine — coverage, assessment, gap planning. """ -from unittest.mock import MagicMock, AsyncMock -import pytest +from unittest.mock import MagicMock +import pytest +from gently.data_reasoning.assessment import DataAssessmentEngine +from gently.data_reasoning.coverage import CoverageAnalyzer +from gently.data_reasoning.gap_planner import GapPlanner from gently.data_reasoning.models import ( CoverageReport, - NetworkDataInventory, - SessionSummary, ) -from gently.data_reasoning.coverage import CoverageAnalyzer -from gently.data_reasoning.assessment import DataAssessmentEngine -from gently.data_reasoning.gap_planner import GapPlanner def _make_mock_store(sessions=None, embryos_per_session=None, gt_per_embryo=None): @@ -41,10 +39,12 @@ def list_embryos(sid): e.nickname = eid result.append(e) return result + store.list_embryos.side_effect = list_embryos def list_volumes(sid, eid): return [MagicMock(timepoint=i) for i in range(3)] + store.list_volumes.side_effect = list_volumes def get_ground_truth(sid, eid): @@ -57,6 +57,7 @@ def get_ground_truth(sid, eid): gt.end_tp = 10 result.append(gt) return result + store.get_ground_truth.side_effect = get_ground_truth return store @@ -101,8 +102,8 @@ def test_imbalance_detection(self): embryos_per_session={"s1": ["e1", "e2", "e3"]}, gt_per_embryo={ ("s1", "e1"): ["early", "early", "early"], # 3 early - ("s1", "e2"): ["comma"], # 1 comma - ("s1", "e3"): ["early"], # 1 more early + ("s1", "e2"): ["comma"], # 1 comma + ("s1", "e3"): ["early"], # 1 more early }, ) analyzer = CoverageAnalyzer(gently_store=store) diff --git a/gently/__init__.py b/gently/__init__.py index be9985bb..29b0b3b4 100644 --- a/gently/__init__.py +++ b/gently/__init__.py @@ -7,54 +7,58 @@ # Main entry point from .gently import Gently, create_gently +from .harness.memory.store import ( + ContextStore, +) +# legacy SQLite store (kept for backward compat) # Harness (framework) -from .harness.tools.registry import tool, ToolRegistry, ToolCategory, get_tool_registry -from .harness.memory.store import ContextStore # legacy SQLite store (kept for backward compat) +from .harness.tools.registry import ToolCategory, ToolRegistry, get_tool_registry, tool + try: from .harness.memory.file_store import FileContextStore except ImportError: FileContextStore = None -from .harness.memory.interface import AgentMemory - # Core infrastructure from .core import ( EventBus, EventType, get_event_bus, ) - -# Core utilities -from .core.store import GentlyStore # legacy SQLite store (kept for backward compat) +from .core.coordinates import ( + DEFAULT_OBJECTIVE_MAG, + DEFAULT_PIXEL_SIZE_UM, + get_um_per_pixel, + pixel_displacement_to_stage_movement, + pixel_to_stage_position, + stage_to_pixel_position, +) from .core.file_store import FileStore from .core.imaging import ( - normalize_to_uint8, + clip_and_project, + generate_jpeg_projection, image_to_base64, + normalize_to_uint8, projection_three_view, render_volume_view, - clip_and_project, - generate_jpeg_projection, -) -from .core.coordinates import ( - pixel_to_stage_position, - stage_to_pixel_position, - pixel_displacement_to_stage_movement, - get_um_per_pixel, - DEFAULT_PIXEL_SIZE_UM, - DEFAULT_OBJECTIVE_MAG, ) +# Core utilities +from .core.store import GentlyStore # legacy SQLite store (kept for backward compat) +from .harness.memory.interface import AgentMemory + # Analysis utilities try: from .analysis.core import ( + FitFunction, + FocusAlgorithm, FocusAnalysisConfig, FocusResult, - FocusAlgorithm, - FitFunction, - calculate_focus_score, analyze_focus_stack, + calculate_focus_score, fit_focus_curve, ) + _ANALYSIS_AVAILABLE = True except ImportError: _ANALYSIS_AVAILABLE = False @@ -66,22 +70,22 @@ # Visualization (web map view replaces the retired napari marker) try: from .ui.web import ( - mark_embryos_web, - get_visualization_server, - generate_focus_curve_plot, generate_calibration_summary_plot, generate_edge_detection_plot, + generate_focus_curve_plot, + get_visualization_server, + mark_embryos_web, ) + _VISUALIZATION_AVAILABLE = True except ImportError: _VISUALIZATION_AVAILABLE = False -__version__ = "0.20.0" +__version__ = "0.22.0.dev0" __all__ = [ # Main entry point "Gently", "create_gently", - # Harness "tool", "ToolRegistry", @@ -89,7 +93,6 @@ "get_tool_registry", "ContextStore", # legacy SQLite store (backward compat) "AgentMemory", - # Core infrastructure "EventBus", "EventType", @@ -97,7 +100,6 @@ "GentlyStore", # legacy SQLite store (backward compat) "FileStore", "FileContextStore", - # Imaging "normalize_to_uint8", "image_to_base64", @@ -105,7 +107,6 @@ "render_volume_view", "clip_and_project", "generate_jpeg_projection", - # Coordinates "pixel_to_stage_position", "stage_to_pixel_position", @@ -113,7 +114,6 @@ "get_um_per_pixel", "DEFAULT_PIXEL_SIZE_UM", "DEFAULT_OBJECTIVE_MAG", - # Analysis "FocusAnalysisConfig", "FocusResult", @@ -122,4 +122,10 @@ "calculate_focus_score", "analyze_focus_stack", "fit_focus_curve", + # Visualization + "generate_calibration_summary_plot", + "generate_edge_detection_plot", + "generate_focus_curve_plot", + "get_visualization_server", + "mark_embryos_web", ] diff --git a/gently/agent/__init__.py b/gently/agent/__init__.py index 9ac0e133..6ef11680 100644 --- a/gently/agent/__init__.py +++ b/gently/agent/__init__.py @@ -6,38 +6,49 @@ This module re-exports from the new locations for backward compatibility. """ +from gently_perception import Perceiver, PerceptionOutput +from gently_perception import Session as PerceptionSession + from gently.app.agent import MicroscopyAgent -from gently.harness.state import EmbryoState, ExperimentState, ImageRecord from gently.harness.orchestration.plan_synthesis import PlanSynthesizer, PlanValidator -from gently_perception import Perceiver, PerceptionOutput, Session as PerceptionSession +from gently.harness.state import EmbryoState, ExperimentState, ImageRecord + try: from gently.hardware.dispim.device_factory import create_devices_from_mmcore except ImportError: create_devices_from_mmcore = None -from gently.app.queue_server_client import QueueServerClient -from gently.harness.tools.registry import ToolRegistry, get_tool_registry, tool, ToolCategory -from gently.app.benchmark import run_benchmark, BenchmarkResults, print_benchmark_results - # Import tools package to register all tools -from gently.app import tools +from gently.app import tools # noqa: F401 +from gently.app.benchmark import ( + BenchmarkResults, + print_benchmark_results, + run_benchmark, +) +from gently.app.queue_server_client import QueueServerClient +from gently.harness.tools.registry import ( + ToolCategory, + ToolRegistry, + get_tool_registry, + tool, +) __all__ = [ - 'MicroscopyAgent', - 'QueueServerClient', - 'EmbryoState', - 'ExperimentState', - 'ImageRecord', - 'PlanSynthesizer', - 'PlanValidator', - 'Perceiver', - 'PerceptionOutput', - 'PerceptionSession', - 'create_devices_from_mmcore', - 'ToolRegistry', - 'get_tool_registry', - 'tool', - 'ToolCategory', - 'run_benchmark', - 'BenchmarkResults', - 'print_benchmark_results', + "MicroscopyAgent", + "QueueServerClient", + "EmbryoState", + "ExperimentState", + "ImageRecord", + "PlanSynthesizer", + "PlanValidator", + "Perceiver", + "PerceptionOutput", + "PerceptionSession", + "create_devices_from_mmcore", + "ToolRegistry", + "get_tool_registry", + "tool", + "ToolCategory", + "run_benchmark", + "BenchmarkResults", + "print_benchmark_results", ] diff --git a/gently/analysis/__init__.py b/gently/analysis/__init__.py index a6772a4b..08c6e228 100644 --- a/gently/analysis/__init__.py +++ b/gently/analysis/__init__.py @@ -11,8 +11,8 @@ """ from .pipeline import ( - AnalysisStep, AnalysisResult, + AnalysisStep, Pipeline, PipelineBuilder, StepType, @@ -20,14 +20,13 @@ create_hatching_detection_pipeline, create_morphology_analysis_pipeline, ) - from .steps import ( - VLMStep, - SAMStep, + BlobDetectionStep, MaxProjectionStep, - ThresholdStep, MorphologyStep, - BlobDetectionStep, + SAMStep, + ThresholdStep, + VLMStep, ) __all__ = [ diff --git a/gently/analysis/core.py b/gently/analysis/core.py index 6c95df07..d5f93b57 100644 --- a/gently/analysis/core.py +++ b/gently/analysis/core.py @@ -10,19 +10,20 @@ """ import logging -from typing import Dict, List, Optional, Tuple, Any, Union from dataclasses import dataclass from enum import Enum +from typing import Any + import numpy as np -from scipy import optimize, stats +from scipy import optimize from scipy.ndimage import gaussian_filter, sobel -import warnings from ..exceptions import FocusFitError class FocusAlgorithm(Enum): """Focus scoring algorithms available""" + VOLATH = "volath" GRADIENT = "gradient" VARIANCE = "variance" @@ -32,11 +33,12 @@ class FocusAlgorithm(Enum): # FFT Bandpass parameters (from ASI diSPIM OughtaFocus implementation) # These define the spatial frequency band analyzed for focus quality FFT_LOWER_CUTOFF = 0.025 # 2.5% of max frequency - filters DC and low spatial frequencies -FFT_UPPER_CUTOFF = 0.14 # 14% of max frequency - filters high-frequency noise +FFT_UPPER_CUTOFF = 0.14 # 14% of max frequency - filters high-frequency noise class FitFunction(Enum): """Curve fitting functions available""" + GAUSSIAN = "gaussian" PARABOLIC = "parabolic" NONE = "none" @@ -45,6 +47,7 @@ class FitFunction(Enum): @dataclass class FocusAnalysisConfig: """Configuration for focus analysis operations""" + algorithm: str = FocusAlgorithm.VOLATH.value fit_function: str = FitFunction.GAUSSIAN.value minimum_r_squared: float = 0.75 @@ -56,14 +59,15 @@ class FocusAnalysisConfig: @dataclass class FocusResult: """Result of focus analysis""" + success: bool best_position: float best_score: float r_squared: float - fit_params: Optional[np.ndarray] = None - all_positions: Optional[np.ndarray] = None - all_scores: Optional[np.ndarray] = None - error_message: Optional[str] = None + fit_params: np.ndarray | None = None + all_positions: np.ndarray | None = None + all_scores: np.ndarray | None = None + error_message: str | None = None class AdaptiveSweepState: @@ -82,15 +86,15 @@ class AdaptiveSweepState: # Early stopping thresholds DECLINE_THRESHOLD: float = 0.70 # Stop if score drops below 70% of max - MIN_DECLINE_COUNT: int = 3 # Require N consecutive declines - MIN_POINTS_FOR_FIT: int = 5 # Minimum points before attempting fit - MIN_POINTS_PAST_PEAK: int = 3 # Need N points past peak for robust fit + MIN_DECLINE_COUNT: int = 3 # Require N consecutive declines + MIN_POINTS_FOR_FIT: int = 5 # Minimum points before attempting fit + MIN_POINTS_PAST_PEAK: int = 3 # Need N points past peak for robust fit STABILITY_THRESHOLD_UM: float = 0.5 # Peak position stability threshold HIGH_CONFIDENCE_R2: float = 0.90 # R² threshold for early exit def __init__(self): - self.positions: List[float] = [] - self.scores: List[float] = [] + self.positions: list[float] = [] + self.scores: list[float] = [] # Peak detection state self.running_max_score: float = 0.0 @@ -103,9 +107,9 @@ def __init__(self): self.current_r_squared: float = 0.0 self.fit_stable: bool = False self.last_fit_position: float = 0.0 - self.fit_history: List[Dict[str, float]] = [] + self.fit_history: list[dict[str, float]] = [] - def add_point(self, position: float, score: float) -> Dict[str, Any]: + def add_point(self, position: float, score: float) -> dict[str, Any]: """ Add new measurement and compute early stopping decision. @@ -127,9 +131,9 @@ def add_point(self, position: float, score: float) -> Dict[str, Any]: self.scores.append(score) result = { - 'should_stop': False, - 'reason': None, - 'confidence': 0.0, + "should_stop": False, + "reason": None, + "confidence": 0.0, } # Update running max @@ -147,44 +151,49 @@ def add_point(self, position: float, score: float) -> Dict[str, Any]: points_past_peak = len(self.positions) - self.running_max_idx - 1 if self.decline_count >= self.MIN_DECLINE_COUNT: self.peak_detected = True - result['confidence'] = 0.7 + result["confidence"] = 0.7 # Continue a few more points past detected peak for robust fitting if points_past_peak >= self.MIN_POINTS_PAST_PEAK: - result['should_stop'] = True - result['reason'] = 'peak_passed' + result["should_stop"] = True + result["reason"] = "peak_passed" # Confidence-based early exit (if we have enough points) if len(self.positions) >= self.MIN_POINTS_FOR_FIT: fit_result = self._attempt_fit() if fit_result: - self.current_r_squared = fit_result['r_squared'] - new_position = fit_result['peak_position'] + self.current_r_squared = fit_result["r_squared"] + new_position = fit_result["peak_position"] # Track fit history for stability check - self.fit_history.append({ - 'position': new_position, - 'r_squared': fit_result['r_squared'], - }) + self.fit_history.append( + { + "position": new_position, + "r_squared": fit_result["r_squared"], + } + ) # Check stability (position change across recent fits) if len(self.fit_history) >= 3: - recent_positions = [f['position'] for f in self.fit_history[-3:]] + recent_positions = [f["position"] for f in self.fit_history[-3:]] position_range = max(recent_positions) - min(recent_positions) self.fit_stable = position_range < self.STABILITY_THRESHOLD_UM self.last_fit_position = new_position # High confidence early exit - if (self.current_r_squared >= self.HIGH_CONFIDENCE_R2 and - self.fit_stable and len(self.positions) >= 7): - result['should_stop'] = True - result['reason'] = 'high_confidence_fit' - result['confidence'] = self.current_r_squared + if ( + self.current_r_squared >= self.HIGH_CONFIDENCE_R2 + and self.fit_stable + and len(self.positions) >= 7 + ): + result["should_stop"] = True + result["reason"] = "high_confidence_fit" + result["confidence"] = self.current_r_squared return result - def _attempt_fit(self) -> Optional[Dict[str, Any]]: + def _attempt_fit(self) -> dict[str, Any] | None: """ Attempt Gaussian fit on current data. @@ -201,19 +210,17 @@ def _attempt_fit(self) -> Optional[Dict[str, Any]]: scores = np.array(self.scores) # Use existing fit_focus_curve - _, _, params, r_squared = fit_focus_curve( - positions, scores, FitFunction.GAUSSIAN.value - ) + _, _, params, r_squared = fit_focus_curve(positions, scores, FitFunction.GAUSSIAN.value) return { - 'peak_position': float(params[1]), # mu - 'r_squared': r_squared, - 'params': params, + "peak_position": float(params[1]), # mu + "r_squared": r_squared, + "params": params, } except Exception: return None - def get_best_position(self) -> Tuple[float, float]: + def get_best_position(self) -> tuple[float, float]: """ Get best focus position from current data. @@ -227,8 +234,8 @@ def get_best_position(self) -> Tuple[float, float]: # Try fit first fit_result = self._attempt_fit() - if fit_result and fit_result['r_squared'] >= 0.5: - return (fit_result['peak_position'], fit_result['r_squared']) + if fit_result and fit_result["r_squared"] >= 0.5: + return (fit_result["peak_position"], fit_result["r_squared"]) # Fall back to max score position return (self.running_max_position, 0.0) @@ -248,9 +255,12 @@ def reset(self): self.fit_history = [] -def calculate_focus_score(image: np.ndarray, algorithm: str = FocusAlgorithm.VOLATH.value, - roi: Optional[Tuple[int, int, int, int]] = None, - config: Optional[FocusAnalysisConfig] = None) -> float: +def calculate_focus_score( + image: np.ndarray, + algorithm: str = FocusAlgorithm.VOLATH.value, + roi: tuple[int, int, int, int] | None = None, + config: FocusAnalysisConfig | None = None, +) -> float: """ Calculate focus score for an image using specified algorithm @@ -285,12 +295,12 @@ def calculate_focus_score(image: np.ndarray, algorithm: str = FocusAlgorithm.VOL # Apply ROI if specified if roi is not None: x, y, w, h = roi - image = image[y:y+h, x:x+w] + image = image[y : y + h, x : x + w] # Crop edges to avoid boundary effects if config.edge_crop > 0: crop = config.edge_crop - if image.shape[0] > 2*crop and image.shape[1] > 2*crop: + if image.shape[0] > 2 * crop and image.shape[1] > 2 * crop: image = image[crop:-crop, crop:-crop] # Convert to float for calculations @@ -324,7 +334,7 @@ def _volath_focus_score(image: np.ndarray) -> float: shifted = np.roll(image, 1, axis=1) product_sum = np.sum(image * shifted) - return product_sum - (mean_val ** 2) * image.size + return product_sum - (mean_val**2) * image.size except Exception as e: logging.getLogger(__name__).error(f"Volath focus score failed: {e}") @@ -362,9 +372,11 @@ def _variance_focus_score(image: np.ndarray) -> float: return 0.0 -def _fft_bandpass_focus_score(image: np.ndarray, - lower_cutoff: float = FFT_LOWER_CUTOFF, - upper_cutoff: float = FFT_UPPER_CUTOFF) -> float: +def _fft_bandpass_focus_score( + image: np.ndarray, + lower_cutoff: float = FFT_LOWER_CUTOFF, + upper_cutoff: float = FFT_UPPER_CUTOFF, +) -> float: """ FFT bandpass focus measure (ASI diSPIM OughtaFocus algorithm). @@ -408,7 +420,7 @@ def _fft_bandpass_focus_score(image: np.ndarray, # Create distance map from center (DC component) y, x = np.ogrid[:h, :w] - distance_from_center = np.sqrt((x - cx)**2 + (y - cy)**2) + distance_from_center = np.sqrt((x - cx) ** 2 + (y - cy) ** 2) # Maximum frequency (corner distance) max_freq = np.sqrt(cx**2 + cy**2) @@ -417,7 +429,9 @@ def _fft_bandpass_focus_score(image: np.ndarray, normalized_distance = distance_from_center / max_freq # Create bandpass mask: keep frequencies in [lower_cutoff, upper_cutoff] - bandpass_mask = (normalized_distance >= lower_cutoff) & (normalized_distance <= upper_cutoff) + bandpass_mask = (normalized_distance >= lower_cutoff) & ( + normalized_distance <= upper_cutoff + ) # Apply mask and compute mean power in band masked_power = power_spectrum * bandpass_mask @@ -434,8 +448,9 @@ def _fft_bandpass_focus_score(image: np.ndarray, return 0.0 -def analyze_focus_stack(positions: List[float], images: List[np.ndarray], - config: FocusAnalysisConfig) -> FocusResult: +def analyze_focus_stack( + positions: list[float], images: list[np.ndarray], config: FocusAnalysisConfig +) -> FocusResult: """ Analyze a complete focus stack to find best focus position @@ -459,14 +474,20 @@ def analyze_focus_stack(positions: List[float], images: List[np.ndarray], try: if len(positions) != len(images): return FocusResult( - success=False, best_position=0.0, best_score=0.0, r_squared=0.0, - error_message="Positions and images length mismatch" + success=False, + best_position=0.0, + best_score=0.0, + r_squared=0.0, + error_message="Positions and images length mismatch", ) if len(positions) < 3: return FocusResult( - success=False, best_position=0.0, best_score=0.0, r_squared=0.0, - error_message="Need at least 3 data points for analysis" + success=False, + best_position=0.0, + best_score=0.0, + r_squared=0.0, + error_message="Need at least 3 data points for analysis", ) # Calculate focus scores for all images @@ -495,7 +516,7 @@ def analyze_focus_stack(positions: List[float], images: List[np.ndarray], best_position = positions[best_idx] best_score = scores[best_idx] - except (FocusFitError, Exception) as fit_error: + except (FocusFitError, Exception): # Fallback to highest measured score best_idx = np.argmax(scores) best_position = positions[best_idx] @@ -510,18 +531,24 @@ def analyze_focus_stack(positions: List[float], images: List[np.ndarray], r_squared=float(r_squared), fit_params=fit_params, all_positions=positions, - all_scores=scores + all_scores=scores, ) except Exception as e: return FocusResult( - success=False, best_position=0.0, best_score=0.0, r_squared=0.0, - error_message=str(e) + success=False, + best_position=0.0, + best_score=0.0, + r_squared=0.0, + error_message=str(e), ) -def fit_focus_curve(positions: np.ndarray, scores: np.ndarray, - fit_function: str = FitFunction.GAUSSIAN.value) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float]: +def fit_focus_curve( + positions: np.ndarray, + scores: np.ndarray, + fit_function: str = FitFunction.GAUSSIAN.value, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, float]: """ Fit a curve to focus score data @@ -545,7 +572,7 @@ def fit_focus_curve(positions: np.ndarray, scores: np.ndarray, # Create high-resolution position array for smooth curve pos_range = np.max(positions) - np.min(positions) pos_center = (np.max(positions) + np.min(positions)) / 2 - fitted_positions = np.linspace(pos_center - pos_range*0.6, pos_center + pos_range*0.6, 100) + fitted_positions = np.linspace(pos_center - pos_range * 0.6, pos_center + pos_range * 0.6, 100) try: if fit_function == FitFunction.GAUSSIAN.value: @@ -558,7 +585,7 @@ def fit_focus_curve(positions: np.ndarray, scores: np.ndarray, # Generate fitted curve if fit_function == FitFunction.GAUSSIAN.value: a, mu, sigma, c = fit_params - fitted_scores = a * np.exp(-((fitted_positions - mu) ** 2) / (2 * sigma ** 2)) + c + fitted_scores = a * np.exp(-((fitted_positions - mu) ** 2) / (2 * sigma**2)) + c else: # parabolic a, b, c = fit_params fitted_scores = a * fitted_positions**2 + b * fitted_positions + c @@ -572,10 +599,11 @@ def fit_focus_curve(positions: np.ndarray, scores: np.ndarray, raise -def _fit_gaussian(positions: np.ndarray, scores: np.ndarray) -> Tuple[np.ndarray, float]: +def _fit_gaussian(positions: np.ndarray, scores: np.ndarray) -> tuple[np.ndarray, float]: """Fit Gaussian curve to focus data""" + def gaussian(x, a, mu, sigma, c): - return a * np.exp(-((x - mu) ** 2) / (2 * sigma ** 2)) + c + return a * np.exp(-((x - mu) ** 2) / (2 * sigma**2)) + c # Initial parameter estimates a_init = np.max(scores) - np.min(scores) @@ -588,7 +616,7 @@ def gaussian(x, a, mu, sigma, c): # Fit with bounds to ensure physical parameters bounds = ( [0, np.min(positions), 0.1, 0], # Lower bounds - [np.inf, np.max(positions), np.inf, np.inf] # Upper bounds + [np.inf, np.max(positions), np.inf, np.inf], # Upper bounds ) popt, pcov = optimize.curve_fit(gaussian, positions, scores, p0=p0, bounds=bounds, maxfev=1000) @@ -602,7 +630,7 @@ def gaussian(x, a, mu, sigma, c): return popt, r_squared -def _fit_parabolic(positions: np.ndarray, scores: np.ndarray) -> Tuple[np.ndarray, float]: +def _fit_parabolic(positions: np.ndarray, scores: np.ndarray) -> tuple[np.ndarray, float]: """Fit parabolic curve to focus data""" # Fit quadratic polynomial: y = ax^2 + bx + c coeffs = np.polyfit(positions, scores, 2) @@ -617,9 +645,9 @@ def _fit_parabolic(positions: np.ndarray, scores: np.ndarray) -> Tuple[np.ndarra def create_focus_montage( - images: List[np.ndarray], - labels: Optional[List[str]] = None, - offsets: Optional[List[float]] = None, + images: list[np.ndarray], + labels: list[str] | None = None, + offsets: list[float] | None = None, normalize: bool = True, gap: int = 4, ) -> np.ndarray: @@ -659,7 +687,7 @@ def create_focus_montage( # Default labels: A, B, C, ... if labels is None: - labels = [chr(ord('A') + i) for i in range(len(images))] + labels = [chr(ord("A") + i) for i in range(len(images))] # Ensure all images are 2D and same size processed = [] @@ -689,19 +717,22 @@ def create_focus_montage( # Create montage canvas n = len(processed) total_width = w * n + gap * (n - 1) - montage = np.ones((h + 30, total_width), dtype=np.uint8) * 255 # White background, extra space for labels + montage = ( + np.ones((h + 30, total_width), dtype=np.uint8) * 255 + ) # White background, extra space for labels # Place images with gaps - for i, (img, label) in enumerate(zip(processed, labels)): + for i, (img, label) in enumerate(zip(processed, labels, strict=False)): x_start = i * (w + gap) # Resize if needed to match reference dimensions if img.shape != (h, w): from scipy.ndimage import zoom + zoom_factors = (h / img.shape[0], w / img.shape[1]) img = zoom(img, zoom_factors, order=1).astype(np.uint8) - montage[:h, x_start:x_start + w] = img + montage[:h, x_start : x_start + w] = img # Add label text (simple pixel drawing for letter) # Position label at top-left of each image @@ -738,4 +769,4 @@ def _draw_label(image: np.ndarray, text: str, x: int, y: int, small: bool = Fals image[y:y_end, x:x_end] = 40 # Dark gray background # Note: For proper text rendering, the caller should use PIL or OpenCV - # This placeholder ensures the montage structure works \ No newline at end of file + # This placeholder ensures the montage structure works diff --git a/gently/analysis/focus.py b/gently/analysis/focus.py index 82bd578c..1983abf8 100644 --- a/gently/analysis/focus.py +++ b/gently/analysis/focus.py @@ -10,43 +10,46 @@ """ import logging -import numpy as np -from typing import List, Tuple, Optional, Dict, Any from dataclasses import dataclass +import numpy as np + logger = logging.getLogger(__name__) # Import from core analysis module -from .core import ( - calculate_focus_score, analyze_focus_stack, fit_focus_curve, - FocusAnalysisConfig, FocusResult +from ..detection import get_embryo_focus_roi # noqa: E402 +from .core import ( # noqa: E402 + FocusAnalysisConfig, + analyze_focus_stack, + calculate_focus_score, ) -from ..detection import get_embryo_focus_roi @dataclass class FocusDataPoint: """Single focus measurement data point""" + position: float score: float image: np.ndarray - roi: Optional[Tuple[int, int, int, int]] = None + roi: tuple[int, int, int, int] | None = None @dataclass class FocusSweepResult: """Result of a focus sweep with analysis""" + success: bool best_position: float best_score: float - all_data: List[FocusDataPoint] + all_data: list[FocusDataPoint] r_squared: float = 0.0 - error_message: Optional[str] = None + error_message: str | None = None -def score_single_image(image: np.ndarray, - config: FocusAnalysisConfig, - detect_roi: bool = True) -> Tuple[float, Optional[Tuple[int, int, int, int]]]: +def score_single_image( + image: np.ndarray, config: FocusAnalysisConfig, detect_roi: bool = True +) -> tuple[float, tuple[int, int, int, int] | None]: """ Score a single image for focus quality @@ -74,10 +77,12 @@ def score_single_image(image: np.ndarray, return score, roi -def find_best_focus_position(positions: List[float], - scores: List[float], - images: List[np.ndarray], - config: FocusAnalysisConfig) -> float: +def find_best_focus_position( + positions: list[float], + scores: list[float], + images: list[np.ndarray], + config: FocusAnalysisConfig, +) -> float: """ Find the best focus position from sweep data @@ -122,8 +127,9 @@ def find_best_focus_position(positions: List[float], return positions[np.argmax(scores)] -def analyze_focus_sweep(sweep_data: List[FocusDataPoint], - config: FocusAnalysisConfig) -> FocusSweepResult: +def analyze_focus_sweep( + sweep_data: list[FocusDataPoint], config: FocusAnalysisConfig +) -> FocusSweepResult: """ Analyze a complete focus sweep @@ -147,7 +153,7 @@ def analyze_focus_sweep(sweep_data: List[FocusDataPoint], best_position=0.0, best_score=0.0, all_data=sweep_data, - error_message="Insufficient data points for analysis" + error_message="Insufficient data points for analysis", ) try: @@ -176,7 +182,7 @@ def analyze_focus_sweep(sweep_data: List[FocusDataPoint], best_position=best_position, best_score=best_score, all_data=sweep_data, - r_squared=r_squared + r_squared=r_squared, ) except Exception as e: @@ -185,14 +191,13 @@ def analyze_focus_sweep(sweep_data: List[FocusDataPoint], best_position=0.0, best_score=0.0, all_data=sweep_data, - error_message=str(e) + error_message=str(e), ) -def create_focus_positions(center: float, - range_um: float, - num_steps: int, - limits: Tuple[float, float]) -> List[float]: +def create_focus_positions( + center: float, range_um: float, num_steps: int, limits: tuple[float, float] +) -> list[float]: """ Create focus sweep positions with limit checking @@ -215,11 +220,7 @@ def create_focus_positions(center: float, List of valid positions within limits """ # Generate positions - positions = np.linspace( - center - range_um/2, - center + range_um/2, - num_steps - ) + positions = np.linspace(center - range_um / 2, center + range_um / 2, num_steps) # Filter to limits min_pos, max_pos = limits @@ -240,22 +241,24 @@ def print_focus_summary(result: FocusSweepResult, scan_type: str = "focus") -> N Type of scan ("coarse", "fine", etc.) """ if result.success: - logger.info(f"{scan_type.capitalize()} analysis: " - f"best position {result.best_position:.2f} um " - f"(score: {result.best_score:.1f}, R2: {result.r_squared:.3f})") + logger.info( + f"{scan_type.capitalize()} analysis: " + f"best position {result.best_position:.2f} um " + f"(score: {result.best_score:.1f}, R2: {result.r_squared:.3f})" + ) else: logger.warning(f"{scan_type.capitalize()} analysis failed: {result.error_message}") # Convenience functions for common operations -def quick_focus_score(image: np.ndarray, algorithm: str = 'gradient') -> float: +def quick_focus_score(image: np.ndarray, algorithm: str = "gradient") -> float: """Quick focus scoring with default parameters""" config = FocusAnalysisConfig(algorithm=algorithm) score, _ = score_single_image(image, config, detect_roi=True) return score -def is_good_focus_curve(scores: List[float], threshold: float = 0.1) -> bool: +def is_good_focus_curve(scores: list[float], threshold: float = 0.1) -> bool: """ Check if focus curve has sufficient variation for analysis @@ -281,4 +284,4 @@ def is_good_focus_curve(scores: List[float], threshold: float = 0.1) -> bool: return False coefficient_of_variation = std_dev / mean_score - return coefficient_of_variation >= threshold \ No newline at end of file + return coefficient_of_variation >= threshold diff --git a/gently/analysis/pipeline.py b/gently/analysis/pipeline.py index 1edd4230..8e6b1d8c 100644 --- a/gently/analysis/pipeline.py +++ b/gently/analysis/pipeline.py @@ -4,7 +4,6 @@ Provides the base classes and execution engine for composable analysis pipelines. """ -import asyncio import logging import time import uuid @@ -12,25 +11,24 @@ from dataclasses import dataclass, field from datetime import datetime from enum import Enum, auto -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any from ..settings import settings -import numpy as np - logger = logging.getLogger(__name__) class StepType(Enum): """Types of analysis steps""" - VLM = auto() # Vision Language Model (Claude) - SAM = auto() # Segment Anything Model - CLASSICAL = auto() # Classical CV (OpenCV, scikit-image) - PROJECTION = auto() # Dimension reduction (max proj, etc.) - THRESHOLD = auto() # Thresholding operations - MORPHOLOGY = auto() # Morphological operations - DETECTION = auto() # Object detection - CUSTOM = auto() # Custom step + + VLM = auto() # Vision Language Model (Claude) + SAM = auto() # Segment Anything Model + CLASSICAL = auto() # Classical CV (OpenCV, scikit-image) + PROJECTION = auto() # Dimension reduction (max proj, etc.) + THRESHOLD = auto() # Thresholding operations + MORPHOLOGY = auto() # Morphological operations + DETECTION = auto() # Object detection + CUSTOM = auto() # Custom step @dataclass @@ -44,37 +42,38 @@ class AnalysisResult: - Metadata about the analysis - The actual result data """ + uid: str = field(default_factory=lambda: str(uuid.uuid4())) step_name: str = "" step_type: StepType = StepType.CUSTOM - parent_uid: Optional[str] = None + parent_uid: str | None = None timestamp: datetime = field(default_factory=datetime.now) # Result data data: Any = None # Primary result (image, mask, dict, etc.) - metadata: Dict[str, Any] = field(default_factory=dict) + metadata: dict[str, Any] = field(default_factory=dict) # Execution info duration_ms: float = 0.0 success: bool = True - error: Optional[str] = None + error: str | None = None def __str__(self) -> str: status = "+" if self.success else "x" return f"{status} {self.step_name} ({self.step_type.name}) [{self.uid[:8]}]" - def to_dict(self) -> Dict: + def to_dict(self) -> dict: """Serialize to dictionary""" return { - 'uid': self.uid, - 'step_name': self.step_name, - 'step_type': self.step_type.name, - 'parent_uid': self.parent_uid, - 'timestamp': self.timestamp.isoformat(), - 'metadata': self.metadata, - 'duration_ms': self.duration_ms, - 'success': self.success, - 'error': self.error, + "uid": self.uid, + "step_name": self.step_name, + "step_type": self.step_type.name, + "parent_uid": self.parent_uid, + "timestamp": self.timestamp.isoformat(), + "metadata": self.metadata, + "duration_ms": self.duration_ms, + "success": self.success, + "error": self.error, } @@ -91,9 +90,9 @@ class AnalysisStep(ABC): def __init__( self, - name: Optional[str] = None, + name: str | None = None, step_type: StepType = StepType.CUSTOM, - config: Optional[Dict] = None, + config: dict | None = None, ): self.name = name or self.__class__.__name__ self.step_type = step_type @@ -103,7 +102,7 @@ def __init__( async def execute( self, input_data: Any, - context: Optional[Dict] = None, + context: dict | None = None, ) -> AnalysisResult: """ Execute the analysis step @@ -125,7 +124,7 @@ async def execute( async def __call__( self, input_data: Any, - context: Optional[Dict] = None, + context: dict | None = None, ) -> AnalysisResult: """Allow calling step directly""" return await self.execute(input_data, context) @@ -148,15 +147,15 @@ class Pipeline: def __init__( self, name: str = "pipeline", - steps: Optional[List[AnalysisStep]] = None, + steps: list[AnalysisStep] | None = None, store_intermediate: bool = False, ): self.name = name - self.steps: List[AnalysisStep] = steps or [] + self.steps: list[AnalysisStep] = steps or [] self.store_intermediate = store_intermediate self._data_store = None - def add_step(self, step: AnalysisStep) -> 'Pipeline': + def add_step(self, step: AnalysisStep) -> "Pipeline": """Add a step to the pipeline (fluent)""" self.steps.append(step) return self @@ -168,7 +167,7 @@ def set_data_store(self, store): async def execute( self, input_data: Any, - context: Optional[Dict] = None, + context: dict | None = None, ) -> AnalysisResult: """ Execute the full pipeline @@ -187,8 +186,8 @@ async def execute( """ context = context or {} current_data = input_data - parent_uid = context.get('input_uid') - results: List[AnalysisResult] = [] + parent_uid = context.get("input_uid") + results: list[AnalysisResult] = [] start_time = time.time() @@ -210,10 +209,10 @@ async def execute( data=result.data, data_type="analysis", metadata={ - 'step_name': result.step_name, - 'step_type': result.step_type.name, - 'pipeline': self.name, - 'step_index': i, + "step_name": result.step_name, + "step_type": result.step_type.name, + "pipeline": self.name, + "step_index": i, **result.metadata, }, parent_uid=parent_uid, @@ -228,7 +227,7 @@ async def execute( # Pass result data to next step current_data = result.data - logger.debug(f"Step {i+1}/{len(self.steps)}: {result}") + logger.debug(f"Step {i + 1}/{len(self.steps)}: {result}") except Exception as e: logger.error(f"Pipeline step {step.name} failed: {e}") @@ -238,7 +237,7 @@ async def execute( parent_uid=parent_uid, success=False, error=str(e), - metadata={'failed_at_step': i, 'pipeline': self.name}, + metadata={"failed_at_step": i, "pipeline": self.name}, ) # Create final result @@ -246,14 +245,14 @@ async def execute( final_result = AnalysisResult( step_name=self.name, step_type=StepType.CUSTOM, - parent_uid=context.get('input_uid'), + parent_uid=context.get("input_uid"), data=current_data, duration_ms=total_duration, success=True, metadata={ - 'pipeline': self.name, - 'num_steps': len(self.steps), - 'step_results': [r.to_dict() for r in results], + "pipeline": self.name, + "num_steps": len(self.steps), + "step_results": [r.to_dict() for r in results], }, ) @@ -264,8 +263,8 @@ async def execute( data=final_result.data, data_type="analysis", metadata={ - 'pipeline': self.name, - 'final': True, + "pipeline": self.name, + "final": True, **final_result.metadata, }, parent_uid=parent_uid, @@ -279,7 +278,7 @@ async def execute( async def __call__( self, input_data: Any, - context: Optional[Dict] = None, + context: dict | None = None, ) -> AnalysisResult: """Allow calling pipeline directly""" return await self.execute(input_data, context) @@ -305,28 +304,30 @@ class PipelineBuilder: def __init__(self, name: str = "pipeline"): self.name = name - self._steps: List[AnalysisStep] = [] + self._steps: list[AnalysisStep] = [] self._store_intermediate = False self._data_store = None - def add(self, step: AnalysisStep) -> 'PipelineBuilder': + def add(self, step: AnalysisStep) -> "PipelineBuilder": """Add a custom step""" self._steps.append(step) return self - def max_projection(self, axis: int = 0) -> 'PipelineBuilder': + def max_projection(self, axis: int = 0) -> "PipelineBuilder": """Add max projection step""" from .steps import MaxProjectionStep + self._steps.append(MaxProjectionStep(axis=axis)) return self def threshold( self, method: str = "otsu", - value: Optional[float] = None, - ) -> 'PipelineBuilder': + value: float | None = None, + ) -> "PipelineBuilder": """Add threshold step""" from .steps import ThresholdStep + self._steps.append(ThresholdStep(method=method, value=value)) return self @@ -334,9 +335,10 @@ def morphology( self, operation: str = "open", kernel_size: int = 3, - ) -> 'PipelineBuilder': + ) -> "PipelineBuilder": """Add morphological operation step""" from .steps import MorphologyStep + self._steps.append(MorphologyStep(operation=operation, kernel_size=kernel_size)) return self @@ -345,23 +347,27 @@ def blob_detection( min_sigma: float = 10, max_sigma: float = 50, threshold: float = 0.1, - ) -> 'PipelineBuilder': + ) -> "PipelineBuilder": """Add blob detection step""" from .steps import BlobDetectionStep - self._steps.append(BlobDetectionStep( - min_sigma=min_sigma, - max_sigma=max_sigma, - threshold=threshold, - )) + + self._steps.append( + BlobDetectionStep( + min_sigma=min_sigma, + max_sigma=max_sigma, + threshold=threshold, + ) + ) return self def sam_segment( self, - prompt: Optional[str] = None, - points: Optional[List] = None, - ) -> 'PipelineBuilder': + prompt: str | None = None, + points: list | None = None, + ) -> "PipelineBuilder": """Add SAM segmentation step""" from .steps import SAMStep + self._steps.append(SAMStep(prompt=prompt, points=points)) return self @@ -370,13 +376,14 @@ def vlm_analyze( prompt: str, model: str = settings.models.perception, max_tokens: int = 1024, - ) -> 'PipelineBuilder': + ) -> "PipelineBuilder": """Add VLM analysis step""" from .steps import VLMStep + self._steps.append(VLMStep(prompt=prompt, model=model, max_tokens=max_tokens)) return self - def store_intermediate(self, store=None) -> 'PipelineBuilder': + def store_intermediate(self, store=None) -> "PipelineBuilder": """Enable storing intermediate results""" self._store_intermediate = True self._data_store = store @@ -398,6 +405,7 @@ def build(self) -> Pipeline: # Pre-built pipelines for common analysis tasks # ============================================================================= + def create_embryo_detection_pipeline( use_sam: bool = True, use_vlm_verification: bool = False, @@ -434,7 +442,7 @@ def create_embryo_detection_pipeline( if use_vlm_verification: builder.vlm_analyze( prompt="Verify these are C. elegans embryos. " - "Count the number of valid embryos and note any false positives." + "Count the number of valid embryos and note any false positives." ) return builder.build() diff --git a/gently/analysis/steps.py b/gently/analysis/steps.py index 5ed10533..24577eb4 100644 --- a/gently/analysis/steps.py +++ b/gently/analysis/steps.py @@ -10,12 +10,12 @@ import asyncio import logging import os -from typing import Any, Dict, List, Optional, Tuple +from typing import Any import numpy as np from ..settings import settings -from .pipeline import AnalysisStep, AnalysisResult, StepType +from .pipeline import AnalysisResult, AnalysisStep, StepType logger = logging.getLogger(__name__) @@ -24,6 +24,7 @@ # Image Processing Steps # ============================================================================= + class MaxProjectionStep(AnalysisStep): """ Create max intensity projection along an axis @@ -38,7 +39,7 @@ def __init__(self, axis: int = 0, name: str = "max_projection"): async def execute( self, input_data: Any, - context: Optional[Dict] = None, + context: dict | None = None, ) -> AnalysisResult: """Execute max projection""" if not isinstance(input_data, np.ndarray): @@ -56,7 +57,7 @@ async def execute( step_type=self.step_type, data=input_data, success=True, - metadata={'already_2d': True}, + metadata={"already_2d": True}, ) # Perform projection @@ -68,9 +69,9 @@ async def execute( data=projection, success=True, metadata={ - 'input_shape': list(input_data.shape), - 'output_shape': list(projection.shape), - 'axis': self.axis, + "input_shape": list(input_data.shape), + "output_shape": list(projection.shape), + "axis": self.axis, }, ) @@ -89,7 +90,7 @@ class ThresholdStep(AnalysisStep): def __init__( self, method: str = "otsu", - value: Optional[float] = None, + value: float | None = None, name: str = "threshold", ): super().__init__(name=name, step_type=StepType.THRESHOLD) @@ -99,7 +100,7 @@ def __init__( async def execute( self, input_data: Any, - context: Optional[Dict] = None, + context: dict | None = None, ) -> AnalysisResult: """Apply threshold""" if not isinstance(input_data, np.ndarray): @@ -117,6 +118,7 @@ async def execute( # Otsu's method try: from skimage.filters import threshold_otsu + threshold_value = threshold_otsu(image) except ImportError: # Fallback to simple percentile @@ -151,9 +153,9 @@ async def execute( data=binary.astype(np.uint8) * 255, success=True, metadata={ - 'method': self.method, - 'threshold_value': float(threshold_value), - 'pixels_above': int(np.sum(binary)), + "method": self.method, + "threshold_value": float(threshold_value), + "pixels_above": int(np.sum(binary)), }, ) @@ -184,7 +186,7 @@ def __init__( async def execute( self, input_data: Any, - context: Optional[Dict] = None, + context: dict | None = None, ) -> AnalysisResult: """Apply morphological operation""" if not isinstance(input_data, np.ndarray): @@ -197,6 +199,7 @@ async def execute( try: import cv2 + kernel = np.ones((self.kernel_size, self.kernel_size), np.uint8) if self.operation == "erode": @@ -219,25 +222,38 @@ async def execute( # Fallback without OpenCV using scipy try: from scipy import ndimage + kernel = np.ones((self.kernel_size, self.kernel_size)) if self.operation in ("erode", "open"): - result = ndimage.binary_erosion( - input_data > 0, structure=kernel, iterations=self.iterations - ).astype(np.uint8) * 255 + result = ( + ndimage.binary_erosion( + input_data > 0, structure=kernel, iterations=self.iterations + ).astype(np.uint8) + * 255 + ) if self.operation == "open": - result = ndimage.binary_dilation( - result > 0, structure=kernel, iterations=self.iterations - ).astype(np.uint8) * 255 + result = ( + ndimage.binary_dilation( + result > 0, structure=kernel, iterations=self.iterations + ).astype(np.uint8) + * 255 + ) elif self.operation in ("dilate", "close"): - result = ndimage.binary_dilation( - input_data > 0, structure=kernel, iterations=self.iterations - ).astype(np.uint8) * 255 + result = ( + ndimage.binary_dilation( + input_data > 0, structure=kernel, iterations=self.iterations + ).astype(np.uint8) + * 255 + ) if self.operation == "close": - result = ndimage.binary_erosion( - result > 0, structure=kernel, iterations=self.iterations - ).astype(np.uint8) * 255 + result = ( + ndimage.binary_erosion( + result > 0, structure=kernel, iterations=self.iterations + ).astype(np.uint8) + * 255 + ) except ImportError: return AnalysisResult( @@ -253,9 +269,9 @@ async def execute( data=result, success=True, metadata={ - 'operation': self.operation, - 'kernel_size': self.kernel_size, - 'iterations': self.iterations, + "operation": self.operation, + "kernel_size": self.kernel_size, + "iterations": self.iterations, }, ) @@ -284,7 +300,7 @@ def __init__( async def execute( self, input_data: Any, - context: Optional[Dict] = None, + context: dict | None = None, ) -> AnalysisResult: """Detect blobs""" if not isinstance(input_data, np.ndarray): @@ -316,25 +332,27 @@ async def execute( detections = [] for blob in blobs: y, x, sigma = blob - detections.append({ - 'x': float(x), - 'y': float(y), - 'radius': float(sigma * np.sqrt(2)), - }) + detections.append( + { + "x": float(x), + "y": float(y), + "radius": float(sigma * np.sqrt(2)), + } + ) return AnalysisResult( step_name=self.name, step_type=self.step_type, data={ - 'detections': detections, - 'count': len(detections), - 'image': input_data, # Keep original for next step + "detections": detections, + "count": len(detections), + "image": input_data, # Keep original for next step }, success=True, metadata={ - 'num_blobs': len(detections), - 'min_sigma': self.min_sigma, - 'max_sigma': self.max_sigma, + "num_blobs": len(detections), + "min_sigma": self.min_sigma, + "max_sigma": self.max_sigma, }, ) @@ -351,6 +369,7 @@ async def execute( # VLM (Vision Language Model) Step # ============================================================================= + class VLMStep(AnalysisStep): """ Analyze image using Claude Vision API @@ -364,7 +383,7 @@ def __init__( model: str = settings.models.perception, max_tokens: int = 1024, name: str = "vlm_analysis", - api_key: Optional[str] = None, + api_key: str | None = None, ): super().__init__(name=name, step_type=StepType.VLM) self.prompt = prompt @@ -374,21 +393,22 @@ def __init__( def _encode_image(self, image: np.ndarray) -> str: """Encode image to base64 JPEG.""" - from gently.core.imaging import normalize_to_uint8, image_to_base64 + from gently.core.imaging import image_to_base64, normalize_to_uint8 + img = normalize_to_uint8(image, method="minmax") return image_to_base64(img, format="JPEG", quality=85) async def execute( self, input_data: Any, - context: Optional[Dict] = None, + context: dict | None = None, ) -> AnalysisResult: """Analyze with Claude Vision""" context = context or {} # Get image from input (handle dict from previous step) if isinstance(input_data, dict): - image = input_data.get('image', input_data.get('data')) + image = input_data.get("image", input_data.get("data")) else: image = input_data @@ -404,7 +424,7 @@ async def execute( import anthropic # Get API key - api_key = self.api_key or context.get('api_key') or os.getenv("ANTHROPIC_API_KEY") + api_key = self.api_key or context.get("api_key") or os.getenv("ANTHROPIC_API_KEY") if not api_key: return AnalysisResult( step_name=self.name, @@ -424,12 +444,12 @@ async def execute( "type": "base64", "media_type": "image/jpeg", "data": b64_image, - } + }, }, { "type": "text", "text": self.prompt, - } + }, ] # Call Claude @@ -438,7 +458,7 @@ async def execute( client.messages.create, model=self.model, max_tokens=self.max_tokens, - messages=[{"role": "user", "content": content}] + messages=[{"role": "user", "content": content}], ) result_text = response.content[0].text @@ -447,15 +467,15 @@ async def execute( step_name=self.name, step_type=self.step_type, data={ - 'analysis': result_text, - 'image': image, # Pass through for next step + "analysis": result_text, + "image": image, # Pass through for next step }, success=True, metadata={ - 'model': self.model, - 'prompt': self.prompt[:100] + "..." if len(self.prompt) > 100 else self.prompt, - 'input_tokens': response.usage.input_tokens, - 'output_tokens': response.usage.output_tokens, + "model": self.model, + "prompt": self.prompt[:100] + "..." if len(self.prompt) > 100 else self.prompt, + "input_tokens": response.usage.input_tokens, + "output_tokens": response.usage.output_tokens, }, ) @@ -473,6 +493,7 @@ async def execute( # SAM (Segment Anything Model) Step # ============================================================================= + class SAMStep(AnalysisStep): """ Segment image using SAM (Segment Anything Model) @@ -485,9 +506,9 @@ class SAMStep(AnalysisStep): def __init__( self, - prompt: Optional[str] = None, - points: Optional[List[Tuple[int, int]]] = None, - model_path: Optional[str] = None, + prompt: str | None = None, + points: list[tuple[int, int]] | None = None, + model_path: str | None = None, min_area: int = 1000, name: str = "sam_segmentation", ): @@ -505,8 +526,12 @@ def _load_sam(self): return try: - from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator import torch + from segment_anything import ( + SamAutomaticMaskGenerator, + SamPredictor, + sam_model_registry, + ) device = "cuda" if torch.cuda.is_available() else "cpu" @@ -528,12 +553,12 @@ def _load_sam(self): async def execute( self, input_data: Any, - context: Optional[Dict] = None, + context: dict | None = None, ) -> AnalysisResult: """Run SAM segmentation""" # Get image from input if isinstance(input_data, dict): - image = input_data.get('image', input_data.get('data')) + image = input_data.get("image", input_data.get("data")) else: image = input_data @@ -565,28 +590,29 @@ async def execute( if image_rgb.max() <= 1.0: image_rgb = (image_rgb * 255).astype(np.uint8) else: - image_rgb = ((image_rgb - image_rgb.min()) / - (image_rgb.max() - image_rgb.min()) * 255).astype(np.uint8) + image_rgb = ( + (image_rgb - image_rgb.min()) / (image_rgb.max() - image_rgb.min()) * 255 + ).astype(np.uint8) # Run in thread (SAM is CPU/GPU intensive) masks = await asyncio.to_thread(self._run_sam, image_rgb) # Filter by area - filtered_masks = [m for m in masks if m['area'] >= self.min_area] + filtered_masks = [m for m in masks if m["area"] >= self.min_area] return AnalysisResult( step_name=self.name, step_type=self.step_type, data={ - 'masks': filtered_masks, - 'count': len(filtered_masks), - 'image': image, + "masks": filtered_masks, + "count": len(filtered_masks), + "image": image, }, success=True, metadata={ - 'total_masks': len(masks), - 'filtered_masks': len(filtered_masks), - 'min_area': self.min_area, + "total_masks": len(masks), + "filtered_masks": len(filtered_masks), + "min_area": self.min_area, }, ) @@ -599,7 +625,7 @@ async def execute( error=str(e), ) - def _run_sam(self, image_rgb: np.ndarray) -> List[Dict]: + def _run_sam(self, image_rgb: np.ndarray) -> list[dict]: """Run SAM (blocking, called in thread)""" if self.points: # Point-prompted segmentation @@ -615,9 +641,9 @@ def _run_sam(self, image_rgb: np.ndarray) -> List[Dict]: return [ { - 'segmentation': masks[i], - 'area': int(np.sum(masks[i])), - 'score': float(scores[i]), + "segmentation": masks[i], + "area": int(np.sum(masks[i])), + "score": float(scores[i]), } for i in range(len(masks)) ] @@ -643,25 +669,27 @@ async def _fallback_segment(self, image: np.ndarray) -> AnalysisResult: for region in regions: if region.area >= self.min_area: mask = labeled == region.label - masks.append({ - 'segmentation': mask, - 'area': region.area, - 'centroid': region.centroid, - 'bbox': region.bbox, - }) + masks.append( + { + "segmentation": mask, + "area": region.area, + "centroid": region.centroid, + "bbox": region.bbox, + } + ) return AnalysisResult( step_name=self.name, step_type=self.step_type, data={ - 'masks': masks, - 'count': len(masks), - 'image': image, + "masks": masks, + "count": len(masks), + "image": image, }, success=True, metadata={ - 'method': 'fallback_threshold', - 'num_masks': len(masks), + "method": "fallback_threshold", + "num_masks": len(masks), }, ) diff --git a/gently/app/agent.py b/gently/app/agent.py index 1bd8c64a..4a5f6602 100644 --- a/gently/app/agent.py +++ b/gently/app/agent.py @@ -14,37 +14,50 @@ import asyncio import logging import os -from typing import Dict, List, Optional, Callable, Any, TYPE_CHECKING +from collections.abc import Callable from datetime import datetime from pathlib import Path +from typing import TYPE_CHECKING, Any import anthropic import numpy as np -from ..exceptions import StorageError, AgentError +from ..exceptions import StorageError from ..settings import settings if TYPE_CHECKING: from ..ui.web.server import VisualizationServer +from gently_perception import Perceiver -logger = logging.getLogger(__name__) - -from ..harness.state import ExperimentState, EmbryoState, ImageRecord -from ..harness.orchestration.plan_synthesis import PlanSynthesizer, PlanLibrary, PlanValidator +from ..core import EventType, emit, get_event_bus +from ..core.file_store import FileStore +from ..harness.conversation import ConversationManager +from ..harness.orchestration.plan_synthesis import ( + PlanLibrary, + PlanSynthesizer, + PlanValidator, +) +from ..harness.prompts.manager import PromptManager +from ..harness.session.interaction_logger import InteractionLogger +from ..harness.session.manager import SessionManager +from ..harness.session.timeline import TimelineManager +from ..harness.state import ExperimentState from ..harness.tools.registry import get_tool_registry -from gently_perception import Perceiver # Import tools package to trigger @tool decorator registration from . import tools as _tools # noqa: F401 -from ..harness.session.interaction_logger import InteractionLogger from .orchestration.timelapse import TimelapseOrchestrator -from ..harness.session.timeline import TimelineManager -from ..core import EventType, get_event_bus, emit -from ..core.file_store import FileStore -from ..harness.conversation import ConversationManager -from ..harness.session.manager import SessionManager -from ..harness.prompts.manager import PromptManager +logger = logging.getLogger(__name__) + +# Shown when the agent is launched in UI-only mode (--no-api). The web UI is +# fully browsable, but anything that would call Claude is disabled. +_NO_API_NOTICE = ( + "The agent is running in **UI-only mode** (`--no-api`), so it can't " + "respond — no Anthropic API calls are made. You can browse the interface, " + "view saved sessions, and explore the UI. To enable chat, perception, and " + "plan generation, restart without `--no-api` and with `ANTHROPIC_API_KEY` set." +) class MicroscopyAgent: @@ -61,12 +74,13 @@ class MicroscopyAgent: def __init__( self, - api_key: Optional[str] = None, + api_key: str | None = None, storage_path: Path = Path("./experiment_data"), model: str = settings.models.main, microscope_client=None, - session_id: Optional[str] = None, - store: FileStore = None, + session_id: str | None = None, + store: FileStore | None = None, + no_api: bool = False, ): """ Parameters @@ -83,14 +97,27 @@ def __init__( Session ID to resume. If None, creates new session. store : FileStore Unified file-based data store. Required. + no_api : bool + UI-only mode: skip Anthropic API calls entirely. The full agent and + its sub-components are still constructed (so the web UI boots), but + message handling short-circuits with a clear notice instead of + calling Claude. Useful for browsing the UI without an API key. """ if store is None: raise ValueError("FileStore is required. Pass store=FileStore(path) to agent.") + # UI-only mode: no real API calls. We still build the client object (so + # all sub-components that hold a reference work), but fall back to a + # placeholder key so construction never fails when no key is set, and + # the message entry points refuse to call Claude. + self.api_enabled = not no_api + # API client with interleaved thinking support self.claude = anthropic.Anthropic( - api_key=api_key or os.getenv("ANTHROPIC_API_KEY"), - default_headers={"anthropic-beta": "interleaved-thinking-2025-05-14"} + api_key=api_key + or os.getenv("ANTHROPIC_API_KEY") + or ("no-api-mode" if no_api else None), + default_headers={"anthropic-beta": "interleaved-thinking-2025-05-14"}, ) self.model = model @@ -102,7 +129,7 @@ def __init__( self.mode: str = "run" # Context store (agent's mind — set via set_context_store) - self.context_store: Optional[Any] = None + self.context_store: Any | None = None # Experiment state self.experiment = ExperimentState() @@ -118,13 +145,17 @@ def __init__( # Plan synthesis self.plan_synthesizer = PlanSynthesizer( - plan_library=PlanLibrary(), - validator=PlanValidator() + plan_library=PlanLibrary(), validator=PlanValidator() ) # Event bus for async messaging (must be before perception manager) self._event_bus = get_event_bus() + # Broadcast the embryo list whenever it mutates. Hooked through the + # state object's observer so add/remove/nickname/restore all publish + # without each call site having to remember. + self.experiment.on_embryos_changed = self._publish_embryos_update + # Perception system (gently-perception harness) self.perceiver = Perceiver() @@ -135,20 +166,49 @@ def __init__( self.client = self.microscope # Callbacks - self.on_message_callback: Optional[Callable] = None - self.choice_handler: Optional[Callable] = None + self.on_message_callback: Callable | None = None + self.choice_handler: Callable | None = None + + # Serializes conversation turns: user turns and autonomous wake turns + # must not interleave on the shared conversation_history. + self._turn_lock = asyncio.Lock() + + # Autonomy backstop: while a wake turn runs, _autonomous_active is True + # and the registry refuses these irreversible tools (they require a + # human). User turns are unaffected. _wake_choice_factory is set by the + # web bridge so ASK-mode wake turns can round-trip an approval picker. + self._autonomous_active = False + self._autonomous_blocked_tools = frozenset( + { + "set_laser_power", + "remove_embryo", + "stop_timelapse", + } + ) + self._wake_choice_factory = None + self._wake_choice_discard = None # Interaction logger for structured logging (research data collection) - self.interaction_logger: Optional[InteractionLogger] = None + self.interaction_logger: InteractionLogger | None = None + + # Event capture — durable log of every EventBus event during this + # session. Substrate for offline replay / shadow-mode A/B of + # candidate orchestrator architectures. + self.event_capture = None + + # Decision log — what production decided at each turn (tool calls, + # response text, prompt hash). Pairs with event capture so a + # candidate replay can be diffed against production turn-by-turn. + self.decision_log = None # Timelapse orchestrator (initialized when microscope connected) - self.timelapse_orchestrator: Optional[TimelapseOrchestrator] = None + self.timelapse_orchestrator: TimelapseOrchestrator | None = None # Timeline manager for tracking events - self.timeline_manager: Optional[TimelineManager] = None + self.timeline_manager: TimelineManager | None = None # Visualization server for real-time feedback - self.viz_server: Optional["VisualizationServer"] = None + self.viz_server: VisualizationServer | None = None # Device-state monitor (bridges device-layer SSE → EventBus) self.device_state_monitor = None @@ -166,10 +226,10 @@ def __init__( ) # Wire tool execution context self.conversation._tool_context = { - 'agent': self, - 'client': getattr(self, 'microscope', None), - 'microscope': getattr(self, 'microscope', None), - 'databroker': getattr(self, 'databroker', None), + "agent": self, + "client": getattr(self, "microscope", None), + "microscope": getattr(self, "microscope", None), + "databroker": getattr(self, "databroker", None), } # Session manager (persistence) @@ -189,20 +249,36 @@ def __init__( success, history = self.sessions._resume_session(session_id, self.experiment) if success: self.conversation.conversation_history = history - self._emit_event(EventType.SESSION_RESTORED, { - 'session_id': session_id, - 'embryo_count': len(self.experiment.embryos), - 'message_count': len(self.conversation.conversation_history), - }) + self._emit_event( + EventType.SESSION_RESTORED, + { + "session_id": session_id, + "embryo_count": len(self.experiment.embryos), + "message_count": len(self.conversation.conversation_history), + }, + ) else: self.sessions.create_session() - self._emit_event(EventType.SESSION_STARTED, { - 'session_id': self.sessions.session_id, - }) + self._emit_event( + EventType.SESSION_STARTED, + { + "session_id": self.sessions.session_id, + }, + ) # Initialize interaction logger (for research data collection) self._init_interaction_logger() + # Start event capture into the session folder so offline replay / + # shadow-mode testing has a durable input stream. Filters out the + # high-volume telemetry types (DEVICE_STATE_UPDATE / BOTTOM_CAMERA_FRAME) + # by default so a long timelapse doesn't bury the meaningful events. + self._init_event_capture() + + # Open the per-session production decision log and hand it to the + # conversation manager so each Claude round-trip is captured. + self._init_decision_log() + # Wire interaction logger and choice handler to conversation manager self.conversation.interaction_logger = self.interaction_logger self.conversation.choice_handler = self.choice_handler @@ -216,6 +292,17 @@ def __init__( # Subscribe to CV result events for EmbryoState integration self._subscribe_to_cv_events() + # Decision-moment wake-router (opt-in, default OFF). Wakes the agent on + # wake-worthy perception/lifecycle events so it can adapt acquisition + # autonomously; enabled via the set_autonomy tool. + try: + from gently.app.wake_router import WakeRouter + + self.wake_router = WakeRouter(self, self._event_bus) + except Exception: + logger.exception("Failed to init wake-router") + self.wake_router = None + # Build initial system prompt self._update_system_prompt() @@ -227,7 +314,7 @@ def session_id(self) -> str: return self.sessions.session_id @property - def _session_id(self) -> Optional[str]: + def _session_id(self) -> str | None: """Internal session ID (backward compat).""" return self.sessions._session_id @@ -236,7 +323,7 @@ def _session_id(self, value): self.sessions._session_id = value @property - def conversation_history(self) -> List[Dict]: + def conversation_history(self) -> list[dict]: """Get conversation history.""" return self.conversation.conversation_history @@ -281,6 +368,7 @@ def set_context_store(self, context_store) -> None: self.prompts.context_store = context_store # Create agent memory harness from ..harness.memory.interface import AgentMemory + self.memory = AgentMemory(context_store, session_id=self.session_id) self.prompts.memory = self.memory @@ -290,8 +378,13 @@ def enter_plan_mode(self) -> str: return "Already in plan mode." self.mode = "plan" import gently.harness.plan_mode.tools # noqa: F401 + self._update_system_prompt() - emit(EventType.STATUS_CHANGED, {"field": "agent_mode", "value": "plan"}, source="agent") + emit( + EventType.STATUS_CHANGED, + {"field": "agent_mode", "value": "plan"}, + source="agent", + ) logger.info("Entered plan mode") return "Switched to plan mode. I'm now your experimental design collaborator." @@ -318,7 +411,7 @@ def enter_resolution_mode(self) -> str: logger.info("Entered resolution mode") return "Resolution mode active. Determining what this session is for." - def exit_resolution_mode(self, outcome: str = None) -> str: + def exit_resolution_mode(self, outcome: str | None = None) -> str: """Leave resolution mode for run mode. Called by resolution tools (attach_session_to_plan, @@ -369,7 +462,8 @@ def exit_plan_mode(self) -> str: if item and self.session_id and self.context_store: try: self.context_store.link_session_campaign( - self.session_id, item.campaign_id, + self.session_id, + item.campaign_id, ) except Exception: pass @@ -386,19 +480,27 @@ def exit_plan_mode(self) -> str: self.prompts.invalidate_context_cache() self._update_system_prompt() - emit(EventType.STATUS_CHANGED, {"field": "agent_mode", "value": "run"}, source="agent") + emit( + EventType.STATUS_CHANGED, + {"field": "agent_mode", "value": "run"}, + source="agent", + ) logger.info("Exited plan mode") return result # ===== Prompt & System Prompt ===== - def _update_system_prompt(self, context_summary: str = None): + def _update_system_prompt(self, context_summary: str | None = None): """Rebuild system prompt via PromptManager.""" self.system_prompt = self.prompts.update_system_prompt( - self.experiment, self.client, self.mode, context_summary + self.experiment, + self.client, + self.mode, + context_summary, + perceiver=getattr(self, "perceiver", None), ) - def _get_active_plan_summary(self) -> Optional[str]: + def _get_active_plan_summary(self) -> str | None: """Delegation shim for agent bridge access.""" return self.prompts.get_active_plan_summary() @@ -428,7 +530,7 @@ def _auto_save(self): self.experiment, self.conversation.conversation_history, self.system_prompt ) - def list_sessions(self) -> List[Dict]: + def list_sessions(self) -> list[dict]: """List available sessions.""" return self.sessions.list_sessions() @@ -452,6 +554,81 @@ def _init_interaction_logger(self): logging.getLogger(__name__).warning(f"Failed to init interaction logger: {e}") self.interaction_logger = None + def _init_event_capture(self): + """Open the per-session events.jsonl capture. + + Resolves the session folder via FileStore._session_dir so the log + sits next to session.yaml / interaction_log.jsonl. Silent no-op + when the session folder can't be resolved (e.g. test harness with + a stripped-down agent) — replay just won't have a log to read. + """ + from gently.eval import EventCapture + + try: + session_dir = None + sid = self.session_id + if self.store is not None and sid: + session_dir = self.store._session_dir(sid) + if session_dir is None: + logging.getLogger(__name__).debug( + "EventCapture: no session dir for %s — skipping", sid + ) + return + path = session_dir / "events.jsonl" + self.event_capture = EventCapture(path) + self.event_capture.start(self._event_bus) + except Exception: + logging.getLogger(__name__).exception("Failed to init event capture") + self.event_capture = None + + def stop_event_capture(self): + """Flush + close the events.jsonl. Idempotent; safe at shutdown.""" + if self.event_capture is not None: + try: + self.event_capture.stop() + except Exception: + logging.getLogger(__name__).exception("EventCapture stop failed") + self.event_capture = None + + def _init_decision_log(self): + """Open the per-session decisions.jsonl and wire it into conversation. + + Each call to ConversationManager.call_claude writes one Decision + row (success or error) describing what production decided for the + user turn. Shadow candidates write their own rows into separate + logs and the two are diffed offline. + """ + from gently.eval import DecisionLog + + try: + session_dir = None + sid = self.session_id + if self.store is not None and sid: + session_dir = self.store._session_dir(sid) + if session_dir is None: + logging.getLogger(__name__).debug( + "DecisionLog: no session dir for %s — skipping", sid + ) + return + path = session_dir / "decisions.jsonl" + self.decision_log = DecisionLog(path) + self.decision_log.open() + self.conversation.decision_log = self.decision_log + except Exception: + logging.getLogger(__name__).exception("Failed to init decision log") + self.decision_log = None + + def stop_decision_log(self): + """Flush + close the decisions.jsonl. Idempotent; safe at shutdown.""" + if self.decision_log is not None: + try: + self.decision_log.close() + except Exception: + logging.getLogger(__name__).exception("DecisionLog close failed") + self.decision_log = None + if hasattr(self, "conversation") and self.conversation is not None: + self.conversation.decision_log = None + def _init_timelapse_orchestrator(self): """Initialize the timelapse orchestrator if microscope is connected.""" if not self._has_microscope(): @@ -529,18 +706,71 @@ def on_stage_detected(event): embryo_id = data.get("embryo_id") if embryo_id and embryo_id in self.experiment.embryos: embryo = self.experiment.embryos[embryo_id] - embryo.add_cv_result("stage_classification", { - "stage": data.get("stage"), - "confidence": data.get("confidence"), - "nuclei_count": data.get("nuclei_count"), - "timepoint": data.get("timepoint"), - }) + embryo.add_cv_result( + "stage_classification", + { + "stage": data.get("stage"), + "confidence": data.get("confidence"), + "nuclei_count": data.get("nuclei_count"), + "timepoint": data.get("timepoint"), + }, + ) except Exception as e: logger.warning(f"Error handling stage detected event: {e}") unsub = self._event_bus.subscribe(EventType.STAGE_DETECTED, on_stage_detected) self._cv_subscriptions.append(unsub) + def on_perception(event): + # Bridge the perception loop's DETECTOR_EVALUATED into EmbryoState so + # the prompt/display developmental stage reflects the live Perceiver. + # (The STAGE_DETECTED wiring above is never emitted by the perception + # path — this closes that long-standing gap.) Record only on an + # actual stage CHANGE to keep cv_analyses a clean transition log and + # avoid per-timepoint disk/cache churn; live stability/timing is read + # straight from the Perceiver by the prompt snapshot + pull tool. + try: + data = event.data + if data.get("skipped") or data.get("detector_name") != "perception": + return # ignore recheck-skips and role=test pseudo-stages + embryo_id = data.get("embryo_id") + stage = data.get("stage") + # 'no_object' is an empty-field sentinel, not a developmental + # stage — don't mirror it into latest_developmental_stage. + if ( + not stage + or stage == "no_object" + or not embryo_id + or embryo_id not in self.experiment.embryos + ): + return + embryo = self.experiment.embryos[embryo_id] + if stage == getattr(embryo, "latest_developmental_stage", None): + return # steady state — nothing new to mirror + embryo.add_cv_result( + "stage_classification", + { + "stage": stage, + "timepoint": data.get("timepoint"), + "stability": data.get("stability"), + "temporal_analysis": data.get("temporal_analysis"), + "detector_name": "perception", + }, + ) + self.invalidate_context_cache() + self._auto_save() + logger.info( + "Perception: %s -> stage %s (t%s)", + embryo_id, + stage, + data.get("timepoint"), + ) + except Exception as e: + logger.warning(f"Error handling perception event: {e}") + + unsub = self._event_bus.subscribe(EventType.DETECTOR_EVALUATED, on_perception) + self._cv_subscriptions.append(unsub) + logger.debug("Subscribed to CV result events") except Exception as e: @@ -549,7 +779,9 @@ def on_stage_detected(event): # ===== Visualization Server Methods ===== - async def start_viz_server(self, port: int = settings.network.viz_port, ssl_certfile=None, ssl_keyfile=None): + async def start_viz_server( + self, port: int = settings.network.viz_port, ssl_certfile=None, ssl_keyfile=None + ): """Start the visualization server for real-time feedback.""" if self.viz_server is not None: logger.info("Visualization server already running") @@ -599,6 +831,7 @@ async def start_viz_server(self, port: int = settings.network.viz_port, ssl_cert if self.microscope is not None and self.device_state_monitor is None: try: from .device_state_monitor import DeviceStateMonitor + self.device_state_monitor = DeviceStateMonitor(self.microscope) await self.device_state_monitor.start() logger.info("Device-state monitor started") @@ -612,6 +845,7 @@ async def start_viz_server(self, port: int = settings.network.viz_port, ssl_cert if self.microscope is not None and self.bottom_camera_monitor is None: try: from .bottom_camera_monitor import BottomCameraStreamMonitor + self.bottom_camera_monitor = BottomCameraStreamMonitor(self.microscope) logger.info("Bottom-camera monitor ready (not started)") except Exception as e: @@ -642,16 +876,14 @@ def push_viz( array: np.ndarray, uid: str, data_type: str = "image", - metadata: Optional[Dict[str, Any]] = None, + metadata: dict[str, Any] | None = None, ): """Non-blocking push of image to visualization server.""" if self.viz_server is None: return try: - asyncio.create_task( - self.viz_server.push_image(array, uid, data_type, metadata or {}) - ) + asyncio.create_task(self.viz_server.push_image(array, uid, data_type, metadata or {})) except RuntimeError: pass except Exception as e: @@ -670,7 +902,7 @@ def _has_microscope(self) -> bool: """ return self.client is not None - def _emit_event(self, event_type: EventType, data: Optional[Dict] = None): + def _emit_event(self, event_type: EventType, data: dict | None = None): """Emit an event to the event bus.""" self._event_bus.publish( event_type=event_type, @@ -678,13 +910,46 @@ def _emit_event(self, event_type: EventType, data: Optional[Dict] = None): source="agent", ) + def _publish_embryos_update(self) -> None: + """Broadcast the current embryo list as an EMBRYOS_UPDATE event. + + Wired into ExperimentState.on_embryos_changed at agent init so every + add / remove / restore / nickname change snaps a fresh full-list + snapshot onto the bus. The viz server's wildcard subscription forwards + it straight to connected browsers — that's how the Devices > Map page + learns about embryos without a poll loop. + """ + if self._event_bus is None: + return + try: + embryos = [e.to_dict() for e in self.experiment.embryos.values()] + except Exception: + logger.exception("Failed to serialise embryos for EMBRYOS_UPDATE") + return + payload = { + "embryos": embryos, + "count": len(embryos), + "session_id": getattr(self, "session_id", None), + } + try: + self._event_bus.publish( + event_type=EventType.EMBRYOS_UPDATE, + data=payload, + source="agent.experiment", + ) + except Exception: + logger.exception("Failed to publish EMBRYOS_UPDATE") + def _mark_significant_action(self, action_type: str): """Mark that a significant action occurred (triggers auto-save).""" self._auto_save() - self._emit_event(EventType.SESSION_SAVED, { - 'session_id': self.sessions._session_id, - 'action_type': action_type, - }) + self._emit_event( + EventType.SESSION_SAVED, + { + "session_id": self.sessions._session_id, + "action_type": action_type, + }, + ) # ===== Public Message API ===== @@ -703,11 +968,17 @@ async def handle_message(self, user_message: str) -> str: Response from agent """ if quick_response := self.conversation.try_quick_response( - user_message, self.experiment, self.mode, - self.enter_plan_mode, self.exit_plan_mode, + user_message, + self.experiment, + self.mode, + self.enter_plan_mode, + self.exit_plan_mode, ): return quick_response + if not self.api_enabled: + return _NO_API_NOTICE + # Update system prompt with current state and context awareness context_summary = await self.prompts.get_cached_context_summary( self.experiment, self.timelapse_orchestrator, self.timeline_manager @@ -715,10 +986,7 @@ async def handle_message(self, user_message: str) -> str: self._update_system_prompt(context_summary) # Add user message to history - self.conversation.conversation_history.append({ - "role": "user", - "content": user_message - }) + self.conversation.conversation_history.append({"role": "user", "content": user_message}) tools = self._get_tools_for_mode() cached_prompt = self._get_cached_system_prompt() @@ -745,43 +1013,172 @@ async def handle_message_stream(self, user_message: str): Chunks with 'type' and data """ if quick_response := self.conversation.try_quick_response( - user_message, self.experiment, self.mode, - self.enter_plan_mode, self.exit_plan_mode, + user_message, + self.experiment, + self.mode, + self.enter_plan_mode, + self.exit_plan_mode, ): - yield {'type': 'text', 'text': quick_response} + yield {"type": "text", "text": quick_response} return - context_summary = await self.prompts.get_cached_context_summary( - self.experiment, self.timelapse_orchestrator, self.timeline_manager - ) - self._update_system_prompt(context_summary) + if not self.api_enabled: + yield {"type": "text", "text": _NO_API_NOTICE} + return - self.conversation.conversation_history.append({ - "role": "user", - "content": user_message - }) + # Hold the turn-lock for the whole streamed turn so an autonomous wake + # turn cannot interleave on the shared conversation_history. + lock = getattr(self, "_turn_lock", None) + acquired = False + if lock is not None: + await lock.acquire() + acquired = True + try: + context_summary = await self.prompts.get_cached_context_summary( + self.experiment, self.timelapse_orchestrator, self.timeline_manager + ) + self._update_system_prompt(context_summary) - tools = self._get_tools_for_mode() - cached_prompt = self._get_cached_system_prompt() + self.conversation.conversation_history.append({"role": "user", "content": user_message}) - inner_gen = self.conversation.call_claude_stream( - cached_prompt, tools, - tool_label_fn=self.conversation.tool_label, - auto_save_fn=self._auto_save, - ) - sent_value = None + tools = self._get_tools_for_mode() + cached_prompt = self._get_cached_system_prompt() + + inner_gen = self.conversation.call_claude_stream( + cached_prompt, + tools, + tool_label_fn=self.conversation.tool_label, + auto_save_fn=self._auto_save, + ) + sent_value = None + + try: + while True: + if sent_value is None: + chunk = await inner_gen.__anext__() + else: + chunk = await inner_gen.asend(sent_value) + sent_value = yield chunk + except StopAsyncIteration: + return + finally: + if acquired: + lock.release() + + async def run_wake_turn( + self, wake_note: str, trigger: str | None = None, interactive: bool = False + ): + """Drive one autonomous (no-user) turn for the wake-router. + + Runs through the normal streaming pipeline (so it acquires the turn-lock + and is recorded to conversation history / auto-saved). Brackets the turn + with an 'autonomous_start' (carrying the wake trigger) and a synthesized + 'stream_end' so it streams to the web chat distinctly. Sets + _autonomous_active so the registry backstop refuses irreversible tools. + When interactive (ASK mode) a choice_request round-trips through the + operator; otherwise it is auto-cancelled. Run mode only. + """ + if self.mode != "run": + logger.info("Wake turn skipped — agent not in run mode (mode=%s)", self.mode) + return "" + async def _emit(chunk): + cb = getattr(self, "on_message_callback", None) + if cb is None: + return + try: + res = cb(chunk) + if asyncio.iscoroutine(res): + await res + except Exception: + logger.debug("on_message_callback failed for wake chunk", exc_info=True) + + await _emit({"type": "autonomous_start", "trigger": trigger or ""}) + text_parts = [] + self._autonomous_active = True + agen = self.handle_message_stream(wake_note) + sent_value = None try: while True: - if sent_value is None: - chunk = await inner_gen.__anext__() - else: - chunk = await inner_gen.asend(sent_value) - sent_value = yield chunk - except StopAsyncIteration: - return + try: + if sent_value is None: + chunk = await agen.__anext__() + else: + chunk = await agen.asend(sent_value) + sent_value = None + except StopAsyncIteration: + break + ctype = chunk.get("type") if isinstance(chunk, dict) else None + if ctype == "text": + text_parts.append(chunk.get("text", "")) + if ctype == "choice_request": + # Resolve via the operator (ASK) or auto-cancel (AUTO). + sent_value = await self._resolve_wake_choice(chunk, _emit, interactive) + continue # don't re-emit the raw choice_request + await _emit(chunk) + except Exception: + logger.exception("run_wake_turn error") + finally: + self._autonomous_active = False + try: + # Release the turn-lock even if a picker hung / timed out. + await agen.aclose() + except Exception: + pass + await _emit({"type": "stream_end"}) + summary = "".join(text_parts).strip() + if summary: + logger.info("Autonomous wake turn result: %s", summary[:500]) + return summary + + async def _resolve_wake_choice(self, chunk, emit, interactive): + """Resolve a choice_request raised during a wake turn. + + AUTO (or no operator channel) -> 'cancelled'. ASK -> register a future via + the web choice-factory, broadcast the picker to clients, and await the + operator's selection (timeout -> 'skip' so an unanswered picker can't hold + the turn-lock forever).""" + choice_data = chunk.get("choice_data", {}) if isinstance(chunk, dict) else {} + factory = getattr(self, "_wake_choice_factory", None) + if not interactive or factory is None: + logger.info( + "Wake picker auto-cancelled (interactive=%s, channel=%s)", + interactive, + factory is not None, + ) + return "cancelled" + try: + future = factory(choice_data) # registers future + sets request_id + except Exception: + logger.exception("wake choice factory failed") + return "cancelled" + request_id = choice_data.get("request_id", "") + await emit({**chunk, "origin": "wake", "request_id": request_id}) + from gently.app.wake_router import ASK_TIMEOUT_SEC - async def get_tool_call(self, user_message: str) -> Optional[Dict]: + try: + selected = await asyncio.wait_for(future, timeout=ASK_TIMEOUT_SEC) + except asyncio.TimeoutError: + logger.info("Wake ASK timed out (%.0fs) -> skip", ASK_TIMEOUT_SEC) + selected = "skip" + except asyncio.CancelledError: + # The picker future was cancelled (e.g. the operator disconnected) — + # treat as a cancelled proposal so the turn finishes cleanly. + logger.info("Wake ASK future cancelled -> cancelled") + selected = "cancelled" + except Exception: + selected = "cancelled" + finally: + # Don't leak the future in the router-scoped registry on timeout/cancel. + discard = getattr(self, "_wake_choice_discard", None) + if discard is not None and request_id: + try: + discard(request_id) + except Exception: + pass + return selected or "skip" + + async def get_tool_call(self, user_message: str) -> dict | None: """Dry-run tool call (for benchmarking).""" context_summary = await self.prompts.get_cached_context_summary( self.experiment, self.timelapse_orchestrator, self.timeline_manager @@ -793,25 +1190,25 @@ async def get_tool_call(self, user_message: str) -> Optional[Dict]: # === Experiment Management Methods === - def load_embryos_from_database(self, database: Dict): + def load_embryos_from_database(self, database: dict): """Load embryos from calibration database.""" - if 'embryos' not in database: + if "embryos" not in database: return - for embryo_id, embryo_data in database['embryos'].items(): - position = embryo_data.get('stage_position_after_centering_um', {}) - calibration = embryo_data.get('calibration', {}) + for embryo_id, embryo_data in database["embryos"].items(): + position = embryo_data.get("stage_position_after_centering_um", {}) + calibration = embryo_data.get("calibration", {}) self.experiment.add_embryo( embryo_id=embryo_id, position=position, calibration=calibration, - uid=embryo_data.get('uid'), + uid=embryo_data.get("uid"), ) self._update_system_prompt() - def import_embryos_from_session(self, session_id: str, clear_existing: bool = False) -> Dict: + def import_embryos_from_session(self, session_id: str, clear_existing: bool = False) -> dict: """ Import embryos from another session into the current experiment. @@ -847,11 +1244,16 @@ def import_embryos_from_session(self, session_id: str, clear_existing: bool = Fa eid, ) src_role = "unassigned" + coarse = row.get("position_coarse") or {} + fine = row.get("position_fine") or {} embryo_states[eid] = { - "stage_position": { - "x": row.get("position_x"), - "y": row.get("position_y"), - }, + # stage_position remains for legacy consumers of this snapshot. + # It carries the resolved (fine ?? coarse) view; add_embryo() + # downstream treats it as coarse, but the explicit + # position_fine field below will override that on restore. + "stage_position": dict(fine) if fine else dict(coarse), + "position_coarse": dict(coarse), + "position_fine": dict(fine), "calibration": row.get("calibration") or {}, "uid": row.get("embryo_uid"), "user_label": row.get("nickname"), @@ -862,18 +1264,19 @@ def import_embryos_from_session(self, session_id: str, clear_existing: bool = Fa if not embryo_states: session_data = self.store.load_session_snapshot(session_id) if session_data: - embryo_states = session_data.get('embryo_states', {}) + embryo_states = session_data.get("embryo_states", {}) if not embryo_states: return { - 'success': False, - 'error': "No embryos found in session", - 'imported': [], - 'skipped': [], + "success": False, + "error": "No embryos found in session", + "imported": [], + "skipped": [], } if clear_existing: self.experiment.embryos.clear() + self.experiment.notify_embryos_changed() imported = [] skipped = [] @@ -885,26 +1288,33 @@ def import_embryos_from_session(self, session_id: str, clear_existing: bool = Fa continue try: - position = embryo_data.get('stage_position', {}) - calibration = embryo_data.get('calibration', {}) - source_uid = embryo_data.get('uid') or f"{session_id}_{embryo_id}" + # Prefer explicit coarse/fine when the snapshot has them + # (FileStore path); fall back to flat stage_position for the + # legacy JSON-snapshot path which only carries the resolved view. + position_coarse = embryo_data.get("position_coarse") + position_fine = embryo_data.get("position_fine") + if position_coarse is None and position_fine is None: + position_coarse = embryo_data.get("stage_position", {}) + calibration = embryo_data.get("calibration", {}) + source_uid = embryo_data.get("uid") or f"{session_id}_{embryo_id}" self.experiment.add_embryo( embryo_id=embryo_id, - position=position, + position=position_coarse or {}, + position_fine=position_fine or {}, calibration=calibration, - user_label=embryo_data.get('user_label'), + user_label=embryo_data.get("user_label"), uid=source_uid, - role=embryo_data.get('role') or 'unassigned', + role=embryo_data.get("role") or "unassigned", ) embryo = self.experiment.embryos[embryo_id] - embryo.nickname = embryo_data.get('nickname') - embryo.interval_seconds = embryo_data.get('interval_seconds') - embryo.num_slices = embryo_data.get('num_slices', 50) - embryo.exposure_ms = embryo_data.get('exposure_ms', 10.0) - embryo.priority = embryo_data.get('priority', 'normal') - embryo.acquisition_mode = embryo_data.get('acquisition_mode', 'volume') + embryo.nickname = embryo_data.get("nickname") + embryo.interval_seconds = embryo_data.get("interval_seconds") + embryo.num_slices = embryo_data.get("num_slices", 50) + embryo.exposure_ms = embryo_data.get("exposure_ms", 10.0) + embryo.priority = embryo_data.get("priority", "normal") + embryo.acquisition_mode = embryo_data.get("acquisition_mode", "volume") # Light budget import. Prefer fields already on embryo_data # (future schema may persist these directly on embryo.yaml); @@ -915,27 +1325,22 @@ def import_embryos_from_session(self, session_id: str, clear_existing: bool = Fa # removed and the import should fail loudly if dose is # missing. dose = self._compute_imported_dose(session_id, embryo_id) - embryo.exposure_count = ( - embryo_data.get('exposure_count') - or dose['exposure_count'] - ) + embryo.exposure_count = embryo_data.get("exposure_count") or dose["exposure_count"] embryo.total_exposure_ms = ( - embryo_data.get('total_exposure_ms') - or dose['total_exposure_ms'] + embryo_data.get("total_exposure_ms") or dose["total_exposure_ms"] ) embryo.timepoints_acquired = ( - embryo_data.get('timepoints_acquired') - or dose['exposure_count'] + embryo_data.get("timepoints_acquired") or dose["exposure_count"] ) - last_imaged_str = embryo_data.get('last_imaged') + last_imaged_str = embryo_data.get("last_imaged") if last_imaged_str: try: embryo.last_imaged = datetime.fromisoformat(last_imaged_str) except (ValueError, TypeError): - embryo.last_imaged = dose['last_imaged'] + embryo.last_imaged = dose["last_imaged"] else: - embryo.last_imaged = dose['last_imaged'] + embryo.last_imaged = dose["last_imaged"] imported.append(embryo_id) @@ -946,14 +1351,14 @@ def import_embryos_from_session(self, session_id: str, clear_existing: bool = Fa self._mark_significant_action("embryo_import") return { - 'success': len(imported) > 0, - 'imported': imported, - 'skipped': skipped, - 'errors': errors, - 'source_session': session_id, + "success": len(imported) > 0, + "imported": imported, + "skipped": skipped, + "errors": errors, + "source_session": session_id, } - def _compute_imported_dose(self, source_session_id: str, embryo_id: str) -> Dict: + def _compute_imported_dose(self, source_session_id: str, embryo_id: str) -> dict: """Reconstruct an embryo's realized 488 nm photodose from the source session's per-volume meta files. @@ -968,66 +1373,68 @@ def _compute_imported_dose(self, source_session_id: str, embryo_id: str) -> Dict TODO: replace with reading a persisted ``dose:`` block from embryo.yaml once dose-tracking is first-class. """ - import yaml from datetime import datetime from pathlib import Path + import yaml + result = { - 'exposure_count': 0, - 'total_exposure_ms': 0.0, - 'last_imaged': None, + "exposure_count": 0, + "total_exposure_ms": 0.0, + "last_imaged": None, } if not self.store: return result # FileStore exposes _session_dir(session_id) → resolved Path. - session_dir_fn = getattr(self.store, '_session_dir', None) + session_dir_fn = getattr(self.store, "_session_dir", None) sd = session_dir_fn(source_session_id) if callable(session_dir_fn) else None if sd is None: return result - vols_dir = Path(sd) / 'embryos' / embryo_id / 'volumes' + vols_dir = Path(sd) / "embryos" / embryo_id / "volumes" if not vols_dir.is_dir(): return result latest = None - for meta_path in sorted(vols_dir.glob('*.meta.yaml')): + for meta_path in sorted(vols_dir.glob("*.meta.yaml")): try: doc = yaml.safe_load(meta_path.read_text()) or {} except Exception: continue - md = doc.get('metadata') or {} - num_slices = md.get('num_slices') + md = doc.get("metadata") or {} + num_slices = md.get("num_slices") if num_slices is None: - shape = doc.get('shape') or [] + shape = doc.get("shape") or [] num_slices = shape[0] if shape else 0 - exposure_ms = md.get('exposure_ms') or 0.0 + exposure_ms = md.get("exposure_ms") or 0.0 try: - result['total_exposure_ms'] += float(num_slices) * float(exposure_ms) + result["total_exposure_ms"] += float(num_slices) * float(exposure_ms) except (TypeError, ValueError): pass - result['exposure_count'] += 1 - acq = doc.get('acquired_at') + result["exposure_count"] += 1 + acq = doc.get("acquired_at") if acq and (latest is None or acq > latest): latest = acq if latest: try: - result['last_imaged'] = datetime.fromisoformat(latest) + result["last_imaged"] = datetime.fromisoformat(latest) except (ValueError, TypeError): pass return result - async def on_volume_acquired(self, embryo_id: str, timepoint: int, - volume_data, volume_path=None): + async def on_volume_acquired( + self, embryo_id: str, timepoint: int, volume_data, volume_path=None + ): """Callback when a volume is acquired.""" embryo = self.experiment.embryos.get(embryo_id) if not embryo: return - if hasattr(volume_data, 'read_volume'): + if hasattr(volume_data, "read_volume"): volume = volume_data.read_volume() else: volume = volume_data @@ -1036,9 +1443,10 @@ async def on_volume_acquired(self, embryo_id: str, timepoint: int, if self.store and self.session_id: try: self.store.register_embryo( - self.session_id, embryo_id, - position_x=embryo.stage_position.get('x') if embryo.stage_position else None, - position_y=embryo.stage_position.get('y') if embryo.stage_position else None, + self.session_id, + embryo_id, + position_coarse=embryo.position_coarse or None, + position_fine=embryo.position_fine or None, calibration=embryo.calibration, role=embryo.role, ) @@ -1053,14 +1461,19 @@ async def on_volume_acquired(self, embryo_id: str, timepoint: int, } if volume_path is not None: stored_path = self.store.register_volume( - self.session_id, embryo_id, timepoint, + self.session_id, + embryo_id, + timepoint, incoming_path=Path(volume_path), metadata=acq_metadata, volume_data=volume, ) else: stored_path = self.store.put_volume( - self.session_id, embryo_id, timepoint, volume, + self.session_id, + embryo_id, + timepoint, + volume, metadata=acq_metadata, ) except StorageError: @@ -1075,9 +1488,9 @@ async def on_volume_acquired(self, embryo_id: str, timepoint: int, if self.viz_server and volume is not None: try: from gently.core.imaging import ( - projection_three_view, - compute_crop_bounds, apply_crop_bounds, + compute_crop_bounds, + projection_three_view, ) view_a = volume[0] if volume.ndim == 4 else volume @@ -1085,14 +1498,18 @@ async def on_volume_acquired(self, embryo_id: str, timepoint: int, if view_a.ndim == 3: z_depth, height, width = view_a.shape if width > height * 2: - view_a = view_a[:, :, :width // 2] + view_a = view_a[:, :, : width // 2] bounds = compute_crop_bounds(view_a) cropped = apply_crop_bounds(view_a, bounds) three_view_img, _ = projection_three_view(cropped) else: three_view_img = view_a.astype(np.float32) if three_view_img.max() > three_view_img.min(): - three_view_img = (three_view_img - three_view_img.min()) / (three_view_img.max() - three_view_img.min()) * 255 + three_view_img = ( + (three_view_img - three_view_img.min()) + / (three_view_img.max() - three_view_img.min()) + * 255 + ) three_view_img = three_view_img.astype(np.uint8) self.push_viz( @@ -1100,29 +1517,32 @@ async def on_volume_acquired(self, embryo_id: str, timepoint: int, uid=projection_uid, data_type="volume_projection", metadata={ - 'embryo_id': embryo_id, - 'timepoint': timepoint, - 'shape': list(volume.shape), - 'projection_uid': projection_uid, - 'volume_uid': volume_uid, - 'projection_type': 'three_view', - } + "embryo_id": embryo_id, + "timepoint": timepoint, + "shape": list(volume.shape), + "projection_uid": projection_uid, + "volume_uid": volume_uid, + "projection_type": "three_view", + }, ) except Exception as e: logger.warning(f"Failed to push to viz: {e}") - self._emit_event(EventType.VOLUME_ACQUIRED, { - 'embryo_id': embryo_id, - 'timepoint': timepoint, - 'volume_uid': volume_uid, - 'projection_uid': projection_uid, - 'volume_path': str(stored_path) if stored_path else None, - 'shape': list(volume.shape), - }) + self._emit_event( + EventType.VOLUME_ACQUIRED, + { + "embryo_id": embryo_id, + "timepoint": timepoint, + "volume_uid": volume_uid, + "projection_uid": projection_uid, + "volume_path": str(stored_path) if stored_path else None, + "shape": list(volume.shape), + }, + ) return { - 'volume_uid': volume_uid, - 'projection_uid': projection_uid, + "volume_uid": volume_uid, + "projection_uid": projection_uid, } def should_stop_experiment(self) -> bool: @@ -1131,22 +1551,31 @@ def should_stop_experiment(self) -> bool: return False return all(e.should_skip for e in self.experiment.embryos.values()) - def get_embryo_acquisition_order(self) -> List[str]: + def get_embryo_acquisition_order(self) -> list[str]: """Get embryo acquisition order based on priority.""" - high = [e.id for e in self.experiment.embryos.values() if e.priority == "high" and not e.should_skip] - normal = [e.id for e in self.experiment.embryos.values() if e.priority == "normal" and not e.should_skip] - low = [e.id for e in self.experiment.embryos.values() if e.priority == "low" and not e.should_skip] + high = [ + e.id + for e in self.experiment.embryos.values() + if e.priority == "high" and not e.should_skip + ] + normal = [ + e.id + for e in self.experiment.embryos.values() + if e.priority == "normal" and not e.should_skip + ] + low = [ + e.id + for e in self.experiment.embryos.values() + if e.priority == "low" and not e.should_skip + ] return high + normal + low - def decide_parameters(self, embryo_id: str, timepoint: int) -> Dict: + def decide_parameters(self, embryo_id: str, timepoint: int) -> dict: """Get current acquisition parameters for embryo.""" embryo = self.experiment.embryos.get(embryo_id) if not embryo: - return {'num_slices': 50, 'exposure_ms': 10.0} - return { - 'num_slices': embryo.num_slices, - 'exposure_ms': embryo.exposure_ms - } + return {"num_slices": 50, "exposure_ms": 10.0} + return {"num_slices": embryo.num_slices, "exposure_ms": embryo.exposure_ms} def decide_next_interval(self, timepoint: int) -> float: """Decide interval until next timepoint.""" @@ -1171,8 +1600,9 @@ async def check_blank_image( logger.warning(f"[BLANK_CHECK] {embryo_id}: Numerical check indicates blank image") return True - import io import base64 + import io + from PIL import Image if max_proj.max() > 0: @@ -1182,10 +1612,11 @@ async def check_blank_image( img = Image.fromarray(normalized) buffer = io.BytesIO() - img.save(buffer, format='PNG') + img.save(buffer, format="PNG") b64_image = base64.b64encode(buffer.getvalue()).decode() - prompt = """Look at this microscopy image. Is this a VALID microscopy image or a BLANK/CORRUPTED image? + prompt = """Look at this microscopy image. Is this a VALID microscopy image or a +BLANK/CORRUPTED image? A BLANK or CORRUPTED image shows: - Mostly uniform gray/black with no structure @@ -1203,20 +1634,22 @@ async def check_blank_image( self.claude.messages.create, model=settings.models.fast, max_tokens=10, - messages=[{ - "role": "user", - "content": [ - {"type": "text", "text": prompt}, - { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/png", - "data": b64_image - } - } - ] - }] + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": b64_image, + }, + }, + ], + } + ], ) result = response.content[0].text.strip().upper() @@ -1227,7 +1660,11 @@ async def check_blank_image( return is_blank - except (anthropic.APIConnectionError, anthropic.RateLimitError, anthropic.APIStatusError) as e: + except ( + anthropic.APIConnectionError, + anthropic.RateLimitError, + anthropic.APIStatusError, + ) as e: logger.error(f"[BLANK_CHECK] Claude API error for {embryo_id}: {e}") return False except Exception as e: diff --git a/gently/app/benchmark.py b/gently/app/benchmark.py index 5724f548..9263d7fe 100644 --- a/gently/app/benchmark.py +++ b/gently/app/benchmark.py @@ -15,7 +15,6 @@ /benchmark --volumes 10 --slices 50 """ -import asyncio import csv import logging import shutil @@ -25,7 +24,7 @@ from dataclasses import dataclass, field from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING if TYPE_CHECKING: from .agent import MicroscopyAgent @@ -36,6 +35,7 @@ @dataclass class VolumeTiming: """Timing breakdown for a single volume acquisition.""" + volume_idx: int embryo_id: str timepoint: int @@ -51,13 +51,14 @@ class VolumeTiming: volume_shape: tuple = () file_size_mb: float = 0.0 success: bool = True - error: Optional[str] = None + error: str | None = None @dataclass class BenchmarkResults: """Aggregate benchmark results.""" - timings: List[VolumeTiming] = field(default_factory=list) + + timings: list[VolumeTiming] = field(default_factory=list) num_embryos: int = 1 num_slices: int = 50 exposure_ms: float = 10.0 @@ -65,14 +66,14 @@ class BenchmarkResults: completed_at: str = "" @property - def successful(self) -> List[VolumeTiming]: + def successful(self) -> list[VolumeTiming]: return [t for t in self.timings if t.success] @property - def failed(self) -> List[VolumeTiming]: + def failed(self) -> list[VolumeTiming]: return [t for t in self.timings if not t.success] - def _stat(self, values: List[float]) -> Dict[str, float]: + def _stat(self, values: list[float]) -> dict[str, float]: if not values: return {"mean": 0, "std": 0, "min": 0, "max": 0} return { @@ -83,19 +84,19 @@ def _stat(self, values: List[float]) -> Dict[str, float]: } @property - def acquisition_stats(self) -> Dict[str, float]: + def acquisition_stats(self) -> dict[str, float]: return self._stat([t.acquisition_time for t in self.successful]) @property - def storage_stats(self) -> Dict[str, float]: + def storage_stats(self) -> dict[str, float]: return self._stat([t.storage_time for t in self.successful]) @property - def viz_push_stats(self) -> Dict[str, float]: + def viz_push_stats(self) -> dict[str, float]: return self._stat([t.viz_push_time for t in self.successful]) @property - def total_stats(self) -> Dict[str, float]: + def total_stats(self) -> dict[str, float]: return self._stat([t.total_time for t in self.successful]) @property @@ -137,10 +138,10 @@ async def run_benchmark( exposure_ms: float = 10.0, warmup: int = 1, # Legacy parameters (ignored, kept for API compat with bridge) - n_volumes: int = None, - n_slices: int = None, - n_warmup: int = None, - progress_fn: Optional[callable] = None, + n_volumes: int | None = None, + n_slices: int | None = None, + n_warmup: int | None = None, + progress_fn: callable | None = None, ) -> BenchmarkResults: """ Run end-to-end volume acquisition benchmark. @@ -172,7 +173,7 @@ async def run_benchmark( started_at=datetime.now().isoformat(), ) - viz_server = getattr(agent, 'viz_server', None) + viz_server = getattr(agent, "viz_server", None) has_viz = viz_server is not None temp_dir = Path(tempfile.mkdtemp(prefix="gently_benchmark_")) @@ -187,10 +188,13 @@ async def run_benchmark( logger.info( "Benchmark: %d volumes + %d warmup, %d slices, %.0f ms exposure", - num_volumes, warmup, num_slices, exposure_ms, + num_volumes, + warmup, + num_slices, + exposure_ms, ) - timepoints: Dict[str, int] = {eid: 0 for eid in embryo_ids} + timepoints: dict[str, int] = {eid: 0 for eid in embryo_ids} total_iterations = warmup + num_volumes for i in range(total_iterations): @@ -221,18 +225,16 @@ async def run_benchmark( try: if embryo and embryo.calibration: cal = embryo.calibration - galvo_amp = cal.get('galvo_amplitude', 0.5) - galvo_center = cal.get('galvo_center', 0.0) - piezo_amp = cal.get('piezo_amplitude', 25.0) - piezo_center = cal.get('piezo_center', 50.0) + galvo_amp = cal.get("galvo_amplitude", 0.5) + galvo_center = cal.get("galvo_center", 0.0) + piezo_amp = cal.get("piezo_amplitude", 25.0) + piezo_center = cal.get("piezo_center", 50.0) else: galvo_amp, galvo_center = 0.5, 0.0 piezo_amp, piezo_center = 25.0, 50.0 if embryo and embryo.position: - await agent.client.move_stage( - embryo.position['x'], embryo.position['y'] - ) + await agent.client.move_stage(embryo.position["x"], embryo.position["y"]) # Stage 1: Acquisition t0 = time.perf_counter() @@ -247,28 +249,34 @@ async def run_benchmark( t1 = time.perf_counter() timing.acquisition_time = t1 - t0 - if not result.get('success'): + if not result.get("success"): timing.success = False - timing.error = result.get('error', 'Acquisition failed') + timing.error = result.get("error", "Acquisition failed") if not is_warmup: results.timings.append(timing) continue - volume = result.get('volume') - volume_path = result.get('volume_path') + volume = result.get("volume") + volume_path = result.get("volume_path") timing.volume_shape = volume.shape if volume is not None else () # Stage 2: Storage t2 = time.perf_counter() if volume_path: canonical_path = temp_store.register_volume( - benchmark_session, embryo_id, tp, Path(volume_path), + benchmark_session, + embryo_id, + tp, + Path(volume_path), volume_data=volume, ) timing.file_size_mb = canonical_path.stat().st_size / (1024 * 1024) elif volume is not None: canonical_path = temp_store.put_volume( - benchmark_session, embryo_id, tp, volume, + benchmark_session, + embryo_id, + tp, + volume, ) timing.file_size_mb = canonical_path.stat().st_size / (1024 * 1024) t3 = time.perf_counter() @@ -282,7 +290,11 @@ async def run_benchmark( proj, uid=f"benchmark_{embryo_id}_t{tp:04d}", data_type="benchmark", - metadata={"embryo_id": embryo_id, "timepoint": tp, "benchmark": True}, + metadata={ + "embryo_id": embryo_id, + "timepoint": tp, + "benchmark": True, + }, ) t5 = time.perf_counter() timing.viz_push_time = t5 - t4 @@ -355,21 +367,34 @@ def save_benchmark_csv(results: BenchmarkResults, path: Path): writer.writerow(["# std_total_s", f"{total['std']:.6f}"]) writer.writerow([]) - writer.writerow([ - "volume_idx", "embryo_id", "timepoint", - "acquisition_s", "storage_s", "viz_push_s", "total_s", - "file_size_mb", "success", "error" - ]) + writer.writerow( + [ + "volume_idx", + "embryo_id", + "timepoint", + "acquisition_s", + "storage_s", + "viz_push_s", + "total_s", + "file_size_mb", + "success", + "error", + ] + ) for t in results.timings: - writer.writerow([ - t.volume_idx, t.embryo_id, t.timepoint, - f"{t.acquisition_time:.6f}", - f"{t.storage_time:.6f}", - f"{t.viz_push_time:.6f}", - f"{t.total_time:.6f}", - f"{t.file_size_mb:.2f}", - t.success, - t.error or "", - ]) + writer.writerow( + [ + t.volume_idx, + t.embryo_id, + t.timepoint, + f"{t.acquisition_time:.6f}", + f"{t.storage_time:.6f}", + f"{t.viz_push_time:.6f}", + f"{t.total_time:.6f}", + f"{t.file_size_mb:.2f}", + t.success, + t.error or "", + ] + ) return path diff --git a/gently/app/bottom_camera_monitor.py b/gently/app/bottom_camera_monitor.py index 6ec6443e..1671e718 100644 --- a/gently/app/bottom_camera_monitor.py +++ b/gently/app/bottom_camera_monitor.py @@ -14,7 +14,7 @@ import asyncio import logging -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING from gently.core.event_bus import EventType, get_event_bus from gently.core.service import Service @@ -35,15 +35,15 @@ class BottomCameraStreamMonitor(Service): def __init__( self, - microscope: "DiSPIMMicroscope", + microscope: DiSPIMMicroscope, reconnect_delay_sec: float = 2.0, ): super().__init__(name="bottom-camera-monitor", service_type="bridge") self.microscope = microscope self.reconnect_delay_sec = reconnect_delay_sec - self._task: Optional[asyncio.Task] = None + self._task: asyncio.Task | None = None self._stop_requested = False - self._last_frame_ts: Optional[float] = None + self._last_frame_ts: float | None = None @property def running(self) -> bool: @@ -89,7 +89,8 @@ async def _run(self): except Exception as exc: logger.warning( "BottomCameraStreamMonitor: stream ended (%s) — reconnecting in %.1fs", - exc, self.reconnect_delay_sec, + exc, + self.reconnect_delay_sec, ) if self._stop_requested: break diff --git a/gently/app/calibration/__init__.py b/gently/app/calibration/__init__.py index 9365824e..683b152c 100644 --- a/gently/app/calibration/__init__.py +++ b/gently/app/calibration/__init__.py @@ -19,9 +19,9 @@ """ from .base import CalibrationData, CalibrationPipeline, aggregate_calibrations -from .two_point import TwoPointCalibration from .edge_roi import EdgeRoiCalibration from .registry import CALIBRATION_REGISTRY, get_calibration_pipeline +from .two_point import TwoPointCalibration __all__ = [ "CalibrationData", diff --git a/gently/app/calibration/base.py b/gently/app/calibration/base.py index ae3e72dc..3b8259d7 100644 --- a/gently/app/calibration/base.py +++ b/gently/app/calibration/base.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any import numpy as np @@ -19,11 +19,12 @@ class CalibrationData: ``edge_bbox`` for edge ROI). Merged ``CalibrationData`` from multiple pipelines is what gets passed to detector context as ``calibration``. """ + pipeline_name: str captured_at: datetime = field(default_factory=datetime.now) - source_embryo_ids: List[str] = field(default_factory=list) - payload: Dict[str, Any] = field(default_factory=dict) - notes: Optional[str] = None + source_embryo_ids: list[str] = field(default_factory=list) + payload: dict[str, Any] = field(default_factory=dict) + notes: str | None = None class CalibrationPipeline(ABC): @@ -41,8 +42,8 @@ class CalibrationPipeline(ABC): @abstractmethod def capture( self, - source_volumes: Dict[str, np.ndarray], - context: Dict[str, Any], + source_volumes: dict[str, np.ndarray], + context: dict[str, Any], ) -> CalibrationData: """Compute calibration data from one or more source embryo volumes.""" ... @@ -57,8 +58,8 @@ def apply( def aggregate_calibrations( - pipelines_data: List[CalibrationData], -) -> Dict[str, Any]: + pipelines_data: list[CalibrationData], +) -> dict[str, Any]: """Merge multiple CalibrationData payloads into the single dict the detector context expects. @@ -67,7 +68,7 @@ def aggregate_calibrations( single pipeline is the pipeline's responsibility (typically via pixel-wise median, see TwoPointCalibration). """ - merged: Dict[str, Any] = {} + merged: dict[str, Any] = {} for cal in pipelines_data: merged.update(cal.payload) return merged diff --git a/gently/app/calibration/edge_roi.py b/gently/app/calibration/edge_roi.py index 68a0cb58..25aa907a 100644 --- a/gently/app/calibration/edge_roi.py +++ b/gently/app/calibration/edge_roi.py @@ -11,7 +11,7 @@ dopaminergic detector uses). """ -from typing import Any, Dict, List, Optional, Tuple +from typing import Any import numpy as np @@ -28,8 +28,8 @@ def __init__(self, padding_px: int = 10): def capture( self, - source_volumes: Dict[str, Any], - context: Dict[str, Any], + source_volumes: dict[str, Any], + context: dict[str, Any], ) -> CalibrationData: """``source_volumes`` here is expected to be a flat ``{embryo_id: volume_ndarray}`` dict (no dark/flat split). @@ -39,10 +39,10 @@ def capture( elsewhere and we just want to reuse its output. """ precomputed = context.get("embryo_bboxes") if context else None - bboxes: List[Tuple[int, int, int, int]] = [] + bboxes: list[tuple[int, int, int, int]] = [] if precomputed: - for eid, bb in precomputed.items(): + for _eid, bb in precomputed.items(): if bb is not None and len(bb) == 4: bboxes.append(tuple(map(int, bb))) else: @@ -65,8 +65,10 @@ def capture( xs1 = np.array([b[2] for b in bboxes]) ys1 = np.array([b[3] for b in bboxes]) agg = ( - int(np.median(xs0)), int(np.median(ys0)), - int(np.median(xs1)), int(np.median(ys1)), + int(np.median(xs0)), + int(np.median(ys0)), + int(np.median(xs1)), + int(np.median(ys1)), ) return CalibrationData( @@ -77,7 +79,7 @@ def capture( ) -def _bbox_from_volume(vol, padding: int = 10) -> Optional[Tuple[int, int, int, int]]: +def _bbox_from_volume(vol, padding: int = 10) -> tuple[int, int, int, int] | None: """Cheap thresholding-based bbox fallback when SAM hasn't run. Max-projects, thresholds at 25th percentile + delta, returns the diff --git a/gently/app/calibration/registry.py b/gently/app/calibration/registry.py index 5953110f..cadfa857 100644 --- a/gently/app/calibration/registry.py +++ b/gently/app/calibration/registry.py @@ -1,30 +1,31 @@ """Calibration pipeline registry — name → factory.""" -from typing import Callable, Dict, Optional +from collections.abc import Callable from .base import CalibrationPipeline - CalibrationFactory = Callable[..., CalibrationPipeline] def _make_two_point(**kw) -> CalibrationPipeline: from .two_point import TwoPointCalibration + return TwoPointCalibration(**kw) def _make_edge_roi(**kw) -> CalibrationPipeline: from .edge_roi import EdgeRoiCalibration + return EdgeRoiCalibration(**kw) -CALIBRATION_REGISTRY: Dict[str, CalibrationFactory] = { +CALIBRATION_REGISTRY: dict[str, CalibrationFactory] = { "two_point": _make_two_point, "edge_roi": _make_edge_roi, } -def get_calibration_pipeline(name: str, **kw) -> Optional[CalibrationPipeline]: +def get_calibration_pipeline(name: str, **kw) -> CalibrationPipeline | None: factory = CALIBRATION_REGISTRY.get(name) if factory is None: return None diff --git a/gently/app/calibration/two_point.py b/gently/app/calibration/two_point.py index edd31263..bea11c1e 100644 --- a/gently/app/calibration/two_point.py +++ b/gently/app/calibration/two_point.py @@ -12,7 +12,7 @@ calibration embryos uses pixel-wise **median** for robustness. """ -from typing import Any, Dict, List +from typing import Any import numpy as np @@ -44,8 +44,8 @@ def __init__(self, dark_source: str = "dark", flat_source: str = "flat"): def capture( self, - source_volumes: Dict[str, Any], - context: Dict[str, Any], + source_volumes: dict[str, Any], + context: dict[str, Any], ) -> CalibrationData: darks = source_volumes.get(self.dark_source, {}) or {} flats = source_volumes.get(self.flat_source, {}) or {} @@ -53,7 +53,7 @@ def capture( dark_proj = _aggregate_projections(list(darks.values())) flat_proj = _aggregate_projections(list(flats.values())) - payload: Dict[str, Any] = {} + payload: dict[str, Any] = {} if dark_proj is not None: payload["dark"] = dark_proj if flat_proj is not None: @@ -72,7 +72,7 @@ def capture( ) -def _aggregate_projections(volumes: List[np.ndarray]): +def _aggregate_projections(volumes: list[np.ndarray]): """Max-project each volume to 2D, then median across embryos.""" if not volumes: return None diff --git a/gently/app/detectors/__init__.py b/gently/app/detectors/__init__.py index 04db8775..bb54301c 100644 --- a/gently/app/detectors/__init__.py +++ b/gently/app/detectors/__init__.py @@ -18,9 +18,9 @@ """ from .base import Detector, DetectorResult +from .blank_image import BlankImageDetector from .dopaminergic_signal import DopaminergicSignalDetector from .hatching import HatchingDetector -from .blank_image import BlankImageDetector from .perception_proxy import PerceptionProxy from .registry import DETECTOR_REGISTRY, get_detector diff --git a/gently/app/detectors/base.py b/gently/app/detectors/base.py index a25caaca..0179061f 100644 --- a/gently/app/detectors/base.py +++ b/gently/app/detectors/base.py @@ -11,7 +11,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any import numpy as np @@ -26,18 +26,19 @@ class DetectorResult: structure_quality: "GOOD"}`` for the dopaminergic detector, ``{stage: "pretzel"}`` for the perception proxy). """ + detector_name: str embryo_id: str timepoint: int - findings: Dict[str, Any] = field(default_factory=dict) - confidence: Optional[float] = None - reasoning: Optional[str] = None - raw_response: Optional[str] = None + findings: dict[str, Any] = field(default_factory=dict) + confidence: float | None = None + reasoning: str | None = None + raw_response: str | None = None timestamp: datetime = field(default_factory=datetime.now) - elapsed_ms: Optional[float] = None - error: Optional[str] = None + elapsed_ms: float | None = None + error: str | None = None - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "detector_name": self.detector_name, "embryo_id": self.embryo_id, @@ -76,7 +77,7 @@ class Detector(ABC): async def run( self, volume: np.ndarray, - context: Dict[str, Any], + context: dict[str, Any], ) -> DetectorResult: """Observe the volume and return a structured result.""" ... diff --git a/gently/app/detectors/blank_image.py b/gently/app/detectors/blank_image.py index f8a6bcd3..8973c0cd 100644 --- a/gently/app/detectors/blank_image.py +++ b/gently/app/detectors/blank_image.py @@ -11,7 +11,7 @@ import io import logging import time -from typing import Any, Dict, Optional +from typing import Any import numpy as np @@ -20,7 +20,8 @@ logger = logging.getLogger(__name__) -_BLANK_PROMPT = """Look at this microscopy image. Is this a VALID microscopy image or a BLANK/CORRUPTED image? +_BLANK_PROMPT = """Look at this microscopy image. Is this a VALID microscopy image or a +BLANK/CORRUPTED image? A BLANK or CORRUPTED image shows: - Mostly uniform gray/black with no structure @@ -40,18 +41,19 @@ class BlankImageDetector(Detector): name = "blank_image" - def __init__(self, claude_client=None, model: Optional[str] = None): + def __init__(self, claude_client=None, model: str | None = None): self._claude = claude_client self._model = model async def run( self, volume: np.ndarray, - context: Dict[str, Any], + context: dict[str, Any], ) -> DetectorResult: - from gently.settings import settings - from PIL import Image as PILImage import anthropic + from PIL import Image as PILImage + + from gently.settings import settings embryo_id = context.get("embryo_id", "?") timepoint = int(context.get("timepoint", 0)) @@ -61,15 +63,20 @@ async def run( vol = np.squeeze(volume) if volume is not None else None if vol is None or vol.size == 0: return DetectorResult( - detector_name=self.name, embryo_id=embryo_id, timepoint=timepoint, - findings={"is_blank": True}, reasoning="Empty volume", + detector_name=self.name, + embryo_id=embryo_id, + timepoint=timepoint, + findings={"is_blank": True}, + reasoning="Empty volume", elapsed_ms=(time.time() - start) * 1000, ) max_proj = np.max(vol, axis=0) if vol.ndim == 3 else vol if np.std(max_proj) < 1.0 or np.max(max_proj) < 10: return DetectorResult( - detector_name=self.name, embryo_id=embryo_id, timepoint=timepoint, + detector_name=self.name, + embryo_id=embryo_id, + timepoint=timepoint, findings={"is_blank": True}, reasoning="Numerical check (low std / max)", elapsed_ms=(time.time() - start) * 1000, @@ -78,7 +85,9 @@ async def run( claude = self._claude or context.get("claude") if claude is None: return DetectorResult( - detector_name=self.name, embryo_id=embryo_id, timepoint=timepoint, + detector_name=self.name, + embryo_id=embryo_id, + timepoint=timepoint, findings={"is_blank": False}, reasoning="No Claude client; deferred to numerical check (passed)", elapsed_ms=(time.time() - start) * 1000, @@ -87,8 +96,11 @@ async def run( # Normalize, encode, ask Claude if max_proj.max() == 0: return DetectorResult( - detector_name=self.name, embryo_id=embryo_id, timepoint=timepoint, - findings={"is_blank": True}, reasoning="Max projection is all zeros", + detector_name=self.name, + embryo_id=embryo_id, + timepoint=timepoint, + findings={"is_blank": True}, + reasoning="Max projection is all zeros", elapsed_ms=(time.time() - start) * 1000, ) normalized = (max_proj / max_proj.max() * 255).astype(np.uint8) @@ -101,38 +113,53 @@ async def run( claude.messages.create, model=self._model or settings.models.fast, max_tokens=10, - messages=[{ - "role": "user", - "content": [ - {"type": "text", "text": _BLANK_PROMPT}, - {"type": "image", "source": { - "type": "base64", - "media_type": "image/png", - "data": b64_image, - }}, - ], - }], + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": _BLANK_PROMPT}, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": b64_image, + }, + }, + ], + } + ], ) text = (response.content[0].text if response.content else "").strip().upper() is_blank = "BLANK" in text return DetectorResult( - detector_name=self.name, embryo_id=embryo_id, timepoint=timepoint, + detector_name=self.name, + embryo_id=embryo_id, + timepoint=timepoint, findings={"is_blank": is_blank}, reasoning=text, raw_response=text, elapsed_ms=(time.time() - start) * 1000, ) - except (anthropic.APIConnectionError, anthropic.RateLimitError, anthropic.APIStatusError) as e: + except ( + anthropic.APIConnectionError, + anthropic.RateLimitError, + anthropic.APIStatusError, + ) as e: logger.error("[%s] Claude API error for %s: %s", self.name, embryo_id, e) return DetectorResult( - detector_name=self.name, embryo_id=embryo_id, timepoint=timepoint, + detector_name=self.name, + embryo_id=embryo_id, + timepoint=timepoint, error=f"API error: {e}", elapsed_ms=(time.time() - start) * 1000, ) except Exception as e: logger.exception("[%s] unexpected error", self.name) return DetectorResult( - detector_name=self.name, embryo_id=embryo_id, timepoint=timepoint, + detector_name=self.name, + embryo_id=embryo_id, + timepoint=timepoint, error=str(e), elapsed_ms=(time.time() - start) * 1000, ) diff --git a/gently/app/detectors/dopaminergic_signal.py b/gently/app/detectors/dopaminergic_signal.py index 4d919446..7fb09408 100644 --- a/gently/app/detectors/dopaminergic_signal.py +++ b/gently/app/detectors/dopaminergic_signal.py @@ -23,7 +23,7 @@ import logging import re import time -from typing import Any, Dict, Optional, Tuple +from typing import Any import numpy as np @@ -32,58 +32,81 @@ logger = logging.getLogger(__name__) -_PERCEIVER_PROMPT = """You are looking at the max projection of a volume of C. elegans imaged with a 488 nm light-sheet microscope. The embryo is expressing a fluorophore that lights up when certain neurons are born. +_PERCEIVER_PROMPT = """You are looking at the max projection of a volume of C. elegans imaged +with a 488 nm light-sheet microscope. The embryo is expressing a fluorophore that lights up +when certain neurons are born. -We are trying to image the birth of these neurons and their continued existence. Your description will be read by a classifier that decides how to guide the imaging: it will speed up imaging once the neuronal structures appear, and trigger a 1-minute burst of fast acquisitions once they look stably bright. Be specific so the classifier has something concrete to act on. +We are trying to image the birth of these neurons and their continued existence. Your +description will be read by a classifier that decides how to guide the imaging: it will speed +up imaging once the neuronal structures appear, and trigger a 1-minute burst of fast +acquisitions once they look stably bright. Be specific so the classifier has something +concrete to act on. You may initially see a faint outline of the embryo — autofluorescence from the body. -Eventually, you may see puncta-like structures in the embryo region of the image. If there are any puncta outside the embryo region, ignore them — those are likely gut granules. +Eventually, you may see puncta-like structures in the embryo region of the image. If there +are any puncta outside the embryo region, ignore them — those are likely gut granules. -The nerve cells, as they begin to express, will first appear as a faint blob, then a brighter blob, then a further-brighter blob, and will start to emit thread-like structures from them — the nerve body. These are what we want to image. +The nerve cells, as they begin to express, will first appear as a faint blob, then a brighter +blob, then a further-brighter blob, and will start to emit thread-like structures from them — +the nerve body. These are what we want to image. -The embryo may also eventually hatch. This can look like the embryo structure disappearing entirely from the field of view. Mention it if you see this. +The embryo may also eventually hatch. This can look like the embryo structure disappearing +entirely from the field of view. Mention it if you see this. Describe what you see in a few sentences of plain prose. """ -_CLASSIFIER_PROMPT = """You are reading a microscopist's description of an image of a C. elegans embryo expressing a dopaminergic-neuron reporter (dat-1::mNeonGreen). +_CLASSIFIER_PROMPT = """You are reading a microscopist's description of an image of a C. elegans +embryo expressing a dopaminergic-neuron reporter (dat-1::mNeonGreen). -Classify the description against the rubric below. Output ONLY a JSON object — no prose, no markdown fences. Your output drives the timelapse orchestrator's next imaging decision. +Classify the description against the rubric below. Output ONLY a JSON object — no prose, no +markdown fences. Your output drives the timelapse orchestrator's next imaging decision. The orchestrator can take these actions based on your output: - speed_up — accelerate imaging cadence to 1-minute intervals (when neurons begin to appear). -- burst — fire a 1-minute burst of fast acquisitions (when neurons are stably bright and well-resolved). +- burst — fire a 1-minute burst of fast acquisitions (when neurons are stably bright and + well-resolved). - ramp_down_power — step the 488 nm laser down (when signal saturates the camera). - stop — stop imaging this embryo (when it has hatched / left the field). -- none — keep the base cadence, take no action (when the description is uncertain or shows only background). +- none — keep the base cadence, take no action (when the description is uncertain or shows + only background). Schema: { - "intensity_level": "NONE" | "WEAK" | "MEDIUM" | "STRONG" | "SATURATING" | "UNCERTAIN", # signal strength — is the signal absent, faint, average, or stronger? - "structure_quality": "NONE" | "PARTIAL" | "GOOD" | "UNCERTAIN", # are the neuronal structures absent, emerging, fully emerged with good signal, or can't tell? - "has_hatched": true | false, # has the embryo hatched or left the field of view? — drives the stop action + "intensity_level": "NONE" | "WEAK" | "MEDIUM" | "STRONG" | "SATURATING" | "UNCERTAIN", + # signal strength — is the signal absent, faint, average, or stronger? + "structure_quality": "NONE" | "PARTIAL" | "GOOD" | "UNCERTAIN", + # are the neuronal structures absent, emerging, fully emerged with good signal, or can't tell? + "has_hatched": true | false, + # has the embryo hatched or left the field of view? — drives the stop action "reasoning": "one short sentence quoting which words drove your choice" } intensity_level rubric (drives speed_up / ramp_down_power): - NONE: description says no puncta in the embryo / blank / nothing above background. → none -- WEAK: description mentions 1 dim spot, OR signal explicitly described as "barely visible" / "could be noise" / "very faint". → none (could be noise) +- WEAK: description mentions 1 dim spot, OR signal explicitly described as "barely visible" / + "could be noise" / "very faint". → none (could be noise) - MEDIUM: description mentions 2+ clearly discrete bright spots above background. → speed_up - STRONG: description mentions multiple bright, well-resolved spots. → speed_up - SATURATING: description explicitly says the signal saturates the camera. → ramp_down_power -- UNCERTAIN: description hedges, doesn't address signal, or you cannot tell from the prose alone. → none +- UNCERTAIN: description hedges, doesn't address signal, or you cannot tell from the prose + alone. → none structure_quality rubric (drives burst): - NONE: no puncta described. → none - PARTIAL: puncta present but no neurites / no curved connecting traces. → none (not yet stable) -- GOOD: description mentions curved/elongated traces between puncta or recognizable neurite structure. → burst +- GOOD: description mentions curved/elongated traces between puncta or recognizable neurite + structure. → burst - UNCERTAIN: description is silent on structure or ambiguous. → none -has_hatched: true ONLY if the description explicitly says the embryo has hatched, OR the embryo structure has disappeared from the field of view. Default false. → stop +has_hatched: true ONLY if the description explicitly says the embryo has hatched, OR the +embryo structure has disappeared from the field of view. Default false. → stop -When the description is ambiguous, choose UNCERTAIN over guessing. When between two adjacent levels, choose the more conservative (lower) one — false negatives (missing onset by a timepoint) are cheap, false positives (burning photodose on noise) are expensive. +When the description is ambiguous, choose UNCERTAIN over guessing. When between two adjacent +levels, choose the more conservative (lower) one — false negatives (missing onset by a +timepoint) are cheap, false positives (burning photodose on noise) are expensive. Description to classify: --- @@ -100,10 +123,10 @@ class DopaminergicSignalDetector(Detector): def __init__( self, claude_client=None, - perceiver_model: Optional[str] = None, - classifier_model: Optional[str] = None, + perceiver_model: str | None = None, + classifier_model: str | None = None, # Back-compat: callers passing the old single-model kwarg. - model: Optional[str] = None, + model: str | None = None, ): self._claude = claude_client self._perceiver_model = perceiver_model or model @@ -113,11 +136,12 @@ def __init__( async def run( self, volume: np.ndarray, - context: Dict[str, Any], + context: dict[str, Any], ) -> DetectorResult: - from gently.settings import settings import anthropic + from gently.settings import settings + embryo_id = context.get("embryo_id", "?") timepoint = int(context.get("timepoint", 0)) start = time.time() @@ -155,8 +179,11 @@ async def run( detector_name=self.name, embryo_id=embryo_id, timepoint=timepoint, - findings={"intensity_level": "NONE", "structure_quality": "NONE", - "has_hatched": False}, + findings={ + "intensity_level": "NONE", + "structure_quality": "NONE", + "has_hatched": False, + }, reasoning="Empty / unreadable volume", elapsed_ms=(time.time() - start) * 1000, ) @@ -164,13 +191,17 @@ async def run( # Stage 1: Perceiver (image → prose) perceiver_model = self._perceiver_model or settings.models.perception description, perceiver_raw = await self._call_perceiver( - claude, perceiver_model, b64_image, + claude, + perceiver_model, + b64_image, ) # Stage 2: Classifier (prose → findings) classifier_model = self._classifier_model or settings.models.main findings, classifier_raw, parse_err = await self._call_classifier( - claude, classifier_model, description, + claude, + classifier_model, + description, ) return DetectorResult( @@ -190,7 +221,11 @@ async def run( error=parse_err, ) - except (anthropic.APIConnectionError, anthropic.RateLimitError, anthropic.APIStatusError) as e: + except ( + anthropic.APIConnectionError, + anthropic.RateLimitError, + anthropic.APIStatusError, + ) as e: logger.error("[%s] Claude API error for %s: %s", self.name, embryo_id, e) return DetectorResult( detector_name=self.name, @@ -210,35 +245,43 @@ async def run( ) async def _call_perceiver( - self, claude, model: str, b64_image: str, - ) -> Tuple[str, str]: + self, + claude, + model: str, + b64_image: str, + ) -> tuple[str, str]: """Stage 1: image → free prose description. Stateless — each timepoint is evaluated independently.""" response = await asyncio.to_thread( claude.messages.create, model=model, max_tokens=400, - messages=[{ - "role": "user", - "content": [ - {"type": "text", "text": _PERCEIVER_PROMPT}, - { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/png", - "data": b64_image, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": _PERCEIVER_PROMPT}, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": b64_image, + }, }, - }, - ], - }], + ], + } + ], ) raw = response.content[0].text if response.content else "" return raw.strip(), raw async def _call_classifier( - self, claude, model: str, description: str, - ) -> Tuple[Dict[str, Any], str, Optional[str]]: + self, + claude, + model: str, + description: str, + ) -> tuple[dict[str, Any], str, str | None]: """Stage 2: description prose → structured findings.""" prompt = _CLASSIFIER_PROMPT.replace("{DESCRIPTION}", description or "(no description)") response = await asyncio.to_thread( @@ -252,7 +295,7 @@ async def _call_classifier( return findings, raw, parse_err -def _volume_to_b64(volume: np.ndarray, calibration: Optional[Dict] = None) -> Optional[str]: +def _volume_to_b64(volume: np.ndarray, calibration: dict | None = None) -> str | None: """Max-project a volume (with optional dark/flat correction + edge ROI) and return a base64-encoded PNG. @@ -298,9 +341,9 @@ def _volume_to_b64(volume: np.ndarray, calibration: Optional[Dict] = None) -> Op dark = calibration.get("dark") flat = calibration.get("flat") if dark is not None and flat is not None and dark.shape == proj.shape: - denom = (flat.astype(np.float32) - dark.astype(np.float32)) + denom = flat.astype(np.float32) - dark.astype(np.float32) denom[denom <= 0] = 1.0 - proj = ((proj.astype(np.float32) - dark.astype(np.float32)) / denom * 255.0) + proj = (proj.astype(np.float32) - dark.astype(np.float32)) / denom * 255.0 proj = np.clip(proj, 0, 255).astype(np.uint8) calibrated = True @@ -389,7 +432,7 @@ def _parse_response(raw: str) -> tuple: return dict(_DEFAULT_FINDINGS), f"could not parse response: {raw[:120]!r}" -def _normalize_findings(d: Dict[str, Any]) -> Dict[str, Any]: +def _normalize_findings(d: dict[str, Any]) -> dict[str, Any]: """Coerce a parsed JSON dict to the expected schema, filling defaults for missing keys and validating enum values. UNCERTAIN is a valid value for both intensity_level and structure_quality and downstream diff --git a/gently/app/detectors/hatching.py b/gently/app/detectors/hatching.py index df4e7c00..62dd31cb 100644 --- a/gently/app/detectors/hatching.py +++ b/gently/app/detectors/hatching.py @@ -10,7 +10,7 @@ import asyncio import logging import time -from typing import Any, Dict, Optional +from typing import Any import numpy as np @@ -20,7 +20,8 @@ logger = logging.getLogger(__name__) -_HATCHING_PROMPT = """You are observing a C. elegans embryo on a microscope. Decide whether the embryo has HATCHED. +_HATCHING_PROMPT = """You are observing a C. elegans embryo on a microscope. Decide whether +the embryo has HATCHED. A HATCHED embryo: - Has visibly broken out of the eggshell @@ -48,19 +49,22 @@ class HatchingDetector(Detector): name = "hatching" - def __init__(self, claude_client=None, model: Optional[str] = None): + def __init__(self, claude_client=None, model: str | None = None): self._claude = claude_client self._model = model async def run( self, volume: np.ndarray, - context: Dict[str, Any], + context: dict[str, Any], ) -> DetectorResult: - from gently.settings import settings - import json, re + import json + import re + import anthropic + from gently.settings import settings + embryo_id = context.get("embryo_id", "?") timepoint = int(context.get("timepoint", 0)) start = time.time() @@ -90,17 +94,22 @@ async def run( claude.messages.create, model=self._model or settings.models.fast, max_tokens=200, - messages=[{ - "role": "user", - "content": [ - {"type": "text", "text": _HATCHING_PROMPT}, - {"type": "image", "source": { - "type": "base64", - "media_type": "image/png", - "data": b64_image, - }}, - ], - }], + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": _HATCHING_PROMPT}, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": b64_image, + }, + }, + ], + } + ], ) raw = response.content[0].text if response.content else "" @@ -131,7 +140,11 @@ async def run( error=err, ) - except (anthropic.APIConnectionError, anthropic.RateLimitError, anthropic.APIStatusError) as e: + except ( + anthropic.APIConnectionError, + anthropic.RateLimitError, + anthropic.APIStatusError, + ) as e: logger.error("[%s] Claude API error for %s: %s", self.name, embryo_id, e) return DetectorResult( detector_name=self.name, diff --git a/gently/app/detectors/perception_proxy.py b/gently/app/detectors/perception_proxy.py index 42962715..8cb501c6 100644 --- a/gently/app/detectors/perception_proxy.py +++ b/gently/app/detectors/perception_proxy.py @@ -10,7 +10,7 @@ import logging import time -from typing import Any, Dict, Optional +from typing import Any import numpy as np @@ -35,7 +35,7 @@ def perceiver(self): async def run( self, volume: np.ndarray, - context: Dict[str, Any], + context: dict[str, Any], ) -> DetectorResult: embryo_id = context.get("embryo_id", "?") timepoint = int(context.get("timepoint", 0)) @@ -66,8 +66,12 @@ async def run( try: from datetime import datetime + result = await perceiver( - embryo_id, timepoint, b64_image, datetime.now().isoformat(), + embryo_id, + timepoint, + b64_image, + datetime.now().isoformat(), ) except Exception as e: logger.exception("[%s] perceiver error for %s", self.name, embryo_id) diff --git a/gently/app/detectors/registry.py b/gently/app/detectors/registry.py index 35613623..219144f8 100644 --- a/gently/app/detectors/registry.py +++ b/gently/app/detectors/registry.py @@ -6,36 +6,39 @@ instance here for each acquired volume. """ -from typing import Any, Callable, Dict, Optional +from collections.abc import Callable from .base import Detector - # Factory signature: (claude_client=None, perceiver=None) -> Detector DetectorFactory = Callable[..., Detector] def _make_dopaminergic(*, claude_client=None, perceiver=None, **_) -> Detector: from .dopaminergic_signal import DopaminergicSignalDetector + return DopaminergicSignalDetector(claude_client=claude_client) def _make_hatching(*, claude_client=None, perceiver=None, **_) -> Detector: from .hatching import HatchingDetector + return HatchingDetector(claude_client=claude_client) def _make_blank(*, claude_client=None, perceiver=None, **_) -> Detector: from .blank_image import BlankImageDetector + return BlankImageDetector(claude_client=claude_client) def _make_perception(*, claude_client=None, perceiver=None, **_) -> Detector: from .perception_proxy import PerceptionProxy + return PerceptionProxy(perceiver=perceiver) -DETECTOR_REGISTRY: Dict[str, DetectorFactory] = { +DETECTOR_REGISTRY: dict[str, DetectorFactory] = { "dopaminergic_signal": _make_dopaminergic, "hatching": _make_hatching, "blank_image": _make_blank, @@ -44,11 +47,11 @@ def _make_perception(*, claude_client=None, perceiver=None, **_) -> Detector: def get_detector( - name: Optional[str], + name: str | None, *, claude_client=None, perceiver=None, -) -> Optional[Detector]: +) -> Detector | None: """Return a Detector instance for ``name``, or None if unknown. Unknown / None names return None so the orchestrator can choose diff --git a/gently/app/developmental_tracker.py b/gently/app/developmental_tracker.py index 64c47773..ef69f4d5 100644 --- a/gently/app/developmental_tracker.py +++ b/gently/app/developmental_tracker.py @@ -6,12 +6,12 @@ """ from gently.organisms.celegans.developmental_tracker import ( # noqa: F401 - DevelopmentalStage, + STAGE_CLASSIFICATION_PROMPT, STAGE_TIMING_20C, TIME_TO_HATCHING, TIMING_VARIABILITY, + DevelopmentalStage, + DevelopmentalTracker, HatchingPrediction, StageClassification, - STAGE_CLASSIFICATION_PROMPT, - DevelopmentalTracker, ) diff --git a/gently/app/device_state_monitor.py b/gently/app/device_state_monitor.py index beff9a54..8231a9c2 100644 --- a/gently/app/device_state_monitor.py +++ b/gently/app/device_state_monitor.py @@ -14,9 +14,11 @@ Watchdog -------- -The SSE iterator can silently stall in the agent process — most reliably -when a Qt window (napari) freezes the asyncio loop synchronously during a -tool call, but in principle any half-open TCP path can cause it. aiohttp's +The SSE iterator can silently stall in the agent process whenever a +half-open TCP path or a long synchronous tool call wedges the asyncio loop. +(Historically the worst offender was a Qt window — napari — blocking the +loop during a tool call; that path is gone now that all visualization is +in-browser, but the watchdog stays for general robustness.) aiohttp's async iterator won't raise on a stalled socket; the ``async for`` just waits forever. To recover, a sibling watchdog task tracks the timestamp of the last received event; if no event arrives within ``stale_timeout_sec`` @@ -30,7 +32,7 @@ import asyncio import logging import time -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING from gently.core.event_bus import EventType, get_event_bus from gently.core.service import Service @@ -67,7 +69,7 @@ class DeviceStateMonitor(Service): def __init__( self, - microscope: "DiSPIMMicroscope", + microscope: DiSPIMMicroscope, reconnect_delay_sec: float = 3.0, stale_timeout_sec: float = DEFAULT_STALE_TIMEOUT_SEC, watchdog_interval_sec: float = DEFAULT_WATCHDOG_INTERVAL_SEC, @@ -77,14 +79,14 @@ def __init__( self.reconnect_delay_sec = reconnect_delay_sec self.stale_timeout_sec = stale_timeout_sec self.watchdog_interval_sec = watchdog_interval_sec - self._task: Optional[asyncio.Task] = None - self._watchdog_task: Optional[asyncio.Task] = None + self._task: asyncio.Task | None = None + self._watchdog_task: asyncio.Task | None = None self._stop_requested = False # Monotonic timestamp of the last successfully-received event. The # watchdog reads this; the reader writes it under no lock because # asyncio is single-threaded (datetime/float assignment is atomic # at the bytecode level on CPython). - self._last_event_at: Optional[float] = None + self._last_event_at: float | None = None # Counts of staleness-triggered reconnects, useful for diagnostics. self._watchdog_kicks: int = 0 # Set by the watchdog right before it cancels the reader task, so @@ -100,7 +102,8 @@ async def on_start(self): self._last_event_at = time.monotonic() self._task = asyncio.create_task(self._run(), name="device-state-monitor") self._watchdog_task = asyncio.create_task( - self._watchdog(), name="device-state-watchdog", + self._watchdog(), + name="device-state-watchdog", ) async def on_stop(self): @@ -159,7 +162,8 @@ async def _run(self): except Exception as exc: logger.debug( "DeviceStateMonitor: stream ended (%s) — reconnecting in %.1fs", - exc, self.reconnect_delay_sec, + exc, + self.reconnect_delay_sec, ) if self._stop_requested: break @@ -196,7 +200,9 @@ async def _watchdog(self): logger.warning( "DeviceStateMonitor: stale stream (%.1fs since last event > " "%.1fs threshold) — forcing reconnect (kick #%d)", - age, self.stale_timeout_sec, self._watchdog_kicks, + age, + self.stale_timeout_sec, + self._watchdog_kicks, ) # Reset the timer FIRST so we don't trigger again before the # reader has a chance to reconnect and publish. diff --git a/gently/app/orchestration/exclusive.py b/gently/app/orchestration/exclusive.py index 44dc8140..db9804bd 100644 --- a/gently/app/orchestration/exclusive.py +++ b/gently/app/orchestration/exclusive.py @@ -16,11 +16,10 @@ import asyncio import logging from abc import ABC, abstractmethod -from collections import deque from dataclasses import dataclass, field from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple +from typing import Any import numpy as np @@ -30,14 +29,15 @@ @dataclass class ExclusiveResult: """Outcome of an ExclusiveAcquisition.run().""" + success: bool target_embryo_id: str request_id: str frames_captured: int = 0 duration_s: float = 0.0 - output_path: Optional[str] = None - extra: Dict[str, Any] = field(default_factory=dict) - error: Optional[str] = None + output_path: str | None = None + extra: dict[str, Any] = field(default_factory=dict) + error: str | None = None class ExclusiveAcquisition(ABC): @@ -51,17 +51,17 @@ class ExclusiveAcquisition(ABC): #: Human-readable kind name, for events / UI / persistence. kind: str = "exclusive" - def __init__(self, target_embryo_id: str, request_id: Optional[str] = None): + def __init__(self, target_embryo_id: str, request_id: str | None = None): self.target_embryo_id = target_embryo_id self.request_id = request_id or _make_request_id(self.kind, target_embryo_id) @abstractmethod - async def run(self, orchestrator) -> ExclusiveResult: - ... + async def run(self, orchestrator) -> ExclusiveResult: ... def _make_request_id(kind: str, embryo_id: str) -> str: import uuid + return f"{kind}_{embryo_id}_{uuid.uuid4().hex[:8]}" @@ -97,7 +97,7 @@ def __init__( frames: int = 60, mode: str = "1hz", num_slices: int = 1, - request_id: Optional[str] = None, + request_id: str | None = None, ): super().__init__(target_embryo_id=target_embryo_id, request_id=request_id) self.frames = frames @@ -118,12 +118,15 @@ async def run(self, orchestrator) -> ExclusiveResult: # Phase 10 dedicated burst events (Phase 7 originally rode # STATUS_CHANGED; now we have first-class types). - orchestrator._emit_event(EventType.BURST_START, { - "embryo_id": self.target_embryo_id, - "request_id": self.request_id, - "frames": self.frames, - "mode": self.mode, - }) + orchestrator._emit_event( + EventType.BURST_START, + { + "embryo_id": self.target_embryo_id, + "request_id": self.request_id, + "frames": self.frames, + "mode": self.mode, + }, + ) # Hardware kwargs from the embryo's calibration (mirrors _acquire_embryo) cal = embryo.calibration or {} @@ -171,7 +174,7 @@ async def run(self, orchestrator) -> ExclusiveResult: pass frames_data = result.get("frames") or [] - frames_captured: List[np.ndarray] = [ + frames_captured: list[np.ndarray] = [ np.asarray(f["volume"]) for f in frames_data if f.get("volume") is not None ] @@ -213,14 +216,17 @@ async def run(self, orchestrator) -> ExclusiveResult: success = bool(result.get("success")) and len(frames_captured) > 0 - orchestrator._emit_event(EventType.BURST_COMPLETE, { - "embryo_id": self.target_embryo_id, - "request_id": self.request_id, - "frames_captured": len(frames_captured), - "duration_s": duration_s, - "sustained_hz": sustained_hz, - "mp4_path": mp4_path, - }) + orchestrator._emit_event( + EventType.BURST_COMPLETE, + { + "embryo_id": self.target_embryo_id, + "request_id": self.request_id, + "frames_captured": len(frames_captured), + "duration_s": duration_s, + "sustained_hz": sustained_hz, + "mp4_path": mp4_path, + }, + ) return ExclusiveResult( success=success, @@ -248,16 +254,19 @@ async def _progress_ticker(self, orchestrator, frame_event_type): await asyncio.sleep(tick_interval) elapsed = (datetime.now() - start).total_seconds() approx_idx = min(self.frames - 1, int(elapsed / target_dt)) - orchestrator._emit_event(frame_event_type, { - "embryo_id": self.target_embryo_id, - "request_id": self.request_id, - "frame_idx": approx_idx, - "total_frames": self.frames, - "approximate": True, - }) + orchestrator._emit_event( + frame_event_type, + { + "embryo_id": self.target_embryo_id, + "request_id": self.request_id, + "frame_idx": approx_idx, + "total_frames": self.frames, + "approximate": True, + }, + ) -def _resolve_burst_dir(orchestrator, embryo_id: str, request_id: str) -> Optional[Path]: +def _resolve_burst_dir(orchestrator, embryo_id: str, request_id: str) -> Path | None: """Return ``bursts/{request_id}/`` under the embryo's session folder. Uses ``FileStore._session_dir`` so the short session_id resolves to the full @@ -269,7 +278,7 @@ def _resolve_burst_dir(orchestrator, embryo_id: str, request_id: str) -> Optiona sid = getattr(orchestrator, "_session_id", None) if store is None or sid is None: return None - session_dir: Optional[Path] = None + session_dir: Path | None = None for attr in ("_session_dir", "session_dir"): fn = getattr(store, attr, None) if callable(fn): @@ -281,7 +290,10 @@ def _resolve_burst_dir(orchestrator, embryo_id: str, request_id: str) -> Optiona break if session_dir is None: # Last-resort fallback: previous behaviour (will write to the shadow folder). - logger.warning("FileStore has no session_dir resolver; falling back to root/sessions/%s", sid) + logger.warning( + "FileStore has no session_dir resolver; falling back to root/sessions/%s", + sid, + ) session_dir = store.root / "sessions" / sid burst_dir = session_dir / "embryos" / embryo_id / "bursts" / request_id @@ -297,7 +309,7 @@ def _persist_burst_to_disk( request_id: str, mode: str, frames_requested: int, - frames_data: List[Dict[str, Any]], + frames_data: list[dict[str, Any]], loop_start: datetime, duration_s: float, sustained_hz: float, @@ -305,8 +317,8 @@ def _persist_burst_to_disk( galvo_center: float, piezo_amplitude: float, piezo_center: float, - laser_power_488_pct: Optional[float], -) -> Optional[Path]: + laser_power_488_pct: float | None, +) -> Path | None: """Save per-frame TIFFs + meta + projections + a burst.yaml manifest. Best-effort: any per-frame failure is logged and skipped, the rest still @@ -336,7 +348,7 @@ def _persist_burst_to_disk( pos = getattr(embryo, "stage_position", {}) or {} sid = getattr(orchestrator, "_session_id", None) - saved_frames: List[Dict[str, Any]] = [] + saved_frames: list[dict[str, Any]] = [] for i, fr in enumerate(frames_data, start=1): vol = fr.get("volume") if vol is None: @@ -367,6 +379,7 @@ def _persist_burst_to_disk( # Projection via the same helper used for regular volumes. try: from gently.core.imaging import generate_jpeg_projection + generate_jpeg_projection(arr, proj_path) except Exception as exc: logger.debug("[%s] burst frame %d projection failed: %s", embryo_id, i, exc) @@ -380,8 +393,12 @@ def _persist_burst_to_disk( "dtype": str(arr.dtype), "acquired_at": acquired_at, "metadata": { - "num_slices": int(getattr(embryo, "num_slices", 1)) if hasattr(embryo, "num_slices") else None, - "exposure_ms": float(getattr(embryo, "exposure_ms", 0.0)) if hasattr(embryo, "exposure_ms") else None, + "num_slices": int(getattr(embryo, "num_slices", 1)) + if hasattr(embryo, "num_slices") + else None, + "exposure_ms": float(getattr(embryo, "exposure_ms", 0.0)) + if hasattr(embryo, "exposure_ms") + else None, "acquisition_mode": "burst", "burst_mode": mode, "laser_power_488_pct": laser_power_488_pct, @@ -394,12 +411,14 @@ def _persist_burst_to_disk( _yaml.safe_dump(meta, f, sort_keys=False) except Exception as exc: logger.debug("[%s] burst frame %d meta write failed: %s", embryo_id, i, exc) - saved_frames.append({ - "frame_index": i, - "tif": tif_path.name, - "projection": f"projections/{proj_path.name}", - "acquired_at": acquired_at, - }) + saved_frames.append( + { + "frame_index": i, + "tif": tif_path.name, + "projection": f"projections/{proj_path.name}", + "acquired_at": acquired_at, + } + ) manifest = { "request_id": request_id, @@ -428,18 +447,23 @@ def _persist_burst_to_disk( except Exception as exc: logger.warning("[%s] burst manifest write failed: %s", embryo_id, exc) - logger.info("[%s] persisted %d/%d burst frames -> %s", - embryo_id, len(saved_frames), frames_requested, burst_dir) + logger.info( + "[%s] persisted %d/%d burst frames -> %s", + embryo_id, + len(saved_frames), + frames_requested, + burst_dir, + ) return burst_dir async def _maybe_generate_mp4( *, - burst_dir: Optional[Path], + burst_dir: Path | None, embryo_id: str, request_id: str, - frames: List[np.ndarray], -) -> Optional[str]: + frames: list[np.ndarray], +) -> str | None: """Best-effort MP4 generation using OpenCV's VideoWriter. Mirrors the codec-fallback pattern in :mod:`gently.app.video_maker` @@ -461,7 +485,7 @@ async def _maybe_generate_mp4( # Reduce 3D frames to 2D max-projections, normalize to uint8, and # convert to 3-channel BGR for VideoWriter. - proj_frames: List[np.ndarray] = [] + proj_frames: list[np.ndarray] = [] for f in frames: v = np.squeeze(f) if v.ndim == 4: @@ -483,10 +507,10 @@ async def _maybe_generate_mp4( height, width = proj_frames[0].shape[:2] codecs = ( - ('mp4v', '.mp4'), - ('avc1', '.mp4'), - ('XVID', '.avi'), - ('MJPG', '.avi'), + ("mp4v", ".mp4"), + ("avc1", ".mp4"), + ("XVID", ".avi"), + ("MJPG", ".avi"), ) writer = None chosen_codec = None @@ -511,7 +535,9 @@ async def _maybe_generate_mp4( writer.release() logger.info( "Wrote burst movie: %s (%d frames, codec=%s)", - chosen_path, len(proj_frames), chosen_codec, + chosen_path, + len(proj_frames), + chosen_codec, ) return str(chosen_path) except Exception as e: diff --git a/gently/app/orchestration/monitoring_modes.py b/gently/app/orchestration/monitoring_modes.py index 90f5e555..f84b034e 100644 --- a/gently/app/orchestration/monitoring_modes.py +++ b/gently/app/orchestration/monitoring_modes.py @@ -18,8 +18,8 @@ Activate a mode via ``TimelapseOrchestrator.enable_monitoring_mode(name)``. """ +from collections.abc import Callable from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Optional @dataclass @@ -31,11 +31,12 @@ class MonitoringMode: the active modes so the UI / persistence layer can show what anticipation logic is in play. """ + name: str = "" description: str = "" - applies_to_roles: List[str] = field(default_factory=list) + applies_to_roles: list[str] = field(default_factory=list) - def activate(self, orchestrator, embryo_ids: Optional[List[str]] = None): + def activate(self, orchestrator, embryo_ids: list[str] | None = None): """Install the mode's rules / detectors on the orchestrator.""" raise NotImplementedError @@ -54,6 +55,7 @@ class ExpressionMonitoringMode(MonitoringMode): Burst on good structure is handled by Phase 7's BurstAcquisition. """ + fast_interval: float = 60.0 onset_confirm_timepoints: int = 2 rampdown_step_pct: float = 1.0 @@ -80,7 +82,7 @@ def __post_init__(self): if not self.applies_to_roles: self.applies_to_roles = ["test"] - def activate(self, orchestrator, embryo_ids: Optional[List[str]] = None): + def activate(self, orchestrator, embryo_ids: list[str] | None = None): orchestrator.add_test_onset_speedup( fast_interval=self.fast_interval, confirm_timepoints=self.onset_confirm_timepoints, @@ -108,6 +110,7 @@ class PreTerminalMonitoringMode(MonitoringMode): Wraps ``enable_pre_hatching_speedup`` declaratively. Uses the organism's PRE_TERMINAL_SPEEDUP_STAGE as the trigger. """ + fast_interval: float = 30.0 def __post_init__(self): @@ -119,7 +122,7 @@ def __post_init__(self): f"to {self.fast_interval}s on detection." ) - def activate(self, orchestrator, embryo_ids: Optional[List[str]] = None): + def activate(self, orchestrator, embryo_ids: list[str] | None = None): orchestrator.add_pre_terminal_speedup(fast_interval=self.fast_interval) @@ -133,20 +136,20 @@ def __post_init__(self): if not self.description: self.description = "No active anticipation; standard timelapse cadence." - def activate(self, orchestrator, embryo_ids: Optional[List[str]] = None): + def activate(self, orchestrator, embryo_ids: list[str] | None = None): pass # Public registry. ``enable_monitoring_mode(name)`` on the orchestrator # looks up by key. -MONITORING_MODES: Dict[str, Callable[[], MonitoringMode]] = { +MONITORING_MODES: dict[str, Callable[[], MonitoringMode]] = { "idle": IdleMode, "expression_monitoring": ExpressionMonitoringMode, "pre_terminal_monitoring": PreTerminalMonitoringMode, } -def get_monitoring_mode(name: str) -> Optional[MonitoringMode]: +def get_monitoring_mode(name: str) -> MonitoringMode | None: """Return an instance of the named monitoring mode, or None if unknown.""" factory = MONITORING_MODES.get(name) if factory is None: diff --git a/gently/app/orchestration/timelapse.py b/gently/app/orchestration/timelapse.py index 28ae86b8..f4a2692a 100644 --- a/gently/app/orchestration/timelapse.py +++ b/gently/app/orchestration/timelapse.py @@ -11,37 +11,38 @@ import asyncio import json import logging +import traceback from collections import deque +from collections.abc import Callable from datetime import datetime, timedelta from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING -import traceback +from typing import TYPE_CHECKING, Any, Optional import numpy as np from gently.core import EventType, get_event_bus from gently.core.imaging import ( - projection_three_view, - compute_crop_bounds, apply_crop_bounds, + compute_crop_bounds, image_to_base64, normalize_to_uint8, + projection_three_view, ) -from gently.settings import settings from gently.harness.error_log import GlobalErrorLog +from gently.harness.state import EmbryoState from gently.organisms import get_organism +from gently.settings import settings # Re-export models for backward compatibility from .timelapse_models import ( - StopConditionType, + BurstRule, IntervalRule, PowerRule, - BurstRule, StopCondition, - TimelapseStatus, + StopConditionType, TimelapseState, + TimelapseStatus, ) -from gently.harness.state import EmbryoState if TYPE_CHECKING: from gently.core.file_store import FileStore @@ -69,8 +70,8 @@ def __init__( microscope_client, experiment_state, perceiver=None, - on_volume_callback: Optional[Callable] = None, - session_id: Optional[str] = None, + on_volume_callback: Callable | None = None, + session_id: str | None = None, store: Optional["FileStore"] = None, claude_client=None, ): @@ -103,11 +104,11 @@ def __init__( # Trace file storage (writes JSON files to disk) self._session_id = session_id - self._trace_dir: Optional[Path] = None + self._trace_dir: Path | None = None # Unified store for perception persistence self._store = store - self._perception_run_id: Optional[int] = None + self._perception_run_id: int | None = None # Event bus for status updates self._event_bus = get_event_bus() @@ -116,20 +117,20 @@ def __init__( # Holds references to EmbryoState objects from self.experiment.embryos # for embryos currently active in this timelapse. Same objects, no # duplication of state. - self._embryo_states: Dict[str, EmbryoState] = {} + self._embryo_states: dict[str, EmbryoState] = {} self._status = TimelapseStatus.IDLE - self._started_at: Optional[datetime] = None + self._started_at: datetime | None = None self._total_timepoints = 0 self._current_round = 0 - self._error_message: Optional[str] = None + self._error_message: str | None = None # Round-based scheduling (global timing for all embryos) self._base_interval_seconds: float = 120.0 self._total_pause_duration: timedelta = timedelta(0) - self._pause_start: Optional[datetime] = None + self._pause_start: datetime | None = None # Control - self._acquisition_task: Optional[asyncio.Task] = None + self._acquisition_task: asyncio.Task | None = None self._stop_requested = False # In-flight perception tasks. Acquisition fires perception tasks @@ -138,59 +139,63 @@ def __init__( # not VLM API latency. Tasks are tracked here so (a) stop() can # await them before returning and (b) we can cap concurrency if # the Claude API rate limit becomes a problem. - self._perception_tasks: Set[asyncio.Task] = set() + self._perception_tasks: set[asyncio.Task] = set() # Interval adjustment rules - self._interval_rules: List[IntervalRule] = [] - self._applied_rules: Dict[str, Set[str]] = {} # embryo_id -> set of applied rule names - self._interval_rule_consecutive: Dict[str, Dict[str, int]] = {} # embryo_id -> {rule_name: consecutive matches} + self._interval_rules: list[IntervalRule] = [] + self._applied_rules: dict[str, set[str]] = {} # embryo_id -> set of applied rule names + self._interval_rule_consecutive: dict[ + str, dict[str, int] + ] = {} # embryo_id -> {rule_name: consecutive matches} # Phase 5 reactive control: per-line power rules (sticky-downward # ramp on saturation, etc). Evaluated alongside _interval_rules in # _check_adaptive_rules / _check_interval_rules. - self._power_rules: List[PowerRule] = [] - self._power_rule_consecutive: Dict[str, Dict[str, int]] = {} # embryo_id -> {rule_name: count} + self._power_rules: list[PowerRule] = [] + self._power_rule_consecutive: dict[ + str, dict[str, int] + ] = {} # embryo_id -> {rule_name: count} # Auto-burst rules: fire queue_burst() once per embryo on a # structure/intensity predicate match. One-time semantics are # enforced by queue_burst via _burst_applied. - self._burst_rules: List[Any] = [] # List[BurstRule] without forward import - self._burst_rule_consecutive: Dict[str, Dict[str, int]] = {} + self._burst_rules: list[Any] = [] # List[BurstRule] without forward import + self._burst_rule_consecutive: dict[str, dict[str, int]] = {} # Active monitoring modes (declarative anticipation bundles). - self._active_monitoring_modes: List[Any] = [] + self._active_monitoring_modes: list[Any] = [] # Phase 6: calibration data from CalibrationEmbryos. Populated by # ``run_calibration_pipelines()``; consumed by _run_detector as # detector context["calibration"]. - self._calibration_data: Optional[Dict[str, Any]] = None + self._calibration_data: dict[str, Any] | None = None # Phase 8: per-role photodose budgets. ``base_dose_budget_ms`` is the # ceiling for a 1× role (e.g. test). Other roles get this scaled by # their EmbryoRole.photodose_budget_multiplier (e.g. calibration # gets 10×). None = no enforcement. See ``set_photodose_budget``. - self._dose_budget_base_ms: Optional[float] = None + self._dose_budget_base_ms: float | None = None # Embryos that have hit the budget — set to paused once, emit a # STATUS_CHANGED, then leave alone (don't spam). - self._dose_budget_exceeded: Set[str] = set() + self._dose_budget_exceeded: set[str] = set() # Async cadence state (Phase 4). Single-burst exclusion handled by # _burst_in_progress: while set to an embryo_id, _run_loop skips # all other embryos and runs only the burst executor (Phase 7). - self._burst_in_progress: Optional[str] = None + self._burst_in_progress: str | None = None # Phase 7: exclusive acquisitions (bursts, etc) — FIFO queue. self._exclusive_queue: deque = deque() # Track which embryos have already had a burst applied # (one-time per embryo by default). - self._burst_applied: Set[str] = set() + self._burst_applied: set[str] = set() # Global error log for cross-embryo hardware error correlation self.global_error_log = GlobalErrorLog() async def start( self, - embryo_ids: List[str], + embryo_ids: list[str] | None = None, stop_condition: str = "manual", base_interval_seconds: float = 120.0, condition_value: Any = None, @@ -218,7 +223,10 @@ async def start( Status message """ if self._status == TimelapseStatus.RUNNING: - return "Timelapse already running. Use stop() first or modify_embryo() to change parameters." + return ( + "Timelapse already running. Use stop() first or modify_embryo() to change" + " parameters." + ) # Parse stop condition stop_cond = self._parse_stop_condition(stop_condition, condition_value) @@ -286,7 +294,10 @@ async def start( method="vlm_stage_classification", model_name=settings.models.perception, source="live", - config={"stop_condition": stop_condition, "interval": base_interval_seconds}, + config={ + "stop_condition": stop_condition, + "interval": base_interval_seconds, + }, ) logger.info(f"Created perception run {self._perception_run_id} in FileStore") except Exception as e: @@ -309,11 +320,14 @@ async def start( self._acquisition_task = asyncio.create_task(self._run_loop()) # Emit event - self._emit_event(EventType.ACQUISITION_STARTED, { - 'embryo_ids': embryo_ids, - 'stop_condition': stop_condition, - 'interval_seconds': base_interval_seconds, - }) + self._emit_event( + EventType.ACQUISITION_STARTED, + { + "embryo_ids": embryo_ids, + "stop_condition": stop_condition, + "interval_seconds": base_interval_seconds, + }, + ) logger.info(f"Started timelapse for {len(embryo_ids)} embryos") @@ -324,11 +338,7 @@ async def start( f"Use get_timelapse_status to monitor progress." ) - def _parse_stop_condition( - self, - condition: str, - value: Any = None - ) -> StopCondition: + def _parse_stop_condition(self, condition: str, value: Any = None) -> StopCondition: """Parse stop condition string into StopCondition object""" condition = condition.lower().strip() @@ -399,31 +409,38 @@ def _is_eligible(self, embryo) -> bool: if self._dose_budget_base_ms is not None and embryo.id not in self._dose_budget_exceeded: from gently.harness.roles import REGISTRY as _ROLE_REGISTRY + role_def = _ROLE_REGISTRY.get(getattr(embryo, "role", "test")) mult = role_def.photodose_budget_multiplier if role_def else 1.0 budget = self._dose_budget_base_ms * mult if embryo.total_exposure_ms > budget: embryo.cadence_phase = "paused" self._dose_budget_exceeded.add(embryo.id) - self._emit_event(EventType.STATUS_CHANGED, { - "embryo_id": embryo.id, - "change": "photodose_budget_exceeded", - "role": embryo.role, - "total_exposure_ms": embryo.total_exposure_ms, - "budget_ms": budget, - "multiplier": mult, - }) + self._emit_event( + EventType.STATUS_CHANGED, + { + "embryo_id": embryo.id, + "change": "photodose_budget_exceeded", + "role": embryo.role, + "total_exposure_ms": embryo.total_exposure_ms, + "budget_ms": budget, + "multiplier": mult, + }, + ) logger.warning( "[%s] photodose budget exceeded: %.0f ms > %.0f ms " "(role=%s, mult=%.1fx). Pausing.", - embryo.id, embryo.total_exposure_ms, budget, - embryo.role, mult, + embryo.id, + embryo.total_exposure_ms, + budget, + embryo.role, + mult, ) return False return True - def set_photodose_budget(self, base_dose_budget_ms: Optional[float]) -> str: + def set_photodose_budget(self, base_dose_budget_ms: float | None) -> str: """Set the per-role photodose ceiling. Each embryo's ``total_exposure_ms`` is checked against @@ -436,10 +453,7 @@ def set_photodose_budget(self, base_dose_budget_ms: Optional[float]) -> str: self._dose_budget_exceeded.clear() if base_dose_budget_ms is None: return "Photodose budget enforcement disabled." - return ( - f"Photodose budget set: {base_dose_budget_ms:.0f} ms base " - f"(scaled per role)." - ) + return f"Photodose budget set: {base_dose_budget_ms:.0f} ms base (scaled per role)." def _pick_next_due(self) -> tuple: """Pick the most-overdue eligible embryo. @@ -486,10 +500,10 @@ def transition_cadence( self, embryo, *, - new_phase: Optional[str] = None, - new_interval_seconds: Optional[float] = None, + new_phase: str | None = None, + new_interval_seconds: float | None = None, reschedule: bool = True, - reason: Optional[str] = None, + reason: str | None = None, ) -> None: """Public API for cadence transitions (Phase 5 / agent tools use this). @@ -514,23 +528,25 @@ def transition_cadence( if reschedule: self._reschedule(embryo) - self._emit_event(EventType.EMBRYO_CADENCE_CHANGED, { - "embryo_id": embryo.id, - "old_phase": old_phase, - "new_phase": embryo.cadence_phase, - "old_interval_s": old_interval, - "new_interval_s": embryo.interval_seconds, - "next_due_at": embryo.next_due_at.isoformat() if embryo.next_due_at else None, - "reason": reason, - }) + self._emit_event( + EventType.EMBRYO_CADENCE_CHANGED, + { + "embryo_id": embryo.id, + "old_phase": old_phase, + "new_phase": embryo.cadence_phase, + "old_interval_s": old_interval, + "new_interval_s": embryo.interval_seconds, + "next_due_at": embryo.next_due_at.isoformat() if embryo.next_due_at else None, + "reason": reason, + }, + ) async def _finalize_timelapse(self): """Drain perception tasks, log trace count, emit completion event.""" if self._perception_tasks: pending = list(self._perception_tasks) logger.info( - f"Draining {len(pending)} perception task(s) before " - f"completing timelapse..." + f"Draining {len(pending)} perception task(s) before completing timelapse..." ) _done, still_pending = await asyncio.wait(pending, timeout=60.0) if still_pending: @@ -549,13 +565,17 @@ async def _finalize_timelapse(self): self._finalize_perception_run("completed") - self._emit_event(EventType.ACQUISITION_COMPLETED, { - "total_timepoints": self._total_timepoints, - "duration_minutes": ( - (datetime.now() - self._started_at).total_seconds() / 60 - if self._started_at else 0 - ), - }) + self._emit_event( + EventType.ACQUISITION_COMPLETED, + { + "total_timepoints": self._total_timepoints, + "duration_minutes": ( + (datetime.now() - self._started_at).total_seconds() / 60 + if self._started_at + else 0 + ), + }, + ) async def _run_loop(self): """Async per-embryo acquisition loop (Phase 4). @@ -591,7 +611,10 @@ async def _run_loop(self): for eid, e in self._embryo_states.items(): if eid == next_op.target_embryo_id: continue - if getattr(e, "cadence_phase", "normal") not in ("burst", "paused"): + if getattr(e, "cadence_phase", "normal") not in ( + "burst", + "paused", + ): e.cadence_phase = "paused" paused_ids.append(eid) # Bursting embryo's phase reflects what it's doing. @@ -605,7 +628,9 @@ async def _run_loop(self): except Exception as e: logger.error( "Exclusive op %s failed: %s", - next_op.request_id, e, exc_info=True, + next_op.request_id, + e, + exc_info=True, ) finally: self._burst_in_progress = None @@ -631,6 +656,7 @@ async def _run_loop(self): # interval; rely on it to reschedule the embryo. if target_emb is not None and target_emb.cadence_phase == "burst": from gently.harness.roles import REGISTRY as _ROLE_REGISTRY + role_def = _ROLE_REGISTRY.get(target_emb.role) new_interval = ( role_def.default_cadence_seconds @@ -647,7 +673,8 @@ async def _run_loop(self): # All embryos complete? active_count = sum( - 1 for e in self._embryo_states.values() + 1 + for e in self._embryo_states.values() if not e.is_complete and not e.should_skip ) if active_count == 0: @@ -705,11 +732,14 @@ async def _run_loop(self): self._status = TimelapseStatus.FAILED self._error_message = str(e) self._finalize_perception_run("failed", error_message=str(e)) - self._emit_event(EventType.ACQUISITION_FAILED, { - 'error': str(e), - }) + self._emit_event( + EventType.ACQUISITION_FAILED, + { + "error": str(e), + }, + ) - async def _acquire_embryo(self, embryo_state: EmbryoState, round_time: datetime = None): + async def _acquire_embryo(self, embryo_state: EmbryoState, round_time: datetime | None = None): """Acquire a single volume for one embryo Parameters @@ -730,20 +760,20 @@ async def _acquire_embryo(self, embryo_state: EmbryoState, round_time: datetime try: # Move to embryo position pos = embryo.stage_position - if pos and pos.get('x') is not None: - await self.client.move_to_position(pos['x'], pos['y']) + if pos and pos.get("x") is not None: + await self.client.move_to_position(pos["x"], pos["y"]) # Get calibration parameters cal = embryo.calibration or {} - galvo_amplitude = cal.get('galvo_amplitude', 0.5) - galvo_center = cal.get('galvo_center', 0.0) - piezo_amplitude = cal.get('piezo_amplitude', 25.0) - piezo_center = cal.get('piezo_center', 50.0) + galvo_amplitude = cal.get("galvo_amplitude", 0.5) + galvo_center = cal.get("galvo_center", 0.0) + piezo_amplitude = cal.get("piezo_amplitude", 25.0) + piezo_center = cal.get("piezo_center", 50.0) # Acquire based on mode (volume or snap) - acquisition_mode = getattr(embryo, 'acquisition_mode', 'volume') + acquisition_mode = getattr(embryo, "acquisition_mode", "volume") - if acquisition_mode == 'snap': + if acquisition_mode == "snap": # Single 2D lightsheet image result = await self.client.capture_lightsheet_image( piezo_position=piezo_center, @@ -760,12 +790,12 @@ async def _acquire_embryo(self, embryo_state: EmbryoState, round_time: datetime galvo_center=galvo_center, piezo_amplitude=piezo_amplitude, piezo_center=piezo_center, - laser_power_488_pct=getattr(embryo, 'laser_power_488_pct', None), + laser_power_488_pct=getattr(embryo, "laser_power_488_pct", None), ) num_frames = embryo.num_slices exposure_ms = embryo.exposure_ms - if result.get('success'): + if result.get("success"): # Update state embryo_state.timepoints_acquired += 1 embryo_state.error_count = 0 @@ -780,7 +810,7 @@ async def _acquire_embryo(self, embryo_state: EmbryoState, round_time: datetime embryo.record_exposure( exposure_ms=exposure_ms, num_frames=num_frames, - timestamp=acquisition_timestamp + timestamp=acquisition_timestamp, ) # Note: VOLUME_ACQUIRED event is emitted by the callback (agent.on_volume_acquired) @@ -791,17 +821,21 @@ async def _acquire_embryo(self, embryo_state: EmbryoState, round_time: datetime volume_uids = None # Track UIDs from storage for perception events if self.on_volume_callback: # Get data - 'volume' for volume mode, 'image' for snap mode - data = result.get('volume') if acquisition_mode == 'volume' else result.get('image') + data = ( + result.get("volume") + if acquisition_mode == "volume" + else result.get("image") + ) if data is not None: # Ensure data is numpy array if not isinstance(data, np.ndarray): data = np.array(data) # For snap mode (2D), add Z dimension so store_volume works - if acquisition_mode == 'snap' and data.ndim == 2: + if acquisition_mode == "snap" and data.ndim == 2: data = data[np.newaxis, ...] # Add Z dimension: (Y,X) -> (1,Y,X) volume_data = data # Pass volume_path if available (zero-copy from device) - volume_path = result.get('volume_path') + volume_path = result.get("volume_path") # Callback may return UIDs from storage callback_result = await self.on_volume_callback( embryo_id, @@ -837,13 +871,11 @@ async def _acquire_embryo(self, embryo_state: EmbryoState, round_time: datetime # perception - we'll re-check inside _run_perception too) await self._check_stop_condition(embryo_state) - logger.debug( - f"Acquired t={embryo_state.timepoints_acquired} for {embryo_id}" - ) + logger.debug(f"Acquired t={embryo_state.timepoints_acquired} for {embryo_id}") else: embryo_state.error_count += 1 - embryo_state.last_error = result.get('error', 'Unknown error') + embryo_state.last_error = result.get("error", "Unknown error") # Log to global error log for cross-embryo correlation self.global_error_log.log_error( @@ -851,7 +883,7 @@ async def _acquire_embryo(self, embryo_state: EmbryoState, round_time: datetime embryo_id=embryo_id, timepoint=embryo_state.timepoints_acquired, error_type="acquisition", - message=embryo_state.last_error + message=embryo_state.last_error, ) # Stop after too many errors @@ -874,7 +906,7 @@ async def _acquire_embryo(self, embryo_state: EmbryoState, round_time: datetime timepoint=embryo_state.timepoints_acquired, error_type="acquisition_exception", message=str(e), - exception=e + exception=e, ) async def _check_stop_condition(self, embryo_state: EmbryoState): @@ -889,10 +921,13 @@ async def _check_stop_condition(self, embryo_state: EmbryoState): # from the role definition), the canonical signature of an # embryo that hatched / drifted out of the FOV. from gently.harness.roles import REGISTRY as _ROLE_REGISTRY + role = _ROLE_REGISTRY.get(getattr(embryo_state, "role", "test")) - if (role is not None - and role.no_object_consecutive_terminal is not None - and embryo_state.consecutive_no_object >= role.no_object_consecutive_terminal): + if ( + role is not None + and role.no_object_consecutive_terminal is not None + and embryo_state.consecutive_no_object >= role.no_object_consecutive_terminal + ): embryo_state.is_complete = True embryo_state.completion_reason = ( f"no_object x {embryo_state.consecutive_no_object} consecutive " @@ -900,6 +935,14 @@ async def _check_stop_condition(self, embryo_state: EmbryoState): f"likely hatched / out of FOV)" ) logger.info(f"Embryo {embryo_state.id} stopped: {embryo_state.completion_reason}") + self._emit_event( + EventType.EMBRYO_TERMINATED, + { + "embryo_id": embryo_state.id, + "completion_reason": embryo_state.completion_reason, + "timepoints_acquired": embryo_state.timepoints_acquired, + }, + ) return # Check all conditions (primary + additional) with OR logic @@ -909,13 +952,19 @@ async def _check_stop_condition(self, embryo_state: EmbryoState): embryo_state.is_complete = True embryo_state.completion_reason = reason logger.info(f"Embryo {embryo_state.id} stopped: {reason}") + self._emit_event( + EventType.EMBRYO_TERMINATED, + { + "embryo_id": embryo_state.id, + "completion_reason": reason, + "timepoints_acquired": embryo_state.timepoints_acquired, + }, + ) return # Stop on first matching condition def _evaluate_single_condition( - self, - cond: StopCondition, - embryo_state: EmbryoState - ) -> Optional[str]: + self, cond: StopCondition, embryo_state: EmbryoState + ) -> str | None: """ Evaluate a single stop condition. @@ -939,7 +988,8 @@ def _evaluate_single_condition( # Stop when every role='test' embryo in the active timelapse # has hatched (via Claude detector setting hatching_status). test_states = [ - e for e in self._embryo_states.values() + e + for e in self._embryo_states.values() if getattr(e, "role", "test") == "test" and not e.should_skip ] if not test_states: @@ -950,14 +1000,12 @@ def _evaluate_single_condition( if embryo_state.detection_triggered_at is None: embryo_state.detection_triggered_at = embryo_state.timepoints_acquired embryo_state.detection_type = "all_test_hatched" - tps_since = ( - embryo_state.timepoints_acquired - - embryo_state.detection_triggered_at - ) + tps_since = embryo_state.timepoints_acquired - embryo_state.detection_triggered_at if tps_since >= cond.confirm_timepoints: - return ( - f"all test embryos hatched" - + (f" (+{cond.confirm_timepoints} confirm)" if cond.confirm_timepoints > 0 else "") + return "all test embryos hatched" + ( + f" (+{cond.confirm_timepoints} confirm)" + if cond.confirm_timepoints > 0 + else "" ) return None @@ -970,9 +1018,11 @@ def _evaluate_single_condition( if elapsed_hours >= cond.value: return f"reached {cond.value}h duration" - elif cond.condition_type in (StopConditionType.STAGE_BASED, - StopConditionType.HATCHING, - StopConditionType.COMMA_STAGE): + elif cond.condition_type in ( + StopConditionType.STAGE_BASED, + StopConditionType.HATCHING, + StopConditionType.COMMA_STAGE, + ): # Generic stage-based stop: check if current stage is in target set target = cond.target_stages or set() @@ -989,7 +1039,8 @@ def _evaluate_single_condition( logger.info( f"Target stage '{current_stage}' detected for {embryo_state.id} " f"at t{embryo_state.timepoints_acquired}, " - f"will acquire {cond.confirm_timepoints} more confirmation timepoints" + f"will acquire {cond.confirm_timepoints} more" + " confirmation timepoints" ) # Check if we've acquired enough confirmation timepoints @@ -1009,9 +1060,11 @@ def _evaluate_single_condition( # reports 'hatched'. Session has no is_complete() — we # check terminal-stage membership directly. organism = get_organism() - if (current_stage - and current_stage in organism.TERMINAL_STAGES - and target & organism.TERMINAL_STAGES): + if ( + current_stage + and current_stage in organism.TERMINAL_STAGES + and target & organism.TERMINAL_STAGES + ): return f"terminal stage '{current_stage}' reached (perception)" # Fallback: check legacy hatching_status (for manual marking) @@ -1019,7 +1072,7 @@ def _evaluate_single_condition( if target & organism.TERMINAL_STAGES: embryo = self.experiment.embryos.get(embryo_state.id) if embryo: - hatched_via_status = embryo.hatching_status.get('hatched', False) + hatched_via_status = embryo.hatching_status.get("hatched", False) if hatched_via_status: return "terminal stage detected (manual)" @@ -1044,7 +1097,8 @@ def get_status(self) -> TimelapseState: if self._status == TimelapseStatus.RUNNING and self._embryo_states: due_times = [ - e.next_due_at for e in self._embryo_states.values() + e.next_due_at + for e in self._embryo_states.values() if not e.is_complete and not e.should_skip and getattr(e, "cadence_phase", "normal") != "paused" @@ -1052,9 +1106,7 @@ def get_status(self) -> TimelapseState: ] if due_times: next_round_time = min(due_times) - seconds_until_next = max( - 0, (next_round_time - datetime.now()).total_seconds() - ) + seconds_until_next = max(0, (next_round_time - datetime.now()).total_seconds()) return TimelapseState( status=self._status, @@ -1071,7 +1123,7 @@ def get_status(self) -> TimelapseState: async def add_embryo( self, embryo_id: str, - stop_condition: Optional[str] = None, + stop_condition: str | None = None, condition_value: Any = None, ) -> str: """ @@ -1115,6 +1167,7 @@ async def add_embryo( # Add embryo to the timelapse: set runtime fields on its EmbryoState # and register the reference in _embryo_states. from gently.harness.roles import REGISTRY as ROLE_REGISTRY + embryo = self.experiment.embryos[embryo_id] embryo.stop_condition = stop_cond embryo.is_complete = False @@ -1134,8 +1187,7 @@ async def add_embryo( elif embryo.interval_seconds is None: role_def = ROLE_REGISTRY.get(embryo.role) embryo.interval_seconds = ( - role_def.default_cadence_seconds if role_def is not None - else 300.0 + role_def.default_cadence_seconds if role_def is not None else 300.0 ) embryo.cadence_phase = "normal" embryo.next_due_at = datetime.now() # picked up on next loop tick @@ -1232,7 +1284,9 @@ def modify_interval(self, new_interval_seconds: float) -> str: logger.info( "Broadcast interval: base %ss -> %ss; %d embryos rescheduled", - old_base, new_interval_seconds, len(changed), + old_base, + new_interval_seconds, + len(changed), ) return ( f"Interval changed to {new_interval_seconds}s across " @@ -1242,7 +1296,7 @@ def modify_interval(self, new_interval_seconds: float) -> str: async def modify_embryo( self, embryo_id: str, - stop_condition: Optional[str] = None, + stop_condition: str | None = None, condition_value: Any = None, ) -> str: """ @@ -1277,7 +1331,10 @@ async def modify_embryo( changes.append(f"stop condition: {stop_condition}") if not changes: - return f"No changes specified for {embryo_id}. Note: use modify_interval() to change acquisition interval." + return ( + f"No changes specified for {embryo_id}." + " Note: use modify_interval() to change acquisition interval." + ) return f"Modified {embryo_id}: {', '.join(changes)}" @@ -1372,10 +1429,13 @@ async def stop(self, reason: str = "user_request") -> str: "total_timepoints": self._total_timepoints, "embryo_count": len(self._embryo_states), }, - source="timelapse_orchestrator" + source="timelapse_orchestrator", ) - return f"Timelapse stopped (reason: {reason}). Acquired {self._total_timepoints} total timepoints." + return ( + f"Timelapse stopped (reason: {reason})." + f" Acquired {self._total_timepoints} total timepoints." + ) async def pause(self) -> str: """Pause the timelapse""" @@ -1417,7 +1477,7 @@ def add_speedup_on_stage( self, stage_name: str, new_interval_seconds: float = 30.0, - embryo_ids: Optional[List[str]] = None, + embryo_ids: list[str] | None = None, ): """ Add a rule to speed up imaging when a stage is reached @@ -1476,7 +1536,7 @@ def add_test_onset_speedup( *, fast_interval: float = 60.0, confirm_timepoints: int = 2, - embryo_ids: Optional[List[str]] = None, + embryo_ids: list[str] | None = None, ): """Install the canonical 'TestEmbryo signal-onset → fast cadence' rule. @@ -1489,7 +1549,8 @@ def add_test_onset_speedup( """ if embryo_ids is None: embryo_ids = [ - eid for eid, e in self._embryo_states.items() + eid + for eid, e in self._embryo_states.items() if getattr(e, "role", "test") == "test" ] rule = IntervalRule( @@ -1502,9 +1563,10 @@ def add_test_onset_speedup( ) self.add_interval_rule(rule) logger.info( - "Test-onset speedup installed: %ss on signal onset for %s " - "(confirm_timepoints=%d)", - fast_interval, embryo_ids or "all test embryos", confirm_timepoints, + "Test-onset speedup installed: %ss on signal onset for %s (confirm_timepoints=%d)", + fast_interval, + embryo_ids or "all test embryos", + confirm_timepoints, ) def add_burst_rule(self, rule): @@ -1519,7 +1581,7 @@ def add_test_burst_on_good_structure( mode: str = "1hz", num_slices: int = 1, confirm_timepoints: int = 2, - embryo_ids: Optional[List[str]] = None, + embryo_ids: list[str] | None = None, ): """Install the canonical 'TestEmbryo stably-bright structure → burst' rule. @@ -1532,7 +1594,8 @@ def add_test_burst_on_good_structure( """ if embryo_ids is None: embryo_ids = [ - eid for eid, e in self._embryo_states.items() + eid + for eid, e in self._embryo_states.items() if getattr(e, "role", "test") == "test" ] rule = BurstRule( @@ -1550,7 +1613,10 @@ def add_test_burst_on_good_structure( logger.info( "Test-burst rule installed: %d frames @ %s on stable structure for %s " "(confirm_timepoints=%d)", - frames, mode, embryo_ids or "all test embryos", confirm_timepoints, + frames, + mode, + embryo_ids or "all test embryos", + confirm_timepoints, ) def add_test_saturation_rampdown( @@ -1561,7 +1627,7 @@ def add_test_saturation_rampdown( floor_pct: float = 2.0, ceiling_pct: float = 6.0, confirm_timepoints: int = 0, - embryo_ids: Optional[List[str]] = None, + embryo_ids: list[str] | None = None, ): """Install the canonical 'TestEmbryo saturation → step laser down' rule. @@ -1573,7 +1639,8 @@ def add_test_saturation_rampdown( """ if embryo_ids is None: embryo_ids = [ - eid for eid, e in self._embryo_states.items() + eid + for eid, e in self._embryo_states.items() if getattr(e, "role", "test") == "test" ] rule = PowerRule( @@ -1592,14 +1659,17 @@ def add_test_saturation_rampdown( self.add_power_rule(rule) logger.info( "Test-saturation rampdown installed: %dnm step=%.2f%% floor=%.2f%% for %s", - wavelength, step_pct, floor_pct, embryo_ids or "all test embryos", + wavelength, + step_pct, + floor_pct, + embryo_ids or "all test embryos", ) # ------------------------------------------------------------------ # Phase 9: overnight persistence (timelapse.yaml) # ------------------------------------------------------------------ - def _session_storage_dir(self) -> Optional[Path]: + def _session_storage_dir(self) -> Path | None: """Resolve the FileStore-indexed folder for this session. Falls back to ``/sessions/`` (the legacy bare-id @@ -1614,7 +1684,7 @@ def _session_storage_dir(self) -> Optional[Path]: sd = self._store.root / "sessions" / self._session_id return sd - def save_state(self) -> Optional[Path]: + def save_state(self) -> Path | None: """Write orchestrator runtime state to ``timelapse.yaml``. Captures per-embryo cadence state, installed rules, burst state, @@ -1629,6 +1699,7 @@ def save_state(self) -> Optional[Path]: return None try: import yaml + path = sd / "timelapse.yaml" path.parent.mkdir(parents=True, exist_ok=True) doc = self._serialize_runtime_state() @@ -1655,6 +1726,7 @@ def load_state(self) -> str: return "No session — cannot load state." try: import yaml + candidates = [] sd = self._store._session_dir(self._session_id) if sd is not None: @@ -1665,7 +1737,7 @@ def load_state(self) -> str: path = next((p for p in candidates if p.exists()), None) if path is None: return f"No timelapse.yaml at {candidates[0]}" - with open(path, "r", encoding="utf-8") as f: + with open(path, encoding="utf-8") as f: doc = yaml.safe_load(f) or {} except Exception as e: return f"Failed to read timelapse.yaml: {e}" @@ -1682,8 +1754,9 @@ def load_state(self) -> str: f"{len(self._exclusive_queue)} queued bursts" ) - def _serialize_runtime_state(self) -> Dict[str, Any]: + def _serialize_runtime_state(self) -> dict[str, Any]: """Build the timelapse.yaml document.""" + def _iso(dt): return dt.isoformat() if dt is not None else None @@ -1737,7 +1810,8 @@ def _ser_stop_condition(sc): def _ser_interval_rule(r): return { - "name": r.name, "trigger_detector": r.trigger_detector, + "name": r.name, + "trigger_detector": r.trigger_detector, "trigger_stage": r.trigger_stage, "new_interval_seconds": r.new_interval_seconds, "applies_to": list(r.applies_to) if r.applies_to else None, @@ -1747,12 +1821,14 @@ def _ser_interval_rule(r): def _ser_power_rule(r): return { - "name": r.name, "wavelength": r.wavelength, + "name": r.name, + "wavelength": r.wavelength, "trigger_detector": r.trigger_detector, "trigger_intensity_levels": list(r.trigger_intensity_levels or []), "trigger_stage": r.trigger_stage, "step_pct": r.step_pct, - "floor_pct": r.floor_pct, "ceiling_pct": r.ceiling_pct, + "floor_pct": r.floor_pct, + "ceiling_pct": r.ceiling_pct, "direction": r.direction, "applies_to": list(r.applies_to) if r.applies_to else None, "confirm_timepoints": r.confirm_timepoints, @@ -1803,10 +1879,10 @@ def _ser_burst_rule(r): "embryos": embryos, } - def _apply_runtime_state(self, doc: Dict[str, Any]) -> None: + def _apply_runtime_state(self, doc: dict[str, Any]) -> None: """Restore orchestrator state from a parsed timelapse.yaml dict.""" - from .timelapse_models import IntervalRule, PowerRule, BurstRule, StopCondition from .exclusive import BurstAcquisition + from .timelapse_models import BurstRule, IntervalRule, PowerRule, StopCondition def _parse_dt(s): if not s: @@ -1816,7 +1892,9 @@ def _parse_dt(s): except (TypeError, ValueError): return None - self._base_interval_seconds = float(doc.get("base_interval_seconds", self._base_interval_seconds)) + self._base_interval_seconds = float( + doc.get("base_interval_seconds", self._base_interval_seconds) + ) self._current_round = int(doc.get("current_round", self._current_round)) self._total_timepoints = int(doc.get("total_timepoints", self._total_timepoints)) self._dose_budget_base_ms = doc.get("dose_budget_base_ms") @@ -1827,62 +1905,70 @@ def _parse_dt(s): self._started_at = started self._interval_rules = [] - for r in (doc.get("interval_rules") or []): - self._interval_rules.append(IntervalRule( - name=r["name"], - trigger_detector=r.get("trigger_detector"), - trigger_stage=r.get("trigger_stage"), - new_interval_seconds=float(r.get("new_interval_seconds", 30.0)), - applies_to=r.get("applies_to"), - confirm_timepoints=int(r.get("confirm_timepoints", 0)), - one_time=bool(r.get("one_time", True)), - )) + for r in doc.get("interval_rules") or []: + self._interval_rules.append( + IntervalRule( + name=r["name"], + trigger_detector=r.get("trigger_detector"), + trigger_stage=r.get("trigger_stage"), + new_interval_seconds=float(r.get("new_interval_seconds", 30.0)), + applies_to=r.get("applies_to"), + confirm_timepoints=int(r.get("confirm_timepoints", 0)), + one_time=bool(r.get("one_time", True)), + ) + ) self._power_rules = [] - for r in (doc.get("power_rules") or []): - self._power_rules.append(PowerRule( - name=r["name"], - wavelength=int(r.get("wavelength", 488)), - trigger_detector=r.get("trigger_detector"), - trigger_intensity_levels=r.get("trigger_intensity_levels") or None, - trigger_stage=r.get("trigger_stage"), - step_pct=float(r.get("step_pct", 1.0)), - floor_pct=float(r.get("floor_pct", 2.0)), - ceiling_pct=float(r.get("ceiling_pct", 6.0)), - direction=r.get("direction", "down"), - applies_to=r.get("applies_to"), - confirm_timepoints=int(r.get("confirm_timepoints", 0)), - one_time=bool(r.get("one_time", False)), - )) + for r in doc.get("power_rules") or []: + self._power_rules.append( + PowerRule( + name=r["name"], + wavelength=int(r.get("wavelength", 488)), + trigger_detector=r.get("trigger_detector"), + trigger_intensity_levels=r.get("trigger_intensity_levels") or None, + trigger_stage=r.get("trigger_stage"), + step_pct=float(r.get("step_pct", 1.0)), + floor_pct=float(r.get("floor_pct", 2.0)), + ceiling_pct=float(r.get("ceiling_pct", 6.0)), + direction=r.get("direction", "down"), + applies_to=r.get("applies_to"), + confirm_timepoints=int(r.get("confirm_timepoints", 0)), + one_time=bool(r.get("one_time", False)), + ) + ) self._burst_rules = [] - for r in (doc.get("burst_rules") or []): - self._burst_rules.append(BurstRule( - name=r["name"], - trigger_detector=r.get("trigger_detector"), - trigger_intensity_levels=r.get("trigger_intensity_levels") or None, - trigger_structure_qualities=r.get("trigger_structure_qualities") or None, - frames=int(r.get("frames", 60)), - mode=r.get("mode", "1hz"), - num_slices=int(r.get("num_slices", 1)), - applies_to=r.get("applies_to"), - confirm_timepoints=int(r.get("confirm_timepoints", 0)), - )) + for r in doc.get("burst_rules") or []: + self._burst_rules.append( + BurstRule( + name=r["name"], + trigger_detector=r.get("trigger_detector"), + trigger_intensity_levels=r.get("trigger_intensity_levels") or None, + trigger_structure_qualities=r.get("trigger_structure_qualities") or None, + frames=int(r.get("frames", 60)), + mode=r.get("mode", "1hz"), + num_slices=int(r.get("num_slices", 1)), + applies_to=r.get("applies_to"), + confirm_timepoints=int(r.get("confirm_timepoints", 0)), + ) + ) self._applied_rules = { eid: set(names) for eid, names in (doc.get("applied_rules") or {}).items() } self._exclusive_queue.clear() - for op_doc in (doc.get("exclusive_queue") or []): + for op_doc in doc.get("exclusive_queue") or []: if op_doc.get("kind") == "burst": - self._exclusive_queue.append(BurstAcquisition( - target_embryo_id=op_doc["target_embryo_id"], - frames=int(op_doc.get("frames", 60)), - mode=op_doc.get("mode", "1hz"), - num_slices=int(op_doc.get("num_slices", 1)), - request_id=op_doc.get("request_id"), - )) + self._exclusive_queue.append( + BurstAcquisition( + target_embryo_id=op_doc["target_embryo_id"], + frames=int(op_doc.get("frames", 60)), + mode=op_doc.get("mode", "1hz"), + num_slices=int(op_doc.get("num_slices", 1)), + request_id=op_doc.get("request_id"), + ) + ) def _deser_stop_condition(d): """Rebuild a StopCondition from the dict written by _ser_stop_condition.""" @@ -1895,17 +1981,19 @@ def _deser_stop_condition(d): target_stages=set(d.get("target_stages") or []) or None, confirm_timepoints=int(d.get("confirm_timepoints") or 0), ) - for ad in (d.get("additional") or []): - primary.add_condition(StopCondition( - condition_type=StopConditionType(ad["condition_type"]), - value=ad.get("value"), - target_stages=set(ad.get("target_stages") or []) or None, - confirm_timepoints=int(ad.get("confirm_timepoints") or 0), - )) + for ad in d.get("additional") or []: + primary.add_condition( + StopCondition( + condition_type=StopConditionType(ad["condition_type"]), + value=ad.get("value"), + target_stages=set(ad.get("target_stages") or []) or None, + confirm_timepoints=int(ad.get("confirm_timepoints") or 0), + ) + ) return primary except Exception: # Fall back to spec-string parse if shape changed across versions. - spec = (d.get("spec") if isinstance(d, dict) else None) + spec = d.get("spec") if isinstance(d, dict) else None if isinstance(spec, str) and spec: try: return StopCondition.parse(spec) @@ -1920,10 +2008,17 @@ def _deser_stop_condition(d): if embryo is None: continue for attr in ( - "cadence_phase", "interval_seconds", "laser_power_488_pct", - "total_exposure_ms", "timepoints_acquired", "is_complete", - "completion_reason", "should_skip", "skip_reason", - "detection_triggered_at", "detection_type", + "cadence_phase", + "interval_seconds", + "laser_power_488_pct", + "total_exposure_ms", + "timepoints_acquired", + "is_complete", + "completion_reason", + "should_skip", + "skip_reason", + "detection_triggered_at", + "detection_type", "no_object_since_timepoint", ): if attr in ed: @@ -1975,6 +2070,7 @@ def queue_burst( If True, queue even if this embryo has already had a burst. """ from .exclusive import BurstAcquisition + if embryo_id not in self._embryo_states: return f"Embryo '{embryo_id}' not in active timelapse." if not force and embryo_id in self._burst_applied: @@ -1992,18 +2088,24 @@ def queue_burst( ) self._exclusive_queue.append(op) logger.info( - "Queued burst for %s: frames=%d mode=%s num_slices=%d request_id=%s " - "(queue depth=%d)", - embryo_id, frames, mode, num_slices, op.request_id, + "Queued burst for %s: frames=%d mode=%s num_slices=%d request_id=%s (queue depth=%d)", + embryo_id, + frames, + mode, + num_slices, + op.request_id, len(self._exclusive_queue), ) - self._emit_event(EventType.BURST_QUEUED, { - "embryo_id": embryo_id, - "request_id": op.request_id, - "position_in_queue": len(self._exclusive_queue), - "frames": frames, - "mode": mode, - }) + self._emit_event( + EventType.BURST_QUEUED, + { + "embryo_id": embryo_id, + "request_id": op.request_id, + "position_in_queue": len(self._exclusive_queue), + "frames": frames, + "mode": mode, + }, + ) return ( f"Burst queued for {embryo_id} (request_id={op.request_id}, " f"frames={frames}, mode={mode}, queue_depth={len(self._exclusive_queue)})" @@ -2016,9 +2118,9 @@ def queue_burst( def run_calibration_pipelines( self, *, - pipelines: Optional[List[str]] = None, - source_volumes: Optional[Dict[str, Any]] = None, - embryo_bboxes: Optional[Dict[str, Any]] = None, + pipelines: list[str] | None = None, + source_volumes: dict[str, Any] | None = None, + embryo_bboxes: dict[str, Any] | None = None, ) -> str: """Run the named calibration pipelines on CalibrationEmbryo volumes and merge the result into ``self._calibration_data``. @@ -2032,7 +2134,8 @@ def run_calibration_pipelines( ``edge_bbox`` out and applies them as preprocessing. """ from gently.app.calibration import ( - get_calibration_pipeline, aggregate_calibrations, + aggregate_calibrations, + get_calibration_pipeline, ) if pipelines is None: @@ -2052,7 +2155,9 @@ def run_calibration_pipelines( captured.append(data) logger.info( "Calibration pipeline '%s' captured: keys=%s notes=%s", - pname, sorted(data.payload.keys()), data.notes, + pname, + sorted(data.payload.keys()), + data.notes, ) except Exception as e: logger.warning("Calibration pipeline '%s' failed: %s", pname, e) @@ -2069,7 +2174,7 @@ def run_calibration_pipelines( if sd is not None: try: import yaml - from pathlib import Path + cal_dir = sd / "calibration" cal_dir.mkdir(parents=True, exist_ok=True) manifest = { @@ -2088,6 +2193,7 @@ def run_calibration_pipelines( yaml.safe_dump(manifest, f, sort_keys=False) # Save heavy arrays as .npy import numpy as _np + for c in captured: for key, value in c.payload.items(): if isinstance(value, _np.ndarray): @@ -2119,6 +2225,7 @@ def enable_expression_monitoring( for the declarative form. """ from .monitoring_modes import ExpressionMonitoringMode + mode = ExpressionMonitoringMode( name="expression_monitoring", description="", @@ -2135,7 +2242,7 @@ def enable_monitoring_mode( self, name: str, *, - embryo_ids: Optional[List[str]] = None, + embryo_ids: list[str] | None = None, **mode_kwargs, ) -> str: """Activate a named MonitoringMode from the registry. @@ -2147,11 +2254,11 @@ def enable_monitoring_mode( constructor (e.g. ``fast_interval=30.0``). """ from .monitoring_modes import MONITORING_MODES + factory = MONITORING_MODES.get(name) if factory is None: return ( - f"Unknown monitoring mode: {name!r}. " - f"Available: {sorted(MONITORING_MODES.keys())}" + f"Unknown monitoring mode: {name!r}. Available: {sorted(MONITORING_MODES.keys())}" ) mode = factory(**mode_kwargs) if mode_kwargs else factory() mode.activate(self, embryo_ids=embryo_ids) @@ -2161,10 +2268,10 @@ def enable_monitoring_mode( def _check_interval_rules( self, embryo_id: str, - detector_name: Optional[str] = None, - stage: Optional[str] = None, - intensity_level: Optional[str] = None, - structure_quality: Optional[str] = None, + detector_name: str | None = None, + stage: str | None = None, + intensity_level: str | None = None, + structure_quality: str | None = None, ): """ Evaluate all adaptive rules (interval + power) against a fresh @@ -2228,36 +2335,46 @@ def _check_interval_rules( if prule.wavelength == 488: estate.laser_power_488_pct = new_pct # (Future wavelengths: extend EmbryoState similarly.) - self._emit_event(EventType.POWER_RAMP_STEP, { - "embryo_id": embryo_id, - "rule": prule.name, - "wavelength": prule.wavelength, - "old_pct": current, - "new_pct": new_pct, - "direction": prule.direction, - "intensity_level": intensity_level, - }) - # Discrete trigger-fired event so the strategy view can show - # rule firings without inferring them from the power-step event. - self._emit_event(EventType.TRIGGER_FIRED, { - "embryo_id": embryo_id, - "rule_name": prule.name, - "rule_kind": "power", - "trigger_detector": prule.trigger_detector, - "trigger_stage": prule.trigger_stage, - "trigger_intensity_level": intensity_level, - "applied": { + self._emit_event( + EventType.POWER_RAMP_STEP, + { + "embryo_id": embryo_id, + "rule": prule.name, "wavelength": prule.wavelength, "old_pct": current, "new_pct": new_pct, "direction": prule.direction, + "intensity_level": intensity_level, }, - }) + ) + # Discrete trigger-fired event so the strategy view can show + # rule firings without inferring them from the power-step event. + self._emit_event( + EventType.TRIGGER_FIRED, + { + "embryo_id": embryo_id, + "rule_name": prule.name, + "rule_kind": "power", + "trigger_detector": prule.trigger_detector, + "trigger_stage": prule.trigger_stage, + "trigger_intensity_level": intensity_level, + "applied": { + "wavelength": prule.wavelength, + "old_pct": current, + "new_pct": new_pct, + "direction": prule.direction, + }, + }, + ) logger.info( - "Applied PowerRule '%s' on %s: %dnm %.2f%% -> %.2f%% " - "(direction=%s, intensity=%s)", - prule.name, embryo_id, prule.wavelength, - current, new_pct, prule.direction, intensity_level, + "Applied PowerRule '%s' on %s: %dnm %.2f%% -> %.2f%% (direction=%s, intensity=%s)", + prule.name, + embryo_id, + prule.wavelength, + current, + new_pct, + prule.direction, + intensity_level, ) if prule.one_time: @@ -2303,24 +2420,30 @@ def _check_interval_rules( new_interval_seconds=rule.new_interval_seconds, reason=f"rule:{rule.name}", ) - self._emit_event(EventType.TRIGGER_FIRED, { - "embryo_id": embryo_id, - "rule_name": rule.name, - "rule_kind": "interval", - "trigger_detector": rule.trigger_detector, - "trigger_stage": rule.trigger_stage, - "trigger_intensity_level": None, - "applied": { - "old_interval_s": old_interval, - "new_interval_s": rule.new_interval_seconds, - "one_time": rule.one_time, - "confirm_timepoints": rule.confirm_timepoints, + self._emit_event( + EventType.TRIGGER_FIRED, + { + "embryo_id": embryo_id, + "rule_name": rule.name, + "rule_kind": "interval", + "trigger_detector": rule.trigger_detector, + "trigger_stage": rule.trigger_stage, + "trigger_intensity_level": None, + "applied": { + "old_interval_s": old_interval, + "new_interval_s": rule.new_interval_seconds, + "one_time": rule.one_time, + "confirm_timepoints": rule.confirm_timepoints, + }, }, - }) + ) logger.info( "Applied interval rule '%s' on %s: %ss -> %ss (confirm=%d)", - rule.name, embryo_id, old_interval, - rule.new_interval_seconds, rule.confirm_timepoints, + rule.name, + embryo_id, + old_interval, + rule.new_interval_seconds, + rule.confirm_timepoints, ) if rule.one_time: @@ -2360,27 +2483,34 @@ def _check_interval_rules( mode=brule.mode, num_slices=brule.num_slices, ) - self._emit_event(EventType.TRIGGER_FIRED, { - "embryo_id": embryo_id, - "rule_name": brule.name, - "rule_kind": "burst", - "trigger_detector": brule.trigger_detector, - "trigger_intensity_level": intensity_level, - "trigger_structure_quality": structure_quality, - "applied": { - "frames": brule.frames, - "mode": brule.mode, - "num_slices": brule.num_slices, - "confirm_timepoints": brule.confirm_timepoints, - "queue_result": result, + self._emit_event( + EventType.TRIGGER_FIRED, + { + "embryo_id": embryo_id, + "rule_name": brule.name, + "rule_kind": "burst", + "trigger_detector": brule.trigger_detector, + "trigger_intensity_level": intensity_level, + "trigger_structure_quality": structure_quality, + "applied": { + "frames": brule.frames, + "mode": brule.mode, + "num_slices": brule.num_slices, + "confirm_timepoints": brule.confirm_timepoints, + "queue_result": result, + }, }, - }) + ) logger.info( "Applied burst rule '%s' on %s (intensity=%s structure=%s): %s", - brule.name, embryo_id, intensity_level, structure_quality, result, + brule.name, + embryo_id, + intensity_level, + structure_quality, + result, ) - def _finalize_perception_run(self, status: str = "completed", error_message: str = None): + def _finalize_perception_run(self, status: str = "completed", error_message: str | None = None): """Mark the perception run as finished in FileStore.""" if self._store and self._perception_run_id is not None: try: @@ -2393,7 +2523,7 @@ def _finalize_perception_run(self, status: str = "completed", error_message: str except Exception as e: logger.warning(f"Failed to finalize perception run: {e}") - def _emit_event(self, event_type: EventType, data: Dict): + def _emit_event(self, event_type: EventType, data: dict): """Emit event to event bus""" self._event_bus.publish( event_type=event_type, @@ -2426,7 +2556,7 @@ def _volume_to_b64(self, volume) -> tuple: if view_a.ndim == 3: z_depth, height, width = view_a.shape if width > height * 2: - view_a = view_a[:, :, :width // 2] + view_a = view_a[:, :, : width // 2] bounds = compute_crop_bounds(view_a) cropped = apply_crop_bounds(view_a, bounds) three_view_img, _ = projection_three_view(cropped) @@ -2445,7 +2575,7 @@ async def _run_detector( volume, embryo_state: EmbryoState, detector_name: str, - volume_uids: dict = None, + volume_uids: dict | None = None, ): """Run a role-declared Detector (Phase 2) and persist + emit results. @@ -2465,7 +2595,8 @@ async def _run_detector( if detector is None: logger.warning( "Unknown detector '%s' for embryo %s — skipping detection", - detector_name, embryo_id, + detector_name, + embryo_id, ) return @@ -2493,18 +2624,20 @@ async def _run_detector( # Mirror onto EmbryoState for the agent's prompt + rule machinery. try: - embryo_state.cv_analyses.setdefault(detector_name, []).append({ - "timepoint": timepoint, - "intensity_level": intensity_level, - "structure_quality": structure_quality, - "has_hatched": has_hatched, - "reasoning": result.reasoning, - }) + embryo_state.cv_analyses.setdefault(detector_name, []).append( + { + "timepoint": timepoint, + "intensity_level": intensity_level, + "structure_quality": structure_quality, + "has_hatched": has_hatched, + "reasoning": result.reasoning, + } + ) # Keep a rolling cap so cv_analyses doesn't grow unbounded. if len(embryo_state.cv_analyses[detector_name]) > 200: - embryo_state.cv_analyses[detector_name] = ( - embryo_state.cv_analyses[detector_name][-200:] - ) + embryo_state.cv_analyses[detector_name] = embryo_state.cv_analyses[detector_name][ + -200: + ] # If detector flagged hatched, update legacy field too. if has_hatched: embryo_state.hatching_status = { @@ -2552,7 +2685,9 @@ async def _run_detector( elif has_hatched: pseudo_stage = "hatched" else: - pseudo_stage = "no_object" if intensity_level == "NONE" else (intensity_level or "unknown") + pseudo_stage = ( + "no_object" if intensity_level == "NONE" else (intensity_level or "unknown") + ) # Track consecutive no_object across roles — drives the role-based # terminal stop in _check_stop_condition. @@ -2607,22 +2742,28 @@ async def _run_detector( event_data["projection_uid"] = volume_uids.get("projection_uid") self._emit_event(EventType.DETECTOR_EVALUATED, event_data) # Dedicated Phase 10 event for the per-detector findings stream. - self._emit_event(EventType.CLAUDE_DETECTOR_RESULT, { - "embryo_id": embryo_id, - "timepoint": timepoint, - "detector_name": detector_name, - "findings": findings, - "reasoning": result.reasoning, - "description": description, - }) - - if has_hatched: - self._emit_event(EventType.HATCHING_DETECTED, { + self._emit_event( + EventType.CLAUDE_DETECTOR_RESULT, + { "embryo_id": embryo_id, "timepoint": timepoint, "detector_name": detector_name, - "stage": "hatched", - }) + "findings": findings, + "reasoning": result.reasoning, + "description": description, + }, + ) + + if has_hatched: + self._emit_event( + EventType.HATCHING_DETECTED, + { + "embryo_id": embryo_id, + "timepoint": timepoint, + "detector_name": detector_name, + "stage": "hatched", + }, + ) # Drive Phase 5 reactive rules — pass both the pseudo-stage and # the detailed findings so rules can match either. @@ -2636,8 +2777,12 @@ async def _run_detector( logger.info( "[%s] T%d: detector=%s intensity=%s structure=%s hatched=%s", - embryo_id, timepoint, detector_name, - intensity_level, structure_quality, has_hatched, + embryo_id, + timepoint, + detector_name, + intensity_level, + structure_quality, + has_hatched, ) async def _run_perception( @@ -2646,7 +2791,7 @@ async def _run_perception( timepoint: int, volume, embryo_state: EmbryoState, - volume_uids: dict = None, + volume_uids: dict | None = None, ): """Run the per-role detector on the acquired volume and emit results. @@ -2662,7 +2807,6 @@ async def _run_perception( # ad-hoc detector path; otherwise fall through to the original # Perceiver-based flow below. from gently.harness.roles import REGISTRY as ROLE_REGISTRY - from gently.app.detectors import get_detector role_def = ROLE_REGISTRY.get(getattr(embryo_state, "role", "test")) detector_name = role_def.detector_name if role_def else None @@ -2686,12 +2830,20 @@ async def _run_perception( (timepoints_since // self.NO_OBJECT_RECHECK_INTERVAL + 1) * self.NO_OBJECT_RECHECK_INTERVAL ) - self._emit_event(EventType.DETECTOR_EVALUATED, { - 'embryo_id': embryo_id, 'timepoint': timepoint, - 'detector_name': 'perception', 'stage': 'no_object', - 'reasoning': f"Skipped (empty field). Rechecking in {next_recheck - timepoint} timepoints.", - 'skipped': True, - }) + self._emit_event( + EventType.DETECTOR_EVALUATED, + { + "embryo_id": embryo_id, + "timepoint": timepoint, + "detector_name": "perception", + "stage": "no_object", + "reasoning": ( + f"Skipped (empty field). Rechecking in" + f" {next_recheck - timepoint} timepoints." + ), + "skipped": True, + }, + ) return else: logger.info(f"Rechecking no_object embryo {embryo_id} at t={timepoint}") @@ -2718,7 +2870,9 @@ async def _run_perception( logger.info(f"Embryo {embryo_id} marked as no_object at t={timepoint}") else: if embryo_state.no_object_since_timepoint is not None: - logger.info(f"Embryo {embryo_id} object found at t={timepoint}, resuming perception") + logger.info( + f"Embryo {embryo_id} object found at t={timepoint}, resuming perception" + ) embryo_state.no_object_since_timepoint = None embryo_state.consecutive_no_object = 0 @@ -2752,28 +2906,35 @@ async def _run_perception( # Build and emit perception event event_data = { - 'embryo_id': embryo_id, 'timepoint': timepoint, - 'detector_name': 'perception', - 'stage': result.stage, - 'reasoning': result.reasoning, + "embryo_id": embryo_id, + "timepoint": timepoint, + "detector_name": "perception", + "stage": result.stage, + "reasoning": result.reasoning, } if volume_uids: - event_data['volume_uid'] = volume_uids.get('volume_uid') - event_data['projection_uid'] = volume_uids.get('projection_uid') + event_data["volume_uid"] = volume_uids.get("volume_uid") + event_data["projection_uid"] = volume_uids.get("projection_uid") if session: - event_data['stability'] = session.stability + event_data["stability"] = session.stability summary = session.summary() - if summary.get('temporal'): + if summary.get("temporal"): from dataclasses import asdict - event_data['temporal_analysis'] = asdict(summary['temporal']) + + event_data["temporal_analysis"] = asdict(summary["temporal"]) self._emit_event(EventType.DETECTOR_EVALUATED, event_data) if result.stage in ("hatching", "hatched"): - self._emit_event(EventType.HATCHING_DETECTED, { - 'embryo_id': embryo_id, 'timepoint': timepoint, - 'detector_name': 'hatching', 'stage': result.stage, - }) + self._emit_event( + EventType.HATCHING_DETECTED, + { + "embryo_id": embryo_id, + "timepoint": timepoint, + "detector_name": "hatching", + "stage": result.stage, + }, + ) # Check interval rules based on stage self._check_interval_rules(embryo_id=embryo_id, stage=result.stage) @@ -2790,7 +2951,7 @@ def _write_trace_file(self, embryo_id: str, timepoint: int, trace_data: dict) -> """Write perception trace to JSON file.""" filename = f"{embryo_id}_T{timepoint:04d}.json" file_path = self._trace_dir / filename - with open(file_path, 'w', encoding='utf-8') as f: + with open(file_path, "w", encoding="utf-8") as f: json.dump(trace_data, f, indent=2, ensure_ascii=False) logger.debug(f"Wrote trace: {file_path.name}") return file_path diff --git a/gently/app/orchestration/timelapse_models.py b/gently/app/orchestration/timelapse_models.py index c3d46fb5..262d2197 100644 --- a/gently/app/orchestration/timelapse_models.py +++ b/gently/app/orchestration/timelapse_models.py @@ -8,18 +8,19 @@ from dataclasses import dataclass, field from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional, Set +from typing import Any from gently.organisms import get_organism class StopConditionType(Enum): """Types of stop conditions for embryo acquisition""" - MANUAL = "manual" # Stop only when user says - STAGE_BASED = "stage_based" # Stop when any of target stages reached + + MANUAL = "manual" # Stop only when user says + STAGE_BASED = "stage_based" # Stop when any of target stages reached FIXED_TIMEPOINTS = "fixed_timepoints" # Stop after N timepoints - DURATION = "duration" # Stop after X hours - ALL_TERMINAL = "all_terminal" # Stop when all embryos reach terminal stage + DURATION = "duration" # Stop after X hours + ALL_TERMINAL = "all_terminal" # Stop when all embryos reach terminal stage # Phase 8: stop when every role='test' embryo has hatched # (via the dopaminergic detector setting hatching_status.hatched=True). ALL_TEST_HATCHED = "all_test_hatched" @@ -38,19 +39,20 @@ class IntervalRule: ``applies_to`` filters which embryos this rule listens to — it is not a fan-out target list. """ + name: str - trigger_detector: Optional[str] = None # Detector name that triggers this rule - trigger_stage: Optional[str] = None # Stage name that triggers (comma, pretzel, etc.) - new_interval_seconds: float = 30.0 # New interval when triggered - applies_to: Optional[List[str]] = None # Embryo IDs this rule listens to (None = all) - confirm_timepoints: int = 0 # require N consecutive trigger matches before firing - one_time: bool = True # Only apply once per embryo + trigger_detector: str | None = None # Detector name that triggers this rule + trigger_stage: str | None = None # Stage name that triggers (comma, pretzel, etc.) + new_interval_seconds: float = 30.0 # New interval when triggered + applies_to: list[str] | None = None # Embryo IDs this rule listens to (None = all) + confirm_timepoints: int = 0 # require N consecutive trigger matches before firing + one_time: bool = True # Only apply once per embryo def matches( self, embryo_id: str, - detector_name: Optional[str] = None, - stage: Optional[str] = None, + detector_name: str | None = None, + stage: str | None = None, ) -> bool: """Check if this rule should trigger""" if self.applies_to and embryo_id not in self.applies_to: @@ -77,25 +79,26 @@ class PowerRule: The hard safety limit at DiSPIMLightSource.POWER_LIMITS_PCT is the bottom-line bound; this is the soft control on top. """ + name: str wavelength: int = 488 - trigger_detector: Optional[str] = None - trigger_intensity_levels: Optional[List[str]] = None # e.g. ["SATURATING"] - trigger_stage: Optional[str] = None - step_pct: float = 1.0 # how much to change per firing - floor_pct: float = 2.0 # never go below - ceiling_pct: float = 6.0 # never go above - direction: str = "down" # "down" (sticky-downward) or "up" - applies_to: Optional[List[str]] = None - confirm_timepoints: int = 0 # require N consecutive firings before applying - one_time: bool = False # default: fire repeatedly for ramps + trigger_detector: str | None = None + trigger_intensity_levels: list[str] | None = None # e.g. ["SATURATING"] + trigger_stage: str | None = None + step_pct: float = 1.0 # how much to change per firing + floor_pct: float = 2.0 # never go below + ceiling_pct: float = 6.0 # never go above + direction: str = "down" # "down" (sticky-downward) or "up" + applies_to: list[str] | None = None + confirm_timepoints: int = 0 # require N consecutive firings before applying + one_time: bool = False # default: fire repeatedly for ramps def matches( self, embryo_id: str, - detector_name: Optional[str] = None, - stage: Optional[str] = None, - intensity_level: Optional[str] = None, + detector_name: str | None = None, + stage: str | None = None, + intensity_level: str | None = None, ) -> bool: if self.applies_to and embryo_id not in self.applies_to: return False @@ -127,22 +130,23 @@ class BurstRule: downstream by ``queue_burst`` (which gates on ``_burst_applied``), so this rule doesn't need its own ``one_time`` flag. """ + name: str - trigger_detector: Optional[str] = None - trigger_intensity_levels: Optional[List[str]] = None # AND-combined predicate - trigger_structure_qualities: Optional[List[str]] = None # AND-combined predicate + trigger_detector: str | None = None + trigger_intensity_levels: list[str] | None = None # AND-combined predicate + trigger_structure_qualities: list[str] | None = None # AND-combined predicate frames: int = 60 - mode: str = "1hz" # "1hz" | "asap" + mode: str = "1hz" # "1hz" | "asap" num_slices: int = 1 - applies_to: Optional[List[str]] = None # listen-filter + applies_to: list[str] | None = None # listen-filter confirm_timepoints: int = 0 # require N consecutive matches before firing def matches( self, embryo_id: str, - detector_name: Optional[str] = None, - intensity_level: Optional[str] = None, - structure_quality: Optional[str] = None, + detector_name: str | None = None, + intensity_level: str | None = None, + structure_quality: str | None = None, ) -> bool: if self.applies_to and embryo_id not in self.applies_to: return False @@ -173,29 +177,33 @@ class StopCondition: specifies how many additional timepoints to acquire after detection before actually stopping - useful to verify the detection is real. """ + condition_type: StopConditionType value: Any = None # e.g., number of timepoints, hours, etc. - target_stages: Optional[Set[str]] = None # Stages that satisfy STAGE_BASED condition + target_stages: set[str] | None = None # Stages that satisfy STAGE_BASED condition confirm_timepoints: int = 0 # Extra timepoints to acquire after detection - additional_conditions: List['StopCondition'] = field(default_factory=list) + additional_conditions: list["StopCondition"] = field(default_factory=list) - def add_condition(self, condition: 'StopCondition') -> None: + def add_condition(self, condition: "StopCondition") -> None: """Add another stop condition (OR logic).""" self.additional_conditions.append(condition) - def all_conditions(self) -> List['StopCondition']: + def all_conditions(self) -> list["StopCondition"]: """Get all conditions including self (flattened).""" return [self] + self.additional_conditions def describe(self) -> str: """Human-readable description of the stop condition(s).""" - def _describe_single(cond: 'StopCondition') -> str: + + def _describe_single(cond: "StopCondition") -> str: confirm_suffix = f"+{cond.confirm_timepoints}tp" if cond.confirm_timepoints > 0 else "" if cond.condition_type == StopConditionType.MANUAL: return "manual" - elif cond.condition_type in (StopConditionType.STAGE_BASED, - StopConditionType.HATCHING, - StopConditionType.COMMA_STAGE): + elif cond.condition_type in ( + StopConditionType.STAGE_BASED, + StopConditionType.HATCHING, + StopConditionType.COMMA_STAGE, + ): stages_str = ",".join(sorted(cond.target_stages)) if cond.target_stages else "?" return f"stages({stages_str}){confirm_suffix}" elif cond.condition_type == StopConditionType.FIXED_TIMEPOINTS: @@ -214,7 +222,7 @@ def _describe_single(cond: 'StopCondition') -> str: return " OR ".join(descriptions) @classmethod - def until_hatching(cls, confirm_timepoints: int = 0) -> 'StopCondition': + def until_hatching(cls, confirm_timepoints: int = 0) -> "StopCondition": """Stop when hatching is detected (backward-compatible convenience method).""" organism = get_organism() return cls( @@ -224,7 +232,7 @@ def until_hatching(cls, confirm_timepoints: int = 0) -> 'StopCondition': ) @classmethod - def until_comma(cls, confirm_timepoints: int = 0) -> 'StopCondition': + def until_comma(cls, confirm_timepoints: int = 0) -> "StopCondition": """Stop when comma stage is detected (backward-compatible convenience method).""" organism = get_organism() return cls( @@ -234,19 +242,19 @@ def until_comma(cls, confirm_timepoints: int = 0) -> 'StopCondition': ) @classmethod - def fixed_timepoints(cls, n: int) -> 'StopCondition': + def fixed_timepoints(cls, n: int) -> "StopCondition": return cls(StopConditionType.FIXED_TIMEPOINTS, value=n) @classmethod - def duration_hours(cls, hours: float) -> 'StopCondition': + def duration_hours(cls, hours: float) -> "StopCondition": return cls(StopConditionType.DURATION, value=hours) @classmethod - def manual(cls) -> 'StopCondition': + def manual(cls) -> "StopCondition": return cls(StopConditionType.MANUAL) @classmethod - def all_test_hatched(cls, confirm_timepoints: int = 0) -> 'StopCondition': + def all_test_hatched(cls, confirm_timepoints: int = 0) -> "StopCondition": """Stop when EVERY role='test' embryo's ``hatching_status.hatched`` flag is True (set by the dopaminergic detector / Phase 2 path).""" return cls( @@ -255,7 +263,7 @@ def all_test_hatched(cls, confirm_timepoints: int = 0) -> 'StopCondition': ) @classmethod - def composite(cls, *conditions: 'StopCondition') -> 'StopCondition': + def composite(cls, *conditions: "StopCondition") -> "StopCondition": """Create a composite stop condition from multiple conditions (OR logic).""" if not conditions: return cls.manual() @@ -265,7 +273,7 @@ def composite(cls, *conditions: 'StopCondition') -> 'StopCondition': return primary @classmethod - def parse(cls, spec: str) -> 'StopCondition': + def parse(cls, spec: str) -> "StopCondition": """ Parse a stop condition specification string. @@ -278,29 +286,30 @@ def parse(cls, spec: str) -> 'StopCondition': spec : str Specification like "hatching", "duration:10", "hatching|duration:10" """ - def _parse_single(s: str) -> 'StopCondition': + + def _parse_single(s: str) -> "StopCondition": s = s.strip().lower() # Check for confirmation timepoints suffix: "hatching+3" or "comma+5" confirm_timepoints = 0 - if '+' in s: - base, confirm_str = s.rsplit('+', 1) + if "+" in s: + base, confirm_str = s.rsplit("+", 1) try: confirm_timepoints = int(confirm_str) s = base except ValueError: pass - if s == 'manual': + if s == "manual": return cls.manual() - elif s in ('all_test_hatched', 'test_hatched'): + elif s in ("all_test_hatched", "test_hatched"): return cls.all_test_hatched(confirm_timepoints=confirm_timepoints) - elif s.startswith('timepoints:'): - n = int(s.split(':')[1]) + elif s.startswith("timepoints:"): + n = int(s.split(":")[1]) return cls.fixed_timepoints(n) - elif s.startswith('duration:'): - hours_str = s.split(':')[1] - if hours_str.endswith('h'): + elif s.startswith("duration:"): + hours_str = s.split(":")[1] + if hours_str.endswith("h"): hours_str = hours_str[:-1] hours = float(hours_str) return cls.duration_hours(hours) @@ -319,7 +328,7 @@ def _parse_single(s: str) -> 'StopCondition': f"{', '.join(organism.STOP_CONDITIONS.keys())}" ) - parts = spec.split('|') + parts = spec.split("|") conditions = [_parse_single(p) for p in parts] return cls.composite(*conditions) @@ -334,6 +343,7 @@ def _parse_single(s: str) -> 'StopCondition': class TimelapseStatus(Enum): """Overall timelapse status""" + IDLE = "idle" RUNNING = "running" PAUSED = "paused" @@ -344,41 +354,44 @@ class TimelapseStatus(Enum): @dataclass class TimelapseState: """Current state of the timelapse""" + status: TimelapseStatus - started_at: Optional[datetime] + started_at: datetime | None # Dict of embryo_id -> EmbryoState reference (from agent.experiment.embryos). # Typed Any to avoid importing harness/ from models/ (dependency direction). - embryos: Dict[str, Any] + embryos: dict[str, Any] total_timepoints: int = 0 current_round: int = 0 interval_seconds: float = 120.0 - next_round_time: Optional[datetime] = None - seconds_until_next_round: Optional[float] = None - error_message: Optional[str] = None + next_round_time: datetime | None = None + seconds_until_next_round: float | None = None + error_message: str | None = None - def to_dict(self) -> Dict: + def to_dict(self) -> dict: """Serialize for display""" active = [e for e in self.embryos.values() if not e.is_complete] completed = [e for e in self.embryos.values() if e.is_complete] return { - 'status': self.status.value, - 'started_at': self.started_at.isoformat() if self.started_at else None, - 'duration_minutes': (datetime.now() - self.started_at).total_seconds() / 60 if self.started_at else 0, - 'total_timepoints': self.total_timepoints, - 'current_round': self.current_round, - 'interval_seconds': self.interval_seconds, - 'next_round_time': self.next_round_time.isoformat() if self.next_round_time else None, - 'seconds_until_next_round': self.seconds_until_next_round, - 'active_embryos': len(active), - 'completed_embryos': len(completed), - 'embryo_details': { + "status": self.status.value, + "started_at": self.started_at.isoformat() if self.started_at else None, + "duration_minutes": (datetime.now() - self.started_at).total_seconds() / 60 + if self.started_at + else 0, + "total_timepoints": self.total_timepoints, + "current_round": self.current_round, + "interval_seconds": self.interval_seconds, + "next_round_time": self.next_round_time.isoformat() if self.next_round_time else None, + "seconds_until_next_round": self.seconds_until_next_round, + "active_embryos": len(active), + "completed_embryos": len(completed), + "embryo_details": { eid: { - 'timepoints': e.timepoints_acquired, - 'is_complete': e.is_complete, - 'completion_reason': e.completion_reason, + "timepoints": e.timepoints_acquired, + "is_complete": e.is_complete, + "completion_reason": e.completion_reason, } for eid, e in self.embryos.items() }, - 'error': self.error_message, + "error": self.error_message, } diff --git a/gently/app/queue_server_client.py b/gently/app/queue_server_client.py index 82697639..fbe177b3 100644 --- a/gently/app/queue_server_client.py +++ b/gently/app/queue_server_client.py @@ -1,4 +1,9 @@ """Backward-compatibility shim — client moved to gently.hardware.dispim.client.""" -from gently.hardware.dispim.client import DiSPIMMicroscope, QueueServerClient, create_queue_server_client -__all__ = ['DiSPIMMicroscope', 'QueueServerClient', 'create_queue_server_client'] +from gently.hardware.dispim.client import ( + DiSPIMMicroscope, + QueueServerClient, + create_queue_server_client, +) + +__all__ = ["DiSPIMMicroscope", "QueueServerClient", "create_queue_server_client"] diff --git a/gently/app/theme.py b/gently/app/theme.py index f8fc96ce..ddb8d8a6 100644 --- a/gently/app/theme.py +++ b/gently/app/theme.py @@ -6,12 +6,12 @@ """ from dataclasses import dataclass -from typing import Dict @dataclass class Theme: """Theme definition with colors and text indicators.""" + name: str color_mode: str # "dark" or "light" @@ -44,7 +44,7 @@ class Theme: icon_system: str = "System" -THEMES: Dict[str, Theme] = { +THEMES: dict[str, Theme] = { "vibrant": Theme( name="Vibrant", color_mode="dark", @@ -193,6 +193,6 @@ def set_theme(name: str) -> None: raise ValueError(f"Unknown theme: '{name}'. Available: {available}") -def list_themes() -> Dict[str, Theme]: +def list_themes() -> dict[str, Theme]: """Get all available themes.""" return THEMES.copy() diff --git a/gently/app/tools/__init__.py b/gently/app/tools/__init__.py index 765b3d33..d8ab70db 100644 --- a/gently/app/tools/__init__.py +++ b/gently/app/tools/__init__.py @@ -6,34 +6,45 @@ """ # Import all tool modules to register their tools -from . import experiment_tools -from . import stage_tools -from . import led_tools -from . import light_source_tools -from . import calibration_tools -from . import acquisition_tools -from . import volume_tools -from . import analysis_tools -from . import data_tools -from . import timelapse_tools -from . import session_tools -from . import focus_tools -from . import interaction_tools -from . import detection_tools -from . import plan_execution_tools -from . import memory_tools -from . import resolution_tools -from gently.harness.plan_mode.tools import lab_context as _lab_context # query_lab_history in run mode - -# Import tool registry utilities -from gently.harness.tools.registry import get_tool_registry, ToolCategory +from gently.harness.plan_mode.tools import ( + lab_context as _lab_context, # noqa: F401 +) +# query_lab_history in run mode # Re-export helper utilities for convenience from gently.harness.tools.helpers import ( - require_agent, get_embryo_or_error, require_microscope, - require_interaction_logger, require_developmental_tracker, - require_timelapse_orchestrator, require_databroker, - get_timestamp_string, format_duration + format_duration, + get_embryo_or_error, + get_timestamp_string, + require_agent, + require_databroker, + require_developmental_tracker, + require_interaction_logger, + require_microscope, + require_timelapse_orchestrator, +) + +# Import tool registry utilities +from gently.harness.tools.registry import ToolCategory, get_tool_registry + +from . import ( + acquisition_tools, + analysis_tools, + calibration_tools, + detection_tools, + experiment_tools, + focus_tools, + interaction_tools, + led_tools, + light_source_tools, + memory_tools, + plan_execution_tools, + resolution_tools, + session_tools, + stage_tools, + temperature_tools, + timelapse_tools, + volume_tools, ) @@ -48,3 +59,36 @@ def register_all_tools(): # Auto-register on import _registered = register_all_tools() + +__all__ = [ + # Helper utilities + "format_duration", + "get_embryo_or_error", + "get_timestamp_string", + "require_agent", + "require_databroker", + "require_developmental_tracker", + "require_interaction_logger", + "require_microscope", + "require_timelapse_orchestrator", + # Tool registry + "ToolCategory", + # Tool modules (imported for registration side effects) + "acquisition_tools", + "analysis_tools", + "calibration_tools", + "detection_tools", + "experiment_tools", + "focus_tools", + "interaction_tools", + "led_tools", + "light_source_tools", + "memory_tools", + "plan_execution_tools", + "resolution_tools", + "session_tools", + "stage_tools", + "temperature_tools", + "timelapse_tools", + "volume_tools", +] diff --git a/gently/app/tools/acquisition_tools.py b/gently/app/tools/acquisition_tools.py index 2500fe41..8efcef79 100644 --- a/gently/app/tools/acquisition_tools.py +++ b/gently/app/tools/acquisition_tools.py @@ -4,44 +4,52 @@ Tools for acquiring lightsheet volumes and images from the microscope. """ -import logging -from typing import Dict, Optional import asyncio +import logging import numpy as np -logger = logging.getLogger(__name__) +from gently.harness.tools.helpers import ctx_get, get_embryo_or_error +from gently.harness.tools.registry import ToolCategory, ToolExample, tool -from gently.harness.tools.registry import tool, ToolCategory, ToolExample -from gently.harness.tools.helpers import get_embryo_or_error +logger = logging.getLogger(__name__) @tool( name="acquire_volume", - description="""Acquire a single 3D lightsheet volume for a specific embryo. Moves to embryo position and uses its calibration data. -Use when user wants a full 3D stack of an embryo (e.g., "acquire volume of embryo 1", "take a 3D image"). -Embryo must be calibrated first. Default 50 slices at 10ms exposure takes ~2.5 seconds. Turns laser on during acquisition. - -The z_buffer_um parameter can override the calibrated Z range to add more empty space above/below the embryo. -This is useful for segmentation without needing to recalibrate. Set to None to use calibrated range.""", + description="""Acquire a single 3D lightsheet volume for a specific embryo. Moves to embryo +position and uses its calibration data. +Use when user wants a full 3D stack of an embryo (e.g., "acquire volume of embryo 1", +"take a 3D image"). Embryo must be calibrated first. Default 50 slices at 10ms exposure +takes ~2.5 seconds. Turns laser on during acquisition. + +The z_buffer_um parameter can override the calibrated Z range to add more empty space +above/below the embryo. This is useful for segmentation without needing to recalibrate. +Set to None to use calibrated range.""", category=ToolCategory.HARDWARE, requires_microscope=True, examples=[ ToolExample("Acquire volume of embryo 1", {"embryo_id": "embryo_1"}), - ToolExample("Take a 3D image of embryo 2 with 80 slices", {"embryo_id": "embryo_2", "num_slices": 80}), - ToolExample("Acquire with more Z padding", {"embryo_id": "embryo_1", "z_buffer_um": 20.0}), + ToolExample( + "Take a 3D image of embryo 2 with 80 slices", + {"embryo_id": "embryo_2", "num_slices": 80}, + ), + ToolExample( + "Acquire with more Z padding", + {"embryo_id": "embryo_1", "z_buffer_um": 20.0}, + ), ], ) async def acquire_volume( embryo_id: str, num_slices: int = 50, exposure_ms: float = 10.0, - z_buffer_um: float = None, - context: Dict = None + z_buffer_um: float | None = None, + context: dict | None = None, ) -> str: """Acquire single volume - moves to embryo first, uses calibration""" - agent = context.get('agent') - client = context.get('client') + agent = ctx_get(context, "agent") + client = ctx_get(context, "client") if not agent: return "Error: No agent context" @@ -53,22 +61,22 @@ async def acquire_volume( try: # Move to embryo position first pos = embryo.stage_position - if pos and pos.get('x') is not None and pos.get('y') is not None: - await client.move_to_position(pos['x'], pos['y']) + if pos and pos.get("x") is not None and pos.get("y") is not None: + await client.move_to_position(pos["x"], pos["y"]) # Get calibration parameters (use defaults if not calibrated) cal = embryo.calibration or {} - galvo_amplitude = cal.get('galvo_amplitude', 0.5) - galvo_center = cal.get('galvo_center', 0.0) - piezo_amplitude = cal.get('piezo_amplitude', 25.0) - piezo_center = cal.get('piezo_center', 50.0) + galvo_amplitude = cal.get("galvo_amplitude", 0.5) + galvo_center = cal.get("galvo_center", 0.0) + piezo_amplitude = cal.get("piezo_amplitude", 25.0) + piezo_center = cal.get("piezo_center", 50.0) # Override Z range if z_buffer_um is specified z_buffer_applied = None if z_buffer_um is not None and cal: # Get the original embryo extent from calibration - calibrated_buffer = cal.get('z_buffer_um', 5.0) # Old default was 5µm - slope = cal.get('slope_um_per_deg', 100.0) + calibrated_buffer = cal.get("z_buffer_um", 5.0) # Old default was 5µm + slope = cal.get("slope_um_per_deg", 100.0) # Calculate additional buffer needed additional_buffer_um = z_buffer_um - calibrated_buffer @@ -90,8 +98,8 @@ async def acquire_volume( laser_power_488_pct=embryo.laser_power_488_pct, ) - if result.get('success'): - volume = result.get('volume') + if result.get("success"): + volume = result.get("volume") timepoint = embryo.timepoints_acquired # Current timepoint (0-indexed) # Increment timepoints acquired @@ -105,10 +113,13 @@ async def acquire_volume( if agent.store and agent.session_id: try: from pathlib import Path as _Path + pos = embryo.stage_position or {} agent.store.register_embryo( - agent.session_id, embryo_id, - position_x=pos.get('x'), position_y=pos.get('y'), + agent.session_id, + embryo_id, + position_x=pos.get("x"), + position_y=pos.get("y"), calibration=embryo.calibration, role=embryo.role, ) @@ -126,17 +137,22 @@ async def acquire_volume( "piezo_center": piezo_center, }, } - volume_path_ref = result.get('volume_path') + volume_path_ref = result.get("volume_path") if volume_path_ref is not None: saved_path = agent.store.register_volume( - agent.session_id, embryo_id, timepoint, + agent.session_id, + embryo_id, + timepoint, incoming_path=_Path(volume_path_ref), metadata=acq_metadata, volume_data=volume, ) elif volume is not None: saved_path = agent.store.put_volume( - agent.session_id, embryo_id, timepoint, volume, + agent.session_id, + embryo_id, + timepoint, + volume, metadata=acq_metadata, ) except Exception as store_err: @@ -158,23 +174,29 @@ async def acquire_volume( uid=f"volume_{session_prefix}{embryo_id}_t{timepoint:04d}", data_type="volume_projection", metadata={ - 'embryo_id': embryo_id, - 'timepoint': timepoint, - 'shape': list(volume.shape) if hasattr(volume, 'shape') else None, - 'num_slices': num_slices, - 'exposure_ms': exposure_ms, - } + "embryo_id": embryo_id, + "timepoint": timepoint, + "shape": list(volume.shape) if hasattr(volume, "shape") else None, + "num_slices": num_slices, + "exposure_ms": exposure_ms, + }, ) except Exception as viz_err: logger.warning("Failed to push volume to viz: %s", viz_err) # Build response - shape_str = str(result.get('shape', 'unknown')) + shape_str = str(result.get("shape", "unknown")) z_info = f" (z_buffer: {z_buffer_applied}\u00b5m)" if z_buffer_applied else "" if saved_path: - return f"Acquired volume for {embryo.id}{z_info}\nShape: {shape_str}\nSaved: {saved_path}" + return ( + f"Acquired volume for {embryo.id}{z_info}\nShape: {shape_str}" + f"\nSaved: {saved_path}" + ) else: - return f"Acquired volume for {embryo.id}{z_info}\nShape: {shape_str}\n(Volume not saved to disk)" + return ( + f"Acquired volume for {embryo.id}{z_info}\nShape: {shape_str}" + "\n(Volume not saved to disk)" + ) else: return f"Acquisition failed: {result.get('error', 'Unknown error')}" @@ -184,12 +206,15 @@ async def acquire_volume( @tool( name="capture_lightsheet", - description="""Capture a single 2D lightsheet fluorescence image at specified piezo/galvo position. Uses 50ms exposure by default. -Use when user says "take a lightsheet image", "lightsheet snap", or wants to see fluorescence at a specific Z position. -This is a COMPLETE action - do NOT follow up with acquire_volume unless user explicitly asks for a 3D volume. + description="""Capture a single 2D lightsheet fluorescence image at specified piezo/galvo +position. Uses 50ms exposure by default. +Use when user says "take a lightsheet image", "lightsheet snap", or wants to see fluorescence +at a specific Z position. This is a COMPLETE action - do NOT follow up with acquire_volume +unless user explicitly asks for a 3D volume. -IMPORTANT: Always pass embryo_id when capturing for an embryo. This ensures the image is captured at the correct -focus position from the embryo's focus_history (set by fine_focus). Without embryo_id, focus may be incorrect. +IMPORTANT: Always pass embryo_id when capturing for an embryo. This ensures the image is +captured at the correct focus position from the embryo's focus_history (set by fine_focus). +Without embryo_id, focus may be incorrect. The piezo position is determined by priority: 1. Explicit piezo_position parameter (if provided) @@ -201,20 +226,26 @@ async def acquire_volume( requires_microscope=True, examples=[ ToolExample("Take a lightsheet image of embryo 1", {"embryo_id": "embryo_1"}), - ToolExample("Lightsheet snap at specific piezo", {"embryo_id": "embryo_1", "piezo_position": 50.0}), - ToolExample("Capture at different galvo", {"embryo_id": "embryo_1", "galvo_position": 0.5}), + ToolExample( + "Lightsheet snap at specific piezo", + {"embryo_id": "embryo_1", "piezo_position": 50.0}, + ), + ToolExample( + "Capture at different galvo", + {"embryo_id": "embryo_1", "galvo_position": 0.5}, + ), ], ) async def capture_lightsheet( - piezo_position: float = None, + piezo_position: float | None = None, galvo_position: float = 0.0, - embryo_id: str = None, + embryo_id: str | None = None, show: bool = True, - context: Dict = None + context: dict | None = None, ) -> str: """Capture and optionally display a single lightsheet image""" - client = context.get('client') - agent = context.get('agent') + client = ctx_get(context, "client") + agent = ctx_get(context, "agent") try: embryo = None @@ -224,8 +255,7 @@ async def capture_lightsheet( if embryo and embryo.stage_position: # Move stage to embryo's position await client.move_to_position( - x=embryo.stage_position['x'], - y=embryo.stage_position['y'] + x=embryo.stage_position["x"], y=embryo.stage_position["y"] ) # Determine piezo position for best focus @@ -234,17 +264,17 @@ async def capture_lightsheet( if piezo_position is None: # Check embryo's focus history first (from fine_focus) if embryo and embryo.focus_history: - # Try interpolation if we have 2+ points - fit = embryo.get_piezo_galvo_fit() - if fit is not None: - slope, intercept = fit - piezo_position = slope * galvo_position + intercept - focus_source = "interpolated" - else: - # Single point or exact match - piezo_position = embryo.get_focus_at_galvo(galvo_position) - if piezo_position is not None: - focus_source = "focus_history" + # Try interpolation if we have 2+ points + fit = embryo.get_piezo_galvo_fit() + if fit is not None: + slope, intercept = fit + piezo_position = slope * galvo_position + intercept + focus_source = "interpolated" + else: + # Single point or exact match + piezo_position = embryo.get_focus_at_galvo(galvo_position) + if piezo_position is not None: + focus_source = "focus_history" # Fall back to hardware query (unreliable) if piezo_position is None: @@ -252,13 +282,12 @@ async def capture_lightsheet( focus_source = "hardware_query" result = await client.capture_lightsheet_image( - piezo_position=piezo_position, - galvo_position=galvo_position + piezo_position=piezo_position, galvo_position=galvo_position ) - if result.get('success'): - image = result.get('image') - run_uid = result.get('run_uid', 'unknown') + if result.get("success"): + image = result.get("image") + run_uid = result.get("run_uid", "unknown") # Update embryo's last_imaged and exposure tracking if specified if embryo: @@ -278,21 +307,32 @@ async def capture_lightsheet( # Display the image from datetime import datetime from pathlib import Path + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") save_path = f"lightsheet_captures/lightsheet_{timestamp}.jpg" Path("lightsheet_captures").mkdir(exist_ok=True) - view_result = await client.view_image( + await client.view_image( image=image, title=f"Lightsheet: piezo={piezo_position:.2f}um, galvo={galvo_position}V", save_path=save_path, - show=True + show=True, + ) + return ( + f"\u2713 Captured lightsheet at piezo={piezo_position:.2f}\u03bcm," + f" galvo={galvo_position}V{focus_info}\nSaved to: {save_path}" ) - return f"\u2713 Captured lightsheet at piezo={piezo_position:.2f}\u03bcm, galvo={galvo_position}V{focus_info}\nSaved to: {save_path}" elif image is None: - return f"\u2713 Lightsheet captured at piezo={piezo_position:.2f}\u03bcm, galvo={galvo_position}V{focus_info} (image not displayed)\nRun UID: {run_uid}" + return ( + f"\u2713 Lightsheet captured at piezo={piezo_position:.2f}\u03bcm," + f" galvo={galvo_position}V{focus_info} (image not displayed)" + f"\nRun UID: {run_uid}" + ) else: - return f"\u2713 Captured lightsheet at piezo={piezo_position:.2f}\u03bcm, galvo={galvo_position}V{focus_info}" + return ( + f"\u2713 Captured lightsheet at piezo={piezo_position:.2f}\u03bcm," + f" galvo={galvo_position}V{focus_info}" + ) else: return f"Failed: {result.get('error', 'Unknown error')}" @@ -302,10 +342,12 @@ async def capture_lightsheet( @tool( name="batch_lightsheet", - description="""Capture lightsheet images from ALL embryos and display them together in a single napari viewer. -Use when user says "lightsheet all embryos", "capture all embryos", "show me all embryos in lightsheet". -Moves to each embryo, captures a lightsheet image, then opens napari with all images as separate layers. -Much more efficient than capturing one at a time.""", + description="""Capture lightsheet images from ALL embryos and show them together in the web UI. +Use when user says "lightsheet all embryos", "capture all embryos", "show me all embryos in +lightsheet". +Moves to each embryo, captures a lightsheet image, saves it, and pushes it to the +web viewer (live image strip) for everyone watching. Much more efficient than +capturing one at a time.""", category=ToolCategory.HARDWARE, requires_microscope=True, examples=[ @@ -313,13 +355,10 @@ async def capture_lightsheet( ToolExample("Capture all embryos", {}), ], ) -async def batch_lightsheet( - galvo_position: float = 0.0, - context: Dict = None -) -> str: - """Capture lightsheet images from all embryos and show in single napari viewer""" - agent = context.get('agent') - client = context.get('client') +async def batch_lightsheet(galvo_position: float = 0.0, context: dict | None = None) -> str: + """Capture lightsheet images from all embryos and show them in the web UI""" + agent = ctx_get(context, "agent") + client = ctx_get(context, "client") if not agent or not client: return "Error: Agent or microscope not available" @@ -333,8 +372,7 @@ async def batch_lightsheet( errors = [] active_embryos = [ - (eid, emb) for eid, emb in agent.experiment.embryos.items() - if not emb.should_skip + (eid, emb) for eid, emb in agent.experiment.embryos.items() if not emb.should_skip ] if not active_embryos: @@ -346,8 +384,8 @@ async def batch_lightsheet( try: # Move to embryo position if embryo.stage_position: - x = embryo.stage_position.get('x', 0) - y = embryo.stage_position.get('y', 0) + x = embryo.stage_position.get("x", 0) + y = embryo.stage_position.get("y", 0) logger.info("Moving to %s at (%.1f, %.1f)...", embryo_id, x, y) await client.move_to_position(x, y) # Wait for stage to settle @@ -359,24 +397,28 @@ async def batch_lightsheet( if embryo.calibration: # Get piezo center - if embryo.calibration.get('piezo_center'): - piezo_position = embryo.calibration['piezo_center'] - elif embryo.calibration.get('focus_position'): - piezo_position = embryo.calibration['focus_position'] + if embryo.calibration.get("piezo_center"): + piezo_position = embryo.calibration["piezo_center"] + elif embryo.calibration.get("focus_position"): + piezo_position = embryo.calibration["focus_position"] # Get galvo center (critical for light sheet alignment) - if embryo.calibration.get('galvo_center'): - embryo_galvo = embryo.calibration['galvo_center'] + if embryo.calibration.get("galvo_center"): + embryo_galvo = embryo.calibration["galvo_center"] # Capture lightsheet - logger.info("Capturing %s at piezo=%.1f um, galvo=%.2f...", embryo_id, piezo_position, embryo_galvo) + logger.info( + "Capturing %s at piezo=%.1f um, galvo=%.2f...", + embryo_id, + piezo_position, + embryo_galvo, + ) result = await client.capture_lightsheet_image( - piezo_position=piezo_position, - galvo_position=embryo_galvo + piezo_position=piezo_position, galvo_position=embryo_galvo ) - if result.get('success') and result.get('image') is not None: - images.append(result['image']) + if result.get("success") and result.get("image") is not None: + images.append(result["image"]) embryo_ids.append(embryo_id) # Track light exposure (default 50ms) embryo.record_exposure(exposure_ms=50.0, num_frames=1) @@ -392,43 +434,40 @@ async def batch_lightsheet( # Save images from datetime import datetime from pathlib import Path + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") save_dir = Path("lightsheet_captures") / f"batch_{timestamp}" save_dir.mkdir(parents=True, exist_ok=True) - for i, (img, eid) in enumerate(zip(images, embryo_ids)): + for _i, (img, eid) in enumerate(zip(images, embryo_ids, strict=False)): import tifffile + save_path = save_dir / f"{eid}.tiff" tifffile.imwrite(str(save_path), img) logger.info("Saved %d images to %s", len(images), save_dir) - # Open single napari viewer with all images as a stack - import napari - import numpy as np - logger.info("Opening napari with %d embryo images as stack...", len(images)) - - # Stack images into a single array for slider navigation - image_stack = np.stack(images, axis=0) - - viewer = napari.Viewer(title=f"Batch Lightsheet - {len(images)} embryos") - - # Add as single stack with slider (grayscale) - viewer.add_image( - image_stack, - name='Embryos', - colormap='gray', - ) - - # Print embryo ID mapping for reference - logger.info("Slider index -> Embryo ID:") - for i, eid in enumerate(embryo_ids): - logger.info(" %d: %s", i, eid) - - napari.run() + # Push each captured image to the web UI \u2014 no blocking desktop window. + # They appear in the live viewer / recent strip for everyone watching. + pushed = 0 + if agent.viz_server is not None: + for img, eid in zip(images, embryo_ids, strict=False): + uid = f"batch_lightsheet_{eid}_{timestamp}" + agent.push_viz( + img, + uid, + "image", + {"embryo_id": eid, "source": "batch_lightsheet", "label": eid}, + ) + pushed += 1 + logger.info("Pushed %d batch-lightsheet images to the web UI", pushed) # Summary summary = f"\u2713 Captured {len(images)} embryos: {', '.join(embryo_ids)}" + if pushed: + summary += f"\nShowing {pushed} image(s) in the web UI viewer." + elif agent.viz_server is None: + summary += "\n(Web UI not running \u2014 images saved to disk only.)" if errors: summary += f"\n\u26a0 Errors: {'; '.join(errors)}" summary += f"\nSaved to: {save_dir}" diff --git a/gently/app/tools/analysis_tools.py b/gently/app/tools/analysis_tools.py index 75422e77..1e897b61 100644 --- a/gently/app/tools/analysis_tools.py +++ b/gently/app/tools/analysis_tools.py @@ -4,10 +4,8 @@ Tools for analyzing embryo images using Claude Vision. """ -from typing import Dict, Optional - -from gently.harness.tools.registry import tool, ToolCategory -from gently.harness.tools.helpers import require_agent, get_embryo_or_error +from gently.harness.tools.helpers import get_embryo_or_error, require_agent +from gently.harness.tools.registry import ToolCategory, ToolExample, tool @tool( @@ -19,8 +17,8 @@ async def analyze_volume( embryo_id: str, analysis_prompt: str, use_recent_context: bool = False, - timepoint: Optional[int] = None, - context: Dict = None + timepoint: int | None = None, + context: dict | None = None, ) -> str: """Analyze embryo volume with Claude Vision""" agent, err = require_agent(context) @@ -36,19 +34,107 @@ async def analyze_volume( embryo_id=embryo.id, prompt=analysis_prompt, use_context=use_recent_context, - timepoint=timepoint + timepoint=timepoint, ) return result except Exception as e: return f"Error analyzing volume: {str(e)}" +@tool( + name="get_recent_perceptions", + description="""Get the latest perception state for one embryo or all embryos: +current developmental stage, how many consecutive observations it has held that stage +(stability), a possible-arrest signal, the recent stage trajectory, and the +perceiver's reasoning. Source: the LIVE perception loop (reads accumulated state, +does not trigger a fresh capture). +Use when the user asks "what stage is embryo X", "is anything stuck/arrested", +"how are the embryos developing", or before deciding whether to adapt acquisition.""", + category=ToolCategory.ANALYSIS, + examples=[ + ToolExample("What stage is embryo_1 at?", {"embryo_id": "embryo_1"}), + ToolExample("How is everything developing?", {}), + ToolExample("Is anything arrested?", {}), + ], +) +def get_recent_perceptions( + embryo_id: str | None = None, + n: int = 5, + context: dict | None = None, +) -> str: + """Read live per-embryo perception state from the perception sessions. + + All reads here (get_session / summary / attribute access) are synchronous and + side-effect-free — they never trigger a VLM call. + """ + agent, err = require_agent(context) + if err: + return err + + perceiver = getattr(agent, "perceiver", None) + if perceiver is None: + return "Perception system not available." + + def _one(eid: str) -> str: + try: + session = perceiver.get_session(eid) + except Exception as e: + return f"{eid}: perception read failed ({e})" + if session is None or not getattr(session, "current_stage", None): + return f"{eid}: no perceptions recorded yet" + summary = session.summary() + lines = [ + f"{eid}: stage={summary.get('current_stage')} " + f"(stable for {summary.get('stability', 0)} obs, " + f"{summary.get('observation_count', 0)} total)" + ] + seq = summary.get("stage_sequence") or [] + if seq: + lines.append(f" trajectory: {' -> '.join(seq)}") + temporal = summary.get("temporal") # TemporalContext dataclass or None + if temporal is not None: + tmin = getattr(temporal, "time_in_stage_min", 0.0) + exp = getattr(temporal, "expected_duration_min", None) + seg = f" time in stage: {tmin:.0f} min" + if exp: + seg += ( + f" (expected ~{exp:.0f} min, {getattr(temporal, 'overtime_ratio', 0.0):.1f}x)" + ) + lines.append(seg) + if getattr(temporal, "is_potentially_arrested", False): + lines.append(" ** potentially ARRESTED **") + observations = getattr(session, "observations", None) or [] + if observations and n > 0: + recent = observations[-n:] + lines.append(f" recent observations (last {len(recent)}):") + for o in recent: + reason = (getattr(o, "reasoning", "") or "").strip().replace("\n", " ") + if len(reason) > 160: + reason = reason[:159] + "…" + lines.append( + f" t{getattr(o, 'timepoint', '?')}: {getattr(o, 'stage', '?')} - {reason}" + ) + return "\n".join(lines) + + if embryo_id: + return _one(embryo_id) + + embryos = getattr(agent.experiment, "embryos", {}) or {} + if not embryos: + return "No embryos in the experiment." + out = ["Perception state (all embryos):", ""] + for eid in sorted(embryos): + out.append(_one(eid)) + out.append("") + return "\n".join(out).rstrip() + + @tool( name="get_detection_summary", description="Get summary of all detections across all embryos", category=ToolCategory.DETECTION, ) -def get_detection_summary(context: Dict) -> str: +def get_detection_summary(context: dict) -> str: """Get detection summary""" agent, err = require_agent(context) if err: @@ -62,7 +148,10 @@ def get_detection_summary(context: Dict) -> str: for det_name, results in embryo.detection_results.items(): if results: latest = results[-1] - lines.append(f" - {det_name}: {latest.get('detected', False)} at t={latest.get('timepoint', '?')}") + lines.append( + f" - {det_name}: {latest.get('detected', False)}" + f" at t={latest.get('timepoint', '?')}" + ) lines.append("") if len(lines) == 2: diff --git a/gently/app/tools/calibration_tools.py b/gently/app/tools/calibration_tools.py index 963dc2d0..3db84e79 100644 --- a/gently/app/tools/calibration_tools.py +++ b/gently/app/tools/calibration_tools.py @@ -6,25 +6,25 @@ """ import logging -from typing import Dict, List, Optional, Tuple from datetime import datetime -import json -import asyncio logger = logging.getLogger(__name__) -import numpy as np +import numpy as np # noqa: E402 -from gently.harness.tools.registry import tool, ToolCategory, ToolExample -from gently.harness.tools.helpers import get_embryo_or_error -from gently.harness.state import CalibrationPrior -from gently.analysis.core import AdaptiveSweepState, FitFunction, fit_focus_curve -from gently.ui.web.plots import ( - generate_focus_curve_plot, +from gently.analysis.core import AdaptiveSweepState, FitFunction, fit_focus_curve # noqa: E402 +from gently.harness.state import CalibrationPrior # noqa: E402 +from gently.harness.tools.helpers import ctx_get, get_embryo_or_error # noqa: E402 +from gently.harness.tools.registry import ToolCategory, ToolExample, tool # noqa: E402 +from gently.ui.web.plots import ( # noqa: E402 generate_calibration_summary_plot, - generate_edge_detection_plot, + generate_focus_curve_plot, +) + +from .hardware_common import ( # noqa: E402 + select_best_view, + select_view_and_crop_roi, ) -from .hardware_common import select_best_view, crop_to_embryo_roi, select_view_and_crop_roi async def _adaptive_focus_sweep( @@ -37,7 +37,7 @@ async def _adaptive_focus_sweep( session_prior: CalibrationPrior, select_best_view, calculate_focus_score, -) -> Tuple[Dict, int]: +) -> tuple[dict, int]: """ Adaptive focus sweep with early stopping. @@ -79,7 +79,7 @@ async def _adaptive_focus_sweep( # Adaptive parameters based on prior confidence SPARSE_STEP = 5.0 # µm - larger steps for survey - DENSE_STEP = 0.5 # µm - fine steps for refinement + DENSE_STEP = 0.5 # µm - fine steps for refinement DENSE_RANGE = 3.0 # µm - narrow window around peak MIN_R_SQUARED = 0.75 @@ -90,28 +90,29 @@ async def _adaptive_focus_sweep( sparse_range = 15.0 # µm sweep range logger.info("=== ADAPTIVE %s FOCUS SWEEP at galvo=%.3f deg ===", galvo_name.upper(), galvo_pos) - logger.info("Using range: +/-%.1f um (prior: %d calibrations)", sparse_range, session_prior.num_calibrations) + logger.info( + "Using range: +/-%.1f um (prior: %d calibrations)", + sparse_range, + session_prior.num_calibrations, + ) # --- PHASE 1: SPARSE SURVEY --- logger.info("Phase 1: Sparse survey +/-%.1f um, %.1f um steps...", sparse_range, SPARSE_STEP) sparse_state = AdaptiveSweepState() sparse_positions = np.arange( - expected_piezo - sparse_range, - expected_piezo + sparse_range + SPARSE_STEP, - SPARSE_STEP + expected_piezo - sparse_range, expected_piezo + sparse_range + SPARSE_STEP, SPARSE_STEP ) for piezo in sparse_positions: result = await client.capture_lightsheet_image( - piezo_position=float(piezo), - galvo_position=float(galvo_pos) + piezo_position=float(piezo), galvo_position=float(galvo_pos) ) - if result.get('success'): + if result.get("success"): total_exposures += 1 - if result.get('success') and result.get('image') is not None: - img = select_best_view(result['image']) - score = calculate_focus_score(img, algorithm='fft_bandpass') + if result.get("success") and result.get("image") is not None: + img = select_best_view(result["image"]) + score = calculate_focus_score(img, algorithm="fft_bandpass") # Add to adaptive state and check for early stopping decision = sparse_state.add_point(float(piezo), float(score)) @@ -123,43 +124,52 @@ async def _adaptive_focus_sweep( uid=f"focus_sparse_{embryo_id}_{galvo_name}_{piezo:.1f}", data_type="focus_sweep", metadata={ - 'embryo_id': embryo_id, - 'sweep': 'sparse', - 'galvo_name': galvo_name, - 'galvo': float(galvo_pos), - 'piezo': float(piezo), - 'score': float(score), - 'peak_detected': sparse_state.peak_detected, - } + "embryo_id": embryo_id, + "sweep": "sparse", + "galvo_name": galvo_name, + "galvo": float(galvo_pos), + "piezo": float(piezo), + "score": float(score), + "peak_detected": sparse_state.peak_detected, + }, ) - if decision['should_stop']: - logger.info("Early stop: %s (confidence: %.2f)", decision['reason'], decision['confidence']) + if decision["should_stop"]: + logger.info( + "Early stop: %s (confidence: %.2f)", decision["reason"], decision["confidence"] + ) break sparse_best, sparse_r2 = sparse_state.get_best_position() - logger.info("Sparse best: %.1f um (R2=%.3f, %d points)", sparse_best, sparse_r2, len(sparse_state.positions)) + logger.info( + "Sparse best: %.1f um (R2=%.3f, %d points)", + sparse_best, + sparse_r2, + len(sparse_state.positions), + ) # --- PHASE 2: DENSE REFINEMENT --- - logger.info("Phase 2: Dense refinement +/-%.1f um around %.1f um, %.1f um steps...", DENSE_RANGE, sparse_best, DENSE_STEP) + logger.info( + "Phase 2: Dense refinement +/-%.1f um around %.1f um, %.1f um steps...", + DENSE_RANGE, + sparse_best, + DENSE_STEP, + ) dense_state = AdaptiveSweepState() dense_positions = np.arange( - sparse_best - DENSE_RANGE, - sparse_best + DENSE_RANGE + DENSE_STEP, - DENSE_STEP + sparse_best - DENSE_RANGE, sparse_best + DENSE_RANGE + DENSE_STEP, DENSE_STEP ) for piezo in dense_positions: result = await client.capture_lightsheet_image( - piezo_position=float(piezo), - galvo_position=float(galvo_pos) + piezo_position=float(piezo), galvo_position=float(galvo_pos) ) - if result.get('success'): + if result.get("success"): total_exposures += 1 - if result.get('success') and result.get('image') is not None: - img = select_best_view(result['image']) - score = calculate_focus_score(img, algorithm='fft_bandpass') + if result.get("success") and result.get("image") is not None: + img = select_best_view(result["image"]) + score = calculate_focus_score(img, algorithm="fft_bandpass") # Add to adaptive state and check for early stopping decision = dense_state.add_point(float(piezo), float(score)) @@ -173,18 +183,20 @@ async def _adaptive_focus_sweep( uid=f"focus_dense_{embryo_id}_{galvo_name}_{piezo:.1f}", data_type="focus_sweep", metadata={ - 'embryo_id': embryo_id, - 'sweep': 'dense', - 'galvo_name': galvo_name, - 'galvo': float(galvo_pos), - 'piezo': float(piezo), - 'score': float(score), - 'r_squared': dense_state.current_r_squared, - } + "embryo_id": embryo_id, + "sweep": "dense", + "galvo_name": galvo_name, + "galvo": float(galvo_pos), + "piezo": float(piezo), + "score": float(score), + "r_squared": dense_state.current_r_squared, + }, ) - if decision['should_stop']: - logger.info("Early stop: %s (confidence: %.2f)", decision['reason'], decision['confidence']) + if decision["should_stop"]: + logger.info( + "Early stop: %s (confidence: %.2f)", decision["reason"], decision["confidence"] + ) break # Get final result @@ -196,9 +208,7 @@ async def _adaptive_focus_sweep( try: positions = np.array(dense_state.positions) scores = np.array(dense_state.scores) - _, _, params, r_squared = fit_focus_curve( - positions, scores, FitFunction.GAUSSIAN.value - ) + _, _, params, r_squared = fit_focus_curve(positions, scores, FitFunction.GAUSSIAN.value) if r_squared >= 0.5: best_piezo = float(params[1]) best_piezo = max(min(best_piezo, positions.max()), positions.min()) @@ -214,15 +224,20 @@ async def _adaptive_focus_sweep( fit_quality = "fallback" logger.info("Best focus: piezo=%.2f um (R2=%.3f, %s)", best_piezo, r_squared, fit_quality) - logger.info("Total exposures: %d (sparse: %d, dense: %d)", total_exposures, len(sparse_state.positions), len(dense_state.positions)) + logger.info( + "Total exposures: %d (sparse: %d, dense: %d)", + total_exposures, + len(sparse_state.positions), + len(dense_state.positions), + ) # Build result dict result_dict = { - 'galvo': galvo_pos, - 'piezo': best_piezo, - 'max_score': float(max(dense_state.scores)) if dense_state.scores else 0.0, - 'r_squared': r_squared, - 'fit_params': None, # Will be set by focus curve plot generation + "galvo": galvo_pos, + "piezo": best_piezo, + "max_score": float(max(dense_state.scores)) if dense_state.scores else 0.0, + "r_squared": r_squared, + "fit_params": None, # Will be set by focus curve plot generation } # Push focus curve plot @@ -232,7 +247,7 @@ async def _adaptive_focus_sweep( scores = np.array(dense_state.scores) try: _, _, fit_params, _ = fit_focus_curve(positions, scores, FitFunction.GAUSSIAN.value) - result_dict['fit_params'] = fit_params + result_dict["fit_params"] = fit_params except Exception: fit_params = None @@ -242,21 +257,21 @@ async def _adaptive_focus_sweep( best_position=best_piezo, fit_params=fit_params, r_squared=r_squared, - title=f'{embryo_id} - {galvo_name.upper()} Focus Curve (Adaptive)', + title=f"{embryo_id} - {galvo_name.upper()} Focus Curve (Adaptive)", ) agent.push_viz( array=plot_img, uid=f"focus_curve_{embryo_id}_{galvo_name}", data_type="focus_plot", metadata={ - 'embryo_id': embryo_id, - 'galvo_name': galvo_name, - 'galvo': float(galvo_pos), - 'best_piezo': best_piezo, - 'r_squared': r_squared, - 'adaptive': True, - 'exposures': total_exposures, - } + "embryo_id": embryo_id, + "galvo_name": galvo_name, + "galvo": float(galvo_pos), + "best_piezo": best_piezo, + "r_squared": r_squared, + "adaptive": True, + "exposures": total_exposures, + }, ) except Exception as plot_err: logger.warning("Failed to generate focus plot: %s", plot_err) @@ -273,7 +288,7 @@ async def _fine_focus_sweep( expected_piezo: float, select_best_view, calculate_focus_score, -) -> Tuple[Dict, int]: +) -> tuple[dict, int]: """ Fine-only focus sweep - assumes heuristic is close. @@ -305,31 +320,33 @@ async def _fine_focus_sweep( (result_dict, total_exposures) result_dict has 'galvo', 'piezo', 'max_score', 'r_squared', 'fit_params' """ - FINE_RANGE = 5.0 # ±5µm around expected - FINE_STEP = 0.5 # 0.5µm steps + FINE_RANGE = 5.0 # ±5µm around expected + FINE_STEP = 0.5 # 0.5µm steps MIN_R_SQUARED = 0.75 total_exposures = 0 logger.info("=== FINE-ONLY %s FOCUS SWEEP at galvo=%.3f deg ===", galvo_name.upper(), galvo_pos) - logger.info("Sweeping +/-%.1f um around heuristic (%.1f um), %.1f um steps...", FINE_RANGE, expected_piezo, FINE_STEP) + logger.info( + "Sweeping +/-%.1f um around heuristic (%.1f um), %.1f um steps...", + FINE_RANGE, + expected_piezo, + FINE_STEP, + ) positions = np.arange( - expected_piezo - FINE_RANGE, - expected_piezo + FINE_RANGE + FINE_STEP, - FINE_STEP + expected_piezo - FINE_RANGE, expected_piezo + FINE_RANGE + FINE_STEP, FINE_STEP ) piezo_scores = [] for piezo in positions: result = await client.capture_lightsheet_image( - piezo_position=float(piezo), - galvo_position=float(galvo_pos) + piezo_position=float(piezo), galvo_position=float(galvo_pos) ) - if result.get('success'): + if result.get("success"): total_exposures += 1 - if result.get('success') and result.get('image') is not None: - img = select_best_view(result['image']) - score = calculate_focus_score(img, algorithm='fft_bandpass') + if result.get("success") and result.get("image") is not None: + img = select_best_view(result["image"]) + score = calculate_focus_score(img, algorithm="fft_bandpass") piezo_scores.append((float(piezo), float(score))) logger.debug("piezo=%.1f: score=%.2e", piezo, score) @@ -341,13 +358,13 @@ async def _fine_focus_sweep( uid=f"focus_fine_{embryo_id}_{galvo_name}_{piezo:.1f}", data_type="focus_sweep", metadata={ - 'embryo_id': embryo_id, - 'sweep': 'fine_only', - 'galvo_name': galvo_name, - 'galvo': float(galvo_pos), - 'piezo': float(piezo), - 'score': float(score), - } + "embryo_id": embryo_id, + "sweep": "fine_only", + "galvo_name": galvo_name, + "galvo": float(galvo_pos), + "piezo": float(piezo), + "score": float(score), + }, ) # Fit Gaussian to find peak @@ -392,11 +409,11 @@ async def _fine_focus_sweep( # Build result dict result_dict = { - 'galvo': galvo_pos, - 'piezo': best_piezo, - 'max_score': max_score, - 'r_squared': r_squared, - 'fit_params': fit_params, + "galvo": galvo_pos, + "piezo": best_piezo, + "max_score": max_score, + "r_squared": r_squared, + "fit_params": fit_params, } # Push focus curve plot @@ -411,21 +428,21 @@ async def _fine_focus_sweep( best_position=best_piezo, fit_params=fit_params, r_squared=r_squared, - title=f'{embryo_id} - {galvo_name.upper()} Focus Curve (Fine-Only)', + title=f"{embryo_id} - {galvo_name.upper()} Focus Curve (Fine-Only)", ) agent.push_viz( array=plot_img, uid=f"focus_curve_{embryo_id}_{galvo_name}", data_type="focus_plot", metadata={ - 'embryo_id': embryo_id, - 'galvo_name': galvo_name, - 'galvo': float(galvo_pos), - 'best_piezo': best_piezo, - 'r_squared': r_squared, - 'fine_only': True, - 'exposures': total_exposures, - } + "embryo_id": embryo_id, + "galvo_name": galvo_name, + "galvo": float(galvo_pos), + "best_piezo": best_piezo, + "r_squared": r_squared, + "fine_only": True, + "exposures": total_exposures, + }, ) except Exception as plot_err: logger.warning("Failed to generate focus plot: %s", plot_err) @@ -437,14 +454,15 @@ async def _fine_focus_sweep( # FAST CALIBRATION ALGORITHM (Vision-Guided) # ============================================================================ + async def hybrid_focus_selection( - images: List[np.ndarray], - offsets: List[float], + images: list[np.ndarray], + offsets: list[float], claude_vision, agent, embryo_id: str, - fft_confidence_threshold: float = 0.85 -) -> Tuple[int, str, float]: + fft_confidence_threshold: float = 0.85, +) -> tuple[int, str, float | None]: """ Two-stage focus selection: FFT first, Vision if ambiguous. @@ -475,7 +493,7 @@ async def hybrid_focus_selection( from gently.analysis.core import calculate_focus_score # Stage 1: FFT scoring (instant) - scores = [calculate_focus_score(img, 'fft_bandpass') for img in images] + scores = [calculate_focus_score(img, "fft_bandpass") for img in images] max_score = max(scores) best_idx = scores.index(max_score) @@ -484,27 +502,29 @@ async def hybrid_focus_selection( sorted_ratios = sorted(score_ratios, reverse=True) second_best_ratio = sorted_ratios[1] if len(sorted_ratios) > 1 else 0 - confidence_ratio = 1.0 / second_best_ratio if second_best_ratio > 0 else float('inf') + confidence_ratio = 1.0 / second_best_ratio if second_best_ratio > 0 else float("inf") if second_best_ratio < fft_confidence_threshold: # FFT is confident - best is >15% better than second logger.info("FFT confident: position %d (ratio %.2f)", best_idx, confidence_ratio) - return best_idx, 'fft', confidence_ratio + return best_idx, "fft", confidence_ratio # Stage 2: FFT ambiguous - ask Vision logger.info("FFT ambiguous (ratio %.2f), consulting Vision...", confidence_ratio) import tempfile from pathlib import Path + from PIL import Image + from gently.analysis.core import create_focus_montage # Create montage - labels = [chr(ord('A') + i) for i in range(len(images))] + labels = [chr(ord("A") + i) for i in range(len(images))] montage = create_focus_montage(images, labels=labels, offsets=offsets) # Save to temp file - with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f: + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: montage_path = Path(f.name) Image.fromarray(montage).save(montage_path) @@ -522,15 +542,15 @@ async def hybrid_focus_selection( uid=f"focus_montage_{embryo_id}", data_type="focus_montage", metadata={ - 'embryo_id': embryo_id, - 'offsets': offsets, - 'fft_scores': scores, - 'vision_pick': vision_label, - 'reasoning': reasoning, - } + "embryo_id": embryo_id, + "offsets": offsets, + "fft_scores": scores, + "vision_pick": vision_label, + "reasoning": reasoning, + }, ) - return vision_idx, 'vision', None + return vision_idx, "vision", None finally: # Clean up temp file @@ -548,8 +568,8 @@ async def binary_edge_search( embryo_id: str, piezo_heuristic: float, max_range: float = 0.25, - num_iterations: int = 4 -) -> Tuple[float, int]: + num_iterations: int = 4, +) -> tuple[float, int]: """ Binary search for embryo edge in 4 steps. @@ -582,9 +602,10 @@ async def binary_edge_search( """ import tempfile from pathlib import Path + from PIL import Image - sign = -1 if direction == 'top' else 1 + sign = -1 if direction == "top" else 1 low, high = 0.0, max_range * sign last_visible = 0.0 exposures = 0 @@ -600,23 +621,22 @@ async def binary_edge_search( # Capture image result = await client.capture_lightsheet_image( - piezo_position=float(piezo), - galvo_position=float(mid) + piezo_position=float(piezo), galvo_position=float(mid) ) exposures += 1 - if not result.get('success') or result.get('image') is None: + if not result.get("success") or result.get("image") is None: # Assume not visible on failure high = mid if sign > 0 else low low = mid if sign < 0 else high continue - img = select_best_view(result['image']) + img = select_best_view(result["image"]) # Check visibility with Vision img_norm = ((img - img.min()) / (img.max() - img.min() + 1e-8) * 255).astype(np.uint8) - with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f: + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: temp_path = Path(f.name) Image.fromarray(img_norm).save(temp_path) @@ -628,7 +648,9 @@ async def binary_edge_search( except Exception: pass - logger.debug("iter %d: galvo=%.3f deg, visible=%s, features=%s", i + 1, mid, visible, feature_score) + logger.debug( + "iter %d: galvo=%.3f deg, visible=%s, features=%s", i + 1, mid, visible, feature_score + ) if visible: last_visible = mid @@ -648,10 +670,8 @@ async def binary_edge_search( async def fast_calibrate_embryo( - embryo_id: str, - context: Dict, - z_buffer_um: float = 25.0 -) -> Tuple[bool, str, int]: + embryo_id: str, context: dict, z_buffer_um: float = 25.0 +) -> tuple[bool, str, int]: """ Vision-guided fast calibration using session slope. @@ -676,14 +696,14 @@ async def fast_calibrate_embryo( """ from gently.hardware.dispim.claude_client import AsyncClaudeClient - agent = context.get('agent') - client = context.get('client') + agent = ctx_get(context, "agent") + client = ctx_get(context, "client") if not agent: return False, "Error: No agent context", 0 - if not client or not getattr(client, 'is_connected', False): - return False, f"Error: Not connected to microscope server", 0 + if not client or not getattr(client, "is_connected", False): + return False, "Error: Not connected to microscope server", 0 embryo, err = get_embryo_or_error(agent, embryo_id) if err: @@ -715,12 +735,12 @@ async def fast_calibrate_embryo( piezo_heuristic = session_prior.offset_um if session_prior.num_calibrations > 0 else 0.0 galvo_top, exp_top = await binary_edge_search( - client, claude_vision, 'top', agent, embryo_id, piezo_heuristic + client, claude_vision, "top", agent, embryo_id, piezo_heuristic ) total_exposures += exp_top galvo_bottom, exp_bottom = await binary_edge_search( - client, claude_vision, 'bottom', agent, embryo_id, piezo_heuristic + client, claude_vision, "bottom", agent, embryo_id, piezo_heuristic ) total_exposures += exp_bottom @@ -728,7 +748,11 @@ async def fast_calibrate_embryo( # Validate edges if galvo_top >= galvo_bottom: - return False, f"Invalid edges: top {galvo_top:.3f}° >= bottom {galvo_bottom:.3f}°", total_exposures + return ( + False, + f"Invalid edges: top {galvo_top:.3f}° >= bottom {galvo_bottom:.3f}°", + total_exposures, + ) galvo_center = (galvo_top + galvo_bottom) / 2 galvo_extent = galvo_bottom - galvo_top @@ -737,7 +761,11 @@ async def fast_calibrate_embryo( logger.info("Step 2: Focus at center (galvo=%.3f deg)...", galvo_center) # Use session slope or heuristic - slope = session_prior.slope_um_per_deg if session_prior.is_ready_for_fast_calibration() else HEURISTIC_SLOPE + slope = ( + session_prior.slope_um_per_deg + if session_prior.is_ready_for_fast_calibration() + else HEURISTIC_SLOPE + ) piezo_expected = slope * galvo_center + session_prior.offset_um # Capture 3-point focus grid (±2µm) @@ -746,13 +774,12 @@ async def fast_calibrate_embryo( for offset in focus_offsets: result = await client.capture_lightsheet_image( - piezo_position=float(piezo_expected + offset), - galvo_position=float(galvo_center) + piezo_position=float(piezo_expected + offset), galvo_position=float(galvo_center) ) total_exposures += 1 - if result.get('success') and result.get('image') is not None: - img = select_best_view(result['image']) + if result.get("success") and result.get("image") is not None: + img = select_best_view(result["image"]) focus_images.append(img) else: focus_images.append(np.zeros((100, 100), dtype=np.uint8)) @@ -782,12 +809,12 @@ async def fast_calibrate_embryo( for ext_offset in extend_offsets: result = await client.capture_lightsheet_image( piezo_position=float(piezo_expected + ext_offset), - galvo_position=float(galvo_center) + galvo_position=float(galvo_center), ) total_exposures += 1 - if result.get('success') and result.get('image') is not None: - img = select_best_view(result['image']) + if result.get("success") and result.get("image") is not None: + img = select_best_view(result["image"]) focus_images.append(img) focus_offsets.append(ext_offset) @@ -814,12 +841,12 @@ async def fast_calibrate_embryo( for offset in [-2.0, 0.0, 2.0]: result = await client.capture_lightsheet_image( piezo_position=float(piezo_expected_second + offset), - galvo_position=float(galvo_second) + galvo_position=float(galvo_second), ) total_exposures += 1 - if result.get('success') and result.get('image') is not None: - img = select_best_view(result['image']) + if result.get("success") and result.get("image") is not None: + img = select_best_view(result["image"]) focus_images_2.append(img) else: focus_images_2.append(np.zeros((100, 100), dtype=np.uint8)) @@ -835,7 +862,9 @@ async def fast_calibrate_embryo( # Lock session slope session_prior.lock_session_slope(calibrated_slope, 0.85, embryo_id) - logger.info("Session slope locked: %.2f um/deg (bootstrap embryo: %s)", calibrated_slope, embryo_id) + logger.info( + "Session slope locked: %.2f um/deg (bootstrap embryo: %s)", calibrated_slope, embryo_id + ) else: # Fast mode - use session slope, calculate offset calibrated_slope = session_prior.slope_um_per_deg @@ -843,18 +872,22 @@ async def fast_calibrate_embryo( # --- STEP 6: Store Calibration --- embryo.calibration = { - 'slope_um_per_deg': calibrated_slope, - 'offset_um': calibrated_offset, - 'galvo_top_deg': galvo_top, - 'galvo_bottom_deg': galvo_bottom, - 'galvo_calib_top_deg': galvo_center, - 'galvo_calib_bottom_deg': galvo_center + 0.3 * galvo_extent if is_bootstrap else galvo_center, - 'piezo_calib_top_um': piezo_center, - 'piezo_calib_bottom_um': piezo_center if not is_bootstrap else piezo_center + calibrated_slope * 0.3 * galvo_extent, - 'r_squared': 0.85, # Assumed for Vision-based selection - 'method': 'fast_vision_guided', - 'bootstrap': is_bootstrap, - 'timestamp': datetime.now().isoformat(), + "slope_um_per_deg": calibrated_slope, + "offset_um": calibrated_offset, + "galvo_top_deg": galvo_top, + "galvo_bottom_deg": galvo_bottom, + "galvo_calib_top_deg": galvo_center, + "galvo_calib_bottom_deg": galvo_center + 0.3 * galvo_extent + if is_bootstrap + else galvo_center, + "piezo_calib_top_um": piezo_center, + "piezo_calib_bottom_um": piezo_center + if not is_bootstrap + else piezo_center + calibrated_slope * 0.3 * galvo_extent, + "r_squared": 0.85, # Assumed for Vision-based selection + "method": "fast_vision_guided", + "bootstrap": is_bootstrap, + "timestamp": datetime.now().isoformat(), } # Calculate volume parameters @@ -862,17 +895,17 @@ async def fast_calibrate_embryo( total_range_um = extent_um + 2 * z_buffer_um recommended_slices = max(30, min(150, int(total_range_um / 0.5))) # 0.5µm per slice - embryo.calibration['volume_params'] = { - 'piezo_start_um': calibrated_offset + calibrated_slope * galvo_top - z_buffer_um, - 'piezo_end_um': calibrated_offset + calibrated_slope * galvo_bottom + z_buffer_um, - 'total_range_um': total_range_um, - 'recommended_slices': recommended_slices, + embryo.calibration["volume_params"] = { + "piezo_start_um": calibrated_offset + calibrated_slope * galvo_top - z_buffer_um, + "piezo_end_um": calibrated_offset + calibrated_slope * galvo_bottom + z_buffer_um, + "total_range_um": total_range_um, + "recommended_slices": recommended_slices, } agent._save_state() msg = f"""\u2713 Fast calibration complete for {embryo_id} - Mode: {'BOOTSTRAP' if is_bootstrap else 'FAST'} + Mode: {"BOOTSTRAP" if is_bootstrap else "FAST"} Slope: {calibrated_slope:.2f} \u00b5m/\u00b0 Offset: {calibrated_offset:.2f} \u00b5m Edges: {galvo_top:.3f}\u00b0 to {galvo_bottom:.3f}\u00b0 ({galvo_extent:.3f}\u00b0 extent) @@ -900,7 +933,8 @@ async def fast_calibrate_embryo( Use after detection to prepare an embryo for volume acquisition. Takes ~2-4 minutes per embryo. The z_buffer_um parameter controls how much empty space is captured above and below the embryo. -Default is 25\u00b5m. Increase for more context (useful for segmentation), decrease for faster acquisition.""", +Default is 25\u00b5m. Increase for more context (useful for segmentation), decrease for faster +acquisition.""", category=ToolCategory.CALIBRATION, requires_microscope=True, examples=[ @@ -912,15 +946,15 @@ async def fast_calibrate_embryo( async def calibrate_embryo( embryo_id: str, skip_edge_detection: bool = False, - galvo_top: float = None, - galvo_bottom: float = None, + galvo_top: float | None = None, + galvo_bottom: float | None = None, edge_step: float = 0.05, edge_max_range: float = 0.5, edge_tolerance_deg: float = 0.20, inset_fraction: float = 0.4, z_buffer_um: float = 25.0, use_v04_plan: bool = False, - context: Dict = None + context: dict | None = None, ) -> str: """Run piezo-galvo calibration with Claude vision edge detection. @@ -956,35 +990,35 @@ async def calibrate_embryo( not on the agent side. If you need it, wire it through the queue server plan-submission API. Until then, the surgical path IS the v0.4.0 path. """ - import numpy as np import tempfile from pathlib import Path + + import numpy as np from PIL import Image - from gently.analysis.core import calculate_focus_score, fit_focus_curve, FitFunction + + from gently.analysis.core import calculate_focus_score from gently.hardware.dispim.claude_client import AsyncClaudeClient - agent = context.get('agent') - client = context.get('client') + agent = ctx_get(context, "agent") + client = ctx_get(context, "client") if not agent: return "Error: No agent context" - if not client or not getattr(client, 'is_connected', False): + if not client or not getattr(client, "is_connected", False): return f"Error: Not connected to microscope server. Cannot calibrate {embryo_id}." if use_v04_plan: - # Escape hatch hook. Not wired yet - delegating to the real Bluesky - # plan requires a RunEngine and device objects that live on the device - # layer, so the caller would have to submit the plan through the queue - # server. Since the surgical path already mirrors v0.4.0's behavior, - # this is a placeholder for a future hardware-regression follow-up. - raise NotImplementedError( - "use_v04_plan=True is not wired yet. The default surgical path " - "in calibrate_embryo already replicates the v0.4.0 calibration " - "plan's behavior (edge detection + inset + wide adaptive sweep). " - "If that path regresses on hardware, wire this branch to submit " - "gently.hardware.dispim.plans.calibration.calibrate_embryo_piezo_galvo " - "through the queue server's plan-submission API." + # Escape hatch reserved for a future hardware-regression follow-up. It is + # intentionally unwired (delegating to the real Bluesky plan needs a + # RunEngine + device objects that live on the device layer). Return a + # clear message instead of raising, so a model that sets this flag gets a + # graceful answer rather than a hard NotImplementedError — the default + # surgical path already mirrors v0.4.0 behavior. + return ( + "use_v04_plan is not available: the default calibration path already " + "replicates the v0.4.0 plan (edge detection + inset + wide adaptive " + "sweep). Re-run calibrate_embryo without use_v04_plan." ) logger.info("calibration path: surgical (v0.4.0-equivalent inset + adaptive sweep)") @@ -1000,12 +1034,22 @@ async def calibrate_embryo( if session_prior.num_calibrations > 0 and session_prior.r_squared_mean >= 0.75: HEURISTIC_SLOPE = session_prior.slope_um_per_deg HEURISTIC_OFFSET = session_prior.offset_um - logger.info("Using session prior: %.1f um/deg, offset %.1f um", HEURISTIC_SLOPE, HEURISTIC_OFFSET) - logger.info("Prior from %d embryo(s), R2=%.3f", session_prior.num_calibrations, session_prior.r_squared_mean) - elif embryo.calibration and embryo.calibration.get('slope_um_per_deg'): - HEURISTIC_SLOPE = embryo.calibration['slope_um_per_deg'] - HEURISTIC_OFFSET = embryo.calibration.get('offset_um', 0.0) - logger.info("Using previous embryo calibration: %.1f um/deg, offset %.1f um", HEURISTIC_SLOPE, HEURISTIC_OFFSET) + logger.info( + "Using session prior: %.1f um/deg, offset %.1f um", HEURISTIC_SLOPE, HEURISTIC_OFFSET + ) + logger.info( + "Prior from %d embryo(s), R2=%.3f", + session_prior.num_calibrations, + session_prior.r_squared_mean, + ) + elif embryo.calibration and embryo.calibration.get("slope_um_per_deg"): + HEURISTIC_SLOPE = embryo.calibration["slope_um_per_deg"] + HEURISTIC_OFFSET = embryo.calibration.get("offset_um", 0.0) + logger.info( + "Using previous embryo calibration: %.1f um/deg, offset %.1f um", + HEURISTIC_SLOPE, + HEURISTIC_OFFSET, + ) else: HEURISTIC_SLOPE = 100.0 # Default empirical value HEURISTIC_OFFSET = 0.0 @@ -1017,9 +1061,9 @@ async def calibrate_embryo( try: # First move to embryo position pos = embryo.stage_position - if pos and pos.get('x') is not None and pos.get('y') is not None: + if pos and pos.get("x") is not None and pos.get("y") is not None: logger.info("Moving to %s position...", embryo_id) - await client.move_to_position(pos['x'], pos['y']) + await client.move_to_position(pos["x"], pos["y"]) # Initialize Claude client for vision claude_vision = AsyncClaudeClient() @@ -1034,37 +1078,48 @@ async def check_embryo_at_position(galvo_pos: float) -> bool: nonlocal total_exposures piezo_pos = HEURISTIC_SLOPE * galvo_pos + HEURISTIC_OFFSET # Track light sheet result = await client.capture_lightsheet_image( - piezo_position=float(piezo_pos), - galvo_position=float(galvo_pos) + piezo_position=float(piezo_pos), galvo_position=float(galvo_pos) ) - if result.get('success'): + if result.get("success"): total_exposures += 1 - if not result.get('success') or result.get('image') is None: + if not result.get("success") or result.get("image") is None: return False - img = result['image'] + img = result["image"] # Select best view from dual-view image img_view = select_best_view(img) # Save to temp file for Claude - with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f: + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: temp_path = Path(f.name) # Normalize and save - img_norm = ((img_view - img_view.min()) / (img_view.max() - img_view.min() + 1e-10) * 255).astype(np.uint8) + img_norm = ( + (img_view - img_view.min()) / (img_view.max() - img_view.min() + 1e-10) * 255 + ).astype(np.uint8) Image.fromarray(img_norm).save(temp_path) try: # Claude now returns (visible, feature_score, description) - visible, feature_score, description = await claude_vision.detect_embryo_presence(temp_path) - logger.debug("galvo=%+.3f deg: %s (features=%s/10) - %s...", galvo_pos, 'VISIBLE' if visible else 'EMPTY', feature_score, description[:40]) + visible, feature_score, description = await claude_vision.detect_embryo_presence( + temp_path + ) + logger.debug( + "galvo=%+.3f deg: %s (features=%s/10) - %s...", + galvo_pos, + "VISIBLE" if visible else "EMPTY", + feature_score, + description[:40], + ) # Record for optimal focus position selection - edge_detection_data.append({ - 'galvo': float(galvo_pos), - 'piezo': float(piezo_pos), - 'visible': visible, - 'feature_score': feature_score, - }) + edge_detection_data.append( + { + "galvo": float(galvo_pos), + "piezo": float(piezo_pos), + "visible": visible, + "feature_score": feature_score, + } + ) # Push edge detection image to viz server if agent.viz_server: @@ -1073,12 +1128,12 @@ async def check_embryo_at_position(galvo_pos: float) -> bool: uid=f"edge_detect_{embryo_id}_{galvo_pos:.3f}", data_type="edge_detection", metadata={ - 'embryo_id': embryo_id, - 'galvo': float(galvo_pos), - 'piezo': float(piezo_pos), - 'visible': visible, - 'feature_score': feature_score, - } + "embryo_id": embryo_id, + "galvo": float(galvo_pos), + "piezo": float(piezo_pos), + "visible": visible, + "feature_score": feature_score, + }, ) return visible @@ -1090,14 +1145,18 @@ async def check_embryo_at_position(galvo_pos: float) -> bool: # Use provided galvo positions or defaults detected_top = galvo_top if galvo_top is not None else -0.15 detected_bottom = galvo_bottom if galvo_bottom is not None else 0.15 - logger.info("Skipping edge detection, using galvo range: %.3f deg to %.3f deg", detected_top, detected_bottom) + logger.info( + "Skipping edge detection, using galvo range: %.3f deg to %.3f deg", + detected_top, + detected_bottom, + ) else: logger.info("Phase 1: Detecting embryo Z extent with Claude vision...") # Detect TOP edge (sweep from center toward negative) logger.info("Detecting TOP edge (sweeping galvo toward negative)...") detected_top = 0.0 - for galvo in np.arange(0.0, -edge_max_range - edge_step/2, -edge_step): + for galvo in np.arange(0.0, -edge_max_range - edge_step / 2, -edge_step): visible = await check_embryo_at_position(galvo) if visible: detected_top = galvo @@ -1108,7 +1167,7 @@ async def check_embryo_at_position(galvo_pos: float) -> bool: # Detect BOTTOM edge (sweep from center toward positive) logger.info("Detecting BOTTOM edge (sweeping galvo toward positive)...") detected_bottom = 0.0 - for galvo in np.arange(0.0, edge_max_range + edge_step/2, edge_step): + for galvo in np.arange(0.0, edge_max_range + edge_step / 2, edge_step): visible = await check_embryo_at_position(galvo) if visible: detected_bottom = galvo @@ -1116,8 +1175,13 @@ async def check_embryo_at_position(galvo_pos: float) -> bool: logger.info("Embryo disappeared at galvo=%.3f deg", galvo) break - logger.info("Detected embryo extent: top=%.3f deg, bottom=%.3f deg, range=%.3f deg (~%.0f um)", - detected_top, detected_bottom, detected_bottom - detected_top, (detected_bottom - detected_top) * 100) + logger.info( + "Detected embryo extent: top=%.3f deg, bottom=%.3f deg, range=%.3f deg (~%.0f um)", + detected_top, + detected_bottom, + detected_bottom - detected_top, + (detected_bottom - detected_top) * 100, + ) # === PHASE 2: COMPUTE CALIBRATION POSITIONS (v0.4.0 inset formula) === # Extend the detected edges outward by `edge_tolerance_deg` to get the @@ -1135,11 +1199,22 @@ async def check_embryo_at_position(galvo_pos: float) -> bool: calib_bottom = scan_bottom - inset_amount # Store on the embryo so the viz server's "scan top"/"scan bot" dashed # lines match what the calibration actually scanned. - detected_range = detected_bottom - detected_top - logger.info("Scan boundaries (edges +/- %.3f tolerance): %.3f to %.3f deg (range %.3f)", - edge_tolerance_deg, scan_top, scan_bottom, scan_range) - logger.info("Calibration positions (%.0f%% inset from scan boundary): top=%.3f deg, bottom=%.3f deg (%.3f deg apart)", - inset_fraction * 100, calib_top, calib_bottom, calib_bottom - calib_top) + detected_bottom - detected_top + logger.info( + "Scan boundaries (edges +/- %.3f tolerance): %.3f to %.3f deg (range %.3f)", + edge_tolerance_deg, + scan_top, + scan_bottom, + scan_range, + ) + logger.info( + "Calibration positions (%.0f%% inset from scan boundary):" + " top=%.3f deg, bottom=%.3f deg (%.3f deg apart)", + inset_fraction * 100, + calib_top, + calib_bottom, + calib_bottom - calib_top, + ) # === PHASE 3: ADAPTIVE FOCUS SWEEPS AT CALIBRATION POSITIONS === logger.info("Phase 3: Adaptive focus sweeps at calibration positions...") @@ -1168,14 +1243,16 @@ async def check_embryo_at_position(galvo_pos: float) -> bool: total_exposures += sweep_exposures # Check for sweep failure - if result_dict['r_squared'] < 0.5: - logger.warning("Low confidence for %s (R2=%.3f)", galvo_name, result_dict['r_squared']) + if result_dict["r_squared"] < 0.5: + logger.warning( + "Low confidence for %s (R2=%.3f)", galvo_name, result_dict["r_squared"] + ) # === PHASE 4: CALCULATE 2-POINT LINEAR CALIBRATION === - g_top = results['top']['galvo'] - p_top = results['top']['piezo'] - g_bottom = results['bottom']['galvo'] - p_bottom = results['bottom']['piezo'] + g_top = results["top"]["galvo"] + p_top = results["top"]["piezo"] + g_bottom = results["bottom"]["galvo"] + p_bottom = results["bottom"]["piezo"] slope = (p_bottom - p_top) / (g_bottom - g_top) offset = p_top - slope * g_top @@ -1200,26 +1277,26 @@ async def check_embryo_at_position(galvo_pos: float) -> bool: # Store calibration embryo.calibration = { - 'slope_um_per_deg': slope, - 'offset_um': offset, - 'galvo_top_deg': detected_top, - 'galvo_bottom_deg': detected_bottom, - 'galvo_calib_top_deg': g_top, - 'galvo_calib_bottom_deg': g_bottom, - 'piezo_top_um': p_top, - 'piezo_bottom_um': p_bottom, + "slope_um_per_deg": slope, + "offset_um": offset, + "galvo_top_deg": detected_top, + "galvo_bottom_deg": detected_bottom, + "galvo_calib_top_deg": g_top, + "galvo_calib_bottom_deg": g_bottom, + "piezo_top_um": p_top, + "piezo_bottom_um": p_bottom, # Volume acquisition parameters - 'galvo_center': galvo_center, - 'galvo_amplitude': galvo_amplitude, - 'piezo_center': piezo_center, - 'piezo_amplitude': piezo_amplitude, - 'z_buffer_um': z_buffer_um, - 'r_squared_top': results['top']['r_squared'], - 'r_squared_bottom': results['bottom']['r_squared'], + "galvo_center": galvo_center, + "galvo_amplitude": galvo_amplitude, + "piezo_center": piezo_center, + "piezo_amplitude": piezo_amplitude, + "z_buffer_um": z_buffer_um, + "r_squared_top": results["top"]["r_squared"], + "r_squared_bottom": results["bottom"]["r_squared"], } # Update session prior for cross-embryo learning - avg_r_squared = (results['top']['r_squared'] + results['bottom']['r_squared']) / 2 + avg_r_squared = (results["top"]["r_squared"] + results["bottom"]["r_squared"]) / 2 extent_deg = detected_bottom - detected_top session_prior.update_from_calibration( slope=slope, @@ -1227,17 +1304,21 @@ async def check_embryo_at_position(galvo_pos: float) -> bool: r_squared=avg_r_squared, extent_deg=extent_deg, ) - logger.info("Updated session prior (now %d embryo(s), avg R2=%.3f)", session_prior.num_calibrations, session_prior.r_squared_mean) + logger.info( + "Updated session prior (now %d embryo(s), avg R2=%.3f)", + session_prior.num_calibrations, + session_prior.r_squared_mean, + ) # Add to focus history - for name in ['top', 'bottom']: + for name in ["top", "bottom"]: embryo.add_focus_datapoint( - galvo=results[name]['galvo'], - piezo=results[name]['piezo'], - score=results[name]['max_score'], - r_squared=results[name]['r_squared'], - method='calibration', - algorithm='fft_bandpass', + galvo=results[name]["galvo"], + piezo=results[name]["piezo"], + score=results[name]["max_score"], + r_squared=results[name]["r_squared"], + method="calibration", + algorithm="fft_bandpass", ) # Record total light exposure from calibration (50ms default exposure) @@ -1255,24 +1336,24 @@ async def check_embryo_at_position(galvo_pos: float) -> bool: piezo_bottom=p_bottom, slope=slope, offset=offset, - r_squared_top=results['top']['r_squared'], - r_squared_bottom=results['bottom']['r_squared'], + r_squared_top=results["top"]["r_squared"], + r_squared_bottom=results["bottom"]["r_squared"], ) agent.push_viz( array=summary_plot, uid=f"calibration_summary_{embryo_id}", data_type="calibration_summary", metadata={ - 'embryo_id': embryo_id, - 'slope': slope, - 'offset': offset, - 'galvo_top': g_top, - 'galvo_bottom': g_bottom, - 'piezo_top': p_top, - 'piezo_bottom': p_bottom, - 'r_squared_top': results['top']['r_squared'], - 'r_squared_bottom': results['bottom']['r_squared'], - } + "embryo_id": embryo_id, + "slope": slope, + "offset": offset, + "galvo_top": g_top, + "galvo_bottom": g_bottom, + "piezo_top": p_top, + "piezo_bottom": p_bottom, + "r_squared_top": results["top"]["r_squared"], + "r_squared_bottom": results["bottom"]["r_squared"], + }, ) except Exception as plot_err: logger.warning("Failed to generate calibration summary plot: %s", plot_err) @@ -1294,6 +1375,7 @@ async def check_embryo_at_position(galvo_pos: float) -> bool: except Exception as e: import traceback + return f"Error calibrating embryo: {str(e)}\n{traceback.format_exc()}" @@ -1314,13 +1396,13 @@ async def check_embryo_at_position(galvo_pos: float) -> bool: ], ) async def calibrate_all_embryos( - embryo_ids: List[str] = None, + embryo_ids: list[str] | None = None, skip_edge_detection: bool = False, z_buffer_um: float = 25.0, - context: Dict = None + context: dict | None = None, ) -> str: """Calibrate all embryos sequentially with Claude vision""" - agent = context.get('agent') + agent = ctx_get(context, "agent") if not agent: return "Error: No agent context" @@ -1344,10 +1426,10 @@ async def calibrate_all_embryos( embryo_id=eid, skip_edge_detection=skip_edge_detection, z_buffer_um=z_buffer_um, - context=context + context=context, ) # Get first two lines of result - lines = result.split('\n') + lines = result.split("\n") summary = lines[0] if len(lines) == 1 else f"{lines[0]} {lines[1]}" results.append(f"{eid}: {summary}") @@ -1393,15 +1475,26 @@ def _format_quality(cal: dict) -> str: @tool( name="apply_calibration_to_embryos", - description="""Copy one embryo's calibration onto one or more target embryos. Useful when one embryo has a strong piezo-galvo fit and others can borrow it as-is. - -**Quality metric is R²**, NOT galvo extent. The right "source" is the embryo with the highest ``min(r_squared_top, r_squared_bottom)`` — both ends of the galvo sweep need a clean Gaussian fit for the calibration to hold up at the volume edges. Wider extent just means a bigger embryo; it does not imply better calibration. - -**Auto-pick by quality**: pass ``source_embryo_id="auto"`` to let the tool pick the calibration with the highest min-R² across all currently-calibrated embryos. The response includes the per-embryo R² ranking so the agent can narrate the choice. - -Pass ``target_embryo_ids=None`` (or omit it) to apply to ALL other embryos in the experiment that are not skipped. - -Caveats — calibration is position-dependent: piezo-galvo slope drifts across the XY field, and embryos may sit at slightly different Z depths. The agent should warn the user about this when applying broadly. Best practice is still to calibrate each embryo individually; this tool is for "good enough" propagation or when individual calibration would burn too much light dose.""", + description="""Copy one embryo's calibration onto one or more target embryos. Useful when +one embryo has a strong piezo-galvo fit and others can borrow it as-is. + +**Quality metric is R²**, NOT galvo extent. The right "source" is the embryo with the +highest ``min(r_squared_top, r_squared_bottom)`` — both ends of the galvo sweep need a clean +Gaussian fit for the calibration to hold up at the volume edges. Wider extent just means a +bigger embryo; it does not imply better calibration. + +**Auto-pick by quality**: pass ``source_embryo_id="auto"`` to let the tool pick the +calibration with the highest min-R² across all currently-calibrated embryos. The response +includes the per-embryo R² ranking so the agent can narrate the choice. + +Pass ``target_embryo_ids=None`` (or omit it) to apply to ALL other embryos in the +experiment that are not skipped. + +Caveats — calibration is position-dependent: piezo-galvo slope drifts across the XY field, +and embryos may sit at slightly different Z depths. The agent should warn the user about this +when applying broadly. Best practice is still to calibrate each embryo individually; this +tool is for "good enough" propagation or when individual calibration would burn too much +light dose.""", category=ToolCategory.HARDWARE, requires_microscope=False, examples=[ @@ -1417,9 +1510,9 @@ def _format_quality(cal: dict) -> str: ) def apply_calibration_to_embryos( source_embryo_id: str, - target_embryo_ids: Optional[List[str]] = None, + target_embryo_ids: list[str] | None = None, overwrite_existing: bool = True, - context: Dict = None, + context: dict | None = None, ) -> str: """Broadcast one embryo's calibration to others. @@ -1429,6 +1522,7 @@ def apply_calibration_to_embryos( to narrate the choice (and the user can audit). """ from gently.harness.tools.helpers import require_agent + agent, err = require_agent(context) if err: return err @@ -1449,7 +1543,7 @@ def apply_calibration_to_embryos( ranked.sort(key=lambda t: t[1], reverse=True) source_embryo_id = ranked[0][0] ranking_lines.append("Ranking by min(R²_top, R²_bot):") - for eid, score, cal in ranked: + for eid, _score, cal in ranked: mark = " ← chosen" if eid == source_embryo_id else "" ranking_lines.append(f" {eid}: {_format_quality(cal)}{mark}") @@ -1457,14 +1551,12 @@ def apply_calibration_to_embryos( return f"Source embryo '{source_embryo_id}' not found." source = agent.experiment.embryos[source_embryo_id] if not source.calibration: - return ( - f"{source_embryo_id} has no calibration data to copy. " - f"Run calibrate_embryo first." - ) + return f"{source_embryo_id} has no calibration data to copy. Run calibrate_embryo first." if target_embryo_ids is None: target_embryo_ids = [ - eid for eid, e in agent.experiment.embryos.items() + eid + for eid, e in agent.experiment.embryos.items() if eid != source_embryo_id and not e.should_skip ] @@ -1482,6 +1574,7 @@ def apply_calibration_to_embryos( continue # Deep-ish copy so subsequent mutations don't alias the source. import copy + tgt.calibration = copy.deepcopy(source.calibration) applied.append(tid) diff --git a/gently/app/tools/data_tools.py b/gently/app/tools/data_tools.py deleted file mode 100644 index 9435a433..00000000 --- a/gently/app/tools/data_tools.py +++ /dev/null @@ -1,203 +0,0 @@ -""" -Databroker Tools - -Tools for querying and retrieving data from Bluesky/Databroker. -""" - -from typing import Dict, List - -from gently.harness.tools.registry import tool, ToolCategory -from gently.harness.tools.helpers import require_agent - - -@tool( - name="list_runs", - description="List recent Bluesky runs from Databroker", - category=ToolCategory.DATA, -) -def list_runs( - limit: int = 10, - embryo_id: str = None, - plan_name: str = None, - context: Dict = None -) -> str: - """List recent runs""" - agent = context.get('agent') - - if not agent or not agent.databroker: - return "Databroker not available" - - try: - db = agent.databroker - - query = {} - if embryo_id: - query['embryo_id'] = embryo_id - if plan_name: - query['plan_name'] = plan_name - - runs = list(db(**query))[:limit] - - if not runs: - return "No runs found" - - lines = [f"Recent runs ({len(runs)}):", ""] - - for run_uid in runs: - run = db[run_uid] - start = run.metadata.get('start', {}) - lines.append(f"* {run_uid[:8]}...") - lines.append(f" Plan: {start.get('plan_name', 'unknown')}") - lines.append(f" Time: {start.get('time', 'unknown')}") - if 'embryo_id' in start: - lines.append(f" Embryo: {start['embryo_id']}") - lines.append("") - - return "\n".join(lines) - - except Exception as e: - return f"Error listing runs: {str(e)}" - - -@tool( - name="get_run_data", - description="Get data from a specific Bluesky run", - category=ToolCategory.DATA, -) -def get_run_data( - run_id: str, - data_keys: List[str] = None, - stream: str = "primary", - context: Dict = None -) -> str: - """Get run data""" - agent = context.get('agent') - - if not agent or not agent.databroker: - return "Databroker not available" - - try: - db = agent.databroker - - if run_id.startswith('-'): - run = db[int(run_id)] - else: - run = db[run_id] - - data = run.primary.read() - - if data_keys: - data = {k: data[k] for k in data_keys if k in data} - - lines = [f"Run: {run.metadata['start']['uid'][:8]}...", ""] - lines.append(f"Available keys: {list(data.keys())}") - - for key, values in data.items(): - shape = values.shape if hasattr(values, 'shape') else 'scalar' - lines.append(f" {key}: shape={shape}") - - return "\n".join(lines) - - except Exception as e: - return f"Error getting run data: {str(e)}" - - -@tool( - name="get_run_image", - description="Get an image from a Bluesky run for analysis", - category=ToolCategory.DATA, -) -async def get_run_image( - run_id: str, - detector: str = None, - analyze: bool = False, - analysis_prompt: str = None, - context: Dict = None -) -> str: - """Get run image""" - agent = context.get('agent') - - if not agent or not agent.databroker: - return "Databroker not available" - - try: - db = agent.databroker - - if run_id.startswith('-'): - run = db[int(run_id)] - else: - run = db[run_id] - - data = run.primary.read() - - if not detector: - for key in ['bottom_camera', 'camera', 'detector']: - if key in data: - detector = key - break - - if detector not in data: - return f"Detector '{detector}' not found. Available: {list(data.keys())}" - - image = data[detector] - shape = image.shape if hasattr(image, 'shape') else 'unknown' - - result = f"Retrieved image from {detector}\nShape: {shape}" - - if analyze and analysis_prompt: - analysis = await agent._analyze_image_with_vision( - image=image, - prompt=analysis_prompt - ) - result += f"\n\nAnalysis:\n{analysis}" - - return result - - except Exception as e: - return f"Error getting image: {str(e)}" - - -@tool( - name="search_runs", - description="Search Databroker runs by metadata criteria", - category=ToolCategory.DATA, -) -def search_runs( - since: str = None, - until: str = None, - metadata: Dict = None, - limit: int = 20, - context: Dict = None -) -> str: - """Search runs""" - agent = context.get('agent') - - if not agent or not agent.databroker: - return "Databroker not available" - - try: - db = agent.databroker - - query = metadata or {} - - if since: - query['since'] = since - if until: - query['until'] = until - - runs = list(db(**query))[:limit] - - if not runs: - return "No matching runs found" - - lines = [f"Found {len(runs)} runs:", ""] - - for run_uid in runs: - run = db[run_uid] - start = run.metadata.get('start', {}) - lines.append(f"* {run_uid[:8]}: {start.get('plan_name', 'unknown')}") - - return "\n".join(lines) - - except Exception as e: - return f"Error searching runs: {str(e)}" diff --git a/gently/app/tools/detection_tools.py b/gently/app/tools/detection_tools.py index eda06d81..fd7d826b 100644 --- a/gently/app/tools/detection_tools.py +++ b/gently/app/tools/detection_tools.py @@ -9,32 +9,31 @@ """ import uuid -from typing import Dict, List, Optional, Tuple from datetime import datetime from pathlib import Path import numpy as np -from gently.harness.tools.registry import tool, ToolCategory, ToolExample -from gently.harness.tools.helpers import require_agent from gently.core.coordinates import ( + DEFAULT_OBJECTIVE_MAG, + DEFAULT_PIXEL_SIZE_UM, + get_um_per_pixel, pixel_to_stage_position, stage_to_pixel_position, - get_um_per_pixel, - DEFAULT_PIXEL_SIZE_UM, - DEFAULT_OBJECTIVE_MAG, ) +from gently.harness.tools.helpers import ctx_get +from gently.harness.tools.registry import ToolCategory, ToolExample, tool async def _route_to_map_view( agent, image: np.ndarray, - initial_markers: List[Dict], - stage_position: Tuple[float, float], + initial_markers: list[dict], + stage_position: tuple[float, float], default_role: str = "test", pixel_size_um: float = DEFAULT_PIXEL_SIZE_UM, - timeout: Optional[float] = None, -) -> Tuple[Optional[List[Dict]], Optional[str]]: + timeout: float | None = None, +) -> tuple[list[dict] | None, str | None]: """Hand off image + markers to the web map view; await user-edited result. Returns ``(marked, None)`` on success or ``(None, error_message)`` if the @@ -42,8 +41,7 @@ async def _route_to_map_view( """ if getattr(agent, "viz_server", None) is None: return None, ( - "Map view requires the web visualization server. " - "Start it with start_viz_server first." + "Map view requires the web visualization server. Start it with start_viz_server first." ) from gently.ui.web.embryo_marker import mark_embryos_web @@ -51,7 +49,7 @@ async def _route_to_map_view( marked = await mark_embryos_web( viz_server=agent.viz_server, image=image, - initial_stage_position=tuple(stage_position), + initial_stage_position=stage_position, pixel_size_um=pixel_size_um, initial_markers=initial_markers, default_role=default_role, @@ -73,28 +71,38 @@ def _next_embryo_number(experiment) -> int: def _stage_from_pixel( - pixel_x: float, pixel_y: float, - image_shape: Tuple[int, int], - current_stage: Tuple[float, float], + pixel_x: float, + pixel_y: float, + image_shape: tuple[int, int], + current_stage: tuple[float, float], pixel_size_um: float = DEFAULT_PIXEL_SIZE_UM, objective_mag: float = DEFAULT_OBJECTIVE_MAG, -) -> Tuple[float, float]: +) -> tuple[float, float]: """Convert a pixel position in the bottom-cam image to a stage XY.""" h, w = image_shape[:2] um_per_px = get_um_per_pixel(pixel_size_um, objective_mag) return pixel_to_stage_position( - pixel_x, pixel_y, - w / 2, h / 2, - current_stage[0], current_stage[1], + pixel_x, + pixel_y, + w / 2, + h / 2, + current_stage[0], + current_stage[1], um_per_px, ) @tool( name="detect_embryos", - description="""Automatically detect embryos in the current field of view using brightness detection + SAM segmentation, then hand off to the web map view for editing and role assignment. - -Use when user says "find embryos", "detect embryos", or at the start of an experiment to locate samples. Captures a bottom camera image, runs SAM detection, and opens the web map view with SAM markers pre-placed. User adds/removes markers, cycles each marker's role (Test / Calibration / unassigned), and presses Done. The confirmed embryos are registered with their roles in the experiment.""", + description="""Automatically detect embryos in the current field of view using brightness +detection + SAM segmentation, then hand off to the web map view for editing and role +assignment. + +Use when user says "find embryos", "detect embryos", or at the start of an experiment to +locate samples. Captures a bottom camera image, runs SAM detection, and opens the web map +view with SAM markers pre-placed. User adds/removes markers, cycles each marker's role +(Test / Calibration / unassigned), and presses Done. The confirmed embryos are registered +with their roles in the experiment.""", category=ToolCategory.DETECTION, requires_microscope=True, examples=[ @@ -106,23 +114,26 @@ async def detect_embryos( auto_calibrate: bool = False, min_confidence: float = 0.7, use_claude_review: bool = False, - exposure_ms: float = None, + exposure_ms: float | None = None, brightness_percentile: float = 99.0, min_area: int = 5000, max_area: int = 150000, default_role: str = "test", - context: Dict = None, + context: dict | None = None, ) -> str: """Detect embryos via SAM + edit/assign roles in the web map view.""" - agent = context.get('agent') - client = context.get('client') + agent = ctx_get(context, "agent") + client = ctx_get(context, "client") if not agent: return "Error: No agent context" if not client: return "Error: Microscope not connected. Cannot detect embryos in offline mode." if not client.has_sam: - return "Error: SAM server not connected. Embryo detection requires the SAM segmentation server." + return ( + "Error: SAM server not connected." + " Embryo detection requires the SAM segmentation server." + ) try: result = await client.detect_embryos( @@ -134,15 +145,18 @@ async def detect_embryos( max_area=max_area, ) - if not result.get('success'): + if not result.get("success"): return f"Detection failed: {result.get('error', 'Unknown error')}" - sam_embryos = result.get('embryos', []) - image = result.get('image') - stage_pos = tuple(result.get('stage_position', [0.0, 0.0])) + sam_embryos = result.get("embryos", []) + image = result.get("image") + stage_pos = tuple(result.get("stage_position", [0.0, 0.0])) if image is None: - return f"Detection ran but no image was returned for the map view (got {len(sam_embryos)} SAM detections)." + return ( + f"Detection ran but no image was returned for the map view" + f" (got {len(sam_embryos)} SAM detections)." + ) # Hand off SAM detections as editable initial markers in the map view. initial_markers = [ @@ -184,8 +198,10 @@ async def detect_embryos( next_num += 1 stage_x, stage_y = _stage_from_pixel( - m["pixel_x"], m["pixel_y"], - image.shape, stage_pos, + m["pixel_x"], + m["pixel_y"], + image.shape, + stage_pos, ) position = {"x": stage_x, "y": stage_y} @@ -204,26 +220,55 @@ async def detect_embryos( ) added.append((emb_id, m.get("role", default_role))) - role_counts = {} + # OPERATOR_MARKED_EMBRYOS — operator confirmed via the web canvas. + # This is the intent signal eval/shadow listeners hook for ReactiveCandidate. + if added: + bus = getattr(agent, "_event_bus", None) + if bus is not None: + from gently.core.event_bus import EventType + + try: + bus.publish( + event_type=EventType.OPERATOR_MARKED_EMBRYOS, + data={ + "embryo_ids": [eid for eid, _ in added], + "count": len(added), + "stage_origin": list(stage_pos), + "pre_edit_count": len(sam_embryos), + }, + source="detect_embryos:web-editor", + ) + except Exception: + pass + + role_counts: dict[str, int] = {} for _, r in added: role_counts[r] = role_counts.get(r, 0) + 1 role_summary = ", ".join(f"{n} {r}" for r, n in sorted(role_counts.items())) if auto_calibrate and added: - return f"Detected & registered {len(added)} embryos ({role_summary}). Starting calibration..." + return ( + f"Detected & registered {len(added)} embryos ({role_summary})." + " Starting calibration..." + ) return f"Detection complete: {len(added)} embryo(s) ({role_summary})." except Exception as e: import traceback + traceback.print_exc() return f"Error detecting embryos: {str(e)}" @tool( name="manual_mark_embryos", - description="""Capture a bottom-camera image and open the web map view for manual embryo marking. User clicks to add markers and cycles each marker's role (Test / Calibration / unassigned). + description="""Capture a bottom-camera image and open the web map view for manual embryo +marking. User clicks to add markers and cycles each marker's role (Test / Calibration / +unassigned). -Use when automatic detection missed embryos, or when the user wants to add embryos manually (e.g., "let me mark embryos", "I'll click on them"). Newly marked embryos are registered with the role assigned in the map view; existing embryos remain untouched.""", +Use when automatic detection missed embryos, or when the user wants to add embryos manually +(e.g., "let me mark embryos", "I'll click on them"). Newly marked embryos are registered with +the role assigned in the map view; existing embryos remain untouched.""", category=ToolCategory.DETECTION, requires_microscope=True, examples=[ @@ -232,13 +277,13 @@ async def detect_embryos( ], ) async def manual_mark_embryos( - exposure_ms: float = None, + exposure_ms: float | None = None, default_role: str = "test", - context: Dict = None, + context: dict | None = None, ) -> str: """Manual marking via the web map view.""" - agent = context.get('agent') - client = context.get('client') + agent = ctx_get(context, "agent") + client = ctx_get(context, "client") if not agent: return "Error: No agent context" @@ -265,19 +310,24 @@ async def manual_mark_embryos( pos = emb.stage_position or {} sx, sy = pos.get("x", 0), pos.get("y", 0) px, py = stage_to_pixel_position( - stage_x=sx, stage_y=sy, - current_stage_x=stage_pos[0], current_stage_y=stage_pos[1], - image_center_x=w / 2, image_center_y=h / 2, + stage_x=sx, + stage_y=sy, + current_stage_x=stage_pos[0], + current_stage_y=stage_pos[1], + image_center_x=w / 2, + image_center_y=h / 2, um_per_pixel=um_per_px, ) - initial_markers.append({ - "pixel_x": px, - "pixel_y": py, - "role": emb.role, - "source": "existing", - "embryo_id": embryo_id, - "confidence": emb.detection_confidence, - }) + initial_markers.append( + { + "pixel_x": px, + "pixel_y": py, + "role": emb.role, + "source": "existing", + "embryo_id": embryo_id, + "confidence": emb.detection_confidence, + } + ) marked, err = await _route_to_map_view( agent=agent, @@ -301,7 +351,10 @@ async def manual_mark_embryos( added_ids, updated_ids = [], [] for m in marked: stage_x, stage_y = _stage_from_pixel( - m["pixel_x"], m["pixel_y"], image.shape, stage_pos, + m["pixel_x"], + m["pixel_y"], + image.shape, + stage_pos, ) pos = {"x": stage_x, "y": stage_y} existing_id = m.get("embryo_id") @@ -325,10 +378,7 @@ async def manual_mark_embryos( # Removals: any existing embryo NOT mentioned in marked is dropped. seen_ids = {m.get("embryo_id") for m in marked if m.get("embryo_id")} - removed_ids = [ - eid for eid in existing_ids - if eid not in seen_ids and eid not in added_ids - ] + removed_ids = [eid for eid in existing_ids if eid not in seen_ids and eid not in added_ids] for rid in removed_ids: agent.experiment.embryos.pop(rid, None) @@ -343,15 +393,20 @@ async def manual_mark_embryos( except Exception as e: import traceback + traceback.print_exc() return f"Error: {str(e)}" @tool( name="edit_embryos", - description="""Capture a fresh bottom-camera image and open the web map view to edit current embryos — add, remove, move, or re-label Test/Calibration. Same surface as manual_mark_embryos; this tool exists to match user intent when they say "edit" rather than "mark". + description="""Capture a fresh bottom-camera image and open the web map view to edit +current embryos — add, remove, move, or re-label Test/Calibration. Same surface as +manual_mark_embryos; this tool exists to match user intent when they say "edit" rather than +"mark". -Use when user wants to adjust existing detection results (e.g., "edit embryos", "remove embryo_3", "swap roles", "fix detection").""", +Use when user wants to adjust existing detection results (e.g., "edit embryos", +"remove embryo_3", "swap roles", "fix detection").""", category=ToolCategory.DETECTION, requires_microscope=True, examples=[ @@ -361,19 +416,20 @@ async def manual_mark_embryos( ], ) async def edit_embryos( - exposure_ms: float = None, + exposure_ms: float | None = None, default_role: str = "test", - context: Dict = None, + context: dict | None = None, ) -> str: """Edit existing embryos via the web map view.""" - agent = context.get('agent') + agent = ctx_get(context, "agent") if not agent: return "Error: No agent context" if not agent.experiment.embryos: return "No embryos to edit. Run detect_embryos or manual_mark_embryos first." # Same flow as manual_mark_embryos: pre-populate with existing markers, - # let user edit, reconcile. + # let user edit, reconcile. notify_embryos_changed is fired by + # manual_mark_embryos / add_embryo internally. return await manual_mark_embryos( exposure_ms=exposure_ms, default_role=default_role, @@ -383,9 +439,11 @@ async def edit_embryos( @tool( name="show_detected_embryos", - description="""Capture a fresh image and display all tracked embryos with labeled bounding boxes. Shows embryo IDs at their positions. -Use when user wants to see where embryos are visually (e.g., "show me the embryos", "display embryo positions"). -Captures a new bottom camera image and overlays all active (non-skipped) embryo positions. Image is saved to detection_results/.""", + description="""Capture a fresh image and display all tracked embryos with labeled bounding +boxes. Shows embryo IDs at their positions. +Use when user wants to see where embryos are visually (e.g., "show me the embryos", +"display embryo positions"). Captures a new bottom camera image and overlays all active +(non-skipped) embryo positions. Image is saved to detection_results/.""", category=ToolCategory.DETECTION, requires_microscope=True, examples=[ @@ -393,13 +451,10 @@ async def edit_embryos( ToolExample("Display embryo positions", {}), ], ) -async def show_detected_embryos( - save_to_file: bool = True, - context: Dict = None -) -> str: +async def show_detected_embryos(save_to_file: bool = True, context: dict | None = None) -> str: """Show detected embryos visualization using experiment.embryos as source of truth""" - agent = context.get('agent') - client = context.get('client') + agent = ctx_get(context, "agent") + client = ctx_get(context, "client") if not agent: return "Error: No agent context" @@ -412,21 +467,21 @@ async def show_detected_embryos( try: snap = await client.capture_bottom_image() - image = snap['image'] + image = snap["image"] if image is None or image.shape == (100, 100): return "Failed to capture image for visualization." current_stage = await client.get_stage_position() # Archive the bottom camera image with metadata - if snap.get('image_path') and agent.store and agent.session_id: + if snap.get("image_path") and agent.store and agent.session_id: try: from gently.harness.tools.helpers import build_snapshot_metadata - meta = build_snapshot_metadata( - current_stage, image.shape, agent.experiment) + + meta = build_snapshot_metadata(current_stage, image.shape, agent.experiment) agent.store.register_snapshot( - agent.session_id, "bottom_camera", snap['image_path'], - metadata=meta) + agent.session_id, "bottom_camera", snap["image_path"], metadata=meta + ) except Exception: pass @@ -440,8 +495,8 @@ async def show_detected_embryos( for embryo_id, embryo_state in agent.experiment.embryos.items(): pos = embryo_state.stage_position or {} - stage_x = pos.get('x', current_stage[0]) - stage_y = pos.get('y', current_stage[1]) + stage_x = pos.get("x", current_stage[0]) + stage_y = pos.get("y", current_stage[1]) # Convert stage to pixel using centralized function pixel_x, pixel_y = stage_to_pixel_position( @@ -451,17 +506,19 @@ async def show_detected_embryos( current_stage_y=current_stage[1], image_center_x=image_center_x, image_center_y=image_center_y, - um_per_pixel=um_per_pixel + um_per_pixel=um_per_pixel, ) - embryos.append({ - 'embryo_id': embryo_id, - 'pixel_x': pixel_x, - 'pixel_y': pixel_y, - 'stage_x_um': stage_x, - 'stage_y_um': stage_y, - 'confidence': embryo_state.detection_confidence, - }) + embryos.append( + { + "embryo_id": embryo_id, + "pixel_x": pixel_x, + "pixel_y": pixel_y, + "stage_x_um": stage_x, + "stage_y_um": stage_y, + "confidence": embryo_state.detection_confidence, + } + ) if not embryos: return "No embryos to display." @@ -475,13 +532,13 @@ async def show_detected_embryos( embryos=embryos, title=f"Embryos ({len(embryos)})", save_path=save_path, - show=True + show=True, ) - if view_result.get('success'): - embryo_ids = [e.get('embryo_id', '?') for e in embryos] + if view_result.get("success"): + embryo_ids = [e.get("embryo_id", "?") for e in embryos] return f"Showing {len(embryos)} embryos: {', '.join(embryo_ids)}\nSaved to: {save_path}" - elif view_result.get('error'): + elif view_result.get("error"): return f"Display error: {view_result.get('error')}" else: return f"Visualization complete. Check {save_path}" diff --git a/gently/app/tools/experiment_tools.py b/gently/app/tools/experiment_tools.py index e0da0b8c..2b006348 100644 --- a/gently/app/tools/experiment_tools.py +++ b/gently/app/tools/experiment_tools.py @@ -4,25 +4,25 @@ Tools for managing experiments and tracking embryo states. """ -from typing import Dict -from datetime import datetime import json +from datetime import datetime -from gently.harness.tools.registry import tool, ToolCategory, ToolExample -from gently.harness.tools.helpers import require_agent, get_embryo_or_error +from gently.harness.tools.helpers import get_embryo_or_error, require_agent +from gently.harness.tools.registry import ToolCategory, ToolExample, tool @tool( name="get_current_time", - description="""Get the current date and time. Use this when you need to know what time it is now, -for example when the user says "image until 4pm" or "run for the next 2 hours" and you need to calculate durations.""", + description="""Get the current date and time. Use this when you need to know what time it is +now, for example when the user says "image until 4pm" or "run for the next 2 hours" and you need +to calculate durations.""", category=ToolCategory.UTILITY, examples=[ ToolExample("What time is it?", {}), ToolExample("Image from now to 4pm", {}), ], ) -def get_current_time(context: Dict) -> str: +def get_current_time(context: dict) -> str: """Get current date and time""" now = datetime.now() return f"Current time: {now.strftime('%Y-%m-%d %H:%M:%S')} ({now.strftime('%I:%M %p')})" @@ -30,10 +30,13 @@ def get_current_time(context: Dict) -> str: @tool( name="get_experiment_summary", - description="""Get a comprehensive summary of the current experiment including all embryos, their XY stage positions, calibration status, and imaging history. -Use this tool when the user asks about embryo locations, experiment status, how many embryos exist, or wants an overview. -This is the primary tool for answering questions like "where are the embryos?" or "what's the current status?" -Returns all embryo IDs with their coordinates - no parameters needed.""", + description="""Get a comprehensive summary of the current experiment including all embryos, +their XY stage positions, calibration status, and imaging history. +Use this tool when the user asks about embryo locations, experiment status, how many embryos +exist, or wants an overview. +This is the primary tool for answering questions like "where are the embryos?" or +"what's the current status?" Returns all embryo IDs with their coordinates - no parameters +needed.""", category=ToolCategory.EXPERIMENT, examples=[ ToolExample("Where are all the embryos?", {}), @@ -41,7 +44,7 @@ def get_current_time(context: Dict) -> str: ToolExample("How many embryos do we have?", {}), ], ) -def get_experiment_summary(context: Dict) -> str: +def get_experiment_summary(context: dict) -> str: """Get full experiment summary""" agent, err = require_agent(context) if err: @@ -51,17 +54,19 @@ def get_experiment_summary(context: Dict) -> str: @tool( name="query_embryo_status", - description="""Query detailed status of a specific embryo including position, calibration data, imaging history, and detection results. -Use this when the user asks about a specific embryo by ID or number (e.g., "how is embryo 3?", "check embryo_1"). -Returns JSON with stage_position, piezo_center, galvo_center, timepoints_acquired, and detection_results. -The embryo_id can be like "embryo_1", "embryo_3", etc.""", + description="""Query detailed status of a specific embryo including position, calibration data, +imaging history, and detection results. +Use this when the user asks about a specific embryo by ID or number (e.g., "how is embryo 3?", +"check embryo_1"). Returns JSON with stage_position, piezo_center, galvo_center, +timepoints_acquired, and detection_results. The embryo_id can be like "embryo_1", +"embryo_3", etc.""", category=ToolCategory.EMBRYO, examples=[ ToolExample("What's happening with embryo 1?", {"embryo_id": "embryo_1"}), ToolExample("Check on embryo 3", {"embryo_id": "embryo_3"}), ], ) -def query_embryo_status(embryo_id: str, context: Dict) -> str: +def query_embryo_status(embryo_id: str, context: dict) -> str: """Query embryo status""" agent, err = require_agent(context) if err: @@ -76,15 +81,20 @@ def query_embryo_status(embryo_id: str, context: Dict) -> str: @tool( name="skip_embryo", - description="""Mark an embryo to be skipped in future timelapse acquisitions. The embryo remains in the experiment but won't be imaged. -Use when user wants to temporarily stop imaging an embryo (e.g., "skip embryo 2", "stop imaging embryo_3"). -Requires a reason to document why the embryo is being skipped. Can be resumed later with resume_embryo.""", + description="""Mark an embryo to be skipped in future timelapse acquisitions. The embryo +remains in the experiment but won't be imaged. +Use when user wants to temporarily stop imaging an embryo (e.g., "skip embryo 2", +"stop imaging embryo_3"). Requires a reason to document why the embryo is being skipped. +Can be resumed later with resume_embryo.""", category=ToolCategory.EMBRYO, examples=[ - ToolExample("Skip embryo 2, it's dead", {"embryo_id": "embryo_2", "reason": "embryo dead"}), + ToolExample( + "Skip embryo 2, it's dead", + {"embryo_id": "embryo_2", "reason": "embryo dead"}, + ), ], ) -def skip_embryo(embryo_id: str, reason: str, context: Dict) -> str: +def skip_embryo(embryo_id: str, reason: str, context: dict) -> str: """Skip embryo in future acquisitions""" agent, err = require_agent(context) if err: @@ -102,7 +112,8 @@ def skip_embryo(embryo_id: str, reason: str, context: Dict) -> str: @tool( name="remove_embryo", - description="""Permanently remove an embryo from the experiment. This is irreversible - use for false detections or debris. + description="""Permanently remove an embryo from the experiment. This is irreversible +- use for false detections or debris. Use when user says "remove embryo X", "delete embryo X", or "that's not an embryo". Unlike skip_embryo, this completely removes the embryo from tracking. Use carefully.""", category=ToolCategory.EMBRYO, @@ -110,7 +121,7 @@ def skip_embryo(embryo_id: str, reason: str, context: Dict) -> str: ToolExample("Remove embryo 4, it's a false positive", {"embryo_id": "embryo_4"}), ], ) -def remove_embryo(embryo_id: str, context: Dict) -> str: +def remove_embryo(embryo_id: str, context: dict) -> str: """Remove embryo from experiment completely""" agent, err = require_agent(context) if err: @@ -129,14 +140,16 @@ def remove_embryo(embryo_id: str, context: Dict) -> str: @tool( name="resume_embryo", - description="""Resume imaging a previously skipped embryo. Clears the skip flag so the embryo will be included in future acquisitions. -Use when user wants to start imaging an embryo again after it was skipped (e.g., "resume embryo 2", "start imaging embryo_3 again").""", + description="""Resume imaging a previously skipped embryo. Clears the skip flag so the embryo +will be included in future acquisitions. +Use when user wants to start imaging an embryo again after it was skipped (e.g., +"resume embryo 2", "start imaging embryo_3 again").""", category=ToolCategory.EMBRYO, examples=[ ToolExample("Resume imaging embryo 2", {"embryo_id": "embryo_2"}), ], ) -def resume_embryo(embryo_id: str, context: Dict) -> str: +def resume_embryo(embryo_id: str, context: dict) -> str: """Resume skipped embryo""" agent, err = require_agent(context) if err: @@ -155,14 +168,15 @@ def resume_embryo(embryo_id: str, context: Dict) -> str: @tool( name="assign_nickname", description="""Assign a memorable nickname to an embryo for easier reference in conversation. -Use when you notice distinguishing characteristics or the user wants to name an embryo (e.g., "call embryo 1 speedy", "nickname embryo_2 as the fast one"). +Use when you notice distinguishing characteristics or the user wants to name an embryo +(e.g., "call embryo 1 speedy", "nickname embryo_2 as the fast one"). Nicknames make conversation more natural - you can then refer to embryos by nickname.""", category=ToolCategory.EMBRYO, examples=[ ToolExample("Call embryo 1 speedy", {"embryo_id": "embryo_1", "nickname": "speedy"}), ], ) -def assign_nickname(embryo_id: str, nickname: str, context: Dict) -> str: +def assign_nickname(embryo_id: str, nickname: str, context: dict) -> str: """Assign nickname to embryo""" agent, err = require_agent(context) if err: @@ -183,24 +197,45 @@ def assign_nickname(embryo_id: str, nickname: str, context: Dict) -> str: @tool( name="modify_parameters", - description="""Modify acquisition parameters for a specific embryo. Supported keys: interval_seconds, num_slices, exposure_ms, priority, acquisition_mode, laser_power_488_pct. -Use when user wants to adjust imaging for one embryo (e.g., "image embryo 2 faster", "use snap mode for embryo_1", "drop 488 to 3% for embryo_3"). -acquisition_mode can be "volume" (full 3D stack, default) or "snap" (single 2D lightsheet image - faster, less light exposure). -laser_power_488_pct is hard-limited at the device layer (currently 2-6% — see DiSPIMLightSource.POWER_LIMITS_PCT). Out-of-range values are rejected at the tool boundary AND at the device. -Requires a reason to document why parameters are being changed. Changes take effect at the next acquisition.""", + description="""Modify acquisition parameters for a specific embryo. Supported keys: +interval_seconds, num_slices, exposure_ms, priority, acquisition_mode, laser_power_488_pct. +Use when user wants to adjust imaging for one embryo (e.g., "image embryo 2 faster", +"use snap mode for embryo_1", "drop 488 to 3% for embryo_3"). +acquisition_mode can be "volume" (full 3D stack, default) or "snap" (single 2D lightsheet +image - faster, less light exposure). +laser_power_488_pct is hard-limited at the device layer (currently 2-6% — see +DiSPIMLightSource.POWER_LIMITS_PCT). Out-of-range values are rejected at the tool boundary +AND at the device. Requires a reason to document why parameters are being changed. +Changes take effect at the next acquisition.""", category=ToolCategory.EMBRYO, examples=[ - ToolExample("Image embryo 2 every 30 seconds", {"embryo_id": "embryo_2", "changes": {"interval_seconds": 30}, "reason": "pre-hatching monitoring"}), - ToolExample("Use snap mode for embryo 1", {"embryo_id": "embryo_1", "changes": {"acquisition_mode": "snap"}, "reason": "reduce light exposure"}), - ToolExample("Drop 488 power on embryo 3 to 3%", {"embryo_id": "embryo_3", "changes": {"laser_power_488_pct": 3.0}, "reason": "signal saturating"}), + ToolExample( + "Image embryo 2 every 30 seconds", + { + "embryo_id": "embryo_2", + "changes": {"interval_seconds": 30}, + "reason": "pre-hatching monitoring", + }, + ), + ToolExample( + "Use snap mode for embryo 1", + { + "embryo_id": "embryo_1", + "changes": {"acquisition_mode": "snap"}, + "reason": "reduce light exposure", + }, + ), + ToolExample( + "Drop 488 power on embryo 3 to 3%", + { + "embryo_id": "embryo_3", + "changes": {"laser_power_488_pct": 3.0}, + "reason": "signal saturating", + }, + ), ], ) -def modify_parameters( - embryo_id: str, - changes: Dict, - reason: str, - context: Dict -) -> str: +def modify_parameters(embryo_id: str, changes: dict, reason: str, context: dict) -> str: """Modify embryo acquisition parameters""" agent, err = require_agent(context) if err: @@ -211,34 +246,35 @@ def modify_parameters( return err old_params = { - 'interval_seconds': embryo.interval_seconds, - 'num_slices': embryo.num_slices, - 'exposure_ms': embryo.exposure_ms, - 'priority': embryo.priority, - 'acquisition_mode': embryo.acquisition_mode, - 'laser_power_488_pct': embryo.laser_power_488_pct, + "interval_seconds": embryo.interval_seconds, + "num_slices": embryo.num_slices, + "exposure_ms": embryo.exposure_ms, + "priority": embryo.priority, + "acquisition_mode": embryo.acquisition_mode, + "laser_power_488_pct": embryo.laser_power_488_pct, } - if 'interval_seconds' in changes: - embryo.interval_seconds = changes['interval_seconds'] - if 'num_slices' in changes: - embryo.num_slices = changes['num_slices'] - if 'exposure_ms' in changes: - embryo.exposure_ms = changes['exposure_ms'] - if 'priority' in changes: - embryo.priority = changes['priority'] - if 'acquisition_mode' in changes: - mode = changes['acquisition_mode'] - if mode in ('volume', 'snap'): + if "interval_seconds" in changes: + embryo.interval_seconds = changes["interval_seconds"] + if "num_slices" in changes: + embryo.num_slices = changes["num_slices"] + if "exposure_ms" in changes: + embryo.exposure_ms = changes["exposure_ms"] + if "priority" in changes: + embryo.priority = changes["priority"] + if "acquisition_mode" in changes: + mode = changes["acquisition_mode"] + if mode in ("volume", "snap"): embryo.acquisition_mode = mode else: return f"Invalid acquisition_mode '{mode}'. Use 'volume' or 'snap'." - if 'laser_power_488_pct' in changes: + if "laser_power_488_pct" in changes: # Soft-validate at the tool layer so the agent gets a clean error # without round-tripping to the device. Hard limit is enforced # at DiSPIMLightSource.set_power_pct regardless. from gently.hardware.dispim.devices.optical import DiSPIMLightSource - pct = changes['laser_power_488_pct'] + + pct = changes["laser_power_488_pct"] lo, hi = DiSPIMLightSource.POWER_LIMITS_PCT.get(488, (0.0, 100.0)) if pct is not None and not (lo <= pct <= hi): return ( @@ -248,32 +284,43 @@ def modify_parameters( ) embryo.laser_power_488_pct = pct - return (f"Modified {embryo_id} parameters:\n" - f"Reason: {reason}\n\n" - f"Changes:\n{json.dumps(changes, indent=2)}\n\n" - f"Previous: {json.dumps(old_params, indent=2)}") + return ( + f"Modified {embryo_id} parameters:\n" + f"Reason: {reason}\n\n" + f"Changes:\n{json.dumps(changes, indent=2)}\n\n" + f"Previous: {json.dumps(old_params, indent=2)}" + ) @tool( name="assign_embryo_roles", - description="""Assign experimental roles (test / calibration / unassigned) to one or more embryos. -Use when the user has marked embryos in the map view and is classifying which are biological subjects (test) vs reference/calibration samples (calibration). -Roles drive cadence policy, detector selection, photodose budget, and UI color. Pass a dict mapping embryo_id -> role name. + description="""Assign experimental roles (test / calibration / unassigned) to one or more +embryos. Use when the user has marked embryos in the map view and is classifying which are +biological subjects (test) vs reference/calibration samples (calibration). +Roles drive cadence policy, detector selection, photodose budget, and UI color. Pass a dict +mapping embryo_id -> role name. Available roles come from gently.harness.roles.REGISTRY: 'test', 'calibration', 'unassigned'.""", category=ToolCategory.EMBRYO, examples=[ ToolExample("Mark embryo 1 as calibration", {"roles": {"embryo_1": "calibration"}}), ToolExample( "Embryos 1-2 are calibration, 3-5 are test", - {"roles": {"embryo_1": "calibration", "embryo_2": "calibration", - "embryo_3": "test", "embryo_4": "test", "embryo_5": "test"}}, + { + "roles": { + "embryo_1": "calibration", + "embryo_2": "calibration", + "embryo_3": "test", + "embryo_4": "test", + "embryo_5": "test", + } + }, ), ], ) -def assign_embryo_roles(roles: Dict[str, str], context: Dict) -> str: +def assign_embryo_roles(roles: dict[str, str], context: dict) -> str: """Assign roles to embryos. Validates against gently.harness.roles.REGISTRY.""" - from gently.harness.roles import get_role, list_roles from gently.core import EventType, get_event_bus + from gently.harness.roles import get_role, list_roles agent, err = require_agent(context) if err: @@ -294,10 +341,7 @@ def assign_embryo_roles(roles: Dict[str, str], context: Dict) -> str: except KeyError: invalid_roles.append((eid, role_name)) if invalid_roles: - return ( - f"Invalid role(s): {invalid_roles}. " - f"Available roles: {list_roles()}" - ) + return f"Invalid role(s): {invalid_roles}. Available roles: {list_roles()}" # Apply event_bus = get_event_bus() @@ -314,7 +358,8 @@ def assign_embryo_roles(roles: Dict[str, str], context: Dict) -> str: if getattr(agent, "store", None) and getattr(agent, "session_id", None): pos = embryo.stage_position or {} agent.store.register_embryo( - agent.session_id, eid, + agent.session_id, + eid, position_x=pos.get("x"), position_y=pos.get("y"), calibration=embryo.calibration, diff --git a/gently/app/tools/focus_tools.py b/gently/app/tools/focus_tools.py index 50b24ab7..11769eb4 100644 --- a/gently/app/tools/focus_tools.py +++ b/gently/app/tools/focus_tools.py @@ -11,34 +11,36 @@ - Predicting optimal focus for future acquisitions """ -from typing import Dict, List, Optional from datetime import datetime -import numpy as np -from gently.harness.tools.registry import tool, ToolCategory, ToolExample -from gently.harness.tools.helpers import get_embryo_or_error +import numpy as np # Import focus analysis functions from core from gently.analysis.core import ( + FitFunction, + FocusAnalysisConfig, calculate_focus_score, fit_focus_curve, - FocusAlgorithm, - FocusAnalysisConfig, - FitFunction, ) +from gently.harness.tools.helpers import ctx_get +from gently.harness.tools.registry import ToolCategory, ToolExample, tool @tool( name="fine_focus", - description="""Perform fine focus adjustment by scanning piezo positions and finding optimal focus using image analysis. -Sweeps the piezo through a range of positions, captures lightsheet images at each position, calculates focus scores -using FFT bandpass or gradient algorithm, fits a Gaussian curve, and optionally moves to the best focus position. - -Use when user says "focus", "fine focus", "adjust focus", "find best focus", or after moving to an embryo position. -Default sweep is ±3μm around 4μm with 1μm steps (7 positions). Algorithm options: 'fft_bandpass' (default, best for lightsheet) or 'gradient'. - -If embryo_id is provided, logs the focus measurement to the embryo's focus_history for drift tracking and future reference. -Returns the optimal piezo position and fit quality (R²). Higher R² indicates more reliable focus detection.""", + description="""Perform fine focus adjustment by scanning piezo positions and finding +optimal focus using image analysis. +Sweeps the piezo through a range of positions, captures lightsheet images at each position, +calculates focus scores using FFT bandpass or gradient algorithm, fits a Gaussian curve, +and optionally moves to the best focus position. + +Use when user says "focus", "fine focus", "adjust focus", "find best focus", or after +moving to an embryo position. Default sweep is ±3μm around 4μm with 1μm steps (7 +positions). Algorithm options: 'fft_bandpass' (default, best for lightsheet) or 'gradient'. + +If embryo_id is provided, logs the focus measurement to the embryo's focus_history for +drift tracking and future reference. Returns the optimal piezo position and fit quality (R²). +Higher R² indicates more reliable focus detection.""", category=ToolCategory.CALIBRATION, requires_microscope=True, examples=[ @@ -51,12 +53,12 @@ async def fine_focus( range_um: float = 3.0, step_um: float = 1.0, - center_um: Optional[float] = 4.0, + center_um: float | None = 4.0, algorithm: str = "fft_bandpass", move_to_best: bool = True, galvo_position: float = 0.0, - embryo_id: Optional[str] = None, - context: Dict = None + embryo_id: str | None = None, + context: dict | None = None, ) -> str: """ Perform fine focus sweep to find optimal piezo position. @@ -76,18 +78,19 @@ async def fine_focus( galvo_position : float Galvo position to use during sweep (default: 0.0) embryo_id : str, optional - Embryo to associate this focus measurement with. If provided, logs to embryo's focus_history. + Embryo to associate this focus measurement with. If provided, logs to embryo's + focus_history. context : dict Execution context with client and agent """ - client = context.get('client') - agent = context.get('agent') + client = ctx_get(context, "client") + agent = ctx_get(context, "agent") if not client: return "Error: No microscope client connected" # Validate algorithm - valid_algorithms = ['fft_bandpass', 'gradient', 'volath', 'variance'] + valid_algorithms = ["fft_bandpass", "gradient", "volath", "variance"] if algorithm not in valid_algorithms: return f"Error: Unknown algorithm '{algorithm}'. Valid options: {valid_algorithms}" @@ -104,14 +107,13 @@ async def fine_focus( images = [] captured_positions = [] - for i, pos in enumerate(positions): + for _i, pos in enumerate(positions): result = await client.capture_lightsheet_image( - piezo_position=float(pos), - galvo_position=float(galvo_position) + piezo_position=float(pos), galvo_position=float(galvo_position) ) - if result.get('success') and result.get('image') is not None: - images.append(result['image']) + if result.get("success") and result.get("image") is not None: + images.append(result["image"]) captured_positions.append(pos) if len(images) < 3: @@ -121,7 +123,7 @@ async def fine_focus( scores = [] config = FocusAnalysisConfig(algorithm=algorithm) - for i, img in enumerate(images): + for _i, img in enumerate(images): score = calculate_focus_score(img, algorithm=algorithm, config=config) scores.append(score) @@ -146,14 +148,17 @@ async def fine_focus( fit_quality = "good" if r_squared >= 0.75 else "moderate" # Check if fitted peak is within sweep range - if best_position < captured_positions.min() or best_position > captured_positions.max(): + if ( + best_position < captured_positions.min() + or best_position > captured_positions.max() + ): best_position = best_measured_position fit_quality = "fallback (peak outside range)" else: best_position = best_measured_position fit_quality = "poor" - except Exception as e: + except Exception: best_position = best_measured_position r_squared = 0.0 fit_quality = "failed" @@ -162,7 +167,7 @@ async def fine_focus( if move_to_best: await client.capture_lightsheet_image( piezo_position=float(best_position), - galvo_position=float(galvo_position) + galvo_position=float(galvo_position), ) # Log focus datapoint to embryo's focus_history if embryo_id provided @@ -175,7 +180,7 @@ async def fine_focus( piezo=best_position, score=float(best_measured_score), r_squared=float(r_squared), - method='fine_focus', + method="fine_focus", algorithm=algorithm, ) logged_to_embryo = True @@ -185,11 +190,12 @@ async def fine_focus( # Build result message result_lines = [ - f"✓ Fine focus complete", + "✓ Fine focus complete", f" Optimal position: {best_position:.2f} μm", f" Fit quality: {fit_quality} (R²={r_squared:.3f})", f" Algorithm: {algorithm}", - f" Sweep: {captured_positions.min():.1f} to {captured_positions.max():.1f} μm ({len(captured_positions)} positions)", + f" Sweep: {captured_positions.min():.1f} to {captured_positions.max():.1f} μm" + f" ({len(captured_positions)} positions)", ] if move_to_best: @@ -199,21 +205,25 @@ async def fine_focus( result_lines.append(f" Logged to: {embryo_id} focus history") # Add score statistics - score_range = scores.max() - scores.min() + scores.max() - scores.min() score_cv = np.std(scores) / np.mean(scores) if np.mean(scores) > 0 else 0 - result_lines.append(f" Score variation: {score_cv:.1%} (higher is better for focus detection)") + result_lines.append( + f" Score variation: {score_cv:.1%} (higher is better for focus detection)" + ) return "\n".join(result_lines) except Exception as e: import traceback + return f"Error during fine focus: {str(e)}\n{traceback.format_exc()}" @tool( name="get_focus_score", description="""Calculate focus score for the current lightsheet image without moving the piezo. -Captures a single lightsheet image and returns its focus quality score using the specified algorithm. +Captures a single lightsheet image and returns its focus quality score using the specified +algorithm. Use to check focus quality at current position or compare different positions manually. If piezo_position is not specified, uses CURRENT position (preserves focus after fine_focus). Algorithm options: 'fft_bandpass' (default), 'gradient', 'volath', 'variance'.""", @@ -225,10 +235,10 @@ async def fine_focus( ], ) async def get_focus_score( - piezo_position: float = None, + piezo_position: float | None = None, galvo_position: float = 0.0, algorithm: str = "fft_bandpass", - context: Dict = None + context: dict | None = None, ) -> str: """ Get focus score for current or specified position. @@ -244,12 +254,12 @@ async def get_focus_score( context : dict Execution context """ - client = context.get('client') + client = ctx_get(context, "client") if not client: return "Error: No microscope client connected" - valid_algorithms = ['fft_bandpass', 'gradient', 'volath', 'variance'] + valid_algorithms = ["fft_bandpass", "gradient", "volath", "variance"] if algorithm not in valid_algorithms: return f"Error: Unknown algorithm '{algorithm}'. Valid options: {valid_algorithms}" @@ -260,14 +270,13 @@ async def get_focus_score( # Capture image result = await client.capture_lightsheet_image( - piezo_position=float(piezo_position), - galvo_position=float(galvo_position) + piezo_position=float(piezo_position), galvo_position=float(galvo_position) ) - if not result.get('success') or result.get('image') is None: + if not result.get("success") or result.get("image") is None: return f"Error: Failed to capture image at piezo={piezo_position}μm" - image = result['image'] + image = result["image"] # Calculate focus score config = FocusAnalysisConfig(algorithm=algorithm) @@ -286,7 +295,8 @@ async def get_focus_score( @tool( name="get_focus_history", - description="""Get the focus history for an embryo showing all piezo-galvo measurements over time. + description="""Get the focus history for an embryo showing all piezo-galvo measurements +over time. Shows drift rate, piezo-galvo fit, and individual measurements. Use to understand how focus has changed during a timelapse and whether recalibration is needed.""", category=ToolCategory.ANALYSIS, @@ -295,10 +305,7 @@ async def get_focus_score( ToolExample("Check focus history for embryo 2", {"embryo_id": "embryo_2"}), ], ) -async def get_focus_history( - embryo_id: str, - context: Dict = None -) -> str: +async def get_focus_history(embryo_id: str, context: dict | None = None) -> str: """ Get focus measurement history for an embryo. @@ -309,7 +316,7 @@ async def get_focus_history( context : dict Execution context with agent """ - agent = context.get('agent') + agent = ctx_get(context, "agent") if not agent: return "Error: No agent context available" diff --git a/gently/app/tools/hardware_common.py b/gently/app/tools/hardware_common.py index aa48c2b9..60e07534 100644 --- a/gently/app/tools/hardware_common.py +++ b/gently/app/tools/hardware_common.py @@ -1,4 +1,5 @@ """Shared helpers for hardware tools.""" + import numpy as np diff --git a/gently/app/tools/interaction_tools.py b/gently/app/tools/interaction_tools.py index ba52c7d5..cf3f6f45 100644 --- a/gently/app/tools/interaction_tools.py +++ b/gently/app/tools/interaction_tools.py @@ -6,10 +6,8 @@ """ import json -from typing import Dict, List, Optional - -from gently.harness.tools.registry import tool, ToolCategory, ToolExample +from gently.harness.tools.registry import ToolCategory, ToolExample, tool # Special marker for CLI to detect choice responses CHOICE_RESPONSE_TYPE = "_user_choice_request" @@ -33,28 +31,34 @@ category=ToolCategory.EXPERIMENT, requires_microscope=False, examples=[ - ToolExample("Ask which session", { - "question": "Which session to import from?", - "options": [ - {"id": "abc123", "label": "Today's session (4 embryos)"}, - {"id": "def456", "label": "Yesterday (2 embryos)"}, - ] - }), - ToolExample("Yes/No confirmation", { - "question": "Start the timelapse?", - "options": [ - {"id": "yes", "label": "Yes, start now"}, - {"id": "no", "label": "No, cancel"}, - ] - }), + ToolExample( + "Ask which session", + { + "question": "Which session to import from?", + "options": [ + {"id": "abc123", "label": "Today's session (4 embryos)"}, + {"id": "def456", "label": "Yesterday (2 embryos)"}, + ], + }, + ), + ToolExample( + "Yes/No confirmation", + { + "question": "Start the timelapse?", + "options": [ + {"id": "yes", "label": "Yes, start now"}, + {"id": "no", "label": "No, cancel"}, + ], + }, + ), ], ) async def ask_user_choice( question: str, - options: List[Dict[str, str]], + options: list[dict[str, str]], allow_multiple: bool = False, - default_id: Optional[str] = None, - context: Dict = None + default_id: str | None = None, + context: dict | None = None, ) -> str: """ Present user with selectable options. @@ -88,7 +92,7 @@ async def ask_user_choice( return "Error: Must provide at least one option" for i, opt in enumerate(options): - if 'id' not in opt or 'label' not in opt: + if "id" not in opt or "label" not in opt: return f"Error: Option {i} missing required 'id' or 'label' field" # Return special format that CLI will intercept and render @@ -103,7 +107,7 @@ async def ask_user_choice( return json.dumps(choice_request) -def parse_choice_response(response: str) -> Optional[Dict]: +def parse_choice_response(response: str) -> dict | None: """ Parse a tool response to check if it's a choice request. @@ -128,7 +132,8 @@ def parse_choice_response(response: str) -> Optional[Dict]: # Helper functions for common choice patterns -def yes_no_options(yes_label: str = "Yes", no_label: str = "No") -> List[Dict[str, str]]: + +def yes_no_options(yes_label: str = "Yes", no_label: str = "No") -> list[dict[str, str]]: """Generate standard Yes/No options""" return [ {"id": "yes", "label": yes_label}, @@ -137,10 +142,8 @@ def yes_no_options(yes_label: str = "Yes", no_label: str = "No") -> List[Dict[st def yes_no_cancel_options( - yes_label: str = "Yes", - no_label: str = "No", - cancel_label: str = "Cancel" -) -> List[Dict[str, str]]: + yes_label: str = "Yes", no_label: str = "No", cancel_label: str = "Cancel" +) -> list[dict[str, str]]: """Generate Yes/No/Cancel options""" return [ {"id": "yes", "label": yes_label}, @@ -149,7 +152,7 @@ def yes_no_cancel_options( ] -def embryo_options(agent) -> List[Dict[str, str]]: +def embryo_options(agent) -> list[dict[str, str]]: """Generate options from available embryos""" options = [] for eid, embryo in agent.experiment.embryos.items(): @@ -168,23 +171,19 @@ def embryo_options(agent) -> List[Dict[str, str]]: return options -def session_options(sessions: List[Dict]) -> List[Dict[str, str]]: +def session_options(sessions: list[dict]) -> list[dict[str, str]]: """Generate options from available sessions""" options = [] for session in sessions: - sid = session.get('session_id', session.get('id', 'unknown')) - embryo_count = session.get('embryo_count', 0) - message_count = session.get('message_count', 0) - last_active = session.get('last_active', 'unknown') + sid = session.get("session_id", session.get("id", "unknown")) + embryo_count = session.get("embryo_count", 0) + message_count = session.get("message_count", 0) + last_active = session.get("last_active", "unknown") label = f"{sid[:8]} - {embryo_count} embryos, {message_count} messages" - if last_active != 'unknown': + if last_active != "unknown": label += f" (last: {last_active})" - options.append({ - "id": sid, - "label": label, - "description": f"Session {sid}" - }) + options.append({"id": sid, "label": label, "description": f"Session {sid}"}) return options diff --git a/gently/app/tools/led_tools.py b/gently/app/tools/led_tools.py index d2c3a9e6..df02a8ae 100644 --- a/gently/app/tools/led_tools.py +++ b/gently/app/tools/led_tools.py @@ -4,9 +4,8 @@ Tools for controlling microscope LED illumination. """ -from typing import Dict - -from gently.harness.tools.registry import tool, ToolCategory +from gently.harness.tools.helpers import ctx_get +from gently.harness.tools.registry import ToolCategory, tool @tool( @@ -15,13 +14,13 @@ category=ToolCategory.HARDWARE, requires_microscope=True, ) -async def set_led(state: str, context: Dict) -> str: +async def set_led(state: str, context: dict) -> str: """Set LED state""" - client = context.get('client') + client = ctx_get(context, "client") try: result = await client.set_led(state) - if result.get('success'): + if result.get("success"): return f"LED set to '{state}'" else: return f"Error setting LED: {result.get('error', 'Unknown error')}" @@ -35,21 +34,23 @@ async def set_led(state: str, context: Dict) -> str: category=ToolCategory.HARDWARE, requires_microscope=True, ) -async def get_led_status(context: Dict) -> str: +async def get_led_status(context: dict) -> str: """Get LED status""" - client = context.get('client') + client = ctx_get(context, "client") try: result = await client.get_led_status() - if result.get('success'): - current = result.get('current_state', 'unknown') - available = result.get('available_configs', []) - group = result.get('group_name', 'unknown') - - return (f"LED Status:\n" - f" Current state: {current}\n" - f" ConfigGroup: {group}\n" - f" Available configs: {available}") + if result.get("success"): + current = result.get("current_state", "unknown") + available = result.get("available_configs", []) + group = result.get("group_name", "unknown") + + return ( + f"LED Status:\n" + f" Current state: {current}\n" + f" ConfigGroup: {group}\n" + f" Available configs: {available}" + ) else: return f"Error getting LED status: {result.get('error', 'Unknown error')}" except Exception as e: diff --git a/gently/app/tools/light_source_tools.py b/gently/app/tools/light_source_tools.py index c0093a87..a190b929 100644 --- a/gently/app/tools/light_source_tools.py +++ b/gently/app/tools/light_source_tools.py @@ -14,21 +14,25 @@ ``modify_parameters(embryo_id, {"laser_power_488_pct": ...}, ...)``. """ -from typing import Dict - -from gently.harness.tools.registry import tool, ToolCategory, ToolExample -from gently.harness.tools.helpers import require_agent +from gently.harness.tools.helpers import ctx_get, require_agent +from gently.harness.tools.registry import ToolCategory, ToolExample, tool @tool( name="set_laser_power", - description="""Set per-line laser power % directly (not tied to any embryo). Submits a Bluesky plan via the queue server so the change is traceable. + description="""Set per-line laser power % directly (not tied to any embryo). Submits a +Bluesky plan via the queue server so the change is traceable. -Hard-limited at the device layer (DiSPIMLightSource.POWER_LIMITS_PCT). 488 is constrained to 2-6% by default. Out-of-range values are rejected at the device layer (ValueError); the tool returns the error. +Hard-limited at the device layer (DiSPIMLightSource.POWER_LIMITS_PCT). 488 is constrained +to 2-6% by default. Out-of-range values are rejected at the device layer (ValueError); the +tool returns the error. -After setting, the tool reads back the actual setpoint and includes it in the response so the agent can verify. +After setting, the tool reads back the actual setpoint and includes it in the response so +the agent can verify. -Use for: pre-experiment setup, ad-hoc inspection, calibration. For experiment-scoped per-embryo changes during a timelapse, use modify_parameters with laser_power_488_pct instead.""", +Use for: pre-experiment setup, ad-hoc inspection, calibration. For experiment-scoped +per-embryo changes during a timelapse, use modify_parameters with laser_power_488_pct +instead.""", category=ToolCategory.HARDWARE, requires_microscope=True, examples=[ @@ -39,13 +43,13 @@ async def set_laser_power( wavelength: int, pct: float, - context: Dict = None, + context: dict | None = None, ) -> str: """Set laser power and read back the actual setpoint.""" agent, err = require_agent(context) if err: return err - client = context.get("client") + client = ctx_get(context, "client") if not client: return "Error: Microscope not connected." @@ -66,16 +70,14 @@ async def set_laser_power( actual = None if actual is not None: - return ( - f"{wavelength}nm power set to {pct}% " - f"(readback: {actual:.4f}%)" - ) + return f"{wavelength}nm power set to {pct}% (readback: {actual:.4f}%)" return f"{wavelength}nm power set to {pct}% (readback unavailable)" @tool( name="get_laser_power", - description="""Read the current per-line laser power % from the device. Useful to verify state before/after a change, or to spot-check the current illumination during a long run.""", + description="""Read the current per-line laser power % from the device. Useful to verify +state before/after a change, or to spot-check the current illumination during a long run.""", category=ToolCategory.HARDWARE, requires_microscope=True, examples=[ @@ -85,13 +87,13 @@ async def set_laser_power( ) async def get_laser_power( wavelength: int, - context: Dict = None, + context: dict | None = None, ) -> str: """Read current laser power %.""" agent, err = require_agent(context) if err: return err - client = context.get("client") + client = ctx_get(context, "client") if not client: return "Error: Microscope not connected." diff --git a/gently/app/tools/memory_tools.py b/gently/app/tools/memory_tools.py index a8d0804c..d58d3e5d 100644 --- a/gently/app/tools/memory_tools.py +++ b/gently/app/tools/memory_tools.py @@ -4,12 +4,10 @@ Thin wrappers around AgentMemory. Available in both run and plan modes. """ -from typing import Dict, Optional +from gently.harness.tools.registry import ToolCategory, ToolExample, tool -from gently.harness.tools.registry import tool, ToolCategory, ToolExample - -def _get_memory(context: Dict): +def _get_memory(context: dict | None): """Extract AgentMemory from tool context.""" agent = context.get("agent") if context else None if not agent or not hasattr(agent, "memory") or not agent.memory: @@ -38,7 +36,7 @@ def _get_memory(context: Dict): ) async def recall_campaigns( status: str = "active", - context: Dict = None, + context: dict | None = None, ) -> str: """List campaigns filtered by status.""" memory = _get_memory(context) @@ -66,9 +64,9 @@ async def recall_campaigns( ], ) async def recall_learnings( - query: str = None, + query: str | None = None, limit: int = 20, - context: Dict = None, + context: dict | None = None, ) -> str: """Search or list learnings.""" memory = _get_memory(context) @@ -80,8 +78,7 @@ async def recall_learnings( @tool( name="recall_observations", description=( - "Search or list observations from past sessions. Can filter by " - "keyword or embryo ID." + "Search or list observations from past sessions. Can filter by keyword or embryo ID." ), category=ToolCategory.UTILITY, examples=[ @@ -96,10 +93,10 @@ async def recall_learnings( ], ) async def recall_observations( - query: str = None, - embryo_id: str = None, + query: str | None = None, + embryo_id: str | None = None, limit: int = 20, - context: Dict = None, + context: dict | None = None, ) -> str: """Search or list observations.""" memory = _get_memory(context) @@ -124,8 +121,8 @@ async def recall_observations( ], ) async def recall_context( - campaign_id: str = None, - context: Dict = None, + campaign_id: str | None = None, + context: dict | None = None, ) -> str: """Full context snapshot.""" memory = _get_memory(context) diff --git a/gently/app/tools/plan_execution_tools.py b/gently/app/tools/plan_execution_tools.py index 1d10b18e..f07e8af0 100644 --- a/gently/app/tools/plan_execution_tools.py +++ b/gently/app/tools/plan_execution_tools.py @@ -7,13 +7,11 @@ """ import logging -from typing import Dict, List, Optional -from gently.harness.tools.registry import tool, ToolCategory, ToolExample from gently.harness.tools.helpers import ( require_agent, - require_timelapse_orchestrator, ) +from gently.harness.tools.registry import ToolCategory, ToolExample, tool logger = logging.getLogger(__name__) @@ -22,6 +20,7 @@ # execute_plan_item # --------------------------------------------------------------------------- + @tool( name="execute_plan_item", description=( @@ -41,8 +40,8 @@ ) async def execute_plan_item( item_ref: str, - embryo_ids: List[str] = None, - context: Dict = None, + embryo_ids: list[str] | None = None, + context: dict | None = None, ) -> str: """Execute a planned imaging item.""" agent, err = require_agent(context) @@ -58,7 +57,7 @@ async def execute_plan_item( if not item: return f"Plan item '{item_ref}' not found" - from gently.harness.memory.model import PlanItemType, PlanItemStatus + from gently.harness.memory.model import PlanItemStatus, PlanItemType # 2. Verify type and status if item.type != PlanItemType.IMAGING: @@ -77,7 +76,7 @@ async def execute_plan_item( if not spec: return f"Plan item '{item.title}' has no imaging spec" - actions: List[str] = [] + actions: list[str] = [] # 4. Configure acquisition params on the experiment state experiment = getattr(agent, "experiment", None) @@ -113,7 +112,7 @@ async def execute_plan_item( if session_id: try: cs.link_session_campaign(session_id, item.campaign_id) - actions.append(f"session linked to campaign") + actions.append("session linked to campaign") except Exception: pass @@ -123,7 +122,7 @@ async def execute_plan_item( status=PlanItemStatus.IN_PROGRESS, session_id=session_id, ) - actions.append(f"plan item status → in_progress") + actions.append("plan item status → in_progress") # 8. Start timelapse via orchestrator orchestrator = getattr(agent, "timelapse_orchestrator", None) @@ -131,7 +130,7 @@ async def execute_plan_item( try: stop_cond = spec.stop_condition or "manual" interval = spec.interval_s or 120 - result = await orchestrator.start( + await orchestrator.start( embryo_ids=embryo_ids, stop_condition=stop_cond, base_interval_seconds=interval, @@ -151,11 +150,11 @@ async def execute_plan_item( logger.error( "Failed to install adaptive interval rule " "(stage=%s, new_interval=%s): %s", - stage_key, new_interval, e, - ) - actions.append( - f"interval rule FAILED: {stage_key} → {new_interval}s ({e})" + stage_key, + new_interval, + e, ) + actions.append(f"interval rule FAILED: {stage_key} → {new_interval}s ({e})") except Exception as e: actions.append(f"timelapse start error: {e}") @@ -179,6 +178,7 @@ async def execute_plan_item( # complete_current_plan_item # --------------------------------------------------------------------------- + @tool( name="complete_current_plan_item", description=( @@ -198,8 +198,8 @@ async def execute_plan_item( ) async def complete_current_plan_item( item_ref: str, - outcome: str = None, - context: Dict = None, + outcome: str | None = None, + context: dict | None = None, ) -> str: """Complete a plan item and report newly unblocked items.""" agent, err = require_agent(context) @@ -253,12 +253,13 @@ async def complete_current_plan_item( # Auto-link helper (used by timelapse_tools.py, Feature 3) # --------------------------------------------------------------------------- + def try_auto_link_plan_item( cs, session_id: str, stop_condition: str, interval_seconds: float, -) -> Optional[str]: +) -> str | None: """ Best-effort auto-link: find a planned imaging item that matches the current timelapse parameters and link the session. diff --git a/gently/app/tools/resolution_tools.py b/gently/app/tools/resolution_tools.py index cd659b79..c3a152da 100644 --- a/gently/app/tools/resolution_tools.py +++ b/gently/app/tools/resolution_tools.py @@ -20,10 +20,11 @@ """ import logging -from typing import Dict, List, Optional +from types import SimpleNamespace +from typing import Any -from gently.harness.tools.registry import tool, ToolCategory, ToolExample from gently.harness.tools.helpers import require_agent +from gently.harness.tools.registry import ToolCategory, ToolExample, tool logger = logging.getLogger(__name__) @@ -43,7 +44,7 @@ def _exit_resolution_if_active(agent, outcome: str) -> None: logger.warning(f"exit_resolution_mode failed: {e}") -def _set_active_plan_item(agent, plan_item_id: Optional[str]) -> None: +def _set_active_plan_item(agent, plan_item_id: str | None) -> None: """Update both copies of the active plan item — ExperimentState (persisted) and AgentMemory (in-memory awareness).""" try: @@ -58,7 +59,7 @@ def _set_active_plan_item(agent, plan_item_id: Optional[str]) -> None: pass -def _short(text: Optional[str], n: int = 80) -> str: +def _short(text: str | None, n: int = 80) -> str: if not text: return "" return text if len(text) <= n else text[: n - 1] + "…" @@ -93,7 +94,7 @@ def _short(text: Optional[str], n: int = 80) -> str: async def attach_session_to_plan( plan_item_id: str, rationale: str = "", - context: Dict = None, + context: dict | None = None, ) -> str: """Attach the current session to a plan item.""" agent, err = require_agent(context) @@ -161,7 +162,7 @@ async def attach_session_to_plan( ) async def mark_session_standalone( description: str, - context: Dict = None, + context: dict | None = None, ) -> str: """Mark the session as a standalone (non-plan) run.""" agent, err = require_agent(context) @@ -213,7 +214,7 @@ async def mark_session_standalone( ) async def detach_session_from_plan( reason: str = "", - context: Dict = None, + context: dict | None = None, ) -> str: """Detach the current session from its plan item.""" agent, err = require_agent(context) @@ -277,7 +278,7 @@ async def mark_plan_item_status( plan_item_id: str, status: str, notes: str = "", - context: Dict = None, + context: dict | None = None, ) -> str: """Update a plan item's status.""" agent, err = require_agent(context) @@ -303,10 +304,7 @@ async def mark_plan_item_status( } key = status.strip().lower() if key not in status_map: - return ( - f"Unknown status '{status}'. Valid: " - f"{', '.join(status_map.keys())}." - ) + return f"Unknown status '{status}'. Valid: {', '.join(status_map.keys())}." target = status_map[key] try: @@ -327,7 +325,7 @@ async def mark_plan_item_status( # --------------------------------------------------------------------------- -def _apply_spec_to_embryo(embryo, spec) -> List[str]: +def _apply_spec_to_embryo(embryo, spec) -> list[str]: """Write per-embryo acquisition fields from an ImagingSpec. Returns a list of human-readable changes made.""" applied = [] @@ -363,8 +361,8 @@ def _apply_spec_to_embryo(embryo, spec) -> List[str]: ) async def apply_plan_acquisition_spec( plan_item_id: str, - overrides: Dict = None, - context: Dict = None, + overrides: dict | None = None, + context: dict | None = None, ) -> str: """Apply a plan's ImagingSpec to the experiment.""" agent, err = require_agent(context) @@ -386,26 +384,29 @@ async def apply_plan_acquisition_spec( overrides = overrides or {} # Apply per-embryo to anything we already have - applied_per_embryo: List[str] = [] + applied_per_embryo: list[str] = [] embryo_count = 0 experiment = getattr(agent, "experiment", None) if experiment and experiment.embryos: for embryo in experiment.embryos.values(): # Build an "effective" spec respecting overrides — copy then # zero out any field the caller asked us to skip. - class _Filtered: - pass - eff = _Filtered() + eff = SimpleNamespace() eff.num_slices = ( - None if "num_slices" in overrides and overrides["num_slices"] is None + None + if "num_slices" in overrides and overrides["num_slices"] is None else (overrides.get("num_slices") if "num_slices" in overrides else spec.num_slices) ) eff.exposure_ms = ( - None if "exposure_ms" in overrides and overrides["exposure_ms"] is None - else (overrides.get("exposure_ms") if "exposure_ms" in overrides else spec.exposure_ms) + None + if "exposure_ms" in overrides and overrides["exposure_ms"] is None + else ( + overrides.get("exposure_ms") if "exposure_ms" in overrides else spec.exposure_ms + ) ) eff.interval_s = ( - None if "interval_s" in overrides and overrides["interval_s"] is None + None + if "interval_s" in overrides and overrides["interval_s"] is None else (overrides.get("interval_s") if "interval_s" in overrides else spec.interval_s) ) changes = _apply_spec_to_embryo(embryo, eff) @@ -428,13 +429,15 @@ class _Filtered: "stop_condition": spec.stop_condition, "detectors": list(spec.detectors) if spec.detectors else None, "success_criteria": spec.success_criteria, - "adaptive_intervals": dict(spec.adaptive_intervals) if spec.adaptive_intervals else None, + "adaptive_intervals": dict(spec.adaptive_intervals) + if spec.adaptive_intervals + else None, } except Exception: pass # Build a narratable summary for the agent to quote - parts: List[str] = [] + parts: list[str] = [] if spec.strain: parts.append(f"strain={spec.strain}") if spec.temperature_c is not None: @@ -508,7 +511,7 @@ class _Filtered: async def recall_sibling_sessions( identifier: str, limit: int = 10, - context: Dict = None, + context: dict | None = None, ) -> str: """Return sessions sharing the given plan item's campaign or the campaign itself.""" agent, err = require_agent(context) @@ -520,7 +523,7 @@ async def recall_sibling_sessions( return "Error: context store unavailable." # Try plan-item-first; fall back to treating identifier as campaign id - campaign_id: Optional[str] = None + campaign_id: str | None = None plan_item = cs.resolve_plan_item(identifier) if hasattr(cs, "resolve_plan_item") else None if plan_item: campaign_id = plan_item.campaign_id @@ -536,10 +539,11 @@ async def recall_sibling_sessions( return f"Could not resolve '{identifier}' to a plan item or campaign." # Get the campaign tree and walk plan items, collecting session ids - sessions: List[Dict] = [] + sessions: list[dict] = [] try: items = cs.get_plan_items( - campaign_id=campaign_id, include_children=True, + campaign_id=campaign_id, + include_children=True, ) except Exception as e: return f"Could not list plan items: {e}" @@ -549,19 +553,23 @@ async def recall_sibling_sessions( sid = getattr(item, "session_id", None) if not sid: continue - meta = {} + meta: dict[str, Any] = {} if file_store is not None: try: - meta = file_store.get_session(sid) or {} + meta = dict(file_store.get_session(sid) or {}) except Exception: meta = {} - sessions.append({ - "plan_item_title": item.title, - "plan_item_status": item.status.value if hasattr(item.status, "value") else str(item.status), - "session_id": sid, - "last_active": meta.get("last_active"), - "name": meta.get("name"), - }) + sessions.append( + { + "plan_item_title": item.title, + "plan_item_status": item.status.value + if hasattr(item.status, "value") + else str(item.status), + "session_id": sid, + "last_active": meta.get("last_active"), + "name": meta.get("name"), + } + ) if not sessions: return f"No sessions yet for campaign {campaign_id[:8]}." @@ -596,7 +604,7 @@ async def recall_sibling_sessions( ) async def summarize_campaign_history( campaign_id: str, - context: Dict = None, + context: dict | None = None, ) -> str: """Compact campaign-progress summary for resolution-mode reasoning.""" agent, err = require_agent(context) @@ -625,8 +633,8 @@ async def summarize_campaign_history( status = cs.get_plan_status(campaign.id) lines.append( f"Plan items: {status['completed']}/{status['total']} done" - f"{', ' + str(status['in_progress']) + ' in progress' if status.get('in_progress') else ''}" - f"{', ' + str(status['blocked']) + ' blocked' if status.get('blocked') else ''}" + + (f", {status['in_progress']} in progress" if status.get("in_progress") else "") + + (f", {status['blocked']} blocked" if status.get("blocked") else "") ) next_actions = status.get("next_actions") or [] if next_actions: @@ -656,7 +664,7 @@ async def summarize_campaign_history( ], ) async def list_imaging_candidates( - context: Dict = None, + context: dict | None = None, ) -> str: """Full deterministic listing of unblocked imaging plan items.""" agent, err = require_agent(context) diff --git a/gently/app/tools/session_tools.py b/gently/app/tools/session_tools.py index 30f4fb58..5490c8bc 100644 --- a/gently/app/tools/session_tools.py +++ b/gently/app/tools/session_tools.py @@ -4,24 +4,23 @@ Tools for session statistics, interaction logging, and experiment comparison. """ -from typing import Dict, List - -from gently.harness.tools.registry import tool, ToolCategory, ToolExample from gently.harness.tools.helpers import ( - require_agent, get_embryo_or_error, - require_interaction_logger + get_embryo_or_error, + require_agent, + require_interaction_logger, ) +from gently.harness.tools.registry import ToolCategory, ToolExample, tool @tool( name="assess_image_quality", - description="Assess image quality metrics (focus, brightness, noise) and suggest parameter adjustments", + description=( + "Assess image quality metrics (focus, brightness, noise) and suggest parameter adjustments" + ), category=ToolCategory.ANALYSIS, ) async def assess_image_quality( - embryo_id: str = None, - suggest_parameters: bool = True, - context: Dict = None + embryo_id: str | None = None, suggest_parameters: bool = True, context: dict | None = None ) -> str: """Assess image quality and suggest improvements""" agent, err = require_agent(context) @@ -80,20 +79,22 @@ async def assess_image_quality( response = agent.claude.messages.create( model=agent.model, max_tokens=800, - messages=[{ - "role": "user", - "content": [ - {"type": "text", "text": quality_prompt}, - { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/jpeg", - "data": image_b64 - } - } - ] - }] + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": quality_prompt}, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": image_b64, + }, + }, + ], + } + ], ) assessment = response.content[0].text @@ -136,10 +137,13 @@ async def assess_image_quality( @tool( name="get_session_stats", - description="Get statistics for the current agent session including interactions, corrections, and tool usage", + description=( + "Get statistics for the current agent session including interactions," + " corrections, and tool usage" + ), category=ToolCategory.DATA, ) -def get_session_stats(context: Dict = None) -> str: +def get_session_stats(context: dict | None = None) -> str: """Get session statistics from interaction logger""" agent, err = require_agent(context) if err: @@ -169,8 +173,7 @@ def get_session_stats(context: Dict = None) -> str: category=ToolCategory.ANALYSIS, ) def compare_embryo_development( - embryo_ids: List[str] = None, - context: Dict = None + embryo_ids: list[str] | None = None, context: dict | None = None ) -> str: """Compare embryo development""" agent, err = require_agent(context) @@ -188,14 +191,16 @@ def compare_embryo_development( lines = ["Embryo Development Comparison:", ""] - lines.append(f"{'Embryo':<15} {'Timepoints':<12} {'Stage':<15} {'Hatching Est.':<15} {'Status'}") + lines.append( + f"{'Embryo':<15} {'Timepoints':<12} {'Stage':<15} {'Hatching Est.':<15} {'Status'}" + ) lines.append("-" * 70) for embryo in embryos: stage = "unknown" hatching_est = "N/A" - if hasattr(agent, 'developmental_tracker') and agent.developmental_tracker: + if hasattr(agent, "developmental_tracker") and agent.developmental_tracker: current = agent.developmental_tracker.get_current_stage(embryo.id) if current: stage = current.stage.value @@ -205,17 +210,18 @@ def compare_embryo_development( if embryo.should_skip: status = f"skipped ({embryo.skip_reason})" - elif embryo.hatching_status and embryo.hatching_status.get('detected'): + elif embryo.hatching_status and embryo.hatching_status.get("detected"): status = "HATCHED" else: status = "active" lines.append( - f"{embryo.id:<15} {embryo.timepoints_acquired:<12} {stage:<15} {hatching_est:<15} {status}" + f"{embryo.id:<15} {embryo.timepoints_acquired:<12} {stage:<15}" + f" {hatching_est:<15} {status}" ) active = sum(1 for e in embryos if not e.should_skip) - hatched = sum(1 for e in embryos if e.hatching_status and e.hatching_status.get('detected')) + hatched = sum(1 for e in embryos if e.hatching_status and e.hatching_status.get("detected")) lines.append("") lines.append(f"Summary: {active} active, {hatched} hatched, {len(embryos) - active} skipped") @@ -225,13 +231,12 @@ def compare_embryo_development( @tool( name="analyze_corrections", - description="Analyze user corrections from interaction logs to identify patterns in agent mistakes", + description=( + "Analyze user corrections from interaction logs to identify patterns in agent mistakes" + ), category=ToolCategory.DATA, ) -def analyze_corrections( - limit: int = 50, - context: Dict = None -) -> str: +def analyze_corrections(limit: int = 50, context: dict | None = None) -> str: """Analyze correction patterns""" agent, err = require_agent(context) if err: @@ -252,12 +257,13 @@ def analyze_corrections( return f"No corrections detected in {len(interactions)} interactions." lines = [ - f"Correction Analysis ({len(corrections)} corrections in {len(interactions)} interactions):", - "" + f"Correction Analysis ({len(corrections)} corrections in" + f" {len(interactions)} interactions):", + "", ] - indicator_counts = {} - tool_corrections = {} + indicator_counts: dict[str, int] = {} + tool_corrections: dict[str, int] = {} for corr in corrections[:limit]: for indicator in corr.correction_indicators: @@ -272,8 +278,8 @@ def analyze_corrections( lines.append("") lines.append("Tools frequently followed by corrections:") - for tool, count in sorted(tool_corrections.items(), key=lambda x: -x[1])[:5]: - lines.append(f" {tool}: {count} times") + for tool_name, count in sorted(tool_corrections.items(), key=lambda x: -x[1])[:5]: + lines.append(f" {tool_name}: {count} times") lines.append("") lines.append("Recent correction examples:") @@ -290,10 +296,7 @@ def analyze_corrections( description="Export interaction logs for external analysis", category=ToolCategory.DATA, ) -def export_interaction_log( - format: str = "summary", - context: Dict = None -) -> str: +def export_interaction_log(format: str = "summary", context: dict | None = None) -> str: """Export interaction log""" agent, err = require_agent(context) if err: @@ -341,20 +344,24 @@ def export_interaction_log( @tool( name="import_embryos_from_session", - description="""Import embryos (positions, calibration, settings) from another session into the current experiment. -Use when user wants to start a fresh session but keep embryo positions from a previous session (e.g., "import embryos from last session", "load embryos from session X"). -This imports positions and calibration data but NOT conversation history or detection results - it's a fresh start with known embryos. + description="""Import embryos (positions, calibration, settings) from another session into +the current experiment. +Use when user wants to start a fresh session but keep embryo positions from a previous +session (e.g., "import embryos from last session", "load embryos from session X"). +This imports positions and calibration data but NOT conversation history or detection +results - it's a fresh start with known embryos. Use list_sessions or /sessions first to find the session_id to import from.""", category=ToolCategory.DATA, examples=[ ToolExample("Import embryos from session abc123", {"session_id": "abc123"}), - ToolExample("Load embryos from previous session, replacing current ones", {"session_id": "abc123", "clear_existing": True}), + ToolExample( + "Load embryos from previous session, replacing current ones", + {"session_id": "abc123", "clear_existing": True}, + ), ], ) def import_embryos_from_session( - session_id: str, - clear_existing: bool = False, - context: Dict = None + session_id: str, clear_existing: bool = False, context: dict | None = None ) -> str: """ Import embryos from another session. @@ -372,12 +379,9 @@ def import_embryos_from_session( if err: return err - result = agent.import_embryos_from_session( - session_id=session_id, - clear_existing=clear_existing - ) + result = agent.import_embryos_from_session(session_id=session_id, clear_existing=clear_existing) - if not result.get('success'): + if not result.get("success"): return f"Import failed: {result.get('error', 'Unknown error')}" lines = [ @@ -385,16 +389,16 @@ def import_embryos_from_session( f" Imported: {len(result['imported'])} embryo(s)", ] - if result['imported']: + if result["imported"]: lines.append(f" {', '.join(result['imported'])}") - if result['skipped']: + if result["skipped"]: lines.append(f" Skipped (already exist): {len(result['skipped'])}") lines.append(f" {', '.join(result['skipped'])}") - if result.get('errors'): + if result.get("errors"): lines.append(f" Errors: {len(result['errors'])}") - for err in result['errors']: + for err in result["errors"]: lines.append(f" - {err}") return "\n".join(lines) @@ -402,19 +406,18 @@ def import_embryos_from_session( @tool( name="list_sessions", - description="""List available sessions with their IDs, embryo counts, message counts, and last active times. -Use when user asks "show sessions", "what sessions exist", or needs to pick a session to resume or import from. -Returns ALL sessions — do NOT filter by embryo count. Sessions are valuable for conversation history too.""", + description="""List available sessions with their IDs, embryo counts, message counts, +and last active times. +Use when user asks "show sessions", "what sessions exist", or needs to pick a session to +resume or import from. Returns ALL sessions — do NOT filter by embryo count. Sessions are +valuable for conversation history too.""", category=ToolCategory.DATA, examples=[ ToolExample("Show available sessions", {}), ToolExample("List recent sessions", {"limit": 5}), ], ) -def list_sessions( - limit: int = 20, - context: Dict = None -) -> str: +def list_sessions(limit: int = 20, context: dict | None = None) -> str: """ List available sessions. @@ -442,15 +445,16 @@ def list_sessions( lines.append("-" * 80) for s in sessions: - session_id = s.get('session_id', 'unknown')[:38] - embryo_count = s.get('embryo_count', 0) - msg_count = s.get('message_count', 0) - last_active = s.get('last_active', '') + session_id = s.get("session_id", "unknown")[:38] + embryo_count = s.get("embryo_count", 0) + msg_count = s.get("message_count", 0) + last_active = s.get("last_active", "") if last_active: # Format datetime string try: from datetime import datetime - dt = datetime.fromisoformat(last_active.replace('Z', '+00:00')) + + dt = datetime.fromisoformat(last_active.replace("Z", "+00:00")) last_active = dt.strftime("%Y-%m-%d %H:%M") except Exception: last_active = last_active[:16] @@ -462,6 +466,9 @@ def list_sessions( lines.append("") lines.append("To resume a session (full history + state): /resume ") - lines.append("To import only embryo positions into current session: import_embryos_from_session(session_id)") + lines.append( + "To import only embryo positions into current session:" + " import_embryos_from_session(session_id)" + ) return "\n".join(lines) diff --git a/gently/app/tools/stage_tools.py b/gently/app/tools/stage_tools.py index 3bf9ec49..c4e9ab55 100644 --- a/gently/app/tools/stage_tools.py +++ b/gently/app/tools/stage_tools.py @@ -4,17 +4,17 @@ Tools for controlling microscope XY stage movement. """ -from typing import Dict - -from gently.harness.tools.registry import tool, ToolCategory, ToolExample -from gently.harness.tools.helpers import get_embryo_or_error +from gently.harness.tools.helpers import ctx_get, get_embryo_or_error +from gently.harness.tools.registry import ToolCategory, ToolExample, tool @tool( name="move_to_embryo", - description="""Move the XY stage to a specific embryo's stored position. The embryo must have been detected and have a valid stage_position. + description="""Move the XY stage to a specific embryo's stored position. The embryo must +have been detected and have a valid stage_position. Use when user says "go to embryo X", "move to embryo X", or before imaging a specific embryo. -This only moves XY - piezo/galvo are controlled separately during acquisition. Movement takes ~0.5 seconds.""", +This only moves XY - piezo/galvo are controlled separately during acquisition. Movement +takes ~0.5 seconds.""", category=ToolCategory.MOVEMENT, requires_microscope=True, examples=[ @@ -22,10 +22,10 @@ ToolExample("Move to embryo 3", {"embryo_id": "embryo_3"}), ], ) -async def move_to_embryo(embryo_id: str, context: Dict) -> str: +async def move_to_embryo(embryo_id: str, context: dict) -> str: """Move stage to embryo position""" - agent = context.get('agent') - client = context.get('client') + agent = ctx_get(context, "agent") + client = ctx_get(context, "client") if not agent: return "Error: No agent context" @@ -38,22 +38,25 @@ async def move_to_embryo(embryo_id: str, context: Dict) -> str: return f"Embryo '{embryo_id}' has no stored position. Run calibration first." try: - x = embryo.stage_position.get('x', 0) - y = embryo.stage_position.get('y', 0) + x = embryo.stage_position.get("x", 0) + y = embryo.stage_position.get("y", 0) await client.move_to_position(x, y) return f"Moved to {embryo_id}\nPosition: ({x:.2f}, {y:.2f}) um" except Exception as e: import traceback + return f"Error moving to embryo: {str(e)}\n{traceback.format_exc()}" @tool( name="get_stage_position", - description="""Get the current XY stage position in micrometers. Returns the real-time position from the hardware. -Use when user asks "where is the stage?", "current position?", or when you need to know the microscope's current location. -This reads from hardware - different from embryo stored positions which are in the experiment data.""", + description="""Get the current XY stage position in micrometers. Returns the real-time +position from the hardware. +Use when user asks "where is the stage?", "current position?", or when you need to know +the microscope's current location. This reads from hardware - different from embryo stored +positions which are in the experiment data.""", category=ToolCategory.HARDWARE, requires_microscope=True, examples=[ @@ -61,9 +64,9 @@ async def move_to_embryo(embryo_id: str, context: Dict) -> str: ToolExample("Current XY position?", {}), ], ) -async def get_stage_position(context: Dict) -> str: +async def get_stage_position(context: dict) -> str: """Get current stage position""" - client = context.get('client') + client = ctx_get(context, "client") if not client: return "Error: No microscope client connected" @@ -79,7 +82,8 @@ async def get_stage_position(context: Dict) -> str: @tool( name="move_stage", description="""Move the XY stage to specific coordinates in micrometers. -Use when user wants to move to arbitrary coordinates (e.g., "move to x=1000, y=500", "move stage to 1200, -600"). +Use when user wants to move to arbitrary coordinates (e.g., "move to x=1000, y=500", +"move stage to 1200, -600"). For moving to a specific embryo, use move_to_embryo instead.""", category=ToolCategory.HARDWARE, requires_microscope=True, @@ -88,13 +92,9 @@ async def get_stage_position(context: Dict) -> str: ToolExample("Move stage to coordinates 1200, -600", {"x": 1200, "y": -600}), ], ) -async def move_stage( - x: float, - y: float, - context: Dict = None -) -> str: +async def move_stage(x: float, y: float, context: dict | None = None) -> str: """Move stage to arbitrary XY coordinates""" - client = context.get('client') + client = ctx_get(context, "client") if not client: return "Error: No microscope client connected" diff --git a/gently/app/tools/temperature_tools.py b/gently/app/tools/temperature_tools.py new file mode 100644 index 00000000..d9fd303c --- /dev/null +++ b/gently/app/tools/temperature_tools.py @@ -0,0 +1,79 @@ +""" +Temperature Control Tools + +Agent tools for the ACUITYnano thermal controller. Temperature drives C. elegans +development rate, so these let the agent hold or shift the sample temperature as +part of closed-loop experiments. +""" + +from gently.harness.tools.helpers import ctx_get +from gently.harness.tools.registry import ToolCategory, ToolExample, tool + + +@tool( + name="set_temperature", + description=( + "Set the sample temperature setpoint in Celsius (0.0-99.9). The thermal " + "controller ramps toward the target and this returns immediately — poll " + "get_temperature until the state reads '[ SYSTEM LOCKED ]' before imaging. " + "Temperature controls C. elegans development rate (~15 C slow, 20 C standard, " + "25 C fast)." + ), + category=ToolCategory.HARDWARE, + requires_microscope=True, + examples=[ + ToolExample("Hold the sample at 20 degrees", {"target_c": 20.0}), + ToolExample("Warm the embryos to 25 C to speed development", {"target_c": 25.0}), + ], +) +async def set_temperature(target_c: float, context: dict) -> str: + """Command the thermal controller to a target temperature. + + Parameters + ---------- + target_c : float + Target temperature in degrees Celsius (0.0-99.9). + """ + client = ctx_get(context, "client") + try: + result = await client.set_temperature(target_c) + if result.get("success"): + return ( + f"Commanded {target_c} C. Currently {result.get('temperature_c')} C, " + f"state {result.get('state')!r}. Ramping — call get_temperature to confirm lock." + ) + return f"Error setting temperature: {result.get('error', 'unknown error')}" + except Exception as e: + return f"Error setting temperature: {e}" + + +@tool( + name="get_temperature", + description=( + "Read the current sample temperature, target setpoint, and lock state from the " + "thermal controller. Use to confirm the sample has stabilized at the setpoint " + "('[ SYSTEM LOCKED ]') before acquiring." + ), + category=ToolCategory.HARDWARE, + requires_microscope=True, + examples=[ + ToolExample("What's the current temperature?"), + ToolExample("Has the sample reached temperature yet?"), + ], +) +async def get_temperature(context: dict) -> str: + """Read current temperature, setpoint, and lock state.""" + client = ctx_get(context, "client") + try: + r = await client.get_temperature() + if r.get("success"): + msg = ( + f"Temperature {r.get('temperature_c')} C " + f"(setpoint {r.get('setpoint_c')} C, state {r.get('state')!r}" + ) + if r.get("peltier_c") is not None: + msg += f", peltier {r.get('peltier_c')} C" + return msg + ")" + return f"Error reading temperature: {r.get('error', 'unknown error')}" + except Exception as e: + return f"Error reading temperature: {e}" diff --git a/gently/app/tools/timelapse_tools.py b/gently/app/tools/timelapse_tools.py index fad0ff95..d5ec3d76 100644 --- a/gently/app/tools/timelapse_tools.py +++ b/gently/app/tools/timelapse_tools.py @@ -4,13 +4,14 @@ Tools for managing adaptive timelapse acquisitions. """ -from typing import Dict, List, Optional - -from gently.harness.tools.registry import tool, ToolCategory from gently.harness.tools.helpers import ( - require_agent, get_embryo_or_error, - require_timelapse_orchestrator, require_developmental_tracker + ctx_get, + get_embryo_or_error, + require_agent, + require_developmental_tracker, + require_timelapse_orchestrator, ) +from gently.harness.tools.registry import ToolCategory, ToolExample, tool @tool( @@ -20,13 +21,13 @@ ) async def generate_bluesky_plan( goal: str, - embryo_ids: List[str], + embryo_ids: list[str], plan_type: str = "adaptive_timelapse", - parameters: Dict = None, - context: Dict = None + parameters: dict | None = None, + context: dict | None = None, ) -> str: """Generate Bluesky plan""" - agent = context.get('agent') + agent = ctx_get(context, "agent") if not agent: return "Error: No agent context" @@ -36,7 +37,7 @@ async def generate_bluesky_plan( goal=goal, embryo_ids=embryo_ids, plan_type=plan_type, - parameters=parameters or {} + parameters=parameters or {}, ) return result @@ -47,20 +48,23 @@ async def generate_bluesky_plan( @tool( name="start_adaptive_timelapse", description=( - "Start an adaptive timelapse that runs in the background. Agent remains responsive while acquisition continues. " - "Pass `monitoring_mode='expression_monitoring'` for fluorescent-reporter experiments to install reactive cadence + power rules at startup. " - "Other monitoring_mode values: 'pre_terminal_monitoring' (hatching-timing experiments), 'idle' (plain imaging, no reactive rules)." + "Start an adaptive timelapse that runs in the background. Agent remains responsive" + " while acquisition continues. " + "Pass `monitoring_mode='expression_monitoring'` for fluorescent-reporter experiments" + " to install reactive cadence + power rules at startup. " + "Other monitoring_mode values: 'pre_terminal_monitoring' (hatching-timing" + " experiments), 'idle' (plain imaging, no reactive rules)." ), category=ToolCategory.EXPERIMENT, requires_microscope=True, ) async def start_adaptive_timelapse( - embryo_ids: List[str] = None, + embryo_ids: list[str] | None = None, stop_condition: str = "manual", interval_seconds: float = 120.0, - condition_value: int = None, - monitoring_mode: Optional[str] = None, - context: Dict = None + condition_value: int | None = None, + monitoring_mode: str | None = None, + context: dict | None = None, ) -> str: """Start adaptive timelapse in background""" agent, err = require_agent(context) @@ -84,10 +88,14 @@ async def start_adaptive_timelapse( cs = getattr(agent, "context_store", None) if cs: from .plan_execution_tools import try_auto_link_plan_item + session_id = getattr(agent, "session_id", None) if session_id: linked = try_auto_link_plan_item( - cs, session_id, stop_condition, interval_seconds, + cs, + session_id, + stop_condition, + interval_seconds, ) if linked: result += f"\n(Auto-linked to plan item: '{linked}')" @@ -114,7 +122,7 @@ async def start_adaptive_timelapse( description="Get current status of the running timelapse including per-embryo progress", category=ToolCategory.EXPERIMENT, ) -def get_timelapse_status(context: Dict = None) -> str: +def get_timelapse_status(context: dict | None = None) -> str: """Get timelapse status""" agent, err = require_agent(context) if err: @@ -127,12 +135,9 @@ def get_timelapse_status(context: Dict = None) -> str: state = orchestrator.get_status() status_dict = state.to_dict() - lines = [ - f"Timelapse Status: {status_dict['status'].upper()}", - "" - ] + lines = [f"Timelapse Status: {status_dict['status'].upper()}", ""] - if status_dict['started_at']: + if status_dict["started_at"]: lines.append(f"Started: {status_dict['started_at']}") lines.append(f"Duration: {status_dict['duration_minutes']:.1f} minutes") lines.append(f"Total timepoints acquired: {status_dict['total_timepoints']}") @@ -142,20 +147,19 @@ def get_timelapse_status(context: Dict = None) -> str: lines.append(f"Completed embryos: {status_dict['completed_embryos']}") lines.append("") - if status_dict['next_embryo']: - lines.append(f"Next acquisition: {status_dict['next_embryo']} in {status_dict['next_acquisition_in_seconds']:.0f}s") + if status_dict.get("seconds_until_next_round") is not None: + lines.append(f"Next acquisition in {status_dict['seconds_until_next_round']:.0f}s") lines.append("") - if status_dict['embryo_details']: + if status_dict["embryo_details"]: lines.append("Embryo Details:") - for eid, details in status_dict['embryo_details'].items(): - status_marker = "[done]" if details['is_complete'] else "[active]" - lines.append(f" {status_marker} {eid}: t={details['timepoints']} " - f"(interval={details['interval_seconds']}s)") - if details['is_complete']: + for eid, details in status_dict["embryo_details"].items(): + status_marker = "[done]" if details["is_complete"] else "[active]" + lines.append(f" {status_marker} {eid}: t={details['timepoints']}") + if details["is_complete"]: lines.append(f" Completed: {details['completion_reason']}") - if status_dict['error']: + if status_dict["error"]: lines.append("") lines.append(f"Error: {status_dict['error']}") @@ -164,15 +168,18 @@ def get_timelapse_status(context: Dict = None) -> str: @tool( name="modify_timelapse_embryo", - description="Modify parameters for a specific embryo during a running timelapse. Note: interval is now global - use modify_timelapse_interval to change it.", + description=( + "Modify parameters for a specific embryo during a running timelapse." + " Note: interval is now global - use modify_timelapse_interval to change it." + ), category=ToolCategory.EXPERIMENT, requires_microscope=True, ) async def modify_timelapse_embryo( embryo_id: str, - stop_condition: str = None, - condition_value: int = None, - context: Dict = None + stop_condition: str | None = None, + condition_value: int | None = None, + context: dict | None = None, ) -> str: """Modify embryo parameters during timelapse (stop condition only - interval is global)""" agent, err = require_agent(context) @@ -196,15 +203,18 @@ async def modify_timelapse_embryo( @tool( name="add_embryo_to_timelapse", - description="Add an embryo to an already running timelapse. The embryo will use the global interval and join on the next round.", + description=( + "Add an embryo to an already running timelapse. The embryo will use the global" + " interval and join on the next round." + ), category=ToolCategory.EXPERIMENT, requires_microscope=True, ) async def add_embryo_to_timelapse( embryo_id: str, - stop_condition: str = None, - condition_value: int = None, - context: Dict = None + stop_condition: str | None = None, + condition_value: int | None = None, + context: dict | None = None, ) -> str: """Add an embryo to a running timelapse (uses global interval)""" agent, err = require_agent(context) @@ -233,9 +243,7 @@ async def add_embryo_to_timelapse( requires_microscope=True, ) async def stop_timelapse_embryo( - embryo_id: str, - reason: str = "user_request", - context: Dict = None + embryo_id: str, reason: str = "user_request", context: dict | None = None ) -> str: """Stop imaging a specific embryo""" agent, err = require_agent(context) @@ -259,10 +267,7 @@ async def stop_timelapse_embryo( category=ToolCategory.EXPERIMENT, requires_microscope=True, ) -async def stop_timelapse( - reason: str = "user_request", - context: Dict = None -) -> str: +async def stop_timelapse(reason: str = "user_request", context: dict | None = None) -> str: """Stop entire timelapse""" agent, err = require_agent(context) if err: @@ -285,7 +290,7 @@ async def stop_timelapse( category=ToolCategory.EXPERIMENT, requires_microscope=True, ) -async def pause_timelapse(context: Dict = None) -> str: +async def pause_timelapse(context: dict | None = None) -> str: """Pause timelapse""" agent, err = require_agent(context) if err: @@ -308,7 +313,7 @@ async def pause_timelapse(context: Dict = None) -> str: category=ToolCategory.EXPERIMENT, requires_microscope=True, ) -async def resume_timelapse(context: Dict = None) -> str: +async def resume_timelapse(context: dict | None = None) -> str: """Resume timelapse""" agent, err = require_agent(context) if err: @@ -327,14 +332,13 @@ async def resume_timelapse(context: Dict = None) -> str: @tool( name="add_stop_condition", - description="Add an additional stop condition to a running timelapse (OR logic). E.g., add 'hatching' condition to a timelapse running with a duration limit.", + description=( + "Add an additional stop condition to a running timelapse (OR logic). E.g., add" + " 'hatching' condition to a timelapse running with a duration limit." + ), category=ToolCategory.EXPERIMENT, ) -def add_stop_condition( - embryo_id: str, - condition: str, - context: Dict = None -) -> str: +def add_stop_condition(embryo_id: str, condition: str, context: dict | None = None) -> str: """ Add an additional stop condition to an embryo in a running timelapse. @@ -391,22 +395,22 @@ def add_stop_condition( # Get updated description new_desc = embryo_state.stop_condition.describe() - return ( - f"Added stop condition '{condition}' to {embryo_id}\n" - f"Stop conditions: {new_desc}" - ) + return f"Added stop condition '{condition}' to {embryo_id}\nStop conditions: {new_desc}" @tool( name="add_interval_speedup_rule", - description="Add a rule to automatically speed up imaging when a developmental stage is reached (e.g., 'speed up to 30s when 3fold stage detected')", + description=( + "Add a rule to automatically speed up imaging when a developmental stage is" + " reached (e.g., 'speed up to 30s when 3fold stage detected')" + ), category=ToolCategory.EXPERIMENT, ) def add_interval_speedup_rule( trigger_stage: str, new_interval_seconds: float = 30.0, - embryo_ids: List[str] = None, - context: Dict = None + embryo_ids: list[str] | None = None, + context: dict | None = None, ) -> str: """Add interval speedup rule based on developmental stage""" agent, err = require_agent(context) @@ -423,7 +427,10 @@ def add_interval_speedup_rule( embryo_ids=embryo_ids, ) - msg = f"Added interval rule: speed up to {new_interval_seconds}s when '{trigger_stage}' stage is reached" + msg = ( + f"Added interval rule: speed up to {new_interval_seconds}s when" + f" '{trigger_stage}' stage is reached" + ) if embryo_ids: msg += f" (for embryos: {', '.join(embryo_ids)})" @@ -432,12 +439,14 @@ def add_interval_speedup_rule( @tool( name="enable_pre_hatching_speedup", - description="Enable automatic speedup when embryos approach hatching (triggers when 3-fold stage is detected by the perception system)", + description=( + "Enable automatic speedup when embryos approach hatching (triggers when 3-fold" + " stage is detected by the perception system)" + ), category=ToolCategory.EXPERIMENT, ) def enable_pre_hatching_speedup( - fast_interval_seconds: float = 30.0, - context: Dict = None + fast_interval_seconds: float = 30.0, context: dict | None = None ) -> str: """Enable pre-hatching speedup based on developmental stage""" agent, err = require_agent(context) @@ -452,13 +461,15 @@ def enable_pre_hatching_speedup( orchestrator.add_pre_terminal_speedup(fast_interval_seconds) from gently.organisms import get_organism + organism = get_organism() trigger_stage = organism.PRE_TERMINAL_SPEEDUP_STAGE return ( f"Enabled pre-hatching speedup:\n" f" - Perception system will detect developmental stages\n" - f" - When {trigger_stage} stage detected, interval will change to {fast_interval_seconds}s\n" + f" - When {trigger_stage} stage detected, interval will change to" + f" {fast_interval_seconds}s\n" f" - This helps capture hatching at high temporal resolution" ) @@ -468,10 +479,7 @@ def enable_pre_hatching_speedup( description="Use Claude Vision to classify the current developmental stage of an embryo", category=ToolCategory.ANALYSIS, ) -async def classify_embryo_stage( - embryo_id: str, - context: Dict = None -) -> str: +async def classify_embryo_stage(embryo_id: str, context: dict | None = None) -> str: """Classify embryo stage""" agent, err = require_agent(context) if err: @@ -487,8 +495,9 @@ async def classify_embryo_stage( latest = embryo.recent_images[-1] # Initialize tracker if needed - if not hasattr(agent, 'developmental_tracker') or agent.developmental_tracker is None: + if not hasattr(agent, "developmental_tracker") or agent.developmental_tracker is None: from ..developmental_tracker import DevelopmentalTracker + agent.developmental_tracker = DevelopmentalTracker( claude_client=agent.claude, model=agent.model, @@ -496,10 +505,12 @@ async def classify_embryo_stage( recent = [] for img in embryo.recent_images[-5:]: - recent.append({ - 'timepoint': img.timepoint, - 'b64_image': img.max_projection_b64, - }) + recent.append( + { + "timepoint": img.timepoint, + "b64_image": img.max_projection_b64, + } + ) result = agent.developmental_tracker.classify_stage( image_b64=latest.max_projection_b64, @@ -517,7 +528,10 @@ async def classify_embryo_stage( if result.predicted_minutes_to_hatching is not None: hours = result.predicted_minutes_to_hatching / 60 - lines.append(f" Predicted time to hatching: ~{hours:.1f} hours ({result.predicted_minutes_to_hatching} min)") + lines.append( + f" Predicted time to hatching: ~{hours:.1f} hours" + f" ({result.predicted_minutes_to_hatching} min)" + ) return "\n".join(lines) @@ -527,22 +541,45 @@ async def classify_embryo_stage( description="Get the developmental stage progression history for an embryo", category=ToolCategory.ANALYSIS, ) -def get_stage_history( - embryo_id: str, - context: Dict = None -) -> str: +def get_stage_history(embryo_id: str, context: dict | None = None) -> str: """Get stage history""" agent, err = require_agent(context) if err: return err + # Prefer the live perception session (the orchestrator's Perceiver, which the + # agent shares). The DevelopmentalTracker below is only populated by manual + # classify_embryo_stage calls, so it is usually empty in autonomous runs. + perceiver = getattr(agent, "perceiver", None) + session = perceiver.get_session(embryo_id) if perceiver else None + if session is not None and getattr(session, "current_stage", None): + s = session.summary() + lines = [ + f"Stage progression for {embryo_id} (live perception):", + f" Current stage: {s.get('current_stage')} (stable for {s.get('stability', 0)} obs)", + f" Observations: {s.get('observation_count', 0)}", + ] + seq = s.get("stage_sequence") or [] + if seq: + lines.append(f" Trajectory: {' -> '.join(seq)}") + t = s.get("temporal") # TemporalContext dataclass or None + if t is not None: + exp = getattr(t, "expected_duration_min", None) + seg = f" Time in current stage: {getattr(t, 'time_in_stage_min', 0.0):.0f} min" + if exp: + seg += f" (expected ~{exp:.0f} min)" + lines.append(seg) + if getattr(t, "is_potentially_arrested", False): + lines.append(" ** potentially ARRESTED **") + return "\n".join(lines) + tracker, err = require_developmental_tracker(agent) if err: return err summary = tracker.get_progression_summary(embryo_id) - if summary['observations'] == 0: + if summary["observations"] == 0: return f"No stage classifications for {embryo_id}. Use classify_embryo_stage first." lines = [ @@ -552,28 +589,98 @@ def get_stage_history( f" Stages observed: {', '.join(summary['stages_observed'])}", ] - if summary['predicted_minutes_to_hatching'] is not None: - hours = summary['predicted_minutes_to_hatching'] / 60 + if summary["predicted_minutes_to_hatching"] is not None: + hours = summary["predicted_minutes_to_hatching"] / 60 lines.append(f" Predicted time to hatching: ~{hours:.1f} hours") return "\n".join(lines) +def _perceiver_hatching_estimate(session) -> float | None: + """Estimate minutes until the 'hatching' stage from the perception session. + + Uses gently_perception's own organism stage ordering + typical durations, so + no DevelopmentalStage enum mapping is needed. Returns None when unknown + (no_object / off-vocabulary stage), 0.0 when already hatching/hatched. + """ + try: + from gently_perception.organism import CELEGANS + except Exception: + return None + stage = getattr(session, "current_stage", None) + if not stage or stage == "no_object": + return None + stages = list(CELEGANS.stages) + durations = dict(CELEGANS.stage_durations) + if stage in ("hatching", "hatched"): + return 0.0 + if stage not in stages or "hatching" not in stages: + return None + idx = stages.index(stage) + target = stages.index("hatching") + if idx >= target: + return 0.0 + # Remaining time in the current stage (expected minus already-elapsed). + elapsed = 0.0 + t = session.summary().get("temporal") + if t is not None: + elapsed = getattr(t, "time_in_stage_min", 0.0) or 0.0 + remaining = max(0.0, durations.get(stage, 0.0) - elapsed) + # Plus the full expected duration of each stage between current and hatching. + for s in stages[idx + 1 : target]: + remaining += durations.get(s, 0.0) + return remaining + + @tool( name="predict_hatching", - description="Predict time-to-hatching for an embryo with confidence intervals based on developmental stage", + description=( + "Predict time-to-hatching for an embryo with confidence intervals based on" + " developmental stage" + ), category=ToolCategory.ANALYSIS, ) def predict_hatching( - embryo_id: str = None, - all_embryos: bool = False, - context: Dict = None + embryo_id: str | None = None, all_embryos: bool = False, context: dict | None = None ) -> str: """Predict hatching time with confidence intervals""" agent, err = require_agent(context) if err: return err + # Prefer the live perception session; the DevelopmentalTracker is usually + # empty in autonomous runs (only manual classify_embryo_stage feeds it). + perceiver = getattr(agent, "perceiver", None) + + def _perc_line(eid: str): + session = perceiver.get_session(eid) if perceiver else None + if session is None or not getattr(session, "current_stage", None): + return None + stage = session.current_stage + if stage in ("hatching", "hatched"): + return f" {eid}: stage={stage} (hatching now / already hatched)" + est = _perceiver_hatching_estimate(session) + if est is None: + return f" {eid}: stage={stage} (time-to-hatching unknown)" + return f" {eid}: stage={stage}, ~{est / 60:.1f}h to hatching ({est:.0f} min)" + + if perceiver is not None: + if all_embryos: + ids = list(agent.experiment.embryos.keys()) + perc = [_perc_line(e) for e in ids] + if any(perc): + out = ["Hatching predictions (live perception):", ""] + out += [p for p in perc if p] + missing = [e for e, p in zip(ids, perc, strict=False) if not p] + if missing: + out.append("") + out.append(f"(no perception yet for: {', '.join(missing)})") + return "\n".join(out) + elif embryo_id: + line = _perc_line(embryo_id) + if line: + return f"Hatching prediction for {embryo_id} (live perception):\n{line}" + tracker, err = require_developmental_tracker(agent) if err: return err @@ -590,7 +697,9 @@ def predict_hatching( for eid, pred in predictions.items(): lines.append(f" {eid}:") lines.append(f" Current stage: {pred.current_stage.value}") - lines.append(f" Predicted: {pred.predicted_hours:.1f}h ({pred.predicted_minutes} min)") + lines.append( + f" Predicted: {pred.predicted_hours:.1f}h ({pred.predicted_minutes} min)" + ) lines.append(f" Range: {pred.range_hours[0]:.1f} - {pred.range_hours[1]:.1f}h") lines.append(f" Confidence: {pred.confidence}") lines.append("") @@ -616,7 +725,8 @@ def predict_hatching( lines = [ f"Hatching Prediction for {embryo_id}:", f" Current stage: {pred.current_stage.value}", - f" Predicted time to hatching: {pred.predicted_hours:.1f} hours ({pred.predicted_minutes} min)", + f" Predicted time to hatching: {pred.predicted_hours:.1f} hours" + f" ({pred.predicted_minutes} min)", f" Confidence interval: {pred.range_hours[0]:.1f} - {pred.range_hours[1]:.1f} hours", f" Classification confidence: {pred.confidence}", ] @@ -630,6 +740,261 @@ def predict_hatching( return "\n".join(lines) +@tool( + name="set_autonomy", + description="""Set the autonomy mode of the decision-moment wake-router (default OFF). +Modes: + 'off' — never act on its own; only respond to your messages. + 'ask' — on a notable event (stage transition, arrest, hatching, termination, + errors) the agent PROPOSES a change and waits for you to Approve / + Modify / Skip in the chat before acting. + 'auto' — the agent adapts acquisition on its own (still bounded by device + limits; a few irreversible actions always require your confirmation). +You can switch modes mid-run. Use when the user says "enable autopilot/autonomous", +"ask me before changing things", "go fully autonomous", or "turn off autonomy".""", + category=ToolCategory.ANALYSIS, + examples=[ + ToolExample("Ask me before adapting", {"mode": "ask"}), + ToolExample("Go fully autonomous", {"mode": "auto"}), + ToolExample("Turn off autonomy", {"mode": "off"}), + ], +) +def set_autonomy( + mode: str | None = None, enabled: bool | None = None, context: dict | None = None +) -> str: + """Set the wake-router mode (off/ask/auto). `enabled` kept for back-compat.""" + agent, err = require_agent(context) + if err: + return err + router = getattr(agent, "wake_router", None) + if router is None: + return "Autonomy is not available (wake-router failed to initialize)." + if mode is not None: + m = str(mode).strip().lower() + if m not in ("off", "ask", "auto"): + return "mode must be 'off', 'ask', or 'auto'." + router.set_mode(m) + elif enabled is not None: + router.set_enabled(bool(enabled)) + else: + return "Specify mode ('off', 'ask', or 'auto')." + cur = router.mode + if cur == "auto": + return ( + "Autonomy set to AUTO. I'll wake on stage transitions, arrest, " + "hatching, termination, and errors and adapt acquisition on my own " + "(irreversible actions still need your okay). Say 'ask mode' or " + "'turn off autonomy' to change." + ) + if cur == "ask": + return ( + "Autonomy set to ASK. On a notable event I'll propose a change and " + "wait for your Approve / Modify / Skip before doing anything." + ) + return "Autonomy OFF. I'll only act when you message me." + + +# --------------------------------------------------------------------------- +# Live cadence / dose modulation — direct knobs for a running timelapse. +# --------------------------------------------------------------------------- + + +@tool( + name="modify_timelapse_interval", + description="""Change the base acquisition interval for ALL embryos on a running +timelapse, effective immediately. +Re-anchors every embryo's next acquisition to now + the new interval and notifies the UI. +Lower interval = more frequent imaging = more photodose; raise it to be gentler. +Use when the user says "image every N minutes/seconds now", "speed up/slow down the whole run". +For a single embryo use set_embryo_cadence instead.""", + category=ToolCategory.EXPERIMENT, + examples=[ + ToolExample("Image every 2 minutes now", {"new_interval_seconds": 120}), + ToolExample("Slow everything down to 10 minutes", {"new_interval_seconds": 600}), + ], +) +def modify_timelapse_interval(new_interval_seconds: float, context: dict | None = None) -> str: + """Globally re-anchor the timelapse interval (live).""" + agent, err = require_agent(context) + if err: + return err + orchestrator, err = require_timelapse_orchestrator(agent) + if err: + return err + return orchestrator.modify_interval(new_interval_seconds) + + +@tool( + name="set_embryo_cadence", + description="""Change ONE embryo's acquisition cadence on a running timelapse, effective +immediately. Set new_interval_seconds to re-anchor that embryo's next acquisition to now + +interval (lower = more frequent = more dose). +Set new_phase to 'normal' to resume a paused embryo, or 'paused' to pause it. +NOTE: re-issuing the SAME interval with the SAME phase is a no-op (it won't re-anchor). +Use for per-embryo tuning, e.g. speed up the one that's developing fastest.""", + category=ToolCategory.EXPERIMENT, + examples=[ + ToolExample( + "Image embryo_2 every minute", + {"embryo_id": "embryo_2", "new_interval_seconds": 60}, + ), + ToolExample("Resume embryo_3", {"embryo_id": "embryo_3", "new_phase": "normal"}), + ], +) +def set_embryo_cadence( + embryo_id: str, + new_interval_seconds: float | None = None, + new_phase: str | None = None, + context: dict | None = None, +) -> str: + """Per-embryo cadence change routed through the re-anchoring path.""" + agent, err = require_agent(context) + if err: + return err + orchestrator, err = require_timelapse_orchestrator(agent) + if err: + return err + embryo, err = get_embryo_or_error(agent, embryo_id) + if err: + return err + if new_interval_seconds is None and new_phase is None: + return "Specify new_interval_seconds and/or new_phase." + if new_interval_seconds is not None and new_interval_seconds < 1: + return "Interval must be >= 1 second." + if new_phase is not None and new_phase not in ("normal", "fast", "burst", "paused"): + return "new_phase must be one of: normal, fast, burst, paused." + # Detect the no-op (transition_cadence silently does nothing, and would NOT + # re-anchor next_due_at, if neither interval nor phase actually changes). + cur_interval = getattr(embryo, "interval_seconds", None) + cur_phase = getattr(embryo, "cadence_phase", None) + interval_change = new_interval_seconds is not None and new_interval_seconds != cur_interval + phase_change = new_phase is not None and new_phase != cur_phase + if not interval_change and not phase_change: + shown = f"{cur_interval:.0f}s" if cur_interval is not None else "default" + return f"{embryo.id}: no change (already interval={shown}, phase={cur_phase})." + orchestrator.transition_cadence( + embryo, + new_interval_seconds=new_interval_seconds if interval_change else None, + new_phase=new_phase if phase_change else None, + reason="agent:set_embryo_cadence", + ) + bits = [] + if interval_change: + bits.append(f"interval={new_interval_seconds:.0f}s") + if phase_change: + bits.append(f"phase={new_phase}") + due = getattr(embryo, "next_due_at", None) + tail = f"; next acquisition ~{due.strftime('%H:%M:%S')}" if due else "" + return f"{embryo.id}: {', '.join(bits)}{tail}" + + +@tool( + name="set_photodose_budget", + description="""Set or clear the per-embryo photodose budget (a hard cap on cumulative +laser exposure). base_dose_budget_ms is the ceiling for a 1x-role (test) embryo; +calibration embryos get 10x. +When an embryo's cumulative exposure exceeds its budget it is auto-PAUSED to protect the sample. +Pass null/None to DISABLE the cap. Raising the budget also resumes embryos that were paused +for the old cap. +Use to enforce gentleness on precious samples, or to lift the cap when the user okays more dose.""", + category=ToolCategory.EXPERIMENT, + examples=[ + ToolExample("Cap each embryo at 5 seconds of light", {"base_dose_budget_ms": 5000}), + ToolExample("Remove the photodose cap", {"base_dose_budget_ms": None}), + ], +) +def set_photodose_budget( + base_dose_budget_ms: float | None = None, + resume_paused: bool = True, + context: dict | None = None, +) -> str: + """Set/clear the photodose budget; optionally resume budget-paused embryos.""" + agent, err = require_agent(context) + if err: + return err + orchestrator, err = require_timelapse_orchestrator(agent) + if err: + return err + # Capture who was budget-paused BEFORE set_photodose_budget clears the set, + # so we only resume embryos paused for the budget (not manual pauses/bursts). + prev_exceeded = set(getattr(orchestrator, "_dose_budget_exceeded", set()) or set()) + msg = orchestrator.set_photodose_budget(base_dose_budget_ms) + resumed = [] + if resume_paused: + states = getattr(orchestrator, "_embryo_states", {}) or {} + try: + from gently.harness.roles import REGISTRY as ROLE_REGISTRY + except Exception: + ROLE_REGISTRY = {} + for eid in prev_exceeded: + e = states.get(eid) + if e is None or getattr(e, "cadence_phase", None) != "paused": + continue + # Only resume if the embryo is now UNDER the new budget (or the cap + # was disabled); otherwise it would just immediately re-pause. + if base_dose_budget_ms is not None: + rdef = ( + ROLE_REGISTRY.get(getattr(e, "role", "test")) + if hasattr(ROLE_REGISTRY, "get") + else None + ) + mult = getattr(rdef, "photodose_budget_multiplier", 1.0) if rdef else 1.0 + if (getattr(e, "total_exposure_ms", 0.0) or 0.0) > base_dose_budget_ms * mult: + continue + orchestrator.transition_cadence( + e, new_phase="normal", reason="agent:budget change resume" + ) + resumed.append(eid) + if resumed: + msg += f" Resumed: {', '.join(sorted(resumed))}." + return msg + + +@tool( + name="get_photodose_status", + description="""Report each embryo's cumulative light exposure vs its photodose budget, +and which are paused over budget. +Use to reason about gentleness before/after changing the budget, power, or cadence.""", + category=ToolCategory.ANALYSIS, + examples=[ToolExample("How much light has each embryo gotten?", {})], +) +def get_photodose_status(context: dict | None = None) -> str: + """Read-only photodose / budget status across embryos.""" + agent, err = require_agent(context) + if err: + return err + orchestrator, err = require_timelapse_orchestrator(agent) + if err: + return err + base = getattr(orchestrator, "_dose_budget_base_ms", None) + exceeded: set[str] = getattr(orchestrator, "_dose_budget_exceeded", set()) or set() + states = getattr(orchestrator, "_embryo_states", {}) or {} + if base is None: + lines = ["Photodose budget: DISABLED (no cap).", ""] + else: + lines = [f"Photodose budget: {base:.0f} ms base (scaled per role).", ""] + try: + from gently.harness.roles import REGISTRY as ROLE_REGISTRY + except Exception: + ROLE_REGISTRY = {} + for eid in sorted(states): + e = states[eid] + used = getattr(e, "total_exposure_ms", 0.0) or 0.0 + role = getattr(e, "role", "test") + if base is not None: + rdef = ROLE_REGISTRY.get(role) if hasattr(ROLE_REGISTRY, "get") else None + mult = getattr(rdef, "photodose_budget_multiplier", 1.0) if rdef else 1.0 + cap = base * mult + pct = (used / cap * 100.0) if cap else 0.0 + flag = " [PAUSED: over budget]" if eid in exceeded else "" + lines.append(f" {eid} ({role}): {used:.0f}/{cap:.0f} ms ({pct:.0f}%){flag}") + else: + lines.append(f" {eid} ({role}): {used:.0f} ms used") + if len(lines) == 2: + lines.append(" (no embryos)") + return "\n".join(lines) + + # --------------------------------------------------------------------------- # Reactive monitoring modes (Phase 5) — high-level "install canonical # detector → cadence + power reactive rules" entry points. Without one of @@ -659,7 +1024,7 @@ def predict_hatching( ) def enable_monitoring_mode( mode_name: str, - context: Dict = None, + context: dict | None = None, ) -> str: """Install a named reactive monitoring mode on the orchestrator.""" agent, err = require_agent(context) @@ -691,8 +1056,8 @@ def enable_monitoring_mode( ) def add_test_onset_speedup( fast_interval: float = 60.0, - embryo_ids: Optional[List[str]] = None, - context: Dict = None, + embryo_ids: list[str] | None = None, + context: dict | None = None, ) -> str: """Install the canonical signal-onset cadence speedup rule.""" agent, err = require_agent(context) @@ -708,9 +1073,7 @@ def add_test_onset_speedup( fast_interval=fast_interval, embryo_ids=embryo_ids, ) - target = ( - ", ".join(embryo_ids) if embryo_ids else "all test-role embryos" - ) + target = ", ".join(embryo_ids) if embryo_ids else "all test-role embryos" return ( f"Installed test-onset speedup: switch to {fast_interval}s interval " f"on signal onset for {target}." @@ -737,8 +1100,8 @@ def add_test_saturation_rampdown( floor_pct: float = 2.0, ceiling_pct: float = 6.0, confirm_timepoints: int = 0, - embryo_ids: Optional[List[str]] = None, - context: Dict = None, + embryo_ids: list[str] | None = None, + context: dict | None = None, ) -> str: """Install the canonical 488 saturation rampdown power rule.""" agent, err = require_agent(context) @@ -757,9 +1120,7 @@ def add_test_saturation_rampdown( confirm_timepoints=confirm_timepoints, embryo_ids=embryo_ids, ) - target = ( - ", ".join(embryo_ids) if embryo_ids else "all test-role embryos" - ) + target = ", ".join(embryo_ids) if embryo_ids else "all test-role embryos" return ( f"Installed 488 saturation rampdown: step={step_pct}%, " f"floor={floor_pct}%, ceiling={ceiling_pct}%, " @@ -792,7 +1153,7 @@ def queue_burst( mode: str = "1hz", num_slices: int = 1, force: bool = False, - context: Dict = None, + context: dict | None = None, ) -> str: """Queue an exclusive burst acquisition for one embryo.""" agent, err = require_agent(context) diff --git a/gently/app/tools/volume_tools.py b/gently/app/tools/volume_tools.py index 080e56dc..3a9d3d5c 100644 --- a/gently/app/tools/volume_tools.py +++ b/gently/app/tools/volume_tools.py @@ -5,21 +5,21 @@ """ import logging -from typing import Dict, Optional -from datetime import datetime -logger = logging.getLogger(__name__) +from gently.core.coordinates import get_um_per_pixel, stage_to_pixel_position +from gently.harness.tools.helpers import ctx_get, require_agent, require_microscope +from gently.harness.tools.registry import ToolCategory, ToolExample, tool -from gently.harness.tools.registry import tool, ToolCategory, ToolExample -from gently.harness.tools.helpers import require_agent, get_embryo_or_error -from gently.core.coordinates import stage_to_pixel_position, get_um_per_pixel +logger = logging.getLogger(__name__) @tool( name="view_image", - description="""Capture and display the current bottom camera widefield image. Shows what's visible at the current stage position. -Use when user says "show me the view", "take a picture", "what does it look like?", or to check sample positioning. -This is the widefield/brightfield camera looking up at the sample - good for seeing embryo outlines and overall positioning. + description="""Capture and display the current bottom camera widefield image. Shows what's +visible at the current stage position. +Use when user says "show me the view", "take a picture", "what does it look like?", or to +check sample positioning. This is the widefield/brightfield camera looking up at the sample +- good for seeing embryo outlines and overall positioning. Image is automatically saved to camera_captures/ folder.""", category=ToolCategory.HARDWARE, requires_microscope=True, @@ -30,18 +30,20 @@ ) async def view_image( title: str = "Bottom Camera Image", - exposure_ms: float = None, + exposure_ms: float | None = None, show: bool = True, show_embryos: bool = True, - context: Dict = None + context: dict | None = None, ) -> str: """Capture and display bottom camera image with embryo annotations""" - client = context.get('client') - agent = context.get('agent') + client, err = require_microscope(context) + if err: + return err + agent = ctx_get(context, "agent") try: snap = await client.capture_bottom_image(exposure_ms=exposure_ms) - image = snap['image'] + image = snap["image"] if image is None or image.shape == (100, 100): return "Failed to capture image from bottom camera" @@ -50,15 +52,16 @@ async def view_image( stage_pos = await client.get_stage_position() # Archive the bottom camera image with metadata - if snap.get('image_path') and agent and agent.store and agent.session_id: + if snap.get("image_path") and agent and agent.store and agent.session_id: try: from gently.harness.tools.helpers import build_snapshot_metadata + meta = build_snapshot_metadata( - stage_pos, image.shape, - agent.experiment if agent else None) + stage_pos, image.shape, agent.experiment if agent else None + ) agent.store.register_snapshot( - agent.session_id, "bottom_camera", snap['image_path'], - metadata=meta) + agent.session_id, "bottom_camera", snap["image_path"], metadata=meta + ) except Exception: pass @@ -72,8 +75,8 @@ async def view_image( for embryo_id, embryo in agent.experiment.embryos.items(): if embryo.stage_position: # Convert stage position to pixel position using centralized function - emb_x = embryo.stage_position.get('x', 0) - emb_y = embryo.stage_position.get('y', 0) + emb_x = embryo.stage_position.get("x", 0) + emb_y = embryo.stage_position.get("y", 0) pixel_x, pixel_y = stage_to_pixel_position( stage_x=emb_x, stage_y=emb_y, @@ -81,34 +84,47 @@ async def view_image( current_stage_y=stage_pos[1], image_center_x=image_center_x, image_center_y=image_center_y, - um_per_pixel=um_per_pixel + um_per_pixel=um_per_pixel, ) - embryo_annotations.append({ - 'embryo_id': embryo_id, - 'pixel_x': pixel_x, - 'pixel_y': pixel_y, - 'label': embryo.user_label or embryo_id - }) + embryo_annotations.append( + { + "embryo_id": embryo_id, + "pixel_x": pixel_x, + "pixel_y": pixel_y, + "label": embryo.user_label or embryo_id, + } + ) if show: from datetime import datetime from pathlib import Path + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") save_path = f"camera_captures/bottom_camera_{timestamp}.jpg" Path("camera_captures").mkdir(exist_ok=True) - view_result = await client.view_image( + await client.view_image( image=image, title=title, save_path=save_path, show=True, - embryo_annotations=embryo_annotations if embryo_annotations else None + embryo_annotations=embryo_annotations if embryo_annotations else None, + ) + num_visible = len( + [ + a + for a in embryo_annotations + if 0 <= a["pixel_x"] < image.shape[1] and 0 <= a["pixel_y"] < image.shape[0] + ] + ) + annotation_msg = ( + f"\nShowing {num_visible} embryo(s) in view" if embryo_annotations else "" + ) + return ( + f"Captured bottom camera image ({image.shape[0]}x{image.shape[1]})" + f"\nSaved to: {save_path}{annotation_msg}" ) - num_visible = len([a for a in embryo_annotations - if 0 <= a['pixel_x'] < image.shape[1] and 0 <= a['pixel_y'] < image.shape[0]]) - annotation_msg = f"\nShowing {num_visible} embryo(s) in view" if embryo_annotations else "" - return f"Captured bottom camera image ({image.shape[0]}x{image.shape[1]})\nSaved to: {save_path}{annotation_msg}" else: return f"Captured bottom camera image ({image.shape[0]}x{image.shape[1]})" @@ -118,118 +134,105 @@ async def view_image( @tool( name="view_volume", - description="""Open a volume in napari for 3D visualization. -Can open a volume by file path OR by embryo ID (opens latest volume or specific timepoint). -Use when user says "open volume", "view volume", "show volume in napari", or "look at the 3D data".""", + description="""Open an acquired volume in the in-browser 3D viewer. +Opens by embryo ID \u2014 the latest volume, or a specific timepoint. The volume +appears in the web UI's volume viewer (interactive 3D raymarcher + projections) +for everyone watching the session; nothing pops up on the instrument desktop. +Use when the user says "open volume", "view volume", "show the 3D data", or +"look at timepoint N of embryo X".""", category=ToolCategory.ANALYSIS, requires_microscope=False, examples=[ ToolExample("Open latest volume for embryo 2", {"embryo_id": "embryo_2"}), ToolExample("Open specific timepoint", {"embryo_id": "embryo_2", "timepoint": 5}), - ToolExample("Open volume file", {"file_path": "D:/Gently/volumes/embryo_1_t0001.tif"}), ], ) async def view_volume( - embryo_id: str = None, - timepoint: int = None, - file_path: str = None, - context: Dict = None + embryo_id: str | None = None, + timepoint: int | None = None, + file_path: str | None = None, + context: dict | None = None, ) -> str: - """Open a volume in napari for visualization""" - import napari - import tifffile - import numpy as np + """Open a volume in the browser-based viewer (no blocking desktop window).""" from pathlib import Path agent, err = require_agent(context) if err: return err - volume = None - volume_path = None - title = "Volume Viewer" + session_id = agent.session_id - # Determine which volume to open - if file_path: - # Open from file path - volume_path = Path(file_path) - if not volume_path.exists(): + # file_path is legacy. In-browser viewing is addressed by embryo + timepoint, + # so map a FileStore path (embryos/{embryo_id}/volumes/t{NNNN}.tif) back to + # those when possible. + if file_path and not embryo_id: + p = Path(file_path) + if not p.exists(): return f"Error: File not found: {file_path}" - title = f"Volume: {volume_path.name}" + stem = p.stem # e.g. "t0005" + try: + if stem.startswith("t"): + timepoint = int(stem[1:]) + # .../embryos/{embryo_id}/volumes/t{NNNN}.tif \u2192 embryo dir is parent of "volumes" + embryo_id = p.parent.parent.name + except (ValueError, IndexError): + pass + if not embryo_id or timepoint is None: + return ( + "Volume viewing is now in-browser and addressed by embryo + " + "timepoint. Please specify embryo_id (and optionally timepoint) " + "rather than a raw file path." + ) - elif embryo_id: - # Get volume for embryo from FileStore - session_id = agent.session_id + if not embryo_id: + return "Error: Specify embryo_id (and optionally timepoint)." - if timepoint is not None: - # Try to find specific timepoint via FileStore - volume_path = agent.store.get_volume_path(session_id, embryo_id, timepoint) - if volume_path and volume_path.exists(): - title = f"{embryo_id} - t{timepoint:04d}" - else: - # Check recent_images as fallback - embryo, err = get_embryo_or_error(agent, embryo_id) - if err: - return err - if embryo.recent_images: - matching = [img for img in embryo.recent_images if img.timepoint == timepoint] - if matching: - volume_path = Path(matching[0].volume_path) - title = f"{embryo_id} - t{timepoint:04d}" - - if not volume_path or not volume_path.exists(): - # List available timepoints from store - volumes = agent.store.list_volumes(session_id, embryo_id) - available = sorted([v['timepoint'] for v in volumes]) - return f"Timepoint {timepoint} not found for {embryo_id}. Available: {available}" - else: - # Find latest volume from store + # Resolve the timepoint (specific or latest) and confirm the volume exists. + if timepoint is not None: + volume_path = agent.store.get_volume_path(session_id, embryo_id, timepoint) + if not volume_path or not Path(volume_path).exists(): volumes = agent.store.list_volumes(session_id, embryo_id) - if not volumes: + available = sorted(v["timepoint"] for v in volumes) + if not available: return f"No volumes found for {embryo_id} in session {session_id}" - - # Find highest timepoint - latest = max(volumes, key=lambda v: v['timepoint']) - latest_tp = latest['timepoint'] - volume_path = agent.store.get_volume_path(session_id, embryo_id, latest_tp) - - title = f"{embryo_id} - t{latest_tp:04d}" - + return f"Timepoint {timepoint} not found for {embryo_id}. Available: {available}" else: - return "Error: Specify either embryo_id or file_path" + volumes = agent.store.list_volumes(session_id, embryo_id) + if not volumes: + return f"No volumes found for {embryo_id} in session {session_id}" + timepoint = max(v["timepoint"] for v in volumes) + + # Drive the in-browser viewer \u2014 no blocking Qt/desktop window. + viz = getattr(agent, "viz_server", None) + if viz is None: + return ( + f"Resolved {embryo_id} t{timepoint:04d}, but the web UI isn't running, " + f"so there's nowhere to display it. Start the web UI and try again." + ) - # Load the volume try: - volume = tifffile.imread(str(volume_path)) - logger.info("Loaded volume: %s, dtype=%s", volume.shape, volume.dtype) + n_clients = await viz.open_volume_in_browser(embryo_id, timepoint) except Exception as e: - return f"Error loading volume: {e}" - - # Open in napari - logger.info("Opening napari viewer...") - viewer = napari.Viewer(title=title) - - # Add volume with appropriate settings - viewer.add_image( - volume, - name='Volume', - colormap='gray', - rendering='mip', # Maximum intensity projection for 3D + logger.exception("open_volume_in_browser failed") + return f"Error opening volume in the web viewer: {e}" + + url = f"http://localhost:{getattr(viz, 'port', 8080)}/" + if n_clients <= 0: + return ( + f"Resolved {embryo_id} t{timepoint:04d}, but no browser is connected. " + f"Open {url} and select that embryo/timepoint to view it." + ) + return ( + f"\u2713 Opening {embryo_id} t{timepoint:04d} in the web volume viewer " + f"({n_clients} view(s) connected) \u2014 {url}" ) - # Add scale bar info - viewer.scale_bar.visible = True - viewer.scale_bar.unit = "um" - - napari.run() - - return f"\u2713 Opened volume in napari: {volume_path.name} (shape: {volume.shape})" - @tool( name="list_volumes", description="""List available volumes for an embryo or all embryos. -Shows volume files with timepoints and file sizes. Scans the storage directory for all volumes (not just recent ones in memory). +Shows volume files with timepoints and file sizes. Scans the storage directory for all +volumes (not just recent ones in memory). Use to see what data is available before viewing.""", category=ToolCategory.ANALYSIS, requires_microscope=False, @@ -238,10 +241,7 @@ async def view_volume( ToolExample("List all volumes", {}), ], ) -async def list_volumes( - embryo_id: str = None, - context: Dict = None -) -> str: +async def list_volumes(embryo_id: str | None = None, context: dict | None = None) -> str: """List available volumes""" agent, err = require_agent(context) if err: @@ -254,16 +254,16 @@ async def list_volumes( all_volumes_list = agent.store.list_volumes(session_id, embryo_id) # Group by embryo_id - all_volumes = {} # embryo_id -> list of volume records + all_volumes: dict[str, list[dict]] = {} # embryo_id -> list of volume records for vol in all_volumes_list: - eid = vol['embryo_id'] + eid = vol["embryo_id"] if eid not in all_volumes: all_volumes[eid] = [] all_volumes[eid].append(vol) # Sort by timepoint for eid in all_volumes: - all_volumes[eid].sort(key=lambda x: x['timepoint']) + all_volumes[eid].sort(key=lambda x: x["timepoint"]) if embryo_id: # List volumes for specific embryo @@ -276,7 +276,7 @@ async def list_volumes( lines.append("") for vol in volumes: - tp = vol['timepoint'] + tp = vol["timepoint"] path = agent.store.get_volume_path(session_id, embryo_id, tp) if path and path.exists(): size_mb = path.stat().st_size / (1024 * 1024) @@ -290,25 +290,33 @@ async def list_volumes( return f"No volumes found in session {session_id}" total_files = sum(len(v) for v in all_volumes.values()) - lines.append(f"Available volumes: {total_files} file(s) across {len(all_volumes)} embryo(s)") + lines.append( + f"Available volumes: {total_files} file(s) across {len(all_volumes)} embryo(s)" + ) lines.append(f"Session: {session_id}") for eid in sorted(all_volumes.keys()): volumes = all_volumes[eid] - timepoints = [v['timepoint'] for v in volumes] - tp_range = f"t{min(timepoints):04d}-t{max(timepoints):04d}" if len(timepoints) > 1 else f"t{timepoints[0]:04d}" + timepoints = [v["timepoint"] for v in volumes] + tp_range = ( + f"t{min(timepoints):04d}-t{max(timepoints):04d}" + if len(timepoints) > 1 + else f"t{timepoints[0]:04d}" + ) # Calculate total size total_size = 0 for vol in volumes: - path = agent.store.get_volume_path(session_id, eid, vol['timepoint']) + path = agent.store.get_volume_path(session_id, eid, vol["timepoint"]) if path and path.exists(): total_size += path.stat().st_size / (1024 * 1024) - lines.append(f"\n{eid}: {len(volumes)} volume(s) [{tp_range}] ({total_size:.1f} MB total)") + lines.append( + f"\n{eid}: {len(volumes)} volume(s) [{tp_range}] ({total_size:.1f} MB total)" + ) # Show last few timepoints for vol in volumes[-3:]: - tp = vol['timepoint'] + tp = vol["timepoint"] path = agent.store.get_volume_path(session_id, eid, tp) if path and path.exists(): size_mb = path.stat().st_size / (1024 * 1024) diff --git a/gently/app/video_maker.py b/gently/app/video_maker.py index 942301cf..f081ae23 100644 --- a/gently/app/video_maker.py +++ b/gently/app/video_maker.py @@ -4,19 +4,16 @@ Creates MP4 videos from max projections of timelapse volumes. """ -import numpy as np -from pathlib import Path -from typing import List, Optional, Dict, Tuple -from datetime import datetime import logging +from datetime import datetime +from pathlib import Path + +import numpy as np logger = logging.getLogger(__name__) -def discover_volumes( - session_dir: Path, - embryo_id: Optional[str] = None -) -> Dict[str, List[Path]]: +def discover_volumes(session_dir: Path, embryo_id: str | None = None) -> dict[str, list[Path]]: """ Discover volume files in a session directory. @@ -39,7 +36,7 @@ def discover_volumes( tif_files = list(session_dir.glob("*.tif")) + list(session_dir.glob("*.tiff")) # Group by embryo ID (filename format: embryo_1_20251210_095317.tif) - embryo_volumes: Dict[str, List[Tuple[datetime, Path]]] = {} + embryo_volumes: dict[str, list[tuple[datetime, Path]]] = {} for f in tif_files: parts = f.stem.split("_") @@ -82,16 +79,19 @@ def make_max_projection(volume: np.ndarray) -> np.ndarray: return volume -def normalize_for_video(image: np.ndarray, percentile_low: float = 1, percentile_high: float = 99.5) -> np.ndarray: +def normalize_for_video( + image: np.ndarray, percentile_low: float = 1, percentile_high: float = 99.5 +) -> np.ndarray: """Normalize image to 8-bit for video encoding.""" from gently.core.imaging import normalize_to_uint8 - return normalize_to_uint8(image, method="percentile", p_low=percentile_low, p_high=percentile_high) + + return normalize_to_uint8( + image, method="percentile", p_low=percentile_low, p_high=percentile_high + ) def add_timestamp_overlay( - image: np.ndarray, - timestamp: str, - position: str = "top-left" + image: np.ndarray, timestamp: str, position: str = "top-left" ) -> np.ndarray: """Add timestamp text overlay to image.""" import cv2 @@ -126,7 +126,7 @@ def add_timestamp_overlay( (x - 2, y - text_height - 2), (x + text_width + 2, y + baseline + 2), (0, 0, 0), - -1 + -1, ) # Draw text @@ -136,13 +136,13 @@ def add_timestamp_overlay( def create_timelapse_video( - volume_paths: List[Path], + volume_paths: list[Path], output_path: Path, fps: int = 10, add_timestamps: bool = True, - embryo_id: str = None, - progress_callback=None -) -> Dict: + embryo_id: str | None = None, + progress_callback=None, +) -> dict: """ Create MP4 video from list of volume files. @@ -223,21 +223,17 @@ def create_timelapse_video( # Try different codecs in order of preference codecs = [ - ('mp4v', '.mp4'), - ('avc1', '.mp4'), - ('XVID', '.avi'), - ('MJPG', '.avi'), + ("mp4v", ".mp4"), + ("avc1", ".mp4"), + ("XVID", ".avi"), + ("MJPG", ".avi"), ] for codec, ext in codecs: fourcc = cv2.VideoWriter_fourcc(*codec) test_path = output_path.with_suffix(ext) writer = cv2.VideoWriter( - str(test_path), - fourcc, - fps, - (width, height), - isColor=True + str(test_path), fourcc, fps, (width, height), isColor=True ) if writer.isOpened(): output_path = test_path @@ -274,14 +270,14 @@ def create_timelapse_video( "frame_count": frame_count, "duration_seconds": duration, "fps": fps, - "resolution": f"{first_shape[1]}x{first_shape[0]}" if first_shape else "unknown" + "resolution": f"{first_shape[1]}x{first_shape[0]}" if first_shape else "unknown", } except Exception as e: if writer: try: writer.release() - except: + except Exception: pass return {"error": str(e)} @@ -289,11 +285,11 @@ def create_timelapse_video( def make_session_videos( storage_path: Path, session_id: str, - output_dir: Optional[Path] = None, - embryo_ids: Optional[List[str]] = None, + output_dir: Path | None = None, + embryo_ids: list[str] | None = None, fps: int = 10, - progress_callback=None -) -> Dict[str, Dict]: + progress_callback=None, +) -> dict[str, dict]: """ Create videos for all embryos in a session. @@ -344,9 +340,9 @@ def make_session_videos( output_path = output_dir / f"{embryo_id}_timelapse.mp4" - def embryo_progress(current, total): + def embryo_progress(current, total, _eid=embryo_id): if progress_callback: - progress_callback(embryo_id, current, total) + progress_callback(_eid, current, total) result = create_timelapse_video( volume_paths=volumes, @@ -354,7 +350,7 @@ def embryo_progress(current, total): fps=fps, add_timestamps=True, embryo_id=embryo_id, - progress_callback=embryo_progress + progress_callback=embryo_progress, ) results[embryo_id] = result diff --git a/gently/app/wake_router.py b/gently/app/wake_router.py new file mode 100644 index 00000000..87f4b577 --- /dev/null +++ b/gently/app/wake_router.py @@ -0,0 +1,281 @@ +"""Decision-moment wake-router for autonomous agent turns. + +Subscribes to wake-worthy perception/lifecycle events and, when enabled, wakes +the conversational agent between user messages so it can re-decide acquisition +(cadence, power, stop conditions) in response to what perception sees — the +closed loop. + +Design (opt-in, default OFF): + * Triggers: critical events (hatching / arrest / embryo-terminated / errors) + plus developmental stage transitions. No periodic heartbeat. + * Debounce: a burst of events inside COALESCE_WINDOW collapses into ONE wake. + * Throttle: non-critical wakes are rate-limited by MIN_WAKE_INTERVAL; critical + events bypass the throttle. + * Serialization: the wake turn runs through the agent's normal streaming + pipeline, which holds the agent turn-lock, so it never races a user turn. + A wake therefore waits for any in-progress user turn — including an open + choice picker — to finish before it runs; "critical bypasses the throttle" + means it skips MIN_WAKE_INTERVAL, not that it preempts an active user turn + (preempting would interleave on the shared conversation history). + +Nothing fires until ``set_enabled(True)`` (e.g. via the set_autonomy tool). +""" + +from __future__ import annotations + +import asyncio +import logging + +from gently.core.event_bus import EventType + +logger = logging.getLogger(__name__) + +# Tunables (seconds). +COALESCE_WINDOW = 20.0 # collapse a burst of events into one wake +MIN_WAKE_INTERVAL = 120.0 # throttle non-critical wakes +ASK_TIMEOUT_SEC = 300.0 # ASK mode: how long to wait for operator approval -> Skip + +# Events that always wake immediately (bypass MIN_WAKE_INTERVAL). +CRITICAL_EVENTS = frozenset( + { + EventType.HATCHING_DETECTED, + EventType.EMBRYO_TERMINATED, + EventType.ERROR_OCCURRED, + EventType.ACQUISITION_FAILED, + EventType.ANOMALY_DETECTED, + } +) +# Non-critical events we also inspect (filtered for real transitions / arrest). +WATCH_EVENTS = frozenset({EventType.DETECTOR_EVALUATED}) + + +class WakeRouter: + """Routes wake-worthy events into coalesced, throttled autonomous agent turns.""" + + def __init__(self, agent, bus): + self.agent = agent + self.bus = bus + self.mode = "off" # 'off' | 'ask' | 'auto' + self._loop = None + self._pending = [] # list[(EventType, dict)] + self._flush_handle = None # TimerHandle for the coalesce window + self._last_wake = 0.0 # loop.time() of the last fired wake + self._last_stage = {} # embryo_id -> last stage seen (transition detection) + self._in_flight = False + self._unsubs = [] + self._subscribe() + + # -- public control ------------------------------------------------- + @property + def enabled(self) -> bool: + return self.mode != "off" + + def set_mode(self, mode: str) -> str: + mode = (mode or "off").strip().lower() + if mode not in ("off", "ask", "auto"): + mode = "off" + self.mode = mode + if mode == "off": + self._cancel_flush() + self._pending.clear() + logger.info("Wake-router mode -> %s", mode.upper()) + return self.mode + + def set_enabled(self, enabled: bool) -> bool: + """Back-compat boolean toggle: maps to AUTO / OFF.""" + self.set_mode("auto" if enabled else "off") + return self.enabled + + def is_enabled(self) -> bool: + return self.enabled + + def shutdown(self): + self._cancel_flush() + for unsub in self._unsubs: + try: + unsub() + except Exception: + pass + self._unsubs.clear() + + # -- subscription --------------------------------------------------- + def _subscribe(self): + for et in CRITICAL_EVENTS | WATCH_EVENTS: + try: + self._unsubs.append( + self.bus.subscribe(et, lambda e, _et=et: self._on_event(_et, e)) + ) + except Exception: + logger.exception("wake-router failed to subscribe %s", et) + + # -- event intake --------------------------------------------------- + def _on_event(self, event_type, event): + # Synchronous handler (the bus calls it inline). Cheap-filter, then + # schedule a coalesced flush on the running loop. Never raise — the bus + # swallows handler exceptions, so failures would otherwise vanish. + try: + if not self.enabled: + return + if self._loop is None: + try: + self._loop = asyncio.get_running_loop() + except RuntimeError: + return # no running loop -> can't schedule a wake; drop + data = getattr(event, "data", None) or {} + if not self._is_wake_worthy(event_type, data): + return + self._pending.append((event_type, data)) + self._schedule_flush(critical=event_type in CRITICAL_EVENTS) + except Exception: + logger.exception("wake-router _on_event error") + + def _is_wake_worthy(self, event_type, data) -> bool: + if event_type in CRITICAL_EVENTS: + return True + if event_type == EventType.DETECTOR_EVALUATED: + if data.get("skipped"): + return False + if data.get("detector_name") != "perception": + return False # role=test pseudo-stages are not developmental + stage = data.get("stage") + if not stage or stage == "no_object": + return False # empty-field sentinel — not a developmental change + ta = data.get("temporal_analysis") or {} + if ta.get("is_potentially_arrested"): + return True + eid = data.get("embryo_id") + last = self._last_stage.get(eid) + self._last_stage[eid] = stage + return stage != last # only a real transition wakes + return False + + # -- coalescing / flush -------------------------------------------- + def _schedule_flush(self, critical: bool): + loop = self._loop + if loop is None: + return + delay = 0.0 if critical else COALESCE_WINDOW + if self._flush_handle is None: + self._flush_handle = loop.call_later(delay, self._fire_flush) + elif critical: + # bring a pending window-flush forward + self._flush_handle.cancel() + self._flush_handle = loop.call_later(0.0, self._fire_flush) + + def _cancel_flush(self): + if self._flush_handle is not None: + try: + self._flush_handle.cancel() + except Exception: + pass + self._flush_handle = None + + def _fire_flush(self): + self._flush_handle = None + loop = self._loop + if loop is not None: + asyncio.ensure_future(self._flush(), loop=loop) + + async def _flush(self): + if not self._pending or not self.enabled: + self._pending.clear() + return + # Evaluate the guards BEFORE draining so a deferral can't lose events. + critical = any(et in CRITICAL_EVENTS for et, _ in self._pending) + now = self._loop.time() if self._loop else 0.0 + if self._in_flight or (not critical and (now - self._last_wake) < MIN_WAKE_INTERVAL): + # A wake is already running, or we're inside the non-critical throttle + # window. Keep _pending intact and re-arm so these events — including + # any CRITICAL ones — are retried once the turn finishes / window + # elapses, rather than being dropped. + logger.debug("wake deferred (in_flight=%s critical=%s)", self._in_flight, critical) + # Retry on the coalesce window (not delay 0) so a critical event + # deferred behind an in-flight turn doesn't busy-spin call_later(0). + self._schedule_flush(critical=False) + return + events = self._pending + self._pending = [] + self._in_flight = True + self._last_wake = now + try: + ask = self.mode == "ask" + note, trigger = self._build_wake_note(events, ask=ask) + logger.info( + "Wake-router firing %s turn (%d event(s)): %s", + self.mode.upper(), + len(events), + trigger, + ) + await self.agent.run_wake_turn(note, trigger=trigger, interactive=ask) + except Exception: + logger.exception("wake turn failed") + finally: + self._in_flight = False + # Events that arrived while we were busy (including deferred CRITICAL + # ones) are still in _pending — re-fire promptly rather than waiting + # out another coalesce window. _in_flight is now False so this flush + # will proceed instead of deferring (no busy-spin). + if self._pending and self.enabled: + self._schedule_flush(critical=any(et in CRITICAL_EVENTS for et, _ in self._pending)) + + # -- wake package --------------------------------------------------- + def _build_wake_note(self, events, ask=False): + """Return (note, trigger_str). The note is the agent-facing wake prompt; + trigger_str is the short human-readable reason shown in the chat banner. + When ask=True the note instructs propose-then-confirm instead of acting.""" + from gently.harness.prompts.templates import build_perception_snapshot + + triggers = [] + for et, data in events: + name = getattr(et, "name", str(et)) + eid = data.get("embryo_id", "?") + stage = data.get("stage") + if et == EventType.HATCHING_DETECTED: + triggers.append(f"{eid}: hatching detected") + elif et == EventType.EMBRYO_TERMINATED: + triggers.append(f"{eid}: terminated ({data.get('completion_reason', '?')})") + elif et in ( + EventType.ERROR_OCCURRED, + EventType.ACQUISITION_FAILED, + EventType.ANOMALY_DETECTED, + ): + triggers.append(f"{eid}: {name.lower().replace('_', ' ')}") + elif et == EventType.DETECTOR_EVALUATED: + ta = data.get("temporal_analysis") or {} + if ta.get("is_potentially_arrested"): + triggers.append(f"{eid}: potential arrest at stage {stage}") + else: + triggers.append(f"{eid}: stage -> {stage}") + else: + triggers.append(f"{eid}: {name.lower()}") + triggers = list(dict.fromkeys(triggers)) # dedupe, preserve order + + try: + snap = build_perception_snapshot( + getattr(self.agent, "perceiver", None), + getattr(getattr(self.agent, "experiment", None), "embryos", {}) or {}, + ) + except Exception: + snap = "" + snap = snap or "(no live perception data)" + trigger_str = "; ".join(triggers) + + head = ( + "[AUTONOMOUS WAKE] Something changed while no one was typing.\n\n" + f"What triggered this: {trigger_str}\n\n" + f"{snap}\n\n" + ) + if ask: + tail = ( + "Decide whether any acquisition change is warranted. If so, briefly " + "state your proposed change and WHY, then call ask_user_choice with " + "options Approve / Modify / Skip and act ONLY if the operator approves. " + "If nothing needs doing, say so briefly and take no action (no need to ask)." + ) + else: + tail = ( + "If a change helps (adjust interval/power, add a stop condition, queue a " + "burst, or stop an embryo), do it now using your tools. If nothing needs " + "doing, say so briefly and take no action." + ) + return head + tail, trigger_str diff --git a/gently/core/__init__.py b/gently/core/__init__.py index fc0406ff..dd351143 100644 --- a/gently/core/__init__.py +++ b/gently/core/__init__.py @@ -9,43 +9,40 @@ from .event_bus import ( Event, - EventType, EventBus, - get_event_bus, - set_event_bus, + EventType, emit, - on, + get_event_bus, handles, + on, + set_event_bus, ) - from .service import ( Service, - ServiceState, + ServiceClient, ServiceInfo, ServiceRegistry, - ServiceClient, - + ServiceState, get_service_registry, set_service_registry, ) __all__ = [ # Event bus - 'Event', - 'EventType', - 'EventBus', - 'get_event_bus', - 'set_event_bus', - 'emit', - 'on', - 'handles', + "Event", + "EventType", + "EventBus", + "get_event_bus", + "set_event_bus", + "emit", + "on", + "handles", # Service - 'Service', - 'ServiceState', - 'ServiceInfo', - 'ServiceRegistry', - 'ServiceClient', - - 'get_service_registry', - 'set_service_registry', + "Service", + "ServiceState", + "ServiceInfo", + "ServiceRegistry", + "ServiceClient", + "get_service_registry", + "set_service_registry", ] diff --git a/gently/core/coordinates.py b/gently/core/coordinates.py index 944fa219..9948ad97 100644 --- a/gently/core/coordinates.py +++ b/gently/core/coordinates.py @@ -15,8 +15,6 @@ All other files should import from here. """ -from typing import Tuple - # Default optical parameters for bottom detection camera DEFAULT_PIXEL_SIZE_UM = 6.5 DEFAULT_OBJECTIVE_MAG = 10.0 # 10x objective on bottom camera @@ -29,7 +27,7 @@ def get_um_per_pixel( pixel_size_um: float = DEFAULT_PIXEL_SIZE_UM, - objective_mag: float = DEFAULT_OBJECTIVE_MAG + objective_mag: float = DEFAULT_OBJECTIVE_MAG, ) -> float: """ Calculate microns per pixel for the optical system. @@ -56,8 +54,8 @@ def pixel_to_stage_position( image_center_y: float, stage_x: float, stage_y: float, - um_per_pixel: float = None -) -> Tuple[float, float]: + um_per_pixel: float | None = None, +) -> tuple[float, float]: """ Convert pixel coordinates to stage position (for embryo POSITION calculation). @@ -101,7 +99,8 @@ def pixel_to_stage_position( dy_pixels = pixel_y - image_center_y # X: NOT inverted (stage coords match image coords for X) - # Y: IS inverted (stage +Y moves embryo down, but image +Y is also down, need to invert for centering) + # Y: IS inverted (stage +Y moves embryo down, but image +Y is also down, + # need to invert for centering) embryo_stage_x = stage_x + dx_pixels * um_per_pixel embryo_stage_y = stage_y - dy_pixels * um_per_pixel @@ -115,8 +114,8 @@ def stage_to_pixel_position( current_stage_y: float, image_center_x: float, image_center_y: float, - um_per_pixel: float = None -) -> Tuple[float, float]: + um_per_pixel: float | None = None, +) -> tuple[float, float]: """ Convert stage position to pixel coordinates (for DISPLAY/visualization). @@ -155,10 +154,8 @@ def stage_to_pixel_position( def pixel_displacement_to_stage_movement( - pixel_displacement_x: float, - pixel_displacement_y: float, - um_per_pixel: float = None -) -> Tuple[float, float]: + pixel_displacement_x: float, pixel_displacement_y: float, um_per_pixel: float | None = None +) -> tuple[float, float]: """ Convert pixel displacement to stage MOVEMENT (for centering an embryo). diff --git a/gently/core/database.py b/gently/core/database.py index 7e2a6255..59737bee 100644 --- a/gently/core/database.py +++ b/gently/core/database.py @@ -11,17 +11,18 @@ analysis tools and workflows. """ -import logging import json +import logging from datetime import datetime - -logger = logging.getLogger(__name__) from pathlib import Path -from typing import Dict, List, Optional, Any +from typing import Any + import numpy as np +logger = logging.getLogger(__name__) -def format_timestamp(dt: Optional[datetime] = None) -> str: + +def format_timestamp(dt: datetime | None = None) -> str: """ Format datetime as ISO 8601 string for JSON storage. @@ -68,7 +69,7 @@ def numpy_to_python(obj: Any) -> Any: return obj -def format_embryo_calibration_for_json(calibration_data: Dict) -> Dict: +def format_embryo_calibration_for_json(calibration_data: dict) -> dict: """ Format embryo calibration data from databroker for JSON export. @@ -90,34 +91,34 @@ def format_embryo_calibration_for_json(calibration_data: Dict) -> Dict: # Ensure required fields exist required_fields = [ - 'slope_um_per_deg', - 'offset_um', - 'galvo_top_deg', - 'galvo_bottom_deg', - 'piezo_top_um', - 'piezo_bottom_um', - 'sample_type', - 'timestamp', - 'device_piezo', - 'device_galvo' + "slope_um_per_deg", + "offset_um", + "galvo_top_deg", + "galvo_bottom_deg", + "piezo_top_um", + "piezo_bottom_um", + "sample_type", + "timestamp", + "device_piezo", + "device_galvo", ] for field in required_fields: if field not in calibration_json: # Provide sensible defaults for missing fields - if field == 'sample_type': - calibration_json[field] = 'embryo' - elif field == 'timestamp': + if field == "sample_type": + calibration_json[field] = "embryo" + elif field == "timestamp": calibration_json[field] = format_timestamp() - elif 'device' in field: - calibration_json[field] = 'unknown' + elif "device" in field: + calibration_json[field] = "unknown" else: calibration_json[field] = None return calibration_json -def format_embryo_entry_for_json(embryo_data: Dict) -> Dict: +def format_embryo_entry_for_json(embryo_data: dict) -> dict: """ Format single embryo entry from databroker for JSON export. @@ -132,34 +133,31 @@ def format_embryo_entry_for_json(embryo_data: Dict) -> Dict: JSON-compatible embryo entry """ entry = { - 'embryo_number': int(embryo_data.get('embryo_number', 0)), - 'marking_timestamp': embryo_data.get('marking_timestamp', format_timestamp()), - 'bottom_camera_position_pixel': { - 'x': float(embryo_data.get('pixel_x', 0.0)), - 'y': float(embryo_data.get('pixel_y', 0.0)) + "embryo_number": int(embryo_data.get("embryo_number", 0)), + "marking_timestamp": embryo_data.get("marking_timestamp", format_timestamp()), + "bottom_camera_position_pixel": { + "x": float(embryo_data.get("pixel_x", 0.0)), + "y": float(embryo_data.get("pixel_y", 0.0)), }, - 'initial_stage_position_um': { - 'x': float(embryo_data.get('initial_stage_x', 0.0)), - 'y': float(embryo_data.get('initial_stage_y', 0.0)) + "initial_stage_position_um": { + "x": float(embryo_data.get("initial_stage_x", 0.0)), + "y": float(embryo_data.get("initial_stage_y", 0.0)), + }, + "stage_position_after_centering_um": { + "x": float(embryo_data.get("centered_stage_x", 0.0)), + "y": float(embryo_data.get("centered_stage_y", 0.0)), }, - 'stage_position_after_centering_um': { - 'x': float(embryo_data.get('centered_stage_x', 0.0)), - 'y': float(embryo_data.get('centered_stage_y', 0.0)) - } } # Add calibration data if present - if 'calibration' in embryo_data: - entry['calibration'] = format_embryo_calibration_for_json(embryo_data['calibration']) + if "calibration" in embryo_data: + entry["calibration"] = format_embryo_calibration_for_json(embryo_data["calibration"]) return entry def export_multi_embryo_database( - databroker_catalog, - session_uid: str, - output_path: Path, - pretty_print: bool = True + databroker_catalog, session_uid: str, output_path: Path, pretty_print: bool = True ) -> Path: """ Export multi-embryo calibration data from databroker to JSON database file. @@ -194,25 +192,25 @@ def export_multi_embryo_database( try: session_run = databroker_catalog[session_uid] except KeyError: - raise KeyError(f"Session UID {session_uid} not found in databroker") + raise KeyError(f"Session UID {session_uid} not found in databroker") from None # Handle different databroker API versions try: # v2 API - session_metadata = session_run.metadata['start'] + session_metadata = session_run.metadata["start"] except (AttributeError, KeyError): # v1 API - session_metadata = session_run['start'] + session_metadata = session_run["start"] # Initialize database structure database = { - 'created': session_metadata.get('time', format_timestamp()), - 'embryos': {}, - 'last_updated': format_timestamp() + "created": session_metadata.get("time", format_timestamp()), + "embryos": {}, + "last_updated": format_timestamp(), } # Get list of embryo run UIDs from session metadata - embryo_uids = session_metadata.get('embryo_runs', []) + embryo_uids = session_metadata.get("embryo_runs", []) if not embryo_uids: logger.warning("No embryo runs found in session %s...", session_uid[:8]) @@ -225,17 +223,19 @@ def export_multi_embryo_database( # Get embryo metadata try: - embryo_metadata = embryo_run.metadata['start'] + embryo_metadata = embryo_run.metadata["start"] except (AttributeError, KeyError): - embryo_metadata = embryo_run['start'] + embryo_metadata = embryo_run["start"] - embryo_id = embryo_metadata.get('embryo_id', f"embryo_{len(database['embryos'])+1:03d}") + embryo_id = embryo_metadata.get( + "embryo_id", f"embryo_{len(database['embryos']) + 1:03d}" + ) # Format embryo entry embryo_entry = format_embryo_entry_for_json(embryo_metadata) # Add to database - database['embryos'][embryo_id] = embryo_entry + database["embryos"][embryo_id] = embryo_entry except Exception as e: logger.warning("Could not export embryo %s...: %s", embryo_uid[:8], e) @@ -243,19 +243,23 @@ def export_multi_embryo_database( # Write to JSON file output_path = Path(output_path) - with open(output_path, 'w') as f: + with open(output_path, "w") as f: if pretty_print: json.dump(database, f, indent=2) else: json.dump(database, f) - logger.info("Exported multi-embryo database: File=%s, Embryos=%d, Session=%s...", - output_path, len(database['embryos']), session_uid[:8]) + logger.info( + "Exported multi-embryo database: File=%s, Embryos=%d, Session=%s...", + output_path, + len(database["embryos"]), + session_uid[:8], + ) return output_path -def load_multi_embryo_database(database_path: Path) -> Dict: +def load_multi_embryo_database(database_path: Path) -> dict: """ Load existing multi-embryo database from JSON file. @@ -274,18 +278,18 @@ def load_multi_embryo_database(database_path: Path) -> Dict: if not database_path.exists(): # Return empty database structure return { - 'created': format_timestamp(), - 'embryos': {}, - 'last_updated': format_timestamp() + "created": format_timestamp(), + "embryos": {}, + "last_updated": format_timestamp(), } - with open(database_path, 'r') as f: + with open(database_path) as f: database = json.load(f) return database -def save_multi_embryo_database(database: Dict, database_path: Path): +def save_multi_embryo_database(database: dict, database_path: Path): """ Save multi-embryo database to JSON file. @@ -299,17 +303,13 @@ def save_multi_embryo_database(database: Dict, database_path: Path): database_path = Path(database_path) # Update last_updated timestamp - database['last_updated'] = format_timestamp() + database["last_updated"] = format_timestamp() - with open(database_path, 'w') as f: + with open(database_path, "w") as f: json.dump(database, f, indent=2) -def add_embryo_to_database( - database: Dict, - embryo_id: str, - embryo_data: Dict -) -> Dict: +def add_embryo_to_database(database: dict, embryo_id: str, embryo_data: dict) -> dict: """ Add or update embryo entry in database. @@ -331,13 +331,13 @@ def add_embryo_to_database( embryo_entry = format_embryo_entry_for_json(embryo_data) # Add to database - database['embryos'][embryo_id] = embryo_entry - database['last_updated'] = format_timestamp() + database["embryos"][embryo_id] = embryo_entry + database["last_updated"] = format_timestamp() return database -def get_embryo_calibration(database: Dict, embryo_id: str) -> Optional[Dict]: +def get_embryo_calibration(database: dict, embryo_id: str) -> dict | None: """ Get calibration data for specific embryo from database. @@ -353,15 +353,15 @@ def get_embryo_calibration(database: Dict, embryo_id: str) -> Optional[Dict]: dict or None Calibration dictionary, or None if not found """ - embryo_entry = database.get('embryos', {}).get(embryo_id) + embryo_entry = database.get("embryos", {}).get(embryo_id) if embryo_entry is None: return None - return embryo_entry.get('calibration') + return embryo_entry.get("calibration") -def list_embryos(database: Dict) -> List[str]: +def list_embryos(database: dict) -> list[str]: """ List all embryo IDs in database. @@ -375,12 +375,9 @@ def list_embryos(database: Dict) -> List[str]: list of str Embryo IDs sorted by embryo number """ - embryos = database.get('embryos', {}) + embryos = database.get("embryos", {}) # Sort by embryo_number - sorted_embryos = sorted( - embryos.items(), - key=lambda x: x[1].get('embryo_number', 0) - ) + sorted_embryos = sorted(embryos.items(), key=lambda x: x[1].get("embryo_number", 0)) return [embryo_id for embryo_id, _ in sorted_embryos] diff --git a/gently/core/event_bus.py b/gently/core/event_bus.py index b086f172..3b98e616 100644 --- a/gently/core/event_bus.py +++ b/gently/core/event_bus.py @@ -10,14 +10,14 @@ import asyncio import logging -import time +import threading +import uuid +from collections import deque +from collections.abc import Callable from dataclasses import dataclass, field from datetime import datetime from enum import Enum, auto -from typing import Any, Callable, Dict, List, Optional, Set, Union -from collections import deque -import threading -import uuid +from typing import Any logger = logging.getLogger(__name__) @@ -44,19 +44,22 @@ class EventType(Enum): EMBRYO_CENTERED = auto() EMBRYO_CALIBRATED = auto() EMBRYO_SKIPPED = auto() + # {embryo_id, completion_reason} - emitted when an embryo's imaging stops + # (any reason: no_object terminal, stop condition met, errors, user removal) + EMBRYO_TERMINATED = auto() # Analysis events ANALYSIS_STARTED = auto() ANALYSIS_COMPLETED = auto() - DETECTOR_EVALUATED = auto() # Emitted for every detector run (all evaluations) + DETECTOR_EVALUATED = auto() # Emitted for every detector run (all evaluations) DETECTION_TRIGGERED = auto() # Emitted only when detected=True (positive detection) HATCHING_DETECTED = auto() # Verification events (multi-strategy verification for detections) - VERIFICATION_STARTED = auto() # Verification round begins for embryo - VERIFICATION_STRATEGY = auto() # Individual strategy result (adversarial, temporal, etc.) - VERIFICATION_PROGRESS = auto() # Progress update (e.g., "3/5 strategies complete") - VERIFICATION_COMPLETED = auto() # Final verification result with consensus + VERIFICATION_STARTED = auto() # Verification round begins for embryo + VERIFICATION_STRATEGY = auto() # Individual strategy result (adversarial, temporal, etc.) + VERIFICATION_PROGRESS = auto() # Progress update (e.g., "3/5 strategies complete") + VERIFICATION_COMPLETED = auto() # Final verification result with consensus # CV Subagent events SEGMENTATION_COMPLETED = auto() @@ -80,8 +83,22 @@ class EventType(Enum): STAGE_MOVED = auto() FOCUS_CHANGED = auto() LASER_CHANGED = auto() - DEVICE_STATE_UPDATE = auto() # Periodic device-state snapshot from device layer - BOTTOM_CAMERA_FRAME = auto() # Live JPEG frame from the bottom camera stream + DEVICE_STATE_UPDATE = auto() # Periodic device-state snapshot from device layer + BOTTOM_CAMERA_FRAME = auto() # Live JPEG frame from the bottom camera stream + EMBRYOS_UPDATE = auto() # Full embryo list snapshot from agent.experiment + + # Python logging.LogRecord republished onto the bus so the Events page + # surfaces what would otherwise only land in the terminal. See + # gently/core/log_bridge.py — opt-in handler. + LOG_RECORD = auto() + + # Operator-action events. Distinct from EMBRYOS_UPDATE because they + # carry intent ("a human did this") rather than just state delta. + # Candidate orchestrators can subscribe and reason about what the + # operator just did without having to type it in chat. + OPERATOR_EDITED_EMBRYO = auto() # Map drag/drop -> PUT /api/embryos/{id}/position + OPERATOR_REMOVED_EMBRYO = auto() # Map delete -> DELETE /api/embryos/{id} + OPERATOR_MARKED_EMBRYOS = auto() # Marking canvas "Done" — operator confirmed N positions # System events ERROR_OCCURRED = auto() @@ -89,12 +106,14 @@ class EventType(Enum): STATUS_CHANGED = auto() # Async timelapse — per-embryo cadence transitions (Phase 4) - EMBRYO_CADENCE_CHANGED = auto() # {embryo_id, old_phase, new_phase, old_interval_s, new_interval_s, next_due_at} + EMBRYO_CADENCE_CHANGED = ( + auto() + ) # {embryo_id, old_phase, new_phase, old_interval_s, new_interval_s, next_due_at} # Burst lifecycle (Phase 7 / 10) - BURST_QUEUED = auto() # {embryo_id, request_id, position_in_queue} - BURST_START = auto() # {embryo_id, request_id, frames, mode} - BURST_FRAME = auto() # {embryo_id, request_id, frame_idx, total_frames} + BURST_QUEUED = auto() # {embryo_id, request_id, position_in_queue} + BURST_START = auto() # {embryo_id, request_id, frames, mode} + BURST_FRAME = auto() # {embryo_id, request_id, frame_idx, total_frames} BURST_COMPLETE = auto() # {embryo_id, request_id, mp4_path, sustained_hz, frames_captured} # Reactive control telemetry (Phase 5 / 10) @@ -132,17 +151,17 @@ class EventType(Enum): MESH_SCOPE_DENIED = auto() # Mesh topology events - MESH_PEER_OFFLINE = auto() # peer marked offline in verse map (kept in map) - MESH_PEER_RETURNED = auto() # previously offline peer came back online + MESH_PEER_OFFLINE = auto() # peer marked offline in verse map (kept in map) + MESH_PEER_RETURNED = auto() # previously offline peer came back online # ML pipeline events ML_PIPELINE_CREATED = auto() ML_TRAINING_STARTED = auto() - ML_TRAINING_PROGRESS = auto() # per-epoch updates + ML_TRAINING_PROGRESS = auto() # per-epoch updates ML_TRAINING_COMPLETED = auto() ML_TRAINING_FAILED = auto() ML_EVALUATION_COMPLETED = auto() - ML_SUBAGENT_STATUS = auto() # subagent thinking/planning updates + ML_SUBAGENT_STATUS = auto() # subagent thinking/planning updates # Bulk transfer events TRANSFER_STARTED = auto() @@ -154,10 +173,15 @@ class EventType(Enum): # High-volume telemetry events that skip the bounded history deque. These # fire many times per second and would push out events that humans actually # want to inspect later (acquisitions, perceptions, errors). -_NO_HISTORY_TYPES = frozenset({ - EventType.DEVICE_STATE_UPDATE, - EventType.BOTTOM_CAMERA_FRAME, # ~2 Hz JPEG frames — would crowd history out -}) +_NO_HISTORY_TYPES = frozenset( + { + EventType.DEVICE_STATE_UPDATE, + EventType.BOTTOM_CAMERA_FRAME, # ~2 Hz JPEG frames — would crowd history out + EventType.LOG_RECORD, # log lines can hit hundreds/min during + # calibration; durable copy is in the + # gently_*.log file already + } +) @dataclass @@ -180,37 +204,40 @@ class Event: correlation_id : str, optional ID to correlate related events (e.g., request/response) """ + event_type: EventType - data: Dict[str, Any] = field(default_factory=dict) + data: dict[str, Any] = field(default_factory=dict) source: str = "unknown" timestamp: datetime = field(default_factory=datetime.now) event_id: str = field(default_factory=lambda: str(uuid.uuid4())[:8]) - correlation_id: Optional[str] = None + correlation_id: str | None = None def __str__(self) -> str: return f"Event({self.event_type.name}, source={self.source}, id={self.event_id})" - def to_dict(self) -> Dict: + def to_dict(self) -> dict: """Serialize for storage/transmission""" return { - 'event_type': self.event_type.name, - 'data': self.data, - 'source': self.source, - 'timestamp': self.timestamp.isoformat(), - 'event_id': self.event_id, - 'correlation_id': self.correlation_id, + "event_type": self.event_type.name, + "data": self.data, + "source": self.source, + "timestamp": self.timestamp.isoformat(), + "event_id": self.event_id, + "correlation_id": self.correlation_id, } @classmethod - def from_dict(cls, d: Dict) -> 'Event': + def from_dict(cls, d: dict) -> "Event": """Deserialize from dict""" return cls( - event_type=EventType[d['event_type']], - data=d.get('data', {}), - source=d.get('source', 'unknown'), - timestamp=datetime.fromisoformat(d['timestamp']) if 'timestamp' in d else datetime.now(), - event_id=d.get('event_id', str(uuid.uuid4())[:8]), - correlation_id=d.get('correlation_id'), + event_type=EventType[d["event_type"]], + data=d.get("data", {}), + source=d.get("source", "unknown"), + timestamp=datetime.fromisoformat(d["timestamp"]) + if "timestamp" in d + else datetime.now(), + event_id=d.get("event_id", str(uuid.uuid4())[:8]), + correlation_id=d.get("correlation_id"), ) @@ -237,17 +264,17 @@ def __init__(self, history_size: int = 100): history_size : int Number of recent events to keep in history """ - self._handlers: Dict[EventType, List[EventHandler]] = {} - self._async_handlers: Dict[EventType, List[AsyncEventHandler]] = {} - self._wildcard_handlers: List[EventHandler] = [] - self._async_wildcard_handlers: List[AsyncEventHandler] = [] + self._handlers: dict[EventType, list[EventHandler]] = {} + self._async_handlers: dict[EventType, list[AsyncEventHandler]] = {} + self._wildcard_handlers: list[EventHandler] = [] + self._async_wildcard_handlers: list[AsyncEventHandler] = [] self._history: deque = deque(maxlen=history_size) self._lock = threading.RLock() - self._event_loop: Optional[asyncio.AbstractEventLoop] = None + self._event_loop: asyncio.AbstractEventLoop | None = None def subscribe( self, - event_type: Union[EventType, str], + event_type: EventType | str, handler: EventHandler, ) -> Callable[[], None]: """ @@ -287,7 +314,7 @@ def unsubscribe(): def subscribe_async( self, - event_type: Union[EventType, str], + event_type: EventType | str, handler: AsyncEventHandler, ) -> Callable[[], None]: """ @@ -328,9 +355,9 @@ def unsubscribe(): def publish( self, event_type: EventType, - data: Optional[Dict] = None, + data: dict | None = None, source: str = "unknown", - correlation_id: Optional[str] = None, + correlation_id: str | None = None, ) -> Event: """ Publish an event to all subscribers @@ -459,10 +486,10 @@ def set_event_loop(self, loop: asyncio.AbstractEventLoop): def get_history( self, - event_type: Optional[EventType] = None, - source: Optional[str] = None, + event_type: EventType | None = None, + source: str | None = None, limit: int = 50, - ) -> List[Event]: + ) -> list[Event]: """ Get recent event history @@ -497,7 +524,7 @@ def clear_history(self): with self._lock: self._history.clear() - def get_handler_count(self, event_type: Optional[EventType] = None) -> int: + def get_handler_count(self, event_type: EventType | None = None) -> int: """Get count of registered handlers""" with self._lock: if event_type: @@ -513,7 +540,7 @@ def get_handler_count(self, event_type: Optional[EventType] = None) -> int: # Global event bus instance -_global_bus: Optional[EventBus] = None +_global_bus: EventBus | None = None def get_event_bus() -> EventBus: @@ -533,20 +560,20 @@ def set_event_bus(bus: EventBus): # Convenience functions for common operations def emit( event_type: EventType, - data: Optional[Dict] = None, + data: dict | None = None, source: str = "unknown", ) -> Event: """Emit an event on the global bus""" return get_event_bus().publish(event_type, data, source) -def on(event_type: Union[EventType, str], handler: EventHandler) -> Callable[[], None]: +def on(event_type: EventType | str, handler: EventHandler) -> Callable[[], None]: """Subscribe to events on the global bus""" return get_event_bus().subscribe(event_type, handler) # Decorator for event handlers -def handles(event_type: Union[EventType, str]): +def handles(event_type: EventType | str): """ Decorator to register a function as an event handler @@ -555,7 +582,9 @@ def handles(event_type: Union[EventType, str]): def on_volume(event): print(f"Got volume: {event.data}") """ + def decorator(func: EventHandler) -> EventHandler: get_event_bus().subscribe(event_type, func) return func + return decorator diff --git a/gently/core/file_store.py b/gently/core/file_store.py index 8ca383ea..5a8b7f01 100644 --- a/gently/core/file_store.py +++ b/gently/core/file_store.py @@ -58,7 +58,7 @@ import time from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any import numpy as np import yaml @@ -66,7 +66,6 @@ from .store_types import ( EmbryoInfo, GroundTruthEntry, - PerceptionRunInfo, PredictionInfo, ProjectionInfo, SessionInfo, @@ -81,6 +80,7 @@ # Helpers # --------------------------------------------------------------------------- + def _slugify(text: str, max_len: int = 30) -> str: """Lowercase, replace non-alphanum with hyphens, truncate.""" if not text: @@ -108,21 +108,56 @@ def _sanitize_for_yaml(obj): return obj +def _coarse_from_legacy(record: dict) -> dict | None: + """Extract coarse XY from an embryo.yaml record, accepting either the new + `position_coarse` dict or the legacy flat `position_x` / `position_y` keys. + Returns None if neither shape carries usable values. + """ + coarse = record.get("position_coarse") + if isinstance(coarse, dict) and coarse: + return coarse + px, py = record.get("position_x"), record.get("position_y") + if px is None and py is None: + return None + out = {} + if px is not None: + out["x"] = px + if py is not None: + out["y"] = py + return out or None + + +def _normalize_embryo_record(record: dict | None) -> dict | None: + """Backfill an embryo.yaml dict so callers always see the new schema. + + Adds `position_coarse` derived from legacy `position_x` / `position_y` if + only the legacy fields are present, and ensures `position_fine` exists + (as None) for forward-compat. The original record is not mutated. + """ + if record is None: + return None + out = dict(record) + if out.get("position_coarse") is None: + backfill = _coarse_from_legacy(out) + if backfill is not None: + out["position_coarse"] = backfill + out.setdefault("position_fine", None) + return out + + def _write_yaml(path: Path, data: Any) -> None: """Write YAML atomically: write to a temp file, then rename.""" path.parent.mkdir(parents=True, exist_ok=True) data = _sanitize_for_yaml(data) - fd, tmp = tempfile.mkstemp( - suffix=".tmp", prefix=path.stem, dir=str(path.parent) - ) + fd, tmp = tempfile.mkstemp(suffix=".tmp", prefix=path.stem, dir=str(path.parent)) try: with os.fdopen(fd, "w", encoding="utf-8") as f: - yaml.safe_dump(data, f, default_flow_style=False, sort_keys=False, - allow_unicode=True) - # On Windows, rename over an existing file requires removing it first. - if path.exists(): - path.unlink() - Path(tmp).rename(path) + yaml.safe_dump(data, f, default_flow_style=False, sort_keys=False, allow_unicode=True) + f.flush() + os.fsync(f.fileno()) + # os.replace is atomic and overwrites on Windows — no unlink gap that + # a crash/power-loss could leave the target missing. + os.replace(tmp, path) except BaseException: # Clean up temp file on failure try: @@ -147,13 +182,13 @@ def _read_yaml(path: Path) -> Any: """ if not path.exists(): return None - with open(path, "r", encoding="utf-8") as f: + with open(path, encoding="utf-8") as f: text = f.read() try: return yaml.safe_load(text) except yaml.constructor.ConstructorError as err: marker = str(err) - if 'python/object' not in marker and 'numpy' not in marker: + if "python/object" not in marker and "numpy" not in marker: raise data = yaml.unsafe_load(text) return _sanitize_for_yaml(data) @@ -166,12 +201,12 @@ def _append_jsonl(path: Path, record: dict) -> None: f.write(json.dumps(record, ensure_ascii=False, default=str) + "\n") -def _read_jsonl(path: Path) -> List[dict]: +def _read_jsonl(path: Path) -> list[dict]: """Read all lines from a JSONL file.""" if not path.exists(): return [] records = [] - with open(path, "r", encoding="utf-8") as f: + with open(path, encoding="utf-8") as f: for line in f: line = line.strip() if line: @@ -187,6 +222,7 @@ def _now() -> str: # FileStore # --------------------------------------------------------------------------- + class FileStore: """Pure file-based storage for Gently. Drop-in replacement for GentlyStore.""" @@ -207,7 +243,7 @@ def __init__(self, root: Path): # Load session index (session_id -> folder_name) self._index_path = self._root / "sessions" / "_index.yaml" - self._index: Dict[str, str] = _read_yaml(self._index_path) or {} + self._index: dict[str, str] = _read_yaml(self._index_path) or {} # ------------------------------------------------------------------ # Properties @@ -229,7 +265,7 @@ def _save_index(self) -> None: """Persist the session index mapping to disk.""" _write_yaml(self._index_path, self._index) - def _session_dir(self, session_id: str) -> Optional[Path]: + def _session_dir(self, session_id: str) -> Path | None: """Return the session folder path, or None if unknown.""" folder = self._index.get(session_id) if folder is None: @@ -277,7 +313,7 @@ def _generate_projection( embryo_id: str, timepoint: int, volume: np.ndarray, - ) -> Optional[Path]: + ) -> Path | None: """Generate JPEG projection file from volume data.""" from .imaging import generate_jpeg_projection @@ -292,9 +328,9 @@ def _generate_projection( def create_session( self, session_id: str, - name: str = None, - description: str = None, - metadata: dict = None, + name: str | None = None, + description: str | None = None, + metadata: dict | None = None, ) -> str: """Create a new session. Returns session_id.""" # If session already exists, return silently (matches INSERT OR IGNORE) @@ -330,7 +366,7 @@ def create_session( logger.info("Created session %s -> %s", session_id, folder_name) return session_id - def get_session(self, session_id: str) -> Optional[SessionInfo]: + def get_session(self, session_id: str) -> SessionInfo | None: """Return session info as dict, or None.""" sd = self._session_dir(session_id) if sd is None or not sd.exists(): @@ -340,7 +376,7 @@ def get_session(self, session_id: str) -> Optional[SessionInfo]: return None return data - def list_sessions(self) -> List[SessionInfo]: + def list_sessions(self) -> list[SessionInfo]: """Return all sessions ordered by last_active descending.""" sessions = [] for sid in self._index: @@ -350,6 +386,21 @@ def list_sessions(self) -> List[SessionInfo]: sessions.sort(key=lambda s: s.get("last_active", ""), reverse=True) return sessions + def recent_session_ids(self, limit: int = 8) -> list[str]: + """Most-recent session IDs by folder-name date prefix, *cheaply*. + + Folder names are ``{YYYYMMDD}_{HHMM}_{slug}_{id8}`` so a reverse lexical + sort of the index orders them newest-first by creation time — no + ``session.yaml`` parse required. This is a creation-recency proxy (a + long-dormant session that was just resumed sorts by its original date), + which is fine for at-a-glance landing views; use ``list_sessions`` when + exact ``last_active`` ordering matters. + """ + items = sorted(self._index.items(), key=lambda kv: kv[1], reverse=True) + if limit and limit > 0: + items = items[:limit] + return [sid for sid, _ in items] + def touch_session(self, session_id: str) -> None: """Update last_active timestamp.""" sd = self._session_dir(session_id) @@ -367,15 +418,13 @@ def save_session_snapshot(self, session_id: str, snapshot: dict) -> None: sd = self._require_session_dir(session_id) path = sd / "conversation.json" # Write atomically via temp file - fd, tmp = tempfile.mkstemp( - suffix=".tmp", prefix="conversation", dir=str(sd) - ) + fd, tmp = tempfile.mkstemp(suffix=".tmp", prefix="conversation", dir=str(sd)) try: with os.fdopen(fd, "w", encoding="utf-8") as f: json.dump(snapshot, f, indent=2, ensure_ascii=False, default=str) - if path.exists(): - path.unlink() - Path(tmp).rename(path) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp, path) except BaseException: try: os.unlink(tmp) @@ -384,7 +433,7 @@ def save_session_snapshot(self, session_id: str, snapshot: dict) -> None: raise self.touch_session(session_id) - def load_session_snapshot(self, session_id: str) -> Optional[dict]: + def load_session_snapshot(self, session_id: str) -> dict | None: """Load conversation.json. Returns None if missing.""" sd = self._session_dir(session_id) if sd is None: @@ -392,7 +441,7 @@ def load_session_snapshot(self, session_id: str) -> Optional[dict]: path = sd / "conversation.json" if not path.exists(): return None - with open(path, "r", encoding="utf-8") as f: + with open(path, encoding="utf-8") as f: return json.load(f) # ------------------------------------------------------------------ @@ -428,36 +477,59 @@ def register_embryo( self, session_id: str, embryo_id: str, - embryo_uid: str = None, - nickname: str = None, - position_x: float = None, - position_y: float = None, - calibration: dict = None, - role: str = None, + embryo_uid: str | None = None, + nickname: str | None = None, + position_x: float | None = None, + position_y: float | None = None, + position_coarse: dict | None = None, + position_fine: dict | None = None, + calibration: dict | None = None, + role: str | None = None, ) -> None: """Register or update an embryo in a session. ``role`` is the experimental role key from gently.harness.roles.REGISTRY (e.g. ``"test"``, ``"calibration"``, ``"unassigned"``). Persisted in embryo.yaml. None preserves the existing value on update. + + Position has two stages: coarse (bottom-camera / manual map placement) + and fine (future SPIM-objective alignment). New callers should pass + position_coarse / position_fine as dicts of shape {"x": float, "y": + float}. Legacy callers passing position_x / position_y get folded into + coarse automatically. """ ed = self._embryo_dir(session_id, embryo_id) ed.mkdir(parents=True, exist_ok=True) + # Fold legacy position_x / position_y into coarse if caller used the + # old kwargs and didn't pass coarse explicitly. + if position_coarse is None and (position_x is not None or position_y is not None): + position_coarse = {} + if position_x is not None: + position_coarse["x"] = position_x + if position_y is not None: + position_coarse["y"] = position_y + yaml_path = ed / "embryo.yaml" existing = _read_yaml(yaml_path) if existing is not None: - # Update: COALESCE behaviour -- keep existing values when new ones - # are None, matching the old ON CONFLICT DO UPDATE SET logic. + # COALESCE update — keep existing values when new ones are None. + existing_coarse = _coarse_from_legacy(existing) embryo_data = { "embryo_id": embryo_id, "session_id": session_id, "embryo_uid": embryo_uid if embryo_uid is not None else existing.get("embryo_uid"), "nickname": nickname if nickname is not None else existing.get("nickname"), - "position_x": position_x if position_x is not None else existing.get("position_x"), - "position_y": position_y if position_y is not None else existing.get("position_y"), - "calibration": calibration if calibration is not None else existing.get("calibration"), + "position_coarse": position_coarse + if position_coarse is not None + else existing_coarse, + "position_fine": position_fine + if position_fine is not None + else existing.get("position_fine"), + "calibration": calibration + if calibration is not None + else existing.get("calibration"), "role": role if role is not None else existing.get("role", "test"), "created_at": existing.get("created_at", _now()), } @@ -467,8 +539,8 @@ def register_embryo( "session_id": session_id, "embryo_uid": embryo_uid, "nickname": nickname, - "position_x": position_x, - "position_y": position_y, + "position_coarse": position_coarse, + "position_fine": position_fine, "calibration": calibration, "role": role if role is not None else "test", "created_at": _now(), @@ -476,16 +548,20 @@ def register_embryo( _write_yaml(yaml_path, embryo_data) - def get_embryo(self, session_id: str, embryo_id: str) -> Optional[EmbryoInfo]: - """Read embryo.yaml. Returns None if not found.""" + def get_embryo(self, session_id: str, embryo_id: str) -> EmbryoInfo | None: + """Read embryo.yaml. Returns None if not found. + + Backfills position_coarse from legacy position_x / position_y so + callers don't need to know about the old schema. + """ sd = self._session_dir(session_id) if sd is None: return None yaml_path = sd / "embryos" / embryo_id / "embryo.yaml" data = _read_yaml(yaml_path) - return data + return _normalize_embryo_record(data) - def list_embryos(self, session_id: str) -> List[EmbryoInfo]: + def list_embryos(self, session_id: str) -> list[EmbryoInfo]: """List all embryos for a session, sorted by embryo_id.""" sd = self._session_dir(session_id) if sd is None: @@ -500,9 +576,25 @@ def list_embryos(self, session_id: str) -> List[EmbryoInfo]: yaml_path = entry / "embryo.yaml" data = _read_yaml(yaml_path) if data is not None: - result.append(data) + result.append(_normalize_embryo_record(data)) return result + def list_embryo_ids(self, session_id: str) -> list[str]: + """Embryo IDs from directory names only — no ``embryo.yaml`` parse. + + The directory name *is* the embryo_id in this layout (see + ``_embryo_dir`` / ``put_embryo``), so callers that only need the ids + (e.g. enumerating projections) can skip the per-embryo YAML read that + ``list_embryos`` pays. + """ + sd = self._session_dir(session_id) + if sd is None: + return [] + embryos_dir = sd / "embryos" + if not embryos_dir.exists(): + return [] + return [e.name for e in sorted(embryos_dir.iterdir()) if e.is_dir()] + # ================================================================== # Volumes # ================================================================== @@ -513,7 +605,7 @@ def put_volume( embryo_id: str, timepoint: int, volume: np.ndarray, - metadata: dict = None, + metadata: dict | None = None, ) -> Path: """ Write a volume to disk, generate a JPEG projection, write sidecar metadata. @@ -554,9 +646,7 @@ def put_volume( # Generate projection self._generate_projection(session_id, embryo_id, timepoint, volume) - logger.debug( - "put_volume: %s/%s t=%d -> %s", session_id, embryo_id, timepoint, vol_path - ) + logger.debug("put_volume: %s/%s t=%d -> %s", session_id, embryo_id, timepoint, vol_path) return vol_path def register_volume( @@ -565,8 +655,8 @@ def register_volume( embryo_id: str, timepoint: int, incoming_path: Path, - metadata: dict = None, - volume_data: np.ndarray = None, + metadata: dict | None = None, + volume_data: np.ndarray | None = None, ) -> Path: """ Zero-copy path: move an existing TIFF to its canonical location. @@ -605,6 +695,7 @@ def register_volume( volume = volume_data else: from .imaging import load_volume + volume = load_volume(canonical) # Write sidecar metadata @@ -625,33 +716,26 @@ def register_volume( logger.debug("register_volume: %s -> %s", incoming_path.name, canonical) return canonical - def get_volume( - self, session_id: str, embryo_id: str, timepoint: int - ) -> Optional[np.ndarray]: + def get_volume(self, session_id: str, embryo_id: str, timepoint: int) -> np.ndarray | None: """Load a volume from disk. Returns None if not found.""" path = self.get_volume_path(session_id, embryo_id, timepoint) if path is None or not path.exists(): return None import tifffile + return tifffile.imread(str(path)) - def get_volume_path( - self, session_id: str, embryo_id: str, timepoint: int - ) -> Optional[Path]: + def get_volume_path(self, session_id: str, embryo_id: str, timepoint: int) -> Path | None: """Return the absolute path to a volume TIFF, or None.""" sd = self._session_dir(session_id) if sd is None: return None - vol_path = ( - sd / "embryos" / embryo_id / "volumes" / self._volume_filename(timepoint) - ) + vol_path = sd / "embryos" / embryo_id / "volumes" / self._volume_filename(timepoint) if vol_path.exists(): return vol_path return None - def list_volumes( - self, session_id: str, embryo_id: str = None - ) -> List[VolumeInfo]: + def list_volumes(self, session_id: str, embryo_id: str | None = None) -> list[VolumeInfo]: """List volume metadata by scanning sidecar YAML files on disk.""" sd = self._session_dir(session_id) if sd is None: @@ -665,11 +749,9 @@ def list_volumes( if embryo_id: dirs = [embryos_dir / embryo_id] else: - dirs = sorted( - d for d in embryos_dir.iterdir() if d.is_dir() - ) + dirs = sorted(d for d in embryos_dir.iterdir() if d.is_dir()) - result: List[VolumeInfo] = [] + result: list[VolumeInfo] = [] for edir in dirs: vol_dir = edir / "volumes" if not vol_dir.exists(): @@ -679,9 +761,7 @@ def list_volumes( if data is None: continue # Build a VolumeInfo dict - tif_path = meta_file.parent / meta_file.name.replace( - ".meta.yaml", ".tif" - ) + tif_path = meta_file.parent / meta_file.name.replace(".meta.yaml", ".tif") info: VolumeInfo = { "session_id": data.get("session_id", session_id), "embryo_id": data.get("embryo_id", edir.name), @@ -698,9 +778,7 @@ def list_volumes( result.sort(key=lambda v: (v["embryo_id"], v["timepoint"])) return result - def get_acquisition_params( - self, session_id: str, embryo_id: str = None - ) -> Optional[dict]: + def get_acquisition_params(self, session_id: str, embryo_id: str | None = None) -> dict | None: """ Get acquisition parameters from the most recent volume sidecar. @@ -718,24 +796,38 @@ def get_acquisition_params( # Projections # ================================================================== - def get_projection_path( - self, session_id: str, embryo_id: str, timepoint: int - ) -> Optional[Path]: + def get_projection_path(self, session_id: str, embryo_id: str, timepoint: int) -> Path | None: """Return absolute path to the JPEG projection, or None.""" sd = self._session_dir(session_id) if sd is None: return None proj_path = ( - sd / "embryos" / embryo_id / "projections" - / self._projection_filename(timepoint) + sd / "embryos" / embryo_id / "projections" / self._projection_filename(timepoint) ) if proj_path.exists(): return proj_path return None - def get_projection_b64( - self, session_id: str, embryo_id: str, timepoint: int - ) -> Optional[str]: + def list_projection_timepoints(self, session_id: str, embryo_id: str) -> list[int]: + """Cheaply list projection timepoints (glob only, no PIL/meta reads). + + Used to rehydrate the viz image store on resume without paying the + per-file cost of list_projections(). + """ + sd = self._session_dir(session_id) + if sd is None: + return [] + proj_dir = sd / "embryos" / embryo_id / "projections" + if not proj_dir.exists(): + return [] + tps: list[int] = [] + for jpg in proj_dir.glob("t*.jpg"): + m = re.match(r"t(\d+)\.jpg$", jpg.name) + if m: + tps.append(int(m.group(1))) + return sorted(tps) + + def get_projection_b64(self, session_id: str, embryo_id: str, timepoint: int) -> str | None: """Return base64-encoded JPEG projection, or None.""" path = self.get_projection_path(session_id, embryo_id, timepoint) if path is None or not path.exists(): @@ -743,9 +835,7 @@ def get_projection_b64( with open(path, "rb") as f: return base64.b64encode(f.read()).decode("utf-8") - def list_projections( - self, session_id: str, embryo_id: str - ) -> List[ProjectionInfo]: + def list_projections(self, session_id: str, embryo_id: str) -> list[ProjectionInfo]: """List projection info for an embryo by scanning projection files.""" sd = self._session_dir(session_id) if sd is None: @@ -754,7 +844,7 @@ def list_projections( if not proj_dir.exists(): return [] - result: List[ProjectionInfo] = [] + result: list[ProjectionInfo] = [] for jpg in sorted(proj_dir.glob("t*.jpg")): # Extract timepoint from filename t0003.jpg -> 3 match = re.match(r"t(\d+)\.jpg$", jpg.name) @@ -766,6 +856,7 @@ def list_projections( width, height, size_kb = None, None, None try: from PIL import Image as PILImage + img = PILImage.open(str(jpg)) width, height = img.size size_kb = round(jpg.stat().st_size / 1024, 1) @@ -774,10 +865,7 @@ def list_projections( # Use the volume sidecar's acquired_at as the projection created_at # if available; otherwise use the file mtime. - meta_path = ( - sd / "embryos" / embryo_id / "volumes" - / self._volume_meta_filename(tp) - ) + meta_path = sd / "embryos" / embryo_id / "volumes" / self._volume_meta_filename(tp) vol_meta = _read_yaml(meta_path) created = ( vol_meta.get("acquired_at", "") @@ -807,7 +895,7 @@ def register_snapshot( session_id: str, source: str, incoming_path: Path, - metadata: dict = None, + metadata: dict | None = None, ) -> Path: """Move a transient TIFF from incoming/ to ``snapshots/``.""" incoming_path = Path(incoming_path) @@ -838,6 +926,7 @@ def register_snapshot( # Read shape for the sidecar try: import tifffile + arr = tifffile.imread(str(canonical)) sidecar["width"] = int(arr.shape[-1]) if arr.ndim >= 2 else None sidecar["height"] = int(arr.shape[-2]) if arr.ndim >= 2 else None @@ -853,9 +942,7 @@ def register_snapshot( logger.debug("register_snapshot: %s -> %s", incoming_path.name, canonical) return canonical - def list_snapshots( - self, session_id: str, source: str = None - ) -> List[Dict[str, Any]]: + def list_snapshots(self, session_id: str, source: str | None = None) -> list[dict[str, Any]]: """List snapshot records for a session, optionally filtered by source.""" sd = self._session_dir(session_id) if sd is None: @@ -902,9 +989,7 @@ def cleanup_incoming(self, max_age_seconds: float = 300) -> int: deleted += 1 logger.debug("cleanup_incoming: deleted %s", f.name) except OSError as e: - logger.warning( - "cleanup_incoming: could not delete %s: %s", f.name, e - ) + logger.warning("cleanup_incoming: could not delete %s: %s", f.name, e) if deleted: logger.info("cleanup_incoming: removed %d stale file(s)", deleted) return deleted @@ -917,7 +1002,7 @@ def _perception_runs_path(self, session_id: str) -> Path: sd = self._require_session_dir(session_id) return sd / "perception_runs.yaml" - def _load_perception_runs(self, session_id: str) -> Dict[int, dict]: + def _load_perception_runs(self, session_id: str) -> dict[int, dict]: """Load perception_runs.yaml as {run_id: run_metadata}.""" data = _read_yaml(self._perception_runs_path(session_id)) if data is None: @@ -925,9 +1010,7 @@ def _load_perception_runs(self, session_id: str) -> Dict[int, dict]: # Ensure keys are ints return {int(k): v for k, v in data.items()} - def _save_perception_runs( - self, session_id: str, runs: Dict[int, dict] - ) -> None: + def _save_perception_runs(self, session_id: str, runs: dict[int, dict]) -> None: _write_yaml(self._perception_runs_path(session_id), runs) def create_perception_run( @@ -935,10 +1018,10 @@ def create_perception_run( session_id: str, name: str, method: str, - model_name: str = None, + model_name: str | None = None, trace_type: str = "perception", source: str = "live", - config: dict = None, + config: dict | None = None, ) -> int: """Create a perception run. Returns run_id (auto-increment).""" runs = self._load_perception_runs(session_id) @@ -964,7 +1047,7 @@ def create_perception_run( return run_id def complete_perception_run( - self, run_id: int, status: str = "completed", error_message: str = None + self, run_id: int, status: str = "completed", error_message: str | None = None ) -> None: """Mark a perception run as completed or failed. @@ -988,14 +1071,14 @@ def store_prediction( embryo_id: str, timepoint: int, predicted_stage: str, - confidence: float = None, - reasoning: str = None, + confidence: float | None = None, + reasoning: str | None = None, is_transitional: bool = False, - execution_time_ms: float = None, - trace_data: dict = None, - observed_features: dict = None, - ground_truth_stage: str = None, - is_correct: int = None, + execution_time_ms: float | None = None, + trace_data: dict | None = None, + observed_features: dict | None = None, + ground_truth_stage: str | None = None, + is_correct: int | None = None, ) -> int: """ Append a prediction to predictions.jsonl and optionally write trace JSON. @@ -1050,9 +1133,9 @@ def store_prediction( def get_predictions( self, session_id: str, - embryo_id: str = None, - run_id: int = None, - ) -> List[PredictionInfo]: + embryo_id: str | None = None, + run_id: int | None = None, + ) -> list[PredictionInfo]: """Query predictions with optional filters.""" sd = self._session_dir(session_id) if sd is None: @@ -1068,7 +1151,7 @@ def get_predictions( else: dirs = sorted(d for d in embryos_dir.iterdir() if d.is_dir()) - result: List[PredictionInfo] = [] + result: list[PredictionInfo] = [] for edir in dirs: pred_path = edir / "predictions.jsonl" records = _read_jsonl(pred_path) @@ -1091,9 +1174,9 @@ def set_ground_truth( embryo_id: str, stage: str, start_timepoint: int, - end_timepoint: int = None, - annotator: str = None, - notes: str = None, + end_timepoint: int | None = None, + annotator: str | None = None, + notes: str | None = None, ) -> None: """Insert or update a ground-truth annotation.""" ed = self._embryo_dir(session_id, embryo_id) @@ -1117,23 +1200,23 @@ def set_ground_truth( if not found: # Auto-increment id max_id = max((e.get("id", 0) for e in entries), default=0) - entries.append({ - "id": max_id + 1, - "session_id": session_id, - "embryo_id": embryo_id, - "stage": stage, - "start_timepoint": start_timepoint, - "end_timepoint": end_timepoint, - "annotator": annotator, - "notes": notes, - "created_at": now, - }) + entries.append( + { + "id": max_id + 1, + "session_id": session_id, + "embryo_id": embryo_id, + "stage": stage, + "start_timepoint": start_timepoint, + "end_timepoint": end_timepoint, + "annotator": annotator, + "notes": notes, + "created_at": now, + } + ) _write_yaml(gt_path, entries) - def get_ground_truth( - self, session_id: str, embryo_id: str - ) -> List[GroundTruthEntry]: + def get_ground_truth(self, session_id: str, embryo_id: str) -> list[GroundTruthEntry]: """Get ground-truth annotations sorted by start_timepoint.""" sd = self._session_dir(session_id) if sd is None: diff --git a/gently/core/gently_manifest.py b/gently/core/gently_manifest.py index 75c26345..7ddeff7b 100644 --- a/gently/core/gently_manifest.py +++ b/gently/core/gently_manifest.py @@ -4,10 +4,11 @@ Called once on first initialization of a new storage root. """ -import yaml from datetime import datetime from pathlib import Path +import yaml + MANIFEST_VERSION = 3 MANIFEST = { @@ -47,7 +48,9 @@ "volumes": "t{NNNN}.tif — zlib-compressed TIFF stacks", "projections": "t{NNNN}.jpg — max-intensity JPEG projections", "traces": "t{NNNN}.json — complete perception record per timepoint", - "predictions.jsonl": "One-line-per-timepoint summary: {timepoint, stage, confidence, timestamp}", + "predictions.jsonl": ( + "One-line-per-timepoint summary: {timepoint, stage, confidence, timestamp}" + ), "ground_truth.yaml": "Human annotations: [{stage, start_timepoint, end_timepoint}]", }, "agent_memory": { @@ -70,7 +73,14 @@ def write_manifest(root: Path): data["created"] = datetime.now().strftime("%Y-%m-%d") with open(manifest_path, "w", encoding="utf-8") as f: - yaml.dump(data, f, default_flow_style=False, sort_keys=False, allow_unicode=True, width=100) + yaml.dump( + data, + f, + default_flow_style=False, + sort_keys=False, + allow_unicode=True, + width=100, + ) def read_manifest(root: Path) -> dict: @@ -78,5 +88,5 @@ def read_manifest(root: Path) -> dict: manifest_path = root / "gently.yaml" if not manifest_path.exists(): return {} - with open(manifest_path, "r", encoding="utf-8") as f: + with open(manifest_path, encoding="utf-8") as f: return yaml.safe_load(f) or {} diff --git a/gently/core/imaging.py b/gently/core/imaging.py index 6e830351..6ff46c14 100644 --- a/gently/core/imaging.py +++ b/gently/core/imaging.py @@ -22,7 +22,6 @@ import io import logging from pathlib import Path -from typing import Optional, Tuple import numpy as np @@ -30,6 +29,7 @@ try: from PIL import Image + PIL_AVAILABLE = True except ImportError: PIL_AVAILABLE = False @@ -246,8 +246,7 @@ def compress_image_for_api( image = np.repeat(image, 10, axis=1) img = normalize_to_uint8(image, method="percentile", p_low=1, p_high=99.5) - b64 = image_to_base64(img, format="JPEG", quality=quality, - max_dimension=max_dimension) + b64 = image_to_base64(img, format="JPEG", quality=quality, max_dimension=max_dimension) size_kb = len(base64.b64decode(b64)) / 1024 return b64, size_kb @@ -257,7 +256,7 @@ def generate_jpeg_projection( output_path: Path, max_dimension: int = 1024, quality: int = 90, -) -> Optional[Path]: +) -> Path | None: """ Generate a JPEG max-projection from a volume and write to disk. @@ -282,9 +281,19 @@ def generate_jpeg_projection( return None try: - max_proj = extract_view_a_and_max_project(volume) - normalized = normalize_to_uint8(max_proj, method="percentile", - p_low=1, p_high=99.5) + # Build the three-orthogonal-view layout (the projection we actually + # want — matches what the perceiver sees). For an explicit 4D + # (Views, Z, Y, X) volume, use View A. For a 3D volume, project the + # whole thing — do NOT try to split views by aspect ratio: the embryo + # is often centered and straddles the X midline, so a width-based + # "dual-view" guess slices it in half (the XY-rendered-halfway bug). + vol = np.squeeze(volume) + if vol.ndim == 4: + vol = vol[0] + if vol.ndim == 3: + normalized, _ = projection_three_view(vol) + else: + normalized = normalize_to_uint8(vol, method="percentile", p_low=1, p_high=99.5) pil_image = Image.fromarray(normalized) @@ -340,13 +349,13 @@ def load_volume(path: Path) -> np.ndarray: z_depth, height, width = vol.shape # Extract View A (left half) if dual-view format if width > height * 2: - vol = vol[:, :, :width // 2] + vol = vol[:, :, : width // 2] return vol def compute_crop_bounds( volume: np.ndarray, padding: int = 20, sigma_mult: float = 3.5 -) -> Tuple[int, int, int, int]: +) -> tuple[int, int, int, int]: """Compute crop bounds for 3D volume using center-of-mass of bright pixels. Parameters @@ -381,9 +390,7 @@ def compute_crop_bounds( return (y_min, y_max, x_min, x_max) -def apply_crop_bounds( - volume: np.ndarray, bounds: Tuple[int, int, int, int] -) -> np.ndarray: +def apply_crop_bounds(volume: np.ndarray, bounds: tuple[int, int, int, int]) -> np.ndarray: """Apply pre-computed crop bounds to a volume. Parameters @@ -408,8 +415,8 @@ def apply_crop_bounds( def projection_three_view( volume: np.ndarray, - voxel_size: Tuple[float, float, float] = (1.0, 0.1625, 0.1625), -) -> Tuple[np.ndarray, str]: + voxel_size: tuple[float, float, float] = (1.0, 0.1625, 0.1625), +) -> tuple[np.ndarray, str]: """Generate three orthogonal views layout from a 3D volume. Views are scaled to be physically isometric based on voxel dimensions. @@ -484,9 +491,7 @@ def projection_three_view( total_width = top_row.shape[1] if xz_scaled.shape[1] < total_width: - pad = np.zeros( - (xz_scaled.shape[0], total_width - xz_scaled.shape[1]), dtype=np.uint8 - ) + pad = np.zeros((xz_scaled.shape[0], total_width - xz_scaled.shape[1]), dtype=np.uint8) bottom_row = np.concatenate([xz_scaled, pad], axis=1) else: bottom_row = xz_scaled[:, :total_width] @@ -496,9 +501,7 @@ def projection_three_view( return combined, "Three-view: [XY|YZ] top, [XZ] bottom" -def _euler_to_rotation_matrix( - rx: float, ry: float, rz: float -) -> np.ndarray: +def _euler_to_rotation_matrix(rx: float, ry: float, rz: float) -> np.ndarray: """Convert Euler angles (degrees) to a 3x3 rotation matrix. Matches BVV's sequential rotation order: Rz * Ry * Rx applied from the @@ -518,11 +521,11 @@ def _euler_to_rotation_matrix( def clip_volume( volume: np.ndarray, - z_range: Optional[Tuple[float, float]] = None, - y_range: Optional[Tuple[float, float]] = None, - x_range: Optional[Tuple[float, float]] = None, - center: Optional[Tuple[float, float, float]] = None, - rotation: Optional[Tuple[float, float, float]] = None, + z_range: tuple[float, float] | None = None, + y_range: tuple[float, float] | None = None, + x_range: tuple[float, float] | None = None, + center: tuple[float, float, float] | None = None, + rotation: tuple[float, float, float] | None = None, ) -> np.ndarray: """Clip a 3D volume using an arbitrarily-oriented 3D box. @@ -632,11 +635,11 @@ def _frac_to_abs(frac_range, dim_size): def clip_and_project( volume: np.ndarray, - z_range: Optional[Tuple[float, float]] = None, - y_range: Optional[Tuple[float, float]] = None, - x_range: Optional[Tuple[float, float]] = None, - center: Optional[Tuple[float, float, float]] = None, - rotation: Optional[Tuple[float, float, float]] = None, + z_range: tuple[float, float] | None = None, + y_range: tuple[float, float] | None = None, + x_range: tuple[float, float] | None = None, + center: tuple[float, float, float] | None = None, + rotation: tuple[float, float, float] | None = None, projection: str = "max", axis: int = 0, max_dimension: int = 800, @@ -690,7 +693,7 @@ def render_volume_view( rotation_x: float = 0, rotation_y: float = 0, threshold: float = 0.2, - voxel_size: Tuple[float, float, float] = (1.0, 0.1625, 0.1625), + voxel_size: tuple[float, float, float] = (1.0, 0.1625, 0.1625), ) -> str: """Render a 3D volume from a specific viewing angle using alpha compositing. diff --git a/gently/core/log_bridge.py b/gently/core/log_bridge.py new file mode 100644 index 00000000..c0b98836 --- /dev/null +++ b/gently/core/log_bridge.py @@ -0,0 +1,185 @@ +"""Bridge Python logging into the EventBus so the Events page mirrors the +console. + +A small ``LogToBusHandler`` subclasses ``logging.Handler``. Every record it +sees gets published as ``EventType.LOG_RECORD`` with a compact payload the +frontend can render. The handler attaches itself to whichever loggers +``configure_log_bridge`` is told to cover — by default only ``gently`` and +``gently_perception``, which keeps third-party noise (aiohttp access logs, +bluesky state transitions, anthropic SDK chatter) off the page unless the +operator opts in. + +Env-configurable: + GENTLY_LOG_BUS — "on" / "off" (default: on) + GENTLY_LOG_BUS_LEVEL — DEBUG / INFO (default) / WARNING / ERROR + GENTLY_LOG_BUS_INCLUDE_THIRDPARTY — "1"/"true" to include common third- + party loggers (uvicorn, aiohttp, + bluesky, anthropic, httpx, httpcore) + +Re-entrancy is the only real subtlety: if a log call happens inside the +EventBus.publish path (e.g. from the dispatch loop's logger), republishing +it as another LOG_RECORD would loop forever. Guarded with a thread-local +re-entry flag. +""" + +from __future__ import annotations + +import logging +import os +import threading +from collections.abc import Iterable, Sequence + +from .event_bus import EventBus, EventType, get_event_bus + +logger = logging.getLogger(__name__) + + +# Loggers we never want on the Events page — they emit at the wrong layer +# (their own log lines describe bus dispatch / events page WebSocket frames) +# so republishing them would create feedback or infinite churn. +_NEVER_BRIDGE = frozenset( + { + "gently.core.event_bus", + "gently.core.log_bridge", + } +) + +# Loggers that count as "third-party noise" — silenced by default, can be +# opted in with GENTLY_LOG_BUS_INCLUDE_THIRDPARTY=1. +_THIRDPARTY_DEFAULTS: Sequence[str] = ( + "uvicorn", + "uvicorn.error", + "uvicorn.access", + "aiohttp", + "aiohttp.access", + "anthropic", + "httpx", + "httpcore", + "bluesky", + "bluesky.RE.state", +) + + +class LogToBusHandler(logging.Handler): + """Publishes each record onto the EventBus as a LOG_RECORD event. + + Per-thread re-entry guard prevents infinite loops when something in + the publish path itself logs. + """ + + def __init__(self, bus: EventBus, *, level: int = logging.INFO): + super().__init__(level=level) + self._bus = bus + self._reentry = threading.local() + + def emit(self, record: logging.LogRecord) -> None: + # Re-entry guard: if a downstream subscriber's handler logs, we + # must not republish that log line. + if getattr(self._reentry, "active", False): + return + # Never bridge our own machinery — those records describe the + # bridge itself, would loop. + if record.name in _NEVER_BRIDGE: + return + self._reentry.active = True + try: + try: + # format() runs all configured formatters (incl. exc_info + # serialisation). We send the formatted message + the + # structured bits separately so the frontend can choose + # how to render. + message = record.getMessage() + except Exception: + message = "" + + payload = { + "level": int(record.levelno), + "level_name": record.levelname, + "logger": record.name, + "message": message, + "module": record.module, + "func": record.funcName, + "line": record.lineno, + # Wall-clock ms since epoch — frontend uses this for its + # own ordering / display, separate from the EventBus's + # internal timestamp. + "ts_ms": int(record.created * 1000), + } + if record.exc_info: + try: + payload["exc_text"] = logging.Formatter().formatException(record.exc_info) + except Exception: + pass + + self._bus.publish( + event_type=EventType.LOG_RECORD, + data=payload, + source=f"log:{record.name}", + ) + except Exception: + # If we can't publish (shutdown, etc.), drop the record + # silently — the live console + on-disk log still have it. + pass + finally: + self._reentry.active = False + + +def configure_log_bridge( + bus: EventBus | None = None, + *, + loggers: Iterable[str] | None = None, + level: str | None = None, + include_thirdparty: bool | None = None, +) -> LogToBusHandler | None: + """Attach a LogToBusHandler to the requested loggers. + + Returns the installed handler (or None if the bridge is disabled). + Idempotent: safe to call multiple times — only the first call attaches. + + Parameters honour env-var defaults so the launch script doesn't need + to know the knobs: + GENTLY_LOG_BUS — "off" disables entirely + GENTLY_LOG_BUS_LEVEL — threshold (default INFO) + GENTLY_LOG_BUS_INCLUDE_THIRDPARTY — adds aiohttp/uvicorn/bluesky/etc. + """ + if os.environ.get("GENTLY_LOG_BUS", "on").lower() in ("off", "0", "false", "no"): + return None + + if bus is None: + bus = get_event_bus() + + if level is None: + level = os.environ.get("GENTLY_LOG_BUS_LEVEL", "INFO") + level_int = getattr(logging, level.upper(), logging.INFO) + + if include_thirdparty is None: + env_val = os.environ.get("GENTLY_LOG_BUS_INCLUDE_THIRDPARTY", "") + include_thirdparty = env_val.lower() in ("1", "true", "yes", "on") + + if loggers is None: + loggers = ["gently", "gently_perception"] + if include_thirdparty: + loggers = list(loggers) + list(_THIRDPARTY_DEFAULTS) + + handler = LogToBusHandler(bus, level=level_int) + + attached = [] + for name in loggers: + target = logging.getLogger(name) + # Skip if already attached (idempotency for re-invocation). + if any(isinstance(h, LogToBusHandler) for h in target.handlers): + continue + target.addHandler(handler) + attached.append(name) + + if attached: + # Surface the configuration once at startup — using our own logger + # (which is in _NEVER_BRIDGE) so this announcement itself doesn't + # become a LOG_RECORD event. + logger.info( + "Log bridge active: level=%s, loggers=%s, include_thirdparty=%s", + logging.getLevelName(level_int), + attached, + include_thirdparty, + ) + return handler diff --git a/gently/core/service.py b/gently/core/service.py index a4bd6bc2..21876ee9 100644 --- a/gently/core/service.py +++ b/gently/core/service.py @@ -10,22 +10,22 @@ import asyncio import logging -import time from abc import ABC, abstractmethod from dataclasses import dataclass, field from datetime import datetime from enum import Enum, auto -from typing import Any, Callable, Dict, List, Optional, Type +from typing import Any import aiohttp -from .event_bus import EventType, get_event_bus, Event +from .event_bus import EventType, get_event_bus logger = logging.getLogger(__name__) class ServiceState(Enum): """Service lifecycle states""" + CREATED = auto() STARTING = auto() RUNNING = auto() @@ -37,25 +37,26 @@ class ServiceState(Enum): @dataclass class ServiceInfo: """Information about a registered service""" + name: str service_type: str host: str = "localhost" - port: Optional[int] = None + port: int | None = None state: ServiceState = ServiceState.CREATED - started_at: Optional[datetime] = None - metadata: Dict[str, Any] = field(default_factory=dict) - health_check_url: Optional[str] = None + started_at: datetime | None = None + metadata: dict[str, Any] = field(default_factory=dict) + health_check_url: str | None = None - def to_dict(self) -> Dict: + def to_dict(self) -> dict: return { - 'name': self.name, - 'service_type': self.service_type, - 'host': self.host, - 'port': self.port, - 'state': self.state.name, - 'started_at': self.started_at.isoformat() if self.started_at else None, - 'metadata': self.metadata, - 'health_check_url': self.health_check_url, + "name": self.name, + "service_type": self.service_type, + "host": self.host, + "port": self.port, + "state": self.state.name, + "started_at": self.started_at.isoformat() if self.started_at else None, + "metadata": self.metadata, + "health_check_url": self.health_check_url, } @@ -74,16 +75,16 @@ def __init__( name: str, service_type: str = "generic", host: str = "localhost", - port: Optional[int] = None, + port: int | None = None, ): self.name = name self.service_type = service_type self.host = host self.port = port self._state = ServiceState.CREATED - self._started_at: Optional[datetime] = None + self._started_at: datetime | None = None self._event_bus = get_event_bus() - self._metadata: Dict[str, Any] = {} + self._metadata: dict[str, Any] = {} @property def state(self) -> ServiceState: @@ -108,18 +109,18 @@ async def start(self): return self._state = ServiceState.STARTING - self._emit_event(EventType.STATUS_CHANGED, {'state': 'starting'}) + self._emit_event(EventType.STATUS_CHANGED, {"state": "starting"}) try: await self.on_start() self._state = ServiceState.RUNNING self._started_at = datetime.now() - self._emit_event(EventType.STATUS_CHANGED, {'state': 'running'}) + self._emit_event(EventType.STATUS_CHANGED, {"state": "running"}) logger.info(f"Service {self.name} started") except Exception as e: self._state = ServiceState.ERROR - self._emit_event(EventType.ERROR_OCCURRED, {'error': str(e)}) + self._emit_event(EventType.ERROR_OCCURRED, {"error": str(e)}) logger.error(f"Service {self.name} failed to start: {e}") raise @@ -129,27 +130,26 @@ async def stop(self): return self._state = ServiceState.STOPPING - self._emit_event(EventType.STATUS_CHANGED, {'state': 'stopping'}) + self._emit_event(EventType.STATUS_CHANGED, {"state": "stopping"}) try: await self.on_stop() self._state = ServiceState.STOPPED - self._emit_event(EventType.STATUS_CHANGED, {'state': 'stopped'}) + self._emit_event(EventType.STATUS_CHANGED, {"state": "stopped"}) logger.info(f"Service {self.name} stopped") except Exception as e: self._state = ServiceState.ERROR logger.error(f"Service {self.name} failed to stop cleanly: {e}") - async def health_check(self) -> Dict: + async def health_check(self) -> dict: """Check service health""" return { - 'name': self.name, - 'state': self._state.name, - 'healthy': self._state == ServiceState.RUNNING, - 'uptime_seconds': ( - (datetime.now() - self._started_at).total_seconds() - if self._started_at else 0 + "name": self.name, + "state": self._state.name, + "healthy": self._state == ServiceState.RUNNING, + "uptime_seconds": ( + (datetime.now() - self._started_at).total_seconds() if self._started_at else 0 ), } @@ -163,11 +163,11 @@ async def on_stop(self): """Called when service stops - implement in subclass""" pass - def _emit_event(self, event_type: EventType, data: Dict): + def _emit_event(self, event_type: EventType, data: dict): """Emit event on the bus""" self._event_bus.publish( event_type=event_type, - data={'service': self.name, **data}, + data={"service": self.name, **data}, source=f"service:{self.name}", ) @@ -183,8 +183,8 @@ class ServiceRegistry: """ def __init__(self): - self._services: Dict[str, Service] = {} - self._service_info: Dict[str, ServiceInfo] = {} + self._services: dict[str, Service] = {} + self._service_info: dict[str, ServiceInfo] = {} def register(self, service: Service): """Register a service""" @@ -205,20 +205,20 @@ def unregister(self, name: str): del self._service_info[name] logger.info(f"Unregistered service: {name}") - def get(self, name: str) -> Optional[Service]: + def get(self, name: str) -> Service | None: """Get service by name""" return self._services.get(name) - def get_info(self, name: str) -> Optional[ServiceInfo]: + def get_info(self, name: str) -> ServiceInfo | None: """Get service info by name""" if name in self._services: return self._services[name].info return self._service_info.get(name) - def find_by_type(self, service_type: str) -> List[ServiceInfo]: + def find_by_type(self, service_type: str) -> list[ServiceInfo]: """Find all services of a given type""" results = [] - for name, service in self._services.items(): + for _name, service in self._services.items(): if service.service_type == service_type: results.append(service.info) for name, info in self._service_info.items(): @@ -226,7 +226,7 @@ def find_by_type(self, service_type: str) -> List[ServiceInfo]: results.append(info) return results - def list_all(self) -> List[ServiceInfo]: + def list_all(self) -> list[ServiceInfo]: """List all registered services""" seen = set() results = [] @@ -238,7 +238,7 @@ def list_all(self) -> List[ServiceInfo]: results.append(info) return results - async def health_check_all(self) -> Dict[str, Dict]: + async def health_check_all(self) -> dict[str, dict]: """Check health of all services""" results = {} for name, service in self._services.items(): @@ -246,28 +246,28 @@ async def health_check_all(self) -> Dict[str, Dict]: results[name] = await service.health_check() except Exception as e: results[name] = { - 'name': name, - 'state': 'ERROR', - 'healthy': False, - 'error': str(e), + "name": name, + "state": "ERROR", + "healthy": False, + "error": str(e), } return results async def start_all(self): """Start all registered services""" - for name, service in self._services.items(): + for _name, service in self._services.items(): if service.state == ServiceState.CREATED: await service.start() async def stop_all(self): """Stop all registered services""" - for name, service in self._services.items(): + for _name, service in self._services.items(): if service.state == ServiceState.RUNNING: await service.stop() # Global registry -_global_registry: Optional[ServiceRegistry] = None +_global_registry: ServiceRegistry | None = None def get_service_registry() -> ServiceRegistry: @@ -288,6 +288,7 @@ def set_service_registry(registry: ServiceRegistry): # Service Client for Communication # ============================================================================= + class ServiceClient: """ Unified client for communicating with services @@ -298,9 +299,9 @@ class ServiceClient: - Protocol abstraction (HTTP) """ - def __init__(self, registry: Optional[ServiceRegistry] = None): + def __init__(self, registry: ServiceRegistry | None = None): self._registry = registry or get_service_registry() - self._connections: Dict[str, Any] = {} + self._connections: dict[str, Any] = {} async def connect(self, service_name: str) -> Any: """ @@ -336,7 +337,7 @@ async def disconnect(self, service_name: str): conn = self._connections.pop(service_name) if isinstance(conn, dict) and "session" in conn: await conn["session"].close() - elif hasattr(conn, 'close'): + elif hasattr(conn, "close"): if asyncio.iscoroutinefunction(conn.close): await conn.close() else: @@ -347,13 +348,7 @@ async def disconnect_all(self): for name in list(self._connections.keys()): await self.disconnect(name) - async def call( - self, - service_name: str, - method: str, - *args, - **kwargs - ) -> Any: + async def call(self, service_name: str, method: str, *args, **kwargs) -> Any: """ Call a method on a service diff --git a/gently/core/store.py b/gently/core/store.py index 1f347270..6a60a202 100644 --- a/gently/core/store.py +++ b/gently/core/store.py @@ -31,19 +31,23 @@ import json import logging -import os import shutil import sqlite3 from contextlib import contextmanager from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any import numpy as np from .store_types import ( - SessionInfo, EmbryoInfo, VolumeInfo, ProjectionInfo, - PerceptionRunInfo, PredictionInfo, GroundTruthEntry, StoreStats, + EmbryoInfo, + GroundTruthEntry, + PredictionInfo, + ProjectionInfo, + SessionInfo, + StoreStats, + VolumeInfo, ) logger = logging.getLogger(__name__) @@ -194,6 +198,7 @@ # GentlyStore # --------------------------------------------------------------------------- + class GentlyStore: """One class for all Gently storage. Owns one SQLite DB + one directory tree.""" @@ -271,8 +276,13 @@ def _abs_path(self, rel_path: str) -> Path: # Sessions # ================================================================== - def create_session(self, session_id: str, name: str = None, - description: str = None, metadata: dict = None) -> str: + def create_session( + self, + session_id: str, + name: str | None = None, + description: str | None = None, + metadata: dict | None = None, + ) -> str: """Create a new session. Returns session_id.""" now = self._now() with self._tx(): @@ -280,13 +290,19 @@ def create_session(self, session_id: str, name: str = None, "INSERT OR IGNORE INTO sessions " "(session_id, name, description, created_at, last_active, metadata) " "VALUES (?, ?, ?, ?, ?, ?)", - (session_id, name, description, now, now, - json.dumps(metadata) if metadata else None), + ( + session_id, + name, + description, + now, + now, + json.dumps(metadata) if metadata else None, + ), ) logger.info(f"Created session {session_id}") return session_id - def get_session(self, session_id: str) -> Optional[SessionInfo]: + def get_session(self, session_id: str) -> SessionInfo | None: """Return session row as dict, or None.""" row = self._conn.execute( "SELECT * FROM sessions WHERE session_id = ?", (session_id,) @@ -297,11 +313,9 @@ def get_session(self, session_id: str) -> Optional[SessionInfo]: self._parse_json_field(d, "metadata") return d - def list_sessions(self) -> List[SessionInfo]: + def list_sessions(self) -> list[SessionInfo]: """Return all sessions ordered by last_active descending.""" - rows = self._conn.execute( - "SELECT * FROM sessions ORDER BY last_active DESC" - ).fetchall() + rows = self._conn.execute("SELECT * FROM sessions ORDER BY last_active DESC").fetchall() result = [] for row in rows: d = dict(row) @@ -325,12 +339,12 @@ def save_session_snapshot(self, session_id: str, snapshot: dict): json.dump(snapshot, f, indent=2, ensure_ascii=False, default=str) self.touch_session(session_id) - def load_session_snapshot(self, session_id: str) -> Optional[dict]: + def load_session_snapshot(self, session_id: str) -> dict | None: """Load session snapshot JSON. Returns None if missing.""" path = self.root / "sessions" / f"{session_id}.json" if not path.exists(): return None - with open(path, "r", encoding="utf-8") as f: + with open(path, encoding="utf-8") as f: return json.load(f) # ================================================================== @@ -341,11 +355,11 @@ def register_embryo( self, session_id: str, embryo_id: str, - embryo_uid: str = None, - nickname: str = None, - position_x: float = None, - position_y: float = None, - calibration: dict = None, + embryo_uid: str | None = None, + nickname: str | None = None, + position_x: float | None = None, + position_y: float | None = None, + calibration: dict | None = None, ): """Register or update an embryo in a session.""" now = self._now() @@ -362,11 +376,19 @@ def register_embryo( " position_x = COALESCE(excluded.position_x, embryos.position_x), " " position_y = COALESCE(excluded.position_y, embryos.position_y), " " calibration = COALESCE(excluded.calibration, embryos.calibration)", - (embryo_id, session_id, embryo_uid, nickname, - position_x, position_y, cal_json, now), + ( + embryo_id, + session_id, + embryo_uid, + nickname, + position_x, + position_y, + cal_json, + now, + ), ) - def get_embryo(self, session_id: str, embryo_id: str) -> Optional[EmbryoInfo]: + def get_embryo(self, session_id: str, embryo_id: str) -> EmbryoInfo | None: row = self._conn.execute( "SELECT * FROM embryos WHERE session_id = ? AND embryo_id = ?", (session_id, embryo_id), @@ -377,7 +399,7 @@ def get_embryo(self, session_id: str, embryo_id: str) -> Optional[EmbryoInfo]: self._parse_json_field(d, "calibration") return d - def list_embryos(self, session_id: str) -> List[EmbryoInfo]: + def list_embryos(self, session_id: str) -> list[EmbryoInfo]: rows = self._conn.execute( "SELECT * FROM embryos WHERE session_id = ? ORDER BY embryo_id", (session_id,), @@ -411,7 +433,7 @@ def put_volume( embryo_id: str, timepoint: int, volume: np.ndarray, - metadata: dict = None, + metadata: dict | None = None, ) -> Path: """ Write a volume to disk, generate a JPEG projection, insert DB rows. @@ -444,17 +466,23 @@ def put_volume( # Insert DB rows self._insert_volume_row( - session_id, embryo_id, timepoint, vol_path, - volume.shape, str(volume.dtype), metadata, + session_id, + embryo_id, + timepoint, + vol_path, + volume.shape, + str(volume.dtype), + metadata, ) if proj_path is not None: self._insert_projection_row( - session_id, embryo_id, timepoint, proj_path, + session_id, + embryo_id, + timepoint, + proj_path, ) - logger.debug( - f"put_volume: {session_id}/{embryo_id} t={timepoint} -> {vol_path}" - ) + logger.debug(f"put_volume: {session_id}/{embryo_id} t={timepoint} -> {vol_path}") return vol_path def register_volume( @@ -463,8 +491,8 @@ def register_volume( embryo_id: str, timepoint: int, incoming_path: Path, - metadata: dict = None, - volume_data: np.ndarray = None, + metadata: dict | None = None, + volume_data: np.ndarray | None = None, ) -> Path: """ Zero-copy path: rename an existing TIFF to canonical location. @@ -510,22 +538,29 @@ def register_volume( volume = volume_data else: from gently.core.imaging import load_volume + volume = load_volume(canonical) proj_path = self._generate_projection(session_id, embryo_id, timepoint, volume) self._insert_volume_row( - session_id, embryo_id, timepoint, canonical, - volume.shape, str(volume.dtype), metadata, + session_id, + embryo_id, + timepoint, + canonical, + volume.shape, + str(volume.dtype), + metadata, ) if proj_path is not None: self._insert_projection_row( - session_id, embryo_id, timepoint, proj_path, + session_id, + embryo_id, + timepoint, + proj_path, ) - logger.debug( - f"register_volume: {incoming_path.name} -> {canonical}" - ) + logger.debug(f"register_volume: {incoming_path.name} -> {canonical}") return canonical # ------------------------------------------------------------------ @@ -540,7 +575,7 @@ def register_snapshot( session_id: str, source: str, incoming_path: Path, - metadata: dict = None, + metadata: dict | None = None, ) -> Path: """Move a transient TIFF from *incoming/* to ``snapshots/{session}/``. @@ -577,6 +612,7 @@ def register_snapshot( # Read shape for DB record import tifffile + arr = tifffile.imread(str(canonical)) with self._tx(): @@ -585,7 +621,8 @@ def register_snapshot( "(session_id, source, file_path, width, height, metadata, captured_at) " "VALUES (?, ?, ?, ?, ?, ?, ?)", ( - session_id, source, + session_id, + source, self._rel_path(canonical), arr.shape[-1] if arr.ndim >= 2 else None, arr.shape[-2] if arr.ndim >= 2 else None, @@ -596,14 +633,11 @@ def register_snapshot( logger.debug("register_snapshot: %s -> %s", incoming_path.name, canonical) return canonical - def list_snapshots( - self, session_id: str, source: str = None - ) -> List[Dict[str, Any]]: + def list_snapshots(self, session_id: str, source: str | None = None) -> list[dict[str, Any]]: """List snapshot records for a session, optionally filtered by source.""" if source: rows = self._conn.execute( - "SELECT * FROM snapshots WHERE session_id = ? AND source = ? " - "ORDER BY captured_at", + "SELECT * FROM snapshots WHERE session_id = ? AND source = ? ORDER BY captured_at", (session_id, source), ).fetchall() else: @@ -611,10 +645,10 @@ def list_snapshots( "SELECT * FROM snapshots WHERE session_id = ? ORDER BY captured_at", (session_id,), ).fetchall() - cols = [d[0] for d in self._conn.execute( - "SELECT * FROM snapshots LIMIT 0" - ).description or []] - return [dict(zip(cols, row)) for row in rows] + cols = [ + d[0] for d in self._conn.execute("SELECT * FROM snapshots LIMIT 0").description or [] + ] + return [dict(zip(cols, row, strict=False)) for row in rows] # ------------------------------------------------------------------ # Incoming cleanup @@ -652,19 +686,16 @@ def cleanup_incoming(self, max_age_seconds: float = 300) -> int: # Volume retrieval # ------------------------------------------------------------------ - def get_volume( - self, session_id: str, embryo_id: str, timepoint: int - ) -> Optional[np.ndarray]: + def get_volume(self, session_id: str, embryo_id: str, timepoint: int) -> np.ndarray | None: """Load a volume from disk. Returns None if not found.""" path = self.get_volume_path(session_id, embryo_id, timepoint) if path is None or not path.exists(): return None import tifffile + return tifffile.imread(str(path)) - def get_volume_path( - self, session_id: str, embryo_id: str, timepoint: int - ) -> Optional[Path]: + def get_volume_path(self, session_id: str, embryo_id: str, timepoint: int) -> Path | None: """Return the absolute path to a volume, or None.""" row = self._conn.execute( "SELECT file_path FROM volumes " @@ -675,14 +706,11 @@ def get_volume_path( return None return self._abs_path(row["file_path"]) - def list_volumes( - self, session_id: str, embryo_id: str = None - ) -> List[VolumeInfo]: + def list_volumes(self, session_id: str, embryo_id: str | None = None) -> list[VolumeInfo]: """List volume metadata rows for a session (optionally filtered).""" if embryo_id: rows = self._conn.execute( - "SELECT * FROM volumes WHERE session_id = ? AND embryo_id = ? " - "ORDER BY timepoint", + "SELECT * FROM volumes WHERE session_id = ? AND embryo_id = ? ORDER BY timepoint", (session_id, embryo_id), ).fetchall() else: @@ -700,9 +728,7 @@ def list_volumes( result.append(d) return result - def get_acquisition_params( - self, session_id: str, embryo_id: str = None - ) -> Optional[dict]: + def get_acquisition_params(self, session_id: str, embryo_id: str | None = None) -> dict | None: """ Get the acquisition parameters used in a session. @@ -742,9 +768,12 @@ def get_acquisition_params( # -- internal helpers -- def _generate_projection( - self, session_id: str, embryo_id: str, timepoint: int, + self, + session_id: str, + embryo_id: str, + timepoint: int, volume: np.ndarray, - ) -> Optional[Path]: + ) -> Path | None: """Generate JPEG projection file from volume data.""" from .imaging import generate_jpeg_projection @@ -753,8 +782,14 @@ def _generate_projection( return generate_jpeg_projection(volume, proj_path) def _insert_volume_row( - self, session_id, embryo_id, timepoint, vol_path, - shape, dtype, metadata, + self, + session_id, + embryo_id, + timepoint, + vol_path, + shape, + dtype, + metadata, ): now = self._now() with self._tx(): @@ -764,7 +799,9 @@ def _insert_volume_row( " acquired_at, metadata) " "VALUES (?, ?, ?, ?, ?, ?, ?, ?)", ( - session_id, embryo_id, timepoint, + session_id, + embryo_id, + timepoint, self._rel_path(vol_path), json.dumps(list(shape)), dtype, @@ -774,7 +811,11 @@ def _insert_volume_row( ) def _insert_projection_row( - self, session_id, embryo_id, timepoint, proj_path, + self, + session_id, + embryo_id, + timepoint, + proj_path, ): from PIL import Image as PILImage @@ -794,9 +835,14 @@ def _insert_projection_row( " size_kb, created_at) " "VALUES (?, ?, ?, ?, ?, ?, ?, ?)", ( - session_id, embryo_id, timepoint, + session_id, + embryo_id, + timepoint, self._rel_path(proj_path), - w, h, size_kb, now, + w, + h, + size_kb, + now, ), ) @@ -804,9 +850,7 @@ def _insert_projection_row( # Projections # ================================================================== - def get_projection_path( - self, session_id: str, embryo_id: str, timepoint: int - ) -> Optional[Path]: + def get_projection_path(self, session_id: str, embryo_id: str, timepoint: int) -> Path | None: """Return absolute path to the JPEG projection, or None.""" row = self._conn.execute( "SELECT file_path FROM projections " @@ -817,9 +861,7 @@ def get_projection_path( return None return self._abs_path(row["file_path"]) - def get_projection_b64( - self, session_id: str, embryo_id: str, timepoint: int - ) -> Optional[str]: + def get_projection_b64(self, session_id: str, embryo_id: str, timepoint: int) -> str | None: """Return base64-encoded JPEG projection, or None.""" import base64 @@ -829,12 +871,9 @@ def get_projection_b64( with open(path, "rb") as f: return base64.b64encode(f.read()).decode("utf-8") - def list_projections( - self, session_id: str, embryo_id: str - ) -> List[ProjectionInfo]: + def list_projections(self, session_id: str, embryo_id: str) -> list[ProjectionInfo]: rows = self._conn.execute( - "SELECT * FROM projections " - "WHERE session_id = ? AND embryo_id = ? ORDER BY timepoint", + "SELECT * FROM projections WHERE session_id = ? AND embryo_id = ? ORDER BY timepoint", (session_id, embryo_id), ).fetchall() return [dict(r) for r in rows] @@ -848,10 +887,10 @@ def create_perception_run( session_id: str, name: str, method: str, - model_name: str = None, + model_name: str | None = None, trace_type: str = "perception", source: str = "live", - config: dict = None, + config: dict | None = None, ) -> int: """Create a perception run. Returns run_id.""" now = self._now() @@ -862,8 +901,12 @@ def create_perception_run( " source, config, status, created_at) " "VALUES (?, ?, ?, ?, ?, ?, ?, 'running', ?)", ( - session_id, name, method, model_name, - trace_type, source, + session_id, + name, + method, + model_name, + trace_type, + source, json.dumps(config) if config else None, now, ), @@ -877,14 +920,14 @@ def store_prediction( embryo_id: str, timepoint: int, predicted_stage: str, - confidence: float = None, - reasoning: str = None, + confidence: float | None = None, + reasoning: str | None = None, is_transitional: bool = False, - execution_time_ms: float = None, - trace_data: dict = None, - observed_features: dict = None, - ground_truth_stage: str = None, - is_correct: int = None, + execution_time_ms: float | None = None, + trace_data: dict | None = None, + observed_features: dict | None = None, + ground_truth_stage: str | None = None, + is_correct: int | None = None, ) -> int: """ Insert a prediction row. Optionally writes trace JSON file. @@ -919,10 +962,17 @@ def store_prediction( " created_at) " "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", ( - run_id, session_id, embryo_id, timepoint, predicted_stage, - confidence, reasoning, + run_id, + session_id, + embryo_id, + timepoint, + predicted_stage, + confidence, + reasoning, 1 if is_transitional else 0, - ground_truth_stage, is_correct, execution_time_ms, + ground_truth_stage, + is_correct, + execution_time_ms, trace_file_rel, json.dumps(observed_features) if observed_features else None, now, @@ -931,7 +981,7 @@ def store_prediction( return cursor.lastrowid def complete_perception_run( - self, run_id: int, status: str = "completed", error_message: str = None + self, run_id: int, status: str = "completed", error_message: str | None = None ): """Mark a perception run as completed or failed.""" now = self._now() @@ -945,9 +995,9 @@ def complete_perception_run( def get_predictions( self, session_id: str, - embryo_id: str = None, - run_id: int = None, - ) -> List[PredictionInfo]: + embryo_id: str | None = None, + run_id: int | None = None, + ) -> list[PredictionInfo]: """Query predictions with optional filters.""" clauses = ["session_id = ?"] params: list = [session_id] @@ -961,8 +1011,7 @@ def get_predictions( where = " AND ".join(clauses) rows = self._conn.execute( - f"SELECT * FROM predictions WHERE {where} " - "ORDER BY timepoint, prediction_id", + f"SELECT * FROM predictions WHERE {where} ORDER BY timepoint, prediction_id", params, ).fetchall() @@ -984,9 +1033,9 @@ def set_ground_truth( embryo_id: str, stage: str, start_timepoint: int, - end_timepoint: int = None, - annotator: str = None, - notes: str = None, + end_timepoint: int | None = None, + annotator: str | None = None, + notes: str | None = None, ): """Insert or update a ground-truth annotation.""" with self._tx(): @@ -1000,11 +1049,18 @@ def set_ground_truth( " end_timepoint = excluded.end_timepoint, " " annotator = excluded.annotator, " " notes = excluded.notes", - (session_id, embryo_id, stage, start_timepoint, - end_timepoint, annotator, notes), + ( + session_id, + embryo_id, + stage, + start_timepoint, + end_timepoint, + annotator, + notes, + ), ) - def get_ground_truth(self, session_id: str, embryo_id: str) -> List[GroundTruthEntry]: + def get_ground_truth(self, session_id: str, embryo_id: str) -> list[GroundTruthEntry]: rows = self._conn.execute( "SELECT * FROM ground_truth " "WHERE session_id = ? AND embryo_id = ? ORDER BY start_timepoint", @@ -1018,8 +1074,15 @@ def get_ground_truth(self, session_id: str, embryo_id: str) -> List[GroundTruthE def stats(self) -> StoreStats: """Return counts and disk-usage summary.""" - tables = ["sessions", "embryos", "volumes", "projections", - "perception_runs", "predictions", "ground_truth"] + tables = [ + "sessions", + "embryos", + "volumes", + "projections", + "perception_runs", + "predictions", + "ground_truth", + ] counts = {} for t in tables: counts[t] = self._conn.execute(f"SELECT COUNT(*) FROM {t}").fetchone()[0] diff --git a/gently/core/store_types.py b/gently/core/store_types.py index 34a466c4..aa3eeaff 100644 --- a/gently/core/store_types.py +++ b/gently/core/store_types.py @@ -2,27 +2,36 @@ Use these for type annotations on store methods and their callers. """ -from typing import List, Optional, TypedDict + +from typing import TypedDict class SessionInfo(TypedDict): session_id: str - name: Optional[str] - description: Optional[str] + name: str | None + description: str | None created_at: str last_active: str - metadata: Optional[dict] + metadata: dict | None -class EmbryoInfo(TypedDict): +class EmbryoInfo(TypedDict, total=False): embryo_id: str session_id: str - embryo_uid: Optional[str] - nickname: Optional[str] - position_x: Optional[float] - position_y: Optional[float] - calibration: Optional[dict] - role: Optional[str] # key into gently.harness.roles.REGISTRY + embryo_uid: str | None + nickname: str | None + # Coarse XY (µm) from bottom-camera detection or manual map placement. + # Shape: {"x": float, "y": float}. Always present once the embryo exists. + position_coarse: dict | None + # Fine XY (µm) from SPIM-objective alignment. None until that workflow + # refines the coarse position. Shape: {"x": float, "y": float}. + position_fine: dict | None + # Legacy flat fields. Still accepted on write and surfaced on read for + # callers that haven't migrated; new code should use position_coarse. + position_x: float | None + position_y: float | None + calibration: dict | None + role: str | None # key into gently.harness.roles.REGISTRY created_at: str @@ -31,10 +40,10 @@ class VolumeInfo(TypedDict): embryo_id: str timepoint: int file_path: str - shape: Optional[List[int]] - dtype: Optional[str] + shape: list[int] | None + dtype: str | None acquired_at: str - metadata: Optional[dict] + metadata: dict | None class ProjectionInfo(TypedDict): @@ -42,25 +51,25 @@ class ProjectionInfo(TypedDict): embryo_id: str timepoint: int file_path: str - width: Optional[int] - height: Optional[int] - size_kb: Optional[float] + width: int | None + height: int | None + size_kb: float | None created_at: str class PerceptionRunInfo(TypedDict): run_id: int - session_id: Optional[str] + session_id: str | None name: str perception_method: str - model_name: Optional[str] + model_name: str | None trace_type: str source: str - config: Optional[dict] + config: dict | None status: str created_at: str - completed_at: Optional[str] - error_message: Optional[str] + completed_at: str | None + error_message: str | None class PredictionInfo(TypedDict): @@ -70,14 +79,14 @@ class PredictionInfo(TypedDict): embryo_id: str timepoint: int predicted_stage: str - confidence: Optional[float] - reasoning: Optional[str] + confidence: float | None + reasoning: str | None is_transitional: int - ground_truth_stage: Optional[str] - is_correct: Optional[int] - execution_time_ms: Optional[float] - trace_file: Optional[str] - observed_features: Optional[dict] + ground_truth_stage: str | None + is_correct: int | None + execution_time_ms: float | None + trace_file: str | None + observed_features: dict | None created_at: str @@ -87,9 +96,9 @@ class GroundTruthEntry(TypedDict): embryo_id: str stage: str start_timepoint: int - end_timepoint: Optional[int] - annotator: Optional[str] - notes: Optional[str] + end_timepoint: int | None + annotator: str | None + notes: str | None created_at: str diff --git a/gently/dataset/__init__.py b/gently/dataset/__init__.py index f69fff01..9c09aeb8 100644 --- a/gently/dataset/__init__.py +++ b/gently/dataset/__init__.py @@ -40,15 +40,15 @@ stacklevel=2, ) -from .schema import ( - init_database, +from .aggregator import DatasetAggregator # noqa: E402 +from .embryo_dataset import DatasetEmbryoEntry, EmbryoDataset, ImageData # noqa: E402 +from .schema import ( # noqa: E402 + DATABASE_VERSION, get_connection, + init_database, migrate_to_v2, migrate_to_v3, - DATABASE_VERSION, ) -from .aggregator import DatasetAggregator -from .embryo_dataset import EmbryoDataset, DatasetEmbryoEntry, ImageData __all__ = [ "init_database", diff --git a/gently/dataset/aggregator.py b/gently/dataset/aggregator.py index 3008ff78..c48a99c3 100644 --- a/gently/dataset/aggregator.py +++ b/gently/dataset/aggregator.py @@ -13,9 +13,9 @@ import sqlite3 from datetime import datetime from pathlib import Path -from typing import Optional, Iterator, Dict, Any, List +from typing import Any -from .schema import init_database, get_connection, transaction +from .schema import init_database logger = logging.getLogger(__name__) @@ -40,18 +40,21 @@ class DatasetAggregator: def __init__( self, - db_path: Optional[Path] = None, + db_path: Path | None = None, sessions_dir: Path = Path("D:/gently/sessions"), data_dir: Path = Path("D:/gently/data"), images_dir: Path = Path("D:/gently/images"), - ground_truth_dir: Optional[Path] = None, + ground_truth_dir: Path | None = None, ): self.db_path = db_path or Path("D:/gently/dataset.db") self.sessions_dir = sessions_dir self.data_dir = data_dir self.images_dir = images_dir - self.ground_truth_dir = ground_truth_dir or Path(__file__).parent.parent.parent / "benchmarks" / "data" / "ground_truth" - self.conn: Optional[sqlite3.Connection] = None + self.ground_truth_dir = ( + ground_truth_dir + or Path(__file__).parent.parent.parent / "benchmarks" / "data" / "ground_truth" + ) + self.conn: sqlite3.Connection | None = None def connect(self) -> sqlite3.Connection: """Initialize and connect to the database.""" @@ -64,7 +67,7 @@ def close(self): self.conn.close() self.conn = None - def aggregate_all(self, incremental: bool = True) -> Dict[str, int]: + def aggregate_all(self, incremental: bool = True) -> dict[str, int]: """ Run full aggregation from all sources. @@ -119,7 +122,7 @@ def aggregate_all(self, incremental: bool = True) -> Dict[str, int]: logger.info(f"Aggregation complete: {stats}") return stats - def aggregate_sessions(self, since: Optional[datetime] = None) -> Dict[str, int]: + def aggregate_sessions(self, since: datetime | None = None) -> dict[str, int]: """ Aggregate session data from JSON files. @@ -149,7 +152,7 @@ def aggregate_sessions(self, since: Optional[datetime] = None) -> Dict[str, int] continue try: - with open(session_file, 'r', encoding='utf-8') as f: + with open(session_file, encoding="utf-8") as f: data = json.load(f) result = self._process_session(data) @@ -159,7 +162,9 @@ def aggregate_sessions(self, since: Optional[datetime] = None) -> Dict[str, int] stats["updated"] += 1 # Count embryos - embryos = data.get("embryo_states", {}) or data.get("experiment_data", {}).get("embryos", {}) + embryos = data.get("embryo_states", {}) or data.get("experiment_data", {}).get( + "embryos", {} + ) stats["embryos"] += len(embryos) except Exception as e: @@ -169,7 +174,7 @@ def aggregate_sessions(self, since: Optional[datetime] = None) -> Dict[str, int] log_id, stats["added"] + stats["updated"], stats["added"], - stats["updated"] + stats["updated"], ) except Exception as e: @@ -178,7 +183,7 @@ def aggregate_sessions(self, since: Optional[datetime] = None) -> Dict[str, int] return stats - def _process_session(self, data: Dict[str, Any]) -> str: + def _process_session(self, data: dict[str, Any]) -> str: """ Insert or update a session record. @@ -193,54 +198,72 @@ def _process_session(self, data: Dict[str, Any]) -> str: # Check if exists existing = self.conn.execute( - "SELECT session_id FROM sessions WHERE session_id = ?", - (session_id,) + "SELECT session_id FROM sessions WHERE session_id = ?", (session_id,) ).fetchone() # Extract metadata metadata = { - k: v for k, v in data.items() - if k not in ("session_id", "name", "description", "created_at", - "last_active", "conversation", "system_prompt", - "experiment_data", "embryo_states") + k: v + for k, v in data.items() + if k + not in ( + "session_id", + "name", + "description", + "created_at", + "last_active", + "conversation", + "system_prompt", + "experiment_data", + "embryo_states", + ) } if existing: - self.conn.execute(""" + self.conn.execute( + """ UPDATE sessions SET name = ?, description = ?, last_active = ?, metadata_json = ? WHERE session_id = ? - """, ( - data.get("name"), - data.get("description"), - data.get("last_active"), - json.dumps(metadata) if metadata else None, - session_id, - )) + """, + ( + data.get("name"), + data.get("description"), + data.get("last_active"), + json.dumps(metadata) if metadata else None, + session_id, + ), + ) result = "updated" else: - self.conn.execute(""" - INSERT INTO sessions (session_id, name, description, created_at, last_active, metadata_json) + self.conn.execute( + """ + INSERT INTO sessions + (session_id, name, description, created_at, last_active, metadata_json) VALUES (?, ?, ?, ?, ?, ?) - """, ( - session_id, - data.get("name"), - data.get("description"), - data.get("created_at"), - data.get("last_active"), - json.dumps(metadata) if metadata else None, - )) + """, + ( + session_id, + data.get("name"), + data.get("description"), + data.get("created_at"), + data.get("last_active"), + json.dumps(metadata) if metadata else None, + ), + ) result = "added" # Process embryos - embryos = data.get("embryo_states", {}) or data.get("experiment_data", {}).get("embryos", {}) + embryos = data.get("embryo_states", {}) or data.get("experiment_data", {}).get( + "embryos", {} + ) for embryo_id, embryo_data in embryos.items(): self._process_embryo(session_id, embryo_id, embryo_data) self.conn.commit() return result - def _process_embryo(self, session_id: str, embryo_id: str, data: Dict[str, Any]): + def _process_embryo(self, session_id: str, embryo_id: str, data: dict[str, Any]): """Insert or update an embryo record.""" stage_pos = data.get("stage_position", {}) calibration = data.get("calibration", {}) @@ -248,23 +271,26 @@ def _process_embryo(self, session_id: str, embryo_id: str, data: Dict[str, Any]) # Get UID from data, or generate backward-compatible UID embryo_uid = data.get("uid") or f"{session_id}_{embryo_id}" - self.conn.execute(""" + self.conn.execute( + """ INSERT OR REPLACE INTO embryos (embryo_id, session_id, nickname, user_label, stage_position_x, stage_position_y, calibration_json, embryo_uid) VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, ( - embryo_id, - session_id, - data.get("nickname"), - data.get("user_label"), - stage_pos.get("x"), - stage_pos.get("y"), - json.dumps(calibration) if calibration else None, - embryo_uid, - )) - - def aggregate_volumes(self, since: Optional[datetime] = None) -> Dict[str, int]: + """, + ( + embryo_id, + session_id, + data.get("nickname"), + data.get("user_label"), + stage_pos.get("x"), + stage_pos.get("y"), + json.dumps(calibration) if calibration else None, + embryo_uid, + ), + ) + + def aggregate_volumes(self, since: datetime | None = None) -> dict[str, int]: """ Aggregate volume data from data directory. @@ -293,7 +319,7 @@ def aggregate_volumes(self, since: Optional[datetime] = None) -> Dict[str, int]: continue try: - with open(meta_file, 'r', encoding='utf-8') as f: + with open(meta_file, encoding="utf-8") as f: data = json.load(f) if self._process_volume(data, meta_file): @@ -306,7 +332,9 @@ def aggregate_volumes(self, since: Optional[datetime] = None) -> Dict[str, int]: stats["skipped"] += 1 self.conn.commit() - self._complete_aggregation_log(log_id, stats["added"] + stats["skipped"], stats["added"], 0) + self._complete_aggregation_log( + log_id, stats["added"] + stats["skipped"], stats["added"], 0 + ) except Exception as e: self._fail_aggregation_log(log_id, str(e)) @@ -339,7 +367,7 @@ def aggregate_volumes(self, since: Optional[datetime] = None) -> Dict[str, int]: return stats - def _process_volume(self, data: Dict[str, Any], meta_file: Path) -> bool: + def _process_volume(self, data: dict[str, Any], meta_file: Path) -> bool: """ Process a volume from its metadata JSON. @@ -350,9 +378,7 @@ def _process_volume(self, data: Dict[str, Any], meta_file: Path) -> bool: return False # Check if exists - existing = self.conn.execute( - "SELECT uid FROM volumes WHERE uid = ?", (uid,) - ).fetchone() + existing = self.conn.execute("SELECT uid FROM volumes WHERE uid = ?", (uid,)).fetchone() if existing: return False @@ -370,26 +396,30 @@ def _process_volume(self, data: Dict[str, Any], meta_file: Path) -> bool: if session_id and embryo_id: result = self.conn.execute( "SELECT embryo_uid FROM embryos WHERE session_id = ? AND embryo_id = ?", - (session_id, embryo_id) + (session_id, embryo_id), ).fetchone() embryo_uid = result[0] if result else f"{session_id}_{embryo_id}" - self.conn.execute(""" + self.conn.execute( + """ INSERT INTO volumes - (uid, session_id, embryo_id, timepoint, file_path, shape_json, dtype, timestamp, metadata_json, embryo_uid) + (uid, session_id, embryo_id, timepoint, file_path, shape_json, dtype, + timestamp, metadata_json, embryo_uid) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - uid, - session_id, - embryo_id, - metadata.get("timepoint"), - str(tiff_path) if tiff_path.exists() else None, - json.dumps(data.get("shape")), - data.get("dtype"), - data.get("timestamp"), - json.dumps(metadata) if metadata else None, - embryo_uid, - )) + """, + ( + uid, + session_id, + embryo_id, + metadata.get("timepoint"), + str(tiff_path) if tiff_path.exists() else None, + json.dumps(data.get("shape")), + data.get("dtype"), + data.get("timestamp"), + json.dumps(metadata) if metadata else None, + embryo_uid, + ), + ) return True @@ -403,9 +433,7 @@ def _process_volume_from_tiff(self, tif_path: Path, session_id: str) -> bool: uid = f"tiff_{tif_path.stem}" # Check if exists - existing = self.conn.execute( - "SELECT uid FROM volumes WHERE uid = ?", (uid,) - ).fetchone() + existing = self.conn.execute("SELECT uid FROM volumes WHERE uid = ?", (uid,)).fetchone() if existing: return False @@ -422,7 +450,10 @@ def _process_volume_from_tiff(self, tif_path: Path, session_id: str) -> bool: try: date_str = parts[2] time_str = parts[3] - timestamp_str = f"{date_str[:4]}-{date_str[4:6]}-{date_str[6:8]}T{time_str[:2]}:{time_str[2:4]}:{time_str[4:6]}" + timestamp_str = ( + f"{date_str[:4]}-{date_str[4:6]}-{date_str[6:8]}" + f"T{time_str[:2]}:{time_str[2:4]}:{time_str[4:6]}" + ) except Exception: timestamp_str = datetime.fromtimestamp(tif_path.stat().st_mtime).isoformat() @@ -434,27 +465,30 @@ def _process_volume_from_tiff(self, tif_path: Path, session_id: str) -> bool: if session_id and embryo_id: result = self.conn.execute( "SELECT embryo_uid FROM embryos WHERE session_id = ? AND embryo_id = ?", - (session_id, embryo_id) + (session_id, embryo_id), ).fetchone() embryo_uid = result[0] if result else f"{session_id}_{embryo_id}" - self.conn.execute(""" + self.conn.execute( + """ INSERT INTO volumes (uid, session_id, embryo_id, file_path, timestamp, metadata_json, embryo_uid) VALUES (?, ?, ?, ?, ?, ?, ?) - """, ( - uid, - session_id, - embryo_id, - str(tif_path), - timestamp_str, - json.dumps({"source": "images_dir"}), - embryo_uid, - )) + """, + ( + uid, + session_id, + embryo_id, + str(tif_path), + timestamp_str, + json.dumps({"source": "images_dir"}), + embryo_uid, + ), + ) return True - def aggregate_images(self, since: Optional[datetime] = None) -> Dict[str, int]: + def aggregate_images(self, since: datetime | None = None) -> dict[str, int]: """ Aggregate image projection data. @@ -484,7 +518,7 @@ def aggregate_images(self, since: Optional[datetime] = None) -> Dict[str, int]: continue try: - with open(meta_file, 'r', encoding='utf-8') as f: + with open(meta_file, encoding="utf-8") as f: data = json.load(f) if self._process_image(data): @@ -497,7 +531,9 @@ def aggregate_images(self, since: Optional[datetime] = None) -> Dict[str, int]: stats["skipped"] += 1 self.conn.commit() - self._complete_aggregation_log(log_id, stats["added"] + stats["skipped"], stats["added"], 0) + self._complete_aggregation_log( + log_id, stats["added"] + stats["skipped"], stats["added"], 0 + ) except Exception as e: self._fail_aggregation_log(log_id, str(e)) @@ -505,7 +541,7 @@ def aggregate_images(self, since: Optional[datetime] = None) -> Dict[str, int]: return stats - def _process_image(self, data: Dict[str, Any]) -> bool: + def _process_image(self, data: dict[str, Any]) -> bool: """ Process an image from its metadata JSON. @@ -516,9 +552,7 @@ def _process_image(self, data: Dict[str, Any]) -> bool: return False # Check if exists - existing = self.conn.execute( - "SELECT uid FROM images WHERE uid = ?", (uid,) - ).fetchone() + existing = self.conn.execute("SELECT uid FROM images WHERE uid = ?", (uid,)).fetchone() if existing: return False @@ -531,33 +565,36 @@ def _process_image(self, data: Dict[str, Any]) -> bool: if session_id and embryo_id: result = self.conn.execute( "SELECT embryo_uid FROM embryos WHERE session_id = ? AND embryo_id = ?", - (session_id, embryo_id) + (session_id, embryo_id), ).fetchone() embryo_uid = result[0] if result else f"{session_id}_{embryo_id}" - self.conn.execute(""" + self.conn.execute( + """ INSERT INTO images (uid, parent_uid, session_id, embryo_id, timepoint, projection_type, shape_json, dtype, b64_size_kb, timestamp, metadata_json, embryo_uid) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - uid, - metadata.get("parent_uid"), - session_id, - embryo_id, - metadata.get("timepoint"), - metadata.get("projection_type"), - json.dumps(data.get("shape")), - data.get("dtype"), - metadata.get("b64_size_kb"), - data.get("timestamp"), - json.dumps(metadata) if metadata else None, - embryo_uid, - )) + """, + ( + uid, + metadata.get("parent_uid"), + session_id, + embryo_id, + metadata.get("timepoint"), + metadata.get("projection_type"), + json.dumps(data.get("shape")), + data.get("dtype"), + metadata.get("b64_size_kb"), + data.get("timestamp"), + json.dumps(metadata) if metadata else None, + embryo_uid, + ), + ) return True - def aggregate_ground_truth(self) -> Dict[str, int]: + def aggregate_ground_truth(self) -> dict[str, int]: """ Import ground truth annotations from benchmark data. @@ -586,7 +623,7 @@ def aggregate_ground_truth(self) -> Dict[str, int]: try: for gt_file in self.ground_truth_dir.glob("*.json"): try: - with open(gt_file, 'r', encoding='utf-8') as f: + with open(gt_file, encoding="utf-8") as f: data = json.load(f) session_id = data.get("session_id") @@ -597,31 +634,53 @@ def aggregate_ground_truth(self) -> Dict[str, int]: for embryo_id, stages in transitions.items(): for stage, start_timepoint in stages.items(): # Check if exists - existing = self.conn.execute(""" + existing = self.conn.execute( + """ SELECT id FROM ground_truth WHERE session_id = ? AND embryo_id = ? AND stage = ? - """, (session_id, embryo_id, stage)).fetchone() + """, + (session_id, embryo_id, stage), + ).fetchone() if existing: - self.conn.execute(""" + self.conn.execute( + """ UPDATE ground_truth SET start_timepoint = ?, annotator = ?, notes = ? WHERE id = ? - """, (start_timepoint, annotator, notes, existing[0])) + """, + (start_timepoint, annotator, notes, existing[0]), + ) stats["updated"] += 1 else: - self.conn.execute(""" + self.conn.execute( + """ INSERT INTO ground_truth - (session_id, embryo_id, stage, start_timepoint, annotator, notes) + (session_id, embryo_id, stage, start_timepoint, + annotator, notes) VALUES (?, ?, ?, ?, ?, ?) - """, (session_id, embryo_id, stage, start_timepoint, annotator, notes)) + """, + ( + session_id, + embryo_id, + stage, + start_timepoint, + annotator, + notes, + ), + ) stats["added"] += 1 except Exception as e: logger.error(f"Error processing ground truth {gt_file}: {e}") self.conn.commit() - self._complete_aggregation_log(log_id, stats["added"] + stats["updated"], stats["added"], stats["updated"]) + self._complete_aggregation_log( + log_id, + stats["added"] + stats["updated"], + stats["added"], + stats["updated"], + ) except Exception as e: self._fail_aggregation_log(log_id, str(e)) @@ -629,7 +688,7 @@ def aggregate_ground_truth(self) -> Dict[str, int]: return stats - def _get_last_run_time(self) -> Optional[datetime]: + def _get_last_run_time(self) -> datetime | None: """Get the timestamp of the last successful aggregation.""" result = self.conn.execute(""" SELECT MAX(completed_at) FROM aggregation_log @@ -642,47 +701,56 @@ def _get_last_run_time(self) -> Optional[datetime]: def _set_last_run_time(self, timestamp: datetime): """Record the current aggregation time.""" - self.conn.execute(""" + self.conn.execute( + """ INSERT OR REPLACE INTO metadata (key, value, updated_at) VALUES ('last_aggregation', ?, ?) - """, (timestamp.isoformat(), timestamp.isoformat())) + """, + (timestamp.isoformat(), timestamp.isoformat()), + ) self.conn.commit() def _start_aggregation_log(self, source_type: str, source_path: str) -> int: """Start a new aggregation log entry.""" - cursor = self.conn.execute(""" + cursor = self.conn.execute( + """ INSERT INTO aggregation_log (source_type, source_path, started_at, status) VALUES (?, ?, ?, 'running') - """, (source_type, source_path, datetime.now().isoformat())) + """, + (source_type, source_path, datetime.now().isoformat()), + ) self.conn.commit() return cursor.lastrowid def _complete_aggregation_log(self, log_id: int, processed: int, added: int, updated: int): """Mark an aggregation log as completed.""" - self.conn.execute(""" + self.conn.execute( + """ UPDATE aggregation_log SET items_processed = ?, items_added = ?, items_updated = ?, completed_at = ?, status = 'completed' WHERE id = ? - """, (processed, added, updated, datetime.now().isoformat(), log_id)) + """, + (processed, added, updated, datetime.now().isoformat(), log_id), + ) self.conn.commit() def _fail_aggregation_log(self, log_id: int, error_message: str): """Mark an aggregation log as failed.""" - self.conn.execute(""" + self.conn.execute( + """ UPDATE aggregation_log SET completed_at = ?, status = 'failed', error_message = ? WHERE id = ? - """, (datetime.now().isoformat(), error_message, log_id)) + """, + (datetime.now().isoformat(), error_message, log_id), + ) self.conn.commit() def get_stage_at_timepoint( - conn: sqlite3.Connection, - session_id: str, - embryo_id: str, - timepoint: int -) -> Optional[str]: + conn: sqlite3.Connection, session_id: str, embryo_id: str, timepoint: int +) -> str | None: """ Get the ground truth stage at a specific timepoint. @@ -705,11 +773,14 @@ def get_stage_at_timepoint( str or None Stage name or None if not found """ - result = conn.execute(""" + result = conn.execute( + """ SELECT stage FROM ground_truth WHERE session_id = ? AND embryo_id = ? AND start_timepoint <= ? ORDER BY start_timepoint DESC LIMIT 1 - """, (session_id, embryo_id, timepoint)).fetchone() + """, + (session_id, embryo_id, timepoint), + ).fetchone() return result[0] if result else None diff --git a/gently/dataset/cli.py b/gently/dataset/cli.py index 964354fb..e7385c01 100644 --- a/gently/dataset/cli.py +++ b/gently/dataset/cli.py @@ -13,11 +13,10 @@ from pathlib import Path from .aggregator import DatasetAggregator -from .schema import get_connection, get_database_stats, DEFAULT_DB_PATH +from .schema import DEFAULT_DB_PATH, get_connection, get_database_stats logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) @@ -74,7 +73,7 @@ def cmd_stats(args): print() print(f"Unique Embryo-Sessions: {stats.get('unique_embryo_sessions', 0):,}") print() - if stats.get('earliest_volume'): + if stats.get("earliest_volume"): print(f"Date Range: {stats['earliest_volume'][:10]} to {stats['latest_volume'][:10]}") @@ -89,6 +88,7 @@ def cmd_serve(args): try: from .explorer_server import DatasetExplorer + explorer = DatasetExplorer(db_path=db_path, port=args.port) explorer.run() except ImportError as e: @@ -121,7 +121,7 @@ def cmd_query(args): print("-" * 80) # Print rows - for row in rows[:args.limit]: + for row in rows[: args.limit]: print("\t".join(str(v) if v is not None else "NULL" for v in row)) if len(rows) > args.limit: @@ -152,22 +152,17 @@ def main(): # Run a SQL query python -m gently.dataset.cli query "SELECT * FROM sessions LIMIT 5" - """ + """, ) - parser.add_argument( - "--db", - help=f"Database path (default: {DEFAULT_DB_PATH})" - ) + parser.add_argument("--db", help=f"Database path (default: {DEFAULT_DB_PATH})") subparsers = parser.add_subparsers(dest="command", help="Command to run") # Aggregate command agg_parser = subparsers.add_parser("aggregate", help="Aggregate data into database") agg_parser.add_argument( - "--full", - action="store_true", - help="Run full aggregation (not incremental)" + "--full", action="store_true", help="Run full aggregation (not incremental)" ) # Stats command @@ -176,20 +171,14 @@ def main(): # Serve command serve_parser = subparsers.add_parser("serve", help="Start web explorer") serve_parser.add_argument( - "--port", - type=int, - default=8765, - help="Port to serve on (default: 8765)" + "--port", type=int, default=8765, help="Port to serve on (default: 8765)" ) # Query command query_parser = subparsers.add_parser("query", help="Run SQL query") query_parser.add_argument("sql", help="SQL query to run") query_parser.add_argument( - "--limit", - type=int, - default=100, - help="Max rows to display (default: 100)" + "--limit", type=int, default=100, help="Max rows to display (default: 100)" ) args = parser.parse_args() diff --git a/gently/dataset/embryo_dataset.py b/gently/dataset/embryo_dataset.py index df951710..b6cd3a66 100644 --- a/gently/dataset/embryo_dataset.py +++ b/gently/dataset/embryo_dataset.py @@ -25,18 +25,18 @@ ) """ -import base64 import json import logging import sqlite3 +from collections.abc import Iterator from dataclasses import dataclass, field from datetime import datetime from pathlib import Path -from typing import Optional, Iterator, List, Dict, Any, Tuple +from typing import Any, Optional import numpy as np -from .schema import get_connection, DEFAULT_DB_PATH +from .schema import DEFAULT_DB_PATH, get_connection logger = logging.getLogger(__name__) @@ -44,40 +44,41 @@ @dataclass class ImageData: """Data for a single image in the dataset.""" + uid: str embryo_id: str timepoint: int timestamp: str # Image data (loaded on demand) - _image_b64: Optional[str] = field(default=None, repr=False) - _volume_path: Optional[str] = None - _image_path: Optional[str] = None + _image_b64: str | None = field(default=None, repr=False) + _volume_path: str | None = None + _image_path: str | None = None # Ground truth (if available) - ground_truth_stage: Optional[str] = None + ground_truth_stage: str | None = None # Metadata - shape: Optional[Tuple[int, int]] = None + shape: tuple[int, int] | None = None projection_type: str = "max_z" - session_id: Optional[str] = None + session_id: str | None = None # Internal reference to dataset for lazy loading _dataset: Optional["EmbryoDataset"] = field(default=None, repr=False) @property - def image_b64(self) -> Optional[str]: + def image_b64(self) -> str | None: """Load and return base64 image data (lazy loading).""" if self._image_b64 is None and self._dataset: self._image_b64 = self._dataset._load_image_b64(self) return self._image_b64 @property - def volume_path(self) -> Optional[str]: + def volume_path(self) -> str | None: """Path to the source volume TIFF.""" return self._volume_path - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Convert to dictionary (without image data).""" return { "uid": self.uid, @@ -95,20 +96,21 @@ def to_dict(self) -> Dict[str, Any]: @dataclass class DatasetEmbryoEntry: """Information about an embryo in the dataset.""" + embryo_id: str - session_id: Optional[str] + session_id: str | None num_images: int num_volumes: int - timepoint_range: Tuple[int, int] # (min, max) + timepoint_range: tuple[int, int] # (min, max) has_ground_truth: bool - ground_truth_stages: List[str] = field(default_factory=list) + ground_truth_stages: list[str] = field(default_factory=list) # Internal reference to dataset _dataset: Optional["EmbryoDataset"] = field(default=None, repr=False) def iter_images( self, - timepoint_range: Optional[Tuple[int, int]] = None, + timepoint_range: tuple[int, int] | None = None, load_image_data: bool = True, ) -> Iterator[ImageData]: """ @@ -158,7 +160,7 @@ class EmbryoDataset: def __init__( self, - db_path: Optional[Path] = None, + db_path: Path | None = None, data_dir: Path = Path("D:/gently/data"), gently_store=None, ): @@ -170,8 +172,8 @@ def __init__( self.db_path = db_path or DEFAULT_DB_PATH self.data_dir = data_dir self._gently_store = None - self._conn: Optional[sqlite3.Connection] = None - self._is_gently_schema: Optional[bool] = None + self._conn: sqlite3.Connection | None = None + self._is_gently_schema: bool | None = None @classmethod def from_store(cls, store) -> "EmbryoDataset": @@ -195,7 +197,7 @@ def is_gently_schema(self) -> bool: self._is_gently_schema = "uid" not in columns return self._is_gently_schema - def _resolve_db_path(self, db_path: Optional[str]) -> Optional[str]: + def _resolve_db_path(self, db_path: str | None) -> str | None: """Convert DB-stored file path to absolute path string.""" if db_path is None: return None @@ -215,8 +217,8 @@ def close(self): def iter_embryos( self, - session_id: Optional[str] = None, - has_ground_truth: Optional[bool] = None, + session_id: str | None = None, + has_ground_truth: bool | None = None, min_images: int = 1, ) -> Iterator[DatasetEmbryoEntry]: """ @@ -281,11 +283,14 @@ def iter_embryos( sess_id = row[1] # Check ground truth - gt_rows = self.conn.execute(""" + gt_rows = self.conn.execute( + """ SELECT stage FROM ground_truth WHERE embryo_id = ? AND (session_id = ? OR ? IS NULL) ORDER BY start_timepoint - """, (embryo_id, sess_id, sess_id)).fetchall() + """, + (embryo_id, sess_id, sess_id), + ).fetchall() gt_stages = [r[0] for r in gt_rows] has_gt = len(gt_stages) > 0 @@ -310,8 +315,8 @@ def iter_embryos( def iter_images( self, embryo_id: str, - session_id: Optional[str] = None, - timepoint_range: Optional[Tuple[int, int]] = None, + session_id: str | None = None, + timepoint_range: tuple[int, int] | None = None, load_image_data: bool = True, ) -> Iterator[ImageData]: """ @@ -404,7 +409,7 @@ def iter_images( if row[8]: # image_shape try: shape = tuple(json.loads(row[8])) - except: + except Exception: pass img_data = ImageData( @@ -425,8 +430,8 @@ def get_image( self, embryo_id: str, timepoint: int, - session_id: Optional[str] = None, - ) -> Optional[ImageData]: + session_id: str | None = None, + ) -> ImageData | None: """Get a single image by embryo and timepoint.""" for img in self.iter_images( embryo_id=embryo_id, @@ -440,18 +445,20 @@ def get_image_by_index( self, embryo_id: str, index: int, - session_id: Optional[str] = None, - ) -> Optional[ImageData]: + session_id: str | None = None, + ) -> ImageData | None: """Get a single image by sequential index (for volumes without timepoints).""" - for i, img in enumerate(self.iter_images( - embryo_id=embryo_id, - session_id=session_id, - )): + for i, img in enumerate( + self.iter_images( + embryo_id=embryo_id, + session_id=session_id, + ) + ): if i == index: return img return None - def get_image_by_uid(self, uid: str) -> Optional[ImageData]: + def get_image_by_uid(self, uid: str) -> ImageData | None: """Get a single image by its UID. For GentlyStore schema (no UIDs), returns None. @@ -461,7 +468,8 @@ def get_image_by_uid(self, uid: str) -> Optional[ImageData]: return None # Legacy schema — query the volume directly - row = self.conn.execute(""" + row = self.conn.execute( + """ SELECT v.uid as volume_uid, v.embryo_id, @@ -477,7 +485,9 @@ def get_image_by_uid(self, uid: str) -> Optional[ImageData]: AND i.timepoint = v.timepoint WHERE v.uid = ? OR i.uid = ? LIMIT 1 - """, (uid, uid)).fetchone() + """, + (uid, uid), + ).fetchone() if not row: return None @@ -496,7 +506,7 @@ def get_image_by_uid(self, uid: str) -> Optional[ImageData]: if row[8]: try: shape = tuple(json.loads(row[8])) - except: + except Exception: pass return ImageData( @@ -518,8 +528,8 @@ def get_image_by_uid(self, uid: str) -> Optional[ImageData]: def _get_ground_truth_map( self, embryo_id: str, - session_id: Optional[str] = None, - ) -> Dict[str, Tuple[int, int]]: + session_id: str | None = None, + ) -> dict[str, tuple[int, int]]: """ Get ground truth stage → (start_tp, end_tp) mapping. @@ -565,9 +575,9 @@ def set_ground_truth( embryo_id: str, stage: str, start_timepoint: int, - end_timepoint: Optional[int] = None, - annotator: Optional[str] = None, - notes: Optional[str] = None, + end_timepoint: int | None = None, + annotator: str | None = None, + notes: str | None = None, ): """ Set or update a ground truth annotation. @@ -589,11 +599,22 @@ def set_ground_truth( notes : str, optional Additional notes """ - self.conn.execute(""" + self.conn.execute( + """ INSERT OR REPLACE INTO ground_truth (session_id, embryo_id, stage, start_timepoint, end_timepoint, annotator, notes) VALUES (?, ?, ?, ?, ?, ?, ?) - """, (session_id, embryo_id, stage, start_timepoint, end_timepoint, annotator, notes)) + """, + ( + session_id, + embryo_id, + stage, + start_timepoint, + end_timepoint, + annotator, + notes, + ), + ) self.conn.commit() end_str = f"-{end_timepoint}" if end_timepoint else "" logger.info(f"Set ground truth: {embryo_id} {stage} @ t={start_timepoint}{end_str}") @@ -602,7 +623,7 @@ def delete_ground_truth( self, session_id: str, embryo_id: str, - stage: Optional[str] = None, + stage: str | None = None, ): """ Delete ground truth annotation(s). @@ -617,29 +638,38 @@ def delete_ground_truth( If provided, delete only this stage. Otherwise delete all. """ if stage: - self.conn.execute(""" + self.conn.execute( + """ DELETE FROM ground_truth WHERE session_id = ? AND embryo_id = ? AND stage = ? - """, (session_id, embryo_id, stage)) + """, + (session_id, embryo_id, stage), + ) else: - self.conn.execute(""" + self.conn.execute( + """ DELETE FROM ground_truth WHERE session_id = ? AND embryo_id = ? - """, (session_id, embryo_id)) + """, + (session_id, embryo_id), + ) self.conn.commit() def get_ground_truth( self, session_id: str, embryo_id: str, - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """Get all ground truth entries for an embryo.""" - rows = self.conn.execute(""" + rows = self.conn.execute( + """ SELECT stage, start_timepoint, end_timepoint, annotator, notes, created_at FROM ground_truth WHERE session_id = ? AND embryo_id = ? ORDER BY start_timepoint - """, (session_id, embryo_id)).fetchall() + """, + (session_id, embryo_id), + ).fetchall() return [ { @@ -661,12 +691,12 @@ def create_perception_run( self, name: str, perception_method: str, - model_name: Optional[str] = None, - config: Optional[Dict] = None, - description: Optional[str] = None, - trace_type: str = 'perception', - source: str = 'benchmark', - session_id: Optional[str] = None, + model_name: str | None = None, + config: dict | None = None, + description: str | None = None, + trace_type: str = "perception", + source: str = "benchmark", + session_id: str | None = None, ) -> int: """ Create a new perception run record. @@ -698,26 +728,43 @@ def create_perception_run( config_json = json.dumps(config) if config else None if self.is_gently_schema: - cursor = self.conn.execute(""" + cursor = self.conn.execute( + """ INSERT INTO perception_runs (session_id, name, perception_method, model_name, config, status, trace_type, source, created_at) VALUES (?, ?, ?, ?, ?, 'running', ?, ?, ?) - """, ( - session_id, name, perception_method, model_name, - config_json, trace_type, source, - datetime.now().isoformat(), - )) + """, + ( + session_id, + name, + perception_method, + model_name, + config_json, + trace_type, + source, + datetime.now().isoformat(), + ), + ) else: - cursor = self.conn.execute(""" + cursor = self.conn.execute( + """ INSERT INTO perception_runs (name, perception_method, model_name, config_json, description, status, trace_type, source, session_id) VALUES (?, ?, ?, ?, ?, 'running', ?, ?, ?) - """, ( - name, perception_method, model_name, - config_json, description, trace_type, source, session_id, - )) + """, + ( + name, + perception_method, + model_name, + config_json, + description, + trace_type, + source, + session_id, + ), + ) self.conn.commit() return cursor.lastrowid @@ -727,15 +774,15 @@ def store_prediction( embryo_id: str, timepoint: int, predicted_stage: str, - confidence: Optional[float] = None, - reasoning: Optional[str] = None, - image_uid: Optional[str] = None, - session_id: Optional[str] = None, + confidence: float | None = None, + reasoning: str | None = None, + image_uid: str | None = None, + session_id: str | None = None, is_transitional: bool = False, - observed_features: Optional[Dict] = None, - reasoning_trace: Optional[Dict] = None, - execution_time_ms: Optional[float] = None, - trace_file_path: Optional[str] = None, + observed_features: dict | None = None, + reasoning_trace: dict | None = None, + execution_time_ms: float | None = None, + trace_file_path: str | None = None, ) -> int: """ Store a perception prediction. @@ -759,22 +806,32 @@ def store_prediction( if self.is_gently_schema: # GentlyStore schema: single predictions table with JSON blobs - cursor = self.conn.execute(""" + cursor = self.conn.execute( + """ INSERT INTO predictions (run_id, session_id, embryo_id, timepoint, predicted_stage, confidence, reasoning, is_transitional, ground_truth_stage, is_correct, execution_time_ms, trace_file, observed_features, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - run_id, session_id, embryo_id, timepoint, - predicted_stage, confidence, reasoning, - 1 if is_transitional else 0, - gt_stage, is_correct, execution_time_ms, - trace_file_path, - json.dumps(observed_features) if observed_features else None, - datetime.now().isoformat(), - )) + """, + ( + run_id, + session_id, + embryo_id, + timepoint, + predicted_stage, + confidence, + reasoning, + 1 if is_transitional else 0, + gt_stage, + is_correct, + execution_time_ms, + trace_file_path, + json.dumps(observed_features) if observed_features else None, + datetime.now().isoformat(), + ), + ) prediction_id = cursor.lastrowid else: # Legacy schema: separate observed_features + reasoning_traces tables @@ -787,56 +844,78 @@ def store_prediction( else: confidence_level = "LOW" - cursor = self.conn.execute(""" + cursor = self.conn.execute( + """ INSERT INTO predictions (perception_run_id, image_uid, session_id, embryo_id, timepoint, predicted_stage, confidence, confidence_level, is_transitional, reasoning, ground_truth_stage, is_correct, execution_time_ms) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - run_id, image_uid, session_id, embryo_id, timepoint, - predicted_stage, confidence, confidence_level, - 1 if is_transitional else 0, - reasoning, gt_stage, is_correct, execution_time_ms, - )) + """, + ( + run_id, + image_uid, + session_id, + embryo_id, + timepoint, + predicted_stage, + confidence, + confidence_level, + 1 if is_transitional else 0, + reasoning, + gt_stage, + is_correct, + execution_time_ms, + ), + ) prediction_id = cursor.lastrowid # Store observed features if provided if observed_features: - self.conn.execute(""" + self.conn.execute( + """ INSERT INTO observed_features (prediction_id, shape, curvature, shell_status, body_segments, emergence, movement, texture, features_json) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - prediction_id, - observed_features.get("shape"), - observed_features.get("curvature"), - observed_features.get("shell_status"), - observed_features.get("body_segments"), - observed_features.get("emergence"), - observed_features.get("movement"), - observed_features.get("texture"), - json.dumps(observed_features), - )) + """, + ( + prediction_id, + observed_features.get("shape"), + observed_features.get("curvature"), + observed_features.get("shell_status"), + observed_features.get("body_segments"), + observed_features.get("emergence"), + observed_features.get("movement"), + observed_features.get("texture"), + json.dumps(observed_features), + ), + ) # Store reasoning trace if provided if reasoning_trace or trace_file_path: - self.conn.execute(""" + self.conn.execute( + """ INSERT INTO reasoning_traces (prediction_id, contrastive_reasoning, steps_json, tool_calls_json, tools_used_json, total_tool_calls, file_path) VALUES (?, ?, ?, ?, ?, ?, ?) - """, ( - prediction_id, - reasoning_trace.get("contrastive_reasoning") if reasoning_trace else None, - json.dumps(reasoning_trace.get("steps", [])) if reasoning_trace else None, - json.dumps(reasoning_trace.get("tool_calls", [])) if reasoning_trace else None, - json.dumps(reasoning_trace.get("tools_used", [])) if reasoning_trace else None, - reasoning_trace.get("total_tool_calls", 0) if reasoning_trace else 0, - trace_file_path, - )) + """, + ( + prediction_id, + reasoning_trace.get("contrastive_reasoning") if reasoning_trace else None, + json.dumps(reasoning_trace.get("steps", [])) if reasoning_trace else None, + json.dumps(reasoning_trace.get("tool_calls", [])) + if reasoning_trace + else None, + json.dumps(reasoning_trace.get("tools_used", [])) + if reasoning_trace + else None, + reasoning_trace.get("total_tool_calls", 0) if reasoning_trace else 0, + trace_file_path, + ), + ) self.conn.commit() return prediction_id @@ -845,19 +924,23 @@ def complete_perception_run( self, run_id: int, status: str = "completed", - error_message: Optional[str] = None, + error_message: str | None = None, ): """Mark a perception run as completed.""" now = datetime.now().isoformat() if self.is_gently_schema: - self.conn.execute(""" + self.conn.execute( + """ UPDATE perception_runs SET status = ?, completed_at = ?, error_message = ? WHERE run_id = ? - """, (status, now, error_message, run_id)) + """, + (status, now, error_message, run_id), + ) else: - self.conn.execute(""" + self.conn.execute( + """ UPDATE perception_runs SET status = ?, completed_at = ?, @@ -866,14 +949,16 @@ def complete_perception_run( SELECT COUNT(*) FROM predictions WHERE perception_run_id = ? ) WHERE id = ? - """, (status, now, error_message, run_id, run_id)) + """, + (status, now, error_message, run_id, run_id), + ) self.conn.commit() # ========================================================================= # Metrics Methods # ========================================================================= - def compute_run_metrics(self, run_id: int) -> Dict[str, Any]: + def compute_run_metrics(self, run_id: int) -> dict[str, Any]: """ Compute accuracy metrics for a perception run. @@ -884,11 +969,14 @@ def compute_run_metrics(self, run_id: int) -> Dict[str, Any]: """ # Get predictions with ground truth run_col = "run_id" if self.is_gently_schema else "perception_run_id" - rows = self.conn.execute(f""" + rows = self.conn.execute( + f""" SELECT predicted_stage, ground_truth_stage, is_correct, confidence FROM predictions WHERE {run_col} = ? AND ground_truth_stage IS NOT NULL - """, (run_id,)).fetchall() + """, + (run_id,), + ).fetchall() if not rows: return {"error": "No predictions with ground truth"} @@ -938,7 +1026,7 @@ def compute_run_metrics(self, run_id: int) -> Dict[str, Any]: # Image Loading # ========================================================================= - def _load_image_b64(self, img: ImageData) -> Optional[str]: + def _load_image_b64(self, img: ImageData) -> str | None: """ Load base64 image data for an ImageData object. @@ -957,9 +1045,8 @@ def _load_image_b64(self, img: ImageData) -> Optional[str]: def _load_projection_from_volume(self, volume_path: str) -> str: """Load volume and create max projection as base64 JPEG.""" + import tifffile - from PIL import Image - import io # Load volume volume = tifffile.imread(volume_path) @@ -974,7 +1061,7 @@ def _load_projection_from_volume(self, volume_path: str) -> str: z_depth, height, width = volume.shape # Extract View A (left half) if dual-view format if width > height * 1.5: - volume = volume[:, :, :width // 2] + volume = volume[:, :, : width // 2] # Max projection projection = np.max(volume, axis=0) else: @@ -982,10 +1069,11 @@ def _load_projection_from_volume(self, volume_path: str) -> str: # Extract View A if 2D dual-view height, width = projection.shape if width > height * 1.5: - projection = projection[:, :width // 2] + projection = projection[:, : width // 2] # Normalize and encode - from gently.core.imaging import normalize_to_uint8, image_to_base64 + from gently.core.imaging import image_to_base64, normalize_to_uint8 + projection = normalize_to_uint8(projection, method="percentile", p_low=1, p_high=99.5) return image_to_base64(projection, format="JPEG", quality=85, max_dimension=1024) @@ -993,7 +1081,7 @@ def _load_projection_from_volume(self, volume_path: str) -> str: # Query Methods # ========================================================================= - def get_sessions(self) -> List[Dict[str, Any]]: + def get_sessions(self) -> list[dict[str, Any]]: """Get list of sessions with summary stats.""" rows = self.conn.execute(""" SELECT @@ -1024,7 +1112,7 @@ def get_sessions(self) -> List[Dict[str, Any]]: for r in rows ] - def get_perception_runs(self) -> List[Dict[str, Any]]: + def get_perception_runs(self) -> list[dict[str, Any]]: """Get list of perception runs with metrics.""" if self.is_gently_schema: # GentlyStore schema — inline the accuracy view @@ -1062,9 +1150,9 @@ def get_traces_for_image( self, embryo_id: str, timepoint: int, - session_id: Optional[str] = None, - trace_type: Optional[str] = None, - ) -> List[Dict[str, Any]]: + session_id: str | None = None, + trace_type: str | None = None, + ) -> list[dict[str, Any]]: """ Get all traces for a specific image (embryo + timepoint). @@ -1172,8 +1260,8 @@ def get_traces_for_image( def get_runs_for_session( self, session_id: str, - trace_type: Optional[str] = None, - ) -> List[Dict[str, Any]]: + trace_type: str | None = None, + ) -> list[dict[str, Any]]: """ Get all perception runs for a session. @@ -1258,8 +1346,8 @@ def get_runs_for_session( def get_run_predictions( self, run_id: int, - embryo_id: Optional[str] = None, - ) -> List[Dict[str, Any]]: + embryo_id: str | None = None, + ) -> list[dict[str, Any]]: """ Get all predictions for a run, optionally filtered by embryo. @@ -1338,7 +1426,7 @@ def get_run_predictions( # Cross-Session UID Methods # ========================================================================= - def get_embryo_by_uid(self, uid: str) -> List[Dict[str, Any]]: + def get_embryo_by_uid(self, uid: str) -> list[dict[str, Any]]: """ Get all instances of an embryo across sessions by its UID. @@ -1353,7 +1441,8 @@ def get_embryo_by_uid(self, uid: str) -> List[Dict[str, Any]]: All embryo instances matching this UID across different sessions """ if self.is_gently_schema: - rows = self.conn.execute(""" + rows = self.conn.execute( + """ SELECT e.embryo_id, e.session_id, @@ -1376,9 +1465,12 @@ def get_embryo_by_uid(self, uid: str) -> List[Dict[str, Any]]: LEFT JOIN sessions s ON e.session_id = s.session_id WHERE e.embryo_uid = ? ORDER BY e.created_at ASC - """, (uid,)).fetchall() + """, + (uid,), + ).fetchall() else: - rows = self.conn.execute(""" + rows = self.conn.execute( + """ SELECT e.embryo_id, e.session_id, @@ -1392,14 +1484,18 @@ def get_embryo_by_uid(self, uid: str) -> List[Dict[str, Any]]: s.name as session_name, s.created_at as session_created_at, (SELECT COUNT(*) FROM volumes v - WHERE v.embryo_uid = e.embryo_uid AND v.session_id = e.session_id) as volume_count, + WHERE v.embryo_uid = e.embryo_uid + AND v.session_id = e.session_id) as volume_count, (SELECT COUNT(*) FROM images i - WHERE i.embryo_uid = e.embryo_uid AND i.session_id = e.session_id) as image_count + WHERE i.embryo_uid = e.embryo_uid + AND i.session_id = e.session_id) as image_count FROM embryos e LEFT JOIN sessions s ON e.session_id = s.session_id WHERE e.embryo_uid = ? ORDER BY e.created_at ASC - """, (uid,)).fetchall() + """, + (uid,), + ).fetchall() return [ { @@ -1501,7 +1597,7 @@ def iter_images_by_uid( if row[8]: try: shape = tuple(json.loads(row[8])) - except: + except Exception: pass img_data = ImageData( @@ -1518,7 +1614,7 @@ def iter_images_by_uid( yield img_data - def get_embryos_with_multiple_sessions(self) -> List[Dict[str, Any]]: + def get_embryos_with_multiple_sessions(self) -> list[dict[str, Any]]: """ Get embryos that appear in multiple sessions (imported). @@ -1584,7 +1680,7 @@ def get_embryos_with_multiple_sessions(self) -> List[Dict[str, Any]]: for r in rows ] - def get_embryo_timeline_by_uid(self, uid: str) -> Dict[str, Any]: + def get_embryo_timeline_by_uid(self, uid: str) -> dict[str, Any]: """ Get complete cross-session timeline for an embryo. @@ -1611,7 +1707,8 @@ def get_embryo_timeline_by_uid(self, uid: str) -> Dict[str, Any]: # Get timepoint range and image count for this session if self.is_gently_schema: - stats = self.conn.execute(""" + stats = self.conn.execute( + """ SELECT MIN(v.acquired_at) as first_timestamp, MAX(v.acquired_at) as last_timestamp, @@ -1622,9 +1719,12 @@ def get_embryo_timeline_by_uid(self, uid: str) -> Dict[str, Any]: JOIN embryos e ON v.embryo_id = e.embryo_id AND v.session_id = e.session_id WHERE e.embryo_uid = ? AND v.session_id = ? - """, (uid, session_id)).fetchone() + """, + (uid, session_id), + ).fetchone() else: - stats = self.conn.execute(""" + stats = self.conn.execute( + """ SELECT MIN(v.timestamp) as first_timestamp, MAX(v.timestamp) as last_timestamp, @@ -1633,27 +1733,34 @@ def get_embryo_timeline_by_uid(self, uid: str) -> Dict[str, Any]: MAX(v.timepoint) as max_timepoint FROM volumes v WHERE v.embryo_uid = ? AND v.session_id = ? - """, (uid, session_id)).fetchone() + """, + (uid, session_id), + ).fetchone() # Get ground truth stages for this session - gt_rows = self.conn.execute(""" + gt_rows = self.conn.execute( + """ SELECT stage, start_timepoint FROM ground_truth WHERE embryo_id = ? AND session_id = ? ORDER BY start_timepoint - """, (instance["embryo_id"], session_id)).fetchall() - - timeline.append({ - "session_id": session_id, - "session_name": instance["session_name"], - "embryo_id": instance["embryo_id"], - "first_timestamp": stats[0] if stats else None, - "last_timestamp": stats[1] if stats else None, - "volume_count": stats[2] if stats else 0, - "timepoint_range": (stats[3], stats[4]) if stats else (None, None), - "ground_truth_stages": [ - {"stage": gt[0], "start_timepoint": gt[1]} for gt in gt_rows - ], - }) + """, + (instance["embryo_id"], session_id), + ).fetchall() + + timeline.append( + { + "session_id": session_id, + "session_name": instance["session_name"], + "embryo_id": instance["embryo_id"], + "first_timestamp": stats[0] if stats else None, + "last_timestamp": stats[1] if stats else None, + "volume_count": stats[2] if stats else 0, + "timepoint_range": (stats[3], stats[4]) if stats else (None, None), + "ground_truth_stages": [ + {"stage": gt[0], "start_timepoint": gt[1]} for gt in gt_rows + ], + } + ) return { "embryo_uid": uid, diff --git a/gently/dataset/explorer_server.py b/gently/dataset/explorer_server.py index 254a44cf..3c9e816d 100644 --- a/gently/dataset/explorer_server.py +++ b/gently/dataset/explorer_server.py @@ -12,32 +12,34 @@ """ import asyncio -import logging +import base64 import json +import logging from dataclasses import dataclass - -logger = logging.getLogger(__name__) from datetime import datetime from pathlib import Path -from typing import Optional, List, Tuple, Dict import numpy as np from fastapi import FastAPI, HTTPException, Query, WebSocket, WebSocketDisconnect -from fastapi.responses import HTMLResponse, JSONResponse, Response +from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel -from .schema import get_connection, get_database_stats, DEFAULT_DB_PATH -from .embryo_dataset import EmbryoDataset from gently.core.imaging import ( - normalize_to_uint8, - image_to_base64 as _image_to_base64, - load_volume, - compute_crop_bounds, apply_crop_bounds, + compute_crop_bounds, + load_volume, + normalize_to_uint8, projection_three_view, ) +from gently.core.imaging import ( + image_to_base64 as _image_to_base64, +) + +from .embryo_dataset import EmbryoDataset +from .schema import DEFAULT_DB_PATH, get_connection, get_database_stats +logger = logging.getLogger(__name__) # Lazy imports for explorer-specific projection functions tifffile = None PIL_Image = None @@ -48,9 +50,11 @@ def ensure_projection_deps(): global tifffile, PIL_Image if tifffile is None: import tifffile as _tifffile + tifffile = _tifffile if PIL_Image is None: from PIL import Image as _Image + PIL_Image = _Image @@ -68,15 +72,25 @@ def find_outer_boundary(img: np.ndarray, percentile: float = 50) -> np.ndarray: """Find outer boundary of embryo by thresholding and extracting mask edge.""" thresh = np.percentile(img, percentile) mask = img > thresh - padded = np.pad(mask, 1, mode='constant', constant_values=True) - eroded = (padded[:-2, :-2] & padded[:-2, 1:-1] & padded[:-2, 2:] & - padded[1:-1, :-2] & padded[1:-1, 1:-1] & padded[1:-1, 2:] & - padded[2:, :-2] & padded[2:, 1:-1] & padded[2:, 2:]) + padded = np.pad(mask, 1, mode="constant", constant_values=True) + eroded = ( + padded[:-2, :-2] + & padded[:-2, 1:-1] + & padded[:-2, 2:] + & padded[1:-1, :-2] + & padded[1:-1, 1:-1] + & padded[1:-1, 2:] + & padded[2:, :-2] + & padded[2:, 1:-1] + & padded[2:, 2:] + ) boundary = (mask & ~eroded).astype(np.uint8) * 255 return boundary -def overlay_edges(img: np.ndarray, edges: np.ndarray, color: Tuple[int, int, int] = (255, 200, 0)) -> np.ndarray: +def overlay_edges( + img: np.ndarray, edges: np.ndarray, color: tuple[int, int, int] = (255, 200, 0) +) -> np.ndarray: """Overlay edge contours on image in specified color.""" if img.ndim == 2: rgb = np.stack([img, img, img], axis=-1) @@ -89,8 +103,8 @@ def overlay_edges(img: np.ndarray, edges: np.ndarray, color: Tuple[int, int, int def projection_dual_view( volume: np.ndarray, - voxel_size: Tuple[float, float, float] = (1.0, 0.1625, 0.1625), -) -> Tuple[np.ndarray, str]: + voxel_size: tuple[float, float, float] = (1.0, 0.1625, 0.1625), +) -> tuple[np.ndarray, str]: """Dual-view projection: TOP above, SIDE below with boundary overlay. voxel_size: (dz, dy, dx) in microns. Used to compute the physically @@ -126,9 +140,9 @@ def projection_dual_view( def projection_depth_colored( volume: np.ndarray, - colormap: str = 'turbo', - voxel_size: Tuple[float, float, float] = (1.0, 0.1625, 0.1625), -) -> Tuple[np.ndarray, str]: + colormap: str = "turbo", + voxel_size: tuple[float, float, float] = (1.0, 0.1625, 0.1625), +) -> tuple[np.ndarray, str]: """Depth-colored max intensity projection. voxel_size: (dz, dy, dx) in microns. The side view is rescaled so the @@ -142,8 +156,9 @@ def projection_depth_colored( dz, dy, dx = voxel_size try: import matplotlib.pyplot as plt + cmap = plt.get_cmap(colormap) - except: + except Exception: cmap = None colored_volume = np.zeros((z_depth, height, width, 3), dtype=np.float32) for z in range(z_depth): @@ -168,7 +183,7 @@ def projection_depth_colored( return combined, f"Z-depth colored MIP ({colormap}): TOP + SIDE" -def projection_multi_slice(volume: np.ndarray, n_slices: int = 6) -> Tuple[np.ndarray, str]: +def projection_multi_slice(volume: np.ndarray, n_slices: int = 6) -> tuple[np.ndarray, str]: """Montage of N representative z-slices.""" if volume.ndim != 3: return normalize_image(volume), "2D input" @@ -184,7 +199,7 @@ def projection_multi_slice(volume: np.ndarray, n_slices: int = 6) -> Tuple[np.nd slices.append(np.zeros_like(slices[0])) rows = [] for r in range(n_rows): - row_slices = slices[r * n_cols:(r + 1) * n_cols] + row_slices = slices[r * n_cols : (r + 1) * n_cols] sep = np.ones((height, 2), dtype=np.uint8) * 64 row_with_sep = [] for i, s in enumerate(row_slices): @@ -198,10 +213,15 @@ def projection_multi_slice(volume: np.ndarray, n_slices: int = 6) -> Tuple[np.nd return montage, f"Multi-slice montage ({n_slices} slices)" -def render_volume_rotated(volume: np.ndarray, angle_y: float, angle_x: float = -0.5, - threshold: float = 0.12, num_slices: int = 48, - perspective: float = 0.4, - voxel_size: Tuple[float, float, float] = (1.0, 0.1625, 0.1625)) -> np.ndarray: +def render_volume_rotated( + volume: np.ndarray, + angle_y: float, + angle_x: float = -0.5, + threshold: float = 0.12, + num_slices: int = 48, + perspective: float = 0.4, + voxel_size: tuple[float, float, float] = (1.0, 0.1625, 0.1625), +) -> np.ndarray: """ Render volume from a rotated viewpoint with parallax and perspective. @@ -281,23 +301,25 @@ def render_volume_rotated(volume: np.ndarray, angle_y: float, angle_x: float = - perspective_scale = 1.0 + (depth_after_rotation * perspective * 1.5 / z_scale) perspective_scale = np.clip(perspective_scale, 0.5, 1.5) - slice_data_list.append({ - 'z_idx': z_idx, - 'shift_x': shift_x, - 'shift_y': shift_y, - 'depth': depth_after_rotation, - 'scale': perspective_scale, - }) + slice_data_list.append( + { + "z_idx": z_idx, + "shift_x": shift_x, + "shift_y": shift_y, + "depth": depth_after_rotation, + "scale": perspective_scale, + } + ) # Sort by depth (back to front for proper alpha compositing) - slice_data_list.sort(key=lambda s: s['depth']) + slice_data_list.sort(key=lambda s: s["depth"]) # Composite slices for slice_info in slice_data_list: - z_idx = slice_info['z_idx'] - shift_x = slice_info['shift_x'] - shift_y = slice_info['shift_y'] - scale = slice_info['scale'] + z_idx = slice_info["z_idx"] + shift_x = slice_info["shift_x"] + shift_y = slice_info["shift_y"] + scale = slice_info["scale"] # Get slice slice_img = vol[z_idx, :, :] @@ -347,8 +369,8 @@ def render_volume_rotated(volume: np.ndarray, angle_y: float, angle_x: float = - # Alpha composite for c in range(3): result[dst_y_start:dst_y_end, dst_x_start:dst_x_end, c] = ( - src_slice * src_alpha + - result[dst_y_start:dst_y_end, dst_x_start:dst_x_end, c] * (1 - src_alpha) + src_slice * src_alpha + + result[dst_y_start:dst_y_end, dst_x_start:dst_x_end, c] * (1 - src_alpha) ) # Crop to content (remove empty margins) @@ -374,8 +396,8 @@ def render_volume_rotated(volume: np.ndarray, angle_y: float, angle_x: float = - def projection_spin_3d( volume: np.ndarray, - voxel_size: Tuple[float, float, float] = (1.0, 0.1625, 0.1625), -) -> Tuple[np.ndarray, str]: + voxel_size: tuple[float, float, float] = (1.0, 0.1625, 0.1625), +) -> tuple[np.ndarray, str]: """Multiple 3D perspective views from different angles (2x3 grid). voxel_size: (dz, dy, dx) in microns. Forwarded to render_volume_rotated @@ -392,8 +414,11 @@ def projection_spin_3d( views = [] for angle_y in angles_y: view = render_volume_rotated( - volume, angle_y=angle_y, angle_x=base_tilt, - threshold=py_threshold, perspective=0.5, + volume, + angle_y=angle_y, + angle_x=base_tilt, + threshold=py_threshold, + perspective=0.5, voxel_size=voxel_size, ) views.append(view) @@ -409,10 +434,10 @@ def projection_spin_3d( pad_w = (target_w - w) // 2 if v.ndim == 3: p = np.zeros((target_h, target_w, 3), dtype=v.dtype) - p[pad_h:pad_h+h, pad_w:pad_w+w] = v + p[pad_h : pad_h + h, pad_w : pad_w + w] = v else: p = np.zeros((target_h, target_w), dtype=v.dtype) - p[pad_h:pad_h+h, pad_w:pad_w+w] = v + p[pad_h : pad_h + h, pad_w : pad_w + w] = v padded.append(p) row1 = np.hstack(padded[0:3]) @@ -423,11 +448,11 @@ def projection_spin_3d( PROJECTION_METHODS = { - 'dual_view': projection_dual_view, - 'depth_colored': projection_depth_colored, - 'multi_slice': projection_multi_slice, - 'three_view': projection_three_view, - 'spin_3d': projection_spin_3d, + "dual_view": projection_dual_view, + "depth_colored": projection_depth_colored, + "multi_slice": projection_multi_slice, + "three_view": projection_three_view, + "spin_3d": projection_spin_3d, } @@ -438,9 +463,11 @@ def projection_spin_3d( # Presence Tracking (Collaborative Feature) # ============================================================================= + @dataclass class ClientInfo: """Information about a connected WebSocket client for presence tracking""" + client_id: str name: str color: str # Hex color for avatar background @@ -452,13 +479,25 @@ class ConnectionManager: # Colors for avatar backgrounds (pleasant, distinct colors) AVATAR_COLORS = [ - '#4a9eff', '#ff6b6b', '#51cf66', '#ffd43b', '#cc5de8', - '#ff922b', '#20c997', '#748ffc', '#f06595', '#69db7c', - '#ffa94d', '#9775fa', '#38d9a9', '#e599f7', '#74c0fc' + "#4a9eff", + "#ff6b6b", + "#51cf66", + "#ffd43b", + "#cc5de8", + "#ff922b", + "#20c997", + "#748ffc", + "#f06595", + "#69db7c", + "#ffa94d", + "#9775fa", + "#38d9a9", + "#e599f7", + "#74c0fc", ] def __init__(self): - self.active_connections: Dict[WebSocket, ClientInfo] = {} + self.active_connections: dict[WebSocket, ClientInfo] = {} self._lock = asyncio.Lock() def _generate_color(self, client_id: str) -> str: @@ -466,12 +505,15 @@ def _generate_color(self, client_id: str) -> str: hash_val = sum(ord(c) for c in client_id) return self.AVATAR_COLORS[hash_val % len(self.AVATAR_COLORS)] - async def connect(self, websocket: WebSocket, client_id: str = None, name: str = None): + async def connect( + self, websocket: WebSocket, client_id: str | None = None, name: str | None = None + ): await websocket.accept() # Generate defaults if not provided if not client_id: import uuid + client_id = str(uuid.uuid4())[:8] if not name: name = f"Anonymous {client_id[:4]}" @@ -480,12 +522,14 @@ async def connect(self, websocket: WebSocket, client_id: str = None, name: str = client_id=client_id, name=name, color=self._generate_color(client_id), - connected_at=datetime.now().isoformat() + connected_at=datetime.now().isoformat(), ) async with self._lock: self.active_connections[websocket] = client_info - logger.info(f"WebSocket connected: {name} ({client_id}). Total: {len(self.active_connections)}") + logger.info( + f"WebSocket connected: {name} ({client_id}). Total: {len(self.active_connections)}" + ) # Broadcast updated presence to all clients await self.broadcast_presence() @@ -494,7 +538,9 @@ async def disconnect(self, websocket: WebSocket): async with self._lock: client_info = self.active_connections.pop(websocket, None) if client_info: - logger.info(f"WebSocket disconnected: {client_info.name}. Total: {len(self.active_connections)}") + logger.info( + f"WebSocket disconnected: {client_info.name}. Total: {len(self.active_connections)}" + ) else: logger.info(f"WebSocket disconnected. Total: {len(self.active_connections)}") @@ -510,11 +556,11 @@ async def update_client_name(self, websocket: WebSocket, name: str): client_id=old_info.client_id, name=name, color=old_info.color, - connected_at=old_info.connected_at + connected_at=old_info.connected_at, ) await self.broadcast_presence() - def get_client_info(self, websocket: WebSocket) -> Optional[ClientInfo]: + def get_client_info(self, websocket: WebSocket) -> ClientInfo | None: """Get client info for a websocket""" return self.active_connections.get(websocket) @@ -526,12 +572,12 @@ async def broadcast_presence(self): # Deduplicate by client_id (same user in multiple tabs = one avatar) async with self._lock: seen_clients = {} - for ws, info in self.active_connections.items(): + for _ws, info in self.active_connections.items(): # Keep the most recent entry for each client_id seen_clients[info.client_id] = { - 'client_id': info.client_id, - 'name': info.name, - 'color': info.color + "client_id": info.client_id, + "name": info.name, + "color": info.color, } clients_list = list(seen_clients.values()) @@ -541,14 +587,8 @@ async def broadcast_presence(self): try: personalized = [] for client in clients_list: - personalized.append({ - **client, - 'is_you': client['client_id'] == info.client_id - }) - await ws.send_json({ - 'type': 'presence', - 'clients': personalized - }) + personalized.append({**client, "is_you": client["client_id"] == info.client_id}) + await ws.send_json({"type": "presence", "clients": personalized}) except Exception: disconnected.append(ws) @@ -564,15 +604,15 @@ class GroundTruthCreate(BaseModel): embryo_id: str stage: str start_timepoint: int - end_timepoint: Optional[int] = None - annotator: Optional[str] = None - notes: Optional[str] = None + end_timepoint: int | None = None + annotator: str | None = None + notes: str | None = None class GroundTruthDelete(BaseModel): session_id: str embryo_id: str - stage: Optional[str] = None + stage: str | None = None class DatasetExplorer: @@ -658,8 +698,7 @@ async def get_session(session_id: str): """Get session details.""" conn = get_connection(self.db_path) session = conn.execute( - "SELECT * FROM sessions WHERE session_id = ?", - (session_id,) + "SELECT * FROM sessions WHERE session_id = ?", (session_id,) ).fetchone() conn.close() @@ -674,8 +713,8 @@ async def get_session(session_id: str): @app.get("/api/embryos") async def list_embryos( - session_id: Optional[str] = None, - has_ground_truth: Optional[bool] = None, + session_id: str | None = None, + has_ground_truth: bool | None = None, ): """List embryos with optional filters.""" embryos = [] @@ -683,15 +722,17 @@ async def list_embryos( session_id=session_id, has_ground_truth=has_ground_truth, ): - embryos.append({ - "embryo_id": embryo.embryo_id, - "session_id": embryo.session_id, - "num_images": embryo.num_images, - "num_volumes": embryo.num_volumes, - "timepoint_range": embryo.timepoint_range, - "has_ground_truth": embryo.has_ground_truth, - "ground_truth_stages": embryo.ground_truth_stages, - }) + embryos.append( + { + "embryo_id": embryo.embryo_id, + "session_id": embryo.session_id, + "num_images": embryo.num_images, + "num_volumes": embryo.num_volumes, + "timepoint_range": embryo.timepoint_range, + "has_ground_truth": embryo.has_ground_truth, + "ground_truth_stages": embryo.ground_truth_stages, + } + ) return embryos @app.get("/api/embryos/{session_id}/{embryo_id}") @@ -745,8 +786,8 @@ async def get_embryos_with_multiple_sessions(): async def list_images( session_id: str, embryo_id: str, - start_tp: Optional[int] = None, - end_tp: Optional[int] = None, + start_tp: int | None = None, + end_tp: int | None = None, ): """List images for an embryo (without image data).""" timepoint_range = None @@ -828,7 +869,10 @@ async def create_ground_truth(data: GroundTruthCreate): notes=data.notes, ) end_str = f"-{data.end_timepoint}" if data.end_timepoint else "" - return {"status": "ok", "message": f"Set {data.stage} @ t={data.start_timepoint}{end_str}"} + return { + "status": "ok", + "message": f"Set {data.stage} @ t={data.start_timepoint}{end_str}", + } @app.delete("/api/ground_truth") async def delete_ground_truth(data: GroundTruthDelete): @@ -857,7 +901,7 @@ async def get_run_metrics(run_id: int): @app.get("/api/runs/{run_id}/predictions") async def get_run_predictions( run_id: int, - embryo_id: Optional[str] = None, + embryo_id: str | None = None, limit: int = Query(100, le=1000), offset: int = 0, ): @@ -895,31 +939,38 @@ async def get_timeline(session_id: str, embryo_id: str): """ # Get all images images = [] - for idx, img in enumerate(self.dataset.iter_images( - embryo_id=embryo_id, - session_id=session_id, - load_image_data=False, - )): - images.append({ - "index": idx, - "timepoint": img.timepoint, - "timestamp": img.timestamp, - "ground_truth_stage": img.ground_truth_stage, - "uid": img.uid, - "volume_path": img.volume_path, - }) + for idx, img in enumerate( + self.dataset.iter_images( + embryo_id=embryo_id, + session_id=session_id, + load_image_data=False, + ) + ): + images.append( + { + "index": idx, + "timepoint": img.timepoint, + "timestamp": img.timestamp, + "ground_truth_stage": img.ground_truth_stage, + "uid": img.uid, + "volume_path": img.volume_path, + } + ) # Get ground truth transitions ground_truth = self.dataset.get_ground_truth(session_id, embryo_id) # Get predictions if any conn = get_connection(self.db_path) - predictions = conn.execute(""" + predictions = conn.execute( + """ SELECT timepoint, predicted_stage, confidence, perception_run_id, reasoning FROM predictions WHERE session_id = ? AND embryo_id = ? ORDER BY perception_run_id DESC, timepoint - """, (session_id, embryo_id)).fetchall() + """, + (session_id, embryo_id), + ).fetchall() conn.close() return { @@ -949,7 +1000,7 @@ async def get_perception_trace(session_id: str, embryo_id: str, timepoint: int): return trace_data except Exception as e: logger.error(f"Failed to read trace file: {e}") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException(status_code=500, detail=str(e)) from e @app.get("/api/traces/{session_id}/{embryo_id}") async def list_perception_traces(session_id: str, embryo_id: str): @@ -976,7 +1027,8 @@ async def get_unified_timeline(embryo_uid: str): # Get all volumes for this embryo UID # Sort by session (using min timestamp per session) then by timepoint within session conn = get_connection(self.db_path) - rows = conn.execute(""" + rows = conn.execute( + """ SELECT v.uid, v.session_id, @@ -984,11 +1036,14 @@ async def get_unified_timeline(embryo_uid: str): v.timepoint, v.timestamp, v.file_path, - (SELECT MIN(v2.uid) FROM volumes v2 WHERE v2.session_id = v.session_id) as session_min_uid + (SELECT MIN(v2.uid) FROM volumes v2 + WHERE v2.session_id = v.session_id) as session_min_uid FROM volumes v WHERE v.embryo_uid = ? ORDER BY session_min_uid ASC, v.uid ASC - """, (embryo_uid,)).fetchall() + """, + (embryo_uid,), + ).fetchall() if not rows: conn.close() @@ -999,16 +1054,20 @@ async def get_unified_timeline(embryo_uid: str): session_embryos = set((r[1], r[2]) for r in rows) gt_maps = {} for session_id, embryo_id in session_embryos: - gt_rows = conn.execute(""" + gt_rows = conn.execute( + """ SELECT stage, start_timepoint, end_timepoint FROM ground_truth WHERE session_id = ? AND embryo_id = ? ORDER BY start_timepoint - """, (session_id, embryo_id)).fetchall() + """, + (session_id, embryo_id), + ).fetchall() gt_maps[(session_id, embryo_id)] = gt_rows conn.close() - # Build unified image list with ground truth based on row index within each session/embryo + # Build unified image list with ground truth based on row index + # within each session/embryo images = [] session_embryo_counts = {} # Track row index within each session/embryo @@ -1029,16 +1088,18 @@ async def get_unified_timeline(embryo_uid: str): gt_stage = stage break - images.append({ - "index": idx, - "uid": r[0], - "session_id": session_id, - "embryo_id": embryo_id, - "timepoint": r[3], - "timestamp": r[4], - "file_path": r[5], - "ground_truth_stage": gt_stage, - }) + images.append( + { + "index": idx, + "uid": r[0], + "session_id": session_id, + "embryo_id": embryo_id, + "timepoint": r[3], + "timestamp": r[4], + "file_path": r[5], + "ground_truth_stage": gt_stage, + } + ) # Get unique sessions sessions = list(set(img["session_id"] for img in images)) @@ -1085,6 +1146,7 @@ async def get_volume_data(session_id: str, embryo_id: str, index: int): # Apply Gaussian blur along Z axis to reduce banding at side views from scipy import ndimage + vol_norm = ndimage.gaussian_filter1d(vol_norm, sigma=1.0, axis=0) vol_uint8 = (vol_norm * 255).astype(np.uint8) @@ -1126,11 +1188,13 @@ async def get_projections(session_id: str, embryo_id: str, index: int): for method_name, method_func in PROJECTION_METHODS.items(): try: proj_img, desc = method_func(vol) - projections.append({ - "method": method_name, - "description": desc, - "data": image_to_base64(proj_img), - }) + projections.append( + { + "method": method_name, + "description": desc, + "data": image_to_base64(proj_img), + } + ) except Exception as e: logger.warning(f"Projection {method_name} failed: {e}") @@ -1159,12 +1223,12 @@ async def websocket_endpoint(websocket: WebSocket): try: while True: data = await websocket.receive_json() - msg_type = data.get('type') + msg_type = data.get("type") - if msg_type == 'join': + if msg_type == "join": # Client joining with ID and name - client_id = data.get('client_id') - name = data.get('name') + client_id = data.get("client_id") + name = data.get("name") # Update client info async with self.manager._lock: if websocket in self.manager.active_connections: @@ -1172,18 +1236,20 @@ async def websocket_endpoint(websocket: WebSocket): self.manager.active_connections[websocket] = ClientInfo( client_id=client_id or old_info.client_id, name=name or old_info.name, - color=self.manager._generate_color(client_id or old_info.client_id), - connected_at=old_info.connected_at + color=self.manager._generate_color( + client_id or old_info.client_id + ), + connected_at=old_info.connected_at, ) await self.manager.broadcast_presence() - elif msg_type == 'update_name': + elif msg_type == "update_name": # Client updating their display name - name = data.get('name') + name = data.get("name") if name: await self.manager.update_client_name(websocket, name) - elif msg_type == 'get_presence': + elif msg_type == "get_presence": # Client requesting current presence list await self.manager.broadcast_presence() @@ -1222,6 +1288,7 @@ def _get_projections_html(self, session_id: str, embryo_id: str, index: int) -> def run(self): """Start the server.""" import uvicorn + logger.info("=== Embryo Dataset Explorer ===") logger.info("Database: %s", self.db_path) logger.info("Open http://localhost:%d in your browser", self.port) diff --git a/gently/dataset/schema.py b/gently/dataset/schema.py index 927b95f7..c806b6c8 100644 --- a/gently/dataset/schema.py +++ b/gently/dataset/schema.py @@ -13,11 +13,10 @@ - reasoning_traces: Full VLM reasoning traces """ -import sqlite3 import logging -from pathlib import Path -from typing import Optional +import sqlite3 from contextlib import contextmanager +from pathlib import Path logger = logging.getLogger(__name__) @@ -25,7 +24,7 @@ DEFAULT_DB_PATH = Path("D:/gently/dataset.db") -def get_connection(db_path: Optional[Path] = None) -> sqlite3.Connection: +def get_connection(db_path: Path | None = None) -> sqlite3.Connection: """ Get a database connection with optimized settings. @@ -59,7 +58,7 @@ def transaction(conn: sqlite3.Connection): raise -def init_database(db_path: Optional[Path] = None) -> sqlite3.Connection: +def init_database(db_path: Path | None = None) -> sqlite3.Connection: """ Initialize the database with all tables. @@ -84,7 +83,7 @@ def init_database(db_path: Optional[Path] = None) -> sqlite3.Connection: # Set version conn.execute( "INSERT OR REPLACE INTO metadata (key, value) VALUES ('version', ?)", - (str(DATABASE_VERSION),) + (str(DATABASE_VERSION),), ) conn.commit() @@ -378,7 +377,7 @@ def migrate_to_v2(conn: sqlite3.Connection) -> bool: cursor = conn.execute("PRAGMA table_info(perception_runs)") columns = {row[1] for row in cursor.fetchall()} - if 'trace_type' in columns: + if "trace_type" in columns: logger.info("Database already at v2, no migration needed") return False @@ -397,7 +396,7 @@ def migrate_to_v2(conn: sqlite3.Connection) -> bool: # Check if file_path column exists in reasoning_traces cursor = conn.execute("PRAGMA table_info(reasoning_traces)") trace_columns = {row[1] for row in cursor.fetchall()} - if 'file_path' not in trace_columns: + if "file_path" not in trace_columns: statements.append("ALTER TABLE reasoning_traces ADD COLUMN file_path TEXT") for stmt in statements: @@ -432,7 +431,7 @@ def migrate_to_v3(conn: sqlite3.Connection) -> bool: cursor = conn.execute("PRAGMA table_info(embryos)") columns = {row[1] for row in cursor.fetchall()} - if 'embryo_uid' in columns: + if "embryo_uid" in columns: logger.info("Database already at v3, no migration needed") return False @@ -513,7 +512,7 @@ def migrate_to_v4(conn: sqlite3.Connection) -> bool: cursor = conn.execute("PRAGMA table_info(ground_truth)") columns = {row[1] for row in cursor.fetchall()} - if 'end_timepoint' in columns: + if "end_timepoint" in columns: logger.info("Database already at v4, no migration needed") return False @@ -542,28 +541,33 @@ def get_database_stats(conn: sqlite3.Connection) -> dict: """ stats = {} - tables = ['sessions', 'embryos', 'volumes', 'images', 'ground_truth', - 'perception_runs', 'predictions'] + tables = [ + "sessions", + "embryos", + "volumes", + "images", + "ground_truth", + "perception_runs", + "predictions", + ] for table in tables: count = conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0] stats[table] = count # Additional stats - stats['unique_embryo_sessions'] = conn.execute( + stats["unique_embryo_sessions"] = conn.execute( "SELECT COUNT(DISTINCT session_id || embryo_id) FROM volumes" ).fetchone()[0] # Count unique embryos with ground truth (more useful than annotation count) - stats['embryos_with_gt'] = conn.execute( + stats["embryos_with_gt"] = conn.execute( "SELECT COUNT(DISTINCT session_id || '|' || embryo_id) FROM ground_truth" ).fetchone()[0] # Date range - result = conn.execute( - "SELECT MIN(timestamp), MAX(timestamp) FROM volumes" - ).fetchone() - stats['earliest_volume'] = result[0] - stats['latest_volume'] = result[1] + result = conn.execute("SELECT MIN(timestamp), MAX(timestamp) FROM volumes").fetchone() + stats["earliest_volume"] = result[0] + stats["latest_volume"] = result[1] return stats diff --git a/gently/detection.py b/gently/detection.py index 72adf7e5..4c4e567e 100644 --- a/gently/detection.py +++ b/gently/detection.py @@ -10,22 +10,31 @@ """ import logging + import numpy as np -from typing import Optional, Tuple, List -from scipy.ndimage import uniform_filter, gaussian_filter, sobel, label, binary_opening, binary_closing +from scipy.ndimage import ( + binary_closing, + binary_opening, + gaussian_filter, + label, + sobel, + uniform_filter, +) logger = logging.getLogger(__name__) -def detect_embryo_roi(image: np.ndarray, - kernel_size_fraction: float = 0.1, - min_kernel_size: int = 24, - variance_weight: float = 0.4, - gradient_weight: float = 0.4, - intensity_weight: float = 0.2, - threshold_std: float = 1.5, - min_area_fraction: float = 0.05, - max_area_fraction: float = 0.5) -> Optional[Tuple[int, int, int, int]]: +def detect_embryo_roi( + image: np.ndarray, + kernel_size_fraction: float = 0.1, + min_kernel_size: int = 24, + variance_weight: float = 0.4, + gradient_weight: float = 0.4, + intensity_weight: float = 0.2, + threshold_std: float = 1.5, + min_area_fraction: float = 0.05, + max_area_fraction: float = 0.5, +) -> tuple[int, int, int, int] | None: """ Detect embryo region of interest for sparse bottom camera images @@ -83,9 +92,11 @@ def detect_embryo_roi(image: np.ndarray, intensity_contrast = np.abs(img_norm - background_level) # Combine detection methods - combined_score = (variance_weight * local_var + - gradient_weight * gradient_mag + - intensity_weight * intensity_contrast) + combined_score = ( + variance_weight * local_var + + gradient_weight * gradient_mag + + intensity_weight * intensity_contrast + ) # Find regions above threshold threshold = np.mean(combined_score) + threshold_std * np.std(combined_score) @@ -104,7 +115,7 @@ def detect_embryo_roi(image: np.ndarray, # Find largest connected component (most likely embryo) region_sizes = [(labeled_regions == i).sum() for i in range(1, num_regions + 1)] largest_region_idx = np.argmax(region_sizes) + 1 - largest_region_mask = (labeled_regions == largest_region_idx) + largest_region_mask = labeled_regions == largest_region_idx # Get bounding box with padding rows, cols = np.where(largest_region_mask) @@ -123,7 +134,7 @@ def detect_embryo_roi(image: np.ndarray, total_area = image.shape[0] * image.shape[1] # Check if ROI is reasonable size - if min_area_fraction <= roi_area/total_area <= max_area_fraction: + if min_area_fraction <= roi_area / total_area <= max_area_fraction: return (x_min, y_min, w, h) else: return None @@ -133,9 +144,9 @@ def detect_embryo_roi(image: np.ndarray, return None -def detect_multiple_embryos(image: np.ndarray, - max_embryos: int = 5, - min_separation: int = 50) -> List[Tuple[int, int, int, int]]: +def detect_multiple_embryos( + image: np.ndarray, max_embryos: int = 5, min_separation: int = 50 +) -> list[tuple[int, int, int, int]]: """ Detect multiple embryo regions in bottom camera images @@ -187,7 +198,7 @@ def detect_multiple_embryos(image: np.ndarray, # Get all regions with their sizes and centers regions = [] for i in range(1, num_regions + 1): - region_mask = (labeled_regions == i) + region_mask = labeled_regions == i region_size = region_mask.sum() # Skip tiny regions @@ -211,14 +222,16 @@ def detect_multiple_embryos(image: np.ndarray, w = x_max - x_min h = y_max - y_min - regions.append({ - 'roi': (x_min, y_min, w, h), - 'center': (x_center, y_center), - 'size': region_size - }) + regions.append( + { + "roi": (x_min, y_min, w, h), + "center": (x_center, y_center), + "size": region_size, + } + ) # Sort by size (largest first) - regions.sort(key=lambda x: x['size'], reverse=True) + regions.sort(key=lambda x: x["size"], reverse=True) # Filter out overlapping regions selected_regions = [] @@ -226,14 +239,15 @@ def detect_multiple_embryos(image: np.ndarray, if len(selected_regions) >= max_embryos: break - center = region['center'] + center = region["center"] # Check separation from already selected regions too_close = False for selected in selected_regions: - selected_center = selected['center'] - distance = np.sqrt((center[0] - selected_center[0])**2 + - (center[1] - selected_center[1])**2) + selected_center = selected["center"] + distance = np.sqrt( + (center[0] - selected_center[0]) ** 2 + (center[1] - selected_center[1]) ** 2 + ) if distance < min_separation: too_close = True break @@ -241,14 +255,14 @@ def detect_multiple_embryos(image: np.ndarray, if not too_close: selected_regions.append(region) - return [r['roi'] for r in selected_regions] + return [r["roi"] for r in selected_regions] except Exception as e: logger.error("Multiple embryo detection failed: %s", e) return [] -def get_embryo_focus_roi(image: np.ndarray) -> Optional[Tuple[int, int, int, int]]: +def get_embryo_focus_roi(image: np.ndarray) -> tuple[int, int, int, int] | None: """ Simple interface for getting embryo ROI for focus analysis @@ -264,4 +278,4 @@ def get_embryo_focus_roi(image: np.ndarray) -> Optional[Tuple[int, int, int, int Optional[Tuple[int, int, int, int]] ROI as (x, y, width, height) or None to use full image """ - return detect_embryo_roi(image) \ No newline at end of file + return detect_embryo_roi(image) diff --git a/gently/eval/__init__.py b/gently/eval/__init__.py new file mode 100644 index 00000000..0db29d8f --- /dev/null +++ b/gently/eval/__init__.py @@ -0,0 +1,35 @@ +"""Eval / replay / shadow primitives. + +Substrate for testing orchestrator architectures without running real +hardware. The three layers: + + EventCapture — records every EventBus event to a per-session jsonl + file so the agent's input stream is durable. + EventReplay — reads a captured jsonl and republishes events to a + target bus, preserving original timestamps. + ShadowRunner — hosts candidate orchestrators that subscribe to the + live (or replayed) bus, log their decisions, and + never touch hardware. Diff their decision logs to + compare architectures. + +See docs/EVAL.md (TODO) for usage. +""" + +from .candidates import ReactiveCandidate +from .decision_log import Decision, DecisionLog, DecisionTrigger, prompt_hash +from .event_capture import EventCapture +from .event_replay import EventReplay +from .shadow import NoOpCandidate, OrchestratorCandidate, ShadowRunner + +__all__ = [ + "EventCapture", + "EventReplay", + "Decision", + "DecisionLog", + "DecisionTrigger", + "prompt_hash", + "OrchestratorCandidate", + "ShadowRunner", + "NoOpCandidate", + "ReactiveCandidate", +] diff --git a/gently/eval/candidates.py b/gently/eval/candidates.py new file mode 100644 index 00000000..d4d9937e --- /dev/null +++ b/gently/eval/candidates.py @@ -0,0 +1,270 @@ +"""Canned shadow orchestrator candidates. + +NoOpCandidate lives in shadow.py as the trivial baseline. Anything more +interesting — even pure-rule architectures with state — lives here. As +LLM-driven candidates land they should slot into this module too. + +Conventions every candidate should keep: + - It maintains its own tiny world model. The production agent's + `experiment` is intentionally not shared (a candidate that mutates + production state would defeat the point of shadow mode). + - Decisions go through `log_decision`. Never call hardware tools. + - State updates from events are cheap (no LLM, no I/O). +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from typing import Any + +from gently.core.event_bus import Event + +from .decision_log import DecisionLog, DecisionTrigger +from .shadow import OrchestratorCandidate + +logger = logging.getLogger(__name__) + + +@dataclass +class _ReactiveWorldModel: + """The tiniest possible world model — everything ReactiveCandidate + needs to make rule-based decisions without re-reading the agent.""" + + # {embryo_id: {"coarse": {x, y} | None, "fine": {x, y} | None, + # "has_fine": bool, "confidence": float}} + embryos: dict[str, dict[str, Any]] = field(default_factory=dict) + + # Last live stage XY (µm) from a STAGE_MOVED event. + last_stage_um: dict[str, float] | None = None + + # Last error message + timestamp, so the candidate can avoid + # spam-proposing escalations for the same recurring failure. + last_error: dict[str, Any] | None = None + + # Count of events seen, by type name — useful debug field that also + # ends up in the decision context_summary. + seen: dict[str, int] = field(default_factory=dict) + + +class ReactiveCandidate(OrchestratorCandidate): + """Pure-rule reactive shadow orchestrator. + + The thesis being tested: *can a rule-based responder do the + routine bookkeeping that today only happens when the operator + chats with Claude?* + + Reactions + --------- + OPERATOR_EDITED_EMBRYO + Operator moved an embryo on the Map. The PUT also clears fine + position. Propose `recalibrate_embryo(embryo_id)` so the new + coarse position gets a SPIM-fine alignment before the next + acquisition. If `fine_position_invalidated` was False (no fine + existed yet) skip the proposal — there's nothing to refresh. + + OPERATOR_MARKED_EMBRYOS + Operator just confirmed a fresh set of embryos via the marking + canvas. Propose `calibrate_all_embryos` to bring them all into + focus. Cheap pattern: kick off calibration the moment sightings + land, instead of waiting for the operator to type it. + + OPERATOR_REMOVED_EMBRYO + Operator deleted an embryo. Propose a tidy-up step + `forget_embryo(embryo_id)` for any candidate that wants to + clean caches / learnings keyed on the gone embryo. No-op for + production today (state mutation already happened); the + proposal is reserved for downstream cleanup tools. + + ERROR_OCCURRED + Propose `escalate_to_operator(error_message)` once per distinct + error. Suppresses if the same error fires twice within 30s — + avoids drowning the operator in repeat alarms. + + EMBRYOS_UPDATE / STAGE_MOVED + Update the world model. No decision logged (silent ingest). + + """ + + # If two ERROR_OCCURRED events with the same message arrive within + # this window, only the first proposes an escalation. + ERROR_SUPPRESS_WINDOW_SEC = 30.0 + + def __init__(self, name: str, decisions: DecisionLog): + super().__init__(name, decisions) + self.world = _ReactiveWorldModel() + + # ---- event handlers ---------------------------------------------------- + + def on_event(self, event: Event) -> None: + name = event.event_type.name + self.world.seen[name] = self.world.seen.get(name, 0) + 1 + + # Always ingest state-shaped events first. + if name == "EMBRYOS_UPDATE": + self._ingest_embryos_update(event) + return + if name == "STAGE_MOVED": + self._ingest_stage_moved(event) + return + + # Operator + error events produce decisions. + if name == "OPERATOR_EDITED_EMBRYO": + self._react_operator_edited(event) + return + if name == "OPERATOR_MARKED_EMBRYOS": + self._react_operator_marked(event) + return + if name == "OPERATOR_REMOVED_EMBRYO": + self._react_operator_removed(event) + return + if name == "ERROR_OCCURRED": + self._react_error(event) + return + + # ---- ingests ----------------------------------------------------------- + + def _ingest_embryos_update(self, event: Event) -> None: + embryos = (event.data or {}).get("embryos") or [] + new_world: dict[str, dict[str, Any]] = {} + for emb in embryos: + new_world[emb.get("id", "")] = { + "coarse": emb.get("position_coarse"), + "fine": emb.get("position_fine"), + "has_fine": bool(emb.get("has_fine_position")), + "confidence": emb.get("detection_confidence", 0.0), + } + self.world.embryos = new_world + + def _ingest_stage_moved(self, event: Event) -> None: + d = event.data or {} + if "x" in d and "y" in d: + self.world.last_stage_um = {"x": float(d["x"]), "y": float(d["y"])} + + # ---- reactions --------------------------------------------------------- + + def _react_operator_edited(self, event: Event) -> None: + data = event.data or {} + eid = data.get("embryo_id") or "" + invalidated = bool(data.get("fine_position_invalidated")) + tool_calls: list[dict[str, Any]] = [] + # Only propose a recalibration when there was a fine position + # that the edit just invalidated. New coarse without any prior + # fine has nothing to refresh yet. + if invalidated: + tool_calls.append( + { + "name": "recalibrate_embryo", + "input": {"embryo_id": eid}, + "id": None, + } + ) + self.log_decision( + trigger=DecisionTrigger.EVENT, + trigger_detail="OPERATOR_EDITED_EMBRYO", + tool_calls=tool_calls, + response_text=( + f"Operator moved {eid}; proposing recalibration." + if invalidated + else f"Operator moved {eid}; no prior fine -- no action." + ), + recent_event_ids=[event.event_id], + context_summary=self._summary(), + ) + + def _react_operator_marked(self, event: Event) -> None: + data = event.data or {} + ids = data.get("embryo_ids") or [] + count = data.get("count", len(ids)) + tool_calls: list[dict[str, Any]] = [] + if count: + tool_calls.append( + { + "name": "calibrate_all_embryos", + "input": {"embryo_ids": list(ids)}, + "id": None, + } + ) + self.log_decision( + trigger=DecisionTrigger.EVENT, + trigger_detail="OPERATOR_MARKED_EMBRYOS", + tool_calls=tool_calls, + response_text=( + f"Operator marked {count} embryos; proposing calibration." + if count + else "Operator marked zero embryos; no action." + ), + recent_event_ids=[event.event_id], + context_summary=self._summary(), + ) + + def _react_operator_removed(self, event: Event) -> None: + data = event.data or {} + eid = data.get("embryo_id") or "" + self.log_decision( + trigger=DecisionTrigger.EVENT, + trigger_detail="OPERATOR_REMOVED_EMBRYO", + tool_calls=[ + { + "name": "forget_embryo", + "input": {"embryo_id": eid}, + "id": None, + } + ], + response_text=f"Operator removed {eid}; proposing cache tidy-up.", + recent_event_ids=[event.event_id], + context_summary=self._summary(), + ) + + def _react_error(self, event: Event) -> None: + from datetime import datetime + + data = event.data or {} + msg = str(data.get("msg") or data.get("error") or data.get("message") or "unknown") + now = datetime.now() + prior = self.world.last_error + suppress = ( + prior is not None + and prior.get("msg") == msg + and (now - prior["ts"]).total_seconds() < self.ERROR_SUPPRESS_WINDOW_SEC + ) + self.world.last_error = {"msg": msg, "ts": now} + if suppress: + self.log_decision( + trigger=DecisionTrigger.EVENT, + trigger_detail="ERROR_OCCURRED", + tool_calls=[], + response_text=( + f"Suppressed repeat error within" + f" {self.ERROR_SUPPRESS_WINDOW_SEC:.0f}s window: {msg[:120]}" + ), + recent_event_ids=[event.event_id], + context_summary=self._summary(), + ) + return + self.log_decision( + trigger=DecisionTrigger.EVENT, + trigger_detail="ERROR_OCCURRED", + tool_calls=[ + { + "name": "escalate_to_operator", + "input": {"error_message": msg, "source": event.source}, + "id": None, + } + ], + response_text=f"New error -- proposing escalation: {msg[:120]}", + recent_event_ids=[event.event_id], + context_summary=self._summary(), + ) + + # ---- helpers ----------------------------------------------------------- + + def _summary(self) -> str: + n_emb = len(self.world.embryos) + n_fine = sum(1 for v in self.world.embryos.values() if v.get("has_fine")) + stage = self.world.last_stage_um + stage_str = f"({stage['x']:.1f}, {stage['y']:.1f})" if stage else "unknown" + seen = sum(self.world.seen.values()) + return ( + f"{n_emb} embryos ({n_fine} fine-calibrated); stage {stage_str}; {seen} events ingested" + ) diff --git a/gently/eval/decision_log.py b/gently/eval/decision_log.py new file mode 100644 index 00000000..f8aa334d --- /dev/null +++ b/gently/eval/decision_log.py @@ -0,0 +1,195 @@ +"""DecisionLog — records each "decision moment" the orchestrator (or a +shadow candidate) acts on. + +A "decision moment" is whenever the agent wakes up and produces an output: +a Claude tool call, a refusal, a chat reply, or even an explicit no-op +("I see what happened, nothing to do"). Capturing these gives us the diff +substrate for shadow-mode A/B: same input event stream, different +candidates, compare what each decided. + +File format: one JSON object per line, written to +D:/Gently3/sessions/{id}/decisions.jsonl (or wherever the caller points +it). Lossless enough to reconstruct what the agent saw + chose, terse +enough to skim across sessions. +""" + +from __future__ import annotations + +import hashlib +import json +import logging +import threading +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from pathlib import Path +from typing import Any + +from .event_capture import _json_default + +logger = logging.getLogger(__name__) + + +def prompt_hash(system_prompt: Any, messages: Any) -> str: + """Stable short fingerprint of the input the orchestrator saw. + + Two candidates seeing byte-identical (system_prompt, messages) get + the same hash; a difference here means they're working from different + context, so any decision divergence is expected. Used in shadow A/B + to filter out apples-to-oranges comparisons. + + SHA-256 truncated to 16 hex chars — enough to make accidental + collisions vanishingly unlikely at the scale of one session's + decisions, short enough to skim by eye in a log. + """ + h = hashlib.sha256() + if isinstance(system_prompt, str): + h.update(system_prompt.encode("utf-8")) + else: + h.update(json.dumps(system_prompt, sort_keys=True, default=_json_default).encode("utf-8")) + h.update(b"\x1f") # separator so prompt boundary can't be ambiguous + h.update(json.dumps(messages, sort_keys=True, default=_json_default).encode("utf-8")) + return h.hexdigest()[:16] + + +class DecisionTrigger(str, Enum): + """What woke the agent up for this decision moment.""" + + USER_MESSAGE = "user_message" + EVENT = "event" # event-driven (perception, error, etc.) + TICK = "tick" # scheduled / periodic checkpoint + PHASE = "phase" # plan phase boundary (between embryos / timepoints) + STARTUP = "startup" # initial session bring-up + UNKNOWN = "unknown" + + +@dataclass +class Decision: + """A single decision moment. + + The fields try to capture three things: + WHY the agent woke up: trigger, trigger_detail + WHAT it saw: context_summary, recent_event_ids + WHAT it did: tool_calls, response_text + + `prompt_hash` is a stable fingerprint of the actual prompt+context + sent to Claude so two candidates with byte-identical input but + different decisions can be told apart by a single field. + """ + + timestamp: datetime + agent: str # "production" or candidate name + trigger: DecisionTrigger + trigger_detail: str | None = None # event_id, user message excerpt, tick name + + tool_calls: list[dict[str, Any]] = field(default_factory=list) + response_text: str | None = None + prompt_hash: str | None = None + + context_summary: str | None = None # one-line description of state + recent_event_ids: list[str] = field(default_factory=list) + + duration_ms: float | None = None # how long the decision took + error: str | None = None # if the decision moment errored + + def to_dict(self) -> dict[str, Any]: + return { + "timestamp": self.timestamp.isoformat(), + "agent": self.agent, + "trigger": self.trigger.value, + "trigger_detail": self.trigger_detail, + "tool_calls": self.tool_calls, + "response_text": self.response_text, + "prompt_hash": self.prompt_hash, + "context_summary": self.context_summary, + "recent_event_ids": self.recent_event_ids, + "duration_ms": self.duration_ms, + "error": self.error, + } + + @classmethod + def from_dict(cls, d: dict[str, Any]) -> Decision: + return cls( + timestamp=datetime.fromisoformat(d["timestamp"]), + agent=d.get("agent", "unknown"), + trigger=DecisionTrigger(d.get("trigger", "unknown")), + trigger_detail=d.get("trigger_detail"), + tool_calls=d.get("tool_calls") or [], + response_text=d.get("response_text"), + prompt_hash=d.get("prompt_hash"), + context_summary=d.get("context_summary"), + recent_event_ids=d.get("recent_event_ids") or [], + duration_ms=d.get("duration_ms"), + error=d.get("error"), + ) + + +class DecisionLog: + """Append-only jsonl sink for Decisions. Thread-safe.""" + + def __init__(self, path: Path): + self.path = Path(path) + self._fp = None + self._lock = threading.Lock() + self._count = 0 + + def open(self) -> None: + if self._fp is not None: + return + self.path.parent.mkdir(parents=True, exist_ok=True) + self._fp = self.path.open("a", encoding="utf-8") + logger.info("DecisionLog: writing to %s", self.path) + + def close(self) -> None: + with self._lock: + if self._fp is not None: + try: + self._fp.close() + except Exception: + logger.exception("DecisionLog: close failed") + self._fp = None + logger.info("DecisionLog: closed (%d decisions written)", self._count) + + def append(self, decision: Decision) -> None: + try: + line = json.dumps(decision.to_dict(), default=_json_default) + except Exception: + logger.exception("DecisionLog: failed to serialise %s", decision) + return + with self._lock: + if self._fp is None: + self.open() + try: + self._fp.write(line + "\n") + self._fp.flush() + self._count += 1 + except Exception: + logger.exception("DecisionLog: write failed") + + @property + def count(self) -> int: + return self._count + + def __enter__(self): + self.open() + return self + + def __exit__(self, exc_type, exc, tb): + self.close() + return False + + def read(self) -> list[Decision]: + """Read every decision back from disk. Quick + dirty diff substrate.""" + if not self.path.exists(): + return [] + out: list[Decision] = [] + with self.path.open("r", encoding="utf-8") as f: + for line_no, raw in enumerate(f, start=1): + raw = raw.strip() + if not raw: + continue + try: + out.append(Decision.from_dict(json.loads(raw))) + except Exception: + logger.exception("DecisionLog: parse failure on line %d", line_no) + return out diff --git a/gently/eval/event_capture.py b/gently/eval/event_capture.py new file mode 100644 index 00000000..5bb42811 --- /dev/null +++ b/gently/eval/event_capture.py @@ -0,0 +1,160 @@ +"""EventCapture — wildcard-subscribe to an EventBus and append every event +to a per-session jsonl file. + +The captured file is the substrate for replay and shadow-mode testing of +candidate orchestrators. High-volume telemetry types (DEVICE_STATE_UPDATE, +BOTTOM_CAMERA_FRAME) are filtered out by default — a 12-hour timelapse +would otherwise produce ~250 MB of polling noise and drown the meaningful +events (perception completions, operator actions, errors, plan boundaries). +Replay can reconstruct world state from the meaningful events plus the +state-snapshot model; it doesn't need the raw telemetry frames. + +File format: one JSON object per line, mirroring Event.to_dict(): + { + "event_type": "EMBRYOS_UPDATE", + "data": {...}, + "source": "agent.experiment", + "timestamp": "2026-05-15T15:32:55.123456", + "event_id": "abc12345", + "correlation_id": null + } +""" + +from __future__ import annotations + +import json +import logging +import threading +from dataclasses import asdict, is_dataclass +from datetime import date, datetime +from enum import Enum +from pathlib import Path + +from gently.core.event_bus import _NO_HISTORY_TYPES, Event, EventBus, EventType + +logger = logging.getLogger(__name__) + + +class EventCapture: + """Append-only jsonl sink for an EventBus. + + Lifecycle: + capture = EventCapture(path) + capture.start(bus) # opens file, subscribes + ... + capture.stop() # unsubscribes, closes file + + Thread-safe — bus dispatch can come from any thread; writes are + serialised through a lock. + """ + + # By default the same set of high-volume telemetry types the EventBus + # itself skips for its history deque. The rationale carries over: at + # 5 Hz over hours these would dominate the log without adding signal + # that replay / diff can use. + DEFAULT_SKIP: set[EventType] = frozenset(_NO_HISTORY_TYPES) + + def __init__(self, path: Path, *, skip: set[EventType] | None = None): + self.path = Path(path) + self._skip = self.DEFAULT_SKIP if skip is None else frozenset(skip) + self._fp = None + self._unsub = None + self._lock = threading.Lock() + self._count = 0 + self._skipped = 0 + + def start(self, bus: EventBus) -> None: + """Open the capture file and subscribe to the bus (idempotent).""" + if self._fp is not None: + return + self.path.parent.mkdir(parents=True, exist_ok=True) + self._fp = self.path.open("a", encoding="utf-8") + # Sync subscription on purpose — capture is fast (single file write) + # and we want capture order to match dispatch order without async + # scheduling ambiguity. + self._unsub = bus.subscribe("*", self._on_event) + logger.info("EventCapture: writing to %s", self.path) + + def stop(self) -> None: + """Unsubscribe and close the file (idempotent).""" + if self._unsub is not None: + try: + self._unsub() + except Exception: + logger.exception("EventCapture: unsubscribe failed") + self._unsub = None + with self._lock: + if self._fp is not None: + try: + self._fp.close() + except Exception: + logger.exception("EventCapture: file close failed") + self._fp = None + logger.info("EventCapture: closed (%d captured, %d skipped)", self._count, self._skipped) + + def __del__(self): + # Best-effort safety net for cases where the owner forgets to call + # stop() — never let a forgotten file handle outlive the process' + # capture object. We can't rely on this for correctness (GC timing + # is undefined), but it makes tests and dev sessions tidier. + try: + self.stop() + except Exception: + pass + + @property + def count(self) -> int: + return self._count + + def _on_event(self, event: Event) -> None: + if event.event_type in self._skip: + self._skipped += 1 + return + try: + line = json.dumps(event.to_dict(), default=_json_default) + except Exception: + logger.exception("EventCapture: failed to serialise %s", event) + return + with self._lock: + if self._fp is None: + return + try: + self._fp.write(line + "\n") + self._fp.flush() + self._count += 1 + except Exception: + logger.exception("EventCapture: write failed for %s", event) + + +def _json_default(obj): + """Last-resort serialiser for types json.dumps can't natively handle. + + Designed to be lossy-but-useful: numpy arrays become lists, datetimes + become ISO strings, dataclasses become dicts, anything else falls back + to repr() so the line is at least valid JSON. + """ + if isinstance(obj, (datetime, date)): + return obj.isoformat() + if isinstance(obj, Path): + return str(obj) + if isinstance(obj, Enum): + return obj.name + if is_dataclass(obj): + try: + return asdict(obj) + except Exception: + pass + try: + import numpy as np + + if isinstance(obj, np.generic): + return obj.item() + if isinstance(obj, np.ndarray): + return obj.tolist() + except ImportError: + pass + if isinstance(obj, set): + return sorted(obj, key=str) + if isinstance(obj, bytes): + return obj.decode("utf-8", errors="replace") + return repr(obj) diff --git a/gently/eval/event_replay.py b/gently/eval/event_replay.py new file mode 100644 index 00000000..a209b612 --- /dev/null +++ b/gently/eval/event_replay.py @@ -0,0 +1,127 @@ +"""EventReplay — reads a captured events jsonl and republishes events to a +target EventBus. + +Two modes: + fast events as fast as the bus can dispatch (default) + real-time inserts sleep delays between events to preserve the original + cadence — useful when a candidate's behaviour depends on + time-since-last-event + +Original Event timestamps are preserved by going through +EventBus.publish_event() (which keeps the dataclass instance untouched) +rather than EventBus.publish() (which constructs a fresh Event with +datetime.now()). Candidates can therefore reason about historical timing +as if they were live. +""" + +from __future__ import annotations + +import json +import logging +import time +from collections.abc import Callable, Iterator +from datetime import datetime +from pathlib import Path + +from gently.core.event_bus import Event, EventBus + +logger = logging.getLogger(__name__) + + +class EventReplay: + """Stream-replays an events.jsonl into a target bus.""" + + def __init__(self, path: Path): + self.path = Path(path) + if not self.path.exists(): + raise FileNotFoundError(f"event log not found: {self.path}") + + def events(self) -> Iterator[Event]: + """Yield each Event from the captured log, in order. + + Lines that don't parse are skipped with a warning rather than + aborting the whole replay — a partial log is better than no log. + """ + with self.path.open("r", encoding="utf-8") as f: + for line_no, raw in enumerate(f, start=1): + raw = raw.strip() + if not raw: + continue + try: + record = json.loads(raw) + except json.JSONDecodeError: + logger.warning("EventReplay: malformed line %d in %s", line_no, self.path) + continue + try: + yield Event.from_dict(record) + except KeyError: + # Unknown EventType — could be a newer enum the + # capturing process knew about. Skip rather than abort. + logger.warning("EventReplay: unknown event_type on line %d", line_no) + except Exception: + logger.exception("EventReplay: parse failure on line %d", line_no) + + def replay( + self, + target: EventBus, + *, + real_time: bool = False, + time_scale: float = 1.0, + on_event: Callable[[Event], None] | None = None, + ) -> int: + """Replay the captured events to ``target``. Returns count emitted. + + Parameters + ---------- + target: + EventBus to publish into. The bus's existing subscribers (and + any shadow candidates registered on it) will see the events. + real_time: + If True, sleep between events to reproduce the original + cadence. If False, dispatch as fast as the bus can handle. + time_scale: + Only meaningful in real-time mode. ``time_scale=4`` runs the + replay at 4× speed (sleep delays divided by 4). Must be > 0. + on_event: + Optional callback invoked after each event is published, for + instrumentation / progress reporting. Exceptions are caught + and logged. + """ + if time_scale <= 0: + raise ValueError("time_scale must be > 0") + + emitted = 0 + prev_ts: datetime | None = None + wall_start = time.monotonic() + for ev in self.events(): + if real_time and prev_ts is not None: + delta = (ev.timestamp - prev_ts).total_seconds() / time_scale + if delta > 0: + time.sleep(delta) + target.publish_event(ev) + emitted += 1 + if on_event is not None: + try: + on_event(ev) + except Exception: + logger.exception("EventReplay: on_event callback failed") + prev_ts = ev.timestamp + wall = time.monotonic() - wall_start + logger.info( + "EventReplay: emitted %d events in %.2fs (real_time=%s, time_scale=%g)", + emitted, + wall, + real_time, + time_scale, + ) + return emitted + + def event_types(self) -> dict: + """Return a {EventType.name: count} histogram of the log. + + Cheap pre-flight diagnostic before running an expensive replay. + """ + counts: dict = {} + for ev in self.events(): + counts[ev.event_type.name] = counts.get(ev.event_type.name, 0) + 1 + return counts diff --git a/gently/eval/shadow.py b/gently/eval/shadow.py new file mode 100644 index 00000000..c3e2a92f --- /dev/null +++ b/gently/eval/shadow.py @@ -0,0 +1,219 @@ +"""Shadow orchestrator scaffolding. + +A candidate orchestrator runs alongside production: it sees the same +events but its decisions are LOGGED, not enacted. Diff the decision logs +between production and a candidate (or between two candidates) to compare +architectures on identical input streams. + +Two entry points: + + OrchestratorCandidate + Protocol that any candidate must satisfy. Receives events via + on_event() and ticks via on_tick(); is given a DecisionLog to write + into. Never gets to call tools that touch hardware — by construction + its only output is the log. + + ShadowRunner + Hosts a set of candidates against a single EventBus. Wildcards onto + the bus and forwards each event to every registered candidate. + Lifecycle (start / stop) keeps subscriptions tidy. + +The simplest candidate is NoOpCandidate, included as a worked example +and as proof-of-life for the wiring (events visible? decision log +writeable? shutdown clean?). +""" + +from __future__ import annotations + +import logging +import threading +from abc import ABC, abstractmethod +from collections.abc import Callable +from datetime import datetime + +from gently.core.event_bus import Event, EventBus + +from .decision_log import Decision, DecisionLog, DecisionTrigger + +logger = logging.getLogger(__name__) + + +class OrchestratorCandidate(ABC): + """Base class for a shadow orchestrator candidate. + + A candidate is given: + - its name (e.g. "reactive-v1", "haiku-summariser") + - a DecisionLog to write decisions into + + It receives events synchronously via ``on_event``. If it needs to + do heavy work (LLM call, long compute), it should hand off to its + own task / thread and write into the log asynchronously. + + Candidates MUST NOT touch hardware. They have no access to the + device-layer client, no permission to publish events back onto the + bus, no MMCore handle. The only side effect they're allowed is + writing to their decision log. + """ + + def __init__(self, name: str, decisions: DecisionLog): + self.name = name + self.decisions = decisions + + @abstractmethod + def on_event(self, event: Event) -> None: + """Handle one event from the bus. Synchronous, must not block long.""" + + def on_start(self) -> None: # noqa: B027 + """Called once when the shadow runner attaches this candidate.""" + + def on_stop(self) -> None: # noqa: B027 + """Called once when the shadow runner detaches this candidate.""" + + # ---- helpers candidates can use --------------------------------------- + + def log_decision( + self, + *, + trigger: DecisionTrigger, + trigger_detail: str | None = None, + tool_calls: list[dict] | None = None, + response_text: str | None = None, + context_summary: str | None = None, + recent_event_ids: list[str] | None = None, + prompt_hash: str | None = None, + duration_ms: float | None = None, + error: str | None = None, + ) -> None: + self.decisions.append( + Decision( + timestamp=datetime.now(), + agent=self.name, + trigger=trigger, + trigger_detail=trigger_detail, + tool_calls=tool_calls or [], + response_text=response_text, + context_summary=context_summary, + recent_event_ids=recent_event_ids or [], + prompt_hash=prompt_hash, + duration_ms=duration_ms, + error=error, + ) + ) + + +class NoOpCandidate(OrchestratorCandidate): + """Trivial candidate: logs every event it sees as a decision marker. + + Useful as the smoke test for the wiring (events visible? decision + log writeable? shutdown clean?) and as the template every real + candidate evolves from. + """ + + def __init__(self, name: str, decisions: DecisionLog, *, watch: list[str] | None = None): + super().__init__(name, decisions) + # Optional whitelist of event_type names to react to. None = all. + self._watch = set(watch) if watch else None + self._seen = 0 + + def on_event(self, event: Event) -> None: + if self._watch is not None and event.event_type.name not in self._watch: + return + self._seen += 1 + self.log_decision( + trigger=DecisionTrigger.EVENT, + trigger_detail=event.event_type.name, + response_text=f"(noop) seen {event.event_type.name} from {event.source}", + recent_event_ids=[event.event_id], + context_summary=f"noop candidate; events seen so far: {self._seen}", + ) + + +class ShadowRunner: + """Hosts a set of OrchestratorCandidates against an EventBus. + + Wildcards onto the bus, dispatches each event to every registered + candidate. Candidates' exceptions are caught and logged so one + bad candidate doesn't take down the others or affect the live bus. + + The runner itself never enacts decisions — it only forwards events + and lets candidates write their own logs. + """ + + def __init__(self, bus: EventBus): + self.bus = bus + self._candidates: list[OrchestratorCandidate] = [] + self._unsub: Callable[[], None] | None = None + self._lock = threading.RLock() + self._running = False + + def add(self, candidate: OrchestratorCandidate) -> None: + with self._lock: + self._candidates.append(candidate) + if self._running: + try: + candidate.on_start() + except Exception: + logger.exception("ShadowRunner: on_start failed for %s", candidate.name) + + def remove(self, candidate: OrchestratorCandidate) -> None: + with self._lock: + try: + self._candidates.remove(candidate) + except ValueError: + return + try: + candidate.on_stop() + except Exception: + logger.exception("ShadowRunner: on_stop failed for %s", candidate.name) + + def start(self) -> None: + """Subscribe to the bus and notify every candidate. Idempotent.""" + with self._lock: + if self._running: + return + self._unsub = self.bus.subscribe("*", self._on_event) + for c in self._candidates: + try: + c.on_start() + except Exception: + logger.exception("ShadowRunner: on_start failed for %s", c.name) + self._running = True + logger.info("ShadowRunner: started with %d candidate(s)", len(self._candidates)) + + def stop(self) -> None: + """Unsubscribe from the bus and notify every candidate. Idempotent.""" + with self._lock: + if not self._running: + return + if self._unsub is not None: + try: + self._unsub() + except Exception: + logger.exception("ShadowRunner: unsubscribe failed") + self._unsub = None + for c in self._candidates: + try: + c.on_stop() + except Exception: + logger.exception("ShadowRunner: on_stop failed for %s", c.name) + self._running = False + logger.info("ShadowRunner: stopped") + + @property + def candidates(self) -> list[OrchestratorCandidate]: + with self._lock: + return list(self._candidates) + + def _on_event(self, event: Event) -> None: + # Snapshot under the lock so a remove() mid-dispatch doesn't break us. + with self._lock: + candidates = list(self._candidates) + for c in candidates: + try: + c.on_event(event) + except Exception: + logger.exception( + "ShadowRunner: candidate %s raised on %s", + c.name, + event, + ) diff --git a/gently/exceptions.py b/gently/exceptions.py index cfa4aeec..11b3fc83 100644 --- a/gently/exceptions.py +++ b/gently/exceptions.py @@ -29,98 +29,137 @@ class GentlyError(Exception): """Base exception for all Gently errors.""" + pass class HardwareError(GentlyError): """Physical device communication failure.""" + pass + class DeviceNotFoundError(HardwareError): """A requested device was not found in the hardware configuration.""" + pass + class DeviceTimeoutError(HardwareError): """Device operation timed out.""" + pass + class StageMovementError(HardwareError): """Stage movement failed or was out of range.""" + pass + class AcquisitionError(HardwareError): """Image or volume acquisition failed.""" + pass class CalibrationError(GentlyError): """Piezo-galvo calibration failure.""" + pass + class FocusFitError(CalibrationError): """Focus curve fitting failed (bad R², insufficient data, etc.).""" + pass + class EdgeDetectionError(CalibrationError): """Embryo edge detection failed (VLM couldn't find boundary).""" + pass + class CalibrationQualityError(CalibrationError): """Calibration quality below acceptable threshold.""" + pass class PerceptionError(GentlyError): """VLM perception/classification failure.""" + pass + class StageClassificationError(PerceptionError): """Could not classify developmental stage.""" + pass + class VerificationError(PerceptionError): """Verification subagent disagreed or failed.""" + pass class StorageError(GentlyError): """Data persistence failure.""" + pass + class SessionNotFoundError(StorageError): """Requested session does not exist.""" + pass + class VolumeNotFoundError(StorageError): """Requested volume does not exist.""" + pass class NetworkError(GentlyError): """Inter-service communication failure.""" + pass + class DeviceLayerError(NetworkError): """Device layer HTTP API returned an error.""" + pass + class MeshPeerError(NetworkError): """Mesh peer communication failed.""" + pass + class ServiceUnavailableError(NetworkError): """Required service is not running.""" + pass class AgentError(GentlyError): """Agent/conversation failure.""" + pass + class ToolExecutionError(AgentError): """A tool call failed during execution.""" + pass + class PlanSynthesisError(AgentError): """Plan synthesis (natural language → Bluesky plan) failed.""" + pass diff --git a/gently/gently.py b/gently/gently.py index 788a5c17..a901e69a 100644 --- a/gently/gently.py +++ b/gently/gently.py @@ -32,31 +32,30 @@ result = await gently.analyze(volume, pipeline="embryo_detection") """ -import asyncio import logging from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any +from .analysis import ( + Pipeline, + PipelineBuilder, + create_embryo_detection_pipeline, + create_hatching_detection_pipeline, + create_morphology_analysis_pipeline, +) from .core import ( EventBus, EventType, - ServiceRegistry, ServiceClient, ServiceInfo, + ServiceRegistry, get_event_bus, get_service_registry, ) -from .log_config import configure_logging -from .settings import settings from .core.file_store import FileStore from .harness.tools.registry import ToolRegistry, get_tool_registry -from .analysis import ( - Pipeline, - PipelineBuilder, - create_embryo_detection_pipeline, - create_hatching_detection_pipeline, - create_morphology_analysis_pipeline, -) +from .log_config import configure_logging +from .settings import settings logger = logging.getLogger(__name__) @@ -99,16 +98,16 @@ def __init__( self._store = FileStore(self.storage_path) # Current session ID (set by start_session or resume_session) - self._current_session_id: Optional[str] = None + self._current_session_id: str | None = None # Initialize tool registry self._tools = get_tool_registry() # Pre-built pipelines - self._pipelines: Dict[str, Pipeline] = { - 'embryo_detection': create_embryo_detection_pipeline(), - 'hatching_detection': create_hatching_detection_pipeline(), - 'morphology_analysis': create_morphology_analysis_pipeline(), + self._pipelines: dict[str, Pipeline] = { + "embryo_detection": create_embryo_detection_pipeline(), + "hatching_detection": create_hatching_detection_pipeline(), + "morphology_analysis": create_morphology_analysis_pipeline(), } # Agent instance (lazy loaded) @@ -131,21 +130,21 @@ def _register_standard_services(self): service_type="rpc", host="localhost", port=18861, - metadata={'description': 'Main microscope control server'}, + metadata={"description": "Main microscope control server"}, ), ServiceInfo( name="sam_server", service_type="rpc", host="localhost", port=18862, - metadata={'description': 'SAM segmentation server'}, + metadata={"description": "SAM segmentation server"}, ), ServiceInfo( name="queue_server", service_type="http", host="localhost", port=settings.network.device_port, - metadata={'description': 'Bluesky queue server'}, + metadata={"description": "Bluesky queue server"}, ), ] @@ -182,7 +181,7 @@ def tools(self) -> ToolRegistry: return self._tools @property - def pipelines(self) -> Dict[str, Pipeline]: + def pipelines(self) -> dict[str, Pipeline]: """Access pre-built pipelines""" return self._pipelines @@ -192,8 +191,8 @@ def pipelines(self) -> Dict[str, Pipeline]: async def start_session( self, - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, ) -> str: """ Start a new session @@ -211,13 +210,14 @@ async def start_session( Session ID """ import uuid + session_id = str(uuid.uuid4())[:8] self._store.create_session(session_id, name=name) self._current_session_id = session_id self._event_bus.publish( EventType.SESSION_STARTED, - {'session_id': session_id, 'name': name}, + {"session_id": session_id, "name": name}, source="gently", ) @@ -243,14 +243,14 @@ async def resume_session(self, session_id: str) -> bool: self._current_session_id = session_id self._event_bus.publish( EventType.SESSION_RESTORED, - {'session_id': session_id}, + {"session_id": session_id}, source="gently", ) logger.info(f"Resumed session: {session_id}") return True return False - def list_sessions(self) -> List[Dict]: + def list_sessions(self) -> list[dict]: """List available sessions""" return self._store.list_sessions() @@ -285,7 +285,7 @@ async def connect_microscope( info.port = port try: - conn = await self._client.connect("microscope_server") + await self._client.connect("microscope_server") logger.info(f"Connected to microscope server at {host}:{port}") return True except Exception as e: @@ -318,7 +318,7 @@ async def connect_sam_server( info.port = port try: - conn = await self._client.connect("sam_server") + await self._client.connect("sam_server") logger.info(f"Connected to SAM server at {host}:{port}") return True except Exception as e: @@ -333,7 +333,7 @@ async def analyze( self, data: Any, pipeline: str = "embryo_detection", - context: Optional[Dict] = None, + context: dict | None = None, ) -> Any: """ Run analysis pipeline on data @@ -353,8 +353,9 @@ async def analyze( Pipeline result with lineage tracking """ if pipeline not in self._pipelines: - raise ValueError(f"Unknown pipeline: {pipeline}. " - f"Available: {list(self._pipelines.keys())}") + raise ValueError( + f"Unknown pipeline: {pipeline}. Available: {list(self._pipelines.keys())}" + ) pipe = self._pipelines[pipeline] @@ -363,9 +364,9 @@ async def analyze( self._event_bus.publish( EventType.ANALYSIS_COMPLETED, { - 'pipeline': pipeline, - 'result_uid': result.uid, - 'success': result.success, + "pipeline": pipeline, + "result_uid": result.uid, + "success": result.success, }, source="gently", ) @@ -414,7 +415,7 @@ def on(self, event_type: EventType, handler): """ return self._event_bus.subscribe(event_type, handler) - def emit(self, event_type: EventType, data: Dict): + def emit(self, event_type: EventType, data: dict): """Emit an event""" self._event_bus.publish(event_type, data, source="gently") @@ -438,10 +439,9 @@ def get_agent(self, **kwargs): """ if self._agent is None: from .app.agent import MicroscopyAgent + self._agent = MicroscopyAgent( - storage_path=self.storage_path, - store=self._store, - **kwargs + storage_path=self.storage_path, store=self._store, **kwargs ) return self._agent @@ -465,6 +465,7 @@ async def start_visualization_server(self, port: int = settings.network.viz_port """ if self._viz_server is None: from .ui.web.server import VisualizationServer + self._viz_server = VisualizationServer( port=port, data_store=None, # Legacy DataStore removed; viz uses event bus @@ -479,7 +480,7 @@ async def push_image( array, uid: str, data_type: str = "image", - metadata: Optional[Dict] = None, + metadata: dict | None = None, ): """ Push an image to the visualization server @@ -520,7 +521,7 @@ async def shutdown(self): if self._current_session_id: self._event_bus.publish( EventType.SESSION_ENDED, - {'session_id': self._current_session_id}, + {"session_id": self._current_session_id}, source="gently", ) diff --git a/gently/hardware/__init__.py b/gently/hardware/__init__.py index 2f07e9ce..4ae505f9 100644 --- a/gently/hardware/__init__.py +++ b/gently/hardware/__init__.py @@ -14,12 +14,23 @@ import importlib import logging +import pkgutil from types import ModuleType -from typing import Optional logger = logging.getLogger(__name__) -_active_hardware: Optional[ModuleType] = None +_active_hardware: ModuleType | None = None + + +def available_hardware() -> list[str]: + """Names of the hardware plugins shipped under gently.hardware.""" + import gently.hardware as _pkg + + return sorted( + m.name + for m in pkgutil.iter_modules(_pkg.__path__) + if m.ispkg and not m.name.startswith("_") + ) def load_hardware(name: str) -> ModuleType: @@ -43,7 +54,18 @@ def load_hardware(name: str) -> ModuleType: If the hardware module cannot be found. """ global _active_hardware - module = importlib.import_module(f"gently.hardware.{name}") + try: + module = importlib.import_module(f"gently.hardware.{name}") + except ModuleNotFoundError as e: + # Only a missing hardware *package* is a config error; re-raise if a + # dependency inside the module is what's missing. + if e.name in (f"gently.hardware.{name}", name): + avail = ", ".join(available_hardware()) or "(none found)" + raise ValueError( + f"Unknown hardware '{name}'. Available: {avail}. " + f"Set 'hardware:' in config/config.yml." + ) from e + raise _active_hardware = module logger.info("Loaded hardware module: %s", name) return module @@ -60,7 +82,6 @@ def get_hardware() -> ModuleType: """ if _active_hardware is None: raise RuntimeError( - "No hardware loaded. Call load_hardware() at startup, " - "or set 'hardware' in config.yml." + "No hardware loaded. Call load_hardware() at startup, or set 'hardware' in config.yml." ) return _active_hardware diff --git a/gently/hardware/console_ui.py b/gently/hardware/console_ui.py new file mode 100644 index 00000000..60e8dfff --- /dev/null +++ b/gently/hardware/console_ui.py @@ -0,0 +1,189 @@ +"""Lightweight terminal styling for the device-layer console. + +Plain ``print`` to stdout, no third-party dependency. ``rich`` is deliberately +avoided here — it has caused Unicode/encoding issues on Windows consoles (see +the stdout-suppression note in ``dispim/device_layer.py``). + +The point of this module is to give the operator a readable, always-visible +picture of the device layer at the terminal — distinct from the file log. The +file log keeps the full INFO/DEBUG firehose; the console shows a curated set of +milestones and a status panel. + +Robust by construction, because the device layer runs on Windows consoles: + +* **Colour** (ANSI) is auto-disabled unless stdout is a TTY. On Windows we try + to enable virtual-terminal processing first; if that fails, colour is off so + raw escape codes never leak. ``NO_COLOR`` (https://no-color.org) and a + ``dumb`` ``TERM`` also disable it. +* **Box-drawing** glyphs are used only when stdout's encoding is UTF-based; + otherwise ASCII equivalents are used so a cp1252 console shows clean output. +* ``out()`` is defensive: any residual ``UnicodeEncodeError`` is caught and the + line re-emitted with ``errors="replace"`` rather than crashing startup. +""" + +from __future__ import annotations + +import os +import sys + +# Visible width of the status panel (border rules). Content lines are written +# without a right border so coloured text never needs width arithmetic. +WIDTH = 64 + + +def _enable_windows_vt() -> bool: + """Best-effort: turn on ANSI escape handling for the Windows console. + + Returns True if VT processing is (now) enabled or we're not on Windows. + """ + if sys.platform != "win32": + return True + try: + import ctypes + from ctypes import wintypes + + kernel32 = ctypes.windll.kernel32 + ENABLE_VT = 0x0004 + handle = kernel32.GetStdHandle(-11) # STD_OUTPUT_HANDLE + mode = wintypes.DWORD() + if not kernel32.GetConsoleMode(handle, ctypes.byref(mode)): + return False + return bool(kernel32.SetConsoleMode(handle, mode.value | ENABLE_VT)) + except Exception: + return False + + +def _detect_color() -> bool: + if sys.stdout is None or not hasattr(sys.stdout, "isatty") or not sys.stdout.isatty(): + return False + if os.environ.get("NO_COLOR") is not None or os.environ.get("TERM") == "dumb": + return False + return _enable_windows_vt() + + +def _detect_unicode() -> bool: + enc = (getattr(sys.stdout, "encoding", None) or "").lower() + return "utf" in enc + + +_USE_COLOR = _detect_color() +_USE_UNICODE = _detect_unicode() + +# Glyphs: pretty (UTF) vs ASCII fallback. +if _USE_UNICODE: + _HEAVY, _LIGHT, _DOT, _CHECK, _MID, _BULLET = "═", "─", "●", "✓", "·", "•" +else: + _HEAVY, _LIGHT, _DOT, _CHECK, _MID, _BULLET = "=", "-", "*", "+", "-", "-" + +# Public separator for callers that build their own value strings. +MIDDOT = f" {_MID} " + +_CODES = { + "reset": "\033[0m", + "bold": "\033[1m", + "dim": "\033[2m", + "green": "\033[32m", + "cyan": "\033[36m", + "yellow": "\033[33m", + "red": "\033[31m", + "blue": "\033[34m", + "magenta": "\033[35m", + "grey": "\033[90m", +} + + +def supports_color() -> bool: + return _USE_COLOR + + +def c(text, *styles: str) -> str: + """Wrap *text* in ANSI styles, or return it unchanged when colour is off.""" + if not _USE_COLOR or not styles: + return str(text) + prefix = "".join(_CODES.get(s, "") for s in styles) + return f"{prefix}{text}{_CODES['reset']}" + + +def out(text: str = "") -> None: + """Print one line to stdout, flushing so it shows immediately. + + Never raises on encoding: a console that can't represent a character gets + a replacement rather than a crashed startup. + """ + try: + print(text, flush=True) + except UnicodeEncodeError: + enc = getattr(sys.stdout, "encoding", None) or "ascii" + sys.stdout.write(text.encode(enc, "replace").decode(enc, "replace") + "\n") + sys.stdout.flush() + + +def rule(heavy: bool = True, style: str = "grey") -> None: + out(c((_HEAVY if heavy else _LIGHT) * WIDTH, style)) + + +def header(title: str, badge: str | None = None, badge_style: str = "yellow") -> None: + """Top of a panel: a heavy rule, a title row (optional right-aligned badge), + and a closing heavy rule.""" + rule(heavy=True) + line = " " + c(title, "bold", "cyan") + if badge: + # Right-align using uncoloured widths so padding ignores ANSI codes. + pad = max(1, WIDTH - len(" " + title) - len(badge) - 1) + line += " " * pad + c(badge, "bold", badge_style) + out(line) + rule(heavy=True) + + +def row(label: str, value: str, label_w: int = 12, label_style: str = "grey") -> None: + """A `` label value`` line inside a panel.""" + out(f" {c(label.ljust(label_w), label_style)}{value}") + + +def sub(label: str, value: str, label_w: int = 10) -> None: + """An indented sub-row, e.g. a device-group breakdown.""" + out(f" {c(label.ljust(label_w), 'grey')}{value}") + + +def step(n: int, total: int, label: str) -> None: + """A startup progress line: `` [2/5] Starting Micro-Manager core``""" + out(f" {c(f'[{n}/{total}]', 'cyan')} {label}") + + +def step_done(detail: str = "ok") -> None: + """A check-mark continuation under the most recent step.""" + out(f" {c(_CHECK, 'green')} {c(detail, 'grey')}") + + +def note(text: str, style: str = "grey") -> None: + out(f" {c(text, style)}") + + +def bullet(text: str) -> None: + out(f" {c(_BULLET, 'cyan')} {text}") + + +def error_panel( + title: str, summary: str, details: str | None = None, hints=None, log_file=None +) -> None: + """A red FAILED panel: one-line summary, optional detail, fix hints, log path. + + Used at the top-level startup catch so an operator sees a plain-language + diagnosis instead of a Python traceback (which still goes to the log file). + """ + out() + header(title, badge="FAILED", badge_style="red") + note(summary, "yellow") + if details: + out() + row("Details", details, label_w=10) + if hints: + out() + note("Try this:", "bold") + for h in hints: + bullet(h) + if log_file: + out() + row("Full log", str(log_file), label_w=10) + rule(heavy=True) + out() diff --git a/gently/hardware/dispim/__init__.py b/gently/hardware/dispim/__init__.py index 9d479325..9337a79d 100644 --- a/gently/hardware/dispim/__init__.py +++ b/gently/hardware/dispim/__init__.py @@ -7,6 +7,16 @@ from .description import HARDWARE_DESCRIPTION +__all__ = [ + "HARDWARE_DESCRIPTION", + "HARDWARE_NAME", + "HARDWARE_DISPLAY_NAME", + "CAPABILITIES", + "create_device_layer", + "create_client", + "create_microscope", +] + HARDWARE_NAME = "dispim" HARDWARE_DISPLAY_NAME = "diSPIM" CAPABILITIES = { @@ -35,9 +45,10 @@ def create_device_layer(config: dict): The server instance (call .run(port=N) to start) """ from .device_layer import DeviceLayerServer + return DeviceLayerServer( - config_path=config.get('config_path', 'config/config.yml'), - sam_device=config.get('sam_device', 'cuda'), + config_path=config.get("config_path", "config/config.yml"), + sam_device=config.get("sam_device", "cuda"), ) @@ -55,6 +66,7 @@ def create_client(http_url: str): The microscope instance (call .connect() before use) """ from .client import DiSPIMMicroscope + return DiSPIMMicroscope(http_url=http_url) diff --git a/gently/hardware/dispim/calibration.py b/gently/hardware/dispim/calibration.py index a45d0b21..a13a6d1b 100644 --- a/gently/hardware/dispim/calibration.py +++ b/gently/hardware/dispim/calibration.py @@ -12,7 +12,6 @@ from dataclasses import dataclass from datetime import datetime -from typing import Dict, Optional @dataclass @@ -24,13 +23,14 @@ class CalibrationPrior: time by using cross-embryo learning. The prior is updated after each successful calibration using an exponential moving average. """ + # Linear relationship: piezo = slope * galvo + offset slope_um_per_deg: float = 100.0 # Default heuristic offset_um: float = 0.0 # Confidence metrics r_squared_mean: float = 0.0 # Average R-squared from contributing calibrations - num_calibrations: int = 0 # Number of embryos contributing to prior + num_calibrations: int = 0 # Number of embryos contributing to prior # Observed ranges (for adaptive sweep window sizing) slope_min: float = 90.0 @@ -39,15 +39,15 @@ class CalibrationPrior: offset_max: float = 20.0 # Edge detection statistics - typical_extent_deg: float = 0.3 # Average embryo Z extent in degrees - extent_std_deg: float = 0.1 # Variation in extent + typical_extent_deg: float = 0.3 # Average embryo Z extent in degrees + extent_std_deg: float = 0.1 # Variation in extent # Timestamp for staleness checking - last_updated: Optional[datetime] = None + last_updated: datetime | None = None # Fast calibration: lock slope after first embryo bootstrap session_slope_locked: bool = False - bootstrap_embryo_id: Optional[str] = None # Which embryo established the slope + bootstrap_embryo_id: str | None = None # Which embryo established the slope def lock_session_slope(self, slope: float, r_squared: float, embryo_id: str): """ @@ -82,7 +82,7 @@ def update_from_calibration( offset: float, r_squared: float, extent_deg: float, - alpha: float = 0.3 + alpha: float = 0.3, ): """ Update prior with new calibration result using exponential moving average. @@ -122,41 +122,41 @@ def update_from_calibration( self.num_calibrations += 1 self.last_updated = datetime.now() - def to_dict(self) -> Dict: + def to_dict(self) -> dict: """Serialize for JSON storage""" return { - 'slope_um_per_deg': self.slope_um_per_deg, - 'offset_um': self.offset_um, - 'r_squared_mean': self.r_squared_mean, - 'num_calibrations': self.num_calibrations, - 'slope_min': self.slope_min, - 'slope_max': self.slope_max, - 'offset_min': self.offset_min, - 'offset_max': self.offset_max, - 'typical_extent_deg': self.typical_extent_deg, - 'extent_std_deg': self.extent_std_deg, - 'last_updated': self.last_updated.isoformat() if self.last_updated else None, - 'session_slope_locked': self.session_slope_locked, - 'bootstrap_embryo_id': self.bootstrap_embryo_id, + "slope_um_per_deg": self.slope_um_per_deg, + "offset_um": self.offset_um, + "r_squared_mean": self.r_squared_mean, + "num_calibrations": self.num_calibrations, + "slope_min": self.slope_min, + "slope_max": self.slope_max, + "offset_min": self.offset_min, + "offset_max": self.offset_max, + "typical_extent_deg": self.typical_extent_deg, + "extent_std_deg": self.extent_std_deg, + "last_updated": self.last_updated.isoformat() if self.last_updated else None, + "session_slope_locked": self.session_slope_locked, + "bootstrap_embryo_id": self.bootstrap_embryo_id, } @classmethod - def from_dict(cls, data: Dict) -> 'CalibrationPrior': + def from_dict(cls, data: dict) -> "CalibrationPrior": """Deserialize from JSON""" prior = cls( - slope_um_per_deg=data.get('slope_um_per_deg', 100.0), - offset_um=data.get('offset_um', 0.0), - r_squared_mean=data.get('r_squared_mean', 0.0), - num_calibrations=data.get('num_calibrations', 0), - slope_min=data.get('slope_min', 90.0), - slope_max=data.get('slope_max', 110.0), - offset_min=data.get('offset_min', -20.0), - offset_max=data.get('offset_max', 20.0), - typical_extent_deg=data.get('typical_extent_deg', 0.3), - extent_std_deg=data.get('extent_std_deg', 0.1), - session_slope_locked=data.get('session_slope_locked', False), - bootstrap_embryo_id=data.get('bootstrap_embryo_id'), + slope_um_per_deg=data.get("slope_um_per_deg", 100.0), + offset_um=data.get("offset_um", 0.0), + r_squared_mean=data.get("r_squared_mean", 0.0), + num_calibrations=data.get("num_calibrations", 0), + slope_min=data.get("slope_min", 90.0), + slope_max=data.get("slope_max", 110.0), + offset_min=data.get("offset_min", -20.0), + offset_max=data.get("offset_max", 20.0), + typical_extent_deg=data.get("typical_extent_deg", 0.3), + extent_std_deg=data.get("extent_std_deg", 0.1), + session_slope_locked=data.get("session_slope_locked", False), + bootstrap_embryo_id=data.get("bootstrap_embryo_id"), ) - if data.get('last_updated'): - prior.last_updated = datetime.fromisoformat(data['last_updated']) + if data.get("last_updated"): + prior.last_updated = datetime.fromisoformat(data["last_updated"]) return prior diff --git a/gently/hardware/dispim/claude_client.py b/gently/hardware/dispim/claude_client.py index 722add7f..1998dfbe 100644 --- a/gently/hardware/dispim/claude_client.py +++ b/gently/hardware/dispim/claude_client.py @@ -18,18 +18,19 @@ import base64 import os from pathlib import Path -from typing import Optional, Tuple, Dict + import anthropic from gently.settings import settings + from .plans.calibration import EMBRYO_CENTERING_PROMPT, EMBRYO_EDGE_PROMPT _MEDIA_TYPE_MAP = { - '.png': 'image/png', - '.jpg': 'image/jpeg', - '.jpeg': 'image/jpeg', - '.gif': 'image/gif', - '.webp': 'image/webp', + ".png": "image/png", + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".gif": "image/gif", + ".webp": "image/webp", } @@ -37,6 +38,7 @@ # ASYNC CLAUDE CLIENT # ============================================================================ + class AsyncClaudeClient: """ Async Claude API client for embryo calibration workflows. @@ -65,14 +67,14 @@ class AsyncClaudeClient: def __init__( self, - api_key: Optional[str] = None, + api_key: str | None = None, model: str = settings.models.perception, max_tokens: int = 100, - timeout: float = 30.0 + timeout: float = 30.0, ): """Initialize async Claude client.""" if api_key is None: - api_key = os.environ.get('ANTHROPIC_API_KEY') + api_key = os.environ.get("ANTHROPIC_API_KEY") if api_key is None: raise ValueError( "ANTHROPIC_API_KEY not found in environment. " @@ -99,12 +101,12 @@ def encode_image(image_path: Path) -> str: str Base64-encoded image data """ - with open(image_path, 'rb') as f: + with open(image_path, "rb") as f: image_data = f.read() - return base64.standard_b64encode(image_data).decode('utf-8') + return base64.standard_b64encode(image_data).decode("utf-8") @staticmethod - def parse_yes_no_response(response_text: str) -> Tuple[bool, str]: + def parse_yes_no_response(response_text: str) -> tuple[bool, str]: """ Parse Claude's yes/no response format. @@ -124,14 +126,14 @@ def parse_yes_no_response(response_text: str) -> Tuple[bool, str]: str Description from remaining lines """ - lines = response_text.strip().split('\n', 1) + lines = response_text.strip().split("\n", 1) if len(lines) == 0: return False, "Empty response" # Parse first line for yes/no first_line = lines[0].strip().lower() - is_yes = 'yes' in first_line + is_yes = "yes" in first_line # Get description from remaining lines description = lines[1].strip() if len(lines) > 1 else "No description provided" @@ -139,10 +141,8 @@ def parse_yes_no_response(response_text: str) -> Tuple[bool, str]: return is_yes, description async def check_embryo_centered( - self, - image_path: Path, - custom_prompt: Optional[str] = None - ) -> Tuple[bool, str]: + self, image_path: Path, custom_prompt: str | None = None + ) -> tuple[bool, str]: """ Check if embryo is centered and visible in image. @@ -179,7 +179,7 @@ async def check_embryo_centered( image_data = self.encode_image(image_path) # Get media type from file extension - media_type = _MEDIA_TYPE_MAP.get(image_path.suffix.lower(), 'image/png') + media_type = _MEDIA_TYPE_MAP.get(image_path.suffix.lower(), "image/png") # Prepare prompt prompt = custom_prompt if custom_prompt else EMBRYO_CENTERING_PROMPT @@ -190,25 +190,24 @@ async def check_embryo_centered( self.client.messages.create( model=self.model, max_tokens=self.max_tokens, - messages=[{ - "role": "user", - "content": [ - { - "type": "image", - "source": { - "type": "base64", - "media_type": media_type, - "data": image_data - } - }, - { - "type": "text", - "text": prompt - } - ] - }] + messages=[ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": media_type, + "data": image_data, + }, + }, + {"type": "text", "text": prompt}, + ], + } + ], ), - timeout=self.timeout + timeout=self.timeout, ) # Extract text response @@ -225,10 +224,8 @@ async def check_embryo_centered( return False, f"Claude API error: {str(e)}" async def detect_embryo_presence( - self, - image_path: Path, - custom_prompt: Optional[str] = None - ) -> Tuple[bool, int, str]: + self, image_path: Path, custom_prompt: str | None = None + ) -> tuple[bool, int, str]: """ Detect if embryo is present at current Z position (for edge detection). @@ -271,7 +268,7 @@ async def detect_embryo_presence( image_data = self.encode_image(image_path) # Get media type - media_type = _MEDIA_TYPE_MAP.get(image_path.suffix.lower(), 'image/png') + media_type = _MEDIA_TYPE_MAP.get(image_path.suffix.lower(), "image/png") # Prepare prompt prompt = custom_prompt if custom_prompt else EMBRYO_EDGE_PROMPT @@ -282,35 +279,34 @@ async def detect_embryo_presence( self.client.messages.create( model=self.model, max_tokens=self.max_tokens, - messages=[{ - "role": "user", - "content": [ - { - "type": "image", - "source": { - "type": "base64", - "media_type": media_type, - "data": image_data - } - }, - { - "type": "text", - "text": prompt - } - ] - }] + messages=[ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": media_type, + "data": image_data, + }, + }, + {"type": "text", "text": prompt}, + ], + } + ], ), - timeout=self.timeout + timeout=self.timeout, ) # Extract text response response_text = response.content[0].text # Parse response - now 3 lines: yes/no, score, description - lines = response_text.strip().split('\n') + lines = response_text.strip().split("\n") # Line 1: yes/no - is_present = 'yes' in lines[0].lower() if lines else False + is_present = "yes" in lines[0].lower() if lines else False # Line 2: feature score (1-10) feature_score = 0 @@ -320,14 +316,15 @@ async def detect_embryo_presence( score_line = lines[1].strip() # Handle cases like "8" or "Score: 8" or "8/10" import re - match = re.search(r'\d+', score_line) + + match = re.search(r"\d+", score_line) if match: feature_score = min(10, max(0, int(match.group()))) except (ValueError, IndexError): feature_score = 5 if is_present else 0 # Default # Line 3+: description - description = '\n'.join(lines[2:]).strip() if len(lines) > 2 else "No description" + description = "\n".join(lines[2:]).strip() if len(lines) > 2 else "No description" # If not present, ensure score is 0 if not is_present: @@ -344,8 +341,8 @@ async def validate_focus_montage( self, montage_path: Path, selected_position_um: float, - prompt_template: Optional[str] = None - ) -> Tuple[str, str]: + prompt_template: str | None = None, + ) -> tuple[str, str]: """ Validate algorithmic focus selection by analyzing montage. @@ -384,9 +381,11 @@ async def validate_focus_montage( # Default validation prompt if prompt_template is None: - prompt_template = """You are an expert microscopist reviewing focus quality in embryo images. + prompt_template = """You are an expert microscopist reviewing focus quality in +embryo images. -This montage shows a focus sweep through an embryo sample. Each panel is labeled with its Z position in micrometers. +This montage shows a focus sweep through an embryo sample. Each panel is labeled with its +Z position in micrometers. Our FFT-based algorithm selected position: {position:.2f} µm as optimal focus. @@ -412,7 +411,7 @@ async def validate_focus_montage( # Encode image image_data = self.encode_image(montage_path) - media_type = 'image/png' + media_type = "image/png" try: # Make async API call @@ -420,41 +419,40 @@ async def validate_focus_montage( self.client.messages.create( model=self.model, max_tokens=150, # Slightly longer for reasoning - messages=[{ - "role": "user", - "content": [ - { - "type": "image", - "source": { - "type": "base64", - "media_type": media_type, - "data": image_data - } - }, - { - "type": "text", - "text": prompt - } - ] - }] + messages=[ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": media_type, + "data": image_data, + }, + }, + {"type": "text", "text": prompt}, + ], + } + ], ), - timeout=self.timeout + timeout=self.timeout, ) # Extract response response_text = response.content[0].text - lines = response_text.strip().split('\n', 1) + lines = response_text.strip().split("\n", 1) decision = lines[0].strip().upper() reasoning = lines[1].strip() if len(lines) > 1 else "No reasoning provided" # Normalize decision - if 'CONFIRM' in decision: - decision = 'CONFIRM' - elif 'REJECT' in decision: - decision = 'REJECT' + if "CONFIRM" in decision: + decision = "CONFIRM" + elif "REJECT" in decision: + decision = "REJECT" else: - decision = 'REJECT' # Default to reject if unclear + decision = "REJECT" # Default to reject if unclear reasoning = f"Unclear response: {decision}. {reasoning}" return decision, reasoning @@ -468,8 +466,8 @@ async def select_best_focus( self, montage_path: Path, offsets: list[float], - labels: Optional[list[str]] = None - ) -> Tuple[int, str, str]: + labels: list[str] | None = None, + ) -> tuple[int, str, str]: """ Select the best-focused image from a montage using Vision. @@ -506,16 +504,14 @@ async def select_best_focus( montage_path = Path(montage_path) if not montage_path.exists(): - return 1, 'B', f"Montage file not found, defaulting to center" + return 1, "B", "Montage file not found, defaulting to center" # Default labels if labels is None: - labels = [chr(ord('A') + i) for i in range(len(offsets))] + labels = [chr(ord("A") + i) for i in range(len(offsets))] # Build offset description for prompt - offset_desc = ", ".join( - f"{labels[i]}={offsets[i]:+.1f}µm" for i in range(len(offsets)) - ) + offset_desc = ", ".join(f"{labels[i]}={offsets[i]:+.1f}µm" for i in range(len(offsets))) prompt = f"""You are an expert microscopist comparing focus quality in embryo images. @@ -529,7 +525,7 @@ async def select_best_focus( - Best overall image clarity and contrast RESPOND FORMAT: -Line 1: Just the letter ({', '.join(labels)}) +Line 1: Just the letter ({", ".join(labels)}) Line 2: Brief reasoning (1 sentence) Example: @@ -538,37 +534,36 @@ async def select_best_focus( # Encode image image_data = self.encode_image(montage_path) - media_type = 'image/png' + media_type = "image/png" try: response = await asyncio.wait_for( self.client.messages.create( model=self.model, max_tokens=100, - messages=[{ - "role": "user", - "content": [ - { - "type": "image", - "source": { - "type": "base64", - "media_type": media_type, - "data": image_data - } - }, - { - "type": "text", - "text": prompt - } - ] - }] + messages=[ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": media_type, + "data": image_data, + }, + }, + {"type": "text", "text": prompt}, + ], + } + ], ), - timeout=self.timeout + timeout=self.timeout, ) # Parse response response_text = response.content[0].text.strip() - lines = response_text.split('\n', 1) + lines = response_text.split("\n", 1) # Extract selected label (first non-empty character that's a valid label) selected = None @@ -590,16 +585,25 @@ async def select_best_focus( except asyncio.TimeoutError: # Default to center on timeout center_idx = len(offsets) // 2 - return center_idx, labels[center_idx], f"Claude API timeout, defaulting to center" + return ( + center_idx, + labels[center_idx], + "Claude API timeout, defaulting to center", + ) except Exception as e: center_idx = len(offsets) // 2 - return center_idx, labels[center_idx], f"Claude API error: {str(e)}, defaulting to center" + return ( + center_idx, + labels[center_idx], + f"Claude API error: {str(e)}, defaulting to center", + ) # ============================================================================ # SYNCHRONOUS WRAPPER FOR BACKWARDS COMPATIBILITY # ============================================================================ + class ClaudeClient: """ Synchronous wrapper around AsyncClaudeClient for backwards compatibility. @@ -612,19 +616,17 @@ def __init__(self, **kwargs): """Initialize sync client (wraps async client).""" self.async_client = AsyncClaudeClient(**kwargs) - def check_embryo_centered(self, image_path: Path) -> Tuple[bool, str]: + def check_embryo_centered(self, image_path: Path) -> tuple[bool, str]: """Sync version of check_embryo_centered.""" return asyncio.run(self.async_client.check_embryo_centered(image_path)) - def detect_embryo_presence(self, image_path: Path) -> Tuple[bool, int, str]: + def detect_embryo_presence(self, image_path: Path) -> tuple[bool, int, str]: """Sync version of detect_embryo_presence.""" return asyncio.run(self.async_client.detect_embryo_presence(image_path)) def validate_focus_montage( - self, - montage_path: Path, - selected_position_um: float - ) -> Tuple[str, str]: + self, montage_path: Path, selected_position_um: float + ) -> tuple[str, str]: """Sync version of validate_focus_montage.""" return asyncio.run( self.async_client.validate_focus_montage(montage_path, selected_position_um) diff --git a/gently/hardware/dispim/client.py b/gently/hardware/dispim/client.py index 6f097670..2e4567e1 100644 --- a/gently/hardware/dispim/client.py +++ b/gently/hardware/dispim/client.py @@ -8,22 +8,20 @@ import asyncio import logging import traceback -from typing import Any, Dict, List, Optional, Tuple -import numpy as np -import aiohttp +from typing import Any -logger = logging.getLogger(__name__) +import aiohttp +import numpy as np from gently.core.coordinates import ( - pixel_to_stage_position, - stage_to_pixel_position, - get_um_per_pixel, - DEFAULT_PIXEL_SIZE_UM, DEFAULT_OBJECTIVE_MAG, + DEFAULT_PIXEL_SIZE_UM, ) +from gently.exceptions import DeviceLayerError, NetworkError from gently.harness.microscope import Microscope from gently.settings import settings -from gently.exceptions import DeviceLayerError, NetworkError, AcquisitionError + +logger = logging.getLogger(__name__) class DiSPIMMicroscope(Microscope): @@ -99,7 +97,7 @@ async def connect(self) -> bool: async with self._session.get(f"{self.http_url}/api/sam/status") as resp: if resp.status == 200: sam_status = await resp.json() - self._sam_available = sam_status.get('available', False) + self._sam_available = sam_status.get("available", False) except aiohttp.ClientError: self._sam_available = False @@ -166,9 +164,7 @@ def has_sam(self) -> bool: def _ensure_connected(self): """Raise error if not connected""" if not self.is_connected: - raise ConnectionError( - "Not connected to Microscope Server. Call connect() first." - ) + raise ConnectionError("Not connected to Microscope Server. Call connect() first.") # ========================================================================= # Session Configuration (FileStore integration) @@ -217,9 +213,10 @@ def _resolve_file_ref(ref: dict) -> tuple: ------- tuple of (np.ndarray, Path) """ - import tifffile from pathlib import Path + import tifffile + path = Path(ref["path"]) arr = tifffile.imread(str(path)) return arr, path @@ -259,7 +256,7 @@ async def _api_get(self, path: str) -> dict: async with self._session.get(f"{self.http_url}{path}") as resp: return await resp.json() - async def _api_post(self, path: str, json: dict = None) -> dict: + async def _api_post(self, path: str, json: dict | None = None) -> dict: """POST request using the shared session.""" self._ensure_connected() async with self._session.post(f"{self.http_url}{path}", json=json) as resp: @@ -268,9 +265,9 @@ async def _api_post(self, path: str, json: dict = None) -> dict: async def _submit_plan_and_wait( self, plan_name: str, - kwargs: Dict = None, + kwargs: dict | None = None, timeout: float = 120.0, - ) -> Dict: + ) -> dict: """Submit a Bluesky plan to the server and wait for completion. Parameters @@ -300,37 +297,37 @@ async def _submit_plan_and_wait( async with self._session.post( f"{self.http_url}/api/queue/item/add", json=payload, - timeout=aiohttp.ClientTimeout(total=timeout) + timeout=aiohttp.ClientTimeout(total=timeout), ) as resp: if resp.status != 200: error_text = await resp.text() return { - 'success': False, - 'error': f"HTTP {resp.status}: {error_text}", + "success": False, + "error": f"HTTP {resp.status}: {error_text}", } result = await resp.json() # Resolve file references (zero-copy transfer) if isinstance(result, dict): - docs = result.get('documents', {}) - events = docs.get('events', []) + docs = result.get("documents", {}) + events = docs.get("events", []) for event in events: - data = event.get('data', {}) + data = event.get("data", {}) for key, val in list(data.items()): if self._is_file_ref(val): arr, path = self._resolve_file_ref(val) data[key] = arr # Store path for downstream use - if 'volume_path' not in result: - result['volume_path'] = str(path) + if "volume_path" not in result: + result["volume_path"] = str(path) return result except asyncio.TimeoutError: return { - 'success': False, - 'error': f"Plan '{plan_name}' timed out after {timeout}s", + "success": False, + "error": f"Plan '{plan_name}' timed out after {timeout}s", } except aiohttp.ClientError as e: raise NetworkError(f"Device layer request failed: {e}") from e @@ -339,7 +336,7 @@ async def _submit_plan_and_wait( # Stage Control # ========================================================================= - async def move_to_position(self, x: float, y: float) -> Dict: + async def move_to_position(self, x: float, y: float) -> dict: """ Move stage to absolute position. @@ -351,16 +348,15 @@ async def move_to_position(self, x: float, y: float) -> Dict: logger.info("Moving to (%.1f, %.1f) µm", x, y) result = await self._submit_plan_and_wait( - 'move_stage_plan', - kwargs={'xy_stage': 'xy_stage', 'x': x, 'y': y} + "move_stage_plan", kwargs={"xy_stage": "xy_stage", "x": x, "y": y} ) - if result.get('success'): - return {'success': True, 'x': x, 'y': y} + if result.get("success"): + return {"success": True, "x": x, "y": y} return result - async def get_stage_position(self) -> Tuple[float, float]: + async def get_stage_position(self) -> tuple[float, float]: """ Get current stage position. @@ -369,21 +365,23 @@ async def get_stage_position(self) -> Tuple[float, float]: tuple of (float, float) Current (x, y) position in micrometers """ - result = await self._submit_plan_and_wait('read_stage_plan', kwargs={'xy_stage': 'xy_stage'}) + result = await self._submit_plan_and_wait( + "read_stage_plan", kwargs={"xy_stage": "xy_stage"} + ) - if result.get('success'): - docs = result.get('documents', {}) - events = docs.get('events', []) + if result.get("success"): + docs = result.get("documents", {}) + events = docs.get("events", []) if events: - data = events[0].get('data', {}) + data = events[0].get("data", {}) # Look for stage coordinates - for key in ['XY:31', 'xy_stage', 'stage']: + for key in ["XY:31", "xy_stage", "stage"]: if key in data: val = data[key] if isinstance(val, (list, tuple)) and len(val) >= 2: return (float(val[0]), float(val[1])) if isinstance(val, dict): - return (float(val.get('x', 0)), float(val.get('y', 0))) + return (float(val.get("x", 0)), float(val.get("y", 0))) raise DeviceLayerError("Failed to read stage position") @@ -396,20 +394,20 @@ async def get_piezo_position(self) -> float: float Current Z position in micrometers """ - result = await self._submit_plan_and_wait('read_piezo_plan', kwargs={'piezo': 'piezo'}) + result = await self._submit_plan_and_wait("read_piezo_plan", kwargs={"piezo": "piezo"}) - if result.get('success'): - docs = result.get('documents', {}) - events = docs.get('events', []) + if result.get("success"): + docs = result.get("documents", {}) + events = docs.get("events", []) if events: - data = events[0].get('data', {}) - for key in ['PiezoStage:P:34', 'piezo', 'z_stage']: + data = events[0].get("data", {}) + for key in ["PiezoStage:P:34", "piezo", "z_stage"]: if key in data: val = data[key] if isinstance(val, (int, float)): return float(val) if isinstance(val, dict): - return float(val.get('z', val.get('position', 0))) + return float(val.get("z", val.get("position", 0))) raise DeviceLayerError("Failed to read piezo position") @@ -419,10 +417,10 @@ async def get_piezo_position(self) -> float: async def calibrate_piezo_galvo( self, - piezo_positions: Optional[List[float]] = None, - galvo_positions: Optional[List[float]] = None, + piezo_positions: list[float] | None = None, + galvo_positions: list[float] | None = None, **kwargs, - ) -> Dict: + ) -> dict: """ Run piezo-galvo calibration plan. @@ -431,23 +429,21 @@ async def calibrate_piezo_galvo( dict Calibration results with optimal positions """ - plan_kwargs = {'lightsheet_snap': 'lightsheet_snap'} + plan_kwargs = {"lightsheet_snap": "lightsheet_snap"} if piezo_positions is not None: - plan_kwargs['piezo_positions'] = piezo_positions + plan_kwargs["piezo_positions"] = piezo_positions if galvo_positions is not None: - plan_kwargs['galvo_positions'] = galvo_positions + plan_kwargs["galvo_positions"] = galvo_positions plan_kwargs.update(kwargs) result = await self._submit_plan_and_wait( - 'calibrate_piezo_galvo_plan', - kwargs=plan_kwargs, - timeout=300.0 + "calibrate_piezo_galvo_plan", kwargs=plan_kwargs, timeout=300.0 ) - if result.get('success'): + if result.get("success"): return { - 'success': True, - 'calibration': result.get('calibration', {}), + "success": True, + "calibration": result.get("calibration", {}), } return result @@ -456,22 +452,24 @@ async def calibrate_piezo_galvo( # Imaging # ========================================================================= - def _extract_image(self, result: dict, candidate_keys: List[str], multi_event: bool = False) -> Optional[tuple]: + def _extract_image( + self, result: dict, candidate_keys: list[str], multi_event: bool = False + ) -> tuple | None: """Extract image array from plan result documents. Returns (array, path) or None if not found. """ - if not result.get('success'): + if not result.get("success"): return None - docs = result.get('documents', {}) - events = docs.get('events', []) + docs = result.get("documents", {}) + events = docs.get("events", []) if not events: return None search_events = events if multi_event else [events[0]] for event in search_events: - data = event.get('data', {}) + data = event.get("data", {}) for key in candidate_keys: if key in data: val = data[key] @@ -484,11 +482,11 @@ def _extract_image(self, result: dict, candidate_keys: List[str], multi_event: b async def capture_lightsheet_image( self, - piezo_position: Optional[float] = None, - galvo_position: Optional[float] = None, + piezo_position: float | None = None, + galvo_position: float | None = None, exposure_ms: float = 10.0, **kwargs, - ) -> Dict: + ) -> dict: """ Capture a single lightsheet image at specified position. @@ -504,35 +502,36 @@ async def capture_lightsheet_image( Returns ------- dict - ``{'image': np.ndarray, 'piezo_position': float, 'galvo_position': float, 'success': bool}`` + ``{'image': np.ndarray, 'piezo_position': float, + 'galvo_position': float, 'success': bool}`` """ result = await self._submit_plan_and_wait( - 'capture_lightsheet_image_plan', + "capture_lightsheet_image_plan", kwargs={ - 'lightsheet_snap': 'lightsheet_snap', - 'scanner': 'scanner', - 'piezo': 'piezo', - 'laser_control': 'laser_control', - 'piezo_position': piezo_position if piezo_position is not None else 50.0, - 'galvo_position': galvo_position if galvo_position is not None else 0.0, + "lightsheet_snap": "lightsheet_snap", + "scanner": "scanner", + "piezo": "piezo", + "laser_control": "laser_control", + "piezo_position": piezo_position if piezo_position is not None else 50.0, + "galvo_position": galvo_position if galvo_position is not None else 0.0, }, - timeout=30.0 + timeout=30.0, ) - extracted = self._extract_image(result, ['HamCam1', 'lightsheet_snap', 'camera']) + extracted = self._extract_image(result, ["HamCam1", "lightsheet_snap", "camera"]) if extracted: arr, fpath = extracted ret = { - 'image': arr, - 'piezo_position': piezo_position or 0.0, - 'galvo_position': galvo_position or 0.0, - 'success': True, + "image": arr, + "piezo_position": piezo_position or 0.0, + "galvo_position": galvo_position or 0.0, + "success": True, } if fpath: - ret['image_path'] = fpath + ret["image_path"] = fpath return ret - return {'error': result.get('error', 'No image data'), 'success': False} + return {"error": result.get("error", "No image data"), "success": False} async def acquire_volume( self, @@ -542,13 +541,13 @@ async def acquire_volume( galvo_center: float = 0.0, piezo_amplitude: float = 25.0, piezo_center: float = 50.0, - laser_config: str = None, - laser_power_488_pct: float = None, - laser_power_561_pct: float = None, - laser_power_405_pct: float = None, - laser_power_637_pct: float = None, + laser_config: str | None = None, + laser_power_488_pct: float | None = None, + laser_power_561_pct: float | None = None, + laser_power_405_pct: float | None = None, + laser_power_637_pct: float | None = None, **kwargs, - ) -> Dict: + ) -> dict: """ Acquire a 3D volume via synchronized galvo-piezo scan. @@ -565,7 +564,8 @@ async def acquire_volume( laser_config : str, optional Laser channel preset ("488 and 561", "488 only", etc.). None uses the device-layer default. - laser_power_488_pct, laser_power_561_pct, laser_power_405_pct, laser_power_637_pct : float, optional + laser_power_488_pct, laser_power_561_pct, laser_power_405_pct, + laser_power_637_pct : float, optional Per-line laser power %. Hard-limited at the device layer (DiSPIMLightSource.POWER_LIMITS_PCT). None leaves current setpoint untouched. @@ -576,48 +576,48 @@ async def acquire_volume( ``{'volume': np.ndarray, 'shape': tuple, 'success': bool}`` """ plan_kwargs = { - 'volume_scanner': 'volume_scanner', - 'num_slices': num_slices, - 'exposure_ms': exposure_ms, - 'galvo_amplitude': galvo_amplitude, - 'galvo_center': galvo_center, - 'piezo_amplitude': piezo_amplitude, - 'piezo_center': piezo_center, + "volume_scanner": "volume_scanner", + "num_slices": num_slices, + "exposure_ms": exposure_ms, + "galvo_amplitude": galvo_amplitude, + "galvo_center": galvo_center, + "piezo_amplitude": piezo_amplitude, + "piezo_center": piezo_center, } # Only forward kwargs the user explicitly set — leaves the # acquire_single_volume_plan defaults in place otherwise. if laser_config is not None: - plan_kwargs['laser_config'] = laser_config + plan_kwargs["laser_config"] = laser_config if laser_power_488_pct is not None: - plan_kwargs['laser_power_488_pct'] = laser_power_488_pct + plan_kwargs["laser_power_488_pct"] = laser_power_488_pct if laser_power_561_pct is not None: - plan_kwargs['laser_power_561_pct'] = laser_power_561_pct + plan_kwargs["laser_power_561_pct"] = laser_power_561_pct if laser_power_405_pct is not None: - plan_kwargs['laser_power_405_pct'] = laser_power_405_pct + plan_kwargs["laser_power_405_pct"] = laser_power_405_pct if laser_power_637_pct is not None: - plan_kwargs['laser_power_637_pct'] = laser_power_637_pct + plan_kwargs["laser_power_637_pct"] = laser_power_637_pct result = await self._submit_plan_and_wait( - 'acquire_single_volume_plan', - kwargs=plan_kwargs, - timeout=120.0 + "acquire_single_volume_plan", kwargs=plan_kwargs, timeout=120.0 ) - extracted = self._extract_image(result, ['volume_scanner', 'camera', 'camera_image'], multi_event=True) + extracted = self._extract_image( + result, ["volume_scanner", "camera", "camera_image"], multi_event=True + ) if extracted: arr, fpath = extracted ret = { - 'volume': arr, - 'shape': arr.shape, - 'success': True, + "volume": arr, + "shape": arr.shape, + "success": True, } if fpath: - ret['volume_path'] = str(fpath) - elif result.get('volume_path'): - ret['volume_path'] = result['volume_path'] + ret["volume_path"] = str(fpath) + elif result.get("volume_path"): + ret["volume_path"] = result["volume_path"] return ret - return {'error': result.get('error', 'Acquisition failed'), 'success': False} + return {"error": result.get("error", "Acquisition failed"), "success": False} async def acquire_burst( self, @@ -629,13 +629,13 @@ async def acquire_burst( galvo_center: float = 0.0, piezo_amplitude: float = 25.0, piezo_center: float = 50.0, - laser_config: str = None, - laser_power_488_pct: float = None, - laser_power_561_pct: float = None, - laser_power_405_pct: float = None, - laser_power_637_pct: float = None, - timeout: float = None, - ) -> Dict: + laser_config: str | None = None, + laser_power_488_pct: float | None = None, + laser_power_561_pct: float | None = None, + laser_power_405_pct: float | None = None, + laser_power_637_pct: float | None = None, + timeout: float | None = None, + ) -> dict: """ Acquire ``frames`` volumes back-to-back as a single device-layer plan. @@ -657,76 +657,76 @@ async def acquire_burst( 'volume_path': str|None, 'shape': tuple}``. """ plan_kwargs = { - 'volume_scanner': 'volume_scanner', - 'frames': frames, - 'mode': mode, - 'num_slices': num_slices, - 'exposure_ms': exposure_ms, - 'galvo_amplitude': galvo_amplitude, - 'galvo_center': galvo_center, - 'piezo_amplitude': piezo_amplitude, - 'piezo_center': piezo_center, + "volume_scanner": "volume_scanner", + "frames": frames, + "mode": mode, + "num_slices": num_slices, + "exposure_ms": exposure_ms, + "galvo_amplitude": galvo_amplitude, + "galvo_center": galvo_center, + "piezo_amplitude": piezo_amplitude, + "piezo_center": piezo_center, } if laser_config is not None: - plan_kwargs['laser_config'] = laser_config + plan_kwargs["laser_config"] = laser_config if laser_power_488_pct is not None: - plan_kwargs['laser_power_488_pct'] = laser_power_488_pct + plan_kwargs["laser_power_488_pct"] = laser_power_488_pct if laser_power_561_pct is not None: - plan_kwargs['laser_power_561_pct'] = laser_power_561_pct + plan_kwargs["laser_power_561_pct"] = laser_power_561_pct if laser_power_405_pct is not None: - plan_kwargs['laser_power_405_pct'] = laser_power_405_pct + plan_kwargs["laser_power_405_pct"] = laser_power_405_pct if laser_power_637_pct is not None: - plan_kwargs['laser_power_637_pct'] = laser_power_637_pct + plan_kwargs["laser_power_637_pct"] = laser_power_637_pct if timeout is None: # 3 s/frame headroom (1 s pacing + ~1.5 s plan overhead) with a 60 s floor. timeout = max(60.0, frames * 3.0) result = await self._submit_plan_and_wait( - 'burst_plan', + "burst_plan", kwargs=plan_kwargs, timeout=timeout, ) - if not result.get('success'): - return {'success': False, 'error': result.get('error', 'Burst failed')} + if not result.get("success"): + return {"success": False, "error": result.get("error", "Burst failed")} # _submit_plan_and_wait already swapped file_refs for ndarrays in-place. # Walk every event and pull (volume, path) per frame. - frames_out: List[Dict] = [] - docs = result.get('documents', {}) or {} - events = docs.get('events', []) or [] - candidates = ('volume_scanner', 'camera', 'camera_image') + frames_out: list[dict] = [] + docs = result.get("documents", {}) or {} + events = docs.get("events", []) or [] + candidates = ("volume_scanner", "camera", "camera_image") for ev in events: - data = ev.get('data', {}) or {} + data = ev.get("data", {}) or {} for key in candidates: if key in data: val = data[key] - entry: Dict[str, Any] = {} + entry: dict[str, Any] = {} # Per-frame epoch time from the Bluesky event doc — lets the # orchestrator stamp each saved frame with its real acquisition # time instead of having to interpolate from the burst's # aggregate timing. - ev_time = ev.get('time') + ev_time = ev.get("time") if ev_time is not None: - entry['acquired_at_epoch'] = float(ev_time) + entry["acquired_at_epoch"] = float(ev_time) if isinstance(val, np.ndarray): - entry['volume'] = val - entry['shape'] = val.shape + entry["volume"] = val + entry["shape"] = val.shape elif self._is_file_ref(val): arr, path = self._resolve_file_ref(val) - entry['volume'] = arr - entry['shape'] = arr.shape - entry['volume_path'] = str(path) + entry["volume"] = arr + entry["shape"] = arr.shape + entry["volume_path"] = str(path) else: - entry['volume'] = np.array(val) - entry['shape'] = entry['volume'].shape - if '__resolved_paths__' in data: + entry["volume"] = np.array(val) + entry["shape"] = entry["volume"].shape + if "__resolved_paths__" in data: # _submit_plan_and_wait doesn't currently populate this # for events, but support it if a future change does. - rp = data['__resolved_paths__'] + rp = data["__resolved_paths__"] if isinstance(rp, dict) and key in rp: - entry['volume_path'] = str(rp[key]) + entry["volume_path"] = str(rp[key]) frames_out.append(entry) break @@ -736,32 +736,32 @@ async def acquire_burst( duration_s = 0.0 sustained_hz = 0.0 if events: - first_t = events[0].get('time') - last_t = events[-1].get('time') + first_t = events[0].get("time") + last_t = events[-1].get("time") if first_t is not None and last_t is not None and last_t > first_t: duration_s = float(last_t - first_t) if duration_s > 0: sustained_hz = len(frames_out) / duration_s return { - 'success': True, - 'frames': frames_out, - 'frames_captured': len(frames_out), - 'frames_requested': frames, - 'duration_s': duration_s, - 'sustained_hz': sustained_hz, - 'mode': mode, + "success": True, + "frames": frames_out, + "frames_captured": len(frames_out), + "frames_requested": frames, + "duration_s": duration_s, + "sustained_hz": sustained_hz, + "mode": mode, } # ========================================================================= # LED / Camera Controls # ========================================================================= - async def set_led(self, state: str = 'Closed') -> Dict: + async def set_led(self, state: str = "Closed") -> dict: """Set LED state ('Open' or 'Closed').""" - return await self._api_post('/api/led/set', {'state': state}) + return await self._api_post("/api/led/set", {"state": state}) - async def set_laser_power(self, wavelength: int, pct: float) -> Dict: + async def set_laser_power(self, wavelength: int, pct: float) -> dict: """Set per-line laser power %. Hits the device layer's ``POST /api/light_source/power`` directly @@ -777,12 +777,15 @@ async def set_laser_power(self, wavelength: int, pct: float) -> Dict: pct : float Setpoint percent (must be within hard limit for ``wavelength``). """ - return await self._api_post('/api/light_source/power', { - 'wavelength': int(wavelength), - 'pct': float(pct), - }) + return await self._api_post( + "/api/light_source/power", + { + "wavelength": int(wavelength), + "pct": float(pct), + }, + ) - async def get_laser_power(self, wavelength: int) -> Dict: + async def get_laser_power(self, wavelength: int) -> dict: """Read the current per-line laser power %. Hits ``GET /api/light_source/power?wavelength={wavelength}`` — @@ -796,16 +799,39 @@ async def get_laser_power(self, wavelength: int) -> Dict: ) as resp: return await resp.json() except Exception as e: - return {'success': False, 'error': str(e)} + return {"success": False, "error": str(e)} - async def get_led_status(self) -> Dict: + async def get_led_status(self) -> dict: """Get current LED status.""" - return await self._api_get('/api/led/status') + return await self._api_get("/api/led/status") + + async def set_room_light(self, state: str = "off") -> dict: + """Switch the diSPIM room light on/off via the SwitchBot Bot. + + Hits ``POST /api/room_light/set`` directly (no Bluesky queue, no + experiment trace) — a setup accessory poke. ``state`` is + 'on' | 'off' | 'press'. Blocks at the device layer until the BLE + command lands (~1-2 s). + """ + return await self._api_post("/api/room_light/set", {"state": state}) + + async def get_room_light_status(self) -> dict: + """Read the room light's cached on/off state (no BLE round-trip).""" + return await self._api_get("/api/room_light/status") + + async def set_temperature(self, target_c: float) -> dict: + """Command the thermal-controller setpoint (Celsius). Non-blocking — the + controller ramps; poll get_temperature() for the lock state.""" + return await self._api_post("/api/temperature/set", {"target_c": target_c}) + + async def get_temperature(self) -> dict: + """Get current temperature, setpoint, and lock state.""" + return await self._api_get("/api/temperature/status") # ------------------------------------------------------------------ # Live device-state readout (streamed from the device layer poller) # ------------------------------------------------------------------ - async def get_device_state(self, refresh: bool = False) -> Dict: + async def get_device_state(self, refresh: bool = False) -> dict: """One-shot snapshot of all device positions + properties. Parameters @@ -814,12 +840,12 @@ async def get_device_state(self, refresh: bool = False) -> Dict: If True, force the device layer to re-read MMCore right now. Otherwise return the most recent poller snapshot (typically <500 ms old). """ - path = '/api/devices/state' + path = "/api/devices/state" if refresh: - path += '?refresh=1' + path += "?refresh=1" return await self._api_get(path) - async def stream_device_states(self, timeout: Optional[float] = None): + async def stream_device_states(self, timeout: float | None = None): """Async generator yielding parsed device-state events from the SSE stream. Yields each event payload as a dict. Comment-style heartbeats (lines @@ -869,11 +895,12 @@ async def stream_device_states(self, timeout: Optional[float] = None): raw = b"\n".join(data_lines).decode("utf-8", errors="replace") try: import json as _json + yield _json.loads(raw) except Exception as exc: logger.warning("Malformed SSE payload skipped: %s", exc) - async def stream_bottom_camera(self, timeout: Optional[float] = None): + async def stream_bottom_camera(self, timeout: float | None = None): """Async generator yielding JPEG frames from the bottom-camera SSE stream. Mirrors :meth:`stream_device_states`. The device layer's streamer task @@ -912,23 +939,26 @@ async def stream_bottom_camera(self, timeout: Optional[float] = None): raw = b"\n".join(data_lines).decode("utf-8", errors="replace") try: import json as _json + yield _json.loads(raw) except Exception as exc: logger.warning("Malformed bottom-camera SSE payload skipped: %s", exc) - async def set_camera_led_mode(self, use_led: bool = False) -> Dict: + async def set_camera_led_mode(self, use_led: bool = False) -> dict: """Enable/disable automatic LED for bottom camera captures.""" - return await self._api_post('/api/camera/led_mode', {'use_led': use_led}) + return await self._api_post("/api/camera/led_mode", {"use_led": use_led}) - async def set_bottom_camera_exposure(self, exposure_ms: float) -> Dict: + async def set_bottom_camera_exposure(self, exposure_ms: float) -> dict: """Set bottom camera exposure time in milliseconds.""" - return await self._api_post('/api/camera/exposure', {'exposure_ms': exposure_ms}) + return await self._api_post("/api/camera/exposure", {"exposure_ms": exposure_ms}) - async def get_bottom_camera_exposure(self) -> Dict: + async def get_bottom_camera_exposure(self) -> dict: """Get current bottom camera exposure time.""" - return await self._api_get('/api/camera/exposure') + return await self._api_get("/api/camera/exposure") - async def capture_bottom_image(self, use_led: bool = False, exposure_ms: float = None) -> dict: + async def capture_bottom_image( + self, use_led: bool = False, exposure_ms: float | None = None + ) -> dict: """ Capture image from bottom camera. @@ -950,16 +980,17 @@ async def capture_bottom_image(self, use_led: bool = False, exposure_ms: float = await self.set_camera_led_mode(use_led) result = await self._submit_plan_and_wait( - 'capture_bottom_image_plan', - kwargs={'bottom_camera': 'bottom_camera'} + "capture_bottom_image_plan", kwargs={"bottom_camera": "bottom_camera"} ) - extracted = self._extract_image(result, ['bottom_camera', 'bottom_camera_image', 'Bottom PCO']) + extracted = self._extract_image( + result, ["bottom_camera", "bottom_camera_image", "Bottom PCO"] + ) if extracted: arr, fpath = extracted - return {'image': arr, 'image_path': fpath} + return {"image": arr, "image_path": fpath} - return {'image': np.zeros((100, 100), dtype=np.uint16), 'image_path': None} + return {"image": np.zeros((100, 100), dtype=np.uint16), "image_path": None} # ========================================================================= # SAM Embryo Detection (HTTP API) @@ -971,11 +1002,11 @@ async def detect_embryos( objective_mag: float = DEFAULT_OBJECTIVE_MAG, use_claude_review: bool = True, min_confidence: float = 0.7, - exposure_ms: float = None, + exposure_ms: float | None = None, brightness_percentile: float = 99.0, min_area: int = 5000, max_area: int = 150000, - ) -> Dict: + ) -> dict: """ Capture image and detect embryos using brightness detection + SAM. @@ -1008,7 +1039,7 @@ async def detect_embryos( 'image': np.ndarray, ...}`` """ if not self.has_sam: - return {'error': 'SAM detection not available on server'} + return {"error": "SAM detection not available on server"} self._ensure_connected() @@ -1016,65 +1047,79 @@ async def detect_embryos( logger.info("Calling /api/detect_embryos (server-side capture + SAM)...") payload = { - 'pixel_size_um': pixel_size_um, - 'objective_mag': objective_mag, - 'use_claude_review': use_claude_review, - 'min_confidence': min_confidence, - 'brightness_percentile': brightness_percentile, - 'min_area': min_area, - 'max_area': max_area, + "pixel_size_um": pixel_size_um, + "objective_mag": objective_mag, + "use_claude_review": use_claude_review, + "min_confidence": min_confidence, + "brightness_percentile": brightness_percentile, + "min_area": min_area, + "max_area": max_area, } if exposure_ms is not None: - payload['exposure_ms'] = exposure_ms + payload["exposure_ms"] = exposure_ms async with self._session.post( - f"{self.http_url}/api/detect_embryos", - json=payload, - timeout=300 + f"{self.http_url}/api/detect_embryos", json=payload, timeout=300 ) as resp: result = await resp.json() - if not result.get('success'): + if not result.get("success"): return result # Ensure the caller has the image to feed into the map view. - if result.get('image') is None: + if result.get("image") is None: image = await self._get_detection_image(result, exposure_ms) if image is not None: - result['image'] = image + result["image"] = image return result except asyncio.TimeoutError: - return {'success': False, 'error': 'Detection timed out (5 min limit)'} + return {"success": False, "error": "Detection timed out (5 min limit)"} except aiohttp.ClientError as e: - return {'error': str(NetworkError(f"Device layer request failed: {e}")), - 'traceback': traceback.format_exc(), 'success': False} + return { + "error": str(NetworkError(f"Device layer request failed: {e}")), + "traceback": traceback.format_exc(), + "success": False, + } except Exception as e: - return {'error': str(e), 'traceback': traceback.format_exc(), 'success': False} + return { + "error": str(e), + "traceback": traceback.format_exc(), + "success": False, + } - async def _get_detection_image(self, detection_result: dict, exposure_ms: float = None) -> Optional[np.ndarray]: + async def _get_detection_image( + self, detection_result: dict, exposure_ms: float | None = None + ) -> np.ndarray | None: """Load or capture an image for the detection editor.""" - image_path = detection_result.get('image_path') + image_path = detection_result.get("image_path") if image_path: try: import tifffile + return tifffile.imread(image_path) except Exception: pass # Fallback: capture a fresh image snap = await self.capture_bottom_image(exposure_ms=exposure_ms) - image = snap['image'] - if snap.get('image_path'): + image = snap["image"] + if snap.get("image_path"): try: - snap['image_path'].unlink(missing_ok=True) + snap["image_path"].unlink(missing_ok=True) except OSError: pass return image - async def view_image(self, image: np.ndarray = None, title: str = "Image View", - exposure_ms: float = None, save_path: str = None, - show: bool = True, embryo_annotations: list = None) -> Dict: + async def view_image( + self, + image: np.ndarray | None = None, + title: str = "Image View", + exposure_ms: float | None = None, + save_path: str | None = None, + show: bool = True, + embryo_annotations: list | None = None, + ) -> dict: """Save a bottom-camera image to disk (replaces the napari display). ``show`` and ``title`` are kept for backwards compatibility with @@ -1084,54 +1129,63 @@ async def view_image(self, image: np.ndarray = None, title: str = "Image View", try: if image is None: snap = await self.capture_bottom_image(exposure_ms=exposure_ms) - image = snap['image'] - if snap.get('image_path'): + image = snap["image"] + if snap.get("image_path"): try: - snap['image_path'].unlink(missing_ok=True) + snap["image_path"].unlink(missing_ok=True) except OSError: pass if image is None: - return {'success': False, 'error': 'No image to save'} + return {"success": False, "error": "No image to save"} - result = {'success': True, 'shape': list(image.shape)} + result = {"success": True, "shape": list(image.shape)} if save_path: # Reuse the existing PNG writer; draws annotations if any. from pathlib import Path as _Path + _Path(save_path).parent.mkdir(parents=True, exist_ok=True) annots = [] - for a in (embryo_annotations or []): - px = a.get('pixel_x') - py = a.get('pixel_y') + for a in embryo_annotations or []: + px = a.get("pixel_x") + py = a.get("pixel_y") if px is None or py is None: continue - annots.append({ - 'embryo_number': a.get('label') or a.get('embryo_id') or '?', - 'pixel_position': (px, py), - }) + annots.append( + { + "embryo_number": a.get("label") or a.get("embryo_id") or "?", + "pixel_position": (px, py), + } + ) if annots: from gently.ui.web.embryo_marker import _save_marked_image + _save_marked_image(image, annots, _Path(save_path)) else: from PIL import Image as _PILImage + arr = image if arr.dtype != np.uint8: lo, hi = arr.min(), arr.max() arr = ((arr - lo) / max(hi - lo, 1) * 255).astype(np.uint8) _PILImage.fromarray(arr).save(save_path) - result['saved_to'] = save_path + result["saved_to"] = save_path return result except Exception as e: - return {'error': str(e), 'traceback': traceback.format_exc(), 'success': False} + return { + "error": str(e), + "traceback": traceback.format_exc(), + "success": False, + } async def view_embryos( self, image: np.ndarray, - embryos: List[Dict], + embryos: list[dict], title: str = "Embryos", - save_path: Optional[str] = None, + save_path: str | None = None, show: bool = True, - ) -> Dict: + ) -> dict: """Save an annotated PNG of embryos on an image (replaces napari). Markers in ``embryos`` may use ``center_x``/``center_y`` or @@ -1140,32 +1194,40 @@ async def view_embryos( """ try: if image is None or not embryos: - return {'success': False, 'error': 'No image or embryos to display'} + return {"success": False, "error": "No image or embryos to display"} annots = [] for emb in embryos: - px = emb.get('center_x', emb.get('pixel_x', 0)) - py = emb.get('center_y', emb.get('pixel_y', 0)) - annots.append({ - 'embryo_number': emb.get('embryo_id', '?'), - 'pixel_position': (px, py), - }) - - result = {'success': True, 'num_embryos': len(embryos)} + px = emb.get("center_x", emb.get("pixel_x", 0)) + py = emb.get("center_y", emb.get("pixel_y", 0)) + annots.append( + { + "embryo_number": emb.get("embryo_id", "?"), + "pixel_position": (px, py), + } + ) + + result = {"success": True, "num_embryos": len(embryos)} if save_path: from pathlib import Path as _Path + _Path(save_path).parent.mkdir(parents=True, exist_ok=True) from gently.ui.web.embryo_marker import _save_marked_image + _save_marked_image(image, annots, _Path(save_path)) - result['saved_to'] = save_path + result["saved_to"] = save_path return result except Exception as e: - return {'error': str(e), 'traceback': traceback.format_exc(), 'success': False} + return { + "error": str(e), + "traceback": traceback.format_exc(), + "success": False, + } async def capture_for_marking( self, - exposure_ms: float = None, - ) -> Dict: + exposure_ms: float | None = None, + ) -> dict: """ Capture a bottom-camera image for manual marking in the map view. @@ -1183,34 +1245,38 @@ async def capture_for_marking( try: snap = await self.capture_bottom_image(use_led=True, exposure_ms=exposure_ms) - image = snap['image'] + image = snap["image"] if image is None or (image.shape == (100, 100) and image.max() == 0): - return {'success': False, 'error': 'Failed to capture image'} + return {"success": False, "error": "Failed to capture image"} - if snap.get('image_path'): + if snap.get("image_path"): try: - snap['image_path'].unlink(missing_ok=True) + snap["image_path"].unlink(missing_ok=True) except OSError: pass stage_pos = await self.get_stage_position() return { - 'success': True, - 'image': image, - 'stage_position': list(stage_pos), - 'image_shape': list(image.shape), + "success": True, + "image": image, + "stage_position": list(stage_pos), + "image_shape": list(image.shape), } except Exception as e: - return {'error': str(e), 'traceback': traceback.format_exc(), 'success': False} + return { + "error": str(e), + "traceback": traceback.format_exc(), + "success": False, + } # ========================================================================= # Status # ========================================================================= - async def _get_server_status(self) -> Dict: + async def _get_server_status(self) -> dict: """Query device layer and SAM server status via HTTP.""" status = {} @@ -1220,18 +1286,18 @@ async def _get_server_status(self) -> Dict: async with self._session.get(f"{self.http_url}/api/status") as resp: if resp.status == 200: server_status = await resp.json() - status['queue_server'] = { - 'manager_state': server_status.get('manager_state'), - 're_state': server_status.get('re_state', 'idle'), - 'devices': server_status.get('devices', []), - 'plans': server_status.get('plans', []), + status["queue_server"] = { + "manager_state": server_status.get("manager_state"), + "re_state": server_status.get("re_state", "idle"), + "devices": server_status.get("devices", []), + "plans": server_status.get("plans", []), } else: - status['queue_server'] = {'error': f'HTTP {resp.status}'} + status["queue_server"] = {"error": f"HTTP {resp.status}"} except (aiohttp.ClientError, Exception) as e: - status['queue_server'] = {'error': str(e)} + status["queue_server"] = {"error": str(e)} else: - status['queue_server'] = {'connected': False} + status["queue_server"] = {"connected": False} # SAM status (via HTTP, same server) if self._session and self._qs_connected: @@ -1239,21 +1305,20 @@ async def _get_server_status(self) -> Dict: async with self._session.get(f"{self.http_url}/api/sam/status") as resp: if resp.status == 200: sam_status = await resp.json() - status['sam_server'] = { - 'available': sam_status.get('available', False), - 'loaded': sam_status.get('loaded', False), - 'device': sam_status.get('device', 'unknown'), + status["sam_server"] = { + "available": sam_status.get("available", False), + "loaded": sam_status.get("loaded", False), + "device": sam_status.get("device", "unknown"), } else: - status['sam_server'] = {'error': f'HTTP {resp.status}'} + status["sam_server"] = {"error": f"HTTP {resp.status}"} except (aiohttp.ClientError, Exception) as e: - status['sam_server'] = {'error': str(e)} + status["sam_server"] = {"error": str(e)} else: - status['sam_server'] = {'connected': False} + status["sam_server"] = {"connected": False} return status - # ========================================================================= # Microscope plan implementations # These map plan names to the existing methods above, enabling @@ -1302,7 +1367,7 @@ async def _plan_status(self, **kw) -> dict: async def create_queue_server_client( http_url: str = f"http://{settings.network.device_host}:{settings.network.device_port}", -) -> Optional[DiSPIMMicroscope]: +) -> DiSPIMMicroscope | None: """Create and connect a diSPIM microscope client.""" client = DiSPIMMicroscope(http_url=http_url) if await client.connect(): diff --git a/gently/hardware/dispim/config.py b/gently/hardware/dispim/config.py index 3255b170..55a620a7 100644 --- a/gently/hardware/dispim/config.py +++ b/gently/hardware/dispim/config.py @@ -18,19 +18,17 @@ >>> camera_config = CameraConfig(mode=CameraMode.EXTERNAL_PROGRESSIVE, exposure_ms=10.0) """ -import math import json -from dataclasses import dataclass, asdict +import math +from dataclasses import asdict, dataclass from enum import Enum from pathlib import Path -from typing import Optional, Tuple -from datetime import datetime - # ============================================================================ # ENUMERATIONS # ============================================================================ + class CameraMode(Enum): """ Camera trigger and sensor mode configurations. @@ -39,12 +37,14 @@ class CameraMode(Enum): - INTERNAL_AREA: For calibration, live view, manual acquisitions - EXTERNAL_PROGRESSIVE: For hardware-triggered SPIM volumes """ - INTERNAL_AREA = "internal_area" # Manual trigger, full sensor + + INTERNAL_AREA = "internal_area" # Manual trigger, full sensor EXTERNAL_PROGRESSIVE = "external_progressive" # Hardware trigger, progressive readout class ScannerPattern(Enum): """Galvo scanner waveform patterns.""" + TRIANGLE = "1 - Triangle" SAWTOOTH = "2 - Sawtooth" RAMP = "3 - Ramp" @@ -52,6 +52,7 @@ class ScannerPattern(Enum): class ScannerMode(Enum): """Galvo scanner axis control modes.""" + DISABLED = "1 - Disabled" INTERNAL = "2 - Internal" ENABLED_SYNCED = "3 - Enabled with axes synced" @@ -61,6 +62,7 @@ class ScannerMode(Enum): # HARDWARE PROFILES # ============================================================================ + @dataclass class HardwareProfile: """ @@ -83,17 +85,19 @@ class HardwareProfile: Custom profile for faster camera: >>> profile = HardwareProfile(camera_reset_ms=2.0, camera_readout_ms=8.0) """ - camera_reset_ms: float = 3.0 # Hamamatsu Flash4 reset time - camera_readout_ms: float = 10.0 # Typical for 2048x512 ROI - scan_laser_buffer_ms: float = 0.25 # Buffer before/after laser pulse - scan_filter_freq_khz: float = 0.2 # Scanner filter frequency - has_plogic: bool = False # PLogic card present + + camera_reset_ms: float = 3.0 # Hamamatsu Flash4 reset time + camera_readout_ms: float = 10.0 # Typical for 2048x512 ROI + scan_laser_buffer_ms: float = 0.25 # Buffer before/after laser pulse + scan_filter_freq_khz: float = 0.2 # Scanner filter frequency + has_plogic: bool = False # PLogic card present # ============================================================================ # CAMERA CONFIGURATION # ============================================================================ + @dataclass class CameraConfig: """ @@ -120,9 +124,10 @@ class CameraConfig: ... exposure_ms=10.0 ... ) """ + mode: CameraMode exposure_ms: float - roi: Tuple[int, int, int, int] = (128, 896, 2048, 512) # Default diSPIM ROI + roi: tuple[int, int, int, int] = (128, 896, 2048, 512) # Default diSPIM ROI def to_mm_properties(self) -> dict: """ @@ -138,15 +143,12 @@ def to_mm_properties(self) -> dict: ... core.setProperty(camera_device, name, value) """ if self.mode == CameraMode.INTERNAL_AREA: - return { - "TRIGGER SOURCE": "INTERNAL", - "SENSOR MODE": "AREA" - } + return {"TRIGGER SOURCE": "INTERNAL", "SENSOR MODE": "AREA"} elif self.mode == CameraMode.EXTERNAL_PROGRESSIVE: return { "TRIGGER SOURCE": "EXTERNAL", "SENSOR MODE": "PROGRESSIVE", - "TRIGGER ACTIVE": "EDGE" + "TRIGGER ACTIVE": "EDGE", } else: raise ValueError(f"Unknown camera mode: {self.mode}") @@ -156,6 +158,7 @@ def to_mm_properties(self) -> dict: # SCANNER/GALVO CONFIGURATION # ============================================================================ + @dataclass class GalvoAxisConfig: """ @@ -188,6 +191,7 @@ class GalvoAxisConfig: ... mode=ScannerMode.ENABLED_SYNCED ... ) """ + amplitude_deg: float offset_deg: float pattern: ScannerPattern = ScannerPattern.TRIANGLE @@ -212,7 +216,7 @@ def to_mm_properties(self, axis: str) -> dict: f"SingleAxis{axis}Amplitude(deg)": float(self.amplitude_deg), f"SingleAxis{axis}Offset(deg)": float(self.offset_deg), f"SingleAxis{axis}Pattern": self.pattern.value, - f"SingleAxis{axis}Mode": self.mode.value + f"SingleAxis{axis}Mode": self.mode.value, } @@ -234,6 +238,7 @@ class ScannerConfig: ... beam_enabled=True ... ) """ + x_axis: GalvoAxisConfig y_axis: GalvoAxisConfig beam_enabled: bool = True @@ -262,9 +267,9 @@ def to_mm_properties(self) -> dict: # SPIM TIMING CALCULATION # ============================================================================ + def calculate_spim_timing( - camera_exposure_ms: float, - hardware_profile: Optional[HardwareProfile] = None + camera_exposure_ms: float, hardware_profile: HardwareProfile | None = None ) -> dict: """ Calculate SPIM hardware timing parameters for synchronized acquisition. @@ -356,15 +361,15 @@ def ceil_quarter_ms(val: float) -> float: frame_rate = 1000.0 / slice_duration if slice_duration > 0 else 0.0 return { - 'scanDelay': scan_delay, - 'scanPeriod': scan_period, - 'laserDelay': laser_delay, - 'laserDuration': laser_duration, - 'cameraDelay': camera_delay, - 'cameraDuration': camera_duration, # Must be > 0! - 'cameraExposure': camera_exposure, - 'sliceDuration': slice_duration, - 'frameRate': frame_rate + "scanDelay": scan_delay, + "scanPeriod": scan_period, + "laserDelay": laser_delay, + "laserDuration": laser_duration, + "cameraDelay": camera_delay, + "cameraDuration": camera_duration, # Must be > 0! + "cameraExposure": camera_exposure, + "sliceDuration": slice_duration, + "frameRate": frame_rate, } @@ -372,6 +377,7 @@ def ceil_quarter_ms(val: float) -> float: # PIEZO-GALVO CALIBRATION # ============================================================================ + @dataclass class PiezoGalvoCalibration: """ @@ -420,6 +426,7 @@ class PiezoGalvoCalibration: Calculate piezo position for given galvo angle: >>> piezo_pos = calib.galvo_to_piezo(0.15) # galvo at +0.15° """ + slope_um_per_deg: float offset_um: float galvo_top_deg: float @@ -432,10 +439,10 @@ class PiezoGalvoCalibration: device_galvo: str = "Scanner:AB:33" # Optional metadata - edge_top_deg: Optional[float] = None - edge_bottom_deg: Optional[float] = None - calib_inset_fraction: Optional[float] = None - calib_strategy: Optional[str] = None + edge_top_deg: float | None = None + edge_bottom_deg: float | None = None + calib_inset_fraction: float | None = None + calib_strategy: str | None = None def galvo_to_piezo(self, galvo_deg: float) -> float: """ @@ -469,7 +476,7 @@ def piezo_to_galvo(self, piezo_um: float) -> float: """ return (piezo_um - self.offset_um) / self.slope_um_per_deg - def get_scan_range(self, tolerance_multiplier: float = 1.0) -> Tuple[float, float]: + def get_scan_range(self, tolerance_multiplier: float = 1.0) -> tuple[float, float]: """ Get galvo scan range with optional tolerance multiplier. @@ -486,7 +493,7 @@ def get_scan_range(self, tolerance_multiplier: float = 1.0) -> Tuple[float, floa # If we have edge data, adjust tolerance if self.edge_top_deg is not None and self.edge_bottom_deg is not None: # Estimate original tolerance - original_tolerance = (self.galvo_top_deg - self.edge_top_deg) # Should be negative + original_tolerance = self.galvo_top_deg - self.edge_top_deg # Should be negative # Apply multiplier new_tolerance = abs(original_tolerance) * tolerance_multiplier @@ -501,7 +508,7 @@ def get_scan_range(self, tolerance_multiplier: float = 1.0) -> Tuple[float, floa return (self.galvo_top_deg, self.galvo_bottom_deg) @classmethod - def from_file(cls, path: Path) -> 'PiezoGalvoCalibration': + def from_file(cls, path: Path) -> "PiezoGalvoCalibration": """ Load calibration from JSON file. @@ -523,7 +530,7 @@ def from_file(cls, path: Path) -> 'PiezoGalvoCalibration': if not path.exists(): raise FileNotFoundError(f"Calibration file not found: {path}") - with open(path, 'r') as f: + with open(path) as f: data = json.load(f) return cls(**data) @@ -542,7 +549,7 @@ def to_file(self, path: Path): path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) - with open(path, 'w') as f: + with open(path, "w") as f: json.dump(asdict(self), f, indent=2) def __str__(self) -> str: @@ -560,6 +567,7 @@ def __str__(self) -> str: # PRESET CONFIGURATIONS # ============================================================================ + def get_calibration_camera_config(exposure_ms: float = 50.0) -> CameraConfig: """ Get camera configuration for calibration mode. @@ -570,10 +578,7 @@ def get_calibration_camera_config(exposure_ms: float = 50.0) -> CameraConfig: Returns: CameraConfig for INTERNAL trigger, AREA sensor mode """ - return CameraConfig( - mode=CameraMode.INTERNAL_AREA, - exposure_ms=exposure_ms - ) + return CameraConfig(mode=CameraMode.INTERNAL_AREA, exposure_ms=exposure_ms) def get_hardware_spim_camera_config(exposure_ms: float = 10.0) -> CameraConfig: @@ -586,10 +591,7 @@ def get_hardware_spim_camera_config(exposure_ms: float = 10.0) -> CameraConfig: Returns: CameraConfig for EXTERNAL trigger, PROGRESSIVE sensor mode """ - return CameraConfig( - mode=CameraMode.EXTERNAL_PROGRESSIVE, - exposure_ms=exposure_ms - ) + return CameraConfig(mode=CameraMode.EXTERNAL_PROGRESSIVE, exposure_ms=exposure_ms) def get_standard_scanner_config(y_amplitude_deg: float = 0.04) -> ScannerConfig: @@ -607,13 +609,13 @@ def get_standard_scanner_config(y_amplitude_deg: float = 0.04) -> ScannerConfig: amplitude_deg=8.0, offset_deg=0.0005, pattern=ScannerPattern.TRIANGLE, - mode=ScannerMode.ENABLED_SYNCED + mode=ScannerMode.ENABLED_SYNCED, ), y_axis=GalvoAxisConfig( amplitude_deg=y_amplitude_deg, offset_deg=0.0, pattern=ScannerPattern.TRIANGLE, - mode=ScannerMode.ENABLED_SYNCED + mode=ScannerMode.ENABLED_SYNCED, ), - beam_enabled=True + beam_enabled=True, ) diff --git a/gently/hardware/dispim/device_factory.py b/gently/hardware/dispim/device_factory.py index 44399076..e914b1b1 100644 --- a/gently/hardware/dispim/device_factory.py +++ b/gently/hardware/dispim/device_factory.py @@ -5,15 +5,13 @@ """ import logging -from typing import Dict, Optional -from pathlib import Path + import pymmcore logger = logging.getLogger(__name__) -def create_devices_from_mmcore(core: pymmcore.CMMCore, - config: Optional[Dict] = None) -> Dict: +def create_devices_from_mmcore(core: pymmcore.CMMCore, config: dict | None = None) -> dict: """ Create all necessary Ophyd devices from Micro-Manager core @@ -34,6 +32,7 @@ def create_devices_from_mmcore(core: pymmcore.CMMCore, - lightsheet_snap: DiSPIMLightSheetSnap - scanner: DiSPIMScanner - piezo: DiSPIMPiezo + - fdrive: DiSPIMFDrive (SPIM head Z / F axis) Example ------- @@ -53,22 +52,23 @@ def create_devices_from_mmcore(core: pymmcore.CMMCore, """ # Import devices only when needed (avoids ophyd import issues) from .devices import ( - DiSPIMXYStage, - DiSPIMVolumeScanner, DiSPIMBottomCamera, DiSPIMLightSheetSnap, + DiSPIMPiezo, DiSPIMScanner, - DiSPIMPiezo + DiSPIMVolumeScanner, + DiSPIMXYStage, ) # Default device configuration (from MMConfig_tracking_screening.cfg) default_config = { - 'xy_stage_name': 'XYStage:XY:31', - 'camera_name': 'HamCam1', - 'scanner_name': 'Scanner:AB:33', - 'piezo_name': 'PiezoStage:P:34', - 'bottom_camera_name': 'Bottom PCO', - 'led_name': 'LED:X:31' + "xy_stage_name": "XYStage:XY:31", + "camera_name": "HamCam1", + "scanner_name": "Scanner:AB:33", + "piezo_name": "PiezoStage:P:34", + "fdrive_name": "ZStage:V:37", + "bottom_camera_name": "Bottom PCO", + "led_name": "LED:X:31", } # Merge with user config @@ -86,62 +86,77 @@ def create_devices_from_mmcore(core: pymmcore.CMMCore, led = None try: - scanner = DiSPIMScanner(name=cfg['scanner_name'], core=core) - devices['scanner'] = scanner - logger.info("Created scanner: %s", cfg['scanner_name']) + scanner = DiSPIMScanner(name=cfg["scanner_name"], core=core) + devices["scanner"] = scanner + logger.info("Created scanner: %s", cfg["scanner_name"]) except Exception as e: logger.warning("Could not create scanner: %s", e) try: - piezo = DiSPIMPiezo(name=cfg['piezo_name'], core=core) - devices['piezo'] = piezo - logger.info("Created piezo: %s", cfg['piezo_name']) + piezo = DiSPIMPiezo(name=cfg["piezo_name"], core=core) + devices["piezo"] = piezo + logger.info("Created piezo: %s", cfg["piezo_name"]) except Exception as e: logger.warning("Could not create piezo: %s", e) try: from .devices import DiSPIMCamera - camera = DiSPIMCamera(device_name=cfg['camera_name'], core=core) - devices['camera'] = camera - logger.info("Created camera: %s", cfg['camera_name']) + + camera = DiSPIMCamera(device_name=cfg["camera_name"], core=core) + devices["camera"] = camera + logger.info("Created camera: %s", cfg["camera_name"]) except Exception as e: logger.warning("Could not create camera: %s", e) try: from .devices import DiSPIMLightSource + # Single instance, registered under both names: the ophyd name and # devices-dict key keep the historical "laser_control" identifier # so existing Bluesky plans (which take a `laser_control=...` kwarg) # continue to work; "light_source" is the new canonical alias for # callers that want the broader concept (power + channel). - laser_control = DiSPIMLightSource(core=core, name='laser_control', group_name="Laser") - devices['laser_control'] = laser_control - devices['light_source'] = laser_control + laser_control = DiSPIMLightSource(core=core, name="laser_control", group_name="Laser") + devices["laser_control"] = laser_control + devices["light_source"] = laser_control logger.info("Created light source (laser_control)") except Exception as e: logger.warning("Could not create light source: %s", e) try: - devices['xy_stage'] = DiSPIMXYStage(name=cfg['xy_stage_name'], core=core) - logger.info("Created XY stage: %s", cfg['xy_stage_name']) + devices["xy_stage"] = DiSPIMXYStage(name=cfg["xy_stage_name"], core=core) + logger.info("Created XY stage: %s", cfg["xy_stage_name"]) except Exception as e: logger.warning("Could not create XY stage: %s", e) try: - if cfg.get('led_name'): + from .devices import DiSPIMFDrive + + devices["fdrive"] = DiSPIMFDrive(name=cfg["fdrive_name"], core=core) + logger.info("Created F-drive (SPIM head): %s", cfg["fdrive_name"]) + except Exception as e: + logger.warning("Could not create F-drive (SPIM head): %s", e) + + try: + if cfg.get("led_name"): from .devices import DiSPIMLED - led = DiSPIMLED(core=core, name=cfg['led_name'], group_name="LED") - devices['led'] = led - logger.info("Created LED: %s", cfg['led_name']) + + led = DiSPIMLED(core=core, name=cfg["led_name"], group_name="LED") + devices["led"] = led + logger.info("Created LED: %s", cfg["led_name"]) except Exception as e: logger.warning("Could not create LED: %s", e) # Compound devices try: if scanner and camera and piezo and laser_control: - devices['volume_scanner'] = DiSPIMVolumeScanner( - scanner=scanner, camera=camera, piezo=piezo, - laser_control=laser_control, core=core, name='volume_scanner', + devices["volume_scanner"] = DiSPIMVolumeScanner( + scanner=scanner, + camera=camera, + piezo=piezo, + laser_control=laser_control, + core=core, + name="volume_scanner", ) logger.info("Created volume scanner") else: @@ -152,11 +167,14 @@ def create_devices_from_mmcore(core: pymmcore.CMMCore, try: if led: bottom_camera = DiSPIMBottomCamera( - device_name=cfg['bottom_camera_name'], core=core, - led_control=led, pixel_size_um=6.5, magnification=10.0, + device_name=cfg["bottom_camera_name"], + core=core, + led_control=led, + pixel_size_um=6.5, + magnification=10.0, ) - devices['bottom_camera'] = bottom_camera - logger.info("Created bottom camera: %s", cfg['bottom_camera_name']) + devices["bottom_camera"] = bottom_camera + logger.info("Created bottom camera: %s", cfg["bottom_camera_name"]) else: logger.warning("Skipping bottom camera (missing LED device)") except Exception as e: @@ -164,8 +182,10 @@ def create_devices_from_mmcore(core: pymmcore.CMMCore, try: if scanner and camera: - devices['lightsheet_snap'] = DiSPIMLightSheetSnap( - scanner=scanner, camera=camera, name='lightsheet_snap', + devices["lightsheet_snap"] = DiSPIMLightSheetSnap( + scanner=scanner, + camera=camera, + name="lightsheet_snap", ) logger.info("Created lightsheet snap device") else: @@ -177,5 +197,3 @@ def create_devices_from_mmcore(core: pymmcore.CMMCore, raise RuntimeError("Failed to create any devices. Check your Micro-Manager configuration.") return devices - - diff --git a/gently/hardware/dispim/device_layer.py b/gently/hardware/dispim/device_layer.py index 1172ba1d..6d606621 100644 --- a/gently/hardware/dispim/device_layer.py +++ b/gently/hardware/dispim/device_layer.py @@ -19,16 +19,14 @@ import asyncio import contextlib -import logging import json +import logging import os import sys import time -from pathlib import Path -from typing import Any, Dict, List, Optional - -logger = logging.getLogger(__name__) from dataclasses import dataclass, field +from pathlib import Path +from typing import Any import numpy as np @@ -37,24 +35,27 @@ if str(project_root) not in sys.path: sys.path.insert(0, str(project_root)) -from aiohttp import web -import yaml - -from gently.core.service import Service -from gently.exceptions import HardwareError, AcquisitionError -from gently.log_config import configure_logging -from gently.settings import settings +import yaml # noqa: E402 +from aiohttp import web # noqa: E402 # Bluesky imports -from bluesky import RunEngine +from bluesky import RunEngine # noqa: E402 + +from gently.core.service import Service # noqa: E402 +from gently.hardware import console_ui as cui # noqa: E402 +from gently.log_config import configure_logging # noqa: E402 +from gently.settings import settings # noqa: E402 + +logger = logging.getLogger(__name__) # BestEffortCallback removed — unused @dataclass class PlanRequest: """A request to run a plan""" + plan_name: str - kwargs: Dict[str, Any] + kwargs: dict[str, Any] future: asyncio.Future = field(default_factory=lambda: asyncio.get_event_loop().create_future()) @@ -102,7 +103,7 @@ def __init__( # When set, large numpy arrays are written as TIFF files instead of # being serialized to JSON lists (which can turn a 400MB uint16 volume # into ~2GB of JSON text). - self._volume_dir: Optional[str] = None + self._volume_dir: str | None = None # Server lifecycle objects (populated in on_start) self._app = None @@ -126,34 +127,34 @@ def __init__( # run. Plain stage moves / LED changes / snaps don't pause — the # adapter's per-device mutex handles the contention fine, and we # want the readout to stay live. - self._state_pos_interval_sec = 0.2 # 5 Hz target for XY (hard floor ~4 Hz on ASI) + self._state_pos_interval_sec = 0.2 # 5 Hz target for XY (hard floor ~4 Hz on ASI) self._state_slow_pos_interval_sec = 1.0 # 1 Hz piezo + galvo - self._state_prop_interval_sec = 15.0 # ~0.07 Hz full-state cadence + self._state_prop_interval_sec = 15.0 # ~0.07 Hz full-state cadence self._state_pause_counter = 0 self._state_updating_now = False - self._state_latest: Dict[str, Any] = { + self._state_latest: dict[str, Any] = { "positions": {}, "properties": {}, "t": 0.0, "paused": False, } - self._state_subscribers: List[asyncio.Queue] = [] - self._state_pos_task: Optional[asyncio.Task] = None - self._state_slow_pos_task: Optional[asyncio.Task] = None - self._state_prop_task: Optional[asyncio.Task] = None + self._state_subscribers: list[asyncio.Queue] = [] + self._state_pos_task: asyncio.Task | None = None + self._state_slow_pos_task: asyncio.Task | None = None + self._state_prop_task: asyncio.Task | None = None # MMCore push-callback support. Adapters that fire OnPropertyChanged / # OnXYStagePositionChanged etc. let us skip polling for those events. # Whether the ASI Tiger adapter fires on joystick moves is unknown — # the bridge logs every callback and broadcasts to a dedicated SSE # stream so it can be tested empirically. - self._mm_callback_bridge = None # MMEventCallback subclass - self._mm_callback_loop: Optional[asyncio.AbstractEventLoop] = None - self._callback_subscribers: List[asyncio.Queue] = [] + self._mm_callback_bridge = None # MMEventCallback subclass + self._mm_callback_loop: asyncio.AbstractEventLoop | None = None + self._callback_subscribers: list[asyncio.Queue] = [] # Debounce timer for state-stream broadcasts triggered by callbacks. # A flurry of OnPropertyChanged events (e.g. during config-group load) # gets coalesced into a single broadcast ~50 ms later. - self._callback_broadcast_handle: Optional[asyncio.Handle] = None + self._callback_broadcast_handle: asyncio.Handle | None = None self._callback_broadcast_debounce_sec: float = 0.05 # Bottom-camera live stream (Phase-1 thumbnail). Off by default; the @@ -161,22 +162,29 @@ def __init__( # so the camera is never grabbed when nobody is watching. # Tuned for low latency: small thumbnail + cheap auto-contrast keeps # the encode path under ~5 ms per frame on the encoding thread. - self._cam_subscribers: List[asyncio.Queue] = [] - self._cam_task: Optional[asyncio.Task] = None - self._cam_interval_sec: float = 0.25 # 4 Hz - self._cam_target_max_dim: int = 360 # ~360px thumbnail + self._cam_subscribers: list[asyncio.Queue] = [] + self._cam_task: asyncio.Task | None = None + self._cam_interval_sec: float = 0.25 # 4 Hz + self._cam_target_max_dim: int = 360 # ~360px thumbnail self._cam_jpeg_quality: int = 55 # Plans that hold MMCore for long performance-critical work. # Anything in this set runs with state polling paused. - self._heavy_plans = frozenset({ - 'acquire_single_volume_plan', - 'burst_plan', - 'timelapse_volume_plan', - 'focus_sweep_plan', - 'calibrate_piezo_galvo_plan', - 'multi_embryo_calibration_workflow', - }) + self._heavy_plans = frozenset( + { + "acquire_single_volume_plan", + "burst_plan", + "timelapse_volume_plan", + "focus_sweep_plan", + "calibrate_piezo_galvo_plan", + "multi_embryo_calibration_workflow", + # Slow F-drive traverse (25000 -> ~50 um) + fine focus scan + the + # XY view-registration loop all hold MMCore for a while. + "spim_head_focus_descent_plan", + "register_views_xy_plan", + "spim_head_focus_and_align_plan", + } + ) async def initialize(self): """Initialize hardware and RunEngine""" @@ -184,15 +192,21 @@ async def initialize(self): logger.info("GENTLY DEVICE LAYER") logger.info("=" * 60) + cui.out() + cui.note("Starting device layer...", "bold") + # [1/5] Load config + cui.step(1, 5, "Loading configuration") logger.info("[1/5] Loading configuration...") - with open(self.config_path, 'r') as f: + with open(self.config_path) as f: self.config = yaml.safe_load(f) logger.info("Config loaded from %s", self.config_path) + cui.step_done(str(self.config_path)) # [2/5] MMCore initialization, routed through the DiSPIMSystem facade # so this process never touches `core.*` directly outside the # devices/ package. + cui.step(2, 5, "Initializing Micro-Manager core") logger.info("[2/5] Initializing Micro-Manager Core (direct)...") from .devices.system import DiSPIMSystem @@ -200,12 +214,12 @@ async def initialize(self): self.system.enable_stderr_log(True) # Add MM directory to PATH for device adapters - mm_directory = self.config.get('mmdirectory', 'C:/Program Files/Micro-Manager-1.4') + mm_directory = self.config.get("mmdirectory", "C:/Program Files/Micro-Manager-1.4") os.environ["PATH"] += os.pathsep + mm_directory self.system.set_device_adapter_search_paths([mm_directory]) # Load system configuration - mm_config = self.config.get('mmconfig', 'MMConfig.cfg') + mm_config = self.config.get("mmconfig", "MMConfig.cfg") mm_config_path = os.path.join(mm_directory, mm_config) if not os.path.exists(mm_config_path): # Try config.yml directory @@ -215,6 +229,7 @@ async def initialize(self): self.system.load_system_configuration(mm_config_path) logger.info("MMCore initialized (direct, in-process)") logger.info("Loaded devices: %s", self.system.get_loaded_devices()) + cui.step_done(Path(mm_config_path).name) # Register MMCore event callback so we get push notifications for # property changes, stage moves, exposure changes, etc. — anything the @@ -223,10 +238,13 @@ async def initialize(self): self._register_mmcore_callbacks() # [3/5] Create Ophyd devices + cui.step(3, 5, "Creating devices") logger.info("[3/5] Creating Ophyd devices...") - from .device_factory import create_devices_from_mmcore # Suppress rich console output to avoid Unicode issues on Windows import io + + from .device_factory import create_devices_from_mmcore + old_stdout = sys.stdout sys.stdout = io.StringIO() try: @@ -237,6 +255,46 @@ async def initialize(self): for name in self.devices: logger.debug(" - %s", name) + # Optional BLE accessory: SwitchBot Bot. It's a Bluetooth device, not a + # Micro-Manager adapter, so it's created here (independently of MMCore) + # and added to the same registry. Plans address it by name, e.g. + # bps.mv(switchbot, 'on'). Config-gated: no `switchbot:` section => no-op. + sb_cfg = self.config.get("switchbot") + if sb_cfg: + try: + from gently.hardware.switchbot import SwitchBot + + sb_name = sb_cfg.get("name", "switchbot") + self.devices[sb_name] = SwitchBot( + address=sb_cfg["address"], + name=sb_name, + timeout=sb_cfg.get("timeout", 20.0), + ) + logger.info("Created SwitchBot '%s' at %s", sb_name, sb_cfg["address"]) + except Exception as exc: + logger.warning("Could not create SwitchBot: %s", exc) + + # Optional temperature controller (ACUITYnano). Like the SwitchBot it's + # not an MMCore adapter — created here from config and added to the same + # registry. Plans block on it via bps.mv(temperature, 20.0) until the + # controller reports SYSTEM LOCKED. Config-gated: no `temperature:` => no-op. + temp_cfg = self.config.get("temperature") + if temp_cfg: + try: + from gently.hardware.temperature import create_temperature_controller + + tc = create_temperature_controller(temp_cfg) + self.devices[tc.name] = tc + logger.info( + "Created temperature controller '%s' (backend=%s)", + tc.name, + temp_cfg.get("backend", "serial"), + ) + except Exception as exc: + logger.warning("Could not create temperature controller: %s", exc) + + cui.step_done(f"{len(self.devices)} devices") + # Push XY safety bounds down to the ASI Tiger firmware so the joystick # can't drive past Layer-1 software limits. The XY_STAGE_*_UM constants # in devices/stage.py are the single source of truth — both the @@ -245,12 +303,15 @@ async def initialize(self): # outside the requested envelope — operator must drive into bounds # first. We do NOT SaveCardSettings so a code-side limit change always # wins on next device-layer restart (config-as-code). - xy_stage = self.devices.get('xy_stage') + xy_stage = self.devices.get("xy_stage") if xy_stage is not None: from .devices.stage import ( - XY_STAGE_X_MIN_UM, XY_STAGE_X_MAX_UM, - XY_STAGE_Y_MIN_UM, XY_STAGE_Y_MAX_UM, + XY_STAGE_X_MAX_UM, + XY_STAGE_X_MIN_UM, + XY_STAGE_Y_MAX_UM, + XY_STAGE_Y_MIN_UM, ) + try: xy_stage.set_firmware_limits( x_min_mm=XY_STAGE_X_MIN_UM / 1000.0, @@ -259,10 +320,11 @@ async def initialize(self): y_max_mm=XY_STAGE_Y_MAX_UM / 1000.0, ) logger.info( - "ASI Tiger firmware soft limits applied: " - "X=[%.2f, %.2f] µm, Y=[%.2f, %.2f] µm", - XY_STAGE_X_MIN_UM, XY_STAGE_X_MAX_UM, - XY_STAGE_Y_MIN_UM, XY_STAGE_Y_MAX_UM, + "ASI Tiger firmware soft limits applied: X=[%.2f, %.2f] µm, Y=[%.2f, %.2f] µm", + XY_STAGE_X_MIN_UM, + XY_STAGE_X_MAX_UM, + XY_STAGE_Y_MIN_UM, + XY_STAGE_Y_MAX_UM, ) except ValueError as exc: # Current position is outside the envelope — refuse to start @@ -274,7 +336,20 @@ async def initialize(self): logger.error("Could not apply ASI firmware soft limits: %s", exc) raise + # Tiger persists JoystickEnabled in non-volatile card settings — + # if a prior session ever called SaveCardSettings with the + # joystick off, every subsequent boot inherits that state and the + # physical controller is dead. Force it on at boot so the + # operator's joystick always works regardless of card history. + try: + xy_stage.enable_joystick(True) + except Exception as exc: + # Not fatal — the agent can still drive the stage. Log loudly + # so the operator knows the joystick is unavailable. + logger.error("Could not enable XY joystick: %s", exc) + # [4/5] Initialize RunEngine + cui.step(4, 5, "Initializing RunEngine") logger.info("[4/5] Initializing RunEngine...") self.RE = RunEngine({}) @@ -296,6 +371,7 @@ def serialize_value(v): # Large array + staging dir configured -> file ref if self._volume_dir and v.nbytes > 1_000_000: import uuid + try: import tifffile except ImportError: @@ -324,22 +400,30 @@ def collect_docs(name, doc): # Serialize the document to handle numpy arrays serialized_doc = serialize_value(dict(doc)) - if name == 'start': - self._last_documents = {'start': serialized_doc, 'descriptors': [], 'events': [], 'stop': None} - elif name == 'descriptor': - self._last_documents['descriptors'].append(serialized_doc) - elif name == 'event': - self._last_documents['events'].append(serialized_doc) - elif name == 'stop': - self._last_documents['stop'] = serialized_doc + if name == "start": + self._last_documents = { + "start": serialized_doc, + "descriptors": [], + "events": [], + "stop": None, + } + elif name == "descriptor": + self._last_documents["descriptors"].append(serialized_doc) + elif name == "event": + self._last_documents["events"].append(serialized_doc) + elif name == "stop": + self._last_documents["stop"] = serialized_doc self._run_history.append(self._last_documents.copy()) self.RE.subscribe(collect_docs) logger.info("RunEngine ready") + cui.step_done("ready") # [5/5] Load plans + cui.step(5, 5, "Loading plans") logger.info("[5/5] Loading plans...") self._load_plans() + cui.step_done(f"{len(self.plans)} plans") logger.info("=" * 60) logger.info("Device layer initialized successfully") @@ -349,27 +433,28 @@ def _load_plans(self): """Load available plans""" try: from .plans.acquisition import ( - move_stage_plan, - read_stage_plan, capture_bottom_image_plan, capture_lightsheet_image_plan, + get_light_source_power_plan, move_piezo_plan, move_scanner_plan, - set_led_plan, + move_stage_plan, + read_stage_plan, set_laser_plan, + set_led_plan, set_light_source_power_plan, - get_light_source_power_plan, ) - self.plans['move_stage_plan'] = move_stage_plan - self.plans['read_stage_plan'] = read_stage_plan - self.plans['capture_bottom_image_plan'] = capture_bottom_image_plan - self.plans['capture_lightsheet_image_plan'] = capture_lightsheet_image_plan - self.plans['move_piezo_plan'] = move_piezo_plan - self.plans['move_scanner_plan'] = move_scanner_plan - self.plans['set_led_plan'] = set_led_plan - self.plans['set_laser_plan'] = set_laser_plan - self.plans['set_light_source_power_plan'] = set_light_source_power_plan - self.plans['get_light_source_power_plan'] = get_light_source_power_plan + + self.plans["move_stage_plan"] = move_stage_plan + self.plans["read_stage_plan"] = read_stage_plan + self.plans["capture_bottom_image_plan"] = capture_bottom_image_plan + self.plans["capture_lightsheet_image_plan"] = capture_lightsheet_image_plan + self.plans["move_piezo_plan"] = move_piezo_plan + self.plans["move_scanner_plan"] = move_scanner_plan + self.plans["set_led_plan"] = set_led_plan + self.plans["set_laser_plan"] = set_laser_plan + self.plans["set_light_source_power_plan"] = set_light_source_power_plan + self.plans["get_light_source_power_plan"] = get_light_source_power_plan logger.info("Loaded %d plans", len(self.plans)) except ImportError as e: logger.warning("Could not load some plans: %s", e) @@ -377,18 +462,34 @@ def _load_plans(self): # Also load main plans if available try: from .plans.acquisition import ( - calibrate_piezo_galvo_plan, acquire_single_volume_plan, burst_plan, + calibrate_piezo_galvo_plan, ) - self.plans['calibrate_piezo_galvo_plan'] = calibrate_piezo_galvo_plan - self.plans['acquire_single_volume_plan'] = acquire_single_volume_plan - self.plans['burst_plan'] = burst_plan + + self.plans["calibrate_piezo_galvo_plan"] = calibrate_piezo_galvo_plan + self.plans["acquire_single_volume_plan"] = acquire_single_volume_plan + self.plans["burst_plan"] = burst_plan logger.info("Loaded main acquisition plans") except ImportError: logger.info("Main acquisition plans not available") - def _resolve_device_args(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + # SPIM head focus + dual-view registration plans + try: + from .plans.acquisition import ( + register_views_xy_plan, + spim_head_focus_and_align_plan, + spim_head_focus_descent_plan, + ) + + self.plans["spim_head_focus_descent_plan"] = spim_head_focus_descent_plan + self.plans["register_views_xy_plan"] = register_views_xy_plan + self.plans["spim_head_focus_and_align_plan"] = spim_head_focus_and_align_plan + logger.info("Loaded SPIM head focus plans") + except ImportError: + logger.info("SPIM head focus plans not available") + + def _resolve_device_args(self, kwargs: dict[str, Any]) -> dict[str, Any]: """Replace device name strings with actual device objects""" resolved = {} for key, value in kwargs.items(): @@ -409,13 +510,17 @@ def _get_sam_detector(self): keeping server startup fast. """ if self._sam_detector is None: - logger.info("Loading SAM detector (%s on %s)...", self._sam_model_type, self._sam_device) + logger.info( + "Loading SAM detector (%s on %s)...", + self._sam_model_type, + self._sam_device, + ) from .sam_detection import SAMEmbryoDetector self._sam_detector = SAMEmbryoDetector( sam_checkpoint=self._sam_checkpoint, sam_model_type=self._sam_model_type, - device=self._sam_device + device=self._sam_device, ) logger.info("SAM detector ready") @@ -441,7 +546,7 @@ async def pause_state_updates(self): if self._state_pause_counter == 0: self._state_latest["paused"] = False - def _read_xy_position(self) -> Dict[str, Any]: + def _read_xy_position(self) -> dict[str, Any]: """Read just the XY stage via the ophyd device's ``read()``. Two ASI serial round-trips (~250 ms) — the cost is in the underlying @@ -449,12 +554,12 @@ def _read_xy_position(self) -> Dict[str, Any]: through ``xy_stage.read()`` so the only place that touches MMCore is inside the ophyd device class. """ - out: Dict[str, Any] = {} - xy = self.devices.get('xy_stage') + out: dict[str, Any] = {} + xy = self.devices.get("xy_stage") if xy is not None: try: data = xy.read() - value = data[xy.name]['value'] + value = data[xy.name]["value"] out[xy.name] = { "X": float(value[0]), "Y": float(value[1]), @@ -464,30 +569,30 @@ def _read_xy_position(self) -> Dict[str, Any]: logger.debug("XY position read failed: %s", exc) return out - def _read_slow_positions(self) -> Dict[str, Any]: + def _read_slow_positions(self) -> dict[str, Any]: """Read piezo + galvo via their ophyd ``read()`` methods. These rarely change on their own — piezo by Z-knob or commands, galvo only programmatically — so a 1 Hz cadence is plenty. """ - out: Dict[str, Any] = {} + out: dict[str, Any] = {} - piezo = self.devices.get('piezo') + piezo = self.devices.get("piezo") if piezo is not None: try: data = piezo.read() out[piezo.name] = { - "Position": float(data[piezo.name]['value']), + "Position": float(data[piezo.name]["value"]), "kind": "piezo", } except Exception as exc: logger.debug("Piezo position read failed: %s", exc) - scanner = self.devices.get('scanner') + scanner = self.devices.get("scanner") if scanner is not None: try: data = scanner.read() - value = data[scanner.name]['value'] + value = data[scanner.name]["value"] out[scanner.name] = { "A": float(value[0]), "B": float(value[1]), @@ -498,7 +603,7 @@ def _read_slow_positions(self) -> Dict[str, Any]: return out - def _read_full_state(self) -> Dict[str, Dict[str, str]]: + def _read_full_state(self) -> dict[str, dict[str, str]]: """Snapshot every property of every loaded MM device via the system state cache. `update_system_state_cache()` rereads from hardware, then @@ -514,7 +619,7 @@ def _read_full_state(self) -> Dict[str, Dict[str, str]]: logger.debug("System state cache read failed: %s", exc) return {} - by_device: Dict[str, Dict[str, str]] = {} + by_device: dict[str, dict[str, str]] = {} try: size = cfg.size() except Exception: @@ -561,12 +666,14 @@ async def _position_poller(self): if self._state_pause_counter > 0: now = time.time() if now - last_heartbeat > 2.0: - await self._broadcast_state({ - **self._state_latest, - "t": now, - "paused": True, - "heartbeat": True, - }) + await self._broadcast_state( + { + **self._state_latest, + "t": now, + "paused": True, + "heartbeat": True, + } + ) last_heartbeat = now await asyncio.sleep(self._state_pos_interval_sec) continue @@ -580,7 +687,8 @@ async def _position_poller(self): if read_elapsed > 0.4: logger.warning( "XY position read slow: %.2fs (target<%.2fs)", - read_elapsed, self._state_pos_interval_sec, + read_elapsed, + self._state_pos_interval_sec, ) now = time.time() @@ -631,7 +739,8 @@ async def _slow_positions_poller(self): read_elapsed = time.monotonic() - read_start if read_elapsed > 0.6: logger.warning( - "Slow-positions read slow: %.2fs", read_elapsed, + "Slow-positions read slow: %.2fs", + read_elapsed, ) # Merge — don't clobber XY entries the fast poller maintains. @@ -679,7 +788,8 @@ async def _property_poller(self): read_elapsed = time.monotonic() - read_start if read_elapsed > 1.0: logger.warning( - "Property read slow: %.2fs", read_elapsed, + "Property read slow: %.2fs", + read_elapsed, ) self._state_latest = { @@ -698,11 +808,11 @@ async def _property_poller(self): elapsed = time.monotonic() - tick_start await asyncio.sleep(max(0.0, self._state_prop_interval_sec - elapsed)) - async def _broadcast_state(self, payload: Dict[str, Any]): + async def _broadcast_state(self, payload: dict[str, Any]): """Push a state payload to every SSE subscriber. Drop slow clients.""" if not self._state_subscribers: return - dead: List[asyncio.Queue] = [] + dead: list[asyncio.Queue] = [] for q in self._state_subscribers: try: q.put_nowait(payload) @@ -719,14 +829,14 @@ async def _broadcast_state(self, payload: Dict[str, Any]): # Bottom-camera live stream (Phase 1: low-rate thumbnail) # ========================================================================= - def _capture_bottom_frame_sync(self) -> Optional[np.ndarray]: + def _capture_bottom_frame_sync(self) -> np.ndarray | None: """Grab a single frame via the ophyd device's synchronous ``snap()``. Blocking — call via ``asyncio.to_thread``. All MMCore traffic happens inside ``DiSPIMCamera.snap()``; the streamer holds no direct core handle. """ - cam = self.devices.get('bottom_camera') + cam = self.devices.get("bottom_camera") if cam is None: return None try: @@ -735,7 +845,7 @@ def _capture_bottom_frame_sync(self) -> Optional[np.ndarray]: logger.debug("Bottom-camera grab failed: %s", exc) return None - def _encode_frame_for_stream(self, img: np.ndarray) -> Optional[Dict[str, Any]]: + def _encode_frame_for_stream(self, img: np.ndarray) -> dict[str, Any] | None: """Downsample + auto-contrast + JPEG-encode a uint16 frame for SSE. Optimised for streaming throughput: @@ -748,8 +858,9 @@ def _encode_frame_for_stream(self, img: np.ndarray) -> Optional[Dict[str, Any]]: if img is None or img.size == 0: return None try: - import cv2 # opencv ships with the agent env (SAM uses it) import base64 + + import cv2 # opencv ships with the agent env (SAM uses it) except ImportError as exc: logger.warning("Cannot encode frame — OpenCV unavailable: %s", exc) return None @@ -778,10 +889,10 @@ def _encode_frame_for_stream(self, img: np.ndarray) -> Optional[Dict[str, Any]]: scale = 255.0 / (hi - lo) small = np.clip((small.astype(np.float32) - lo) * scale, 0, 255).astype(np.uint8) - ok, jpeg = cv2.imencode('.jpg', small, [cv2.IMWRITE_JPEG_QUALITY, self._cam_jpeg_quality]) + ok, jpeg = cv2.imencode(".jpg", small, [cv2.IMWRITE_JPEG_QUALITY, self._cam_jpeg_quality]) if not ok: return None - b64 = base64.b64encode(jpeg.tobytes()).decode('ascii') + b64 = base64.b64encode(jpeg.tobytes()).decode("ascii") return { "t": time.time(), "shape": [int(small.shape[0]), int(small.shape[1])], @@ -821,10 +932,10 @@ async def _bottom_camera_streamer(self): finally: logger.info("Bottom-camera streamer exiting") - async def _broadcast_camera(self, payload: Dict[str, Any]): + async def _broadcast_camera(self, payload: dict[str, Any]): if not self._cam_subscribers: return - dead: List[asyncio.Queue] = [] + dead: list[asyncio.Queue] = [] for q in self._cam_subscribers: try: q.put_nowait(payload) @@ -866,7 +977,11 @@ def _register_mmcore_callbacks(self): class _Bridge(pymmcore.MMEventCallback): def _emit(self, kind: str, **payload): payload = {"t": time.time(), "kind": kind, **payload} - logger.info("MMCore callback: %s %s", kind, {k: v for k, v in payload.items() if k != "t"}) + logger.info( + "MMCore callback: %s %s", + kind, + {k: v for k, v in payload.items() if k != "t"}, + ) loop = outer._mm_callback_loop if loop is None or loop.is_closed(): return @@ -900,8 +1015,7 @@ def onPixelSizeChanged(self, new_pixel_size_um): self._emit("pixel_size_changed", um=new_pixel_size_um) def onPixelSizeAffineChanged(self, v0, v1, v2, v3, v4, v5): - self._emit("pixel_size_affine_changed", - affine=[v0, v1, v2, v3, v4, v5]) + self._emit("pixel_size_affine_changed", affine=[v0, v1, v2, v3, v4, v5]) def onSystemConfigurationLoaded(self): self._emit("system_configuration_loaded") @@ -911,7 +1025,7 @@ def onSystemConfigurationLoaded(self): self.system.register_callback(self._mm_callback_bridge) logger.info("MMCore callback bridge registered") - def _enqueue_callback(self, payload: Dict[str, Any]): + def _enqueue_callback(self, payload: dict[str, Any]): """Runs on the asyncio loop (via call_soon_threadsafe). Two jobs: forward to /api/devices/callbacks/stream subscribers (for @@ -921,7 +1035,7 @@ def _enqueue_callback(self, payload: Dict[str, Any]): """ # 1. Forward to the diagnostic callback stream. if self._callback_subscribers: - dead: List[asyncio.Queue] = [] + dead: list[asyncio.Queue] = [] for q in self._callback_subscribers: try: q.put_nowait(payload) @@ -938,7 +1052,7 @@ def _enqueue_callback(self, payload: Dict[str, Any]): if self._apply_callback_to_state(payload): self._schedule_callback_broadcast() - def _apply_callback_to_state(self, payload: Dict[str, Any]) -> bool: + def _apply_callback_to_state(self, payload: dict[str, Any]) -> bool: """Translate a callback payload into a `_state_latest` mutation. Returns True iff something visible changed (the caller will then @@ -1037,23 +1151,24 @@ async def _plan_executor(self): while self._running: try: # Wait for a plan request - request = await asyncio.wait_for( - self._plan_queue.get(), - timeout=1.0 - ) + request = await asyncio.wait_for(self._plan_queue.get(), timeout=1.0) except asyncio.TimeoutError: continue # Log execution start with timestamp start_time = datetime.now() execution_record = { - 'plan_name': request.plan_name, - 'kwargs': {k: str(v) for k, v in request.kwargs.items()}, # Stringify for JSON - 'start_time': start_time.isoformat(), - 'start_time_formatted': start_time.strftime('%H:%M:%S.%f')[:-3], + "plan_name": request.plan_name, + "kwargs": {k: str(v) for k, v in request.kwargs.items()}, # Stringify for JSON + "start_time": start_time.isoformat(), + "start_time_formatted": start_time.strftime("%H:%M:%S.%f")[:-3], } - logger.info(">>> [%s] Executing: %s", start_time.strftime('%H:%M:%S'), request.plan_name) + logger.info( + ">>> [%s] Executing: %s", + start_time.strftime("%H:%M:%S"), + request.plan_name, + ) # Reset documents before each plan so stale results from # a previous plan (e.g. volume file refs from acquire) don't @@ -1091,37 +1206,52 @@ async def _plan_executor(self): end_time = datetime.now() duration_ms = (end_time - start_time).total_seconds() * 1000 - execution_record.update({ - 'end_time': end_time.isoformat(), - 'end_time_formatted': end_time.strftime('%H:%M:%S.%f')[:-3], - 'duration_ms': duration_ms, - 'success': True, - 'uid': uid, - }) + execution_record.update( + { + "end_time": end_time.isoformat(), + "end_time_formatted": end_time.strftime("%H:%M:%S.%f")[:-3], + "duration_ms": duration_ms, + "success": True, + "uid": uid, + } + ) - logger.info("<<< [%s] Complete: %s (%.0fms)", end_time.strftime('%H:%M:%S'), request.plan_name, duration_ms) + logger.info( + "<<< [%s] Complete: %s (%.0fms)", + end_time.strftime("%H:%M:%S"), + request.plan_name, + duration_ms, + ) # Complete the future with result - request.future.set_result({ - 'success': True, - 'uid': uid, - 'documents': self._last_documents.copy() - }) + request.future.set_result( + { + "success": True, + "uid": uid, + "documents": self._last_documents.copy(), + } + ) except Exception as e: - import traceback end_time = datetime.now() duration_ms = (end_time - start_time).total_seconds() * 1000 - execution_record.update({ - 'end_time': end_time.isoformat(), - 'end_time_formatted': end_time.strftime('%H:%M:%S.%f')[:-3], - 'duration_ms': duration_ms, - 'success': False, - 'error': str(e), - }) + execution_record.update( + { + "end_time": end_time.isoformat(), + "end_time_formatted": end_time.strftime("%H:%M:%S.%f")[:-3], + "duration_ms": duration_ms, + "success": False, + "error": str(e), + } + ) - logger.error("<<< [%s] Failed: %s - %s", end_time.strftime('%H:%M:%S'), request.plan_name, e) + logger.error( + "<<< [%s] Failed: %s - %s", + end_time.strftime("%H:%M:%S"), + request.plan_name, + e, + ) request.future.set_exception(e) # Store execution record @@ -1130,17 +1260,13 @@ async def _plan_executor(self): if len(self._plan_execution_log) > 1000: self._plan_execution_log = self._plan_execution_log[-1000:] - async def submit_plan(self, plan_name: str, kwargs: Dict = None) -> Dict: + async def submit_plan(self, plan_name: str, kwargs: dict | None = None) -> dict: """Submit a plan and wait for completion""" kwargs = kwargs or {} # Create request with a future loop = asyncio.get_event_loop() - request = PlanRequest( - plan_name=plan_name, - kwargs=kwargs, - future=loop.create_future() - ) + request = PlanRequest(plan_name=plan_name, kwargs=kwargs, future=loop.create_future()) # Add to queue await self._plan_queue.put(request) @@ -1156,12 +1282,12 @@ async def submit_plan(self, plan_name: str, kwargs: Dict = None) -> Dict: async def handle_status(self, request): """GET /api/status""" status = { - 'manager_state': 'idle' if self._running else 'stopped', - 're_state': 'idle', - 'devices': list(self.devices.keys()), - 'plans': list(self.plans.keys()), - 'queue_size': self._plan_queue.qsize(), - 'sam_loaded': self._sam_detector is not None, + "manager_state": "idle" if self._running else "stopped", + "re_state": "idle", + "devices": list(self.devices.keys()), + "plans": list(self.plans.keys()), + "queue_size": self._plan_queue.qsize(), + "sam_loaded": self._sam_detector is not None, } return web.json_response(status) @@ -1169,13 +1295,12 @@ async def handle_submit_plan(self, request): """POST /api/queue/item/add""" try: data = await request.json() - plan_name = data.get('item', {}).get('name') - kwargs = data.get('item', {}).get('kwargs', {}) + plan_name = data.get("item", {}).get("name") + kwargs = data.get("item", {}).get("kwargs", {}) if not plan_name: return web.json_response( - {'success': False, 'error': 'No plan name provided'}, - status=400 + {"success": False, "error": "No plan name provided"}, status=400 ) result = await self.submit_plan(plan_name, kwargs) @@ -1183,162 +1308,378 @@ async def handle_submit_plan(self, request): except Exception as e: import traceback - return web.json_response({ - 'success': False, - 'error': str(e), - 'traceback': traceback.format_exc() - }, status=500) + + return web.json_response( + { + "success": False, + "error": str(e), + "traceback": traceback.format_exc(), + }, + status=500, + ) async def handle_get_history(self, request): """GET /api/history""" - return web.json_response({ - 'success': True, - 'items': self._run_history[-10:] # Last 10 runs - }) + return web.json_response( + { + "success": True, + "items": self._run_history[-10:], # Last 10 runs + } + ) async def handle_get_devices(self, request): """GET /api/devices""" - return web.json_response({ - 'success': True, - 'devices': list(self.devices.keys()) - }) + return web.json_response({"success": True, "devices": list(self.devices.keys())}) async def handle_get_plans(self, request): """GET /api/plans""" - return web.json_response({ - 'success': True, - 'plans': list(self.plans.keys()) - }) + return web.json_response({"success": True, "plans": list(self.plans.keys())}) async def handle_get_led_status(self, request): """GET /api/led/status - Get current LED state and available configs""" try: - led = self.devices.get('led') + led = self.devices.get("led") if led is None: - return web.json_response({ - 'success': False, - 'error': 'LED device not found' - }) + return web.json_response({"success": False, "error": "LED device not found"}) # Read current state current_state = led.read() - led_value = current_state.get(led.name, {}).get('value', 'unknown') + led_value = current_state.get(led.name, {}).get("value", "unknown") # Get available configs available_configs = led._available_configs - return web.json_response({ - 'success': True, - 'current_state': led_value, - 'available_configs': available_configs, - 'group_name': led.group_name - }) + return web.json_response( + { + "success": True, + "current_state": led_value, + "available_configs": available_configs, + "group_name": led.group_name, + } + ) except Exception as e: import traceback - return web.json_response({ - 'success': False, - 'error': str(e), - 'traceback': traceback.format_exc() - }, status=500) + + return web.json_response( + { + "success": False, + "error": str(e), + "traceback": traceback.format_exc(), + }, + status=500, + ) async def handle_set_led(self, request): """POST /api/led/set - Set LED state directly (bypass plan queue)""" try: data = await request.json() - state = data.get('state', 'Closed') + state = data.get("state", "Closed") - led = self.devices.get('led') + led = self.devices.get("led") if led is None: - return web.json_response({ - 'success': False, - 'error': 'LED device not found' - }) + return web.json_response({"success": False, "error": "LED device not found"}) # Set LED state directly status = led.set(state) # Wait for completion import time + timeout = 5.0 start = time.time() while not status.done and (time.time() - start) < timeout: await asyncio.sleep(0.1) if status.done and status.success: - return web.json_response({ - 'success': True, - 'state': state, - 'message': f'LED set to {state}' - }) + return web.json_response( + {"success": True, "state": state, "message": f"LED set to {state}"} + ) else: - return web.json_response({ - 'success': False, - 'error': f'Failed to set LED to {state}' - }) + return web.json_response( + {"success": False, "error": f"Failed to set LED to {state}"} + ) + except Exception as e: + import traceback + + return web.json_response( + { + "success": False, + "error": str(e), + "traceback": traceback.format_exc(), + }, + status=500, + ) + + def _room_light_device(self): + """Resolve the room-light SwitchBot from the device registry. + + Prefers the conventional 'room_light' key (config.yml name), but + falls back to scanning for any SwitchBot instance so a differently + named bot still works. Returns None when no bot is configured. + """ + bot = self.devices.get("room_light") + if bot is not None: + return bot + try: + from gently.hardware.switchbot import SwitchBot + except Exception: + return None + for dev in self.devices.values(): + if isinstance(dev, SwitchBot): + return dev + return None + + async def handle_get_room_light_status(self, request): + """GET /api/room_light/status - cached on/off state of the room light. + + Reads the SwitchBot's last-commanded state (no BLE round-trip, so it's + cheap to poll). 'unknown' until the first on/off command lands. + """ + try: + bot = self._room_light_device() + if bot is None: + return web.json_response( + { + "success": False, + "available": False, + "error": "room_light device not configured", + } + ) + state = bot.read().get(bot.name, {}).get("value", "unknown") + return web.json_response({"success": True, "available": True, "state": state}) + except Exception as e: + import traceback + + return web.json_response( + { + "success": False, + "available": False, + "error": str(e), + "traceback": traceback.format_exc(), + }, + status=500, + ) + + async def handle_set_room_light(self, request): + """POST /api/room_light/set - drive the room-light SwitchBot. + + Body: {"state": "on" | "off" | "press"}. Blocks until the BLE command + lands (the bot's servo move is ~0.5-1 s plus connect latency). + """ + try: + data = await request.json() + state = str(data.get("state", "")).lower() + if state not in ("on", "off", "press"): + return web.json_response( + { + "success": False, + "error": f"state {state!r} must be on, off, or press", + }, + status=400, + ) + bot = self._room_light_device() + if bot is None: + return web.json_response( + {"success": False, "error": "room_light device not configured"}, + status=503, + ) + + status = bot.set(state) + import time + + timeout = float(getattr(bot, "timeout", 20.0)) + 5 + start = time.time() + while not status.done and (time.time() - start) < timeout: + await asyncio.sleep(0.1) + + if status.done and status.success: + new_state = bot.read().get(bot.name, {}).get("value", state) + return web.json_response({"success": True, "state": new_state}) + return web.json_response( + {"success": False, "error": f"failed to set room light to {state}"}, + status=502, + ) + except Exception as e: + import traceback + + return web.json_response( + { + "success": False, + "error": str(e), + "traceback": traceback.format_exc(), + }, + status=500, + ) + + async def handle_get_temperature_status(self, request): + """GET /api/temperature/status - current temperature, setpoint, lock state.""" + try: + temp = self.devices.get("temperature") + if temp is None: + return web.json_response( + {"success": False, "error": "temperature device not found"} + ) + r = temp.read() + return web.json_response( + { + "success": True, + "temperature_c": r.get(temp.name, {}).get("value"), + "setpoint_c": r.get(f"{temp.name}_setpoint", {}).get("value"), + "state": r.get(f"{temp.name}_state", {}).get("value"), + "peltier_c": r.get(f"{temp.name}_peltier", {}).get("value"), + } + ) except Exception as e: import traceback - return web.json_response({ - 'success': False, - 'error': str(e), - 'traceback': traceback.format_exc() - }, status=500) + + return web.json_response( + { + "success": False, + "error": str(e), + "traceback": traceback.format_exc(), + }, + status=500, + ) + + async def handle_set_temperature(self, request): + """POST /api/temperature/set - command setpoint. Body: {target_c, wait?}. + + Non-blocking by default (controller ramps; poll status). wait=true blocks + until SYSTEM LOCKED or the device's stabilize timeout. + """ + try: + data = await request.json() + target = float(data.get("target_c")) + wait = bool(data.get("wait", False)) + if not (0.0 <= target <= 99.9): + return web.json_response( + {"success": False, "error": f"target {target} outside [0.0, 99.9]"} + ) + temp = self.devices.get("temperature") + if temp is None: + return web.json_response( + {"success": False, "error": "temperature device not found"} + ) + + if not wait: + temp.enable(True) + temp.setpoint(target) + r = temp.read() + return web.json_response( + { + "success": True, + "target_c": target, + "waited": False, + "message": f"commanded {target} C (ramping)", + "temperature_c": r.get(temp.name, {}).get("value"), + "state": r.get(f"{temp.name}_state", {}).get("value"), + } + ) + + import time + + status = temp.set(target) + timeout = float(getattr(temp, "stabilize_timeout", 600.0)) + 10 + start = time.time() + while not status.done and (time.time() - start) < timeout: + await asyncio.sleep(0.5) + r = temp.read() + if status.done and status.success: + return web.json_response( + { + "success": True, + "target_c": target, + "waited": True, + "message": f"locked at {target} C", + "temperature_c": r.get(temp.name, {}).get("value"), + "state": r.get(f"{temp.name}_state", {}).get("value"), + } + ) + return web.json_response( + { + "success": False, + "target_c": target, + "error": f"did not stabilize at {target} C within {timeout:.0f}s", + } + ) + except Exception as e: + import traceback + + return web.json_response( + { + "success": False, + "error": str(e), + "traceback": traceback.format_exc(), + }, + status=500, + ) async def handle_set_camera_led_mode(self, request): """POST /api/camera/led_mode - Enable/disable automatic LED for bottom camera""" try: data = await request.json() - use_led = data.get('use_led', False) + use_led = data.get("use_led", False) - bottom_camera = self.devices.get('bottom_camera') + bottom_camera = self.devices.get("bottom_camera") if bottom_camera is None: - return web.json_response({ - 'success': False, - 'error': 'Bottom camera device not found' - }) + return web.json_response( + {"success": False, "error": "Bottom camera device not found"} + ) # Set the use_led attribute bottom_camera.use_led = use_led - return web.json_response({ - 'success': True, - 'use_led': use_led, - 'message': f'Bottom camera LED mode: {"ON" if use_led else "OFF"}' - }) + return web.json_response( + { + "success": True, + "use_led": use_led, + "message": f"Bottom camera LED mode: {'ON' if use_led else 'OFF'}", + } + ) except Exception as e: import traceback - return web.json_response({ - 'success': False, - 'error': str(e), - 'traceback': traceback.format_exc() - }, status=500) + + return web.json_response( + { + "success": False, + "error": str(e), + "traceback": traceback.format_exc(), + }, + status=500, + ) async def handle_set_camera_exposure(self, request): """POST /api/camera/exposure - Set bottom camera exposure time""" try: data = await request.json() - exposure_ms = data.get('exposure_ms', 50.0) + exposure_ms = data.get("exposure_ms", 50.0) - bottom_camera = self.devices.get('bottom_camera') + bottom_camera = self.devices.get("bottom_camera") if bottom_camera is None: - return web.json_response({ - 'success': False, - 'error': 'Bottom camera device not found' - }) + return web.json_response( + {"success": False, "error": "Bottom camera device not found"} + ) # Set exposure using the device's configure_exposure method bottom_camera.configure_exposure(exposure_ms) - return web.json_response({ - 'success': True, - 'exposure_ms': exposure_ms, - 'message': f'Bottom camera exposure set to {exposure_ms} ms' - }) + return web.json_response( + { + "success": True, + "exposure_ms": exposure_ms, + "message": f"Bottom camera exposure set to {exposure_ms} ms", + } + ) except Exception as e: import traceback - return web.json_response({ - 'success': False, - 'error': str(e), - 'traceback': traceback.format_exc() - }, status=500) + + return web.json_response( + { + "success": False, + "error": str(e), + "traceback": traceback.format_exc(), + }, + status=500, + ) async def handle_set_light_source_power(self, request): """POST /api/light_source/power — set per-line laser power %. @@ -1350,104 +1691,136 @@ async def handle_set_light_source_power(self, request): """ try: data = await request.json() - wavelength = int(data.get('wavelength', 488)) - pct = float(data.get('pct')) - light_source = self.devices.get('light_source') or self.devices.get('laser_control') + wavelength = int(data.get("wavelength", 488)) + pct = float(data.get("pct")) + light_source = self.devices.get("light_source") or self.devices.get("laser_control") if light_source is None: - return web.json_response({ - 'success': False, - 'error': 'Light source device not found', - }, status=503) + return web.json_response( + { + "success": False, + "error": "Light source device not found", + }, + status=503, + ) try: light_source.set_power_pct(wavelength, pct) except (ValueError, KeyError) as e: - return web.json_response({ - 'success': False, 'error': str(e), - }, status=400) + return web.json_response( + { + "success": False, + "error": str(e), + }, + status=400, + ) readback = light_source.get_power_pct(wavelength) - return web.json_response({ - 'success': True, - 'wavelength': wavelength, - 'pct': pct, - 'readback_pct': readback, - }) + return web.json_response( + { + "success": True, + "wavelength": wavelength, + "pct": pct, + "readback_pct": readback, + } + ) except Exception as e: import traceback - return web.json_response({ - 'success': False, - 'error': str(e), - 'traceback': traceback.format_exc(), - }, status=500) + + return web.json_response( + { + "success": False, + "error": str(e), + "traceback": traceback.format_exc(), + }, + status=500, + ) async def handle_get_light_source_power(self, request): """GET /api/light_source/power?wavelength=488 — read laser power %.""" try: - wavelength = int(request.query.get('wavelength', 488)) - light_source = self.devices.get('light_source') or self.devices.get('laser_control') + wavelength = int(request.query.get("wavelength", 488)) + light_source = self.devices.get("light_source") or self.devices.get("laser_control") if light_source is None: - return web.json_response({ - 'success': False, - 'error': 'Light source device not found', - }, status=503) + return web.json_response( + { + "success": False, + "error": "Light source device not found", + }, + status=503, + ) try: pct = light_source.get_power_pct(wavelength) except KeyError as e: - return web.json_response({ - 'success': False, 'error': str(e), - }, status=400) - return web.json_response({ - 'success': True, - 'wavelength': wavelength, - 'pct': float(pct), - }) + return web.json_response( + { + "success": False, + "error": str(e), + }, + status=400, + ) + return web.json_response( + { + "success": True, + "wavelength": wavelength, + "pct": float(pct), + } + ) except Exception as e: import traceback - return web.json_response({ - 'success': False, - 'error': str(e), - 'traceback': traceback.format_exc(), - }, status=500) + + return web.json_response( + { + "success": False, + "error": str(e), + "traceback": traceback.format_exc(), + }, + status=500, + ) async def handle_get_camera_exposure(self, request): """GET /api/camera/exposure - Get bottom camera exposure time""" try: - bottom_camera = self.devices.get('bottom_camera') + bottom_camera = self.devices.get("bottom_camera") if bottom_camera is None: - return web.json_response({ - 'success': False, - 'error': 'Bottom camera device not found' - }) + return web.json_response( + {"success": False, "error": "Bottom camera device not found"} + ) exposure_ms = bottom_camera.exposure_time - return web.json_response({ - 'success': True, - 'exposure_ms': exposure_ms - }) + return web.json_response({"success": True, "exposure_ms": exposure_ms}) except Exception as e: import traceback - return web.json_response({ - 'success': False, - 'error': str(e), - 'traceback': traceback.format_exc() - }, status=500) + + return web.json_response( + { + "success": False, + "error": str(e), + "traceback": traceback.format_exc(), + }, + status=500, + ) async def handle_get_plan_log(self, request): """GET /api/plan_log - Get recent plan execution log with timing""" try: - limit = int(request.query.get('limit', 100)) - return web.json_response({ - 'success': True, - 'entries': self._plan_execution_log[-limit:], - 'total_count': len(self._plan_execution_log), - }) + limit = int(request.query.get("limit", 100)) + return web.json_response( + { + "success": True, + "entries": self._plan_execution_log[-limit:], + "total_count": len(self._plan_execution_log), + } + ) except Exception as e: import traceback - return web.json_response({ - 'success': False, - 'error': str(e), - 'traceback': traceback.format_exc() - }, status=500) + + return web.json_response( + { + "success": False, + "error": str(e), + "traceback": traceback.format_exc(), + }, + status=500, + ) async def handle_session_configure(self, request): """POST /session/configure - set staging directory for file-ref protocol. @@ -1462,25 +1835,33 @@ async def handle_session_configure(self, request): Path(volume_dir).mkdir(parents=True, exist_ok=True) self._volume_dir = volume_dir logger.info("Session configured: volume_dir = %s", volume_dir) - return web.json_response({ - "success": True, - "volume_dir": volume_dir, - }) + return web.json_response( + { + "success": True, + "volume_dir": volume_dir, + } + ) else: # Clear staging self._volume_dir = None - return web.json_response({ - "success": True, - "volume_dir": None, - "message": "Volume staging disabled", - }) + return web.json_response( + { + "success": True, + "volume_dir": None, + "message": "Volume staging disabled", + } + ) except Exception as e: import traceback - return web.json_response({ - "success": False, - "error": str(e), - "traceback": traceback.format_exc(), - }, status=500) + + return web.json_response( + { + "success": False, + "error": str(e), + "traceback": traceback.format_exc(), + }, + status=500, + ) # ========================================================================= # HTTP API Handlers - SAM Detection @@ -1488,13 +1869,15 @@ async def handle_session_configure(self, request): async def handle_sam_status(self, request): """GET /api/sam/status - Check SAM model availability""" - return web.json_response({ - 'success': True, - 'available': True, # SAM is always available (lazy loaded) - 'loaded': self._sam_detector is not None, - 'device': self._sam_device, - 'model_type': self._sam_model_type, - }) + return web.json_response( + { + "success": True, + "available": True, # SAM is always available (lazy loaded) + "loaded": self._sam_detector is not None, + "device": self._sam_device, + "model_type": self._sam_model_type, + } + ) async def handle_detect_embryos(self, request): """POST /api/detect_embryos - Capture image and detect embryos. @@ -1528,76 +1911,81 @@ async def handle_detect_embryos(self, request): data = await request.json() # Extract parameters with defaults - from gently.core.coordinates import DEFAULT_PIXEL_SIZE_UM, DEFAULT_OBJECTIVE_MAG + from gently.core.coordinates import ( + DEFAULT_OBJECTIVE_MAG, + DEFAULT_PIXEL_SIZE_UM, + ) - pixel_size_um = data.get('pixel_size_um', DEFAULT_PIXEL_SIZE_UM) - objective_mag = data.get('objective_mag', DEFAULT_OBJECTIVE_MAG) - use_claude_review = data.get('use_claude_review', True) - min_confidence = data.get('min_confidence', 0.7) - exposure_ms = data.get('exposure_ms') - brightness_percentile = data.get('brightness_percentile', 99.0) - min_area = data.get('min_area', 5000) - max_area = data.get('max_area', 150000) + pixel_size_um = data.get("pixel_size_um", DEFAULT_PIXEL_SIZE_UM) + objective_mag = data.get("objective_mag", DEFAULT_OBJECTIVE_MAG) + use_claude_review = data.get("use_claude_review", True) + data.get("min_confidence", 0.7) + exposure_ms = data.get("exposure_ms") + brightness_percentile = data.get("brightness_percentile", 99.0) + min_area = data.get("min_area", 5000) + max_area = data.get("max_area", 150000) # Set exposure if specified if exposure_ms is not None: - bottom_camera = self.devices.get('bottom_camera') + bottom_camera = self.devices.get("bottom_camera") if bottom_camera: bottom_camera.configure_exposure(exposure_ms) # Capture image via plan logger.info("[detect_embryos] Capturing bottom camera image...") capture_result = await self.submit_plan( - 'capture_bottom_image_plan', - kwargs={'bottom_camera': 'bottom_camera'} + "capture_bottom_image_plan", kwargs={"bottom_camera": "bottom_camera"} ) - if not capture_result.get('success'): - return web.json_response({ - 'success': False, - 'error': f"Image capture failed: {capture_result.get('error', 'Unknown')}" - }, status=500) + if not capture_result.get("success"): + return web.json_response( + { + "success": False, + "error": f"Image capture failed: {capture_result.get('error', 'Unknown')}", + }, + status=500, + ) # Extract image from result - docs = capture_result.get('documents', {}) - events = docs.get('events', []) + docs = capture_result.get("documents", {}) + events = docs.get("events", []) image = None if events: - event_data = events[0].get('data', {}) - for key in ['bottom_camera', 'bottom_camera_image', 'Bottom PCO']: + event_data = events[0].get("data", {}) + for key in ["bottom_camera", "bottom_camera_image", "Bottom PCO"]: if key in event_data: val = event_data[key] # Handle file ref - if isinstance(val, dict) and val.get('__file_ref__'): + if isinstance(val, dict) and val.get("__file_ref__"): import tifffile - image = tifffile.imread(val['path']) + + image = tifffile.imread(val["path"]) else: image = np.array(val) break if image is None: - return web.json_response({ - 'success': False, - 'error': 'No image data in capture result' - }, status=500) + return web.json_response( + {"success": False, "error": "No image data in capture result"}, + status=500, + ) logger.info("[detect_embryos] Image shape: %s", image.shape) # Read stage position logger.info("[detect_embryos] Reading stage position...") stage_result = await self.submit_plan( - 'read_stage_plan', - kwargs={'xy_stage': 'xy_stage'} + "read_stage_plan", kwargs={"xy_stage": "xy_stage"} ) stage_x, stage_y = 0.0, 0.0 - if stage_result.get('success'): - stage_docs = stage_result.get('documents', {}) - stage_events = stage_docs.get('events', []) + if stage_result.get("success"): + stage_docs = stage_result.get("documents", {}) + stage_events = stage_docs.get("events", []) if stage_events: - stage_data = stage_events[0].get('data', {}) + stage_data = stage_events[0].get("data", {}) # DiSPIMXYStage.read() returns {device_name: [x, y]} - for key in ['xy_stage', 'XYStage:XY:31', 'xy_stage_position']: + for key in ["xy_stage", "XYStage:XY:31", "xy_stage_position"]: if key in stage_data: val = stage_data[key] if isinstance(val, (list, tuple)) and len(val) >= 2: @@ -1621,15 +2009,17 @@ async def handle_detect_embryos(self, request): use_claude_review, brightness_percentile, min_area, - max_area + max_area, ) # Save image if volume_dir configured image_path = None if self._volume_dir: import uuid + try: import tifffile + uid = uuid.uuid4().hex[:12] image_path = str(Path(self._volume_dir) / f"detection_{uid}.tif") tifffile.imwrite(image_path, image) @@ -1638,29 +2028,33 @@ async def handle_detect_embryos(self, request): # Build response response = { - 'success': sam_result.get('success', False), - 'embryos': sam_result.get('embryos', []), - 'initial_detections': sam_result.get('initial_detections', 0), - 'final_detections': sam_result.get('final_detections', 0), - 'stage_position': list(stage_position), - 'verification': sam_result.get('verification', {}), + "success": sam_result.get("success", False), + "embryos": sam_result.get("embryos", []), + "initial_detections": sam_result.get("initial_detections", 0), + "final_detections": sam_result.get("final_detections", 0), + "stage_position": list(stage_position), + "verification": sam_result.get("verification", {}), } if image_path: - response['image_path'] = image_path + response["image_path"] = image_path - if 'error' in sam_result: - response['error'] = sam_result['error'] + if "error" in sam_result: + response["error"] = sam_result["error"] return web.json_response(response) except Exception as e: import traceback - return web.json_response({ - 'success': False, - 'error': str(e), - 'traceback': traceback.format_exc() - }, status=500) + + return web.json_response( + { + "success": False, + "error": str(e), + "traceback": traceback.format_exc(), + }, + status=500, + ) def _run_sam_detection( self, @@ -1672,7 +2066,7 @@ def _run_sam_detection( use_claude_review: bool, brightness_percentile: float, min_area: int, - max_area: int + max_area: int, ) -> dict: """Run SAM detection synchronously (called from thread). @@ -1694,12 +2088,12 @@ def _run_sam_detection( output_dir=Path("./detection_results"), brightness_percentile=brightness_percentile, min_area=min_area, - max_area=max_area + max_area=max_area, ) ) # Ensure results are serializable (convert numpy types) - embryos = result.get('embryos', []) + embryos = result.get("embryos", []) for embryo in embryos: for key, value in list(embryo.items()): if isinstance(value, np.floating): @@ -1710,18 +2104,19 @@ def _run_sam_detection( # Remove mask from response (not JSON serializable) del embryo[key] - result['success'] = True + result["success"] = True return result except Exception as e: import traceback + return { - 'success': False, - 'error': str(e), - 'traceback': traceback.format_exc(), - 'embryos': [], - 'initial_detections': 0, - 'final_detections': 0, + "success": False, + "error": str(e), + "traceback": traceback.format_exc(), + "embryos": [], + "initial_detections": 0, + "final_detections": 0, } finally: loop.close() @@ -1758,12 +2153,36 @@ def _run_sam_detection( "input_schema": { "type": "object", "properties": { - "num_slices": {"type": "integer", "description": "Number of Z slices", "default": 50}, - "exposure_ms": {"type": "number", "description": "Camera exposure per slice in ms", "default": 10.0}, - "galvo_amplitude": {"type": "number", "description": "Galvo scan range in volts", "default": 0.5}, - "galvo_center": {"type": "number", "description": "Galvo center position in volts", "default": 0.0}, - "piezo_amplitude": {"type": "number", "description": "Piezo Z range in µm", "default": 25.0}, - "piezo_center": {"type": "number", "description": "Piezo center position in µm", "default": 50.0}, + "num_slices": { + "type": "integer", + "description": "Number of Z slices", + "default": 50, + }, + "exposure_ms": { + "type": "number", + "description": "Camera exposure per slice in ms", + "default": 10.0, + }, + "galvo_amplitude": { + "type": "number", + "description": "Galvo scan range in volts", + "default": 0.5, + }, + "galvo_center": { + "type": "number", + "description": "Galvo center position in volts", + "default": 0.0, + }, + "piezo_amplitude": { + "type": "number", + "description": "Piezo Z range in µm", + "default": 25.0, + }, + "piezo_center": { + "type": "number", + "description": "Piezo center position in µm", + "default": 50.0, + }, }, }, "bluesky_plan": "acquire_single_volume_plan", @@ -1774,9 +2193,19 @@ def _run_sam_detection( "input_schema": { "type": "object", "properties": { - "piezo_position": {"type": "number", "description": "Z position in µm"}, - "galvo_position": {"type": "number", "description": "Galvo angle in volts"}, - "exposure_ms": {"type": "number", "description": "Camera exposure in ms", "default": 10.0}, + "piezo_position": { + "type": "number", + "description": "Z position in µm", + }, + "galvo_position": { + "type": "number", + "description": "Galvo angle in volts", + }, + "exposure_ms": { + "type": "number", + "description": "Camera exposure in ms", + "default": 10.0, + }, }, }, "bluesky_plan": "capture_lightsheet_image_plan", @@ -1787,20 +2216,37 @@ def _run_sam_detection( "input_schema": { "type": "object", "properties": { - "use_led": {"type": "boolean", "description": "Turn on LED during capture", "default": False}, - "exposure_ms": {"type": "number", "description": "Camera exposure in ms"}, + "use_led": { + "type": "boolean", + "description": "Turn on LED during capture", + "default": False, + }, + "exposure_ms": { + "type": "number", + "description": "Camera exposure in ms", + }, }, }, "bluesky_plan": "capture_bottom_image_plan", "extractor": "_extract_image", }, "calibrate": { - "description": "Run piezo-galvo calibration to find optimal focus parameters for an embryo.", + "description": ( + "Run piezo-galvo calibration to find optimal focus parameters for an embryo." + ), "input_schema": { "type": "object", "properties": { - "piezo_positions": {"type": "array", "items": {"type": "number"}, "description": "Piezo positions to sweep (µm)"}, - "galvo_positions": {"type": "array", "items": {"type": "number"}, "description": "Galvo positions to sweep (volts)"}, + "piezo_positions": { + "type": "array", + "items": {"type": "number"}, + "description": "Piezo positions to sweep (µm)", + }, + "galvo_positions": { + "type": "array", + "items": {"type": "number"}, + "description": "Galvo positions to sweep (volts)", + }, }, }, "bluesky_plan": "calibrate_piezo_galvo_plan", @@ -1811,7 +2257,11 @@ def _run_sam_detection( "input_schema": { "type": "object", "properties": { - "state": {"type": "string", "enum": ["Open", "Closed"], "description": "LED state"}, + "state": { + "type": "string", + "enum": ["Open", "Closed"], + "description": "LED state", + }, }, "required": ["state"], }, @@ -1825,17 +2275,46 @@ def _run_sam_detection( "extractor": None, }, "detect": { - "description": "Detect embryos/samples in the current field of view using SAM segmentation.", + "description": ( + "Detect embryos/samples in the current field of view using SAM segmentation." + ), "input_schema": { "type": "object", "properties": { - "pixel_size_um": {"type": "number", "description": "Camera pixel size in µm", "default": 6.5}, - "objective_mag": {"type": "number", "description": "Objective magnification", "default": 10.0}, - "min_confidence": {"type": "number", "description": "Minimum detection confidence", "default": 0.7}, - "exposure_ms": {"type": "number", "description": "Camera exposure in ms"}, - "brightness_percentile": {"type": "number", "description": "Brightness threshold percentile", "default": 99.0}, - "min_area": {"type": "integer", "description": "Minimum embryo area in pixels", "default": 5000}, - "max_area": {"type": "integer", "description": "Maximum embryo area in pixels", "default": 150000}, + "pixel_size_um": { + "type": "number", + "description": "Camera pixel size in µm", + "default": 6.5, + }, + "objective_mag": { + "type": "number", + "description": "Objective magnification", + "default": 10.0, + }, + "min_confidence": { + "type": "number", + "description": "Minimum detection confidence", + "default": 0.7, + }, + "exposure_ms": { + "type": "number", + "description": "Camera exposure in ms", + }, + "brightness_percentile": { + "type": "number", + "description": "Brightness threshold percentile", + "default": 99.0, + }, + "min_area": { + "type": "integer", + "description": "Minimum embryo area in pixels", + "default": 5000, + }, + "max_area": { + "type": "integer", + "description": "Maximum embryo area in pixels", + "default": 150000, + }, }, }, "bluesky_plan": None, @@ -1851,9 +2330,9 @@ def _run_sam_detection( def _extract_from_events(self, documents: dict, candidate_keys: list) -> Any: """Pull data from Bluesky event documents by candidate key names.""" - events = documents.get('events', []) + events = documents.get("events", []) for event in events: - data = event.get('data', {}) + data = event.get("data", {}) for key in candidate_keys: if key in data: return data[key] @@ -1863,45 +2342,57 @@ def _extract_move(self, documents: dict, params: dict) -> dict: return {"success": True, "x": params.get("x"), "y": params.get("y")} def _extract_position(self, documents: dict, params: dict) -> dict: - events = documents.get('events', []) + events = documents.get("events", []) if events: - data = events[0].get('data', {}) - for key in ['XY:31', 'xy_stage', 'stage']: + data = events[0].get("data", {}) + for key in ["XY:31", "xy_stage", "stage"]: if key in data: val = data[key] if isinstance(val, (list, tuple)) and len(val) >= 2: return {"success": True, "x": float(val[0]), "y": float(val[1])} if isinstance(val, dict): - return {"success": True, "x": float(val.get('x', 0)), "y": float(val.get('y', 0))} + return { + "success": True, + "x": float(val.get("x", 0)), + "y": float(val.get("y", 0)), + } return {"success": False, "error": "Could not read position"} def _extract_volume(self, documents: dict, params: dict) -> dict: - val = self._extract_from_events(documents, ['volume_scanner', 'camera', 'camera_image']) + val = self._extract_from_events(documents, ["volume_scanner", "camera", "camera_image"]) if val is not None: result = {"success": True} - if isinstance(val, dict) and val.get('__file_ref__'): - result['volume'] = val # file ref — client resolves - result['shape'] = val.get('shape') + if isinstance(val, dict) and val.get("__file_ref__"): + result["volume"] = val # file ref — client resolves + result["shape"] = val.get("shape") else: - result['volume'] = val - if hasattr(val, 'shape'): - result['shape'] = list(val.shape) + result["volume"] = val + if hasattr(val, "shape"): + result["shape"] = list(val.shape) return result return {"success": False, "error": "No volume data in result"} def _extract_image(self, documents: dict, params: dict) -> dict: val = self._extract_from_events( - documents, ['HamCam1', 'lightsheet_snap', 'camera', 'bottom_camera', 'bottom_camera_image', 'Bottom PCO'] + documents, + [ + "HamCam1", + "lightsheet_snap", + "camera", + "bottom_camera", + "bottom_camera_image", + "Bottom PCO", + ], ) if val is not None: result = {"success": True} - if isinstance(val, dict) and val.get('__file_ref__'): - result['image'] = val - result['shape'] = val.get('shape') + if isinstance(val, dict) and val.get("__file_ref__"): + result["image"] = val + result["shape"] = val.get("shape") else: - result['image'] = val - if hasattr(val, 'shape'): - result['shape'] = list(val.shape) + result["image"] = val + if hasattr(val, "shape"): + result["shape"] = list(val.shape) return result return {"success": False, "error": "No image data in result"} @@ -1914,8 +2405,8 @@ def _extract_success(self, documents: dict, params: dict) -> dict: async def handle_microscope_info(self, request): """GET /api/microscope — handshake: plans as Anthropic tool schemas.""" + from . import HARDWARE_DISPLAY_NAME, HARDWARE_NAME from .description import HARDWARE_DESCRIPTION - from . import HARDWARE_NAME, HARDWARE_DISPLAY_NAME # Build plan list, filtering to actually-available plans available_plans = [] @@ -1923,18 +2414,22 @@ async def handle_microscope_info(self, request): bluesky_name = schema.get("bluesky_plan") if bluesky_name is None or bluesky_name in self.plans: # Return client-facing fields (Anthropic tool format) - available_plans.append({ - "name": plan_name, - "description": schema["description"], - "input_schema": schema["input_schema"], - }) - - return web.json_response({ - "name": HARDWARE_NAME, - "display_name": HARDWARE_DISPLAY_NAME, - "description": HARDWARE_DESCRIPTION, - "plans": available_plans, - }) + available_plans.append( + { + "name": plan_name, + "description": schema["description"], + "input_schema": schema["input_schema"], + } + ) + + return web.json_response( + { + "name": HARDWARE_NAME, + "display_name": HARDWARE_DISPLAY_NAME, + "description": HARDWARE_DESCRIPTION, + "plans": available_plans, + } + ) async def handle_microscope_execute(self, request): """POST /api/microscope/execute — execute a named plan. @@ -1970,11 +2465,17 @@ async def handle_microscope_execute(self, request): return await self.handle_get_led_status(request) elif plan_name == "status": return await self.handle_status(request) - return web.json_response({"success": False, "error": f"Plan '{plan_name}' not implemented"}, status=500) + return web.json_response( + {"success": False, "error": f"Plan '{plan_name}' not implemented"}, + status=500, + ) if bluesky_name not in self.plans: return web.json_response( - {"success": False, "error": f"Hardware plan '{bluesky_name}' not loaded"}, + { + "success": False, + "error": f"Hardware plan '{bluesky_name}' not loaded", + }, status=500, ) @@ -1991,8 +2492,13 @@ async def handle_microscope_execute(self, request): except Exception as e: import traceback + return web.json_response( - {"success": False, "error": str(e), "traceback": traceback.format_exc()}, + { + "success": False, + "error": str(e), + "traceback": traceback.format_exc(), + }, status=500, ) @@ -2026,6 +2532,7 @@ async def handle_devices_state(self, request): return web.json_response(self._state_latest) except Exception as exc: import traceback + return web.json_response( {"error": str(exc), "traceback": traceback.format_exc()}, status=500, @@ -2054,8 +2561,11 @@ async def handle_callbacks_stream(self, request): queue: asyncio.Queue = asyncio.Queue(maxsize=256) self._callback_subscribers.append(queue) peer = request.remote - logger.info("Callback subscriber connected from %s (total=%d)", - peer, len(self._callback_subscribers)) + logger.info( + "Callback subscriber connected from %s (total=%d)", + peer, + len(self._callback_subscribers), + ) try: await response.write(b"event: ready\ndata: {}\n\n") @@ -2067,9 +2577,7 @@ async def handle_callbacks_stream(self, request): continue if payload is None: break # shutdown sentinel - await response.write( - f"data: {json.dumps(payload)}\n\n".encode() - ) + await response.write(f"data: {json.dumps(payload)}\n\n".encode()) except (asyncio.CancelledError, ConnectionResetError, ConnectionAbortedError): pass except Exception: @@ -2079,8 +2587,11 @@ async def handle_callbacks_stream(self, request): self._callback_subscribers.remove(queue) except ValueError: pass - logger.info("Callback subscriber disconnected from %s (total=%d)", - peer, len(self._callback_subscribers)) + logger.info( + "Callback subscriber disconnected from %s (total=%d)", + peer, + len(self._callback_subscribers), + ) return response @@ -2108,8 +2619,11 @@ async def handle_devices_stream(self, request): queue: asyncio.Queue = asyncio.Queue(maxsize=32) self._state_subscribers.append(queue) peer = request.remote - logger.info("Device-state subscriber connected from %s (total=%d)", - peer, len(self._state_subscribers)) + logger.info( + "Device-state subscriber connected from %s (total=%d)", + peer, + len(self._state_subscribers), + ) try: # Send the most recent snapshot immediately so the UI doesn't @@ -2130,9 +2644,7 @@ async def handle_devices_stream(self, request): # shutdown timeout. if payload is None: break - await response.write( - f"data: {json.dumps(payload)}\n\n".encode() - ) + await response.write(f"data: {json.dumps(payload)}\n\n".encode()) except (asyncio.CancelledError, ConnectionResetError, ConnectionAbortedError): pass except Exception: @@ -2142,8 +2654,11 @@ async def handle_devices_stream(self, request): self._state_subscribers.remove(queue) except ValueError: pass - logger.info("Device-state subscriber disconnected from %s (total=%d)", - peer, len(self._state_subscribers)) + logger.info( + "Device-state subscriber disconnected from %s (total=%d)", + peer, + len(self._state_subscribers), + ) return response @@ -2173,8 +2688,11 @@ async def handle_bottom_camera_stream(self, request): self._bottom_camera_streamer(), name="bottom-camera-streamer" ) peer = request.remote - logger.info("Bottom-camera subscriber connected from %s (total=%d)", - peer, len(self._cam_subscribers)) + logger.info( + "Bottom-camera subscriber connected from %s (total=%d)", + peer, + len(self._cam_subscribers), + ) try: # Initial comment so the client knows the connection is alive @@ -2188,9 +2706,7 @@ async def handle_bottom_camera_stream(self, request): continue if payload is None: break # shutdown sentinel - await response.write( - f"data: {json.dumps(payload)}\n\n".encode() - ) + await response.write(f"data: {json.dumps(payload)}\n\n".encode()) except (asyncio.CancelledError, ConnectionResetError, ConnectionAbortedError): pass except Exception: @@ -2200,8 +2716,11 @@ async def handle_bottom_camera_stream(self, request): self._cam_subscribers.remove(queue) except ValueError: pass - logger.info("Bottom-camera subscriber disconnected from %s (total=%d)", - peer, len(self._cam_subscribers)) + logger.info( + "Bottom-camera subscriber disconnected from %s (total=%d)", + peer, + len(self._cam_subscribers), + ) return response @@ -2217,36 +2736,40 @@ async def on_start(self): self._app = web.Application() # Core endpoints (carried forward from simple_server.py) - self._app.router.add_get('/api/status', self.handle_status) - self._app.router.add_post('/api/queue/item/add', self.handle_submit_plan) - self._app.router.add_get('/api/history', self.handle_get_history) - self._app.router.add_get('/api/devices', self.handle_get_devices) - self._app.router.add_get('/api/plans', self.handle_get_plans) - self._app.router.add_get('/api/led/status', self.handle_get_led_status) - self._app.router.add_post('/api/led/set', self.handle_set_led) - self._app.router.add_post('/api/camera/led_mode', self.handle_set_camera_led_mode) - self._app.router.add_post('/api/camera/exposure', self.handle_set_camera_exposure) - self._app.router.add_get('/api/camera/exposure', self.handle_get_camera_exposure) - self._app.router.add_post('/api/light_source/power', self.handle_set_light_source_power) - self._app.router.add_get('/api/light_source/power', self.handle_get_light_source_power) - self._app.router.add_get('/api/plan_log', self.handle_get_plan_log) - self._app.router.add_post('/session/configure', self.handle_session_configure) + self._app.router.add_get("/api/status", self.handle_status) + self._app.router.add_post("/api/queue/item/add", self.handle_submit_plan) + self._app.router.add_get("/api/history", self.handle_get_history) + self._app.router.add_get("/api/devices", self.handle_get_devices) + self._app.router.add_get("/api/plans", self.handle_get_plans) + self._app.router.add_get("/api/led/status", self.handle_get_led_status) + self._app.router.add_post("/api/led/set", self.handle_set_led) + self._app.router.add_get("/api/temperature/status", self.handle_get_temperature_status) + self._app.router.add_post("/api/temperature/set", self.handle_set_temperature) + self._app.router.add_get("/api/room_light/status", self.handle_get_room_light_status) + self._app.router.add_post("/api/room_light/set", self.handle_set_room_light) + self._app.router.add_post("/api/camera/led_mode", self.handle_set_camera_led_mode) + self._app.router.add_post("/api/camera/exposure", self.handle_set_camera_exposure) + self._app.router.add_get("/api/camera/exposure", self.handle_get_camera_exposure) + self._app.router.add_post("/api/light_source/power", self.handle_set_light_source_power) + self._app.router.add_get("/api/light_source/power", self.handle_get_light_source_power) + self._app.router.add_get("/api/plan_log", self.handle_get_plan_log) + self._app.router.add_post("/session/configure", self.handle_session_configure) # SAM endpoints (new - replaces RPyC sam_server.py) - self._app.router.add_get('/api/sam/status', self.handle_sam_status) - self._app.router.add_post('/api/detect_embryos', self.handle_detect_embryos) + self._app.router.add_get("/api/sam/status", self.handle_sam_status) + self._app.router.add_post("/api/detect_embryos", self.handle_detect_embryos) # Microscope API (generic plan-based interface) - self._app.router.add_get('/api/microscope', self.handle_microscope_info) - self._app.router.add_post('/api/microscope/execute', self.handle_microscope_execute) + self._app.router.add_get("/api/microscope", self.handle_microscope_info) + self._app.router.add_post("/api/microscope/execute", self.handle_microscope_execute) # Device state streaming (positions + properties) - self._app.router.add_get('/api/devices/state', self.handle_devices_state) - self._app.router.add_get('/api/devices/stream', self.handle_devices_stream) - self._app.router.add_get('/api/devices/callbacks/stream', self.handle_callbacks_stream) + self._app.router.add_get("/api/devices/state", self.handle_devices_state) + self._app.router.add_get("/api/devices/stream", self.handle_devices_stream) + self._app.router.add_get("/api/devices/callbacks/stream", self.handle_callbacks_stream) # Bottom-camera live stream (subscriber-gated, off when nobody listens) - self._app.router.add_get('/api/bottom_camera/stream', self.handle_bottom_camera_stream) + self._app.router.add_get("/api/bottom_camera/stream", self.handle_bottom_camera_stream) # Start plan executor self._executor_task = asyncio.create_task(self._plan_executor()) @@ -2272,13 +2795,85 @@ async def on_start(self): logger.info("=" * 60) logger.info("HTTP API available at http://%s:%d", self.host, self.port) logger.info("=" * 60) - logger.info("Endpoints: GET /api/status, GET /api/devices, GET /api/plans, POST /api/queue/item/add, ...") + logger.info( + "Endpoints: GET /api/status, GET /api/devices, GET /api/plans," + " POST /api/queue/item/add, ..." + ) await site.start() + self._print_ready_panel() + + def _categorize_devices(self): + """Group device names into human-readable buckets for the console panel. + + First-match-wins so 'room_light' lands in Accessory (not Light) and + 'volume_scanner' in Motion. Accessory entries carry live state. + """ + buckets = { + "Motion": [], + "Imaging": [], + "Light": [], + "Accessory": [], + "Other": [], + } + for name in sorted(self.devices): + low = name.lower() + if low in ("room_light", "temperature"): + label = name + try: + dev = self.devices[name] + val = dev.read().get(dev.name, {}).get("value") + if val is not None: + label = f"{name} ({val})" + except Exception: + pass + buckets["Accessory"].append(label) + elif "cam" in low or "snap" in low: + buckets["Imaging"].append(name) + elif any(k in low for k in ("stage", "piezo", "galvo", "scanner")): + buckets["Motion"].append(name) + elif any(k in low for k in ("laser", "led", "light", "illum")): + buckets["Light"].append(name) + else: + buckets["Other"].append(name) + return list(buckets.items()) + + def _print_ready_panel(self): + """Curated, always-visible status summary at the terminal. + + Separate from the file log: the operator (often a biologist) gets the + URL the agent connects to, a grouped device inventory and accessory + states at a glance — instead of a silent console after the banner. + """ + + def _fmt(names, limit=6): + if len(names) <= limit: + return " · ".join(names) + return " · ".join(names[:limit]) + cui.c(f" +{len(names) - limit} more", "grey") + + host = self.host or "0.0.0.0" + url_host = "localhost" if host in ("0.0.0.0", "::", "") else host + + cui.out() + cui.header(f"GENTLY{cui.MIDDOT}DEVICE LAYER", badge="READY", badge_style="green") + cui.row("URL", cui.c(f"http://{url_host}:{self.port}", "bold")) + cui.row("Hardware", str((self.config or {}).get("hardware", "dispim"))) + cui.row("Devices", f"{len(self.devices)} loaded") + for label, names in self._categorize_devices(): + if names: + cui.sub(label, _fmt(names)) + cui.row("Detection", f"SAM on {self._sam_device} (loads on first use)") + cui.row("Plans", f"{len(self.plans)} available") + cui.rule(heavy=False) + cui.note("Waiting for the agent to connect. Press Ctrl+C to stop.") + cui.rule(heavy=True) + cui.out() async def on_stop(self): """Shut down the HTTP server and plan executor.""" logger.info("Shutting down...") + cui.out() + cui.note("Shutting down device layer...", "yellow") self._running = False # Cancel any pending coalesced-broadcast timer. @@ -2290,7 +2885,11 @@ async def on_stop(self): # sitting in `wait_for(queue.get(), timeout=10)` and forcing aiohttp # to wait out its shutdown timeout. Done first so handlers drain # before we cancel pollers (which would block on in-flight to_thread). - for queues in (self._state_subscribers, self._callback_subscribers, self._cam_subscribers): + for queues in ( + self._state_subscribers, + self._callback_subscribers, + self._cam_subscribers, + ): for q in list(queues): try: q.put_nowait(None) @@ -2308,8 +2907,12 @@ async def on_stop(self): # transfer). If MMCore is hung or a long exposure is set, we'd # otherwise wait forever — so each task gets a 3 s ceiling. A # timed-out thread leaks until interpreter shutdown reaps it. - for task_attr in ("_state_pos_task", "_state_slow_pos_task", - "_state_prop_task", "_cam_task"): + for task_attr in ( + "_state_pos_task", + "_state_slow_pos_task", + "_state_prop_task", + "_cam_task", + ): task = getattr(self, task_attr, None) if task is not None: task.cancel() @@ -2319,7 +2922,8 @@ async def on_stop(self): if not task.done(): logger.warning( "%s did not exit within shutdown timeout; " - "leaking thread, continuing shutdown", task_attr, + "leaking thread, continuing shutdown", + task_attr, ) setattr(self, task_attr, None) if self._executor_task: @@ -2331,16 +2935,17 @@ async def on_stop(self): if self._runner: await self._runner.cleanup() logger.info("Device layer stopped.") + cui.note("Device layer stopped.", "grey") - async def health_check(self) -> Dict: + async def health_check(self) -> dict: """Return health status with device count, queue size, SAM status.""" base = await super().health_check() - base['device_count'] = len(self.devices) - base['queue_size'] = self._plan_queue.qsize() - base['sam_loaded'] = self._sam_detector is not None + base["device_count"] = len(self.devices) + base["queue_size"] = self._plan_queue.qsize() + base["sam_loaded"] = self._sam_detector is not None return base - async def run(self, host: str = None, port: int = None): + async def run(self, host: str | None = None, port: int | None = None): """Start the server and run until interrupted.""" if host is not None: self.host = host @@ -2371,13 +2976,18 @@ async def main(port: int = settings.network.device_port, sam_device: str = "cuda parser = argparse.ArgumentParser(description="Gently Device Layer Server") parser.add_argument("--port", type=int, default=settings.network.device_port, help="HTTP port") - parser.add_argument("--sam-device", default="cuda", choices=["cuda", "cpu"], - help="Device for SAM model (default: cuda)") + parser.add_argument( + "--sam-device", + default="cuda", + choices=["cuda", "cpu"], + help="Device for SAM model (default: cuda)", + ) args = parser.parse_args() - from pathlib import Path from datetime import datetime + from pathlib import Path + log_dir = Path(settings.storage.base_path) / "logs" log_dir.mkdir(parents=True, exist_ok=True) log_file = str(log_dir / f"device_layer_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log") diff --git a/gently/hardware/dispim/devices/__init__.py b/gently/hardware/dispim/devices/__init__.py index 5910460a..a41c025e 100644 --- a/gently/hardware/dispim/devices/__init__.py +++ b/gently/hardware/dispim/devices/__init__.py @@ -5,20 +5,26 @@ from gently.devices import DiSPIMCamera, DiSPIMPiezo, ... """ -from .stage import DiSPIMZstage, DiSPIMXYStage -from .camera import DiSPIMCamera, DiSPIMDualCamera, DiSPIMBottomCamera +from .acquisition import DiSPIMLightSheetSnap, DiSPIMVolumeScanner +from .camera import DiSPIMBottomCamera, DiSPIMCamera, DiSPIMDualCamera +from .optical import DiSPIMLaserControl, DiSPIMLED, DiSPIMLightSource from .piezo import DiSPIMFDrive, DiSPIMPiezo from .scanner import DiSPIMScanner -from .optical import DiSPIMLED, DiSPIMLightSource, DiSPIMLaserControl -from .acquisition import DiSPIMVolumeScanner, DiSPIMLightSheetSnap +from .stage import DiSPIMXYStage, DiSPIMZstage __all__ = [ - "DiSPIMZstage", "DiSPIMXYStage", - "DiSPIMCamera", "DiSPIMDualCamera", "DiSPIMBottomCamera", - "DiSPIMFDrive", "DiSPIMPiezo", + "DiSPIMZstage", + "DiSPIMXYStage", + "DiSPIMCamera", + "DiSPIMDualCamera", + "DiSPIMBottomCamera", + "DiSPIMFDrive", + "DiSPIMPiezo", "DiSPIMScanner", - "DiSPIMLED", "DiSPIMLightSource", + "DiSPIMLED", + "DiSPIMLightSource", # Backwards-compatible alias for DiSPIMLightSource: "DiSPIMLaserControl", - "DiSPIMVolumeScanner", "DiSPIMLightSheetSnap", + "DiSPIMVolumeScanner", + "DiSPIMLightSheetSnap", ] diff --git a/gently/hardware/dispim/devices/acquisition.py b/gently/hardware/dispim/devices/acquisition.py index c281aaea..9b2658ef 100644 --- a/gently/hardware/dispim/devices/acquisition.py +++ b/gently/hardware/dispim/devices/acquisition.py @@ -2,22 +2,21 @@ DiSPIM compound acquisition devices (volume scanner and light sheet snap). """ -import time import logging +import time from collections import OrderedDict -from typing import Dict import numpy as np - -from ophyd.status import Status import pymmcore +from ophyd.status import Status from gently.settings import settings + from ._common import _safe_obtain -from .scanner import DiSPIMScanner from .camera import DiSPIMCamera -from .piezo import DiSPIMPiezo from .optical import DiSPIMLaserControl +from .piezo import DiSPIMPiezo +from .scanner import DiSPIMScanner logger = logging.getLogger(__name__) @@ -56,13 +55,15 @@ class DiSPIMVolumeScanner: Device name (default: "volume_scanner") """ - def __init__(self, - scanner: DiSPIMScanner, - camera: DiSPIMCamera, - piezo: DiSPIMPiezo, - laser_control: 'DiSPIMLaserControl', - core: pymmcore.CMMCore, - name: str = "volume_scanner"): + def __init__( + self, + scanner: DiSPIMScanner, + camera: DiSPIMCamera, + piezo: DiSPIMPiezo, + laser_control: "DiSPIMLaserControl", + core: pymmcore.CMMCore, + name: str = "volume_scanner", + ): """ Initialize volume scanner with all required devices. @@ -87,19 +88,21 @@ def __init__(self, self._exposure_ms = None self._laser_config = None - def configure(self, - num_slices: int, - exposure_ms: float, - galvo_amplitude: float, - galvo_center: float, - piezo_amplitude: float, - piezo_center: float, - laser_config: str = "488 and 561", - laser_power_488_pct: float = None, - laser_power_561_pct: float = None, - laser_power_405_pct: float = None, - laser_power_637_pct: float = None, - timing_params: Dict = None): + def configure( + self, + num_slices: int, + exposure_ms: float, + galvo_amplitude: float, + galvo_center: float, + piezo_amplitude: float, + piezo_center: float, + laser_config: str = "488 and 561", + laser_power_488_pct: float | None = None, + laser_power_561_pct: float | None = None, + laser_power_405_pct: float | None = None, + laser_power_637_pct: float | None = None, + timing_params: dict | None = None, + ): """ Configure all devices for hardware-triggered volume acquisition. @@ -120,7 +123,8 @@ def configure(self, laser_config : str Laser channel selection preset (default: "488 and 561"). Common options: "488 and 561", "488 only", "561 only" - laser_power_488_pct, laser_power_561_pct, laser_power_405_pct, laser_power_637_pct : float, optional + laser_power_488_pct, laser_power_561_pct, laser_power_405_pct, + laser_power_637_pct : float, optional Per-line laser power %. ``None`` leaves the current setpoint unchanged. Out-of-range values are rejected at the device-layer setter (see DiSPIMLightSource.POWER_LIMITS_PCT). @@ -135,14 +139,12 @@ def configure(self, galvo_amplitude=galvo_amplitude, galvo_center=galvo_center, num_slices=num_slices, - timing_params=timing_params + timing_params=timing_params, ) # Configure piezo for volume acquisition self.piezo.configure_for_volume_acquisition( - amplitude_um=piezo_amplitude, - offset_um=piezo_center, - num_slices=num_slices + amplitude_um=piezo_amplitude, offset_um=piezo_center, num_slices=num_slices ) # Apply per-line laser power if specified. Each setter raises if the @@ -258,6 +260,7 @@ def wait(): status.set_finished() import threading + threading.Thread(target=wait).start() return status @@ -267,8 +270,8 @@ def read(self): if self._last_volume is not None: data = OrderedDict() data[self.name] = { - 'value': self._last_volume, - 'timestamp': self._last_volume_time or time.time() + "value": self._last_volume, + "timestamp": self._last_volume_time or time.time(), } return data else: @@ -278,10 +281,10 @@ def describe(self): """Describe volume data format.""" data = OrderedDict() data[self.name] = { - 'source': self.name, - 'dtype': 'array', - 'shape': getattr(self._last_volume, 'shape', []), - 'units': 'counts' + "source": self.name, + "dtype": "array", + "shape": getattr(self._last_volume, "shape", []), + "units": "counts", } return data @@ -302,10 +305,12 @@ class DiSPIMLightSheetSnap: Used during focus sweeps and piezo-galvo calibration. """ - def __init__(self, - scanner: DiSPIMScanner, - camera: DiSPIMCamera, - name: str = "lightsheet_snap"): + def __init__( + self, + scanner: DiSPIMScanner, + camera: DiSPIMCamera, + name: str = "lightsheet_snap", + ): self.name = name self.parent = None # Required for Bluesky self.scanner = scanner @@ -314,10 +319,12 @@ def __init__(self, self._last_image = None self._last_image_time = None - def configure(self, - sheet_width_deg: float = 8.0, - y_position_deg: float = 0.0, - exposure_ms: float = 50.0): + def configure( + self, + sheet_width_deg: float = 8.0, + y_position_deg: float = 0.0, + exposure_ms: float = 50.0, + ): """ Configure light sheet parameters for single snapshot. diff --git a/gently/hardware/dispim/devices/camera.py b/gently/hardware/dispim/devices/camera.py index 34daafca..5141d6d4 100644 --- a/gently/hardware/dispim/devices/camera.py +++ b/gently/hardware/dispim/devices/camera.py @@ -2,16 +2,16 @@ DiSPIM camera detector devices (single, dual, and bottom camera). """ -import time import logging +import time from collections import OrderedDict -from typing import Tuple import numpy as np -from ophyd.status import Status import pymmcore +from ophyd.status import Status + +from gently.exceptions import AcquisitionError, HardwareError -from gently.exceptions import HardwareError, AcquisitionError from ._common import _safe_obtain logger = logging.getLogger(__name__) @@ -68,10 +68,10 @@ def wait(): else: status.set_finished() - status = Status(obj=self, timeout=30) import threading + threading.Thread(target=wait).start() return status @@ -81,8 +81,8 @@ def read(self): if self._last_image is not None: data = OrderedDict() data[self.name] = { - 'value': self._last_image, - 'timestamp': self._last_image_time or time.time() + "value": self._last_image, + "timestamp": self._last_image_time or time.time(), } return data else: @@ -92,9 +92,9 @@ def describe(self): """Describe detector data format""" data = OrderedDict() data[self.name] = { - 'source': self.name, - 'dtype': 'array', - 'shape': getattr(self._last_image, 'shape', []) + "source": self.name, + "dtype": "array", + "shape": getattr(self._last_image, "shape", []), } return data @@ -191,8 +191,9 @@ def set_trigger_active(self, mode: str): """ self.core.setProperty(self.name, "TRIGGER ACTIVE", mode) - def configure_for_calibration(self, exposure_ms: float, - roi: Tuple[int, int, int, int] = (128, 896, 2048, 512)): + def configure_for_calibration( + self, exposure_ms: float, roi: tuple[int, int, int, int] = (128, 896, 2048, 512) + ): """ Configure camera for calibration imaging (single light sheet snapshots). @@ -212,8 +213,9 @@ def configure_for_calibration(self, exposure_ms: float, self.core.setExposure(self.name, exposure_ms) self.core.waitForDevice(self.name) - def configure_for_volume_acquisition(self, exposure_ms: float, - roi: Tuple[int, int, int, int] = (128, 896, 2048, 512)): + def configure_for_volume_acquisition( + self, exposure_ms: float, roi: tuple[int, int, int, int] = (128, 896, 2048, 512) + ): """ Configure camera for hardware-triggered volume acquisition. @@ -280,8 +282,8 @@ def read(self): camera_data = self.camera.read() if self.camera.name in camera_data: - stitched_image = camera_data[self.camera.name]['value'] - timestamp = camera_data[self.camera.name]['timestamp'] + stitched_image = camera_data[self.camera.name]["value"] + timestamp = camera_data[self.camera.name]["timestamp"] # Split image in the middle (width dimension) height, width = stitched_image.shape[:2] @@ -292,14 +294,8 @@ def read(self): # Return as separate data entries data = OrderedDict() - data[f'{self.name}_image_a'] = { - 'value': image_a, - 'timestamp': timestamp - } - data[f'{self.name}_image_b'] = { - 'value': image_b, - 'timestamp': timestamp - } + data[f"{self.name}_image_a"] = {"value": image_a, "timestamp": timestamp} + data[f"{self.name}_image_b"] = {"value": image_b, "timestamp": timestamp} return data else: return OrderedDict() @@ -314,7 +310,7 @@ def describe(self): # Describe image_a and image_b outputs # Shape will be half width of original stitched image if self.camera.name in camera_desc: - original_shape = camera_desc[self.camera.name].get('shape', []) + original_shape = camera_desc[self.camera.name].get("shape", []) if len(original_shape) >= 2: # Split width dimension in half split_shape = [original_shape[0], original_shape[1] // 2] @@ -323,15 +319,15 @@ def describe(self): else: split_shape = original_shape - data[f'{self.name}_image_a'] = { - 'source': f'{self.name}_image_a', - 'dtype': 'array', - 'shape': split_shape + data[f"{self.name}_image_a"] = { + "source": f"{self.name}_image_a", + "dtype": "array", + "shape": split_shape, } - data[f'{self.name}_image_b'] = { - 'source': f'{self.name}_image_b', - 'dtype': 'array', - 'shape': split_shape + data[f"{self.name}_image_b"] = { + "source": f"{self.name}_image_b", + "dtype": "array", + "shape": split_shape, } return data @@ -362,12 +358,14 @@ class DiSPIMBottomCamera(DiSPIMCamera): Used for finding and centering embryos in the sample chamber. """ - def __init__(self, - device_name: str, - core: pymmcore.CMMCore, - led_control: 'DiSPIMLED', - pixel_size_um: float = 6.5, - magnification: float = 10.0): + def __init__( + self, + device_name: str, + core: pymmcore.CMMCore, + led_control: "DiSPIMLED", # noqa: F821 + pixel_size_um: float = 6.5, + magnification: float = 10.0, + ): """ Initialize bottom camera with LED control and calibrated pixel size. @@ -455,6 +453,7 @@ def wait(): status.set_finished() import threading + threading.Thread(target=wait).start() return status diff --git a/gently/hardware/dispim/devices/optical.py b/gently/hardware/dispim/devices/optical.py index 87430df8..a6544cc0 100644 --- a/gently/hardware/dispim/devices/optical.py +++ b/gently/hardware/dispim/devices/optical.py @@ -2,12 +2,12 @@ DiSPIM optical control devices (LED and laser). """ -import time import logging +import time from collections import OrderedDict -from ophyd.status import Status import pymmcore +from ophyd.status import Status logger = logging.getLogger(__name__) @@ -20,7 +20,7 @@ class DiSPIMLED: Device-agnostic: any plan that sets device state will work """ - def __init__(self, core: pymmcore.CMMCore, name: str = "LED", group_name: str = None): + def __init__(self, core: pymmcore.CMMCore, name: str = "LED", group_name: str | None = None): self.core = core self.name = name self.group_name = group_name or name @@ -39,8 +39,7 @@ def _get_available_configs(self): def set(self, state: str): """Set LED state - called by bps.mv(led, 'Open') or bps.mv(led, 'Closed')""" if state not in self._available_configs: - raise ValueError(f"State '{state}' not available. " - f"Available: {self._available_configs}") + raise ValueError(f"State '{state}' not available. Available: {self._available_configs}") status = Status(obj=self, timeout=5) @@ -54,6 +53,7 @@ def wait(): status.set_finished() import threading + threading.Thread(target=wait).start() return status @@ -63,23 +63,16 @@ def read(self): try: current_config = self.core.getCurrentConfig(self.group_name) except Exception: - current_config = 'unknown' + current_config = "unknown" data = OrderedDict() - data[self.name] = { - 'value': current_config, - 'timestamp': time.time() - } + data[self.name] = {"value": current_config, "timestamp": time.time()} return data def describe(self): """Describe LED device - required for Bluesky""" data = OrderedDict() - data[self.name] = { - 'source': self.name, - 'dtype': 'string', - 'shape': [] - } + data[self.name] = {"source": self.name, "dtype": "string", "shape": []} return data def read_configuration(self): @@ -135,7 +128,7 @@ class DiSPIMLightSource: 637: (0.0, 100.0), } - def __init__(self, core: pymmcore.CMMCore, name: str = "Laser", group_name: str = None): + def __init__(self, core: pymmcore.CMMCore, name: str = "Laser", group_name: str | None = None): self.core = core self.name = name self.group_name = group_name or name @@ -154,8 +147,9 @@ def _get_available_configs(self): def set(self, config_name: str): """Set laser configuration - called by bps.mv(laser, 'config_name')""" if config_name not in self._available_configs: - raise ValueError(f"Config '{config_name}' not available. " - f"Available: {self._available_configs}") + raise ValueError( + f"Config '{config_name}' not available. Available: {self._available_configs}" + ) status = Status(obj=self, timeout=5) @@ -169,6 +163,7 @@ def wait(): status.set_finished() import threading + threading.Thread(target=wait).start() return status @@ -214,7 +209,10 @@ def set_power_pct(self, wavelength: int, pct: float) -> None: # No waitForDevice — analog setpoint applies on the next exposure. logger.debug( "Set %dnm power to %.4f%% (%s / %s)", - wavelength, pct, self.POWER_DEVICE_LABEL, prop, + wavelength, + pct, + self.POWER_DEVICE_LABEL, + prop, ) def get_power_pct(self, wavelength: int) -> float: @@ -224,32 +222,25 @@ def get_power_pct(self, wavelength: int) -> float: f"Unknown laser wavelength {wavelength}nm. " f"Available: {sorted(self.POWER_PROPERTY.keys())}" ) - return float(self.core.getProperty( - self.POWER_DEVICE_LABEL, self.POWER_PROPERTY[wavelength] - )) + return float( + self.core.getProperty(self.POWER_DEVICE_LABEL, self.POWER_PROPERTY[wavelength]) + ) def read(self): """Read current laser configuration - required for Bluesky""" try: current_config = self.core.getCurrentConfig(self.group_name) except Exception: - current_config = 'unknown' + current_config = "unknown" data = OrderedDict() - data[self.name] = { - 'value': current_config, - 'timestamp': time.time() - } + data[self.name] = {"value": current_config, "timestamp": time.time()} return data def describe(self): """Describe laser control device - required for Bluesky""" data = OrderedDict() - data[self.name] = { - 'source': self.group_name, - 'dtype': 'string', - 'shape': [] - } + data[self.name] = {"source": self.group_name, "dtype": "string", "shape": []} return data def read_configuration(self): diff --git a/gently/hardware/dispim/devices/piezo.py b/gently/hardware/dispim/devices/piezo.py index 7103814b..780e507e 100644 --- a/gently/hardware/dispim/devices/piezo.py +++ b/gently/hardware/dispim/devices/piezo.py @@ -2,49 +2,91 @@ DiSPIM piezo and F-drive positioner devices. """ -import time import logging +import time from collections import OrderedDict -from typing import Tuple -from ophyd.status import Status import pymmcore +from ophyd.status import Status from gently.exceptions import HardwareError, StageMovementError logger = logging.getLogger(__name__) +# ========================================================================= +# SPIM-HEAD F-DRIVE HARDWARE SAFETY LIMITS — absolute MMCore micrometres. +# +# Layer-1 software fence for the SPIM head: the ASI Tiger "ZStage:V:37" +# axis, which the ASIdiSPIM plugin labels "SPIM Head Height" (the F axis). +# This is the drive that LOWERS the objectives into the dish to hunt for +# embryos ("Start Hunting") and RAISES them clear for sample loading +# ("Load Sample"). Every F-drive move planned by any layer above (Bluesky +# plans, agent orchestrators, UI tools) is bounded here. These are NOT +# constructor kwargs and DiSPIMFDrive exposes no setter — no layer above +# can widen them. +# +# F_DRIVE_MIN_UM collision-critical FLOOR. Smaller F drives the head +# DOWN toward the sample/holder; this hard stop keeps the +# objectives off the dish. +# F_DRIVE_MAX_UM fully-raised "Load Sample" ceiling. +# +# Update only after physically verifying head travel on the rig (drive the +# F axis to each extreme by hand, confirm no collision, read the absolute +# MMCore position) and then editing the constants below. +# +# SCOPE: this is a SOFTWARE-ONLY fence. Unlike the XY stage (see +# devices/stage.py, which also pushes ASI Tiger firmware soft-limits) we do +# NOT write these to the controller, so a physical joystick move can still +# drive the head past these bounds — they bind code-issued moves only. +# ========================================================================= +F_DRIVE_MIN_UM: float = 30.0 +F_DRIVE_MAX_UM: float = 25000.0 + + class DiSPIMFDrive: """ DiSPIM F-drive (SPIM Head motor) - works with bps.mv(fdrive, position) - ASI Tiger V:37 axis - controls F-axis module for lowering objectives - Device-agnostic: any plan that moves a positioner will work with this device + ASI Tiger "ZStage:V:37" axis — the ASIdiSPIM "SPIM Head Height" / F + axis that lowers the objectives to hunt for embryos and raises them to + load a sample. Device-agnostic: any plan that moves a positioner works. + + Hard travel bounds are the module-level F_DRIVE_MIN_UM / F_DRIVE_MAX_UM + constants. They are not constructor kwargs and cannot be widened from + above — see the safety-limit note above this class. """ - def __init__(self, name: str, core: pymmcore.CMMCore, - limits: Tuple[float, float] = (20.0, 25000.0)): + def __init__(self, name: str, core: pymmcore.CMMCore, move_timeout_s: float = 120.0): self.name = name self.core = core self.parent = None # Required for Bluesky - self._limits = limits + # Full-travel F moves (e.g. 25000 -> 5000 um: "Load Sample" -> approach) + # are slow; the per-move Status timeout must comfortably exceed the + # longest traverse or bps.mv would error mid-move while the stage is + # still travelling. Configurable for unusually slow controllers. + self._move_timeout_s = float(move_timeout_s) self.tolerance = 0.1 # µm @property - def limits(self): - return self._limits + def limits(self) -> tuple[float, float]: + """Read-only view of the hardware safety limits (module constants).""" + return (F_DRIVE_MIN_UM, F_DRIVE_MAX_UM) def set(self, position): """Move F-drive to position - called by bps.mv()""" position = float(position) position = round(position, 2) # Round to 0.01 μm precision - # Safety check - if not (self._limits[0] <= position <= self._limits[1]): - raise ValueError(f"Position {position} outside limits {self._limits}") + # Hardware safety check — pinned to the module-level F_DRIVE_*_UM + # constants; nothing above this layer can widen them. + if not (F_DRIVE_MIN_UM <= position <= F_DRIVE_MAX_UM): + raise ValueError( + f"F-drive position {position} outside hardware limits " + f"[{F_DRIVE_MIN_UM}, {F_DRIVE_MAX_UM}]" + ) - status = Status(obj=self, timeout=30) + status = Status(obj=self, timeout=self._move_timeout_s) def wait(): try: @@ -56,6 +98,7 @@ def wait(): status.set_finished() import threading + threading.Thread(target=wait).start() return status @@ -69,9 +112,9 @@ def read(self): data = OrderedDict() data[self.name] = { - 'value': float(value), - 'timestamp': time.time(), - 'units': 'micrometers' + "value": float(value), + "timestamp": time.time(), + "units": "micrometers", } return data @@ -79,10 +122,10 @@ def describe(self): """Describe F-drive device - required for Bluesky""" data = OrderedDict() data[self.name] = { - 'source': self.name, - 'dtype': 'number', - 'shape': [], - 'units': 'micrometers' + "source": self.name, + "dtype": "number", + "shape": [], + "units": "micrometers", } return data @@ -103,8 +146,12 @@ class DiSPIMPiezo: Device-agnostic: any plan that moves a positioner will work with this device """ - def __init__(self, name: str, core: pymmcore.CMMCore, - limits: Tuple[float, float] = (-200, 200.0)): + def __init__( + self, + name: str, + core: pymmcore.CMMCore, + limits: tuple[float, float] = (-200, 200.0), + ): self.name = name self.core = core self.parent = None # Required for Bluesky @@ -136,6 +183,7 @@ def wait(): status.set_finished() import threading + threading.Thread(target=wait).start() return status @@ -150,9 +198,9 @@ def read(self): data = OrderedDict() data[self.name] = { - 'value': float(value), - 'timestamp': time.time(), - 'units': 'micrometers' + "value": float(value), + "timestamp": time.time(), + "units": "micrometers", } return data @@ -160,10 +208,10 @@ def describe(self): """Describe piezo device - required for Bluesky""" data = OrderedDict() data[self.name] = { - 'source': self.name, - 'dtype': 'number', - 'shape': [], - 'units': 'micrometers' + "source": self.name, + "dtype": "number", + "shape": [], + "units": "micrometers", } return data @@ -192,10 +240,9 @@ def set_as_focus_device(self): """Set this piezo as the Micro-Manager focus device.""" self.core.setFocusDevice(self.name) - def configure_amplitude_offset(self, - amplitude_um: float, - offset_um: float, - pattern: str = "1 - Triangle"): + def configure_amplitude_offset( + self, amplitude_um: float, offset_um: float, pattern: str = "1 - Triangle" + ): """ Configure piezo amplitude and offset for scanning. @@ -236,10 +283,9 @@ def configure_for_spim(self, num_slices: int): """ self.core.setProperty(self.name, "SPIMNumSlices", num_slices) - def configure_for_volume_acquisition(self, - amplitude_um: float, - offset_um: float, - num_slices: int): + def configure_for_volume_acquisition( + self, amplitude_um: float, offset_um: float, num_slices: int + ): """ Configure piezo for hardware-triggered volume acquisition. diff --git a/gently/hardware/dispim/devices/scanner.py b/gently/hardware/dispim/devices/scanner.py index b4a92995..1197c5c1 100644 --- a/gently/hardware/dispim/devices/scanner.py +++ b/gently/hardware/dispim/devices/scanner.py @@ -2,15 +2,13 @@ DiSPIM scanner/galvo mirror control devices. """ -import time import logging +import time from collections import OrderedDict -from typing import Dict, Tuple import numpy as np - -from ophyd.status import Status import pymmcore +from ophyd.status import Status from gently.exceptions import HardwareError, StageMovementError @@ -36,9 +34,7 @@ def set(self, value): def wait(): try: - self.scanner.core.setProperty( - self.scanner.name, self.property_name, float(value) - ) + self.scanner.core.setProperty(self.scanner.name, self.property_name, float(value)) self.scanner.core.waitForDevice(self.scanner.name) except (RuntimeError, StageMovementError) as exc: status.set_exception(exc) @@ -46,6 +42,7 @@ def wait(): status.set_finished() import threading + threading.Thread(target=wait).start() return status @@ -56,38 +53,32 @@ def setPosition(self, value: float) -> None: the Status/thread plumbing. MMCore traffic stays inside this ophyd boundary. """ - self.scanner.core.setProperty( - self.scanner.name, self.property_name, float(value) - ) + self.scanner.core.setProperty(self.scanner.name, self.property_name, float(value)) self.scanner.core.waitForDevice(self.scanner.name) def read(self): """Read current offset value""" try: - value = float(self.scanner.core.getProperty( - self.scanner.name, self.property_name - )) + value = float(self.scanner.core.getProperty(self.scanner.name, self.property_name)) except (RuntimeError, HardwareError): value = 0.0 - return OrderedDict({ - self.name: { - 'value': value, - 'timestamp': time.time(), - 'units': 'degrees' - } - }) + return OrderedDict( + {self.name: {"value": value, "timestamp": time.time(), "units": "degrees"}} + ) def describe(self): """Describe component""" - return OrderedDict({ - self.name: { - 'source': self.name, - 'dtype': 'number', - 'shape': [], - 'units': 'degrees' + return OrderedDict( + { + self.name: { + "source": self.name, + "dtype": "number", + "shape": [], + "units": "degrees", + } } - }) + ) class DiSPIMScanner: @@ -102,16 +93,20 @@ class DiSPIMScanner: bps.mv(scanner.sa_offset_y, y_value) # Y-axis offset (galvo position) """ - def __init__(self, name: str, core: pymmcore.CMMCore, - limits: Tuple[float, float] = (-5.0, 5.0)): + def __init__( + self, + name: str, + core: pymmcore.CMMCore, + limits: tuple[float, float] = (-5.0, 5.0), + ): self.name = name self.core = core self.parent = None # Required for Bluesky self._limits = limits # Create movable axis offset components for use with bps.mv() - self.sa_offset_x = _ScannerAxisOffset(self, 'x', 'SingleAxisXOffset(deg)') - self.sa_offset_y = _ScannerAxisOffset(self, 'y', 'SingleAxisYOffset(deg)') + self.sa_offset_x = _ScannerAxisOffset(self, "x", "SingleAxisXOffset(deg)") + self.sa_offset_y = _ScannerAxisOffset(self, "y", "SingleAxisYOffset(deg)") @property def limits(self): @@ -143,6 +138,7 @@ def wait(): status.set_finished() import threading + threading.Thread(target=wait).start() return status @@ -162,21 +158,17 @@ def read(self): ab_pos = np.array([0.0, 0.0]) data = OrderedDict() - data[self.name] = { - 'value': ab_pos, - 'timestamp': time.time(), - 'units': 'volts' - } + data[self.name] = {"value": ab_pos, "timestamp": time.time(), "units": "volts"} return data def describe(self): """Describe scanner device - required for Bluesky""" data = OrderedDict() data[self.name] = { - 'source': self.name, - 'dtype': 'array', - 'shape': [2], - 'units': 'volts' + "source": self.name, + "dtype": "array", + "shape": [2], + "units": "volts", } return data @@ -224,9 +216,13 @@ def set_spim_state(self, state: str): if state == "Idle": self.core.waitForDevice(self.name) - def configure_x_axis(self, amplitude_deg: float, offset_deg: float, - pattern: str = "1 - Triangle", - mode: str = "3 - Enabled with axes synced"): + def configure_x_axis( + self, + amplitude_deg: float, + offset_deg: float, + pattern: str = "1 - Triangle", + mode: str = "3 - Enabled with axes synced", + ): """ Configure galvo X-axis (light sheet width scanning). @@ -246,9 +242,13 @@ def configure_x_axis(self, amplitude_deg: float, offset_deg: float, self.core.setProperty(self.name, "SingleAxisXPattern", pattern) self.core.setProperty(self.name, "SingleAxisXMode", mode) - def configure_y_axis(self, amplitude_deg: float, offset_deg: float, - pattern: str = "1 - Triangle", - mode: str = "3 - Enabled with axes synced"): + def configure_y_axis( + self, + amplitude_deg: float, + offset_deg: float, + pattern: str = "1 - Triangle", + mode: str = "3 - Enabled with axes synced", + ): """ Configure galvo Y-axis (light sheet Z-plane positioning). @@ -282,14 +282,16 @@ def set_y_offset(self, angle_deg: float): self.core.setProperty(self.name, "SingleAxisYOffset(deg)", float(angle_deg)) self.core.waitForDevice(self.name) - def configure_spim_timing(self, - scan_delay_ms: float = 6.75, - num_scans_per_slice: int = 1, - scan_duration_ms: float = 5.5, - laser_delay_ms: float = 8.0, - laser_duration_ms: float = 5.0, - camera_delay_ms: float = 8.0, - camera_duration_ms: float = 1.0): + def configure_spim_timing( + self, + scan_delay_ms: float = 6.75, + num_scans_per_slice: int = 1, + scan_duration_ms: float = 5.5, + laser_delay_ms: float = 8.0, + laser_duration_ms: float = 5.0, + camera_delay_ms: float = 8.0, + camera_duration_ms: float = 1.0, + ): """ Configure SPIM timing parameters for hardware-triggered acquisition. @@ -318,11 +320,13 @@ def configure_spim_timing(self, self.core.setProperty(self.name, "SPIMDelayBeforeCamera(ms)", camera_delay_ms) self.core.setProperty(self.name, "SPIMCameraDuration(ms)", camera_duration_ms) - def configure_spim_parameters(self, - num_slices: int, - slices_per_piezo: int = 1, - num_sides: int = 1, - first_side: str = "A"): + def configure_spim_parameters( + self, + num_slices: int, + slices_per_piezo: int = 1, + num_sides: int = 1, + first_side: str = "A", + ): """ Configure SPIM acquisition parameters. @@ -358,11 +362,13 @@ def configure_for_calibration(self): self.configure_y_axis(amplitude_deg=0.0001, offset_deg=0.0) self.core.waitForDevice(self.name) - def configure_for_volume_acquisition(self, - galvo_amplitude: float, - galvo_center: float, - num_slices: int, - timing_params: Dict = None): + def configure_for_volume_acquisition( + self, + galvo_amplitude: float, + galvo_center: float, + num_slices: int, + timing_params: dict | None = None, + ): """ Configure scanner for hardware-triggered volume acquisition. diff --git a/gently/hardware/dispim/devices/stage.py b/gently/hardware/dispim/devices/stage.py index a179135a..d155b5f8 100644 --- a/gently/hardware/dispim/devices/stage.py +++ b/gently/hardware/dispim/devices/stage.py @@ -2,15 +2,13 @@ DiSPIM stage positioner devices (Z-stage and XY-stage). """ -import time import logging +import time from collections import OrderedDict -from typing import Tuple import numpy as np - -from ophyd.status import Status import pymmcore +from ophyd.status import Status from gently.exceptions import HardwareError, StageMovementError @@ -39,9 +37,9 @@ # inset, even a fast-joystick overshoot still lands inside the true safe # travel envelope the operator measured by hand. XY_STAGE_X_MIN_UM: float = -2252.1 -XY_STAGE_X_MAX_UM: float = 983.0 +XY_STAGE_X_MAX_UM: float = 983.0 XY_STAGE_Y_MIN_UM: float = -1677.0 -XY_STAGE_Y_MAX_UM: float = 586.6 +XY_STAGE_Y_MAX_UM: float = 586.6 class DiSPIMZstage: @@ -51,8 +49,12 @@ class DiSPIMZstage: Device-agnostic: any plan that moves a positioner will work with this device """ - def __init__(self, name: str, core: pymmcore.CMMCore, - limits: Tuple[float, float] = (50.0, 250.0)): + def __init__( + self, + name: str, + core: pymmcore.CMMCore, + limits: tuple[float, float] = (50.0, 250.0), + ): self.name = name self.core = core self.parent = None # Required for Bluesky @@ -87,6 +89,7 @@ def wait(): status.set_finished() import threading + threading.Thread(target=wait).start() return status @@ -101,9 +104,9 @@ def read(self): data = OrderedDict() data[self.name] = { - 'value': float(value), - 'timestamp': time.time(), - 'units': 'micrometers' + "value": float(value), + "timestamp": time.time(), + "units": "micrometers", } return data @@ -111,10 +114,10 @@ def describe(self): """Describe Z stage device - required for Bluesky""" data = OrderedDict() data[self.name] = { - 'source': self.name, - 'dtype': 'number', - 'shape': [], - 'units': 'micrometers' + "source": self.name, + "dtype": "number", + "shape": [], + "units": "micrometers", } return data @@ -141,12 +144,12 @@ def __init__(self, name: str, core: pymmcore.CMMCore): self.parent = None # Required for Bluesky @property - def x_limits(self) -> Tuple[float, float]: + def x_limits(self) -> tuple[float, float]: """Read-only view of the hardware safety limits (module constants).""" return (XY_STAGE_X_MIN_UM, XY_STAGE_X_MAX_UM) @property - def y_limits(self) -> Tuple[float, float]: + def y_limits(self) -> tuple[float, float]: """Read-only view of the hardware safety limits (module constants).""" return (XY_STAGE_Y_MIN_UM, XY_STAGE_Y_MAX_UM) @@ -183,6 +186,7 @@ def wait(): status.set_finished() import threading + threading.Thread(target=wait).start() return status @@ -198,9 +202,9 @@ def read(self): data = OrderedDict() data[self.name] = { - 'value': xy_pos, - 'timestamp': time.time(), - 'units': 'micrometers' + "value": xy_pos, + "timestamp": time.time(), + "units": "micrometers", } return data @@ -208,10 +212,10 @@ def describe(self): """Describe XY stage device - required for Bluesky""" data = OrderedDict() data[self.name] = { - 'source': self.name, - 'dtype': 'array', - 'shape': [2], - 'units': 'micrometers' + "source": self.name, + "dtype": "array", + "shape": [2], + "units": "micrometers", } return data @@ -237,8 +241,10 @@ def describe_configuration(self): def set_firmware_limits( self, - x_min_mm: float, x_max_mm: float, - y_min_mm: float, y_max_mm: float, + x_min_mm: float, + x_max_mm: float, + y_min_mm: float, + y_max_mm: float, *, readback_tolerance_mm: float = 0.001, ) -> None: @@ -289,13 +295,15 @@ def set_firmware_limits( # overshoot we're trying to absorb anyway. POS_SLOP_MM = 0.001 # 1 µm try: - cur = self.read()[self.name]['value'] + cur = self.read()[self.name]["value"] cur_x_mm = float(cur[0]) / 1000.0 cur_y_mm = float(cur[1]) / 1000.0 except Exception as exc: - raise HardwareError(f"Could not read current XY to validate limits: {exc}") - if not (x_min_mm - POS_SLOP_MM <= cur_x_mm <= x_max_mm + POS_SLOP_MM and - y_min_mm - POS_SLOP_MM <= cur_y_mm <= y_max_mm + POS_SLOP_MM): + raise HardwareError(f"Could not read current XY to validate limits: {exc}") from exc + if not ( + x_min_mm - POS_SLOP_MM <= cur_x_mm <= x_max_mm + POS_SLOP_MM + and y_min_mm - POS_SLOP_MM <= cur_y_mm <= y_max_mm + POS_SLOP_MM + ): raise ValueError( f"Current stage position ({cur_x_mm * 1000:.2f}, {cur_y_mm * 1000:.2f}) µm " f"is outside the requested firmware envelope " @@ -318,11 +326,11 @@ def set_firmware_limits( raise HardwareError( f"setProperty {prop}={value_mm} failed: {exc}. The ASI adapter " f"may require EnableAdvancedProperties=Yes for this write." - ) + ) from exc try: got = float(self.core.getProperty(self.name, prop)) except RuntimeError as exc: - raise HardwareError(f"getProperty {prop} read-back failed: {exc}") + raise HardwareError(f"getProperty {prop} read-back failed: {exc}") from exc if abs(got - value_mm) > readback_tolerance_mm: raise HardwareError( f"Firmware limit read-back mismatch for {prop}: " @@ -332,6 +340,40 @@ def set_firmware_limits( ) logger.info("ASI firmware limit %s = %.4f mm (verified)", prop, got) + def enable_joystick(self, enabled: bool = True) -> None: + """Set the ASI Tiger 'JoystickEnabled' property on the XY stage. + + Tiger firmware persists this flag in its non-volatile card settings + (touched whenever someone calls SaveCardSettings — we don't, but + previous sessions may have). If it persisted as 'No', the physical + joystick is dead on boot until something writes 'Yes'. This method + is the boot-time fix; it's called from device_layer.initialize right + after the firmware soft limits are applied. + + Read-back verified so a silent rejection by the adapter doesn't + leave the operator wondering why the controller still doesn't move. + """ + target = "Yes" if enabled else "No" + prop = "JoystickEnabled" + try: + self.core.setProperty(self.name, prop, target) + except RuntimeError as exc: + raise HardwareError( + f"setProperty {prop}={target} failed on {self.name}: {exc}" + ) from exc + try: + got = self.core.getProperty(self.name, prop) + except RuntimeError as exc: + raise HardwareError( + f"getProperty {prop} read-back failed on {self.name}: {exc}" + ) from exc + if str(got).strip() != target: + raise HardwareError( + f"{prop} read-back mismatch on {self.name}: " + f"wrote '{target}', controller reports '{got}'." + ) + logger.info("ASI %s.%s = %s (verified)", self.name, prop, got) + # Synchronous convenience methods (usable outside RunEngine) def get_position(self) -> np.ndarray: """ @@ -348,7 +390,7 @@ def get_position(self) -> np.ndarray: the RunEngine for interactive use, setup, and debugging. For use within plans, prefer yield from bps.rd(xy_stage). """ - return self.read()[self.name]['value'] + return self.read()[self.name]["value"] def get_x(self) -> float: """ @@ -374,9 +416,9 @@ def get_y(self) -> float: # Coordinate conversion utilities for embryo centering @staticmethod - def pixel_to_stage_offset(pixel_offset_x: float, - pixel_offset_y: float, - pixel_size_um: float) -> Tuple[float, float]: + def pixel_to_stage_offset( + pixel_offset_x: float, pixel_offset_y: float, pixel_size_um: float + ) -> tuple[float, float]: """ Convert pixel offsets to stage movement in micrometers. @@ -402,4 +444,5 @@ def pixel_to_stage_offset(pixel_offset_x: float, This method delegates to gently.coordinates for the actual calculation. """ from gently.core.coordinates import pixel_displacement_to_stage_movement + return pixel_displacement_to_stage_movement(pixel_offset_x, pixel_offset_y, pixel_size_um) diff --git a/gently/hardware/dispim/devices/system.py b/gently/hardware/dispim/devices/system.py index 66100491..fabcc26a 100644 --- a/gently/hardware/dispim/devices/system.py +++ b/gently/hardware/dispim/devices/system.py @@ -31,7 +31,7 @@ from __future__ import annotations import logging -from typing import Any, Dict, List +from typing import Any import pymmcore @@ -53,13 +53,13 @@ def __init__(self) -> None: def enable_stderr_log(self, enabled: bool) -> None: self.core.enableStderrLog(bool(enabled)) - def set_device_adapter_search_paths(self, paths: List[str]) -> None: + def set_device_adapter_search_paths(self, paths: list[str]) -> None: self.core.setDeviceAdapterSearchPaths(list(paths)) def load_system_configuration(self, path: str) -> None: self.core.loadSystemConfiguration(str(path)) - def get_loaded_devices(self) -> List[str]: + def get_loaded_devices(self) -> list[str]: return list(self.core.getLoadedDevices()) # ----- system-wide property cache ------------------------------------ diff --git a/gently/hardware/dispim/devices/test_temp_usb.py b/gently/hardware/dispim/devices/test_temp_usb.py new file mode 100644 index 00000000..8b635bbd --- /dev/null +++ b/gently/hardware/dispim/devices/test_temp_usb.py @@ -0,0 +1,98 @@ +import threading +import time + +import serial + + +class AcuityNanoPrecisionThermalizerSerial: + def __init__(self, com_port, baud_rate=115200): + self.telemetry = { + "target": 20.0, + "water": 20.0, + "peltier": 20.0, + "state": "DISCONNECTED", + "errors": "0", + } + self.running = True + self.ser = serial.Serial(com_port, baud_rate, timeout=0.1) + time.sleep(2) + + self.thread = threading.Thread(target=self._read_loop, daemon=True) + self.thread.start() + + def _read_loop(self): + while self.running and self.ser.is_open: + try: + if self.ser.in_waiting: + line = self.ser.readline().decode("utf-8", errors="ignore").strip() + if "=" in line: + key, val = line.split("=", 1) + if key == "TARGET": + self.telemetry["target"] = float(val) + elif key == "WATER": + self.telemetry["water"] = float(val) + elif key == "ACTUAL": + self.telemetry["peltier"] = float(val) + elif key == "STATE": + self.telemetry["state"] = val + elif key == "ERRORS": + self.telemetry["errors"] = val + except Exception: + pass + time.sleep(0.01) + + def close(self): + self.running = False + if self.ser.is_open: + self.ser.close() + + def set_temperature(self, target_celsius): + if 0.0 <= target_celsius <= 99.9: + cmd = f"TEMP={target_celsius}\n" + self.ser.write(cmd.encode("utf-8")) + else: + raise ValueError("Target must be between 0.0 and 99.9 C") + + def enable_tec(self, enable=True): + val = "1" if enable else "0" + cmd = f"ENABLE={val}\n" + self.ser.write(cmd.encode("utf-8")) + + def set_feedback_sensor(self, use_peltier=False): + val = "1" if use_peltier else "0" + cmd = f"SENSOR={val}\n" + self.ser.write(cmd.encode("utf-8")) + + def get_water_temp(self): + return self.telemetry["water"] + + def get_system_state(self): + return self.telemetry["state"] + + def wait_for_target(self, timeout_seconds=300): + start = time.time() + while time.time() - start < timeout_seconds: + if "[ SYSTEM LOCKED ]" in self.telemetry["state"]: + return True + time.sleep(0.5) + return False + + +if __name__ == "__main__": + import time + + print("Connecting to ACUITYnano...") + acuity = AcuityNanoPrecisionThermalizerSerial("COM8") + + print("Commanding 37.0 C...") + acuity.set_temperature(37.0) + acuity.enable_tec(True) + + print("Waiting for thermal stabilization...") + if acuity.wait_for_target(timeout_seconds=600): + print(f"System locked at {acuity.get_water_temp()} C!") + # Trigger external camera or syringe pump here + else: + print("Timeout reached before system stabilized.") + + acuity.close() diff --git a/gently/hardware/dispim/devices/test_temperature_controller.py b/gently/hardware/dispim/devices/test_temperature_controller.py new file mode 100644 index 00000000..2bed40c6 --- /dev/null +++ b/gently/hardware/dispim/devices/test_temperature_controller.py @@ -0,0 +1,112 @@ +""" +ACUITYnano Third-Party Integration SDK +Provides a clean, object-oriented API for external software automation. +""" + +import threading +import time + +import paho.mqtt.client as mqtt + + +class AcuityNanoPrecisionThermalizerAPI: + def __init__( + self, + broker="d0246aa97d194c9da52a19e6f46063eb.s1.eu.hivemq.cloud", + port=8883, + user="acuitynano", + password="Bg984V!@wfhBrkp", + ): + self.prefix = "acuitynano_hhmi_shroff_diSPIM_001" + self.telemetry = { + "target": 20.0, + "water": 20.0, + "peltier": 20.0, + "state": "DISCONNECTED", + "errors": "0", + } + + try: + self.client = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2) + except AttributeError: + self.client = mqtt.Client() + + self.client.username_pw_set(user, password) + self.client.tls_set() + self.client.on_connect = self._on_connect + self.client.on_message = self._on_message + + self.thread = threading.Thread(target=self._start_loop, daemon=True) + self.thread.start() + time.sleep(2) + + def _start_loop(self): + self.client.connect(broker, port, 60) # noqa: F821 + self.client.loop_forever() + + def _on_connect(self, client, userdata, flags, rc, properties=None): + self.client.subscribe(f"{self.prefix}/telemetry/#") + + def _on_message(self, client, userdata, msg): + topic = msg.topic.split("/")[-1] + payload = msg.payload.decode("utf-8") + + if topic == "target": + self.telemetry["target"] = float(payload) + elif topic == "water": + self.telemetry["water"] = float(payload) + elif topic == "actual": + self.telemetry["peltier"] = float(payload) + elif topic == "state": + self.telemetry["state"] = payload + elif topic == "errors": + self.telemetry["errors"] = payload + + def set_temperature(self, target_celsius): + if 0.0 <= target_celsius <= 99.9: + self.client.publish(f"{self.prefix}/cmd/temp", str(target_celsius)) + else: + raise ValueError("Target must be between 0.0 and 99.9 C") + + def enable_tec(self, enable=True): + val = "1" if enable else "0" + self.client.publish(f"{self.prefix}/cmd/enable", val) + + def set_feedback_sensor(self, use_peltier=False): + val = "1" if use_peltier else "0" + self.client.publish(f"{self.prefix}/cmd/sensor", val) + + def get_water_temp(self): + return self.telemetry["water"] + + def get_peltier_temp(self): + return self.telemetry["peltier"] + + def get_system_state(self): + return self.telemetry["state"] + + def wait_for_target(self, timeout_seconds=300): + start = time.time() + while time.time() - start < timeout_seconds: + if "[ SYSTEM LOCKED ]" in self.telemetry["state"]: + return True + time.sleep(0.5) + return False + + +if __name__ == "__main__": + import time + + from acuitynano_precision_thermalizer_api import AcuityNanoPrecisionThermalizerAPI + + acuity = AcuityNanoPrecisionThermalizerAPI() + print("Commanding ACUITYnano to 37.0 C...") + acuity.set_temperature(30.0) + acuity.enable_tec(True) + + print("Waiting for thermal stabilization...") + if acuity.wait_for_target(timeout_seconds=600): + print(f"System locked at {acuity.get_water_temp()} C!") + # Trigger image acquisition here + else: + print("Timeout reached.") diff --git a/gently/hardware/dispim/plans/acquisition.py b/gently/hardware/dispim/plans/acquisition.py index 8b8ffb35..02aabba8 100644 --- a/gently/hardware/dispim/plans/acquisition.py +++ b/gently/hardware/dispim/plans/acquisition.py @@ -30,15 +30,14 @@ import logging import time -import numpy as np -import matplotlib.pyplot as plt -from typing import Any, Dict, Generator, List, Tuple, Optional +from collections.abc import Generator +from typing import Any + import bluesky.plan_stubs as bps import bluesky.plans as bp import bluesky.preprocessors as bpp -from pathlib import Path -import json -from datetime import datetime +import matplotlib.pyplot as plt +import numpy as np logger = logging.getLogger(__name__) @@ -47,10 +46,13 @@ # FOCUS ANALYSIS UTILITIES # ======================= -def compute_fft_bandpass_score(image: np.ndarray, - lower_cutoff: float = 0.025, - upper_cutoff: float = 0.14, - roi: Optional[Tuple[int, int, int, int]] = None) -> float: + +def compute_fft_bandpass_score( + image: np.ndarray, + lower_cutoff: float = 0.025, + upper_cutoff: float = 0.14, + roi: tuple[int, int, int, int] | None = None, +) -> float: """ Compute FFT bandpass focus score (ASI diSPIM OughtaFocus algorithm). @@ -95,7 +97,7 @@ def compute_fft_bandpass_score(image: np.ndarray, # Create distance map from center y, x = np.ogrid[:h, :w] - distance_from_center = np.sqrt((x - cx)**2 + (y - cy)**2) + distance_from_center = np.sqrt((x - cx) ** 2 + (y - cy) ** 2) # Maximum frequency (corner distance) max_freq = np.sqrt(cx**2 + cy**2) @@ -117,9 +119,9 @@ def compute_fft_bandpass_score(image: np.ndarray, return mean_power -def detect_embryo_roi(image: np.ndarray, - margin_fraction: float = 0.1, - min_threshold_ratio: float = 1.15) -> Tuple[int, int, int, int]: +def detect_embryo_roi( + image: np.ndarray, margin_fraction: float = 0.1, min_threshold_ratio: float = 1.15 +) -> tuple[int, int, int, int]: """ Detect embryo region and return bounding box ROI for focus analysis. @@ -145,12 +147,14 @@ def detect_embryo_roi(image: np.ndarray, # Calculate background from edge regions edge_margin = min(50, h // 10, w // 10) - edge_pixels = np.concatenate([ - image[:edge_margin, :].flatten(), - image[-edge_margin:, :].flatten(), - image[:, :edge_margin].flatten(), - image[:, -edge_margin:].flatten() - ]) + edge_pixels = np.concatenate( + [ + image[:edge_margin, :].flatten(), + image[-edge_margin:, :].flatten(), + image[:, :edge_margin].flatten(), + image[:, -edge_margin:].flatten(), + ] + ) background_level = np.median(edge_pixels) threshold = background_level * min_threshold_ratio @@ -228,6 +232,7 @@ def select_best_camera_view(image: np.ndarray) -> np.ndarray: # STAGE POSITIONING PLAN STUBS # ================================ + def get_stage_position_plan(xy_stage): """ Read XY stage position within a plan. @@ -256,14 +261,12 @@ def get_stage_position_plan(xy_stage): >>> current_pos = yield from get_stage_position_plan(xy_stage) >>> print(f"Stage at: {current_pos}") """ - result = yield from bps.rd(xy_stage) - return result[xy_stage.name]['value'] + # bps.rd returns the device's single reading value directly (here the + # [x, y] position array), not a {name: {value}} dict. + return (yield from bps.rd(xy_stage)) -def move_to_pixel_plan(xy_stage, - bottom_camera, - pixel_x: float, - pixel_y: float): +def move_to_pixel_plan(xy_stage, bottom_camera, pixel_x: float, pixel_y: float): """ Move stage to center on pixel coordinates from bottom camera image. @@ -315,9 +318,7 @@ def move_to_pixel_plan(xy_stage, # Convert pixel displacement to stage movement dx, dy = pixel_displacement_to_stage_movement( - pixel_offset_x, - pixel_offset_y, - bottom_camera.effective_pixel_size + pixel_offset_x, pixel_offset_y, bottom_camera.effective_pixel_size ) # Calculate target position @@ -329,10 +330,7 @@ def move_to_pixel_plan(xy_stage, return target_pos -def center_on_feature_plan(xy_stage, - bottom_camera, - image: np.ndarray, - feature_detector_func): +def center_on_feature_plan(xy_stage, bottom_camera, image: np.ndarray, feature_detector_func): """ Complete workflow: detect feature in image and center stage on it. @@ -374,24 +372,22 @@ def center_on_feature_plan(xy_stage, feature_x, feature_y = feature_detector_func(image) # Move to center the feature - final_pos = yield from move_to_pixel_plan( - xy_stage, bottom_camera, feature_x, feature_y - ) + final_pos = yield from move_to_pixel_plan(xy_stage, bottom_camera, feature_x, feature_y) - return { - 'feature_pos': (feature_x, feature_y), - 'stage_pos': final_pos - } + return {"feature_pos": (feature_x, feature_y), "stage_pos": final_pos} # ======================= # CALIBRATION PLANS # ======================= -def focus_sweep_plan(lightsheet_snap, - galvo_positions: List[float], - roi_detection: bool = True, - metadata: Optional[Dict] = None): + +def focus_sweep_plan( + lightsheet_snap, + galvo_positions: list[float], + roi_detection: bool = True, + metadata: dict | None = None, +): """ Perform focus sweep by moving galvo Y-offset and analyzing image quality. @@ -424,15 +420,15 @@ def focus_sweep_plan(lightsheet_snap, """ # Prepare storage for results results = { - 'galvo_positions': galvo_positions, - 'images': [], - 'focus_scores': [], - 'rois': [], - 'timestamp': time.time() + "galvo_positions": galvo_positions, + "images": [], + "focus_scores": [], + "rois": [], + "timestamp": time.time(), } # Open data collection run - _md = {'plan_name': 'focus_sweep'} + _md = {"plan_name": "focus_sweep"} if metadata: _md.update(metadata) @@ -440,8 +436,12 @@ def focus_sweep_plan(lightsheet_snap, def inner(): logger.info("=" * 70) logger.info("FOCUS SWEEP: %d positions", len(galvo_positions)) - logger.info("Galvo Y range: [%.4f, %.4f] deg", min(galvo_positions), max(galvo_positions)) - logger.info("ROI detection: %s", 'enabled' if roi_detection else 'disabled') + logger.info( + "Galvo Y range: [%.4f, %.4f] deg", + min(galvo_positions), + max(galvo_positions), + ) + logger.info("ROI detection: %s", "enabled" if roi_detection else "disabled") logger.info("=" * 70) for idx, galvo_y in enumerate(galvo_positions): @@ -454,7 +454,7 @@ def inner(): yield from bps.trigger_and_read([lightsheet_snap]) # Get captured image from device - image_data = lightsheet_snap.read()[lightsheet_snap.camera.name]['value'] + image_data = lightsheet_snap.read()[lightsheet_snap.camera.name]["value"] # Select best view if dual-camera if image_data.shape[1] > image_data.shape[0] * 2: # Heuristic for stitched image @@ -472,34 +472,36 @@ def inner(): focus_score = compute_fft_bandpass_score(image, roi=roi) # Store results - results['images'].append(image) - results['focus_scores'].append(focus_score) - results['rois'].append(roi) + results["images"].append(image) + results["focus_scores"].append(focus_score) + results["rois"].append(roi) logger.debug("Score: %.2e", focus_score) yield from inner() # Analyze results - scores = np.array(results['focus_scores']) + scores = np.array(results["focus_scores"]) best_idx = np.argmax(scores) - results['best_position'] = galvo_positions[best_idx] - results['best_score'] = scores[best_idx] - results['best_image'] = results['images'][best_idx] + results["best_position"] = galvo_positions[best_idx] + results["best_score"] = scores[best_idx] + results["best_image"] = results["images"][best_idx] logger.info("BEST FOCUS: Position %d/%d", best_idx + 1, len(galvo_positions)) - logger.info("Galvo Y = %.4f deg", results['best_position']) - logger.info("Score = %.2e", results['best_score']) + logger.info("Galvo Y = %.4f deg", results["best_position"]) + logger.info("Score = %.2e", results["best_score"]) return results -def calibrate_piezo_galvo_plan(lightsheet_snap, - piezo_positions: List[float], - initial_galvo_position: float = 0.0, - search_range_deg: float = 0.02, - n_sweep_points: int = 21, - metadata: Optional[Dict] = None): +def calibrate_piezo_galvo_plan( + lightsheet_snap, + piezo_positions: list[float], + initial_galvo_position: float = 0.0, + search_range_deg: float = 0.02, + n_sweep_points: int = 21, + metadata: dict | None = None, +): """ Calibrate piezo-galvo synchronization using 2-point linear fit. @@ -534,14 +536,14 @@ def calibrate_piezo_galvo_plan(lightsheet_snap, Calibration results with slope, offset, and fit quality """ results = { - 'piezo_positions': piezo_positions, - 'galvo_positions': [], - 'focus_scores': [], - 'sweep_results': [], - 'timestamp': time.time() + "piezo_positions": piezo_positions, + "galvo_positions": [], + "focus_scores": [], + "sweep_results": [], + "timestamp": time.time(), } - _md = {'plan_name': 'calibrate_piezo_galvo'} + _md = {"plan_name": "calibrate_piezo_galvo"} if metadata: _md.update(metadata) @@ -557,13 +559,18 @@ def inner(): current_galvo_guess = initial_galvo_position for piezo_idx, piezo_z in enumerate(piezo_positions): - logger.info("[PIEZO %d/%d] Z = %.2f um", piezo_idx + 1, len(piezo_positions), piezo_z) + logger.info( + "[PIEZO %d/%d] Z = %.2f um", + piezo_idx + 1, + len(piezo_positions), + piezo_z, + ) # Generate galvo sweep positions around current guess galvo_sweep = np.linspace( current_galvo_guess - search_range_deg, current_galvo_guess + search_range_deg, - n_sweep_points + n_sweep_points, ) # Perform focus sweep at this piezo position @@ -571,16 +578,16 @@ def inner(): lightsheet_snap, galvo_sweep.tolist(), roi_detection=True, - metadata={'piezo_position': piezo_z} + metadata={"piezo_position": piezo_z}, ) # Store results - best_galvo = sweep_results['best_position'] - best_score = sweep_results['best_score'] + best_galvo = sweep_results["best_position"] + best_score = sweep_results["best_score"] - results['galvo_positions'].append(best_galvo) - results['focus_scores'].append(best_score) - results['sweep_results'].append(sweep_results) + results["galvo_positions"].append(best_galvo) + results["focus_scores"].append(best_score) + results["sweep_results"].append(sweep_results) # Update guess for next iteration (assume roughly linear) current_galvo_guess = best_galvo @@ -591,7 +598,7 @@ def inner(): # Compute linear fit: galvo_y = slope * piezo_z + offset piezo_array = np.array(piezo_positions) - galvo_array = np.array(results['galvo_positions']) + galvo_array = np.array(results["galvo_positions"]) # Linear regression coeffs = np.polyfit(piezo_array, galvo_array, deg=1) @@ -604,11 +611,11 @@ def inner(): rmse = np.sqrt(np.mean(residuals**2)) # Store calibration parameters - results['calibration'] = { - 'slope': slope, - 'offset': offset, - 'rmse': rmse, - 'equation': f"galvo_y = {slope:.6e} * piezo_z + {offset:.6f}" + results["calibration"] = { + "slope": slope, + "offset": offset, + "rmse": rmse, + "equation": f"galvo_y = {slope:.6e} * piezo_z + {offset:.6f}", } logger.info("=" * 70) @@ -626,10 +633,10 @@ def inner(): # EMBRYO DETECTION PLANS # ======================= -def mark_embryo_interactive_plan(bottom_camera, - xy_stage, - embryo_number: int, - metadata: Optional[Dict] = None): + +def mark_embryo_interactive_plan( + bottom_camera, xy_stage, embryo_number: int, metadata: dict | None = None +): """ Interactive plan for user to mark embryo position and center it. @@ -661,7 +668,7 @@ def mark_embryo_interactive_plan(bottom_camera, Dict Results with embryo position, stage position, and images """ - _md = {'plan_name': 'mark_embryo_interactive', 'embryo_number': embryo_number} + _md = {"plan_name": "mark_embryo_interactive", "embryo_number": embryo_number} if metadata: _md.update(metadata) @@ -674,10 +681,10 @@ def inner(): # Capture initial image logger.info("Capturing initial image...") yield from bps.trigger_and_read([bottom_camera]) - initial_image = bottom_camera.read()[bottom_camera.name]['value'] + initial_image = bottom_camera.read()[bottom_camera.name]["value"] # Get current stage position - initial_stage_pos = xy_stage.read()[xy_stage.name]['value'] + initial_stage_pos = xy_stage.read()[xy_stage.name]["value"] logger.info("Initial stage: (%.2f, %.2f) um", initial_stage_pos[0], initial_stage_pos[1]) # Display interactive marking interface @@ -687,15 +694,17 @@ def inner(): # Create interactive figure fig, ax = plt.subplots(figsize=(12, 10)) - img_norm = (initial_image - initial_image.min()) / (initial_image.max() - initial_image.min()) - ax.imshow(img_norm, cmap='gray') + img_norm = (initial_image - initial_image.min()) / ( + initial_image.max() - initial_image.min() + ) + ax.imshow(img_norm, cmap="gray") # Draw center crosshair h, w = initial_image.shape - ax.axvline(w/2, color='red', linestyle='--', linewidth=2, label='Center') - ax.axhline(h/2, color='red', linestyle='--', linewidth=2) + ax.axvline(w / 2, color="red", linestyle="--", linewidth=2, label="Center") + ax.axhline(h / 2, color="red", linestyle="--", linewidth=2) - ax.set_title(f"Click on Embryo #{embryo_number}", fontsize=14, fontweight='bold') + ax.set_title(f"Click on Embryo #{embryo_number}", fontsize=14, fontweight="bold") # Storage for click embryo_position = [None, None] @@ -705,8 +714,15 @@ def onclick(event): embryo_position[0] = event.xdata embryo_position[1] = event.ydata # Draw marker - ax.plot(event.xdata, event.ydata, 'o', color='lime', - markersize=15, markeredgewidth=3, markeredgecolor='white') + ax.plot( + event.xdata, + event.ydata, + "o", + color="lime", + markersize=15, + markeredgewidth=3, + markeredgecolor="white", + ) fig.canvas.draw() logger.info("Marked at pixel (%.0f, %.0f)", event.xdata, event.ydata) @@ -714,26 +730,23 @@ def on_done(event): plt.close(fig) # Connect handlers - cid = fig.canvas.mpl_connect('button_press_event', onclick) + fig.canvas.mpl_connect("button_press_event", onclick) # Add Done button ax_done = plt.axes([0.81, 0.05, 0.1, 0.04]) - btn_done = Button(ax_done, 'Done') + btn_done = Button(ax_done, "Done") btn_done.on_clicked(on_done) plt.show() if embryo_position[0] is None: logger.warning("No embryo marked!") - return {'success': False} + return {"success": False} # Center the embryo (using plan stub for proper Bluesky message flow) logger.info("Moving stage to center embryo...") yield from move_to_pixel_plan( - xy_stage, - bottom_camera, - embryo_position[0], - embryo_position[1] + xy_stage, bottom_camera, embryo_position[0], embryo_position[1] ) time.sleep(0.5) @@ -741,19 +754,19 @@ def on_done(event): # Capture confirmation image logger.info("Capturing confirmation image...") yield from bps.trigger_and_read([bottom_camera]) - centered_image = bottom_camera.read()[bottom_camera.name]['value'] + centered_image = bottom_camera.read()[bottom_camera.name]["value"] - final_stage_pos = xy_stage.read()[xy_stage.name]['value'] + final_stage_pos = xy_stage.read()[xy_stage.name]["value"] logger.info("Final stage: (%.2f, %.2f) um", final_stage_pos[0], final_stage_pos[1]) - results['success'] = True - results['embryo_number'] = embryo_number - results['pixel_position'] = tuple(embryo_position) - results['initial_stage_position'] = initial_stage_pos - results['final_stage_position'] = final_stage_pos - results['initial_image'] = initial_image - results['centered_image'] = centered_image - results['timestamp'] = time.time() + results["success"] = True + results["embryo_number"] = embryo_number + results["pixel_position"] = tuple(embryo_position) + results["initial_stage_position"] = initial_stage_pos + results["final_stage_position"] = final_stage_pos + results["initial_image"] = initial_image + results["centered_image"] = centered_image + results["timestamp"] = time.time() logger.info("Embryo #%d centered!", embryo_number) @@ -766,20 +779,23 @@ def on_done(event): # VOLUME ACQUISITION PLANS # ======================= -def acquire_single_volume_plan(volume_scanner, - num_slices: int = 100, - exposure_ms: float = 5.0, - galvo_amplitude: float = 0.5, - galvo_center: float = 0.0, - piezo_amplitude: float = 25.0, - piezo_center: float = 50.0, - laser_config: str = "488 and 561", - laser_power_488_pct: float = None, - laser_power_561_pct: float = None, - laser_power_405_pct: float = None, - laser_power_637_pct: float = None, - timing_params: Optional[Dict] = None, - metadata: Optional[Dict] = None): + +def acquire_single_volume_plan( + volume_scanner, + num_slices: int = 100, + exposure_ms: float = 5.0, + galvo_amplitude: float = 0.5, + galvo_center: float = 0.0, + piezo_amplitude: float = 25.0, + piezo_center: float = 50.0, + laser_config: str = "488 and 561", + laser_power_488_pct: float | None = None, + laser_power_561_pct: float | None = None, + laser_power_405_pct: float | None = None, + laser_power_637_pct: float | None = None, + timing_params: dict | None = None, + metadata: dict | None = None, +): """ Acquire a single hardware-triggered 3D volume. @@ -821,18 +837,18 @@ def acquire_single_volume_plan(volume_scanner, Results with volume data """ _md = { - 'plan_name': 'acquire_single_volume', - 'num_slices': num_slices, - 'exposure_ms': exposure_ms, - 'galvo_amplitude': galvo_amplitude, - 'galvo_center': galvo_center, - 'piezo_amplitude': piezo_amplitude, - 'piezo_center': piezo_center, - 'laser_config': laser_config, - 'laser_power_488_pct': laser_power_488_pct, - 'laser_power_561_pct': laser_power_561_pct, - 'laser_power_405_pct': laser_power_405_pct, - 'laser_power_637_pct': laser_power_637_pct, + "plan_name": "acquire_single_volume", + "num_slices": num_slices, + "exposure_ms": exposure_ms, + "galvo_amplitude": galvo_amplitude, + "galvo_center": galvo_center, + "piezo_amplitude": piezo_amplitude, + "piezo_center": piezo_center, + "laser_config": laser_config, + "laser_power_488_pct": laser_power_488_pct, + "laser_power_561_pct": laser_power_561_pct, + "laser_power_405_pct": laser_power_405_pct, + "laser_power_637_pct": laser_power_637_pct, } if metadata: _md.update(metadata) @@ -848,10 +864,14 @@ def inner(): logger.info("Piezo: %.2f um amplitude, %.2f um center", piezo_amplitude, piezo_center) logger.info("Lasers: %s", laser_config) power_log = [] - if laser_power_488_pct is not None: power_log.append(f"488@{laser_power_488_pct}%") - if laser_power_561_pct is not None: power_log.append(f"561@{laser_power_561_pct}%") - if laser_power_405_pct is not None: power_log.append(f"405@{laser_power_405_pct}%") - if laser_power_637_pct is not None: power_log.append(f"637@{laser_power_637_pct}%") + if laser_power_488_pct is not None: + power_log.append(f"488@{laser_power_488_pct}%") + if laser_power_561_pct is not None: + power_log.append(f"561@{laser_power_561_pct}%") + if laser_power_405_pct is not None: + power_log.append(f"405@{laser_power_405_pct}%") + if laser_power_637_pct is not None: + power_log.append(f"637@{laser_power_637_pct}%") if power_log: logger.info("Laser power: %s", ", ".join(power_log)) logger.info("=" * 70) @@ -880,35 +900,37 @@ def inner(): elapsed = time.time() - start_time # Get volume data - volume_data = volume_scanner.read()[volume_scanner.name]['value'] + volume_data = volume_scanner.read()[volume_scanner.name]["value"] logger.info("Volume acquired! Shape: %s, Time: %.2f s", volume_data.shape, elapsed) - results['volume'] = volume_data - results['shape'] = volume_data.shape - results['acquisition_time'] = elapsed - results['timestamp'] = time.time() + results["volume"] = volume_data + results["shape"] = volume_data.shape + results["acquisition_time"] = elapsed + results["timestamp"] = time.time() yield from inner() return results -def burst_plan(volume_scanner, - frames: int = 60, - mode: str = "1hz", - num_slices: int = 1, - exposure_ms: float = 5.0, - galvo_amplitude: float = 0.5, - galvo_center: float = 0.0, - piezo_amplitude: float = 25.0, - piezo_center: float = 50.0, - laser_config: str = "488 only", - laser_power_488_pct: float = None, - laser_power_561_pct: float = None, - laser_power_405_pct: float = None, - laser_power_637_pct: float = None, - timing_params: Optional[Dict] = None, - metadata: Optional[Dict] = None): +def burst_plan( + volume_scanner, + frames: int = 60, + mode: str = "1hz", + num_slices: int = 1, + exposure_ms: float = 5.0, + galvo_amplitude: float = 0.5, + galvo_center: float = 0.0, + piezo_amplitude: float = 25.0, + piezo_center: float = 50.0, + laser_config: str = "488 only", + laser_power_488_pct: float | None = None, + laser_power_561_pct: float | None = None, + laser_power_405_pct: float | None = None, + laser_power_637_pct: float | None = None, + timing_params: dict | None = None, + metadata: dict | None = None, +): """ Acquire a burst of N volumes back-to-back as a single Bluesky run. @@ -938,7 +960,8 @@ def burst_plan(volume_scanner, Scan geometry — same meaning as in ``acquire_single_volume_plan``. laser_config : str Laser preset, e.g. ``"488 only"``. - laser_power_488_pct, laser_power_561_pct, laser_power_405_pct, laser_power_637_pct : float, optional + laser_power_488_pct, laser_power_561_pct, laser_power_405_pct, + laser_power_637_pct : float, optional Per-line laser power %. Hard-limited at the device layer. timing_params : Dict, optional Custom SPIM timing parameters. @@ -962,26 +985,26 @@ def burst_plan(volume_scanner, target_dt = 1.0 if mode == "1hz" else 0.0 _md = { - 'plan_name': 'burst', - 'frames': frames, - 'mode': mode, - 'num_slices': num_slices, - 'exposure_ms': exposure_ms, - 'laser_config': laser_config, - 'laser_power_488_pct': laser_power_488_pct, - 'laser_power_561_pct': laser_power_561_pct, - 'laser_power_405_pct': laser_power_405_pct, - 'laser_power_637_pct': laser_power_637_pct, + "plan_name": "burst", + "frames": frames, + "mode": mode, + "num_slices": num_slices, + "exposure_ms": exposure_ms, + "laser_config": laser_config, + "laser_power_488_pct": laser_power_488_pct, + "laser_power_561_pct": laser_power_561_pct, + "laser_power_405_pct": laser_power_405_pct, + "laser_power_637_pct": laser_power_637_pct, } if metadata: _md.update(metadata) results = { - 'frames_captured': 0, - 'frames_requested': frames, - 'mode': mode, - 'duration_s': 0.0, - 'sustained_hz': 0.0, + "frames_captured": 0, + "frames_requested": frames, + "mode": mode, + "duration_s": 0.0, + "sustained_hz": 0.0, } @bpp.run_decorator(md=_md) @@ -990,10 +1013,14 @@ def inner(): logger.info("BURST ACQUISITION") logger.info("Frames: %d, Mode: %s, Slices/frame: %d", frames, mode, num_slices) power_log = [] - if laser_power_488_pct is not None: power_log.append(f"488@{laser_power_488_pct}%") - if laser_power_561_pct is not None: power_log.append(f"561@{laser_power_561_pct}%") - if laser_power_405_pct is not None: power_log.append(f"405@{laser_power_405_pct}%") - if laser_power_637_pct is not None: power_log.append(f"637@{laser_power_637_pct}%") + if laser_power_488_pct is not None: + power_log.append(f"488@{laser_power_488_pct}%") + if laser_power_561_pct is not None: + power_log.append(f"561@{laser_power_561_pct}%") + if laser_power_405_pct is not None: + power_log.append(f"405@{laser_power_405_pct}%") + if laser_power_637_pct is not None: + power_log.append(f"637@{laser_power_637_pct}%") if power_log: logger.info("Laser power: %s", ", ".join(power_log)) logger.info("=" * 70) @@ -1020,7 +1047,7 @@ def inner(): tick_start = time.time() # One event per frame, each carrying its own volume file_ref. yield from bps.trigger_and_read([volume_scanner]) - results['frames_captured'] = i + 1 + results["frames_captured"] = i + 1 # Pacing — only between frames, not after the last. if i < frames - 1 and target_dt > 0: @@ -1030,24 +1057,24 @@ def inner(): yield from bps.sleep(wait) duration = time.time() - burst_start - results['duration_s'] = duration - results['sustained_hz'] = ( - results['frames_captured'] / duration if duration > 0 else 0.0 - ) + results["duration_s"] = duration + results["sustained_hz"] = results["frames_captured"] / duration if duration > 0 else 0.0 logger.info( "BURST COMPLETE — %d/%d frames in %.2fs (%.2f Hz)", - results['frames_captured'], frames, duration, results['sustained_hz'], + results["frames_captured"], + frames, + duration, + results["sustained_hz"], ) yield from inner() return results -def timelapse_volume_plan(volume_scanner, - num_timepoints: int, - interval_seconds: float, - **volume_kwargs): +def timelapse_volume_plan( + volume_scanner, num_timepoints: int, interval_seconds: float, **volume_kwargs +): """ Acquire time-lapse series of 3D volumes. @@ -1073,16 +1100,16 @@ def timelapse_volume_plan(volume_scanner, Results with all volumes and timestamps """ results = { - 'volumes': [], - 'timestamps': [], - 'num_timepoints': num_timepoints, - 'interval_seconds': interval_seconds + "volumes": [], + "timestamps": [], + "num_timepoints": num_timepoints, + "interval_seconds": interval_seconds, } _md = { - 'plan_name': 'timelapse_volume', - 'num_timepoints': num_timepoints, - 'interval_seconds': interval_seconds + "plan_name": "timelapse_volume", + "num_timepoints": num_timepoints, + "interval_seconds": interval_seconds, } @bpp.run_decorator(md=_md) @@ -1098,20 +1125,18 @@ def inner(): # Acquire volume vol_results = yield from acquire_single_volume_plan( - volume_scanner, - metadata={'timepoint': tp}, - **volume_kwargs + volume_scanner, metadata={"timepoint": tp}, **volume_kwargs ) - results['volumes'].append(vol_results['volume']) - results['timestamps'].append(vol_results['timestamp']) + results["volumes"].append(vol_results["volume"]) + results["timestamps"].append(vol_results["timestamp"]) # Wait for next timepoint (except after last one) if tp < num_timepoints - 1: logger.info("Waiting %d s until next timepoint...", interval_seconds) yield from bps.sleep(interval_seconds) - logger.info("TIME-LAPSE COMPLETE - Total volumes: %d", len(results['volumes'])) + logger.info("TIME-LAPSE COMPLETE - Total volumes: %d", len(results["volumes"])) yield from inner() return results @@ -1121,11 +1146,14 @@ def inner(): # MULTI-EMBRYO WORKFLOWS # ======================= -def multi_embryo_calibration_workflow(bottom_camera, - xy_stage, - lightsheet_snap, - num_embryos: int, - calibration_params: Optional[Dict] = None): + +def multi_embryo_calibration_workflow( + bottom_camera, + xy_stage, + lightsheet_snap, + num_embryos: int, + calibration_params: dict | None = None, +): """ Full multi-embryo calibration workflow. @@ -1162,21 +1190,14 @@ def multi_embryo_calibration_workflow(bottom_camera, # Default calibration parameters if calibration_params is None: calibration_params = { - 'piezo_positions': [40.0, 60.0], # Two-point calibration - 'search_range_deg': 0.02, - 'n_sweep_points': 21 + "piezo_positions": [40.0, 60.0], # Two-point calibration + "search_range_deg": 0.02, + "n_sweep_points": 21, } - results = { - 'embryos': [], - 'num_embryos': num_embryos, - 'timestamp': time.time() - } + results = {"embryos": [], "num_embryos": num_embryos, "timestamp": time.time()} - _md = { - 'plan_name': 'multi_embryo_calibration_workflow', - 'num_embryos': num_embryos - } + _md = {"plan_name": "multi_embryo_calibration_workflow", "num_embryos": num_embryos} @bpp.run_decorator(md=_md) def inner(): @@ -1192,35 +1213,37 @@ def inner(): # Mark and center embryo marking_results = yield from mark_embryo_interactive_plan( - bottom_camera, - xy_stage, - embryo_number=emb_num + bottom_camera, xy_stage, embryo_number=emb_num ) - if not marking_results.get('success', False): + if not marking_results.get("success", False): logger.warning("Skipping embryo %d", emb_num) continue # Perform calibration calib_results = yield from calibrate_piezo_galvo_plan( lightsheet_snap, - metadata={'embryo_number': emb_num}, - **calibration_params + metadata={"embryo_number": emb_num}, + **calibration_params, ) # Store results embryo_data = { - 'embryo_number': emb_num, - 'marking': marking_results, - 'calibration': calib_results, - 'timestamp': time.time() + "embryo_number": emb_num, + "marking": marking_results, + "calibration": calib_results, + "timestamp": time.time(), } - results['embryos'].append(embryo_data) + results["embryos"].append(embryo_data) logger.info("Embryo %d calibration complete!", emb_num) - logger.info("WORKFLOW COMPLETE - Calibrated %d/%d embryos", len(results['embryos']), num_embryos) + logger.info( + "WORKFLOW COMPLETE - Calibrated %d/%d embryos", + len(results["embryos"]), + num_embryos, + ) yield from inner() return results @@ -1230,10 +1253,11 @@ def inner(): # Utility Plans (simple device operations for HTTP API) # ============================================================================= + def move_stage_plan(xy_stage, x: float, y: float) -> Generator[Any, Any, dict]: """Move XY stage to specified position.""" yield from bps.mv(xy_stage, [x, y]) - return {'x': x, 'y': y, 'success': True} + return {"x": x, "y": y, "success": True} def read_stage_plan(xy_stage) -> Generator[Any, Any, None]: @@ -1250,13 +1274,13 @@ def capture_bottom_image_plan(bottom_camera, led=None) -> Generator[Any, Any, No """Capture a single image from the bottom camera.""" if led is not None: try: - yield from bps.mv(led, 'Open') + yield from bps.mv(led, "Open") except Exception: pass yield from bp.count([bottom_camera], num=1) if led is not None: try: - yield from bps.mv(led, 'Closed') + yield from bps.mv(led, "Closed") except Exception: pass @@ -1268,7 +1292,7 @@ def capture_lightsheet_image_plan( laser_control, piezo_position: float = 50.0, galvo_position: float = 0.0, - laser_config: str = "488 and 561" + laser_config: str = "488 and 561", ) -> Generator[Any, Any, None]: """Capture a single lightsheet image at specified piezo/galvo positions.""" yield from bps.mv(piezo, piezo_position) @@ -1284,25 +1308,25 @@ def capture_lightsheet_image_plan( def move_piezo_plan(piezo, position: float) -> Generator[Any, Any, dict]: """Move piezo to specified position.""" yield from bps.mv(piezo, position) - return {'position': position, 'success': True} + return {"position": position, "success": True} def move_scanner_plan(scanner, offset_y: float) -> Generator[Any, Any, dict]: """Move scanner galvo to specified offset.""" yield from bps.mv(scanner.sa_offset_y, offset_y) - return {'offset_y': offset_y, 'success': True} + return {"offset_y": offset_y, "success": True} -def set_laser_plan(laser_control, state: str = 'ON') -> Generator[Any, Any, dict]: +def set_laser_plan(laser_control, state: str = "ON") -> Generator[Any, Any, dict]: """Set laser state.""" yield from bps.mv(laser_control.state, state) - return {'state': state, 'success': True} + return {"state": state, "success": True} -def set_led_plan(led, state: str = 'Open') -> Generator[Any, Any, dict]: +def set_led_plan(led, state: str = "Open") -> Generator[Any, Any, dict]: """Set LED state.""" yield from bps.mv(led, state) - return {'state': state, 'success': True} + return {"state": state, "success": True} def set_light_source_power_plan( @@ -1319,9 +1343,9 @@ def set_light_source_power_plan( light_source.set_power_pct(wavelength, pct) yield from bps.null() # plan must be a generator return { - 'wavelength': wavelength, - 'pct': pct, - 'success': True, + "wavelength": wavelength, + "pct": pct, + "success": True, } @@ -1333,7 +1357,307 @@ def get_light_source_power_plan( value = light_source.get_power_pct(wavelength) yield from bps.null() return { - 'wavelength': wavelength, - 'pct': value, - 'success': True, + "wavelength": wavelength, + "pct": value, + "success": True, } + + +# ============================================================================ +# SPIM HEAD FOCUS +# Bring the SPIM head down onto an XY-positioned embryo, lock focus, then +# register the two objective views on top of each other via the XY stage. +# ============================================================================ + + +def spim_head_focus_descent_plan( + fdrive, + camera, + led=None, + *, + traverse_to_um: float = 5000.0, + coarse_step_um: float = 1000.0, + coarse_stop_um: float = 500.0, + fine_top_um: float = 150.0, + fine_bottom_um: float = 35.0, + fine_step_um: float = 5.0, + focus_algorithm: str = "volath", + coarse_settle_s: float = 2.0, + fine_settle_s: float = 0.5, + led_state: str = "Open", + detect_roi: bool = True, +) -> Generator[Any, Any, dict]: + """ + Bring the SPIM head down onto an (already XY-positioned) embryo and lock focus. + + The head parks fully raised (~25000 um, "Load Sample") and the embryo only + comes into focus near the bottom of travel (~50 um), so almost the whole + descent is empty space -- and full-travel F moves are slow. Rather than scan + the entire 25000->50 range, this descends in three phases: + + 1. Traverse -- one fast move down to ``traverse_to_um`` (default 5000). No imaging. + 2. Coarse -- step down by ``coarse_step_um`` (1000) to ``coarse_stop_um`` + (~500), settling at each step. No focus decision yet. + 3. Fine -- bounded sweep-and-fit from ``fine_top_um`` (150) down to + ``fine_bottom_um`` (35, >= the 30 um hard floor) in + ``fine_step_um`` (5) steps: snap ``camera`` + score each + frame, fit the focus curve (gently.analysis.focus), and + move to the best-focus position. + + ``led`` (transmitted illumination) is opened for the descent and ALWAYS + closed afterwards (finalize), mirroring the bottom camera's LED discipline. + + Every F move is clamped to ``fdrive.limits`` (30-25000 um); the 30 um floor + is the hard collision stop and ``fine_bottom_um`` must sit above it. + + Parameters + ---------- + fdrive : DiSPIMFDrive + SPIM-head F-drive positioner (``bps.mv(fdrive, z)``). + camera : DiSPIMCamera + SPIM objective camera; ``camera.snap()`` returns the frame to score. + led : DiSPIMLED, optional + Transmitted-light source. Opened during the descent, closed after. + focus_algorithm : str + Focus metric: 'volath' (default), 'gradient', 'variance', 'fft_bandpass'. + + Returns + ------- + dict + success, best_position_um, best_score, r_squared, start_position_um, + and fine_curve (list of (z_um, score) tuples). + """ + from gently.analysis.focus import ( + FocusAnalysisConfig, + FocusDataPoint, + analyze_focus_sweep, + score_single_image, + ) + + lo, hi = fdrive.limits + config = FocusAnalysisConfig(algorithm=focus_algorithm) + out: dict[str, Any] = {} + + def _clamp(z: float) -> float: + return max(lo, min(hi, float(z))) + + def _grab_and_score(): + frame = camera.snap() + # A single SPIM snap may be a side-by-side dual view; score the brighter + # objective half (focus is per-objective). Heuristic: width >> height. + if frame.ndim == 2 and frame.shape[1] >= frame.shape[0] * 1.5: + view = select_best_camera_view(frame) + else: + view = frame + score, roi = score_single_image(view, config, detect_roi=detect_roi) + return view, score, roi + + def inner(): + # bps.rd returns the device's single reading value directly (a float + # here) in this bluesky version -- not a {name: {value}} dict. + start_z = float((yield from bps.rd(fdrive))) + logger.info("[SPIM head focus] start %.1f um, floor %.0f um", start_z, lo) + + if led is not None: + yield from bps.mv(led, led_state) + + # Phase 1 -- fast traverse (skip if already below the traverse height) + z = start_z + if z > traverse_to_um: + z = _clamp(traverse_to_um) + logger.info("[SPIM head focus] traverse -> %.0f um (slow move)", z) + yield from bps.mv(fdrive, z) + yield from bps.sleep(coarse_settle_s) + + # Phase 2 -- coarse stepped descent to ~coarse_stop_um + while z > coarse_stop_um + 1e-6: + z = _clamp(max(coarse_stop_um, z - coarse_step_um)) + logger.info("[SPIM head focus] coarse -> %.0f um", z) + yield from bps.mv(fdrive, z) + yield from bps.sleep(coarse_settle_s) + + # Phase 3 -- fine bounded sweep-and-fit through the focus window + top, bottom = _clamp(fine_top_um), _clamp(fine_bottom_um) + n = max(3, int(round(abs(top - bottom) / max(fine_step_um, 1e-6))) + 1) + positions = [_clamp(p) for p in np.linspace(top, bottom, n)] # descending + logger.info("[SPIM head focus] fine scan %.0f->%.0f um in %d steps", top, bottom, n) + + sweep = [] + for zp in positions: + yield from bps.mv(fdrive, zp) + yield from bps.sleep(fine_settle_s) + view, score, roi = _grab_and_score() + sweep.append(FocusDataPoint(position=zp, score=score, image=view, roi=roi)) + logger.debug("[SPIM head focus] z=%.1f score=%.3g", zp, score) + + result = analyze_focus_sweep(sweep, config) + if result.success: + best = _clamp(result.best_position) + else: + best = _clamp(max(sweep, key=lambda d: d.score).position) + + yield from bps.mv(fdrive, best) + yield from bps.sleep(fine_settle_s) + + out.update( + success=bool(result.success), + best_position_um=float(best), + best_score=float(result.best_score), + r_squared=float(result.r_squared), + start_position_um=start_z, + fine_curve=[(float(d.position), float(d.score)) for d in sweep], + ) + logger.info( + "[SPIM head focus] locked %.1f um (score %.3g, R2=%.3f, fit=%s)", + best, + result.best_score, + result.r_squared, + result.success, + ) + + def cleanup(): + if led is not None: + yield from bps.mv(led, "Closed") + + yield from bpp.finalize_wrapper(inner(), cleanup()) + return out + + +def register_views_xy_plan( + camera, + xy_stage, + *, + view_pixel_size_um: float, + tolerance_px: float = 20.0, + max_iters: int = 4, + settle_s: float = 0.5, + led=None, + led_state: str = "Open", + reference_view: int = 0, + x_sign: float = -1.0, + y_sign: float = 1.0, +) -> Generator[Any, Any, dict]: + """ + Register the two SPIM objective views on top of each other via the XY stage. + + Snaps ``camera`` (a side-by-side dual view: left = view A, right = view B), + detects the embryo centroid in each half, and moves the XY stage to bring + the embryo to the centre of the ``reference_view`` (0 = A/left, 1 = B/right). + The residual offset in the *other* view is recorded each iteration as the + registration error. Iterates up to ``max_iters`` or until the reference + view is within ``tolerance_px``. + + The view->stage mapping is the rig-specific part and likely needs tuning: + - ``view_pixel_size_um`` is the effective um/pixel of the OBJECTIVE view + (NOT the bottom camera) -- required. + - ``x_sign`` / ``y_sign`` flip the stage axes to match the camera + orientation. Defaults follow the bottom-camera convention (X inverted); + confirm them on the rig from the reported per-iteration residuals (if + the offset grows instead of shrinking, flip the offending sign). + + XY moves are bounded by the stage's own hardware limits, so a mis-tuned + sign fails loudly (limit error) rather than driving anywhere unsafe. + + Returns + ------- + dict + converged (bool), iterations (per-iter offsets per view, in px), + final_stage_um, and (if nothing was found) error. + """ + from gently.core.coordinates import pixel_displacement_to_stage_movement + + out: dict[str, Any] = {"converged": False, "iterations": []} + + def _views(frame): + if frame.ndim == 2 and frame.shape[1] >= frame.shape[0] * 1.5: + mid = frame.shape[1] // 2 + return [frame[:, :mid], frame[:, mid:]] + return [frame] + + def _centroid_offset(view): + # (dx, dy) of the embryo centroid from the view centre, in pixels; + # None if no embryo (detect_embryo_roi falls back to the whole frame). + y0, y1, x0, x1 = detect_embryo_roi(view) + h, w = view.shape + if (y1 - y0) >= h and (x1 - x0) >= w: + return None + cx, cy = (x0 + x1) / 2.0, (y0 + y1) / 2.0 + return (cx - w / 2.0, cy - h / 2.0) + + def inner(): + if led is not None: + yield from bps.mv(led, led_state) + + for _ in range(max_iters): + frame = camera.snap() + views = _views(frame) + offsets = [_centroid_offset(v) for v in views] + out["iterations"].append( + {"offsets_px": [None if o is None else (float(o[0]), float(o[1])) for o in offsets]} + ) + + ref = offsets[reference_view] if reference_view < len(offsets) else None + if ref is None: + seen = [o for o in offsets if o is not None] + if not seen: + out["error"] = "no embryo detected in any view" + break + ref = seen[0] + + if max(abs(ref[0]), abs(ref[1])) <= tolerance_px: + out["converged"] = True + break + + dx_um, dy_um = pixel_displacement_to_stage_movement(ref[0], ref[1], view_pixel_size_um) + cur = yield from bps.rd(xy_stage) # -> [x, y] um (single reading value) + target = [cur[0] + x_sign * dx_um, cur[1] + y_sign * dy_um] + yield from bps.mov(xy_stage, target) + yield from bps.sleep(settle_s) + + final = yield from bps.rd(xy_stage) + out["final_stage_um"] = [float(final[0]), float(final[1])] + + def cleanup(): + if led is not None: + yield from bps.mv(led, "Closed") + + yield from bpp.finalize_wrapper(inner(), cleanup()) + return out + + +def spim_head_focus_and_align_plan( + fdrive, + camera, + xy_stage, + led=None, + *, + view_pixel_size_um: float | None = None, + align: bool = True, + focus_kwargs: dict | None = None, + align_kwargs: dict | None = None, +) -> Generator[Any, Any, dict]: + """ + Full SPIM-head focus workflow for one XY-positioned embryo. + + Descends and locks focus (``spim_head_focus_descent_plan``), then registers + the two objective views via the XY stage (``register_views_xy_plan``). + + ``align=True`` requires ``view_pixel_size_um`` (objective-view um/pixel). + """ + out: dict[str, Any] = {} + out["focus"] = yield from spim_head_focus_descent_plan( + fdrive, camera, led=led, **(focus_kwargs or {}) + ) + if align: + if view_pixel_size_um is None: + raise ValueError( + "align=True requires view_pixel_size_um (objective-view um/pixel calibration)" + ) + out["align"] = yield from register_views_xy_plan( + camera, + xy_stage, + view_pixel_size_um=view_pixel_size_um, + led=led, + **(align_kwargs or {}), + ) + return out diff --git a/gently/hardware/dispim/plans/calibration.py b/gently/hardware/dispim/plans/calibration.py index 94de95b1..798a203b 100644 --- a/gently/hardware/dispim/plans/calibration.py +++ b/gently/hardware/dispim/plans/calibration.py @@ -18,10 +18,11 @@ import logging import time +from datetime import datetime +from pathlib import Path + import bluesky.plan_stubs as bps import numpy as np -from pathlib import Path -from datetime import datetime logger = logging.getLogger(__name__) @@ -39,9 +40,12 @@ def _safe_obtain(obj): # CLAUDE VISION PROMPTS # ============================================================================ -EMBRYO_CENTERING_PROMPT = """You are an expert microscopist examining a diSPIM light sheet microscopy image of a biological embryo sample. +EMBRYO_CENTERING_PROMPT = """You are an expert microscopist examining a diSPIM light sheet +microscopy image of a biological embryo sample. -This image shows ONE camera view from the diSPIM system. You should look for an embryo structure somewhere in the field of view. The embryo will appear as a brighter structure against a dark background, but the signal may be MODERATE (not necessarily super bright). +This image shows ONE camera view from the diSPIM system. You should look for an embryo +structure somewhere in the field of view. The embryo will appear as a brighter structure +against a dark background, but the signal may be MODERATE (not necessarily super bright). IMPORTANT CONTEXT: - This is a REAL microscopy image with typical noise and artifacts @@ -73,11 +77,14 @@ def _safe_obtain(obj): Example response: yes -An irregularly-shaped embryo structure is visible in the left-center region with moderate brightness and defined boundaries against the dark background.""" +An irregularly-shaped embryo structure is visible in the left-center region with moderate +brightness and defined boundaries against the dark background.""" -EMBRYO_EDGE_PROMPT = """You are an expert microscopist specializing in diSPIM light sheet microscopy of embryos. +EMBRYO_EDGE_PROMPT = """You are an expert microscopist specializing in diSPIM light sheet +microscopy of embryos. -This image shows ONE camera view of an embryo captured with light sheet illumination. We are trying to determine if the embryo is still visible at this Z position. +This image shows ONE camera view of an embryo captured with light sheet illumination. We are +trying to determine if the embryo is still visible at this Z position. CONTEXT: - We are sweeping through Z positions to find where the embryo appears/disappears @@ -112,6 +119,7 @@ def _safe_obtain(obj): # PLAN: VERIFY EMBRYO CENTERED # ============================================================================ + def verify_embryo_centered(embryo_detector, image_dir=None): """ Verify that embryo is centered and visible. @@ -149,17 +157,19 @@ def verify_embryo_centered(embryo_detector, image_dir=None): # Metadata for this phase metadata = { - 'plan_name': 'verify_embryo_centered', - 'phase': 'centering', - 'timestamp': datetime.now().isoformat() + "plan_name": "verify_embryo_centered", + "phase": "centering", + "timestamp": datetime.now().isoformat(), } # Start run - uid = yield from bps.open_run(md=metadata) + yield from bps.open_run(md=metadata) # Prepare image path if image_dir is not None: - image_path = Path(image_dir) / f"centering_check_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png" + image_path = ( + Path(image_dir) / f"centering_check_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png" + ) else: image_path = None @@ -170,7 +180,7 @@ def verify_embryo_centered(embryo_detector, image_dir=None): galvo_deg=0.0, piezo_um=0.0, prompt=EMBRYO_CENTERING_PROMPT, - save_image_path=image_path + save_image_path=image_path, ) # Log result @@ -182,16 +192,16 @@ def verify_embryo_centered(embryo_detector, image_dir=None): yield from bps.close_run() # Report result - if result['embryo_visible']: + if result["embryo_visible"]: logger.info("Embryo VISIBLE at center") - logger.info("Claude: %s", result['description']) - logger.info("Confidence: %.1f%%", result['confidence'] * 100) + logger.info("Claude: %s", result["description"]) + logger.info("Confidence: %.1f%%", result["confidence"] * 100) if image_path: logger.info("Image: %s", image_path) return True else: logger.warning("Embryo NOT VISIBLE at center") - logger.info("Claude: %s", result['description']) + logger.info("Claude: %s", result["description"]) logger.warning("Please adjust sample position and try again") return False @@ -200,10 +210,17 @@ def verify_embryo_centered(embryo_detector, image_dir=None): # PLAN: DETECT EMBRYO EDGE # ============================================================================ -def detect_embryo_edge(embryo_detector, direction='top', - start_deg=0.0, end_deg=0.5, step_deg=0.05, - tolerance_deg=0.20, piezo_um=0.0, - image_dir=None): + +def detect_embryo_edge( + embryo_detector, + direction="top", + start_deg=0.0, + end_deg=0.5, + step_deg=0.05, + tolerance_deg=0.20, + piezo_um=0.0, + image_dir=None, +): """ Detect embryo edge by sweeping until embryo disappears. @@ -263,7 +280,7 @@ def detect_embryo_edge(embryo_detector, direction='top', logger.info("=" * 70) # Determine sweep direction - if direction == 'top': + if direction == "top": # Sweep upward (negative direction) step = -abs(step_deg) tolerance_sign = -1 @@ -277,23 +294,28 @@ def detect_embryo_edge(embryo_detector, direction='top', positions = [start_deg + i * step for i in range(num_steps)] logger.info("Sweep strategy: Start at %+.3f deg, step %+.3f deg", start_deg, step) - logger.info("Testing %d positions from %+.3f deg to %+.3f deg", num_steps, start_deg, end_deg) + logger.info( + "Testing %d positions from %+.3f deg to %+.3f deg", + num_steps, + start_deg, + end_deg, + ) logger.info("Looking for position where embryo disappears...") # Metadata metadata = { - 'plan_name': 'detect_embryo_edge', - 'phase': f'edge_detection_{direction}', - 'direction': direction, - 'start_deg': start_deg, - 'end_deg': end_deg, - 'step_deg': step, - 'tolerance_deg': tolerance_deg, - 'piezo_um': piezo_um, - 'timestamp': datetime.now().isoformat() + "plan_name": "detect_embryo_edge", + "phase": f"edge_detection_{direction}", + "direction": direction, + "start_deg": start_deg, + "end_deg": end_deg, + "step_deg": step, + "tolerance_deg": tolerance_deg, + "piezo_um": piezo_um, + "timestamp": datetime.now().isoformat(), } - uid = yield from bps.open_run(md=metadata) + yield from bps.open_run(md=metadata) # Sweep through positions all_results = [] @@ -303,7 +325,10 @@ def detect_embryo_edge(embryo_detector, direction='top', for i, pos in enumerate(positions): # Prepare image path if image_dir is not None: - image_path = Path(image_dir) / f"edge_{direction}_pos{pos:+.3f}deg_{datetime.now().strftime('%H%M%S')}.png" + image_path = ( + Path(image_dir) + / f"edge_{direction}_pos{pos:+.3f}deg_{datetime.now().strftime('%H%M%S')}.png" + ) else: image_path = None @@ -314,7 +339,7 @@ def detect_embryo_edge(embryo_detector, direction='top', galvo_deg=pos, piezo_um=piezo_um, prompt=EMBRYO_EDGE_PROMPT, - save_image_path=image_path + save_image_path=image_path, ) all_results.append(result) @@ -325,7 +350,7 @@ def detect_embryo_edge(embryo_detector, direction='top', yield from bps.save() # Check if embryo disappeared - if result['embryo_visible']: + if result["embryo_visible"]: logger.info(" -> visible") else: logger.info(" -> NOT visible - EDGE FOUND!") @@ -349,11 +374,11 @@ def detect_embryo_edge(embryo_detector, direction='top', # Prepare result edge_result = { - 'edge_deg': edge_position, - 'with_tolerance_deg': edge_with_tolerance, - 'num_steps': len(all_results), - 'all_positions': [r['galvo_deg'] for r in all_results], - 'all_visible': [r['embryo_visible'] for r in all_results] + "edge_deg": edge_position, + "with_tolerance_deg": edge_with_tolerance, + "num_steps": len(all_results), + "all_positions": [r["galvo_deg"] for r in all_results], + "all_visible": [r["embryo_visible"] for r in all_results], } yield from bps.close_run() @@ -365,11 +390,21 @@ def detect_embryo_edge(embryo_detector, direction='top', # PLAN: CALIBRATE FOCUS AT POSITION # ============================================================================ -def calibrate_focus_at_position(camera, galvo, piezo, focus_scorer, core, - galvo_deg, piezo_center_um, - sweep_range_um=20.0, sweep_step_um=2.0, - min_r_squared=0.75, image_dir=None, - phase_name="FOCUS CALIBRATION"): + +def calibrate_focus_at_position( + camera, + galvo, + piezo, + focus_scorer, + core, + galvo_deg, + piezo_center_um, + sweep_range_um=20.0, + sweep_step_um=2.0, + min_r_squared=0.75, + image_dir=None, + phase_name="FOCUS CALIBRATION", +): """ Perform focus sweep at a galvo position to find optimal piezo position. @@ -457,16 +492,16 @@ def calibrate_focus_at_position(camera, galvo, piezo, focus_scorer, core, # Metadata metadata = { - 'plan_name': 'calibrate_focus_at_position', - 'phase': phase_name, - 'galvo_deg': galvo_deg, - 'piezo_center_um': piezo_center_um, - 'sweep_range_um': sweep_range_um, - 'sweep_step_um': sweep_step_um, - 'timestamp': datetime.now().isoformat() + "plan_name": "calibrate_focus_at_position", + "phase": phase_name, + "galvo_deg": galvo_deg, + "piezo_center_um": piezo_center_um, + "sweep_range_um": sweep_range_um, + "sweep_step_um": sweep_step_um, + "timestamp": datetime.now().isoformat(), } - uid = yield from bps.open_run(md=metadata) + yield from bps.open_run(md=metadata) # Move galvo to position. setPosition() is synchronous — it issues the # set + waits for the device to settle, so no explicit wait needed here. @@ -511,7 +546,12 @@ def calibrate_focus_at_position(camera, galvo, piezo, focus_scorer, core, view_name = "right" if pos == positions[0]: # Only log once at start - logger.info("Using %s camera view (L:%.1f vs R:%.1f)", view_name, left_intensity, right_intensity) + logger.info( + "Using %s camera view (L:%.1f vs R:%.1f)", + view_name, + left_intensity, + right_intensity, + ) # Store image for ROI detection all_images.append(img) @@ -524,8 +564,16 @@ def calibrate_focus_at_position(camera, galvo, piezo, focus_scorer, core, roi_height = y_max - y_min roi_width = x_max - x_min roi_percent = (roi_width * roi_height) / (img.shape[0] * img.shape[1]) * 100 - logger.info("Embryo ROI: [%d:%d, %d:%d] (%dx%d px, %.1f%% of frame)", - y_min, y_max, x_min, x_max, roi_width, roi_height, roi_percent) + logger.info( + "Embryo ROI: [%d:%d, %d:%d] (%dx%d px, %.1f%% of frame)", + y_min, + y_max, + x_min, + x_max, + roi_width, + roi_height, + roi_percent, + ) # Convert to 8-bit with auto-scaling for better visibility # This matches the working calibrate_embryo_piezo_galvo.py behavior @@ -540,7 +588,10 @@ def calibrate_focus_at_position(camera, galvo, piezo, focus_scorer, core, # Save image if requested if image_dir is not None: from PIL import Image - image_path = Path(image_dir) / f"focus_pos{pos:.1f}um_{datetime.now().strftime('%H%M%S')}.png" + + image_path = ( + Path(image_dir) / f"focus_pos{pos:.1f}um_{datetime.now().strftime('%H%M%S')}.png" + ) image_path.parent.mkdir(parents=True, exist_ok=True) Image.fromarray(img_8bit).save(image_path) @@ -561,37 +612,37 @@ def calibrate_focus_at_position(camera, galvo, piezo, focus_scorer, core, logger.info("Fitting Gaussian curve to focus scores...") fit_result = focus_scorer.fit_focus_curve(all_positions, all_scores) - if fit_result['success']: + if fit_result["success"]: logger.info("Fit successful!") - logger.info("Best focus: %.2f um", fit_result['best_position']) - logger.info("R-squared: %.3f (%s)", fit_result['r_squared'], fit_result['fit_quality']) - logger.info("Peak in center: %s", fit_result['peak_in_center']) + logger.info("Best focus: %.2f um", fit_result["best_position"]) + logger.info("R-squared: %.3f (%s)", fit_result["r_squared"], fit_result["fit_quality"]) + logger.info("Peak in center: %s", fit_result["peak_in_center"]) result = { - 'success': True, - 'optimal_position_um': fit_result['best_position'], - 'r_squared': fit_result['r_squared'], - 'all_positions': all_positions, - 'all_scores': all_scores, - 'galvo_deg': galvo_deg, - 'fit_params': fit_result['params'] + "success": True, + "optimal_position_um": fit_result["best_position"], + "r_squared": fit_result["r_squared"], + "all_positions": all_positions, + "all_scores": all_scores, + "galvo_deg": galvo_deg, + "fit_params": fit_result["params"], } else: - logger.warning("Fit failed: %s", fit_result.get('error_message', 'Unknown error')) - logger.warning("R-squared: %.3f (threshold: %.3f)", fit_result['r_squared'], min_r_squared) + logger.warning("Fit failed: %s", fit_result.get("error_message", "Unknown error")) + logger.warning("R-squared: %.3f (threshold: %.3f)", fit_result["r_squared"], min_r_squared) logger.warning("Using maximum score position as fallback") max_idx = np.argmax(all_scores) fallback_position = all_positions[max_idx] result = { - 'success': False, - 'optimal_position_um': fallback_position, - 'r_squared': fit_result['r_squared'], - 'all_positions': all_positions, - 'all_scores': all_scores, - 'galvo_deg': galvo_deg, - 'error_message': fit_result.get('error_message', 'Poor fit quality') + "success": False, + "optimal_position_um": fallback_position, + "r_squared": fit_result["r_squared"], + "all_positions": all_positions, + "all_scores": all_scores, + "galvo_deg": galvo_deg, + "error_message": fit_result.get("error_message", "Poor fit quality"), } yield from bps.close_run() @@ -603,8 +654,14 @@ def calibrate_focus_at_position(camera, galvo, piezo, focus_scorer, core, # PLAN: FULL CALIBRATION ORCHESTRATION # ============================================================================ + def calibrate_embryo_piezo_galvo( - camera, galvo, piezo, focus_scorer, embryo_detector, core, + camera, + galvo, + piezo, + focus_scorer, + embryo_detector, + core, calibration_inset_fraction=0.4, edge_detection_step_deg=0.05, edge_tolerance_deg=0.20, @@ -614,7 +671,7 @@ def calibrate_embryo_piezo_galvo( heuristic_slope=100.0, heuristic_offset=0.0, image_dir=None, - save_path=None + save_path=None, ): """ Complete embryo-based piezo-galvo calibration workflow. @@ -716,7 +773,7 @@ def calibrate_embryo_piezo_galvo( if not centered: logger.error("CALIBRATION ABORTED: Embryo not centered") logger.error("Please adjust sample position and try again.") - return {'success': False, 'error': 'Embryo not centered'} + return {"success": False, "error": "Embryo not centered"} # ==================================================================== # PHASE 1.5: EDGE DETECTION @@ -728,41 +785,53 @@ def calibrate_embryo_piezo_galvo( logger.info("Detecting TOP edge (sweeping upward from center)...") top_edge_result = yield from detect_embryo_edge( embryo_detector, - direction='top', + direction="top", start_deg=0.0, end_deg=-0.5, step_deg=edge_detection_step_deg, tolerance_deg=edge_tolerance_deg, piezo_um=0.0, - image_dir=image_dir + image_dir=image_dir, ) - edge_top_deg = top_edge_result['edge_deg'] - scan_top_deg = top_edge_result['with_tolerance_deg'] + edge_top_deg = top_edge_result["edge_deg"] + scan_top_deg = top_edge_result["with_tolerance_deg"] # Detect BOTTOM edge (sweep downward from center) logger.info("Detecting BOTTOM edge (sweeping downward from center)...") bottom_edge_result = yield from detect_embryo_edge( embryo_detector, - direction='bottom', + direction="bottom", start_deg=0.0, end_deg=0.5, step_deg=edge_detection_step_deg, tolerance_deg=edge_tolerance_deg, piezo_um=0.0, - image_dir=image_dir + image_dir=image_dir, ) - edge_bottom_deg = bottom_edge_result['edge_deg'] - scan_bottom_deg = bottom_edge_result['with_tolerance_deg'] + edge_bottom_deg = bottom_edge_result["edge_deg"] + scan_bottom_deg = bottom_edge_result["with_tolerance_deg"] detected_range = scan_bottom_deg - scan_top_deg logger.info("EDGE DETECTION SUMMARY") - logger.info("Detected edges: TOP=%+.3f deg, BOTTOM=%+.3f deg", edge_top_deg, edge_bottom_deg) - logger.info("Scan boundaries: TOP=%+.3f deg (incl %.3f deg margin), BOTTOM=%+.3f deg", - scan_top_deg, edge_tolerance_deg, scan_bottom_deg) - logger.info("Total scan range: %.3f deg (~%.1f um)", detected_range, detected_range * 100) + logger.info( + "Detected edges: TOP=%+.3f deg, BOTTOM=%+.3f deg", + edge_top_deg, + edge_bottom_deg, + ) + logger.info( + "Scan boundaries: TOP=%+.3f deg (incl %.3f deg margin), BOTTOM=%+.3f deg", + scan_top_deg, + edge_tolerance_deg, + scan_bottom_deg, + ) + logger.info( + "Total scan range: %.3f deg (~%.1f um)", + detected_range, + detected_range * 100, + ) # ==================================================================== # CALCULATE INTERIOR CALIBRATION POSITIONS @@ -778,21 +847,37 @@ def calibrate_embryo_piezo_galvo( calib_bottom_deg = scan_bottom_deg - inset_amount calib_range = calib_bottom_deg - calib_top_deg - logger.info("Inset fraction: %.0f%%, distance: %.3f deg (~%.1f um)", - calibration_inset_fraction * 100, inset_amount, inset_amount * 100) - logger.info("TOP calibration: scan boundary %+.3f deg -> calibrate at %+.3f deg", - scan_top_deg, calib_top_deg) - logger.info("BOTTOM calibration: scan boundary %+.3f deg -> calibrate at %+.3f deg", - scan_bottom_deg, calib_bottom_deg) - logger.info("Calibration range: %.3f deg, Volume scan range: %.3f deg", - calib_range, detected_range) + logger.info( + "Inset fraction: %.0f%%, distance: %.3f deg (~%.1f um)", + calibration_inset_fraction * 100, + inset_amount, + inset_amount * 100, + ) + logger.info( + "TOP calibration: scan boundary %+.3f deg -> calibrate at %+.3f deg", + scan_top_deg, + calib_top_deg, + ) + logger.info( + "BOTTOM calibration: scan boundary %+.3f deg -> calibrate at %+.3f deg", + scan_bottom_deg, + calib_bottom_deg, + ) + logger.info( + "Calibration range: %.3f deg, Volume scan range: %.3f deg", + calib_range, + detected_range, + ) # Estimate piezo positions using heuristic piezo_top_heuristic = calib_top_deg * heuristic_slope + heuristic_offset piezo_bottom_heuristic = calib_bottom_deg * heuristic_slope + heuristic_offset - logger.info("Heuristic piezo positions: TOP=%.1f um, BOTTOM=%.1f um", - piezo_top_heuristic, piezo_bottom_heuristic) + logger.info( + "Heuristic piezo positions: TOP=%.1f um, BOTTOM=%.1f um", + piezo_top_heuristic, + piezo_bottom_heuristic, + ) # ==================================================================== # PHASE 2: TOP CALIBRATION @@ -800,21 +885,28 @@ def calibrate_embryo_piezo_galvo( logger.info("[4/8] Phase 2: TOP Interior Focus Calibration") top_result = yield from calibrate_focus_at_position( - camera, galvo, piezo, focus_scorer, core, + camera, + galvo, + piezo, + focus_scorer, + core, galvo_deg=calib_top_deg, piezo_center_um=piezo_top_heuristic, sweep_range_um=sweep_range_um, sweep_step_um=sweep_step_um, min_r_squared=min_r_squared, image_dir=image_dir, - phase_name="PHASE 2: TOP CALIBRATION" + phase_name="PHASE 2: TOP CALIBRATION", ) - if not top_result['success']: - logger.warning("TOP calibration fit quality poor (R-squared=%.3f)", top_result['r_squared']) + if not top_result["success"]: + logger.warning( + "TOP calibration fit quality poor (R-squared=%.3f)", + top_result["r_squared"], + ) logger.warning("Using best score position as fallback") - piezo_top_final = top_result['optimal_position_um'] + piezo_top_final = top_result["optimal_position_um"] logger.info("TOP optimal position: %.2f um", piezo_top_final) # ==================================================================== @@ -823,21 +915,28 @@ def calibrate_embryo_piezo_galvo( logger.info("[5/8] Phase 3: BOTTOM Interior Focus Calibration") bottom_result = yield from calibrate_focus_at_position( - camera, galvo, piezo, focus_scorer, core, + camera, + galvo, + piezo, + focus_scorer, + core, galvo_deg=calib_bottom_deg, piezo_center_um=piezo_bottom_heuristic, sweep_range_um=sweep_range_um, sweep_step_um=sweep_step_um, min_r_squared=min_r_squared, image_dir=image_dir, - phase_name="PHASE 3: BOTTOM CALIBRATION" + phase_name="PHASE 3: BOTTOM CALIBRATION", ) - if not bottom_result['success']: - logger.warning("BOTTOM calibration fit quality poor (R-squared=%.3f)", bottom_result['r_squared']) + if not bottom_result["success"]: + logger.warning( + "BOTTOM calibration fit quality poor (R-squared=%.3f)", + bottom_result["r_squared"], + ) logger.warning("Using best score position as fallback") - piezo_bottom_final = bottom_result['optimal_position_um'] + piezo_bottom_final = bottom_result["optimal_position_um"] logger.info("BOTTOM optimal position: %.2f um", piezo_bottom_final) # ==================================================================== @@ -855,8 +954,14 @@ def calibrate_embryo_piezo_galvo( slope = delta_piezo / delta_galvo offset = piezo_top_final - slope * calib_top_deg - logger.info("Calibration points: TOP galvo=%+.3f deg piezo=%+.2f um, BOTTOM galvo=%+.3f deg piezo=%+.2f um", - calib_top_deg, piezo_top_final, calib_bottom_deg, piezo_bottom_final) + logger.info( + "Calibration points: TOP galvo=%+.3f deg piezo=%+.2f um," + " BOTTOM galvo=%+.3f deg piezo=%+.2f um", + calib_top_deg, + piezo_top_final, + calib_bottom_deg, + piezo_bottom_final, + ) logger.info("Linear fit: piezo(um) = %.2f * galvo(deg) + %.2f", slope, offset) logger.info("Slope: %.2f um/deg, Offset: %.2f um", slope, offset) @@ -864,9 +969,15 @@ def calibrate_embryo_piezo_galvo( piezo_top_scan = scan_top_deg * slope + offset piezo_bottom_scan = scan_bottom_deg * slope + offset - logger.info("Volume scan range: TOP galvo=%+.3f deg piezo=%+.2f um, BOTTOM galvo=%+.3f deg piezo=%+.2f um, Total=%.1f um", - scan_top_deg, piezo_top_scan, scan_bottom_deg, piezo_bottom_scan, - abs(piezo_bottom_scan - piezo_top_scan)) + logger.info( + "Volume scan range: TOP galvo=%+.3f deg piezo=%+.2f um," + " BOTTOM galvo=%+.3f deg piezo=%+.2f um, Total=%.1f um", + scan_top_deg, + piezo_top_scan, + scan_bottom_deg, + piezo_bottom_scan, + abs(piezo_bottom_scan - piezo_top_scan), + ) # ==================================================================== # CREATE CALIBRATION OBJECT @@ -882,8 +993,8 @@ def calibrate_embryo_piezo_galvo( piezo_bottom_um=piezo_bottom_scan, edge_top_deg=edge_top_deg, edge_bottom_deg=edge_bottom_deg, - sample_type='embryo', - timestamp=datetime.now().strftime('%Y-%m-%d %H:%M:%S') + sample_type="embryo", + timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), ) logger.info("Calibration object created: %s", calibration) @@ -910,24 +1021,27 @@ def calibrate_embryo_piezo_galvo( logger.info("Calibration: piezo(um) = %.2f * galvo(deg) + %.2f", slope, offset) logger.info("Scan range: %+.3f deg to %+.3f deg", scan_top_deg, scan_bottom_deg) logger.info("Piezo range: %+.2f to %+.2f um", piezo_top_scan, piezo_bottom_scan) - logger.info("Quality: TOP R-squared=%.3f, BOTTOM R-squared=%.3f", - top_result['r_squared'], bottom_result['r_squared']) + logger.info( + "Quality: TOP R-squared=%.3f, BOTTOM R-squared=%.3f", + top_result["r_squared"], + bottom_result["r_squared"], + ) if save_path: logger.info("Calibration saved to: %s", save_path) # Store result for databroker access result = { - 'success': True, - 'calibration': calibration, - 'slope_um_per_deg': slope, - 'offset_um': offset, - 'top_r_squared': top_result['r_squared'], - 'bottom_r_squared': bottom_result['r_squared'], - 'scan_top_deg': scan_top_deg, - 'scan_bottom_deg': scan_bottom_deg, - 'edge_top_deg': edge_top_deg, - 'edge_bottom_deg': edge_bottom_deg + "success": True, + "calibration": calibration, + "slope_um_per_deg": slope, + "offset_um": offset, + "top_r_squared": top_result["r_squared"], + "bottom_r_squared": bottom_result["r_squared"], + "scan_top_deg": scan_top_deg, + "scan_bottom_deg": scan_bottom_deg, + "edge_top_deg": edge_top_deg, + "edge_bottom_deg": edge_bottom_deg, } except Exception as e: @@ -936,10 +1050,7 @@ def calibrate_embryo_piezo_galvo( logger.error("=" * 70) logger.error("Error: %s", e, exc_info=True) - result = { - 'success': False, - 'error': str(e) - } + result = {"success": False, "error": str(e)} return result @@ -949,10 +1060,10 @@ def calibrate_embryo_piezo_galvo( # ============================================================================ __all__ = [ - 'EMBRYO_CENTERING_PROMPT', - 'EMBRYO_EDGE_PROMPT', - 'verify_embryo_centered', - 'detect_embryo_edge', - 'calibrate_focus_at_position', - 'calibrate_embryo_piezo_galvo' + "EMBRYO_CENTERING_PROMPT", + "EMBRYO_EDGE_PROMPT", + "verify_embryo_centered", + "detect_embryo_edge", + "calibrate_focus_at_position", + "calibrate_embryo_piezo_galvo", ] diff --git a/gently/hardware/dispim/plans/multi_embryo.py b/gently/hardware/dispim/plans/multi_embryo.py index 7a4b760a..6536b37c 100644 --- a/gently/hardware/dispim/plans/multi_embryo.py +++ b/gently/hardware/dispim/plans/multi_embryo.py @@ -15,36 +15,34 @@ """ import logging -import numpy as np -import time -from pathlib import Path from datetime import datetime -from typing import Dict, List, Optional, Tuple -import bluesky.plan_stubs as bps +from pathlib import Path -logger = logging.getLogger(__name__) -from bluesky.preprocessors import finalize_wrapper, run_wrapper +import bluesky.plan_stubs as bps +import numpy as np # Import existing calibration infrastructure -from .calibration import calibrate_embryo_piezo_galvo -from gently.ui.web.embryo_marker import mark_embryos_web from gently.core.database import ( - export_multi_embryo_database, add_embryo_to_database, - save_multi_embryo_database + save_multi_embryo_database, ) +from gently.ui.web.embryo_marker import mark_embryos_web + +from .calibration import calibrate_embryo_piezo_galvo +logger = logging.getLogger(__name__) # ============================================================================ # PLAN: CENTER AND VERIFY EMBRYO # ============================================================================ + def center_and_verify_embryo_plan( bottom_camera, xy_stage, - embryo_data: Dict, + embryo_data: dict, save_verification: bool = True, - image_dir: Optional[Path] = None + image_dir: Path | None = None, ): """ Center XY stage on marked embryo position and capture verification image. @@ -78,9 +76,9 @@ def center_and_verify_embryo_plan( dict Updated embryo data with 'centered_stage_position_um' and verification info """ - embryo_id = embryo_data['embryo_id'] - pixel_x, pixel_y = embryo_data['pixel_position'] - initial_stage_x, initial_stage_y = embryo_data['initial_stage_position'] + embryo_id = embryo_data["embryo_id"] + pixel_x, pixel_y = embryo_data["pixel_position"] + initial_stage_x, initial_stage_y = embryo_data["initial_stage_position"] logger.info("[Centering %s]", embryo_id) logger.info("Pixel position: (%.1f, %.1f)", pixel_x, pixel_y) @@ -108,7 +106,11 @@ def center_and_verify_embryo_plan( # Debug output matching original code format logger.debug("Image center: (%.0f, %.0f) pixels", image_center_x, image_center_y) - logger.debug("Pixel displacement: (%+.1f, %+.1f) pixels", pixel_displacement_x, pixel_displacement_y) + logger.debug( + "Pixel displacement: (%+.1f, %+.1f) pixels", + pixel_displacement_x, + pixel_displacement_y, + ) logger.debug("Stage movement: (%+.2f, %+.2f) um", dx, dy) logger.info("Target stage: (%.2f, %.2f) um", target_pos[0], target_pos[1]) @@ -117,15 +119,29 @@ def center_and_verify_embryo_plan( y_min, y_max = xy_stage._y_limits if not (x_min <= target_pos[0] <= x_max): - logger.error("Target X position %.2f outside limits (%s, %s)", target_pos[0], x_min, x_max) + logger.error( + "Target X position %.2f outside limits (%s, %s)", + target_pos[0], + x_min, + x_max, + ) logger.error("Skipping %s - marked position unreachable", embryo_id) - embryo_data['error'] = f"Stage X out of bounds: {target_pos[0]:.2f} not in ({x_min}, {x_max})" + embryo_data["error"] = ( + f"Stage X out of bounds: {target_pos[0]:.2f} not in ({x_min}, {x_max})" + ) return embryo_data if not (y_min <= target_pos[1] <= y_max): - logger.error("Target Y position %.2f outside limits (%s, %s)", target_pos[1], y_min, y_max) + logger.error( + "Target Y position %.2f outside limits (%s, %s)", + target_pos[1], + y_min, + y_max, + ) logger.error("Skipping %s - marked position unreachable", embryo_id) - embryo_data['error'] = f"Stage Y out of bounds: {target_pos[1]:.2f} not in ({y_min}, {y_max})" + embryo_data["error"] = ( + f"Stage Y out of bounds: {target_pos[1]:.2f} not in ({y_min}, {y_max})" + ) return embryo_data # Move stage @@ -140,14 +156,14 @@ def center_and_verify_embryo_plan( # Capture verification image logger.info("Capturing verification image...") - yield from bps.trigger_and_read([bottom_camera], name='verification') + yield from bps.trigger_and_read([bottom_camera], name="verification") # Save verification image if requested if save_verification and image_dir is not None: image_dir = Path(image_dir) image_dir.mkdir(parents=True, exist_ok=True) - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") save_path = image_dir / f"{embryo_id}_AFTER_centering_{timestamp}.png" # Get image from device @@ -155,13 +171,14 @@ def center_and_verify_embryo_plan( if verification_image is not None: import tifffile + tifffile.imwrite(save_path, verification_image) logger.info("Saved: %s", save_path.name) # Update embryo data with centered position - embryo_data['centered_stage_x'] = float(final_pos[0]) - embryo_data['centered_stage_y'] = float(final_pos[1]) - embryo_data['centering_timestamp'] = datetime.now().isoformat() + embryo_data["centered_stage_x"] = float(final_pos[0]) + embryo_data["centered_stage_y"] = float(final_pos[1]) + embryo_data["centering_timestamp"] = datetime.now().isoformat() logger.info("%s centered!", embryo_id) @@ -172,12 +189,13 @@ def center_and_verify_embryo_plan( # PLAN: CALIBRATE SINGLE EMBRYO IN SESSION # ============================================================================ + def calibrate_single_embryo_in_session_plan( - embryo_data: Dict, + embryo_data: dict, embryo_detector, laser_control, - image_dir: Optional[Path] = None, - calibration_params: Optional[Dict] = None + image_dir: Path | None = None, + calibration_params: dict | None = None, ): """ Calibrate single embryo and store results in databroker. @@ -207,25 +225,25 @@ def calibrate_single_embryo_in_session_plan( dict Updated embryo data with calibration results """ - embryo_id = embryo_data['embryo_id'] - embryo_number = embryo_data['embryo_number'] + embryo_id = embryo_data["embryo_id"] + embryo_number = embryo_data["embryo_number"] logger.info("=" * 70) logger.info("CALIBRATING %s", embryo_id.upper()) logger.info("=" * 70) # Prepare metadata for this embryo run - embryo_metadata = { - 'embryo_id': embryo_id, - 'embryo_number': embryo_number, - 'pixel_x': embryo_data['pixel_position'][0], - 'pixel_y': embryo_data['pixel_position'][1], - 'initial_stage_x': embryo_data['initial_stage_position'][0], - 'initial_stage_y': embryo_data['initial_stage_position'][1], - 'centered_stage_x': embryo_data.get('centered_stage_x', 0.0), - 'centered_stage_y': embryo_data.get('centered_stage_y', 0.0), - 'marking_timestamp': embryo_data.get('marking_timestamp', ''), - 'centering_timestamp': embryo_data.get('centering_timestamp', ''), + { + "embryo_id": embryo_id, + "embryo_number": embryo_number, + "pixel_x": embryo_data["pixel_position"][0], + "pixel_y": embryo_data["pixel_position"][1], + "initial_stage_x": embryo_data["initial_stage_position"][0], + "initial_stage_y": embryo_data["initial_stage_position"][1], + "centered_stage_x": embryo_data.get("centered_stage_x", 0.0), + "centered_stage_y": embryo_data.get("centered_stage_y", 0.0), + "marking_timestamp": embryo_data.get("marking_timestamp", ""), + "centering_timestamp": embryo_data.get("centering_timestamp", ""), } # Merge custom calibration parameters @@ -242,16 +260,16 @@ def calibrate_single_embryo_in_session_plan( embryo_detector=embryo_detector.claude_client, core=embryo_detector.core, image_dir=image_dir, - **calib_params + **calib_params, ) # Extract calibration data from result if calibration_result is not None: - embryo_data['calibration'] = calibration_result + embryo_data["calibration"] = calibration_result logger.info("%s calibration complete!", embryo_id) else: logger.error("%s calibration failed!", embryo_id) - embryo_data['calibration'] = None + embryo_data["calibration"] = None return embryo_data @@ -260,16 +278,18 @@ def calibrate_single_embryo_in_session_plan( # PLAN: MULTI-EMBRYO CALIBRATION SESSION # ============================================================================ + def multi_embryo_calibration_session_plan( bottom_camera, xy_stage, embryo_detector, laser_control, output_database_path: Path, - image_dir: Optional[Path] = None, + image_dir: Path | None = None, auto_mark: bool = False, - pre_marked_embryos: Optional[List[Dict]] = None, - calibration_params: Optional[Dict] = None + pre_marked_embryos: list[dict] | None = None, + calibration_params: dict | None = None, + **kwargs, ): """ Complete multi-embryo calibration workflow with Bluesky architecture. @@ -352,11 +372,11 @@ def multi_embryo_calibration_session_plan( # Prepare session metadata session_metadata = { - 'plan_name': 'multi_embryo_calibration_session', - 'output_database_path': str(output_database_path), - 'image_dir': str(image_dir), - 'session_start': datetime.now().isoformat(), - 'embryo_runs': [] # Will be populated with UIDs + "plan_name": "multi_embryo_calibration_session", + "output_database_path": str(output_database_path), + "image_dir": str(image_dir), + "session_start": datetime.now().isoformat(), + "embryo_runs": [], # Will be populated with UIDs } def inner_session_plan(): @@ -369,10 +389,15 @@ def inner_session_plan(): # Use pre-marked embryos if provided (marking done outside plan to avoid Qt threading) if pre_marked_embryos is not None: - logger.info("Using pre-marked embryo positions (%d embryos)", len(pre_marked_embryos)) + logger.info( + "Using pre-marked embryo positions (%d embryos)", + len(pre_marked_embryos), + ) marked_embryos = pre_marked_embryos elif auto_mark: - raise NotImplementedError("Auto-detection not yet implemented. Use pre-marked positions.") + raise NotImplementedError( + "Auto-detection not yet implemented. Use pre-marked positions." + ) else: # This path requires capturing overview and interactive marking inside plan # WARNING: This may cause Qt threading issues with napari! @@ -381,11 +406,15 @@ def inner_session_plan(): # Get initial stage position initial_stage_pos = xy_stage.get_position() - logger.info("Initial stage position: (%.2f, %.2f) um", initial_stage_pos[0], initial_stage_pos[1]) + logger.info( + "Initial stage position: (%.2f, %.2f) um", + initial_stage_pos[0], + initial_stage_pos[1], + ) # Capture overview image logger.info("Capturing bottom camera overview...") - yield from bps.trigger_and_read([bottom_camera], name='overview') + yield from bps.trigger_and_read([bottom_camera], name="overview") # Get overview image from device overview_image = bottom_camera._last_image @@ -394,18 +423,20 @@ def inner_session_plan(): raise RuntimeError("Failed to capture overview image!") # Save overview image - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") overview_path = image_dir / f"initial_view_{timestamp}.png" import tifffile + tifffile.imwrite(overview_path, overview_image) logger.info("Saved overview: %s", overview_path.name) logger.info("Launching interactive embryo marking...") # Prefer web-based marking (no Qt dependency, works remotely) - viz_server = kwargs.get('viz_server') + viz_server = kwargs.get("viz_server") if viz_server is not None: import asyncio + logger.info("Using web-based marking via visualization server") marked_embryos = asyncio.get_event_loop().run_until_complete( mark_embryos_web( @@ -413,7 +444,7 @@ def inner_session_plan(): image=overview_image, initial_stage_position=tuple(initial_stage_pos), pixel_size_um=bottom_camera.effective_pixel_size, - save_image_path=image_dir / f"all_embryos_marked_{timestamp}.png" + save_image_path=image_dir / f"all_embryos_marked_{timestamp}.png", ) ) else: @@ -425,9 +456,7 @@ def inner_session_plan(): "interactive marking surface (napari has been retired). " "Start the visualization server before calling this plan." ) - raise RuntimeError( - "Interactive marking requires the web visualization server." - ) + raise RuntimeError("Interactive marking requires the web visualization server.") if len(marked_embryos) == 0: logger.error("No embryos marked! Aborting session.") @@ -442,10 +471,9 @@ def inner_session_plan(): logger.info("[PHASE 2] Calibrating each embryo...") calibrated_embryos = [] - embryo_run_uids = [] for i, embryo_data in enumerate(marked_embryos, 1): - embryo_id = embryo_data['embryo_id'] + embryo_id = embryo_data["embryo_id"] logger.info("-" * 70) logger.info("EMBRYO %d/%d: %s", i, len(marked_embryos), embryo_id) @@ -457,11 +485,11 @@ def inner_session_plan(): xy_stage=xy_stage, embryo_data=embryo_data, save_verification=True, - image_dir=image_dir + image_dir=image_dir, ) # Check if centering failed (position out of bounds) - if 'error' in embryo_data: + if "error" in embryo_data: logger.warning("Skipping calibration for %s due to centering error", embryo_id) calibrated_embryos.append(embryo_data) continue @@ -472,7 +500,7 @@ def inner_session_plan(): embryo_detector=embryo_detector, laser_control=laser_control, image_dir=image_dir / embryo_id, - calibration_params=calibration_params + calibration_params=calibration_params, ) calibrated_embryos.append(embryo_data) @@ -488,13 +516,13 @@ def inner_session_plan(): # Build database structure database = { - 'created': session_metadata['session_start'], - 'embryos': {}, - 'last_updated': datetime.now().isoformat() + "created": session_metadata["session_start"], + "embryos": {}, + "last_updated": datetime.now().isoformat(), } for embryo_data in calibrated_embryos: - embryo_id = embryo_data['embryo_id'] + embryo_id = embryo_data["embryo_id"] database = add_embryo_to_database(database, embryo_id, embryo_data) # Save JSON database @@ -503,7 +531,10 @@ def inner_session_plan(): logger.info("Exported database: %s", output_database_path_resolved) logger.info("Total embryos: %d", len(calibrated_embryos)) - logger.info("Successful calibrations: %d", sum(1 for e in calibrated_embryos if e.get('calibration') is not None)) + logger.info( + "Successful calibrations: %d", + sum(1 for e in calibrated_embryos if e.get("calibration") is not None), + ) # ==================================================================== # SESSION COMPLETE diff --git a/gently/hardware/dispim/sam_detection.py b/gently/hardware/dispim/sam_detection.py index caeb8e5a..cea9990a 100644 --- a/gently/hardware/dispim/sam_detection.py +++ b/gently/hardware/dispim/sam_detection.py @@ -5,29 +5,28 @@ Returns embryo positions (pixel + stage coordinates) for calibration workflow. """ -import logging -import time +import base64 import json +import logging +import os import uuid -import numpy as np +from io import BytesIO from pathlib import Path + +import anthropic import cv2 -import base64 -from io import BytesIO +import numpy as np from PIL import Image -import anthropic -from typing import Dict, List, Tuple, Optional -import os from gently.settings import settings logger = logging.getLogger(__name__) -from gently.core.coordinates import ( - pixel_to_stage_position, - get_um_per_pixel, - DEFAULT_PIXEL_SIZE_UM, +from gently.core.coordinates import ( # noqa: E402 DEFAULT_OBJECTIVE_MAG, + DEFAULT_PIXEL_SIZE_UM, + get_um_per_pixel, + pixel_to_stage_position, ) @@ -42,11 +41,13 @@ class SAMEmbryoDetector: - Returns embryo positions as simple list of coordinates """ - def __init__(self, - sam_checkpoint: str = "sam_vit_b_01ec64.pth", - sam_model_type: str = "vit_b", - device: str = "cpu", - anthropic_api_key: Optional[str] = None): + def __init__( + self, + sam_checkpoint: str = "sam_vit_b_01ec64.pth", + sam_model_type: str = "vit_b", + device: str = "cpu", + anthropic_api_key: str | None = None, + ): """ Initialize SAM detector @@ -85,7 +86,7 @@ def _load_sam(self): if self._mask_generator is not None: return - from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor + from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry if not Path(self.sam_checkpoint).exists(): raise FileNotFoundError(f"SAM checkpoint not found: {self.sam_checkpoint}") @@ -108,13 +109,15 @@ def _load_sam(self): self._predictor = SamPredictor(sam) logger.info("SAM model loaded") - def preprocess_image(self, - image: np.ndarray, - bg_kernel_size: int = 150, - use_clahe: bool = True, - clahe_clip_limit: float = 3.0, - clahe_tile_size: int = 16, - gaussian_sigma: float = 2.0) -> np.ndarray: + def preprocess_image( + self, + image: np.ndarray, + bg_kernel_size: int = 150, + use_clahe: bool = True, + clahe_clip_limit: float = 3.0, + clahe_tile_size: int = 16, + gaussian_sigma: float = 2.0, + ) -> np.ndarray: """ Preprocess image for better SAM detection. @@ -151,14 +154,18 @@ def preprocess_image(self, # This stretches the narrow range (e.g., 84-354) to full 0-255 logger.debug("Percentile normalization (2-98%%)...") p2, p98 = np.percentile(image, (2, 98)) - img_norm = np.clip((image.astype(np.float32) - p2) / (p98 - p2) * 255, 0, 255).astype(np.uint8) + img_norm = np.clip((image.astype(np.float32) - p2) / (p98 - p2) * 255, 0, 255).astype( + np.uint8 + ) logger.debug("Normalized to 0-255") # Step 2: Background subtraction with large morphological opening # Removes large-scale illumination variations if bg_kernel_size > 0: logger.debug("Background subtraction (kernel=%d)...", bg_kernel_size) - kernel_bg = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (bg_kernel_size, bg_kernel_size)) + kernel_bg = cv2.getStructuringElement( + cv2.MORPH_ELLIPSE, (bg_kernel_size, bg_kernel_size) + ) background = cv2.morphologyEx(img_norm, cv2.MORPH_OPEN, kernel_bg) img_no_bg = cv2.subtract(img_norm, background) img_no_bg = cv2.normalize(img_no_bg, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) @@ -171,8 +178,7 @@ def preprocess_image(self, if use_clahe: logger.debug("CLAHE (clip=%.1f, tile=%d)...", clahe_clip_limit, clahe_tile_size) clahe = cv2.createCLAHE( - clipLimit=clahe_clip_limit, - tileGridSize=(clahe_tile_size, clahe_tile_size) + clipLimit=clahe_clip_limit, tileGridSize=(clahe_tile_size, clahe_tile_size) ) img_enhanced = clahe.apply(img_no_bg) logger.debug("CLAHE applied") @@ -187,16 +193,20 @@ def preprocess_image(self, else: img_smooth = img_enhanced - logger.debug("Preprocessing complete (output range: %s - %s)", img_smooth.min(), img_smooth.max()) + logger.debug( + "Preprocessing complete (output range: %s - %s)", img_smooth.min(), img_smooth.max() + ) return img_smooth - def find_embryo_candidates(self, - image: np.ndarray, - brightness_percentile: float = 99.0, - min_area: int = 5000, - max_area: int = 150000, - clahe_clip: float = 3.0, - clahe_tile: int = 16) -> Tuple[List[Dict], np.ndarray]: + def find_embryo_candidates( + self, + image: np.ndarray, + brightness_percentile: float = 99.0, + min_area: int = 5000, + max_area: int = 150000, + clahe_clip: float = 3.0, + clahe_tile: int = 16, + ) -> tuple[list[dict], np.ndarray]: """ Find embryo candidates using brightness-based detection. @@ -229,12 +239,16 @@ def find_embryo_candidates(self, enhanced_image : np.ndarray Contrast-enhanced 8-bit image for SAM """ - logger.info("Finding embryo candidates (brightness percentile=%.1f)...", brightness_percentile) + logger.info( + "Finding embryo candidates (brightness percentile=%.1f)...", brightness_percentile + ) logger.debug("Input range: %s - %s", image.min(), image.max()) # Step 1: Percentile normalization (handles low dynamic range) p2, p98 = np.percentile(image, (2, 98)) - img_norm = np.clip((image.astype(np.float32) - p2) / (p98 - p2) * 255, 0, 255).astype(np.uint8) + img_norm = np.clip((image.astype(np.float32) - p2) / (p98 - p2) * 255, 0, 255).astype( + np.uint8 + ) logger.debug("Normalized to 0-255") # Step 2: CLAHE for local contrast enhancement @@ -261,7 +275,9 @@ def find_embryo_candidates(self, logger.debug("Morphological cleanup complete") # Step 7: Find connected components and filter by area - num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask, connectivity=8) + num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats( + mask, connectivity=8 + ) candidates = [] for i in range(1, num_labels): # Skip background (label 0) @@ -273,19 +289,14 @@ def find_embryo_candidates(self, h = stats[i, cv2.CC_STAT_HEIGHT] cx, cy = centroids[i] - candidates.append({ - 'bbox': (x, y, w, h), - 'centroid': (cx, cy), - 'area': area - }) + candidates.append({"bbox": (x, y, w, h), "centroid": (cx, cy), "area": area}) logger.info("Found %d embryo candidates", len(candidates)) return candidates, img_smooth - def refine_with_sam(self, - image: np.ndarray, - candidates: List[Dict], - padding: int = 20) -> List[Dict]: + def refine_with_sam( + self, image: np.ndarray, candidates: list[dict], padding: int = 20 + ) -> list[dict]: """ Refine embryo candidates using SAM with bounding box prompts. @@ -322,7 +333,7 @@ def refine_with_sam(self, h, w = image.shape[:2] for i, candidate in enumerate(candidates): - x, y, bw, bh = candidate['bbox'] + x, y, bw, bh = candidate["bbox"] # Add padding and clip to image bounds x1 = max(0, x - padding) @@ -335,10 +346,7 @@ def refine_with_sam(self, # Get SAM prediction with box prompt masks, scores, _ = self._predictor.predict( - point_coords=None, - point_labels=None, - box=input_box, - multimask_output=True + point_coords=None, point_labels=None, box=input_box, multimask_output=True ) # Take best mask (highest score) @@ -348,9 +356,7 @@ def refine_with_sam(self, # Calculate properties from SAM mask contours, _ = cv2.findContours( - mask.astype(np.uint8), - cv2.RETR_EXTERNAL, - cv2.CHAIN_APPROX_SIMPLE + mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE ) if contours: @@ -364,41 +370,47 @@ def refine_with_sam(self, cx = M["m10"] / M["m00"] cy = M["m01"] / M["m00"] else: - cx, cy = candidate['centroid'] + cx, cy = candidate["centroid"] # Calculate circularity perimeter = cv2.arcLength(contour, True) - circularity = 4 * np.pi * area / (perimeter ** 2) if perimeter > 0 else 0 + circularity = 4 * np.pi * area / (perimeter**2) if perimeter > 0 else 0 # Get bounding box from contour bx, by, bw, bh = cv2.boundingRect(contour) - embryos.append({ - 'embryo_id': f'embryo_{i + 1}', - 'uid': str(uuid.uuid4()), # Global unique identifier for cross-session tracking - 'pixel_x': float(cx), - 'pixel_y': float(cy), - 'bbox': (bx, by, bw, bh), # Used by visualization functions - 'area_pixels': int(area), - 'circularity': float(circularity), - 'confidence': float(score), - 'mask': mask - }) + embryos.append( + { + "embryo_id": f"embryo_{i + 1}", + "uid": str( + uuid.uuid4() + ), # Global unique identifier for cross-session tracking + "pixel_x": float(cx), + "pixel_y": float(cy), + "bbox": (bx, by, bw, bh), # Used by visualization functions + "area_pixels": int(area), + "circularity": float(circularity), + "confidence": float(score), + "mask": mask, + } + ) logger.info("SAM refined %d embryos", len(embryos)) return embryos - async def detect_embryos(self, - image: np.ndarray, - stage_position: Tuple[float, float], - pixel_size_um: float = DEFAULT_PIXEL_SIZE_UM, - objective_mag: float = DEFAULT_OBJECTIVE_MAG, - use_claude_review: bool = True, - save_visualizations: bool = True, - output_dir: Optional[Path] = None, - brightness_percentile: float = 99.0, - min_area: int = 5000, - max_area: int = 150000) -> Dict: + async def detect_embryos( + self, + image: np.ndarray, + stage_position: tuple[float, float], + pixel_size_um: float = DEFAULT_PIXEL_SIZE_UM, + objective_mag: float = DEFAULT_OBJECTIVE_MAG, + use_claude_review: bool = True, + save_visualizations: bool = True, + output_dir: Path | None = None, + brightness_percentile: float = 99.0, + min_area: int = 5000, + max_area: int = 150000, + ) -> dict: """ Detect embryos using brightness-based detection + SAM refinement. @@ -462,20 +474,17 @@ async def detect_embryos(self, # Step 1: Find candidates using brightness detection logger.info("[1/4] Finding embryo candidates (brightness-based)...") candidates, image_enhanced = self.find_embryo_candidates( - image, - brightness_percentile=brightness_percentile, - min_area=min_area, - max_area=max_area + image, brightness_percentile=brightness_percentile, min_area=min_area, max_area=max_area ) if len(candidates) == 0: logger.warning("No embryo candidates found!") return { - 'embryos': [], - 'initial_detections': 0, - 'final_detections': 0, - 'verification': {'verified': False}, - 'images': {} + "embryos": [], + "initial_detections": 0, + "final_detections": 0, + "verification": {"verified": False}, + "images": {}, } # Step 2: Refine with SAM @@ -489,11 +498,11 @@ async def detect_embryos(self, if len(embryos_sam) == 0: logger.warning("No embryos detected by SAM!") return { - 'embryos': [], - 'initial_detections': 0, - 'final_detections': 0, - 'verification': {'verified': False}, - 'images': {} + "embryos": [], + "initial_detections": 0, + "final_detections": 0, + "verification": {"verified": False}, + "images": {}, } # Save initial detection @@ -503,8 +512,8 @@ async def detect_embryos(self, # Claude review (if enabled) embryos_final = embryos_sam - verification = {'verified': True, 'skipped': not use_claude_review} - changes = {'round1': {'removed': [], 'added': []}} + verification = {"verified": True, "skipped": not use_claude_review} + changes = {"round1": {"removed": [], "added": []}} if use_claude_review and self.claude_client: logger.info("[2/4] Claude Vision review (Round 1)...") @@ -512,7 +521,7 @@ async def detect_embryos(self, review_r1 = await self._review_with_claude(image_8bit, annotated, embryos_sam) logger.info("[3/4] Applying corrections...") - embryos_r1, changes['round1'] = self._apply_corrections( + embryos_r1, changes["round1"] = self._apply_corrections( embryos_sam, review_r1, image, self._predictor ) @@ -524,22 +533,22 @@ async def detect_embryos(self, logger.info("[4/4] Claude verification (Round 2)...") r1_viz = self._create_annotated_image(image_8bit, embryos_r1) verification = await self._verify_with_claude( - image_8bit, r1_viz, embryos_r1, changes['round1'] + image_8bit, r1_viz, embryos_r1, changes["round1"] ) # Apply round 2 corrections if needed has_r2_changes = ( - len(verification.get('additional_false_positives', [])) > 0 or - len(verification.get('additional_false_negatives', [])) > 0 + len(verification.get("additional_false_positives", [])) > 0 + or len(verification.get("additional_false_negatives", [])) > 0 ) if has_r2_changes: logger.info("Applying Round 2 corrections...") review_r2 = { - 'false_positives': verification.get('additional_false_positives', []), - 'false_negatives': verification.get('additional_false_negatives', []) + "false_positives": verification.get("additional_false_positives", []), + "false_negatives": verification.get("additional_false_negatives", []), } - embryos_final, changes['round2'] = self._apply_corrections( + embryos_final, changes["round2"] = self._apply_corrections( embryos_r1, review_r2, image, self._predictor ) else: @@ -553,7 +562,7 @@ async def detect_embryos(self, stage_position, pixel_size_um, objective_mag, - image_shape=image.shape[:2] # (height, width) + image_shape=image.shape[:2], # (height, width) ) # Save final visualization @@ -563,19 +572,19 @@ async def detect_embryos(self, # Package results results = { - 'embryos': embryo_positions, - 'initial_detections': len(embryos_sam), - 'final_detections': len(embryos_final), - 'verification': verification, - 'changes': changes, - 'images': { - 'initial': str(output_dir / "detection_initial.png"), - 'final': str(output_dir / "detection_final.png") - } + "embryos": embryo_positions, + "initial_detections": len(embryos_sam), + "final_detections": len(embryos_final), + "verification": verification, + "changes": changes, + "images": { + "initial": str(output_dir / "detection_initial.png"), + "final": str(output_dir / "detection_final.png"), + }, } if use_claude_review and save_visualizations: - results['images']['round1'] = str(output_dir / "detection_round1.png") + results["images"]["round1"] = str(output_dir / "detection_round1.png") logger.info("=" * 70) logger.info("DETECTION COMPLETE: %d embryos", len(embryo_positions)) @@ -594,7 +603,7 @@ def _to_rgb8(image: np.ndarray) -> np.ndarray: return cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) return image - def _detect_with_sam(self, image: np.ndarray) -> Tuple[List[Dict], np.ndarray]: + def _detect_with_sam(self, image: np.ndarray) -> tuple[list[dict], np.ndarray]: """Run SAM automatic segmentation (extracted from test script)""" image_rgb = self._to_rgb8(image) @@ -604,14 +613,16 @@ def _detect_with_sam(self, image: np.ndarray) -> Tuple[List[Dict], np.ndarray]: # Filter candidates embryo_candidates = [] for mask_data in masks: - area = mask_data['area'] + area = mask_data["area"] if not (self.min_area <= area <= self.max_area): continue - bbox = mask_data['bbox'] - mask = mask_data['segmentation'] - contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + bbox = mask_data["bbox"] + mask = mask_data["segmentation"] + contours, _ = cv2.findContours( + mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE + ) if len(contours) == 0: continue @@ -620,40 +631,44 @@ def _detect_with_sam(self, image: np.ndarray) -> Tuple[List[Dict], np.ndarray]: if perimeter == 0: continue - circularity = 4 * np.pi * area / (perimeter ** 2) + circularity = 4 * np.pi * area / (perimeter**2) if circularity < self.min_circularity: continue - embryo_candidates.append({ - 'mask': mask, - 'bbox': bbox, - 'area': area, - 'circularity': circularity, - 'stability_score': mask_data['stability_score'], - 'predicted_iou': mask_data['predicted_iou'] - }) + embryo_candidates.append( + { + "mask": mask, + "bbox": bbox, + "area": area, + "circularity": circularity, + "stability_score": mask_data["stability_score"], + "predicted_iou": mask_data["predicted_iou"], + } + ) # Sort by quality and apply spatial separation - embryo_candidates.sort(key=lambda x: (x['area'] * x['stability_score']), reverse=True) + embryo_candidates.sort(key=lambda x: x["area"] * x["stability_score"], reverse=True) selected_embryos = [] for candidate in embryo_candidates: if len(selected_embryos) >= self.max_embryos: break - bbox = candidate['bbox'] + bbox = candidate["bbox"] candidate_center_x = bbox[0] + bbox[2] / 2 candidate_center_y = bbox[1] + bbox[3] / 2 too_close = False for selected in selected_embryos: - sel_bbox = selected['bbox'] + sel_bbox = selected["bbox"] sel_center_x = sel_bbox[0] + sel_bbox[2] / 2 sel_center_y = sel_bbox[1] + sel_bbox[3] / 2 - distance = np.sqrt((candidate_center_x - sel_center_x)**2 + - (candidate_center_y - sel_center_y)**2) + distance = np.sqrt( + (candidate_center_x - sel_center_x) ** 2 + + (candidate_center_y - sel_center_y) ** 2 + ) if distance < self.min_separation_pixels: too_close = True @@ -662,19 +677,27 @@ def _detect_with_sam(self, image: np.ndarray) -> Tuple[List[Dict], np.ndarray]: if not too_close: selected_embryos.append(candidate) - return selected_embryos, image_8bit + return selected_embryos, image_rgb - def _create_annotated_image(self, image: np.ndarray, embryos: List[Dict]) -> np.ndarray: + def _create_annotated_image(self, image: np.ndarray, embryos: list[dict]) -> np.ndarray: """Create annotated image with numbered boxes""" viz = image.copy() if len(viz.shape) == 2: viz = cv2.cvtColor(viz, cv2.COLOR_GRAY2RGB) - colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), - (255, 0, 255), (0, 255, 255), (128, 128, 0), (128, 0, 128)] + colors = [ + (255, 0, 0), + (0, 255, 0), + (0, 0, 255), + (255, 255, 0), + (255, 0, 255), + (0, 255, 255), + (128, 128, 0), + (128, 0, 128), + ] for i, embryo in enumerate(embryos): - bbox = embryo['bbox'] + bbox = embryo["bbox"] x, y, w, h = bbox color = colors[i % len(colors)] @@ -713,7 +736,7 @@ def _encode_image_base64(self, image: np.ndarray) -> str: buffered = BytesIO() pil_image.save(buffered, format="JPEG", quality=quality, optimize=True) if buffered.tell() <= max_bytes: - return base64.b64encode(buffered.getvalue()).decode('utf-8') + return base64.b64encode(buffered.getvalue()).decode("utf-8") quality -= 5 # Last resort @@ -722,18 +745,21 @@ def _encode_image_base64(self, image: np.ndarray) -> str: pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS) buffered = BytesIO() pil_image.save(buffered, format="JPEG", quality=85, optimize=True) - return base64.b64encode(buffered.getvalue()).decode('utf-8') + return base64.b64encode(buffered.getvalue()).decode("utf-8") - async def _review_with_claude(self, image: np.ndarray, annotated: np.ndarray, embryos: List[Dict]) -> Dict: + async def _review_with_claude( + self, image: np.ndarray, annotated: np.ndarray, embryos: list[dict] + ) -> dict: """Round 1: Claude reviews detections (from test script)""" if not self.claude_client: - return {'false_positives': [], 'false_negatives': []} + return {"false_positives": [], "false_negatives": []} image_base64 = self._encode_image_base64(annotated) - prompt = f"""You are a microscopy expert analyzing embryo detections from a bottom camera view. + prompt = f"""You are a microscopy expert analyzing embryo detections from a bottom +camera view. -CURRENT DETECTIONS: {len(embryos)} embryos labeled 0-{len(embryos)-1} with colored bounding boxes. +CURRENT DETECTIONS: {len(embryos)} embryos labeled 0-{len(embryos) - 1} with colored bounding boxes. EMBRYO CHARACTERISTICS: - Small, BRIGHT white/light gray oval or rice grain shapes @@ -764,13 +790,22 @@ async def _review_with_claude(self, image: np.ndarray, annotated: np.ndarray, em model=settings.models.perception, max_tokens=8000, thinking={"type": "enabled", "budget_tokens": 5000}, - messages=[{ - "role": "user", - "content": [ - {"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": image_base64}}, - {"type": "text", "text": prompt} - ] - }] + messages=[ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": image_base64, + }, + }, + {"type": "text", "text": prompt}, + ], + } + ], ) response_text = next((b.text for b in message.content if b.type == "text"), "") @@ -787,18 +822,19 @@ async def _review_with_claude(self, image: np.ndarray, annotated: np.ndarray, em except Exception as e: logger.warning("Claude review failed: %s", e) - return {'false_positives': [], 'false_negatives': []} + return {"false_positives": [], "false_negatives": []} - async def _verify_with_claude(self, image: np.ndarray, annotated: np.ndarray, - embryos: List[Dict], previous_changes: Dict) -> Dict: + async def _verify_with_claude( + self, image: np.ndarray, annotated: np.ndarray, embryos: list[dict], previous_changes: dict + ) -> dict: """Round 2: Claude verifies corrections (from test script)""" if not self.claude_client: - return {'verified': True, 'skipped': True} + return {"verified": True, "skipped": True} image_base64 = self._encode_image_base64(annotated) - removed = previous_changes.get('removed', []) - added = previous_changes.get('added', []) + removed = previous_changes.get("removed", []) + added = previous_changes.get("added", []) prompt = f"""VERIFICATION ROUND - You previously reviewed this image. @@ -806,7 +842,7 @@ async def _verify_with_claude(self, image: np.ndarray, annotated: np.ndarray, - Removed: {removed if removed else "none"} - Added: {added if added else "none"} -CURRENT: {len(embryos)} detections (numbered 0-{len(embryos)-1}) +CURRENT: {len(embryos)} detections (numbered 0-{len(embryos) - 1}) TASK: Verify corrections and catch any remaining issues. Only report CLEAR remaining problems. @@ -824,13 +860,22 @@ async def _verify_with_claude(self, image: np.ndarray, annotated: np.ndarray, model=settings.models.perception, max_tokens=6000, thinking={"type": "enabled", "budget_tokens": 4000}, - messages=[{ - "role": "user", - "content": [ - {"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": image_base64}}, - {"type": "text", "text": prompt} - ] - }] + messages=[ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": image_base64, + }, + }, + {"type": "text", "text": prompt}, + ], + } + ], ) response_text = next((b.text for b in message.content if b.type == "text"), "") @@ -846,38 +891,41 @@ async def _verify_with_claude(self, image: np.ndarray, annotated: np.ndarray, except Exception as e: logger.warning("Verification failed: %s", e) - return {'verified': False} + return {"verified": False} - def _apply_corrections(self, embryos: List[Dict], review: Dict, - image: np.ndarray, predictor) -> Tuple[List[Dict], Dict]: + def _apply_corrections( + self, embryos: list[dict], review: dict, image: np.ndarray, predictor + ) -> tuple[list[dict], dict]: """Apply Claude's corrections (from test script)""" corrected = [] - changes = {'removed': [], 'added': []} + changes = {"removed": [], "added": []} # Remove false positives - false_positives = set(review.get('false_positives', [])) + false_positives = set(review.get("false_positives", [])) if false_positives: - changes['removed'] = list(false_positives) + changes["removed"] = list(false_positives) for i, embryo in enumerate(embryos): if i not in false_positives: corrected.append(embryo) # Add false negatives - false_negatives = review.get('false_negatives', []) + false_negatives = review.get("false_negatives", []) if false_negatives: for fn in false_negatives: - point = (fn['x'], fn['y']) + point = (fn["x"], fn["y"]) new_embryo = self._segment_with_sam(image, predictor, point) - if new_embryo and (self.min_area <= new_embryo['area'] <= self.max_area and - new_embryo['circularity'] >= self.min_circularity): + if new_embryo and ( + self.min_area <= new_embryo["area"] <= self.max_area + and new_embryo["circularity"] >= self.min_circularity + ): corrected.append(new_embryo) - changes['added'].append(point) + changes["added"].append(point) return corrected, changes - def _segment_with_sam(self, image: np.ndarray, predictor, point: Tuple) -> Optional[Dict]: + def _segment_with_sam(self, image: np.ndarray, predictor, point: tuple) -> dict | None: """Use SAM predictor to segment region (from test script)""" image_rgb = self._to_rgb8(image) predictor.set_image(image_rgb) @@ -886,9 +934,7 @@ def _segment_with_sam(self, image: np.ndarray, predictor, point: Tuple) -> Optio point_labels = np.array([1]) masks, scores, _ = predictor.predict( - point_coords=point_coords, - point_labels=point_labels, - multimask_output=True + point_coords=point_coords, point_labels=point_labels, multimask_output=True ) best_idx = np.argmax(scores) @@ -903,26 +949,33 @@ def _segment_with_sam(self, image: np.ndarray, predictor, point: Tuple) -> Optio bbox = [x_min, y_min, x_max - x_min, y_max - y_min] area = mask.sum() - contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + contours, _ = cv2.findContours( + mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE + ) if len(contours) > 0: perimeter = cv2.arcLength(contours[0], True) - circularity = 4 * np.pi * area / (perimeter ** 2) if perimeter > 0 else 0 + circularity = 4 * np.pi * area / (perimeter**2) if perimeter > 0 else 0 else: circularity = 0 return { - 'mask': mask, - 'bbox': bbox, - 'area': int(area), - 'circularity': float(circularity), - 'stability_score': float(scores[best_idx]), - 'predicted_iou': float(scores[best_idx]) + "mask": mask, + "bbox": bbox, + "area": int(area), + "circularity": float(circularity), + "stability_score": float(scores[best_idx]), + "predicted_iou": float(scores[best_idx]), } - def _pixel_to_stage_coordinates(self, embryos: List[Dict], stage_pos: Tuple[float, float], - pixel_size_um: float, objective_mag: float, - image_shape: Tuple[int, int] = (2048, 2048)) -> List[Dict]: + def _pixel_to_stage_coordinates( + self, + embryos: list[dict], + stage_pos: tuple[float, float], + pixel_size_um: float, + objective_mag: float, + image_shape: tuple[int, int] = (2048, 2048), + ) -> list[dict]: """ Convert pixel coordinates to stage coordinates. @@ -938,7 +991,7 @@ def _pixel_to_stage_coordinates(self, embryos: List[Dict], stage_pos: Tuple[floa embryo_positions = [] for i, embryo in enumerate(embryos): - bbox = embryo['bbox'] + bbox = embryo["bbox"] x, y, w, h = bbox center_x_px = x + w / 2 @@ -953,24 +1006,26 @@ def _pixel_to_stage_coordinates(self, embryos: List[Dict], stage_pos: Tuple[floa image_center_y=image_center_y, stage_x=stage_x, stage_y=stage_y, - um_per_pixel=effective_pixel_um + um_per_pixel=effective_pixel_um, ) - embryo_positions.append({ - 'embryo_id': f'embryo_{i + 1}', - 'pixel_x': float(center_x_px), - 'pixel_y': float(center_y_px), - 'stage_x_um': float(embryo_stage_x), - 'stage_y_um': float(embryo_stage_y), - 'bbox_pixel': tuple(bbox), - 'area_pixels': embryo.get('area_pixels', embryo.get('area', 0)), - 'circularity': embryo.get('circularity', 0), - 'confidence': embryo.get('confidence', embryo.get('stability_score', 0)) - }) + embryo_positions.append( + { + "embryo_id": f"embryo_{i + 1}", + "pixel_x": float(center_x_px), + "pixel_y": float(center_y_px), + "stage_x_um": float(embryo_stage_x), + "stage_y_um": float(embryo_stage_y), + "bbox_pixel": tuple(bbox), + "area_pixels": embryo.get("area_pixels", embryo.get("area", 0)), + "circularity": embryo.get("circularity", 0), + "confidence": embryo.get("confidence", embryo.get("stability_score", 0)), + } + ) return embryo_positions - def show_in_napari(self, image: np.ndarray, embryos: List[Dict], block: bool = False): + def show_in_napari(self, image: np.ndarray, embryos: list[dict], block: bool = False): """Deprecated: napari display was retired in Phase 1. SAM detection results are now reviewed via the web map view — diff --git a/gently/hardware/switchbot.py b/gently/hardware/switchbot.py new file mode 100644 index 00000000..a295a082 --- /dev/null +++ b/gently/hardware/switchbot.py @@ -0,0 +1,250 @@ +""" +SwitchBot Bot (WoHand) control as a Bluesky/Ophyd-protocol device. + +The SwitchBot Bot is a Bluetooth-LE button pusher. In "Switch mode" it supports +explicit on/off; in "Press mode" it does a momentary press. This module talks to +it directly over BLE via ``bleak`` using the documented GATT command protocol — +no SwitchBot cloud, no hub. + +The device follows the same duck-typed Bluesky protocol as the diSPIM devices +(see ``gently.hardware.dispim.devices.optical.DiSPIMLED``): ``set(state)`` returns +an ophyd ``Status``, plus ``read()``/``describe()``. So it drops into plans via +``yield from bps.mv(bot, 'on')``. + +BLE I/O is async (``bleak``); ``set()`` runs a fresh connect→write→disconnect +cycle in a worker thread and resolves the Status when done. Connecting per command +keeps the implementation robust (no stale-connection handling) at the cost of +~1-2 s latency, which is fine for a low-frequency accessory. For lower latency or +encrypted/password-protected Bots, swap the ``_send_command`` body for PySwitchbot. + +Self-test (drives a real Bot):: + + python gently/hardware/switchbot.py EC:6F:04:06:5B:23 on off +""" + +from __future__ import annotations + +import asyncio +import logging +import threading +import time +from collections import OrderedDict + +logger = logging.getLogger(__name__) + +# SwitchBot Bot GATT. Note the UUID group is 9fb8 — the widely-copied 9fb9 is wrong. +_CTRL_CHAR = "cba20002-224d-11e6-9fb8-0002a5d5c51b" # write / write-without-response +_NOTIFY_CHAR = "cba20003-224d-11e6-9fb8-0002a5d5c51b" # notify (command response) + +_COMMANDS = { + "on": bytes([0x57, 0x01, 0x01]), + "off": bytes([0x57, 0x01, 0x02]), + "press": bytes([0x57, 0x01, 0x00]), +} +# Dedicated status query: returns battery %, firmware version, mode flags. +# This is the only reliable source of battery — action-command responses +# also include status bytes but in a different format (byte 1 there isn't +# battery despite what's documented for older firmware). +_QUERY_STATUS = bytes([0x57, 0x02]) +# Status-query response format (firmware ≥ 6.x): +# byte 0 = 0x01 success +# byte 1 = battery % +# byte 2 = firmware version (BCD-ish: high nibble.low nibble — 0x42 = v4.2) +# bytes 3+ = mode flags / timer count / counters (firmware-dependent) +_STATUS_BATTERY_IDX = 1 +_STATUS_FIRMWARE_IDX = 2 +# First byte of the success response. Older Bot firmware returns 0x01 alone; +# modern firmware (≥ Bot v4.x) returns 0x05 for action commands followed by +# action-status bytes. Both are "command landed" — the action payload format +# differs from the status-query payload format, so don't reuse parsers. +_RESP_OK = (0x01, 0x05) + + +class SwitchBotError(RuntimeError): + """BLE I/O failed, timed out, or the Bot reported a non-OK response.""" + + +async def _send_command(address: str, command: bytes, timeout: float) -> bytes: + """Connect, send one command, await the response notification, disconnect. + + Returns the raw response bytes; raises SwitchBotError on timeout or non-OK. + """ + from bleak import BleakClient # lazy import keeps module import cheap + + response: dict[str, bytes] = {} + got = asyncio.Event() + + def _on_notify(_char, data: bytearray) -> None: + response["data"] = bytes(data) + got.set() + + async with BleakClient(address, timeout=timeout) as client: + await client.start_notify(_NOTIFY_CHAR, _on_notify) + await client.write_gatt_char(_CTRL_CHAR, command, response=True) + try: + await asyncio.wait_for(got.wait(), timeout=timeout) + except asyncio.TimeoutError as exc: + raise SwitchBotError("no response notification from SwitchBot") from exc + finally: + try: + await client.stop_notify(_NOTIFY_CHAR) + except Exception: # disconnect cleanup is best-effort + pass + + data = response["data"] + if not data or data[0] not in _RESP_OK: + raise SwitchBotError(f"SwitchBot returned non-OK response: {data.hex()}") + return data + + +class SwitchBot: + """Bluesky-protocol device for a SwitchBot Bot button pusher. + + Parameters + ---------- + address : str + BLE MAC address, e.g. "EC:6F:04:06:5B:23". + name : str + Device name used as the key in plans and read() output. + timeout : float + Per-command BLE connect/response timeout in seconds. + + Valid states for ``set``: 'on', 'off', 'press'. + """ + + def __init__(self, address: str, name: str = "switchbot", *, timeout: float = 20.0): + self.address = address + self.name = name + self.timeout = timeout + self.parent = None # required for Bluesky bps.mv() + self._state = "unknown" # last commanded on/off state + # Status fields populated only by read_status(). Left as None until + # first contact — action commands deliberately don't update these, + # see note on _STATUS_BATTERY_IDX above. + self._battery_pct: int | None = None + self._firmware: int | None = None + self._status_ts: float | None = None + self._lock = threading.Lock() # serialize BLE access (one radio, one bot) + + # -- Bluesky settable protocol ------------------------------------------- + def set(self, state: str): + """Send on/off/press. Returns an ophyd Status that finishes when done.""" + from ophyd.status import Status + + state = str(state).lower() + if state not in _COMMANDS: + raise ValueError(f"state {state!r} not in {list(_COMMANDS)}") + + status = Status(obj=self, timeout=self.timeout + 5) + + def worker(): + with self._lock: + try: + data = asyncio.run(_send_command(self.address, _COMMANDS[state], self.timeout)) + except Exception as exc: + logger.warning("SwitchBot %s set(%s) failed: %s", self.name, state, exc) + status.set_exception(exc) + return + if state in ("on", "off"): + self._state = state + logger.info("SwitchBot %s -> %s (resp %s)", self.name, state, data.hex()) + status.set_finished() + + threading.Thread(target=worker, name=f"{self.name}-set", daemon=True).start() + return status + + # -- Dedicated status query (no actuation) ------------------------------- + def read_status(self) -> dict: + """Query battery / firmware / mode without touching the switch arm. + + Synchronous: runs its own BLE connect → query → disconnect on the + caller's thread. Updates the cached status fields on success so + read() surfaces fresh values to the device-state stream. Use this + for periodic polls (~hourly is fine; battery doesn't move quickly). + + Returns a dict ``{battery_pct, firmware, raw_hex}``; raises + SwitchBotError on BLE / protocol failure. + """ + with self._lock: + data = asyncio.run(_send_command(self.address, _QUERY_STATUS, self.timeout)) + info = { + "raw_hex": data.hex(), + "battery_pct": data[_STATUS_BATTERY_IDX] if len(data) > _STATUS_BATTERY_IDX else None, + "firmware": data[_STATUS_FIRMWARE_IDX] if len(data) > _STATUS_FIRMWARE_IDX else None, + } + if info["battery_pct"] is not None: + self._battery_pct = info["battery_pct"] + if info["firmware"] is not None: + self._firmware = info["firmware"] + self._status_ts = time.time() + logger.info("SwitchBot %s status: %s", self.name, info) + return info + + # -- Bluesky readable protocol ------------------------------------------- + def read(self): + ts = time.time() + out = OrderedDict({self.name: {"value": self._state, "timestamp": ts}}) + if self._battery_pct is not None: + out[f"{self.name}_battery_pct"] = { + "value": self._battery_pct, + "timestamp": self._status_ts or ts, + } + if self._firmware is not None: + out[f"{self.name}_firmware"] = { + "value": self._firmware, + "timestamp": self._status_ts or ts, + } + return out + + def describe(self): + return OrderedDict( + { + self.name: { + "source": f"switchbot:{self.address}", + "dtype": "string", + "shape": [], + }, + f"{self.name}_battery_pct": { + "source": f"switchbot:{self.address}", + "dtype": "integer", + "shape": [], + }, + f"{self.name}_firmware": { + "source": f"switchbot:{self.address}", + "dtype": "integer", + "shape": [], + }, + } + ) + + def read_configuration(self): + return OrderedDict() + + def describe_configuration(self): + return OrderedDict() + + +if __name__ == "__main__": + # Standalone self-test, e.g.: python gently/hardware/switchbot.py AA:BB:.. on off + import sys + + address = "EC:6F:04:06:5B:23" + cmds = [] + for arg in sys.argv[1:]: + if ":" in arg and len(arg) >= 17: # looks like a MAC address + address = arg + else: + cmds.append(arg.lower()) + cmds = cmds or ["on", "off"] + + logging.basicConfig(level=logging.INFO, format="%(message)s") + bot = SwitchBot(address) + print(f"SwitchBot {address} — sequence: {cmds}\n") + for i, cmd in enumerate(cmds): + print(f"set({cmd!r}) ...") + st = bot.set(cmd) + st.wait(30) # blocks; raises if the command failed + print(f" done; read() -> {bot.read()[bot.name]['value']}") + if i != len(cmds) - 1: + time.sleep(1.5) + print("\nOK") diff --git a/gently/hardware/temperature.py b/gently/hardware/temperature.py new file mode 100644 index 00000000..dba8865b --- /dev/null +++ b/gently/hardware/temperature.py @@ -0,0 +1,261 @@ +""" +ACUITYnano Precision Thermal Controller as a Bluesky/Ophyd-protocol device. + +Wraps the vendor SDK — a Peltier/TEC water-cooled controller, 0.0-99.9 C. Two +transports expose the same core API: + - USB serial : acuitynano_precision_thermalizer_serial (vendor-recommended + for closed-loop automation; zero-latency) + - MQTT : acuitynano_precision_thermalizer_api (multi-client; adds + get_peltier_temp()) + +The device follows the same duck-typed Bluesky protocol as the diSPIM devices +(see gently.hardware.dispim.devices.optical.DiSPIMLED). A temperature controller +is the textbook bluesky "settable that completes on stabilization": + + yield from bps.mv(temperature, 20.0) # blocks until the controller LOCKS + +set(target) commands the setpoint, enables the TEC, and returns an ophyd Status +that finishes only when the controller reports "[ SYSTEM LOCKED ]" (or raises on +timeout). read() reports the live water temperature (plus setpoint / state, and +peltier temp when the transport provides it). BLE-style work runs in a worker +thread so the Status integrates with the RunEngine. + +NOTE: the vendor `acuitynano_precision_thermalizer_*` packages are NOT on PyPI — +install them on the device-layer machine. Local logic can be exercised with the +built-in mock backend: `python gently/hardware/temperature.py --mock 20`. +""" + +from __future__ import annotations + +import logging +import threading +import time +from collections import OrderedDict + +logger = logging.getLogger(__name__) + +TEMP_MIN_C = 0.0 +TEMP_MAX_C = 99.9 + + +def _make_backend(cfg: dict): + """Construct the vendor SDK transport from a config mapping.""" + backend = str(cfg.get("backend", "serial")).lower() + if backend == "mock": + return _MockBackend() + if backend == "serial": + from acuitynano_precision_thermalizer_serial import ( + AcuityNanoPrecisionThermalizerSerial, + ) + + return AcuityNanoPrecisionThermalizerSerial( + cfg["com_port"], baud_rate=cfg.get("baud_rate", 115200) + ) + if backend == "mqtt": + from acuitynano_precision_thermalizer_api import ( + AcuityNanoPrecisionThermalizerAPI, + ) + + # The vendor package ships with an embedded HiveMQ Cloud broker + creds, + # so MQTT can run with no config. Pass only the keys actually provided, + # to override those embedded defaults (and keep secrets in config, not code). + kwargs = {k: cfg[k] for k in ("broker", "port", "user", "password") if k in cfg} + return AcuityNanoPrecisionThermalizerAPI(**kwargs) + raise ValueError(f"unknown temperature backend {backend!r} (use 'serial', 'mqtt', or 'mock')") + + +def create_temperature_controller(cfg: dict) -> TemperatureController: + """Factory used by the device layer: build transport + wrap as a device.""" + backend = _make_backend(cfg) + if "feedback_peltier" in cfg and hasattr(backend, "set_feedback_sensor"): + backend.set_feedback_sensor(use_peltier=bool(cfg["feedback_peltier"])) + return TemperatureController( + backend, + name=cfg.get("name", "temperature"), + stabilize_timeout=cfg.get("stabilize_timeout", 600.0), + ) + + +class TemperatureController: + """Bluesky-protocol device for the ACUITYnano thermal controller. + + Parameters + ---------- + backend : object + Vendor SDK instance exposing set_temperature / get_water_temp / + get_system_state / enable_tec / wait_for_target. + name : str + Device name; the registry key and primary read() field. + stabilize_timeout : float + Seconds to wait for "[ SYSTEM LOCKED ]" before set() fails. + """ + + def __init__(self, backend, name: str = "temperature", *, stabilize_timeout: float = 600.0): + self._dev = backend + self.name = name + self.stabilize_timeout = float(stabilize_timeout) + self.parent = None # required for Bluesky bps.mv() + self._setpoint = None # last commanded target + self._lock = threading.Lock() + + # -- Bluesky settable protocol ------------------------------------------- + def set(self, target_c): + """Command setpoint + enable TEC; Status finishes when the system locks.""" + from ophyd.status import Status + + target = float(target_c) + if not (TEMP_MIN_C <= target <= TEMP_MAX_C): + raise ValueError(f"target {target} C outside [{TEMP_MIN_C}, {TEMP_MAX_C}]") + + status = Status(obj=self, timeout=self.stabilize_timeout + 30) + + def worker(): + with self._lock: + try: + self._dev.set_temperature(target) # vendor also validates range + self._dev.enable_tec(True) + locked = self._dev.wait_for_target(timeout_seconds=self.stabilize_timeout) + except Exception as exc: + logger.warning("temperature %s set(%.2f) failed: %s", self.name, target, exc) + status.set_exception(exc) + return + self._setpoint = target + if locked: + logger.info("temperature %s locked at %.2f C", self.name, target) + status.set_finished() + else: + status.set_exception( + TimeoutError( + f"{self.name} did not stabilize at {target} C" + f" within {self.stabilize_timeout}s" + ) + ) + + threading.Thread(target=worker, name=f"{self.name}-set", daemon=True).start() + return status + + # -- Explicit controls (outside the bps.mv() path) ----------------------- + def enable(self, on: bool = True): + self._dev.enable_tec(bool(on)) + + def setpoint(self, target_c): + """Command the setpoint without blocking for stabilization.""" + self._dev.set_temperature(float(target_c)) + + # -- Bluesky readable protocol ------------------------------------------- + def read(self): + now = time.time() + data = OrderedDict() + data[self.name] = { + "value": self._safe(self._dev.get_water_temp), + "timestamp": now, + } + data[f"{self.name}_setpoint"] = {"value": self._setpoint, "timestamp": now} + data[f"{self.name}_state"] = { + "value": self._safe(self._dev.get_system_state, default="unknown"), + "timestamp": now, + } + if self._has_peltier(): + data[f"{self.name}_peltier"] = { + "value": self._safe(self._dev.get_peltier_temp), + "timestamp": now, + } + return data + + def describe(self): + src = f"acuitynano:{self.name}" + d = OrderedDict() + d[self.name] = {"source": src, "dtype": "number", "shape": []} + d[f"{self.name}_setpoint"] = {"source": src, "dtype": "number", "shape": []} + d[f"{self.name}_state"] = {"source": src, "dtype": "string", "shape": []} + if self._has_peltier(): + d[f"{self.name}_peltier"] = {"source": src, "dtype": "number", "shape": []} + return d + + def read_configuration(self): + return OrderedDict() + + def describe_configuration(self): + return OrderedDict() + + def close(self): + """Release the transport (serial port / MQTT client) on shutdown.""" + for method in ("close", "disconnect"): + fn = getattr(self._dev, method, None) + if fn is not None: + try: + fn() + except Exception: + pass + return + + # -- helpers -------------------------------------------------------------- + def _has_peltier(self) -> bool: + return getattr(self._dev, "get_peltier_temp", None) is not None + + @staticmethod + def _safe(fn, default=None): + try: + return fn() + except Exception: + return default + + +class _MockBackend: + """In-memory fake mirroring the vendor API, for local testing without hardware.""" + + def __init__(self, *args, **kwargs): + self._target = 25.0 + self._enabled = False + + def set_temperature(self, t): + if not (TEMP_MIN_C <= float(t) <= TEMP_MAX_C): + raise ValueError("Target must be between 0.0 and 99.9 C") + self._target = float(t) + + def enable_tec(self, on): + self._enabled = bool(on) + + def set_feedback_sensor(self, use_peltier=False): + pass + + def wait_for_target(self, timeout_seconds=300): + time.sleep(0.5) # pretend to ramp + settle + return True + + def get_water_temp(self): + return self._target + + def get_peltier_temp(self): + return self._target - 1.0 + + def get_system_state(self): + return "[ SYSTEM LOCKED ]" if self._enabled else "[ IDLE ]" + + def close(self): + pass + + +if __name__ == "__main__": + import sys + + logging.basicConfig(level=logging.INFO, format="%(message)s") + if "--mock" in sys.argv: + target = 20.0 + for arg in sys.argv[1:]: + try: + target = float(arg) + break + except ValueError: + continue + dev = TemperatureController(_MockBackend(), name="temperature", stabilize_timeout=10) + print(f"[mock] set({target}) — blocks until locked ...") + st = dev.set(target) + st.wait(15) + print("[mock] read ->", {k: v["value"] for k, v in dev.read().items()}) + print("OK") + else: + print( + "Real-hardware self-test needs the vendor SDK + a controller. " + "Run with --mock to exercise the device logic locally." + ) diff --git a/gently/harness/bridge.py b/gently/harness/bridge.py index 4561b0f5..ddbfa9ca 100644 --- a/gently/harness/bridge.py +++ b/gently/harness/bridge.py @@ -8,12 +8,11 @@ """ import asyncio -import json import logging -import time -from typing import Any, Callable, Coroutine, Dict, Optional +from collections.abc import Callable, Coroutine +from typing import Any -from .commands import get_command_registry, CommandCategory +from .commands import get_command_registry logger = logging.getLogger(__name__) @@ -39,10 +38,53 @@ class AgentBridge: def __init__(self, agent): self.agent = agent - self._launch_info: Dict[str, Any] = {} + self._launch_info: dict[str, Any] = {} self._wizard = None # StartupWizard, set by init_wizard() - self._active_remote: Optional[Dict[str, Any]] = None # {"peer": PeerInfo, "campaign_id": str} - self._pending_import: Optional[Dict] = None # For /import-embryos picker + self._active_remote: dict[str, Any] | None = None # {"peer": PeerInfo, "campaign_id": str} + self._pending_import: dict | None = None # For /import-embryos picker + # Set by the web layer (register_display_broadcaster) so AGENT-INITIATED + # turns (the wake-router) can stream to all chat clients + the transcript. + self._display_broadcaster: Callable | None = None + + def register_display_broadcaster( + self, broadcast_fn, choice_factory=None, choice_discard=None + ) -> None: + """Register the web layer's broadcast fn for autonomous (wake) turns. + + The wake-router has no per-connection send_fn, so to make autonomous + turns visible we route their chunks through the same _broadcast the web + route uses for user turns (records to the display transcript AND fans out + to every connected chat client). Also wires the agent's dangling + on_message_callback to this path, and (for ASK mode) the choice-future + factory so an autonomous turn can round-trip an approval picker. + Idempotent — last registration wins; the registered fns are router-scoped + and fan out to whoever is connected. + """ + self._display_broadcaster = broadcast_fn + try: + self.agent.on_message_callback = self.broadcast_autonomous_chunk + if choice_factory is not None: + self.agent._wake_choice_factory = choice_factory + if choice_discard is not None: + self.agent._wake_choice_discard = choice_discard + except Exception: + pass + + async def broadcast_autonomous_chunk(self, chunk) -> None: + """Fan one autonomous-turn chunk to all chat clients + the transcript. + + No-op when no web client has registered a broadcaster (headless run) — + the wake turn still executes and is persisted to the conversation/log. + """ + fn = self._display_broadcaster + if fn is None: + return + try: + res = fn(chunk) + if asyncio.iscoroutine(res): + await res + except Exception: + logger.debug("broadcast_autonomous_chunk failed", exc_info=True) async def handle_choice_response(self, request_id: str, selected: str, send_fn) -> bool: """Handle a choice response that may belong to a bridge-initiated picker. @@ -57,7 +99,7 @@ async def handle_choice_response(self, request_id: str, selected: str, send_fn) return True return False - def set_launch_info(self, info: Dict[str, Any]) -> None: + def set_launch_info(self, info: dict[str, Any]) -> None: """Store launch metadata to include in the connect message.""" self._launch_info = info @@ -80,13 +122,13 @@ def get_session_briefing(self) -> str: returns an empty string — the agent's first turn becomes the opener instead of static text. """ - if not hasattr(self.agent, 'memory') or not self.agent.memory: + if not hasattr(self.agent, "memory") or not self.agent.memory: return "" memory = self.agent.memory # Restore active_plan_item_id from experiment state (session resume) - experiment = getattr(self.agent, 'experiment', None) + experiment = getattr(self.agent, "experiment", None) if experiment and experiment.active_plan_item_id: memory.active_plan_item_id = experiment.active_plan_item_id elif experiment: @@ -98,13 +140,14 @@ def get_session_briefing(self) -> str: logger.info(f"Auto-set active plan item: {active_id}") # Link session to the campaign - cs = getattr(self.agent, 'context_store', None) + cs = getattr(self.agent, "context_store", None) if cs and self.agent.session_id: try: item = cs.get_plan_item(active_id) if item: cs.link_session_campaign( - self.agent.session_id, item.campaign_id, + self.agent.session_id, + item.campaign_id, ) except Exception: pass @@ -114,7 +157,7 @@ def get_session_briefing(self) -> str: # Invalidate prompt cache so the system prompt picks up the # active plan item on the next message - prompts = getattr(self.agent, 'prompts', None) + prompts = getattr(self.agent, "prompts", None) if prompts and memory.active_plan_item_id: prompts.invalidate_context_cache() @@ -130,9 +173,9 @@ def should_enter_resolution(self) -> bool: replaces the older O(memory-scan) ``resolve_plan_context()`` gate that was making the picker take seconds to appear at startup. """ - if not hasattr(self.agent, 'memory') or not self.agent.memory: + if not hasattr(self.agent, "memory") or not self.agent.memory: return False - experiment = getattr(self.agent, 'experiment', None) + experiment = getattr(self.agent, "experiment", None) if experiment is None: return False if experiment.active_plan_item_id: @@ -167,7 +210,7 @@ def _sort_key(t): candidates.sort(key=_sort_key) return candidates - def _candidate_to_option(self, item, spec, campaign) -> Dict: + def _candidate_to_option(self, item, spec, campaign) -> dict: """Turn a ``(item, spec, campaign)`` tuple into a picker option.""" memory = getattr(self.agent, "memory", None) spec_summary = "" @@ -177,10 +220,11 @@ def _candidate_to_option(self, item, spec, campaign) -> Dict: except Exception: spec_summary = "" - meta: Dict[str, Any] = {} + meta: dict[str, Any] = {} if campaign is not None: - c_name = getattr(campaign, "shorthand", None) or ( - (getattr(campaign, "description", "") or "")[:60] + c_name = ( + getattr(campaign, "shorthand", None) + or ((getattr(campaign, "description", "") or "")[:60]) ) if c_name: meta["campaign"] = c_name @@ -194,8 +238,7 @@ def _candidate_to_option(self, item, spec, campaign) -> Dict: order = getattr(item, "phase_order", 0) or 0 if total > 0: meta["sequence"] = ( - f"{order} of {total} · {done} done" - if order else f"{done}/{total} done" + f"{order} of {total} · {done} done" if order else f"{done}/{total} done" ) except Exception: pass @@ -205,10 +248,15 @@ def _candidate_to_option(self, item, spec, campaign) -> Dict: meta["status"] = status_val if spec is not None: - spec_dict: Dict[str, Any] = {} + spec_dict: dict[str, Any] = {} for field in ( - "strain", "temperature_c", "num_slices", "exposure_ms", - "interval_s", "stop_condition", "success_criteria", + "strain", + "temperature_c", + "num_slices", + "exposure_ms", + "interval_s", + "stop_condition", + "success_criteria", ): val = getattr(spec, field, None) if val is not None: @@ -224,8 +272,10 @@ def _candidate_to_option(self, item, spec, campaign) -> Dict: } def _build_resolution_choice_payload( - self, candidates: list, full_list: bool = False, - ) -> Dict: + self, + candidates: list, + full_list: bool = False, + ) -> dict: """Build the ``choice_data`` payload for the resolution picker. Top 3 candidates by default; ``full_list=True`` shows up to 20 @@ -238,26 +288,33 @@ def _build_resolution_choice_payload( if not full_list and len(candidates) > show_n: remaining = len(candidates) - show_n - options.append({ - "id": "show_all", - "label": f"See all imaging tasks ({remaining} more)…", - "description": "Browse the full unblocked list", - }) - - options.append({ - "id": "standalone", - "label": "Standalone — just exploring", - "description": "No plan attached, default settings", - }) - options.append({ - "id": "plan_new", - "label": "Design a new plan", - "description": "Enter plan mode first", - }) + options.append( + { + "id": "show_all", + "label": f"See all imaging tasks ({remaining} more)…", + "description": "Browse the full unblocked list", + } + ) + + options.append( + { + "id": "standalone", + "label": "Standalone — just exploring", + "description": "No plan attached, default settings", + } + ) + options.append( + { + "id": "plan_new", + "label": "Design a new plan", + "description": "Enter plan mode first", + } + ) question = ( f"All {len(candidates)} unblocked imaging tasks — pick one:" - if full_list else "What is this session for?" + if full_list + else "What is this session for?" ) return { @@ -270,8 +327,8 @@ def _build_resolution_choice_payload( async def bootstrap_resolution_picker( self, - send_fn: Callable[[Dict], Coroutine], - choice_future_factory: Callable[[Dict], "asyncio.Future[str]"], + send_fn: Callable[[dict], Coroutine], + choice_future_factory: Callable[[dict], "asyncio.Future[str]"], ) -> None: """Open the session with a deterministic resolution picker. @@ -303,9 +360,7 @@ async def bootstrap_resolution_picker( # Kick off the slow candidate scan in the background. By the time # the user picks "Continue", this is almost always already done. - candidates_task = asyncio.create_task( - asyncio.to_thread(self._build_resolution_candidates) - ) + candidates_task = asyncio.create_task(asyncio.to_thread(self._build_resolution_candidates)) # --- Phase 1: fast top-level question ------------------------- top_payload = { @@ -332,11 +387,13 @@ async def bootstrap_resolution_picker( top_request_id = f"resolve_top_{_uuid.uuid4().hex[:8]}" top_payload["request_id"] = top_request_id top_future = choice_future_factory(top_payload) - await send_fn({ - "type": "choice_request", - "choice_data": top_payload, - "request_id": top_request_id, - }) + await send_fn( + { + "type": "choice_request", + "choice_data": top_payload, + "request_id": top_request_id, + } + ) try: top_choice = await top_future @@ -349,7 +406,9 @@ async def bootstrap_resolution_picker( if top_choice != "resume_plan": candidates_task.cancel() await self._dispatch_resolution_pick( - top_choice or "standalone", send_fn, choice_future_factory, + top_choice or "standalone", + send_fn, + choice_future_factory, ) return @@ -366,11 +425,13 @@ async def bootstrap_resolution_picker( if briefing: await send_fn({"type": "stream_start"}) await send_fn({"type": "text", "text": briefing}) - await send_fn({ - "type": "stream_end", - "tokens": self._get_token_snapshot(), - "mode": self.agent.mode, - }) + await send_fn( + { + "type": "stream_end", + "tokens": self._get_token_snapshot(), + "mode": self.agent.mode, + } + ) else: await self._emit_resolution_result( send_fn, @@ -382,7 +443,8 @@ async def bootstrap_resolution_picker( full_list = False while True: payload = self._build_resolution_choice_payload( - candidates, full_list=full_list, + candidates, + full_list=full_list, ) # Re-label the question for the secondary picker so the user # knows they're now picking the specific plan item. @@ -390,11 +452,13 @@ async def bootstrap_resolution_picker( request_id = f"resolve_pick_{_uuid.uuid4().hex[:8]}" payload["request_id"] = request_id future = choice_future_factory(payload) - await send_fn({ - "type": "choice_request", - "choice_data": payload, - "request_id": request_id, - }) + await send_fn( + { + "type": "choice_request", + "choice_data": payload, + "request_id": request_id, + } + ) try: selected = await future except asyncio.CancelledError: @@ -405,15 +469,17 @@ async def bootstrap_resolution_picker( continue await self._dispatch_resolution_pick( - selected, send_fn, choice_future_factory, + selected, + send_fn, + choice_future_factory, ) return async def _dispatch_resolution_pick( self, selected: str, - send_fn: Callable[[Dict], Coroutine], - choice_future_factory: Callable[[Dict], "asyncio.Future[str]"], + send_fn: Callable[[dict], Coroutine], + choice_future_factory: Callable[[dict], "asyncio.Future[str]"], ) -> None: """Apply the user's resolution pick. @@ -456,7 +522,8 @@ async def _dispatch_resolution_pick( logger.error(f"attach_session_to_plan failed: {e}", exc_info=True) await self._emit_resolution_result( send_fn, - "Couldn't attach the session to that plan item. You can try again or pick standalone.", + "Couldn't attach the session to that plan item." + " You can try again or pick standalone.", ) return @@ -472,7 +539,9 @@ async def _dispatch_resolution_pick( spec_dict = self._get_active_plan_spec() closer = self._compose_attach_closer(spec_dict) await self._emit_resolution_result( - send_fn, closer, applied_spec=spec_dict, + send_fn, + closer, + applied_spec=spec_dict, ) return @@ -498,7 +567,8 @@ async def _dispatch_resolution_pick( logger.error(f"enter_plan_mode failed: {e}", exc_info=True) msg = "Plan mode active." await self._emit_resolution_result( - send_fn, msg or "Plan mode — what are we designing?", + send_fn, + msg or "Plan mode — what are we designing?", ) return @@ -510,14 +580,17 @@ async def _dispatch_resolution_pick( except Exception as e: logger.warning(f"enter_resolution_mode failed: {e}") await self._emit_resolution_result( - send_fn, "Couldn't enter resolution mode.", + send_fn, + "Couldn't enter resolution mode.", ) return await self.stream_response( - selected or "(no input)", send_fn, choice_future_factory, + selected or "(no input)", + send_fn, + choice_future_factory, ) - def _get_active_plan_spec(self) -> Optional[Dict]: + def _get_active_plan_spec(self) -> dict | None: """Return the ``active_plan_spec`` dict stashed on ``experiment.metadata`` by ``apply_plan_acquisition_spec``.""" try: @@ -526,35 +599,40 @@ def _get_active_plan_spec(self) -> Optional[Dict]: except Exception: return None - def _compose_attach_closer(self, spec_dict: Optional[Dict]) -> str: + def _compose_attach_closer(self, spec_dict: dict | None) -> str: """One-line conversational closer to follow attach + apply.""" title = (spec_dict or {}).get("plan_item_title") or "this plan item" return f"Attached to **{title}**. Mark embryo positions when you're ready." async def _emit_resolution_result( self, - send_fn: Callable[[Dict], Coroutine], + send_fn: Callable[[dict], Coroutine], closer_text: str, - applied_spec: Optional[Dict] = None, + applied_spec: dict | None = None, ) -> None: """Emit a deterministic stream_start → text → stream_end pair, followed by an optional ``applied_spec`` panel message.""" await send_fn({"type": "stream_start"}) await send_fn({"type": "text", "text": closer_text}) - await send_fn({ - "type": "stream_end", - "tokens": self._get_token_snapshot(), - "mode": self.agent.mode, - }) + await send_fn( + { + "type": "stream_end", + "tokens": self._get_token_snapshot(), + "mode": self.agent.mode, + } + ) if applied_spec: - await send_fn({ - "type": "applied_spec", - "spec": applied_spec, - }) + await send_fn( + { + "type": "applied_spec", + "spec": applied_spec, + } + ) def init_wizard(self, context_store, claude_client=None) -> None: """Create the startup wizard from a ContextStore.""" from .memory.startup_wizard import StartupWizard + self._context_store = context_store self._wizard = StartupWizard( context_store=context_store, @@ -565,8 +643,8 @@ def init_wizard(self, context_store, claude_client=None) -> None: async def stream_response( self, message: str, - send_fn: Callable[[Dict], Coroutine], - choice_future_factory: Callable[[Dict], "asyncio.Future[str]"], + send_fn: Callable[[dict], Coroutine], + choice_future_factory: Callable[[dict], "asyncio.Future[str]"], ) -> None: """ Stream an agent response over WebSocket. @@ -594,11 +672,13 @@ async def stream_response( chunk = await stream_iter.__anext__() except StopAsyncIteration: # Stream finished — send token usage summary + current mode - await send_fn({ - "type": "stream_end", - "tokens": self._get_token_snapshot(), - "mode": self.agent.mode, - }) + await send_fn( + { + "type": "stream_end", + "tokens": self._get_token_snapshot(), + "mode": self.agent.mode, + } + ) return chunk_type = chunk.get("type") @@ -623,12 +703,22 @@ async def stream_response( except Exception as e: logger.error(f"Stream error: {e}", exc_info=True) await send_fn({"type": "error", "error": str(e)}) + finally: + # Deterministically close the agent generator so its turn-lock (and + # any other resources) release immediately. Without this, a cancelled + # or aborted stream leaves the generator suspended at a `yield` still + # holding self._turn_lock until non-deterministic GC, stalling the + # next user turn and any autonomous wake turn on lock.acquire(). + try: + await stream_iter.aclose() + except Exception: + pass async def handle_command( self, command: str, - send_fn: Callable[[Dict], Coroutine], - choice_futures: Dict = None, + send_fn: Callable[[dict], Coroutine], + choice_futures: dict | None = None, ) -> None: """ Execute a slash command and send the result. @@ -647,69 +737,87 @@ async def handle_command( cmd = command.strip().lower() cmd_name = cmd.split()[0] cmd_def = registry.get(cmd_name) - logger.info("handle_command: %s (resolved: %s)", cmd_name, cmd_def.name if cmd_def else "NOT FOUND") + logger.info( + "handle_command: %s (resolved: %s)", + cmd_name, + cmd_def.name if cmd_def else "NOT FOUND", + ) if not cmd_def: - await send_fn({ - "type": "command_result", - "command": command, - "error": f"Unknown command: {cmd_name}", - }) + await send_fn( + { + "type": "command_result", + "command": command, + "error": f"Unknown command: {cmd_name}", + } + ) return if cmd in ("/quit", "/exit", "/q"): - await send_fn({ - "type": "command_result", - "command": cmd, - "action": "quit", - }) + await send_fn( + { + "type": "command_result", + "command": cmd, + "action": "quit", + } + ) return if cmd == "/status": status = self._get_status_data() - await send_fn({ - "type": "command_result", - "command": "/status", - "content": status, - }) + await send_fn( + { + "type": "command_result", + "command": "/status", + "content": status, + } + ) return if cmd in ("/peers", "/mesh") or cmd.startswith("/peers ") or cmd.startswith("/mesh "): parts = command.strip().split() if len(parts) >= 3 and parts[2].lower() == "campaigns": data = await self._get_peer_campaigns(parts[1]) - await send_fn({ - "type": "command_result", - "command": "/peers", - "content": data, - }) + await send_fn( + { + "type": "command_result", + "command": "/peers", + "content": data, + } + ) else: data = self._get_peers_data() - await send_fn({ - "type": "command_result", - "command": "/peers", - "content": data, - }) + await send_fn( + { + "type": "command_result", + "command": "/peers", + "content": data, + } + ) return if cmd == "/embryos" or cmd.startswith("/embryos "): parts = cmd.split(maxsplit=1) embryo_id = parts[1].strip() if len(parts) > 1 else None data = self._get_embryos_data(embryo_id) - await send_fn({ - "type": "command_result", - "command": "/embryos", - "content": data, - }) + await send_fn( + { + "type": "command_result", + "command": "/embryos", + "content": data, + } + ) return if cmd == "/tokens": data = self._get_tokens_data() - await send_fn({ - "type": "command_result", - "command": "/tokens", - "content": data, - }) + await send_fn( + { + "type": "command_result", + "command": "/tokens", + "content": data, + } + ) return if cmd == "/help" or cmd.startswith("/help "): @@ -721,111 +829,137 @@ async def handle_command( text = f"Unknown command: {help_cmd}" else: text = registry.generate_help_markdown() - await send_fn({ - "type": "command_result", - "command": "/help", - "content": {"text": text}, - }) + await send_fn( + { + "type": "command_result", + "command": "/help", + "content": {"text": text}, + } + ) return if cmd.startswith("/theme"): parts = cmd.split() if len(parts) > 1: - from gently.app.theme import set_theme, get_theme + from gently.app.theme import get_theme, set_theme + try: set_theme(parts[1]) theme = get_theme() - await send_fn({ - "type": "command_result", - "command": "/theme", - "content": {"theme": theme.name, "changed": True}, - }) + await send_fn( + { + "type": "command_result", + "command": "/theme", + "content": {"theme": theme.name, "changed": True}, + } + ) except ValueError as e: - await send_fn({ - "type": "command_result", - "command": "/theme", - "error": str(e), - }) + await send_fn( + { + "type": "command_result", + "command": "/theme", + "error": str(e), + } + ) else: - from gently.app.theme import list_themes, get_theme + from gently.app.theme import get_theme, list_themes + current = get_theme() themes = {k: v.name for k, v in list_themes().items()} - await send_fn({ - "type": "command_result", - "command": "/theme", - "content": {"themes": themes, "current": current.name}, - }) + await send_fn( + { + "type": "command_result", + "command": "/theme", + "content": {"themes": themes, "current": current.name}, + } + ) return if cmd == "/sessions": - await send_fn({ - "type": "command_result", - "command": "/sessions", - "content": {"sessions": self._get_sessions_list()}, - }) + await send_fn( + { + "type": "command_result", + "command": "/sessions", + "content": {"sessions": self._get_sessions_list()}, + } + ) return if cmd == "/timelapse" or cmd == "/timelapse watch": data = self._get_timelapse_data() - await send_fn({ - "type": "command_result", - "command": "/timelapse", - "content": data, - }) + await send_fn( + { + "type": "command_result", + "command": "/timelapse", + "content": data, + } + ) return if cmd.startswith("/timeline"): parts = command.strip().split() data = self._get_timeline_data(parts[1:] if len(parts) > 1 else []) - await send_fn({ - "type": "command_result", - "command": "/timeline", - "content": data, - }) + await send_fn( + { + "type": "command_result", + "command": "/timeline", + "content": data, + } + ) return if cmd == "/detectors": data = self._get_detectors_data() - await send_fn({ - "type": "command_result", - "command": "/detectors", - "content": data, - }) + await send_fn( + { + "type": "command_result", + "command": "/detectors", + "content": data, + } + ) return if cmd == "/history": data = self._get_history_data() - await send_fn({ - "type": "command_result", - "command": "/history", - "content": data, - }) + await send_fn( + { + "type": "command_result", + "command": "/history", + "content": data, + } + ) return if cmd == "/save": success = self.agent.save_session() if success: - await send_fn({ - "type": "command_result", - "command": "/save", - "content": {"text": f"Session saved: {self.agent.session_id}"}, - }) + await send_fn( + { + "type": "command_result", + "command": "/save", + "content": {"text": f"Session saved: {self.agent.session_id}"}, + } + ) else: - await send_fn({ - "type": "command_result", - "command": "/save", - "error": "Failed to save session", - }) + await send_fn( + { + "type": "command_result", + "command": "/save", + "error": "Failed to save session", + } + ) return if cmd == "/reset-context": cs = self._require_context_store() if cs is None: - await send_fn({ - "type": "command_result", - "command": "/reset-context", - "error": "Context store not available", - }) + await send_fn( + { + "type": "command_result", + "command": "/reset-context", + "error": "Context store not available", + } + ) else: counts = cs.reset() total = sum(counts.values()) @@ -834,25 +968,32 @@ async def handle_command( self.init_wizard(cs, claude_client) if total > 0: details = ", ".join(f"{v} {k}" for k, v in counts.items()) - msg = f"Context cleared: {total} entries removed ({details}).\nRun /wizard to set up again." + msg = ( + f"Context cleared: {total} entries removed ({details})." + "\nRun /wizard to set up again." + ) else: msg = "Context already empty — nothing to clear." - await send_fn({ - "type": "command_result", - "command": "/reset-context", - "content": {"text": msg}, - }) + await send_fn( + { + "type": "command_result", + "command": "/reset-context", + "content": {"text": msg}, + } + ) return if cmd == "/wizard": # Handled by the WebSocket route (agent_ws.py), not the bridge. # If we reach here, it means the wizard loop called handle_command # — i.e. /wizard was typed while the wizard is already running. - await send_fn({ - "type": "command_result", - "command": "/wizard", - "content": {"text": "The wizard is already running."}, - }) + await send_fn( + { + "type": "command_result", + "command": "/wizard", + "content": {"text": "The wizard is already running."}, + } + ) return if cmd == "/campaign" or cmd == "/campaigns" or cmd.startswith("/campaign "): @@ -861,53 +1002,67 @@ async def handle_command( if subcmd == "share" and len(parts) >= 3: data = self._share_campaign(parts[2]) - await send_fn({ - "type": "command_result", - "command": "/campaign", - "content": data, - }) + await send_fn( + { + "type": "command_result", + "command": "/campaign", + "content": data, + } + ) elif subcmd == "unshare" and len(parts) >= 3: data = self._unshare_campaign(parts[2]) - await send_fn({ - "type": "command_result", - "command": "/campaign", - "content": data, - }) + await send_fn( + { + "type": "command_result", + "command": "/campaign", + "content": data, + } + ) elif subcmd == "delete" and len(parts) >= 3: data = self._delete_campaign(parts[2]) - await send_fn({ - "type": "command_result", - "command": "/campaign", - "content": data, - }) + await send_fn( + { + "type": "command_result", + "command": "/campaign", + "content": data, + } + ) elif subcmd == "rename" and len(parts) >= 4: data = self._rename_campaign(parts[2], " ".join(parts[3:])) - await send_fn({ - "type": "command_result", - "command": "/campaign", - "content": data, - }) + await send_fn( + { + "type": "command_result", + "command": "/campaign", + "content": data, + } + ) elif subcmd == "pause" and len(parts) >= 3: data = self._pause_campaign(parts[2]) - await send_fn({ - "type": "command_result", - "command": "/campaign", - "content": data, - }) + await send_fn( + { + "type": "command_result", + "command": "/campaign", + "content": data, + } + ) elif subcmd == "resume" and len(parts) >= 3: data = self._resume_campaign(parts[2]) - await send_fn({ - "type": "command_result", - "command": "/campaign", - "content": data, - }) + await send_fn( + { + "type": "command_result", + "command": "/campaign", + "content": data, + } + ) else: data = self._get_campaigns_data(command.strip()) - await send_fn({ - "type": "command_result", - "command": "/campaign", - "content": data, - }) + await send_fn( + { + "type": "command_result", + "command": "/campaign", + "content": data, + } + ) return if cmd.startswith("/join-campaign"): @@ -916,11 +1071,13 @@ async def handle_command( data = await self._join_campaign(parts[1], parts[2]) else: data = {"text": "Usage: /join-campaign "} - await send_fn({ - "type": "command_result", - "command": "/join-campaign", - "content": data, - }) + await send_fn( + { + "type": "command_result", + "command": "/join-campaign", + "content": data, + } + ) return if cmd.startswith("/claim"): @@ -929,11 +1086,13 @@ async def handle_command( data = await self._claim_item(parts[1]) else: data = {"text": "Usage: /claim "} - await send_fn({ - "type": "command_result", - "command": "/claim", - "content": data, - }) + await send_fn( + { + "type": "command_result", + "command": "/claim", + "content": data, + } + ) return if cmd == "/pair" or cmd.startswith("/pair "): @@ -952,17 +1111,30 @@ async def handle_command( elif subcmd == "scopes": extra_args = parts[3] if len(parts) > 3 else "" data = self._pair_scopes(arg, extra_args) - elif subcmd and subcmd not in ("accept", "reject", "list", "unpair", "scopes"): + elif subcmd and subcmd not in ( + "accept", + "reject", + "list", + "unpair", + "scopes", + ): # Treat as hostname — initiate pairing data = await self._pair_initiate(subcmd, send_fn) else: - data = {"text": "Usage: /pair | accept | reject | list | unpair | scopes [hostname] [scope_list]"} + data = { + "text": ( + "Usage: /pair | accept | reject | list" + " | unpair | scopes [hostname] [scope_list]" + ) + } - await send_fn({ - "type": "command_result", - "command": "/pair", - "content": data, - }) + await send_fn( + { + "type": "command_result", + "command": "/pair", + "content": data, + } + ) return if cmd == "/plan" or cmd.startswith("/plan "): @@ -971,28 +1143,32 @@ async def handle_command( if subcmd == "exit": msg = self.agent.exit_plan_mode() - await send_fn({ - "type": "command_result", - "command": "/plan", - "content": { - "text": msg, - "mode": self.agent.mode, - }, - }) + await send_fn( + { + "type": "command_result", + "command": "/plan", + "content": { + "text": msg, + "mode": self.agent.mode, + }, + } + ) elif subcmd == "status": summary = self.agent._get_active_plan_summary() if summary: text = f"Mode: {self.agent.mode}\n\n{summary}" else: text = f"Mode: {self.agent.mode}\nNo active campaigns found." - await send_fn({ - "type": "command_result", - "command": "/plan", - "content": { - "text": text, - "mode": self.agent.mode, - }, - }) + await send_fn( + { + "type": "command_result", + "command": "/plan", + "content": { + "text": text, + "mode": self.agent.mode, + }, + } + ) else: # Enter plan mode (or show status if already in it) if self.agent.mode == "plan": @@ -1001,32 +1177,38 @@ async def handle_command( if summary: text += f"\n\n{summary}" text += "\n\nUse /plan exit to return to run mode." - await send_fn({ - "type": "command_result", - "command": "/plan", - "content": { - "text": text, - "mode": "plan", - }, - }) + await send_fn( + { + "type": "command_result", + "command": "/plan", + "content": { + "text": text, + "mode": "plan", + }, + } + ) else: msg = self.agent.enter_plan_mode() - await send_fn({ - "type": "command_result", - "command": "/plan", - "content": { - "text": msg, - "mode": "plan", - }, - }) + await send_fn( + { + "type": "command_result", + "command": "/plan", + "content": { + "text": msg, + "mode": "plan", + }, + } + ) return if cmd == "/clear": - await send_fn({ - "type": "command_result", - "command": "/clear", - "action": "clear", - }) + await send_fn( + { + "type": "command_result", + "command": "/clear", + "action": "clear", + } + ) return if cmd.startswith("/resume"): @@ -1037,39 +1219,48 @@ async def handle_command( if success: embryo_count = len(self.agent.experiment.embryos) msg_count = len(self.agent.conversation_history) - await send_fn({ - "type": "command_result", - "command": "/resume", - "content": { - "text": f"Session resumed: {session_id}\n {embryo_count} embryos, {msg_count} messages", - }, - }) + await send_fn( + { + "type": "command_result", + "command": "/resume", + "content": { + "text": ( + f"Session resumed: {session_id}\n" + f" {embryo_count} embryos, {msg_count} messages" + ), + }, + } + ) else: - await send_fn({ - "type": "command_result", - "command": "/resume", - "error": f"Session '{session_id}' not found", - }) + await send_fn( + { + "type": "command_result", + "command": "/resume", + "error": f"Session '{session_id}' not found", + } + ) else: # No session ID — list available sessions for the user sessions = self._get_sessions_list() if sessions: lines = ["Available sessions (use /resume ):"] for s in sessions: - lines.append( - f" {s['session_id']} — {s['embryo_count']} embryos" - ) - await send_fn({ - "type": "command_result", - "command": "/resume", - "content": {"text": "\n".join(lines)}, - }) + lines.append(f" {s['session_id']} — {s['embryo_count']} embryos") + await send_fn( + { + "type": "command_result", + "command": "/resume", + "content": {"text": "\n".join(lines)}, + } + ) else: - await send_fn({ - "type": "command_result", - "command": "/resume", - "content": {"text": "No saved sessions found."}, - }) + await send_fn( + { + "type": "command_result", + "command": "/resume", + "content": {"text": "No saved sessions found."}, + } + ) return if cmd.startswith("/import-embryos"): @@ -1081,11 +1272,13 @@ async def handle_command( sessions = self._get_sessions_list() sessions_with = [s for s in sessions if s["embryo_count"] > 0] if not sessions_with: - await send_fn({ - "type": "command_result", - "command": "/import-embryos", - "error": "No sessions with embryos found.", - }) + await send_fn( + { + "type": "command_result", + "command": "/import-embryos", + "error": "No sessions with embryos found.", + } + ) return session_id = sessions_with[0]["session_id"] else: @@ -1098,6 +1291,7 @@ async def handle_command( sessions_with = [s for s in sessions if s["embryo_count"] > 0] if sessions_with: import uuid as _uuid + request_id = f"import_embryos_{_uuid.uuid4().hex[:8]}" options = [] for s in sessions_with[:10]: @@ -1110,6 +1304,7 @@ async def handle_command( if last_active: try: from datetime import datetime + dt = datetime.fromisoformat(last_active) time_str = dt.strftime("%b %d %H:%M") except (ValueError, TypeError): @@ -1120,21 +1315,25 @@ async def handle_command( desc = f"Import embryos from session {sid}" if name: desc = name - options.append({ - "id": sid, - "label": label, - "description": desc, - }) - await send_fn({ - "type": "choice_request", - "choice_data": { - "_type": "single", - "question": "Import embryos from which session?", - "options": options, - "allow_multiple": False, - }, - "request_id": request_id, - }) + options.append( + { + "id": sid, + "label": label, + "description": desc, + } + ) + await send_fn( + { + "type": "choice_request", + "choice_data": { + "_type": "single", + "question": "Import embryos from which session?", + "options": options, + "allow_multiple": False, + }, + "request_id": request_id, + } + ) # Register a callback so the choice response triggers the import. # We can't await here (would deadlock the REPL loop), so we # store state for _handle_import_choice to pick up. @@ -1143,11 +1342,13 @@ async def handle_command( "send_fn": send_fn, } else: - await send_fn({ - "type": "command_result", - "command": "/import-embryos", - "content": {"text": "No sessions with embryos found."}, - }) + await send_fn( + { + "type": "command_result", + "command": "/import-embryos", + "content": {"text": "No sessions with embryos found."}, + } + ) return if cmd.startswith("/make-video"): @@ -1170,58 +1371,74 @@ async def handle_command( session_id = self.agent.session_id if not session_id: - await send_fn({ - "type": "command_result", - "command": "/make-video", - "error": "No active session", - }) + await send_fn( + { + "type": "command_result", + "command": "/make-video", + "error": "No active session", + } + ) return try: - from gently.app.video_maker import discover_volumes, create_timelapse_video + from gently.app.video_maker import ( + create_timelapse_video, + discover_volumes, + ) + storage_path = self.agent.storage_path session_images_dir = storage_path / "images" / session_id if not session_images_dir.exists(): - await send_fn({ - "type": "command_result", - "command": "/make-video", - "error": f"No images found for session {session_id}", - }) + await send_fn( + { + "type": "command_result", + "command": "/make-video", + "error": f"No images found for session {session_id}", + } + ) return all_volumes = discover_volumes(session_images_dir, embryo_id) if not all_volumes: - await send_fn({ - "type": "command_result", - "command": "/make-video", - "content": {"text": "No timelapse volumes found."}, - }) + await send_fn( + { + "type": "command_result", + "command": "/make-video", + "content": {"text": "No timelapse volumes found."}, + } + ) return lines = [f"Creating timelapse videos (fps={fps})..."] for eid, vol_paths in all_volumes.items(): output_path = session_images_dir / f"{eid}_timelapse.mp4" - create_timelapse_video(vol_paths, str(output_path), fps=fps) + create_timelapse_video(vol_paths, output_path, fps=fps) lines.append(f" {eid}: {len(vol_paths)} frames → {output_path.name}") - await send_fn({ - "type": "command_result", - "command": "/make-video", - "content": {"text": "\n".join(lines)}, - }) + await send_fn( + { + "type": "command_result", + "command": "/make-video", + "content": {"text": "\n".join(lines)}, + } + ) except ImportError: - await send_fn({ - "type": "command_result", - "command": "/make-video", - "error": "Video maker module not available.", - }) + await send_fn( + { + "type": "command_result", + "command": "/make-video", + "error": "Video maker module not available.", + } + ) except Exception as e: - await send_fn({ - "type": "command_result", - "command": "/make-video", - "error": str(e), - }) + await send_fn( + { + "type": "command_result", + "command": "/make-video", + "error": str(e), + } + ) return if cmd.startswith("/test-device") or cmd.startswith("/benchmark"): @@ -1273,17 +1490,26 @@ async def _benchmark_progress(stage, current, total, timing): text = " ".join(parts) else: return - await send_fn({ + await send_fn( + { + "type": "command_result", + "command": "/test-device", + "content": {"text": text}, + } + ) + + await send_fn( + { "type": "command_result", "command": "/test-device", - "content": {"text": text}, - }) - - await send_fn({ - "type": "command_result", - "command": "/test-device", - "content": {"text": f"Running device test ({n_volumes} volumes, {n_slices} slices, {n_warmup} warmup)..."}, - }) + "content": { + "text": ( + f"Running device test ({n_volumes} volumes," + f" {n_slices} slices, {n_warmup} warmup)..." + ) + }, + } + ) results = await run_benchmark( self.agent, n_volumes=n_volumes, @@ -1303,67 +1529,114 @@ async def _benchmark_progress(stage, current, total, timing): f" Throughput: {results.fps:.2f} vol/s", "", " Stage Mean Std Min Max", - f" Acquisition {acq['mean']:.3f}s {acq['std']:.3f}s {acq['min']:.3f}s {acq['max']:.3f}s", - f" Storage {stor['mean']:.3f}s {stor['std']:.3f}s {stor['min']:.3f}s {stor['max']:.3f}s", + f" Acquisition {acq['mean']:.3f}s {acq['std']:.3f}s " + f"{acq['min']:.3f}s {acq['max']:.3f}s", + f" Storage {stor['mean']:.3f}s {stor['std']:.3f}s " + f"{stor['min']:.3f}s {stor['max']:.3f}s", ] - if viz['mean'] > 0: + if viz["mean"] > 0: lines.append( - f" Viz push {viz['mean']:.3f}s {viz['std']:.3f}s {viz['min']:.3f}s {viz['max']:.3f}s" + f" Viz push {viz['mean']:.3f}s {viz['std']:.3f}s " + f"{viz['min']:.3f}s {viz['max']:.3f}s" ) - lines.extend([ - f" Total {total['mean']:.3f}s {total['std']:.3f}s {total['min']:.3f}s {total['max']:.3f}s", - ]) + lines.extend( + [ + f" Total {total['mean']:.3f}s {total['std']:.3f}s " + f"{total['min']:.3f}s {total['max']:.3f}s", + ] + ) if results.avg_file_size_mb > 0: lines.append(f" File size: {results.avg_file_size_mb:.1f} MB avg") if results.failed: lines.append(f" Failures: {len(results.failed)}") - await send_fn({ - "type": "command_result", - "command": "/test-device", - "content": {"text": "\n".join(lines)}, - }) + await send_fn( + { + "type": "command_result", + "command": "/test-device", + "content": {"text": "\n".join(lines)}, + } + ) except ImportError as e: logger.error("Benchmark import failed: %s", e, exc_info=True) - await send_fn({ - "type": "command_result", - "command": "/test-device", - "error": f"Benchmark module not available: {e}", - }) + await send_fn( + { + "type": "command_result", + "command": "/test-device", + "error": f"Benchmark module not available: {e}", + } + ) except Exception as e: - await send_fn({ - "type": "command_result", - "command": "/test-device", - "error": str(e), - }) + await send_fn( + { + "type": "command_result", + "command": "/test-device", + "error": str(e), + } + ) return # Fallback for truly unimplemented commands - await send_fn({ - "type": "command_result", - "command": cmd, - "content": {"text": f"Command `{cmd}` is not yet available in the TUI."}, - }) + await send_fn( + { + "type": "command_result", + "command": cmd, + "content": {"text": f"Command `{cmd}` is not yet available in the TUI."}, + } + ) def get_commands_json(self) -> list: """Serialize the command registry for the TUI client.""" registry = get_command_registry() commands = [] for cmd in registry.get_all(): - commands.append({ - "name": cmd.name, - "description": cmd.description, - "aliases": cmd.aliases, - "category": cmd.category.name, - "usage": cmd.usage_string(), - "arg_hint": cmd.arg_hint_string(), - "subcommands": [ - {"name": s.name, "description": s.description} - for s in cmd.subcommands - ], - }) + commands.append( + { + "name": cmd.name, + "description": cmd.description, + "aliases": cmd.aliases, + "category": cmd.category.name, + "usage": cmd.usage_string(), + "arg_hint": cmd.arg_hint_string(), + "subcommands": [ + {"name": s.name, "description": s.description} for s in cmd.subcommands + ], + } + ) return commands + def get_tools_json(self) -> list: + """Serialize the agent tool registry for client-side autocomplete. + + Trimmed on purpose (first description line + lightweight param list) so + the connect frame stays small. The web chat uses this for @tool-name + completion and to show a tool's arguments inline. + """ + try: + from gently.harness.tools.registry import get_tool_registry + + registry = get_tool_registry() + except Exception: + return [] + tools = [] + for t in registry.list_all(): + desc = (t.description or "").strip().split("\n", 1)[0][:200] + category = getattr(t.category, "name", None) or str(t.category) + tools.append( + { + "name": t.name, + "description": desc, + "category": category, + "params": [ + {"name": p.name, "type": p.type, "required": bool(p.required)} + for p in t.parameters + if p.name != "context" + ], + } + ) + tools.sort(key=lambda x: x["name"]) + return tools + # ------------------------------------------------------------------ # Private helpers for structured command data # ------------------------------------------------------------------ @@ -1381,7 +1654,7 @@ def _get_status_data(self) -> dict: "has_sam": client.has_sam if client else False, } - def _get_embryos_data(self, embryo_id: str = None) -> dict: + def _get_embryos_data(self, embryo_id: str | None = None) -> dict: """Build structured embryo data.""" exp = self.agent.experiment if embryo_id: @@ -1399,11 +1672,13 @@ def _get_embryos_data(self, embryo_id: str = None) -> dict: embryos = [] for eid, emb in exp.embryos.items(): - embryos.append({ - "id": eid, - "nickname": emb.nickname, - "user_label": emb.user_label, - }) + embryos.append( + { + "id": eid, + "nickname": emb.nickname, + "user_label": emb.user_label, + } + ) return {"embryos": embryos} def _get_tokens_data(self) -> dict: @@ -1426,10 +1701,12 @@ def _get_token_snapshot(self) -> dict: def get_connect_metadata(self) -> dict: """Metadata sent to the TUI on connect.""" import gently + exp = self.agent.experiment meta = { "session_id": self.agent.session_id, "commands": self.get_commands_json(), + "tools": self.get_tools_json(), "version": getattr(gently, "__version__", "dev"), "tokens": self._get_token_snapshot(), "embryo_count": len(exp.embryos), @@ -1507,12 +1784,14 @@ def _get_sessions_list(self) -> list: for s in raw: sid = s.get("session_id", "unknown") embryos = self.agent.store.list_embryos(sid) - sessions.append({ - "session_id": sid, - "name": s.get("name", ""), - "embryo_count": len(embryos) if embryos else 0, - "last_active": s.get("last_active", ""), - }) + sessions.append( + { + "session_id": sid, + "name": s.get("name", ""), + "embryo_count": len(embryos) if embryos else 0, + "last_active": s.get("last_active", ""), + } + ) return sessions async def _send_import_result(self, send_fn, result: dict, session_label: str): @@ -1525,17 +1804,21 @@ async def _send_import_result(self, send_fn, result: dict, session_label: str): lines.append(f" {', '.join(imported)}") if skipped: lines.append(f" Skipped (exist): {', '.join(skipped)}") - await send_fn({ - "type": "command_result", - "command": "/import-embryos", - "content": {"text": "\n".join(lines)}, - }) + await send_fn( + { + "type": "command_result", + "command": "/import-embryos", + "content": {"text": "\n".join(lines)}, + } + ) else: - await send_fn({ - "type": "command_result", - "command": "/import-embryos", - "error": result.get("error", "Import failed"), - }) + await send_fn( + { + "type": "command_result", + "command": "/import-embryos", + "error": result.get("error", "Import failed"), + } + ) def _get_timelapse_data(self) -> dict: """Build structured timelapse status.""" @@ -1637,7 +1920,9 @@ def _delete_campaign(self, campaign_id: str) -> dict: if counts["campaigns"] > 0: parts.append(f"{counts['campaigns']} campaign{'s' if counts['campaigns'] != 1 else ''}") if counts["plan_items"] > 0: - parts.append(f"{counts['plan_items']} plan item{'s' if counts['plan_items'] != 1 else ''}") + parts.append( + f"{counts['plan_items']} plan item{'s' if counts['plan_items'] != 1 else ''}" + ) detail = f" ({', '.join(parts)})" if parts else "" return {"text": f"Deleted **{label}**{detail}."} @@ -1692,6 +1977,7 @@ def _pause_campaign(self, campaign_ref: str) -> dict: if not campaign: return {"text": f"Campaign '{campaign_ref}' not found."} from .memory.model import Status + cs.update_campaign_status(campaign.id, Status.PAUSED) label = campaign.shorthand or campaign.display_name return {"text": f"Campaign **{label}** paused."} @@ -1705,6 +1991,7 @@ def _resume_campaign(self, campaign_ref: str) -> dict: if not campaign: return {"text": f"Campaign '{campaign_ref}' not found."} from .memory.model import Status + cs.update_campaign_status(campaign.id, Status.ACTIVE) label = campaign.shorthand or campaign.display_name return {"text": f"Campaign **{label}** resumed."} @@ -1758,6 +2045,7 @@ async def _join_campaign(self, hostname: str, campaign_ref: str) -> dict: return {"text": "Peer client not available."} import socket + local_hostname = socket.gethostname() ok = await pc.join_campaign(peer, campaign_ref, mesh.instance_id, local_hostname) @@ -1765,12 +2053,19 @@ async def _join_campaign(self, hostname: str, campaign_ref: str) -> dict: return {"text": f"Failed to join campaign '{campaign_ref}' on {hostname}."} self._active_remote = {"peer": peer, "campaign_id": campaign_ref} - return {"text": f"Joined campaign **{campaign_ref}** on **{hostname}**.\nUse `/claim ` to claim items."} + return { + "text": ( + f"Joined campaign **{campaign_ref}** on **{hostname}**." + "\nUse `/claim ` to claim items." + ) + } async def _claim_item(self, item_id: str) -> dict: """Claim a plan item from the active remote campaign.""" if self._active_remote is None: - return {"text": "No active remote campaign. Use `/join-campaign ` first."} + return { + "text": "No active remote campaign. Use `/join-campaign ` first." + } mesh = self._require_mesh() if mesh is None: @@ -1781,6 +2076,7 @@ async def _claim_item(self, item_id: str) -> dict: return {"text": "Peer client not available."} import socket + local_hostname = socket.gethostname() peer = self._active_remote["peer"] @@ -1788,9 +2084,17 @@ async def _claim_item(self, item_id: str) -> dict: ok = await pc.claim_item(peer, campaign_id, item_id, mesh.instance_id, local_hostname) if not ok: - return {"text": f"Failed to claim item `{item_id}` — it may already be claimed by another node."} + return { + "text": ( + f"Failed to claim item `{item_id}` — it may already be claimed by another node." + ) + } - return {"text": f"Claimed item `{item_id}` from campaign **{campaign_id}** on **{peer.hostname}**."} + return { + "text": ( + f"Claimed item `{item_id}` from campaign **{campaign_id}** on **{peer.hostname}**." + ) + } # ------------------------------------------------------------------ # /pair helpers @@ -1817,7 +2121,10 @@ async def _pair_initiate(self, hostname: str, send_fn) -> dict: return {"text": "Peer client not available."} resp = await pc.send_pair_request( - peer, pm.instance_id, pm.hostname, nonce_local, + peer, + pm.instance_id, + pm.hostname, + nonce_local, cert_fingerprint=pm.cert_fingerprint, udp_sign_key=pm.udp_sign_key, ) @@ -1838,7 +2145,11 @@ async def _pair_initiate(self, hostname: str, send_fn) -> dict: # Create local session and compute PIN session = pm.process_initiation_response( - peer_id, peer_host, nonce_local, nonce_remote, pairing_id, + peer_id, + peer_host, + nonce_local, + nonce_remote, + pairing_id, ) # Store remote peer's TLS cert fingerprint and UDP signing key session.responder_cert_fingerprint = remote_cert_fp @@ -1850,9 +2161,14 @@ async def _pair_initiate(self, hostname: str, send_fn) -> dict: await pc.confirm_pair_remote(peer, pairing_id, pm.instance_id) # Start background polling for confirmation - asyncio.create_task(self._pair_poll( - mesh, peer, pairing_id, send_fn, - )) + asyncio.create_task( + self._pair_poll( + mesh, + peer, + pairing_id, + send_fn, + ) + ) return { "text": ( @@ -1880,15 +2196,24 @@ async def _pair_accept(self, send_fn) -> dict: mesh.mark_peer_trusted(session.initiator_id) from gently.core.event_bus import EventType, get_event_bus + get_event_bus().publish( EventType.MESH_PAIRING_COMPLETED, - {"pairing_id": session.pairing_id, "peer_hostname": session.initiator_hostname}, + { + "pairing_id": session.pairing_id, + "peer_hostname": session.initiator_hostname, + }, source="mesh", ) return {"text": f"Paired with **{session.initiator_hostname}**!"} - return {"text": f"Confirmed pairing with **{session.initiator_hostname}**. Waiting for their confirmation..."} + return { + "text": ( + f"Confirmed pairing with **{session.initiator_hostname}**." + " Waiting for their confirmation..." + ) + } async def _pair_reject(self) -> dict: """Reject the most recent pending pairing request.""" @@ -1945,6 +2270,7 @@ def _pair_scopes(self, hostname: str, scope_arg: str) -> dict: scope_str = ", ".join(tp.scopes) if tp.scopes else "none" lines.append(f" **{tp.hostname}** ({tp.instance_id[:8]}): {scope_str}") from gently.mesh.pairing import ALL_SCOPES + lines.append("") lines.append(f"Available scopes: {', '.join(ALL_SCOPES)}") return {"text": "\n".join(lines)} @@ -1961,9 +2287,12 @@ def _pair_scopes(self, hostname: str, scope_arg: str) -> dict: # Set scopes new_scopes = [s.strip() for s in scope_arg.split(",") if s.strip()] from gently.mesh.pairing import ALL_SCOPES + invalid = [s for s in new_scopes if s not in ALL_SCOPES] if invalid: - return {"text": f"Invalid scopes: {', '.join(invalid)}. Available: {', '.join(ALL_SCOPES)}"} + return { + "text": f"Invalid scopes: {', '.join(invalid)}. Available: {', '.join(ALL_SCOPES)}" + } if pm.set_scopes(hostname, new_scopes): return {"text": f"Scopes for **{hostname}** updated to: {', '.join(new_scopes)}"} @@ -1982,12 +2311,18 @@ def _pair_unpair(self, identifier: str) -> dict: # Mark the peer as untrusted in the mesh for peer in mesh.get_all_peers(): - if (peer.hostname.lower() == identifier.lower() - or peer.instance_id.startswith(identifier)): + if peer.hostname.lower() == identifier.lower() or peer.instance_id.startswith( + identifier + ): peer.is_trusted = False break - return {"text": f"Unpaired from **{identifier}**. They will need to re-pair to access mesh services."} + return { + "text": ( + f"Unpaired from **{identifier}**." + " They will need to re-pair to access mesh services." + ) + } async def _pair_poll(self, mesh, peer, pairing_id, send_fn): """Background poll: wait for remote to confirm pairing.""" @@ -2013,35 +2348,42 @@ async def _pair_poll(self, mesh, peer, pairing_id, send_fn): mesh.mark_peer_trusted(peer.instance_id) from gently.core.event_bus import EventType, get_event_bus + get_event_bus().publish( EventType.MESH_PAIRING_COMPLETED, {"pairing_id": pairing_id, "peer_hostname": peer.hostname}, source="mesh", ) - await send_fn({ - "type": "notification", - "level": "success", - "title": f"Paired with {peer.hostname}", - }) + await send_fn( + { + "type": "notification", + "level": "success", + "title": f"Paired with {peer.hostname}", + } + ) return if status in ("rejected", "expired"): - await send_fn({ - "type": "notification", - "level": "warning", - "title": f"Pairing {status}", - "body": peer.hostname, - }) + await send_fn( + { + "type": "notification", + "level": "warning", + "title": f"Pairing {status}", + "body": peer.hostname, + } + ) return # Timeout - await send_fn({ - "type": "notification", - "level": "warning", - "title": f"Pairing timed out", - "body": peer.hostname, - }) + await send_fn( + { + "type": "notification", + "level": "warning", + "title": "Pairing timed out", + "body": peer.hostname, + } + ) def _get_campaigns_data(self, command: str) -> dict: """Build structured campaign/plan data.""" @@ -2086,7 +2428,9 @@ def _render_campaign_list(self, cs) -> dict: lines.append("") - lines.append(f"Use `/campaign ` for details, or browse at the viz server /campaigns page.") + lines.append( + "Use `/campaign ` for details, or browse at the viz server /campaigns page." + ) return {"text": "\n".join(lines)} def _render_campaign_detail(self, cs, campaign_id: str) -> dict: @@ -2120,7 +2464,10 @@ def _render_campaign_detail(self, cs, campaign_id: str) -> dict: lines.append(f"Target: {campaign.target}") if campaign.progress: lines.append(f"Progress: {campaign.progress}") - lines.append(f"Status: {campaign.status.value} · {completed}/{total} complete · {in_progress} in progress") + lines.append( + f"Status: {campaign.status.value} · {completed}/{total} complete" + f" · {in_progress} in progress" + ) lines.append("") TYPE_ICONS = { diff --git a/gently/harness/commands.py b/gently/harness/commands.py index 7002f2f9..a3fed77b 100644 --- a/gently/harness/commands.py +++ b/gently/harness/commands.py @@ -10,37 +10,39 @@ from dataclasses import dataclass, field from enum import Enum, auto -from typing import Callable, Dict, List, Optional class CommandCategory(Enum): """Categories for organizing commands in help and welcome""" - NAVIGATION = auto() # /quit, /clear, /help - INSPECTION = auto() # /status, /detectors, /embryos, /timelapse, /timeline - SESSION = auto() # /sessions, /resume, /save, /import-embryos - PLANNING = auto() # /plan - APPEARANCE = auto() # /theme, /history, /tokens - DIAGNOSTICS = auto() # /test-device + + NAVIGATION = auto() # /quit, /clear, /help + INSPECTION = auto() # /status, /detectors, /embryos, /timelapse, /timeline + SESSION = auto() # /sessions, /resume, /save, /import-embryos + PLANNING = auto() # /plan + APPEARANCE = auto() # /theme, /history, /tokens + DIAGNOSTICS = auto() # /test-device @dataclass class CommandOption: """Definition of a command option/flag""" - name: str # e.g., "--filter" + + name: str # e.g., "--filter" description: str = "" - short: Optional[str] = None # e.g., "-f" - takes_value: bool = False # True if option requires a value - value_choices: List[str] = field(default_factory=list) # Possible values - value_hint: str = "" # e.g., "TYPE" for "--filter TYPE" - is_flag: bool = False # True for boolean flags (no value) + short: str | None = None # e.g., "-f" + takes_value: bool = False # True if option requires a value + value_choices: list[str] = field(default_factory=list) # Possible values + value_hint: str = "" # e.g., "TYPE" for "--filter TYPE" + is_flag: bool = False # True for boolean flags (no value) @dataclass class SubCommand: """Definition of a sub-command""" - name: str # e.g., "watch", "clear" + + name: str # e.g., "watch", "clear" description: str = "" - options: List[CommandOption] = field(default_factory=list) + options: list[CommandOption] = field(default_factory=list) @dataclass @@ -53,21 +55,22 @@ class CommandDefinition: - Help text generation - Welcome message generation """ - name: str # e.g., "/timelapse" (with leading slash) - description: str # Short description for autocomplete - help_text: str = "" # Detailed help (multi-line OK) - aliases: List[str] = field(default_factory=list) # e.g., ["/q", "/exit"] + + name: str # e.g., "/timelapse" (with leading slash) + description: str # Short description for autocomplete + help_text: str = "" # Detailed help (multi-line OK) + aliases: list[str] = field(default_factory=list) # e.g., ["/q", "/exit"] category: CommandCategory = CommandCategory.NAVIGATION # Positional argument - positional_arg: Optional[str] = None # e.g., "embryo_id" - positional_hint: str = "" # e.g., "ID or 'last'" + positional_arg: str | None = None # e.g., "embryo_id" + positional_hint: str = "" # e.g., "ID or 'last'" # Sub-commands (e.g., /timelapse watch) - subcommands: List[SubCommand] = field(default_factory=list) + subcommands: list[SubCommand] = field(default_factory=list) # Options/flags (e.g., /timeline --filter) - options: List[CommandOption] = field(default_factory=list) + options: list[CommandOption] = field(default_factory=list) def usage_string(self) -> str: """Generate usage string like '/timeline [clear] [--filter TYPE]'""" @@ -125,8 +128,8 @@ class CommandRegistry: """ def __init__(self): - self._commands: Dict[str, CommandDefinition] = {} - self._aliases: Dict[str, str] = {} # alias -> canonical name + self._commands: dict[str, CommandDefinition] = {} + self._aliases: dict[str, str] = {} # alias -> canonical name def register(self, command: CommandDefinition) -> None: """Register a command definition""" @@ -134,21 +137,21 @@ def register(self, command: CommandDefinition) -> None: for alias in command.aliases: self._aliases[alias] = command.name - def get(self, name: str) -> Optional[CommandDefinition]: + def get(self, name: str) -> CommandDefinition | None: """Get command by name or alias""" name = name.lower() canonical = self._aliases.get(name, name) return self._commands.get(canonical) - def get_all(self) -> List[CommandDefinition]: + def get_all(self) -> list[CommandDefinition]: """Get all registered commands""" return list(self._commands.values()) - def get_by_category(self, category: CommandCategory) -> List[CommandDefinition]: + def get_by_category(self, category: CommandCategory) -> list[CommandDefinition]: """Get commands in a category""" return [c for c in self._commands.values() if c.category == category] - def get_all_names_and_aliases(self) -> List[str]: + def get_all_names_and_aliases(self) -> list[str]: """Get all command names and aliases for autocomplete""" names = list(self._commands.keys()) names.extend(self._aliases.keys()) @@ -161,10 +164,10 @@ def generate_help_markdown(self) -> str: "", "## Natural Language", "Just type what you want! Examples:", - "- \"What detectors do we have?\"", - "- \"Add a detector for comma stage\"", - "- \"Test hatching detector on embryo 1\"", - "- \"Start imaging all embryos\"", + '- "What detectors do we have?"', + '- "Add a detector for comma stage"', + '- "Test hatching detector on embryo 1"', + '- "Start imaging all embryos"', "", "## Slash Commands", "", @@ -215,18 +218,20 @@ def generate_help_markdown(self) -> str: lines.append("") # Keyboard shortcuts - lines.extend([ - "## Keyboard Shortcuts", - "- `Tab` - Autocomplete commands/options", - "- `Right Arrow` - Accept shadow suggestion", - "- `Ctrl+C` - Exit", - "- `Ctrl+L` - Clear screen", - "- `Ctrl+R` - Reverse search history", - ]) + lines.extend( + [ + "## Keyboard Shortcuts", + "- `Tab` - Autocomplete commands/options", + "- `Right Arrow` - Accept shadow suggestion", + "- `Ctrl+C` - Exit", + "- `Ctrl+L` - Clear screen", + "- `Ctrl+R` - Reverse search history", + ] + ) return "\n".join(lines) - def generate_command_help(self, name: str) -> Optional[str]: + def generate_command_help(self, name: str) -> str | None: """Generate detailed help for a specific command""" cmd = self.get(name) if not cmd: @@ -278,228 +283,285 @@ def generate_command_help(self, name: str) -> Optional[str]: # Default Commands Registration # ============================================================================ + def _register_default_commands(registry: CommandRegistry) -> None: """Register all built-in commands""" # === Navigation Commands === - registry.register(CommandDefinition( - name="/quit", - description="Exit the agent", - help_text="Exit the interactive agent session.", - aliases=["/exit", "/q"], - category=CommandCategory.NAVIGATION, - )) - - registry.register(CommandDefinition( - name="/clear", - description="Clear screen", - help_text="Clear the terminal screen and show welcome banner.", - category=CommandCategory.NAVIGATION, - )) - - registry.register(CommandDefinition( - name="/help", - description="Show help", - help_text="Show help for all commands or a specific command.\n\nUsage:\n- `/help` - Show all commands\n- `/help timeline` - Show detailed help for /timeline", - positional_arg="command", - positional_hint="command", - category=CommandCategory.NAVIGATION, - )) - - # === Inspection Commands === - registry.register(CommandDefinition( - name="/status", - description="Show experiment status", - help_text="Display current experiment status including microscope connection, active embryos, and detector status.", - category=CommandCategory.INSPECTION, - )) - - registry.register(CommandDefinition( - name="/detectors", - description="List all detectors", - help_text="Show a table of all registered detectors with their status, type, and configuration.", - category=CommandCategory.INSPECTION, - )) - - registry.register(CommandDefinition( - name="/embryos", - description="List embryos or show details", - help_text="List all embryos in the current experiment. Provide an embryo ID to see detailed information about a specific embryo.", - positional_arg="embryo_id", - positional_hint="ID", - category=CommandCategory.INSPECTION, - )) - - registry.register(CommandDefinition( - name="/timelapse", - description="Timelapse status [watch]", - help_text="Display timelapse acquisition status for all embryos.\n\nUse 'watch' for live updating countdown view that refreshes every second.", - subcommands=[ - SubCommand( - name="watch", - description="Live countdown mode (Ctrl+C to exit)", + registry.register( + CommandDefinition( + name="/quit", + description="Exit the agent", + help_text="Exit the interactive agent session.", + aliases=["/exit", "/q"], + category=CommandCategory.NAVIGATION, + ) + ) + + registry.register( + CommandDefinition( + name="/clear", + description="Clear screen", + help_text="Clear the terminal screen and show welcome banner.", + category=CommandCategory.NAVIGATION, + ) + ) + + registry.register( + CommandDefinition( + name="/help", + description="Show help", + help_text=( + "Show help for all commands or a specific command.\n\nUsage:\n" + "- `/help` - Show all commands\n" + "- `/help timeline` - Show detailed help for /timeline" ), - ], - category=CommandCategory.INSPECTION, - )) - - registry.register(CommandDefinition( - name="/timeline", - description="Event timeline [--filter, clear]", - help_text="""Display timeline of timelapse and detection events. + positional_arg="command", + positional_hint="command", + category=CommandCategory.NAVIGATION, + ) + ) -Shows events from the current session by default. Use --all to see events from all sessions.""", - subcommands=[ - SubCommand( - name="clear", - description="Clear timeline history", - options=[ - CommandOption( - name="--before", - description="Clear events before time", - takes_value=True, - value_hint="TIME", - ), - ], - ), - ], - options=[ - CommandOption( - name="--filter", - description="Filter by event type", - takes_value=True, - value_choices=["timelapse", "detection"], - value_hint="TYPE", - ), - CommandOption( - name="--embryo", - description="Filter by embryo ID", - takes_value=True, - value_hint="ID", - ), - CommandOption( - name="--since", - description="Show events from time period", - takes_value=True, - value_hint="TIME", - ), - CommandOption( - name="--all", - description="Show events from all sessions", - is_flag=True, - ), - CommandOption( - name="--letters", - description="Lettered markers with legend (default)", - is_flag=True, + # === Inspection Commands === + registry.register( + CommandDefinition( + name="/status", + description="Show experiment status", + help_text=( + "Display current experiment status including microscope connection," + " active embryos, and detector status." ), - CommandOption( - name="--log", - description="Git-log style vertical timeline", - is_flag=True, + category=CommandCategory.INSPECTION, + ) + ) + + registry.register( + CommandDefinition( + name="/detectors", + description="List all detectors", + help_text=( + "Show a table of all registered detectors with their status, type," + " and configuration." ), - CommandOption( - name="--table", - description="Compact table view", - is_flag=True, + category=CommandCategory.INSPECTION, + ) + ) + + registry.register( + CommandDefinition( + name="/embryos", + description="List embryos or show details", + help_text=( + "List all embryos in the current experiment. Provide an embryo ID to see" + " detailed information about a specific embryo." ), - CommandOption( - name="--axis", - description="Simple horizontal axis", - is_flag=True, + positional_arg="embryo_id", + positional_hint="ID", + category=CommandCategory.INSPECTION, + ) + ) + + registry.register( + CommandDefinition( + name="/timelapse", + description="Timelapse status [watch]", + help_text=( + "Display timelapse acquisition status for all embryos.\n\n" + "Use 'watch' for live updating countdown view that refreshes every second." ), - ], - category=CommandCategory.INSPECTION, - )) + subcommands=[ + SubCommand( + name="watch", + description="Live countdown mode (Ctrl+C to exit)", + ), + ], + category=CommandCategory.INSPECTION, + ) + ) + + registry.register( + CommandDefinition( + name="/timeline", + description="Event timeline [--filter, clear]", + help_text="""Display timeline of timelapse and detection events. + +Shows events from the current session by default. Use --all to see events from all sessions.""", + subcommands=[ + SubCommand( + name="clear", + description="Clear timeline history", + options=[ + CommandOption( + name="--before", + description="Clear events before time", + takes_value=True, + value_hint="TIME", + ), + ], + ), + ], + options=[ + CommandOption( + name="--filter", + description="Filter by event type", + takes_value=True, + value_choices=["timelapse", "detection"], + value_hint="TYPE", + ), + CommandOption( + name="--embryo", + description="Filter by embryo ID", + takes_value=True, + value_hint="ID", + ), + CommandOption( + name="--since", + description="Show events from time period", + takes_value=True, + value_hint="TIME", + ), + CommandOption( + name="--all", + description="Show events from all sessions", + is_flag=True, + ), + CommandOption( + name="--letters", + description="Lettered markers with legend (default)", + is_flag=True, + ), + CommandOption( + name="--log", + description="Git-log style vertical timeline", + is_flag=True, + ), + CommandOption( + name="--table", + description="Compact table view", + is_flag=True, + ), + CommandOption( + name="--axis", + description="Simple horizontal axis", + is_flag=True, + ), + ], + category=CommandCategory.INSPECTION, + ) + ) # === Session Commands === - registry.register(CommandDefinition( - name="/sessions", - description="Browse saved sessions", - help_text="Open interactive session browser to view and select from saved sessions.", - category=CommandCategory.SESSION, - )) - - registry.register(CommandDefinition( - name="/resume", - description="Resume a session", - help_text="Resume a previously saved session. Opens interactive picker if no session ID is provided.", - positional_arg="session_id", - positional_hint="ID", - category=CommandCategory.SESSION, - )) - - registry.register(CommandDefinition( - name="/save", - description="Save current session", - help_text="Save the current session including embryo states and conversation history.", - category=CommandCategory.SESSION, - )) - - registry.register(CommandDefinition( - name="/import-embryos", - description="Import embryos from session", - help_text="""Import embryo definitions from another session into the current session. + registry.register( + CommandDefinition( + name="/sessions", + description="Browse saved sessions", + help_text="Open interactive session browser to view and select from saved sessions.", + category=CommandCategory.SESSION, + ) + ) + + registry.register( + CommandDefinition( + name="/resume", + description="Resume a session", + help_text=( + "Resume a previously saved session." + " Opens interactive picker if no session ID is provided." + ), + positional_arg="session_id", + positional_hint="ID", + category=CommandCategory.SESSION, + ) + ) + + registry.register( + CommandDefinition( + name="/save", + description="Save current session", + help_text="Save the current session including embryo states and conversation history.", + category=CommandCategory.SESSION, + ) + ) + + registry.register( + CommandDefinition( + name="/import-embryos", + description="Import embryos from session", + help_text="""Import embryo definitions from another session into the current session. Use 'last' to import from the most recent session with embryos.""", - positional_arg="session_id", - positional_hint="ID|last", - category=CommandCategory.SESSION, - )) + positional_arg="session_id", + positional_hint="ID|last", + category=CommandCategory.SESSION, + ) + ) - registry.register(CommandDefinition( - name="/make-video", - description="Create timelapse video", - help_text="""Generate MP4 video from timelapse volumes in current session. + registry.register( + CommandDefinition( + name="/make-video", + description="Create timelapse video", + help_text="""Generate MP4 video from timelapse volumes in current session. -Creates max projection videos for each embryo. Optionally specify embryo ID to generate video for a single embryo. +Creates max projection videos for each embryo. Optionally specify embryo ID to generate +video for a single embryo. Options: --fps N Frames per second (default: 10) --all Include all embryos""", - positional_arg="embryo_id", - positional_hint="embryo_id", - options=[ - CommandOption( - name="--fps", - description="Frames per second", - takes_value=True, - value_hint="N", - ), - ], - category=CommandCategory.SESSION, - )) + positional_arg="embryo_id", + positional_hint="embryo_id", + options=[ + CommandOption( + name="--fps", + description="Frames per second", + takes_value=True, + value_hint="N", + ), + ], + category=CommandCategory.SESSION, + ) + ) # === Appearance Commands === - registry.register(CommandDefinition( - name="/theme", - description="Switch color theme", - help_text="Change the CLI color theme.\n\nAvailable themes: vibrant, scientific, claude, monochrome", - positional_arg="name", - positional_hint="name", - category=CommandCategory.APPEARANCE, - )) - - registry.register(CommandDefinition( - name="/history", - description="Show conversation history", - help_text="Display recent conversation history with the agent.", - category=CommandCategory.APPEARANCE, - )) - - registry.register(CommandDefinition( - name="/tokens", - description="Show API token usage", - help_text="Display token usage statistics and estimated cost for the current session.", - category=CommandCategory.APPEARANCE, - )) + registry.register( + CommandDefinition( + name="/theme", + description="Switch color theme", + help_text=( + "Change the CLI color theme.\n\n" + "Available themes: vibrant, scientific, claude, monochrome" + ), + positional_arg="name", + positional_hint="name", + category=CommandCategory.APPEARANCE, + ) + ) + + registry.register( + CommandDefinition( + name="/history", + description="Show conversation history", + help_text="Display recent conversation history with the agent.", + category=CommandCategory.APPEARANCE, + ) + ) + + registry.register( + CommandDefinition( + name="/tokens", + description="Show API token usage", + help_text="Display token usage statistics and estimated cost for the current session.", + category=CommandCategory.APPEARANCE, + ) + ) # === Diagnostics Commands === - registry.register(CommandDefinition( - name="/test-device", - aliases=["/benchmark"], - description="Test device layer pipeline (acquisition FPS benchmark)", - help_text="""Run end-to-end volume acquisition benchmark. + registry.register( + CommandDefinition( + name="/test-device", + aliases=["/benchmark"], + description="Test device layer pipeline (acquisition FPS benchmark)", + help_text="""Run end-to-end volume acquisition benchmark. Measures the full pipeline latency: - Acquisition: HTTP → device layer → hardware → file written @@ -507,42 +569,44 @@ def _register_default_commands(registry: CommandRegistry) -> None: - Viz push: Push to visualization server (if running) Requires microscope connection and at least one registered embryo.""", - options=[ - CommandOption( - name="--volumes", - short="-n", - description="Number of volumes to acquire", - takes_value=True, - value_hint="N", - ), - CommandOption( - name="--slices", - short="-s", - description="Slices per volume", - takes_value=True, - value_hint="N", - ), - CommandOption( - name="--warmup", - short="-w", - description="Warmup volumes (not timed)", - takes_value=True, - value_hint="N", - ), - CommandOption( - name="--save", - description="Save results to CSV", - is_flag=True, - ), - ], - category=CommandCategory.DIAGNOSTICS, - )) + options=[ + CommandOption( + name="--volumes", + short="-n", + description="Number of volumes to acquire", + takes_value=True, + value_hint="N", + ), + CommandOption( + name="--slices", + short="-s", + description="Slices per volume", + takes_value=True, + value_hint="N", + ), + CommandOption( + name="--warmup", + short="-w", + description="Warmup volumes (not timed)", + takes_value=True, + value_hint="N", + ), + CommandOption( + name="--save", + description="Save results to CSV", + is_flag=True, + ), + ], + category=CommandCategory.DIAGNOSTICS, + ) + ) # === Planning Commands === - registry.register(CommandDefinition( - name="/campaign", - description="View or manage campaigns", - help_text="""Browse and manage campaigns and experimental plans. + registry.register( + CommandDefinition( + name="/campaign", + description="View or manage campaigns", + help_text="""Browse and manage campaigns and experimental plans. Usage: /campaign List all campaigns with progress summary @@ -552,30 +616,32 @@ def _register_default_commands(registry: CommandRegistry) -> None: /campaign unshare Stop sharing a campaign Use plan mode (/plan) to create and modify campaigns.""", - aliases=["/campaigns"], - positional_arg="campaign_id", - positional_hint="ID", - subcommands=[ - SubCommand( - name="delete", - description="Delete a campaign and its plan items", - ), - SubCommand( - name="share", - description="Share a campaign on the mesh for coordination", - ), - SubCommand( - name="unshare", - description="Stop sharing a campaign on the mesh", - ), - ], - category=CommandCategory.PLANNING, - )) - - registry.register(CommandDefinition( - name="/plan", - description="Switch to plan mode for experimental design", - help_text="""Enter plan mode to design experiments with the agent. + aliases=["/campaigns"], + positional_arg="campaign_id", + positional_hint="ID", + subcommands=[ + SubCommand( + name="delete", + description="Delete a campaign and its plan items", + ), + SubCommand( + name="share", + description="Share a campaign on the mesh for coordination", + ), + SubCommand( + name="unshare", + description="Stop sharing a campaign on the mesh", + ), + ], + category=CommandCategory.PLANNING, + ) + ) + + registry.register( + CommandDefinition( + name="/plan", + description="Switch to plan mode for experimental design", + help_text="""Enter plan mode to design experiments with the agent. In plan mode, the agent acts as a scientific collaborator — helping design campaigns, choose strains, set imaging parameters, plan controls, @@ -585,85 +651,103 @@ def _register_default_commands(registry: CommandRegistry) -> None: /plan Enter plan mode (or show status if already in plan mode) /plan status Show current plan progress /plan exit Return to run mode""", - subcommands=[ - SubCommand( - name="status", - description="Show current plan progress", + subcommands=[ + SubCommand( + name="status", + description="Show current plan progress", + ), + SubCommand( + name="exit", + description="Return to run mode", + ), + ], + category=CommandCategory.PLANNING, + ) + ) + + registry.register( + CommandDefinition( + name="/reset-context", + description="Clear the context database (for testing)", + help_text=( + "Wipe all campaigns, learnings, session intents, and other context.\n" + "The startup wizard will run again on next launch." ), - SubCommand( - name="exit", - description="Return to run mode", + category=CommandCategory.DIAGNOSTICS, + ) + ) + + registry.register( + CommandDefinition( + name="/wizard", + description="Run the startup wizard", + help_text=( + "Re-run the onboarding wizard to set organism, campaign, and session intent.\n" + "Useful after /reset-context or to change your current setup." ), - ], - category=CommandCategory.PLANNING, - )) - - registry.register(CommandDefinition( - name="/reset-context", - description="Clear the context database (for testing)", - help_text="Wipe all campaigns, learnings, session intents, and other context.\nThe startup wizard will run again on next launch.", - category=CommandCategory.DIAGNOSTICS, - )) - - registry.register(CommandDefinition( - name="/wizard", - description="Run the startup wizard", - help_text="Re-run the onboarding wizard to set organism, campaign, and session intent.\nUseful after /reset-context or to change your current setup.", - category=CommandCategory.SESSION, - )) - - registry.register(CommandDefinition( - name="/peers", - description="Show mesh peers on the network", - help_text="""List all Gently instances discovered on the LAN. + category=CommandCategory.SESSION, + ) + ) + + registry.register( + CommandDefinition( + name="/peers", + description="Show mesh peers on the network", + help_text="""List all Gently instances discovered on the LAN. Shows hostname, capabilities (GPU, SAM, microscope), and status for each peer. Usage: /peers List all peers /peers campaigns Show shared campaigns on a peer""", - aliases=["/mesh"], - positional_arg="hostname", - positional_hint="HOSTNAME", - subcommands=[ - SubCommand( - name="campaigns", - description="Show shared campaigns on a peer", - ), - ], - category=CommandCategory.INSPECTION, - )) + aliases=["/mesh"], + positional_arg="hostname", + positional_hint="HOSTNAME", + subcommands=[ + SubCommand( + name="campaigns", + description="Show shared campaigns on a peer", + ), + ], + category=CommandCategory.INSPECTION, + ) + ) # === Mesh coordination commands === - registry.register(CommandDefinition( - name="/join-campaign", - description="Join a shared campaign on a peer", - help_text="""Join a campaign shared by a mesh peer. + registry.register( + CommandDefinition( + name="/join-campaign", + description="Join a shared campaign on a peer", + help_text="""Join a campaign shared by a mesh peer. Usage: /join-campaign After joining, use /claim to claim items for execution.""", - positional_hint="HOSTNAME CAMPAIGN_ID", - category=CommandCategory.PLANNING, - )) + positional_hint="HOSTNAME CAMPAIGN_ID", + category=CommandCategory.PLANNING, + ) + ) - registry.register(CommandDefinition( - name="/claim", - description="Claim a plan item from a shared campaign", - help_text="""Claim a plan item from a joined remote campaign. + registry.register( + CommandDefinition( + name="/claim", + description="Claim a plan item from a shared campaign", + help_text="""Claim a plan item from a joined remote campaign. Usage: /claim Requires an active remote campaign (via /join-campaign).""", - positional_hint="ITEM_ID", - category=CommandCategory.PLANNING, - )) + positional_hint="ITEM_ID", + category=CommandCategory.PLANNING, + ) + ) - registry.register(CommandDefinition( - name="/pair", - description="Pair with a mesh peer for secure communication", - help_text="""Bluetooth-style pairing with mesh peers. + registry.register( + CommandDefinition( + name="/pair", + description="Pair with a mesh peer for secure communication", + help_text="""Bluetooth-style pairing with mesh peers. Usage: /pair Initiate pairing with a peer (shows PIN) @@ -673,24 +757,28 @@ def _register_default_commands(registry: CommandRegistry) -> None: /pair unpair Remove trust for a peer (hostname or instance_id) /pair scopes Show scopes for all peers /pair scopes Set scopes for a peer""", - positional_arg="target", - positional_hint="HOSTNAME|accept|reject|list|unpair|scopes", - subcommands=[ - SubCommand(name="accept", description="Accept a pending pairing request"), - SubCommand(name="reject", description="Reject a pending pairing request"), - SubCommand(name="list", description="Show all trusted peers"), - SubCommand(name="unpair", description="Remove trust for a peer"), - SubCommand(name="scopes", description="View or set permission scopes for a peer"), - ], - category=CommandCategory.PLANNING, - )) + positional_arg="target", + positional_hint="HOSTNAME|accept|reject|list|unpair|scopes", + subcommands=[ + SubCommand(name="accept", description="Accept a pending pairing request"), + SubCommand(name="reject", description="Reject a pending pairing request"), + SubCommand(name="list", description="Show all trusted peers"), + SubCommand(name="unpair", description="Remove trust for a peer"), + SubCommand( + name="scopes", + description="View or set permission scopes for a peer", + ), + ], + category=CommandCategory.PLANNING, + ) + ) # ============================================================================ # Global Registry # ============================================================================ -_command_registry: Optional[CommandRegistry] = None +_command_registry: CommandRegistry | None = None def get_command_registry() -> CommandRegistry: diff --git a/gently/harness/conversation.py b/gently/harness/conversation.py index 8384ba9c..e154c724 100644 --- a/gently/harness/conversation.py +++ b/gently/harness/conversation.py @@ -10,13 +10,34 @@ import logging import re import time -from typing import Dict, List, Optional, Any - -from ..settings import settings +from typing import Any logger = logging.getLogger(__name__) +def _extend_tool_calls(out: list[dict[str, Any]], content_blocks) -> None: + """Append every tool_use block in content_blocks to out. + + Tolerates absent attributes (some SDK versions / mock objects) so it + never crashes the live agent on a content-shape surprise. + """ + if not content_blocks: + return + for block in content_blocks: + try: + if getattr(block, "type", None) != "tool_use": + continue + out.append( + { + "name": getattr(block, "name", None), + "input": getattr(block, "input", None), + "id": getattr(block, "id", None), + } + ) + except Exception: + continue + + class ConversationManager: """ Manages Claude API conversations, tool execution, and token tracking. @@ -34,7 +55,7 @@ def __init__(self, client, model, tool_registry): self._tool_registry = tool_registry # Conversation state - self.conversation_history: List[Dict] = [] + self.conversation_history: list[dict] = [] # Token counters self.total_input_tokens: int = 0 @@ -48,10 +69,16 @@ def __init__(self, client, model, tool_registry): self.choice_handler = None self.context_store = None # for tool_label + # Decision capture for orchestrator A/B testing. Set by the agent + # alongside the EventCapture once the session folder is known. None + # = no capture, so tests / harnesses without a session still work. + self.decision_log = None + # ===== Quick Response ===== - def try_quick_response(self, message: str, experiment, mode: str, - enter_plan_fn, exit_plan_fn) -> Optional[str]: + def try_quick_response( + self, message: str, experiment, mode: str, enter_plan_fn, exit_plan_fn + ) -> str | None: """ Answer simple queries from state without LLM call. @@ -80,7 +107,13 @@ def try_quick_response(self, message: str, experiment, mode: str, return experiment.get_summary() # Plan mode switching via natural language - plan_enter_phrases = ("plan mode", "enter plan", "switch to plan", "let's plan", "design an experiment") + plan_enter_phrases = ( + "plan mode", + "enter plan", + "switch to plan", + "let's plan", + "design an experiment", + ) plan_exit_phrases = ("exit plan", "leave plan", "back to run", "run mode") if mode != "plan" and any(p in message_lower for p in plan_enter_phrases): @@ -116,30 +149,36 @@ def should_use_thinking(self, message: str, mode: str) -> bool: if mode == "plan": return True - import re msg_lower = message.lower() - if re.search(r'\bthink(ing)?\b', message, re.IGNORECASE): + if re.search(r"\bthink(ing)?\b", message, re.IGNORECASE): return True - if re.search(r'\bcalibrat', msg_lower): + if re.search(r"\bcalibrat", msg_lower): return True - if re.search(r'\b(plan|timelapse|time-lapse|acquisition)\b', msg_lower): + if re.search(r"\b(plan|timelapse|time-lapse|acquisition)\b", msg_lower): return True - if re.search(r'\b(analy[sz]e|look at|check|inspect|review).*(image|volume|embryo)', msg_lower): + if re.search( + r"\b(analy[sz]e|look at|check|inspect|review).*(image|volume|embryo)", + msg_lower, + ): return True - if re.search(r'\b(all|every|each)\s+(embryo|sample)', msg_lower): + if re.search(r"\b(all|every|each)\s+(embryo|sample)", msg_lower): return True - if re.search(r'\b(first|then|after|next|finally)\b.*\b(first|then|after|next|finally)\b', msg_lower): + if re.search( + r"\b(first|then|after|next|finally)\b.*\b(first|then|after|next|finally)\b", + msg_lower, + ): return True - if re.search(r'\b(why|problem|issue|error|wrong|fail|debug|troubleshoot)', msg_lower): + if re.search(r"\b(why|problem|issue|error|wrong|fail|debug|troubleshoot)", msg_lower): return True return False # ===== Non-Streaming API Call ===== - async def call_claude(self, user_message: str, system_prompt, tools, - mode: str, auto_save_fn) -> str: + async def call_claude( + self, user_message: str, system_prompt, tools, mode: str, auto_save_fn + ) -> str: """ Call Claude API with full context and tool access (non-streaming). @@ -171,10 +210,28 @@ async def call_claude(self, user_message: str, system_prompt, tools, interaction = self.interaction_logger.start_interaction( user_prompt=user_message, system_state={ - 'acquisition_status': 'unknown', - } + "acquisition_status": "unknown", + }, ) + # Snapshot inputs for decision capture BEFORE the tool loop starts + # appending to conversation_history. This is the state shadow + # candidates would need to reproduce production's input — same + # system_prompt and same starting messages. + decision_prompt_hash = None + if self.decision_log is not None: + try: + from gently.eval import prompt_hash as _prompt_hash + + decision_prompt_hash = _prompt_hash( + system_prompt, + list(self.conversation_history), + ) + except Exception: + logger.exception("Failed to compute decision prompt_hash") + + tool_calls_collected: list[dict[str, Any]] = [] + assistant_message = "" error_occurred = None try: @@ -189,47 +246,35 @@ async def call_claude(self, user_message: str, system_prompt, tools, budget = 30000 if mode == "plan" else 10000 api_kwargs["thinking"] = {"type": "enabled", "budget_tokens": budget} - response = await self._call_api_with_retry( - self.claude.messages.create, - **api_kwargs - ) + response = await self._call_api_with_retry(self.claude.messages.create, **api_kwargs) self._track_token_usage(response) + _extend_tool_calls(tool_calls_collected, response.content) # Process tool calls while response.stop_reason == "tool_use": - tool_results = await self._execute_tools_with_logging( - response.content, interaction - ) + tool_results = await self._execute_tools_with_logging(response.content, interaction) - self.conversation_history.append({ - "role": "assistant", - "content": response.content - }) - self.conversation_history.append({ - "role": "user", - "content": tool_results - }) + self.conversation_history.append({"role": "assistant", "content": response.content}) + self.conversation_history.append({"role": "user", "content": tool_results}) api_kwargs["messages"] = self.conversation_history response = await self._call_api_with_retry( - self.claude.messages.create, - **api_kwargs + self.claude.messages.create, **api_kwargs ) self._track_token_usage(response) + _extend_tool_calls(tool_calls_collected, response.content) # Extract text response assistant_message = "" for block in response.content: - if hasattr(block, 'text'): + if hasattr(block, "text"): assistant_message += block.text - self.conversation_history.append({ - "role": "assistant", - "content": response.content - }) + self.conversation_history.append({"role": "assistant", "content": response.content}) except Exception as e: import traceback + error_occurred = str(e) error_tb = traceback.format_exc() assistant_message = f"Error: {error_occurred}" @@ -242,6 +287,14 @@ async def call_claude(self, user_message: str, system_prompt, tools, error=error_occurred, error_traceback=error_tb, ) + self._write_production_decision( + user_message=user_message, + tool_calls=tool_calls_collected, + response_text=assistant_message, + duration_ms=(time.time() - start_time) * 1000.0, + prompt_hash_value=decision_prompt_hash, + error=error_occurred, + ) raise if interaction and self.interaction_logger: @@ -251,13 +304,61 @@ async def call_claude(self, user_message: str, system_prompt, tools, total_duration_seconds=time.time() - start_time, ) + self._write_production_decision( + user_message=user_message, + tool_calls=tool_calls_collected, + response_text=assistant_message, + duration_ms=(time.time() - start_time) * 1000.0, + prompt_hash_value=decision_prompt_hash, + error=None, + ) + auto_save_fn() return assistant_message + def _write_production_decision( + self, + *, + user_message: str, + tool_calls: list[dict[str, Any]], + response_text: str, + duration_ms: float, + prompt_hash_value: str | None, + error: str | None, + ) -> None: + """Persist one production Decision row (best-effort). + + Failures here are swallowed — decision capture must never break + the live agent. The DecisionLog itself is also tolerant of + serialisation errors. + """ + if self.decision_log is None: + return + try: + from datetime import datetime + + from gently.eval import Decision, DecisionTrigger + + self.decision_log.append( + Decision( + timestamp=datetime.now(), + agent="production", + trigger=DecisionTrigger.USER_MESSAGE, + trigger_detail=(user_message or "")[:200], + tool_calls=tool_calls, + response_text=response_text, + prompt_hash=prompt_hash_value, + duration_ms=duration_ms, + error=error, + ) + ) + except Exception: + logger.exception("Failed to write production Decision") + # ===== Dry-Run Tool Call (Benchmarking) ===== - async def get_tool_call(self, user_message: str, system_prompt, tools) -> Optional[Dict]: + async def get_tool_call(self, user_message: str, system_prompt, tools) -> dict | None: """ Get what tool Claude would call without executing it (dry-run mode). @@ -280,10 +381,7 @@ async def get_tool_call(self, user_message: str, system_prompt, tools) -> Option start_time = time.time() messages = self.conversation_history.copy() - messages.append({ - "role": "user", - "content": user_message - }) + messages.append({"role": "user", "content": user_message}) try: api_kwargs = { @@ -294,15 +392,12 @@ async def get_tool_call(self, user_message: str, system_prompt, tools) -> Option "max_tokens": 4096, } - response = await self._call_api_with_retry( - self.claude.messages.create, - **api_kwargs - ) + response = await self._call_api_with_retry(self.claude.messages.create, **api_kwargs) latency_ms = (time.time() - start_time) * 1000 - input_tokens = getattr(response.usage, 'input_tokens', 0) - output_tokens = getattr(response.usage, 'output_tokens', 0) + input_tokens = getattr(response.usage, "input_tokens", 0) + output_tokens = getattr(response.usage, "output_tokens", 0) for block in response.content: if block.type == "tool_use": @@ -322,7 +417,7 @@ async def get_tool_call(self, user_message: str, system_prompt, tools) -> Option # ===== Tool Execution ===== - async def _execute_tools_with_logging(self, content_blocks, interaction) -> List[Dict]: + async def _execute_tools_with_logging(self, content_blocks, interaction) -> list[dict]: """ Execute Claude's tool calls with interaction logging. @@ -354,7 +449,10 @@ async def _execute_tools_with_logging(self, content_blocks, interaction) -> List if self.choice_handler and isinstance(result, str): try: choice_data = json.loads(result) - if isinstance(choice_data, dict) and choice_data.get("_type") == CHOICE_RESPONSE_TYPE: + if ( + isinstance(choice_data, dict) + and choice_data.get("_type") == CHOICE_RESPONSE_TYPE + ): user_selection = await self.choice_handler(choice_data) result = user_selection except (json.JSONDecodeError, TypeError): @@ -378,19 +476,20 @@ async def _execute_tools_with_logging(self, content_blocks, interaction) -> List error_message=error_message, ) - results.append({ - "type": "tool_result", - "tool_use_id": block.id, - "content": result, - "is_error": is_error, - }) + results.append( + { + "type": "tool_result", + "tool_use_id": block.id, + "content": result, + "is_error": is_error, + } + ) return results # ===== Streaming API Call ===== - async def call_claude_stream(self, system_prompt, tools, - tool_label_fn, auto_save_fn): + async def call_claude_stream(self, system_prompt, tools, tool_label_fn, auto_save_fn): """ Call Claude API with streaming enabled. @@ -421,7 +520,7 @@ def stream_and_collect(): system=system_prompt, messages=self.conversation_history, tools=tools, - max_tokens=4096 + max_tokens=4096, ) as stream: for event in stream: events.append(event) @@ -439,15 +538,24 @@ def stream_and_collect(): self._track_token_usage(final_message) break except APIStatusError as e: - error_type = getattr(e, 'body', {}) + error_type = getattr(e, "body", {}) if isinstance(error_type, dict): - error_type = error_type.get('error', {}).get('type', '') + error_type = error_type.get("error", {}).get("type", "") - if error_type in ('overloaded_error', 'rate_limit_error') or 'overloaded' in str(e).lower(): + if ( + error_type in ("overloaded_error", "rate_limit_error") + or "overloaded" in str(e).lower() + ): if attempt < max_retries - 1: - wait_time = retry_delay * (2 ** attempt) - logger.warning(f"API overloaded, retrying in {wait_time:.1f}s (attempt {attempt + 1}/{max_retries})") - yield {'type': 'text', 'text': f"\n*[API busy, retrying in {wait_time:.0f}s...]*\n"} + wait_time = retry_delay * (2**attempt) + logger.warning( + f"API overloaded, retrying in {wait_time:.1f}s" + f" (attempt {attempt + 1}/{max_retries})" + ) + yield { + "type": "text", + "text": f"\n*[API busy, retrying in {wait_time:.0f}s...]*\n", + } await asyncio.sleep(wait_time) continue raise @@ -456,27 +564,31 @@ def stream_and_collect(): # Diagnostic: log stop_reason and tool block counts tool_block_count = sum( - 1 for b in final_message.content - if hasattr(b, 'type') and b.type == 'tool_use' + 1 for b in final_message.content if hasattr(b, "type") and b.type == "tool_use" ) logger.warning( - "Claude response: stop_reason=%s, content_blocks=%d, tool_use_blocks=%d, tools_passed=%d, model=%s", - final_message.stop_reason, len(final_message.content), - tool_block_count, len(tools), self.model, + "Claude response: stop_reason=%s, content_blocks=%d, tool_use_blocks=%d," + " tools_passed=%d, model=%s", + final_message.stop_reason, + len(final_message.content), + tool_block_count, + len(tools), + self.model, ) if tool_block_count > 0 and final_message.stop_reason != "tool_use": logger.error( "BUG: Claude returned %d tool_use blocks but stop_reason=%s (expected 'tool_use')", - tool_block_count, final_message.stop_reason, + tool_block_count, + final_message.stop_reason, ) # Process events and yield text full_text = [] for event in events: if event.type == "content_block_delta": - if hasattr(event.delta, 'text'): + if hasattr(event.delta, "text"): full_text.append(event.delta.text) - yield {'type': 'text', 'text': event.delta.text} + yield {"type": "text", "text": event.delta.text} # Detect fake XML tool calls in text (Claude writing tool_use as text) joined_text = "".join(full_text) @@ -484,7 +596,8 @@ def stream_and_collect(): logger.error( "DETECTED: Claude wrote XML tool tags as plain text instead of " "using API tool_use mechanism. stop_reason=%s, text_preview=%.200s", - final_message.stop_reason, joined_text[:200], + final_message.stop_reason, + joined_text[:200], ) response_content = final_message.content @@ -498,60 +611,83 @@ def stream_and_collect(): tool_results = [] for block in response_content: - if hasattr(block, 'type') and block.type == "tool_use": + if hasattr(block, "type") and block.type == "tool_use": start_time = time.time() yield { - 'type': 'tool_start', - 'tool_name': block.name, - 'tool_input': block.input, - 'tool_label': tool_label_fn(block.name, block.input), + "type": "tool_start", + "tool_name": block.name, + "tool_input": block.input, + "tool_label": tool_label_fn(block.name, block.input), } + is_error_flag = False + result_text = "" try: tool_result = await self._execute_single_tool(block.name, block.input) if isinstance(tool_result, str): try: - from gently.app.tools.interaction_tools import CHOICE_RESPONSE_TYPE + from gently.app.tools.interaction_tools import ( + CHOICE_RESPONSE_TYPE, + ) + choice_data = json.loads(tool_result) - if isinstance(choice_data, dict) and choice_data.get("_type") == CHOICE_RESPONSE_TYPE: + if ( + isinstance(choice_data, dict) + and choice_data.get("_type") == CHOICE_RESPONSE_TYPE + ): user_selection = yield { - 'type': 'choice_request', - 'choice_data': choice_data + "type": "choice_request", + "choice_data": choice_data, } tool_result = user_selection or "cancelled" except (json.JSONDecodeError, TypeError): pass - tool_results.append({ - "type": "tool_result", - "tool_use_id": block.id, - "content": tool_result - }) + result_text = ( + tool_result if isinstance(tool_result, str) else str(tool_result) + ) + tool_results.append( + { + "type": "tool_result", + "tool_use_id": block.id, + "content": tool_result, + } + ) except Exception as e: - tool_results.append({ - "type": "tool_result", - "tool_use_id": block.id, - "content": f"Error: {str(e)}", - "is_error": True - }) + is_error_flag = True + result_text = f"Error: {str(e)}" + tool_results.append( + { + "type": "tool_result", + "tool_use_id": block.id, + "content": result_text, + "is_error": True, + } + ) + + # First non-empty line of the result, trimmed — gives the chat + # UI a one-line summary so the operator can see what a tool did + # (or didn't do), not just that it ran. + result_summary = next( + (ln.strip() for ln in (result_text or "").splitlines() if ln.strip()), + "", + ) + if len(result_summary) > 140: + result_summary = result_summary[:139] + "…" yield { - 'type': 'tool_call', - 'tool_name': block.name, - 'tool_input': block.input, - 'duration': time.time() - start_time, + "type": "tool_call", + "tool_name": block.name, + "tool_input": block.input, + "duration": time.time() - start_time, + "result_summary": result_summary, + "is_error": is_error_flag, } - self.conversation_history.append({ - "role": "assistant", - "content": response_content - }) - self.conversation_history.append({ - "role": "user", - "content": tool_results - }) + self.conversation_history.append({"role": "assistant", "content": response_content}) + self.conversation_history.append({"role": "user", "content": tool_results}) auto_save_fn() @@ -573,15 +709,12 @@ def stream_and_collect(): else: # No tool calls - add final message to history - self.conversation_history.append({ - "role": "assistant", - "content": response_content - }) + self.conversation_history.append({"role": "assistant", "content": response_content}) auto_save_fn() # ===== Tool Label ===== - def tool_label(self, tool_name: str, tool_input: Dict) -> str: + def tool_label(self, tool_name: str, tool_input: dict) -> str: """Build a human-readable label for a tool call. Used in tool_start chunks so the TUI shows biologist-friendly @@ -595,8 +728,13 @@ def tool_label(self, tool_name: str, tool_input: Dict) -> str: campaign = self.context_store.get_campaign(campaign_id) if campaign: campaign_label = campaign.shorthand or campaign.description - if tool_name in ("propose_plan", "get_plan_status", "export_plan", - "snapshot_plan", "list_plan_versions"): + if tool_name in ( + "propose_plan", + "get_plan_status", + "export_plan", + "snapshot_plan", + "list_plan_versions", + ): return campaign_label if tool_name == "create_campaign" and inp.get("parent_id"): return f"phase under {campaign_label}" @@ -614,8 +752,12 @@ def tool_label(self, tool_name: str, tool_input: Dict) -> str: # Item reference tools item_ref = inp.get("item_ref") or inp.get("ref") or inp.get("item_id") - if item_ref and tool_name in ("get_plan_item", "update_plan_item", - "delete_plan_item", "move_plan_item"): + if item_ref and tool_name in ( + "get_plan_item", + "update_plan_item", + "delete_plan_item", + "move_plan_item", + ): if self.context_store: item = self.context_store.resolve_plan_item(str(item_ref), campaign_id=campaign_id) if item: @@ -642,8 +784,9 @@ def tool_label(self, tool_name: str, tool_input: Dict) -> str: return "" - async def _execute_single_tool(self, tool_name: str, tool_input: Dict, - context: Optional[Dict] = None) -> str: + async def _execute_single_tool( + self, tool_name: str, tool_input: dict, context: dict | None = None + ) -> str: """Execute a single tool call using the tool registry. Parameters @@ -663,13 +806,13 @@ async def _execute_single_tool(self, tool_name: str, tool_input: Dict, def _track_token_usage(self, response): """Track token usage from API response, including cache metrics.""" - if hasattr(response, 'usage'): + if hasattr(response, "usage"): usage = response.usage self.total_input_tokens += usage.input_tokens self.total_output_tokens += usage.output_tokens self.api_call_count += 1 - self.cache_creation_tokens += getattr(usage, 'cache_creation_input_tokens', 0) - self.cache_read_tokens += getattr(usage, 'cache_read_input_tokens', 0) + self.cache_creation_tokens += getattr(usage, "cache_creation_input_tokens", 0) + self.cache_read_tokens += getattr(usage, "cache_read_input_tokens", 0) @property def current_context_tokens(self) -> int: @@ -680,14 +823,14 @@ def current_context_tokens(self) -> int: conv_chars = 0 for msg in self.conversation_history: - content = msg.get('content', '') + content = msg.get("content", "") if isinstance(content, str): conv_chars += len(content) elif isinstance(content, list): for block in content: if isinstance(block, dict): - conv_chars += len(str(block.get('text', ''))) - elif hasattr(block, 'text'): + conv_chars += len(str(block.get("text", ""))) + elif hasattr(block, "text"): conv_chars += len(str(block.text)) else: conv_chars += len(str(block)) @@ -753,19 +896,22 @@ async def _call_api_with_retry(self, api_func, *args, max_retries=3, **kwargs): try: return await asyncio.to_thread(api_func, *args, **kwargs) except APIStatusError as e: - error_type = getattr(e, 'body', {}) + error_type = getattr(e, "body", {}) if isinstance(error_type, dict): - error_type = error_type.get('error', {}).get('type', '') + error_type = error_type.get("error", {}).get("type", "") is_retryable = ( - error_type in ('overloaded_error', 'rate_limit_error') or - 'overloaded' in str(e).lower() or - 'rate_limit' in str(e).lower() + error_type in ("overloaded_error", "rate_limit_error") + or "overloaded" in str(e).lower() + or "rate_limit" in str(e).lower() ) if is_retryable and attempt < max_retries - 1: - wait_time = retry_delay * (2 ** attempt) - logger.warning(f"API error ({error_type}), retrying in {wait_time:.1f}s (attempt {attempt + 1}/{max_retries})") + wait_time = retry_delay * (2**attempt) + logger.warning( + f"API error ({error_type}), retrying in {wait_time:.1f}s" + f" (attempt {attempt + 1}/{max_retries})" + ) await asyncio.sleep(wait_time) continue diff --git a/gently/harness/detection/detector.py b/gently/harness/detection/detector.py index 8cb5fa8b..d5729681 100644 --- a/gently/harness/detection/detector.py +++ b/gently/harness/detection/detector.py @@ -2,21 +2,23 @@ Generic detector system for runtime-configurable event detection """ -from dataclasses import dataclass, field, asdict +from dataclasses import asdict, dataclass, field from datetime import datetime -from typing import Dict, List, Optional, Any from enum import Enum +from typing import Any class DetectionMode(str, Enum): """Action mode when detector fires""" - PASSIVE = "passive" # Just flag, no action - RECOMMEND = "recommend" # Suggest actions to user - AUTO = "auto" # Execute actions automatically + + PASSIVE = "passive" # Just flag, no action + RECOMMEND = "recommend" # Suggest actions to user + AUTO = "auto" # Execute actions automatically class ConfidenceLevel(str, Enum): """Confidence levels for detections""" + LOW = "LOW" MEDIUM = "MEDIUM" HIGH = "HIGH" @@ -25,15 +27,20 @@ class ConfidenceLevel(str, Enum): @dataclass class DetectorConditions: """Conditions for when to run a detector""" - min_timepoint: Optional[int] = None # Don't run before this timepoint - max_timepoint: Optional[int] = None # Don't run after this timepoint - embryo_ids: Optional[List[str]] = None # Only run on these embryos (None = all) - run_if_detected: bool = True # Continue running after first detection? - min_interval_timepoints: int = 1 # Minimum timepoints between runs - - def should_run(self, embryo_id: str, timepoint: int, - last_run_timepoint: Optional[int], - already_detected: bool) -> bool: + + min_timepoint: int | None = None # Don't run before this timepoint + max_timepoint: int | None = None # Don't run after this timepoint + embryo_ids: list[str] | None = None # Only run on these embryos (None = all) + run_if_detected: bool = True # Continue running after first detection? + min_interval_timepoints: int = 1 # Minimum timepoints between runs + + def should_run( + self, + embryo_id: str, + timepoint: int, + last_run_timepoint: int | None, + already_detected: bool, + ) -> bool: """ Check if detector should run @@ -78,11 +85,12 @@ def should_run(self, embryo_id: str, timepoint: int, @dataclass class DetectorActions: """Actions to take when detector fires""" + mode: DetectionMode = DetectionMode.RECOMMEND - parameter_changes: Optional[Dict[str, Any]] = None # e.g., {"interval_seconds": 60} - stop_timelapse: bool = False # Stop timelapse when detected - custom_message: Optional[str] = None # Custom notification - webhook_url: Optional[str] = None # External notification + parameter_changes: dict[str, Any] | None = None # e.g., {"interval_seconds": 60} + stop_timelapse: bool = False # Stop timelapse when detected + custom_message: str | None = None # Custom notification + webhook_url: str | None = None # External notification def get_recommendation_message(self, detector_name: str, embryo_id: str) -> str: """Generate recommendation message for user""" @@ -103,25 +111,26 @@ def get_recommendation_message(self, detector_name: str, embryo_id: str) -> str: @dataclass class DetectionResult: """Result of a single detection attempt""" + detector_name: str embryo_id: str timepoint: int timestamp: datetime detected: bool - confidence: Optional[ConfidenceLevel] = None - reasoning: Optional[str] = None + confidence: ConfidenceLevel | None = None + reasoning: str | None = None error: bool = False - error_message: Optional[str] = None - api_duration: Optional[float] = None # seconds + error_message: str | None = None + api_duration: float | None = None # seconds num_images: int = 1 - full_response: Optional[str] = None + full_response: str | None = None - def to_dict(self) -> Dict: + def to_dict(self) -> dict: """Convert to dictionary""" d = asdict(self) - d['timestamp'] = self.timestamp.isoformat() + d["timestamp"] = self.timestamp.isoformat() if self.confidence: - d['confidence'] = self.confidence.value + d["confidence"] = self.confidence.value return d @@ -133,22 +142,23 @@ class Detector: A detector analyzes volumes using Claude Vision API to detect specific events or states (e.g., "comma stage", "hatching", "neural activity"). """ - name: str # Unique identifier (e.g., "comma_stage") - description: str # Human-readable description - detection_prompt: str # Claude Vision API prompt - enabled: bool = True # Can be toggled on/off + + name: str # Unique identifier (e.g., "comma_stage") + description: str # Human-readable description + detection_prompt: str # Claude Vision API prompt + enabled: bool = True # Can be toggled on/off conditions: DetectorConditions = field(default_factory=DetectorConditions) actions: DetectorActions = field(default_factory=DetectorActions) confidence_threshold: ConfidenceLevel = ConfidenceLevel.MEDIUM - use_temporal_context: bool = True # Include recent images? - temporal_context_size: int = 5 # How many recent images + use_temporal_context: bool = True # Include recent images? + temporal_context_size: int = 5 # How many recent images created_at: datetime = field(default_factory=datetime.now) modified_at: datetime = field(default_factory=datetime.now) - detection_count: int = 0 # Total detections fired - run_count: int = 0 # Total times run + detection_count: int = 0 # Total detections fired + run_count: int = 0 # Total times run # Tracking per-embryo state - _last_run_timepoint: Dict[str, int] = field(default_factory=dict, repr=False) + _last_run_timepoint: dict[str, int] = field(default_factory=dict, repr=False) _detected_embryos: set = field(default_factory=set, repr=False) def should_run(self, embryo_id: str, timepoint: int) -> bool: @@ -173,9 +183,7 @@ def should_run(self, embryo_id: str, timepoint: int) -> bool: last_run = self._last_run_timepoint.get(embryo_id) already_detected = embryo_id in self._detected_embryos - return self.conditions.should_run( - embryo_id, timepoint, last_run, already_detected - ) + return self.conditions.should_run(embryo_id, timepoint, last_run, already_detected) def mark_run(self, embryo_id: str, timepoint: int): """Mark that detector ran for this embryo/timepoint""" @@ -191,8 +199,9 @@ def was_detected(self, embryo_id: str) -> bool: """Check if detector already fired for this embryo""" return embryo_id in self._detected_embryos - def build_detection_content(self, images: List[Dict], - embryo_id: str, timepoint: int) -> List[Dict]: + def build_detection_content( + self, images: list[dict], embryo_id: str, timepoint: int + ) -> list[dict]: """ Build Claude Vision API content array @@ -213,57 +222,59 @@ def build_detection_content(self, images: List[Dict], content = [] # Add instruction - content.append({ - "type": "text", - "text": f"Analyzing {embryo_id} at timepoint {timepoint}" - }) + content.append({"type": "text", "text": f"Analyzing {embryo_id} at timepoint {timepoint}"}) # Add temporal context if enabled if self.use_temporal_context and len(images) > 1: - content.append({ - "type": "text", - "text": f"Recent images (for temporal context, {len(images)} timepoints):" - }) + content.append( + { + "type": "text", + "text": f"Recent images (for temporal context, {len(images)} timepoints):", + } + ) # Add older images first for img_data in images[:-1]: - content.append({ - "type": "text", - "text": f"Timepoint {img_data['timepoint']:04d}" - }) - content.append({ - "type": "image", - "source": { - "type": "base64", - "media_type": "image/jpeg", - "data": img_data['b64_image'] + content.append({"type": "text", "text": f"Timepoint {img_data['timepoint']:04d}"}) + content.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": img_data["b64_image"], + }, } - }) + ) # Add current/latest image latest = images[-1] - content.append({ - "type": "text", - "text": f"Current image (timepoint {latest['timepoint']:04d}) - FOCUS YOUR ANALYSIS HERE:" - }) - content.append({ - "type": "image", - "source": { - "type": "base64", - "media_type": "image/jpeg", - "data": latest['b64_image'] + content.append( + { + "type": "text", + "text": ( + f"Current image (timepoint {latest['timepoint']:04d})" + " - FOCUS YOUR ANALYSIS HERE:" + ), + } + ) + content.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": latest["b64_image"], + }, } - }) + ) # Add detection prompt - content.append({ - "type": "text", - "text": self.detection_prompt - }) + content.append({"type": "text", "text": self.detection_prompt}) return content - def parse_detection_response(self, response_text: str) -> Dict: + def parse_detection_response(self, response_text: str) -> dict: """ Parse Claude's detection response @@ -286,98 +297,102 @@ def parse_detection_response(self, response_text: str) -> Dict: confidence = None reasoning = None - lines = response_text.strip().split('\n') + lines = response_text.strip().split("\n") for line in lines: line = line.strip() - if line.startswith('DETECTED:'): - value = line.split(':', 1)[1].strip().upper() - detected = value in ['YES', 'TRUE', '1'] - elif line.startswith('CONFIDENCE:'): - conf_str = line.split(':', 1)[1].strip().upper() + if line.startswith("DETECTED:"): + value = line.split(":", 1)[1].strip().upper() + detected = value in ["YES", "TRUE", "1"] + elif line.startswith("CONFIDENCE:"): + conf_str = line.split(":", 1)[1].strip().upper() try: confidence = ConfidenceLevel(conf_str) except ValueError: confidence = None - elif line.startswith('REASONING:'): - reasoning = line.split(':', 1)[1].strip() + elif line.startswith("REASONING:"): + reasoning = line.split(":", 1)[1].strip() # If multiline reasoning, capture it if reasoning is None: reasoning_start = False reasoning_lines = [] for line in lines: - if line.startswith('REASONING:'): + if line.startswith("REASONING:"): reasoning_start = True - reasoning_lines.append(line.split(':', 1)[1].strip()) + reasoning_lines.append(line.split(":", 1)[1].strip()) elif reasoning_start and line: reasoning_lines.append(line) if reasoning_lines: - reasoning = ' '.join(reasoning_lines) + reasoning = " ".join(reasoning_lines) return { - 'detected': detected if detected is not None else False, - 'confidence': confidence, - 'reasoning': reasoning or "No reasoning provided" + "detected": detected if detected is not None else False, + "confidence": confidence, + "reasoning": reasoning or "No reasoning provided", } - def to_dict(self) -> Dict: + def to_dict(self) -> dict: """Convert to dictionary for serialization""" return { - 'name': self.name, - 'description': self.description, - 'detection_prompt': self.detection_prompt, - 'enabled': self.enabled, - 'conditions': asdict(self.conditions), - 'actions': { - 'mode': self.actions.mode.value, - 'parameter_changes': self.actions.parameter_changes, - 'custom_message': self.actions.custom_message, - 'webhook_url': self.actions.webhook_url, + "name": self.name, + "description": self.description, + "detection_prompt": self.detection_prompt, + "enabled": self.enabled, + "conditions": asdict(self.conditions), + "actions": { + "mode": self.actions.mode.value, + "parameter_changes": self.actions.parameter_changes, + "custom_message": self.actions.custom_message, + "webhook_url": self.actions.webhook_url, }, - 'confidence_threshold': self.confidence_threshold.value, - 'use_temporal_context': self.use_temporal_context, - 'temporal_context_size': self.temporal_context_size, - 'created_at': self.created_at.isoformat(), - 'modified_at': self.modified_at.isoformat(), - 'detection_count': self.detection_count, - 'run_count': self.run_count, + "confidence_threshold": self.confidence_threshold.value, + "use_temporal_context": self.use_temporal_context, + "temporal_context_size": self.temporal_context_size, + "created_at": self.created_at.isoformat(), + "modified_at": self.modified_at.isoformat(), + "detection_count": self.detection_count, + "run_count": self.run_count, } @classmethod - def from_dict(cls, data: Dict) -> 'Detector': + def from_dict(cls, data: dict) -> "Detector": """Create detector from dictionary""" # Parse dates - created_at = datetime.fromisoformat(data['created_at']) if 'created_at' in data else datetime.now() - modified_at = datetime.fromisoformat(data['modified_at']) if 'modified_at' in data else datetime.now() + created_at = ( + datetime.fromisoformat(data["created_at"]) if "created_at" in data else datetime.now() + ) + modified_at = ( + datetime.fromisoformat(data["modified_at"]) if "modified_at" in data else datetime.now() + ) # Parse conditions - conditions = DetectorConditions(**data.get('conditions', {})) + conditions = DetectorConditions(**data.get("conditions", {})) # Parse actions - actions_data = data.get('actions', {}) + actions_data = data.get("actions", {}) actions = DetectorActions( - mode=DetectionMode(actions_data.get('mode', 'recommend')), - parameter_changes=actions_data.get('parameter_changes'), - custom_message=actions_data.get('custom_message'), - webhook_url=actions_data.get('webhook_url'), + mode=DetectionMode(actions_data.get("mode", "recommend")), + parameter_changes=actions_data.get("parameter_changes"), + custom_message=actions_data.get("custom_message"), + webhook_url=actions_data.get("webhook_url"), ) # Parse confidence threshold - confidence_threshold = ConfidenceLevel(data.get('confidence_threshold', 'MEDIUM')) + confidence_threshold = ConfidenceLevel(data.get("confidence_threshold", "MEDIUM")) return cls( - name=data['name'], - description=data['description'], - detection_prompt=data['detection_prompt'], - enabled=data.get('enabled', True), + name=data["name"], + description=data["description"], + detection_prompt=data["detection_prompt"], + enabled=data.get("enabled", True), conditions=conditions, actions=actions, confidence_threshold=confidence_threshold, - use_temporal_context=data.get('use_temporal_context', True), - temporal_context_size=data.get('temporal_context_size', 5), + use_temporal_context=data.get("use_temporal_context", True), + temporal_context_size=data.get("temporal_context_size", 5), created_at=created_at, modified_at=modified_at, - detection_count=data.get('detection_count', 0), - run_count=data.get('run_count', 0), + detection_count=data.get("detection_count", 0), + run_count=data.get("run_count", 0), ) diff --git a/gently/harness/detection/queue.py b/gently/harness/detection/queue.py index e0f0d2c1..f0d7ee4c 100644 --- a/gently/harness/detection/queue.py +++ b/gently/harness/detection/queue.py @@ -3,14 +3,16 @@ """ import asyncio -from typing import List, Dict, Optional, Callable +from collections.abc import Callable from datetime import datetime + import anthropic from gently.settings import settings -from .detector import Detector, DetectionResult, ConfidenceLevel -from .registry import DetectorRegistry + from ..state import EmbryoState +from .detector import ConfidenceLevel, DetectionResult, Detector +from .registry import DetectorRegistry class DetectionQueue: @@ -25,8 +27,8 @@ def __init__( registry: DetectorRegistry, claude_client: anthropic.Anthropic, model: str = settings.models.perception, - on_detection_callback: Optional[Callable] = None, - on_evaluation_callback: Optional[Callable] = None + on_detection_callback: Callable | None = None, + on_evaluation_callback: Callable | None = None, ): """ Parameters @@ -49,10 +51,8 @@ def __init__( self.on_evaluation_callback = on_evaluation_callback async def run_detectors( - self, - embryo_state: EmbryoState, - timepoint: int - ) -> List[DetectionResult]: + self, embryo_state: EmbryoState, timepoint: int + ) -> list[DetectionResult]: """ Run all applicable detectors for an embryo/timepoint @@ -80,9 +80,7 @@ async def run_detectors( continue # Run detector - result = await self._run_single_detector( - detector, embryo_state, timepoint - ) + result = await self._run_single_detector(detector, embryo_state, timepoint) results.append(result) @@ -107,10 +105,7 @@ async def run_detectors( return results async def _run_single_detector( - self, - detector: Detector, - embryo_state: EmbryoState, - timepoint: int + self, detector: Detector, embryo_state: EmbryoState, timepoint: int ) -> DetectionResult: """ Run a single detector @@ -135,7 +130,9 @@ async def _run_single_detector( try: # Get recent images num_images = detector.temporal_context_size if detector.use_temporal_context else 1 - recent_images = embryo_state.recent_images[-num_images:] if embryo_state.recent_images else [] + recent_images = ( + embryo_state.recent_images[-num_images:] if embryo_state.recent_images else [] + ) if not recent_images: # No images available @@ -146,15 +143,15 @@ async def _run_single_detector( timestamp=datetime.now(), detected=False, error=True, - error_message="No images available" + error_message="No images available", ) # Build image data for detector image_data = [ { - 'timepoint': img.timepoint, - 'b64_image': img.max_projection_b64, - 'size': img.size_kb + "timepoint": img.timepoint, + "b64_image": img.max_projection_b64, + "size": img.size_kb, } for img in recent_images ] @@ -167,7 +164,7 @@ async def _run_single_detector( self.claude.messages.create, model=self.model, max_tokens=1024, - messages=[{"role": "user", "content": content}] + messages=[{"role": "user", "content": content}], ) response_text = response.content[0].text @@ -181,13 +178,13 @@ async def _run_single_detector( embryo_id=embryo_id, timepoint=timepoint, timestamp=datetime.now(), - detected=parsed['detected'], - confidence=parsed['confidence'], - reasoning=parsed['reasoning'], + detected=parsed["detected"], + confidence=parsed["confidence"], + reasoning=parsed["reasoning"], error=False, api_duration=api_duration, num_images=len(image_data), - full_response=response_text + full_response=response_text, ) except Exception as e: @@ -202,14 +199,10 @@ async def _run_single_detector( detected=False, error=True, error_message=str(e), - api_duration=api_duration + api_duration=api_duration, ) - def _meets_confidence_threshold( - self, - result: DetectionResult, - detector: Detector - ) -> bool: + def _meets_confidence_threshold(self, result: DetectionResult, detector: Detector) -> bool: """ Check if detection result meets confidence threshold @@ -232,7 +225,7 @@ def _meets_confidence_threshold( confidence_map = { ConfidenceLevel.LOW: 1, ConfidenceLevel.MEDIUM: 2, - ConfidenceLevel.HIGH: 3 + ConfidenceLevel.HIGH: 3, } result_value = confidence_map.get(result.confidence, 0) @@ -244,8 +237,8 @@ async def test_detector( self, detector_name: str, embryo_state: EmbryoState, - timepoint: Optional[int] = None - ) -> Optional[DetectionResult]: + timepoint: int | None = None, + ) -> DetectionResult | None: """ Test a detector on a specific embryo/timepoint @@ -280,7 +273,7 @@ async def test_detector( return result - def get_detection_summary(self, embryo_states: Dict[str, EmbryoState]) -> Dict: + def get_detection_summary(self, embryo_states: dict[str, EmbryoState]) -> dict: """ Get summary of all detections across all embryos @@ -294,48 +287,44 @@ def get_detection_summary(self, embryo_states: Dict[str, EmbryoState]) -> Dict: dict Summary of detections """ - summary = { - 'detectors': {}, - 'embryos': {} - } + summary = {"detectors": {}, "embryos": {}} # Per-detector summary for detector in self.registry.list_all(): detector_summary = { - 'name': detector.name, - 'description': detector.description, - 'enabled': detector.enabled, - 'total_runs': detector.run_count, - 'total_detections': detector.detection_count, - 'embryos_detected': [] + "name": detector.name, + "description": detector.description, + "enabled": detector.enabled, + "total_runs": detector.run_count, + "total_detections": detector.detection_count, + "embryos_detected": [], } for embryo_id, embryo_state in embryo_states.items(): if embryo_state.was_detected(detector.name): latest = embryo_state.get_latest_detection(detector.name) - detector_summary['embryos_detected'].append({ - 'embryo_id': embryo_id, - 'timepoint': latest.get('timepoint'), - 'confidence': latest.get('confidence') - }) + detector_summary["embryos_detected"].append( + { + "embryo_id": embryo_id, + "timepoint": latest.get("timepoint"), + "confidence": latest.get("confidence"), + } + ) - summary['detectors'][detector.name] = detector_summary + summary["detectors"][detector.name] = detector_summary # Per-embryo summary for embryo_id, embryo_state in embryo_states.items(): - embryo_summary = { - 'embryo_id': embryo_id, - 'detections': {} - } + embryo_summary = {"embryo_id": embryo_id, "detections": {}} for detector_name in embryo_state.detection_results.keys(): latest = embryo_state.get_latest_detection(detector_name) - embryo_summary['detections'][detector_name] = { - 'detected': latest.get('detected', False) if latest else False, - 'timepoint': latest.get('timepoint') if latest else None, - 'confidence': latest.get('confidence') if latest else None + embryo_summary["detections"][detector_name] = { + "detected": latest.get("detected", False) if latest else False, + "timepoint": latest.get("timepoint") if latest else None, + "confidence": latest.get("confidence") if latest else None, } - summary['embryos'][embryo_id] = embryo_summary + summary["embryos"][embryo_id] = embryo_summary return summary diff --git a/gently/harness/detection/registry.py b/gently/harness/detection/registry.py index e7da601f..64f920bb 100644 --- a/gently/harness/detection/registry.py +++ b/gently/harness/detection/registry.py @@ -2,15 +2,17 @@ Detector registry for managing all configured detectors """ -import logging import json -from pathlib import Path -from typing import Dict, List, Optional +import logging from datetime import datetime +from pathlib import Path logger = logging.getLogger(__name__) -from .detector import Detector, DetectorConditions, DetectorActions, DetectionMode, ConfidenceLevel +from .detector import ( # noqa: E402 + ConfidenceLevel, + Detector, +) class DetectorRegistry: @@ -20,14 +22,14 @@ class DetectorRegistry: Handles CRUD operations, persistence, and querying of detectors. """ - def __init__(self, storage_path: Optional[Path] = None): + def __init__(self, storage_path: Path | None = None): """ Parameters ---------- storage_path : Path, optional Where to save detector registry JSON """ - self.detectors: Dict[str, Detector] = {} + self.detectors: dict[str, Detector] = {} self.storage_path = storage_path or Path("./detector_registry.json") # Load existing detectors if file exists @@ -76,15 +78,15 @@ def remove(self, name: str) -> bool: self.save() return True - def get(self, name: str) -> Optional[Detector]: + def get(self, name: str) -> Detector | None: """Get detector by name""" return self.detectors.get(name) - def list_all(self) -> List[Detector]: + def list_all(self) -> list[Detector]: """Get all detectors""" return list(self.detectors.values()) - def list_enabled(self) -> List[Detector]: + def list_enabled(self) -> list[Detector]: """Get all enabled detectors""" return [d for d in self.detectors.values() if d.enabled] @@ -162,7 +164,7 @@ def update(self, name: str, **kwargs) -> bool: self.save() return True - def get_stats(self) -> Dict: + def get_stats(self) -> dict: """ Get registry statistics @@ -177,34 +179,31 @@ def get_stats(self) -> Dict: total_runs = sum(d.run_count for d in self.detectors.values()) return { - 'total_detectors': total, - 'enabled_detectors': enabled, - 'disabled_detectors': total - enabled, - 'total_detections_fired': total_detections, - 'total_runs': total_runs, - 'detectors': { + "total_detectors": total, + "enabled_detectors": enabled, + "disabled_detectors": total - enabled, + "total_detections_fired": total_detections, + "total_runs": total_runs, + "detectors": { name: { - 'enabled': d.enabled, - 'detection_count': d.detection_count, - 'run_count': d.run_count + "enabled": d.enabled, + "detection_count": d.detection_count, + "run_count": d.run_count, } for name, d in self.detectors.items() - } + }, } def save(self): """Save registry to disk""" data = { - 'version': '1.0', - 'saved_at': datetime.now().isoformat(), - 'detectors': { - name: detector.to_dict() - for name, detector in self.detectors.items() - } + "version": "1.0", + "saved_at": datetime.now().isoformat(), + "detectors": {name: detector.to_dict() for name, detector in self.detectors.items()}, } self.storage_path.parent.mkdir(parents=True, exist_ok=True) - with open(self.storage_path, 'w') as f: + with open(self.storage_path, "w") as f: json.dump(data, f, indent=2) def load(self): @@ -212,18 +211,18 @@ def load(self): if not self.storage_path.exists(): return - with open(self.storage_path, 'r') as f: + with open(self.storage_path) as f: data = json.load(f) self.detectors = {} - for name, detector_data in data.get('detectors', {}).items(): + for name, detector_data in data.get("detectors", {}).items(): try: detector = Detector.from_dict(detector_data) self.detectors[name] = detector except Exception as e: logger.warning("Failed to load detector '%s': %s", name, e) - def create_preset_detector(self, preset_name: str) -> Optional[Detector]: + def create_preset_detector(self, preset_name: str) -> Detector | None: """ Create a detector from preset @@ -243,12 +242,12 @@ def create_preset_detector(self, preset_name: str) -> Optional[Detector]: preset_data = presets[preset_name] detector = Detector( - name=preset_data['name'], - description=preset_data['description'], - detection_prompt=preset_data['prompt'], - use_temporal_context=preset_data.get('use_temporal_context', True), - temporal_context_size=preset_data.get('temporal_context_size', 5), - confidence_threshold=ConfidenceLevel(preset_data.get('confidence_threshold', 'MEDIUM')), + name=preset_data["name"], + description=preset_data["description"], + detection_prompt=preset_data["prompt"], + use_temporal_context=preset_data.get("use_temporal_context", True), + temporal_context_size=preset_data.get("temporal_context_size", 5), + confidence_threshold=ConfidenceLevel(preset_data.get("confidence_threshold", "MEDIUM")), ) return detector @@ -258,6 +257,7 @@ def create_preset_detector(self, preset_name: str) -> Optional[Detector]: def get_detector_presets(): """Get detector presets from the active organism module.""" from gently.organisms import get_organism + org = get_organism() presets_module = __import__( f"gently.organisms.{org.ORGANISM_NAME}.detector_presets", diff --git a/gently/harness/detection/verifier.py b/gently/harness/detection/verifier.py index 26534842..0370abb7 100644 --- a/gently/harness/detection/verifier.py +++ b/gently/harness/detection/verifier.py @@ -12,17 +12,17 @@ import asyncio import logging -import re from dataclasses import dataclass, field from datetime import datetime -from typing import Any, Callable, Dict, List, Optional +from typing import Any import anthropic +from gently.core import EventType, get_event_bus from gently.settings import settings -from .detector import Detector, DetectionResult, ConfidenceLevel + from ..state import EmbryoState -from gently.core import EventType, get_event_bus +from .detector import ConfidenceLevel, DetectionResult, Detector logger = logging.getLogger(__name__) @@ -30,17 +30,19 @@ @dataclass class AdversarialResult: """Result of adversarial verification strategy""" + found_counter_evidence: bool - concerns: List[str] - confidence_in_original: Optional[ConfidenceLevel] + concerns: list[str] + confidence_in_original: ConfidenceLevel | None raw_response: str @dataclass class IndependentResult: """Result of independent verification strategy""" + detected: bool - confidence: Optional[ConfidenceLevel] + confidence: ConfidenceLevel | None key_evidence: str raw_response: str @@ -48,28 +50,31 @@ class IndependentResult: @dataclass class TemporalResult: """Result of temporal comparison strategy""" + change_detected: bool description: str - confidence: Optional[ConfidenceLevel] + confidence: ConfidenceLevel | None raw_response: str @dataclass class EnsembleResult: """Result of ensemble voting strategy""" + votes_yes: int votes_no: int total_votes: int agreement_ratio: float # votes_yes / total_votes if detected, votes_no / total_votes if not consensus_detected: bool # True if >70% agree on YES - raw_responses: List[str] = field(default_factory=list) + raw_responses: list[str] = field(default_factory=list) @dataclass class HardwareContextResult: """Result of hardware context analysis strategy""" + suspicious: bool # True if hardware errors could have caused false positive - concerns: List[str] # Specific concerns identified + concerns: list[str] # Specific concerns identified reasoning: str raw_response: str @@ -77,15 +82,16 @@ class HardwareContextResult: @dataclass class VerificationResult: """Combined result of all verification strategies""" + original_detected: bool - original_confidence: Optional[ConfidenceLevel] + original_confidence: ConfidenceLevel | None # Strategy results adversarial: AdversarialResult independent: IndependentResult temporal: TemporalResult - ensemble: Optional[EnsembleResult] = None # Only for hatching detection - hardware_context: Optional[HardwareContextResult] = None # Only when errors present + ensemble: EnsembleResult | None = None # Only for hatching detection + hardware_context: HardwareContextResult | None = None # Only when errors present # Consensus consensus: bool = False @@ -95,44 +101,50 @@ class VerificationResult: timestamp: datetime = field(default_factory=datetime.now) verification_duration_seconds: float = 0.0 - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Serialize to dictionary""" result = { - 'original_detected': self.original_detected, - 'original_confidence': self.original_confidence.value if self.original_confidence else None, - 'adversarial': { - 'found_counter_evidence': self.adversarial.found_counter_evidence, - 'concerns': self.adversarial.concerns, - 'confidence_in_original': self.adversarial.confidence_in_original.value if self.adversarial.confidence_in_original else None, + "original_detected": self.original_detected, + "original_confidence": self.original_confidence.value + if self.original_confidence + else None, + "adversarial": { + "found_counter_evidence": self.adversarial.found_counter_evidence, + "concerns": self.adversarial.concerns, + "confidence_in_original": self.adversarial.confidence_in_original.value + if self.adversarial.confidence_in_original + else None, }, - 'independent': { - 'detected': self.independent.detected, - 'confidence': self.independent.confidence.value if self.independent.confidence else None, - 'key_evidence': self.independent.key_evidence, + "independent": { + "detected": self.independent.detected, + "confidence": self.independent.confidence.value + if self.independent.confidence + else None, + "key_evidence": self.independent.key_evidence, }, - 'temporal': { - 'change_detected': self.temporal.change_detected, - 'description': self.temporal.description, - 'confidence': self.temporal.confidence.value if self.temporal.confidence else None, + "temporal": { + "change_detected": self.temporal.change_detected, + "description": self.temporal.description, + "confidence": self.temporal.confidence.value if self.temporal.confidence else None, }, - 'consensus': self.consensus, - 'consensus_reasoning': self.consensus_reasoning, - 'timestamp': self.timestamp.isoformat(), - 'verification_duration_seconds': self.verification_duration_seconds, + "consensus": self.consensus, + "consensus_reasoning": self.consensus_reasoning, + "timestamp": self.timestamp.isoformat(), + "verification_duration_seconds": self.verification_duration_seconds, } if self.ensemble: - result['ensemble'] = { - 'votes_yes': self.ensemble.votes_yes, - 'votes_no': self.ensemble.votes_no, - 'total_votes': self.ensemble.total_votes, - 'agreement_ratio': self.ensemble.agreement_ratio, - 'consensus_detected': self.ensemble.consensus_detected, + result["ensemble"] = { + "votes_yes": self.ensemble.votes_yes, + "votes_no": self.ensemble.votes_no, + "total_votes": self.ensemble.total_votes, + "agreement_ratio": self.ensemble.agreement_ratio, + "consensus_detected": self.ensemble.consensus_detected, } if self.hardware_context: - result['hardware_context'] = { - 'suspicious': self.hardware_context.suspicious, - 'concerns': self.hardware_context.concerns, - 'reasoning': self.hardware_context.reasoning, + result["hardware_context"] = { + "suspicious": self.hardware_context.suspicious, + "concerns": self.hardware_context.concerns, + "reasoning": self.hardware_context.reasoning, } return result @@ -177,7 +189,7 @@ def __init__( self.ensemble_threshold = ensemble_threshold self._event_bus = event_bus or get_event_bus() - def _emit_event(self, event_type: EventType, data: Dict): + def _emit_event(self, event_type: EventType, data: dict): """Emit event to viz server""" if self._event_bus: self._event_bus.publish(event_type, data) @@ -211,19 +223,13 @@ async def verify( start_time = datetime.now() # Run all strategies in parallel for speed - adversarial_task = self._run_adversarial( - detector, embryo_state, original_result, timepoint - ) - independent_task = self._run_independent( - detector, embryo_state, timepoint - ) - temporal_task = self._run_temporal_check( - detector, embryo_state, timepoint - ) + adversarial_task = self._run_adversarial(detector, embryo_state, original_result, timepoint) + independent_task = self._run_independent(detector, embryo_state, timepoint) + temporal_task = self._run_temporal_check(detector, embryo_state, timepoint) # For hatching detection, also run ensemble voting ensemble_result = None - if detector.name == 'hatching': + if detector.name == "hatching": ensemble_task = self._run_ensemble_hatching(embryo_state) adversarial, independent, temporal, ensemble_result = await asyncio.gather( adversarial_task, independent_task, temporal_task, ensemble_task @@ -255,7 +261,11 @@ async def verify( logger.info( f"Verification complete for {detector.name}: " f"consensus={consensus}, duration={duration:.2f}s" - + (f", ensemble={ensemble_result.votes_yes}/{ensemble_result.total_votes} YES" if ensemble_result else "") + + ( + f", ensemble={ensemble_result.votes_yes}/{ensemble_result.total_votes} YES" + if ensemble_result + else "" + ) ) return result @@ -294,21 +304,15 @@ async def verify_with_context( start_time = datetime.now() # Run all strategies in parallel for speed - adversarial_task = self._run_adversarial( - detector, embryo_state, original_result, timepoint - ) - independent_task = self._run_independent( - detector, embryo_state, timepoint - ) - temporal_task = self._run_temporal_check( - detector, embryo_state, timepoint - ) + adversarial_task = self._run_adversarial(detector, embryo_state, original_result, timepoint) + independent_task = self._run_independent(detector, embryo_state, timepoint) + temporal_task = self._run_temporal_check(detector, embryo_state, timepoint) # For hatching detection, also run ensemble voting ensemble_result = None hardware_result = None - if detector.name == 'hatching': + if detector.name == "hatching": ensemble_task = self._run_ensemble_hatching(embryo_state) # Run hardware context analysis if there are errors @@ -316,11 +320,26 @@ async def verify_with_context( hardware_task = self._run_hardware_context_analysis( global_error_context, embryo_state.id ) - adversarial, independent, temporal, ensemble_result, hardware_result = await asyncio.gather( - adversarial_task, independent_task, temporal_task, ensemble_task, hardware_task + ( + adversarial, + independent, + temporal, + ensemble_result, + hardware_result, + ) = await asyncio.gather( + adversarial_task, + independent_task, + temporal_task, + ensemble_task, + hardware_task, ) else: - adversarial, independent, temporal, ensemble_result = await asyncio.gather( + ( + adversarial, + independent, + temporal, + ensemble_result, + ) = await asyncio.gather( adversarial_task, independent_task, temporal_task, ensemble_task ) else: @@ -335,91 +354,151 @@ async def verify_with_context( # Adversarial result strategies_complete += 1 - self._emit_event(EventType.VERIFICATION_STRATEGY, { - 'embryo_id': embryo_id, - 'detector_name': detector.name, - 'strategy': 'adversarial', - 'passed': not adversarial.found_counter_evidence, - 'summary': f"Counter-evidence: {'YES - ' + ', '.join(adversarial.concerns) if adversarial.found_counter_evidence else 'None found'}", - 'confidence': adversarial.confidence_in_original.value if adversarial.confidence_in_original else None, - }) - self._emit_event(EventType.VERIFICATION_PROGRESS, { - 'embryo_id': embryo_id, - 'strategies_complete': strategies_complete, - 'total_strategies': total_strategies, - }) + self._emit_event( + EventType.VERIFICATION_STRATEGY, + { + "embryo_id": embryo_id, + "detector_name": detector.name, + "strategy": "adversarial", + "passed": not adversarial.found_counter_evidence, + "summary": ( + "Counter-evidence: " + + ( + "YES - " + ", ".join(adversarial.concerns) + if adversarial.found_counter_evidence + else "None found" + ) + ), + "confidence": adversarial.confidence_in_original.value + if adversarial.confidence_in_original + else None, + }, + ) + self._emit_event( + EventType.VERIFICATION_PROGRESS, + { + "embryo_id": embryo_id, + "strategies_complete": strategies_complete, + "total_strategies": total_strategies, + }, + ) # Independent result strategies_complete += 1 - self._emit_event(EventType.VERIFICATION_STRATEGY, { - 'embryo_id': embryo_id, - 'detector_name': detector.name, - 'strategy': 'independent', - 'passed': independent.detected, - 'summary': f"Independent detection: {'YES' if independent.detected else 'NO'} - {independent.key_evidence}", - 'confidence': independent.confidence.value if independent.confidence else None, - }) - self._emit_event(EventType.VERIFICATION_PROGRESS, { - 'embryo_id': embryo_id, - 'strategies_complete': strategies_complete, - 'total_strategies': total_strategies, - }) + self._emit_event( + EventType.VERIFICATION_STRATEGY, + { + "embryo_id": embryo_id, + "detector_name": detector.name, + "strategy": "independent", + "passed": independent.detected, + "summary": ( + f"Independent detection: {'YES' if independent.detected else 'NO'}" + f" - {independent.key_evidence}" + ), + "confidence": independent.confidence.value if independent.confidence else None, + }, + ) + self._emit_event( + EventType.VERIFICATION_PROGRESS, + { + "embryo_id": embryo_id, + "strategies_complete": strategies_complete, + "total_strategies": total_strategies, + }, + ) # Temporal result strategies_complete += 1 - self._emit_event(EventType.VERIFICATION_STRATEGY, { - 'embryo_id': embryo_id, - 'detector_name': detector.name, - 'strategy': 'temporal', - 'passed': temporal.change_detected, - 'summary': f"Change detected: {'YES' if temporal.change_detected else 'NO'} - {temporal.description}", - 'confidence': temporal.confidence.value if temporal.confidence else None, - }) - self._emit_event(EventType.VERIFICATION_PROGRESS, { - 'embryo_id': embryo_id, - 'strategies_complete': strategies_complete, - 'total_strategies': total_strategies, - }) + self._emit_event( + EventType.VERIFICATION_STRATEGY, + { + "embryo_id": embryo_id, + "detector_name": detector.name, + "strategy": "temporal", + "passed": temporal.change_detected, + "summary": ( + f"Change detected: {'YES' if temporal.change_detected else 'NO'}" + f" - {temporal.description}" + ), + "confidence": temporal.confidence.value if temporal.confidence else None, + }, + ) + self._emit_event( + EventType.VERIFICATION_PROGRESS, + { + "embryo_id": embryo_id, + "strategies_complete": strategies_complete, + "total_strategies": total_strategies, + }, + ) # Ensemble result (if applicable) if ensemble_result: strategies_complete += 1 - self._emit_event(EventType.VERIFICATION_STRATEGY, { - 'embryo_id': embryo_id, - 'detector_name': detector.name, - 'strategy': 'ensemble', - 'passed': ensemble_result.consensus_detected, - 'summary': f"Ensemble vote: {ensemble_result.votes_yes}/{ensemble_result.total_votes} YES ({ensemble_result.agreement_ratio*100:.0f}%)", - 'votes_yes': ensemble_result.votes_yes, - 'votes_no': ensemble_result.votes_no, - 'total_votes': ensemble_result.total_votes, - }) - self._emit_event(EventType.VERIFICATION_PROGRESS, { - 'embryo_id': embryo_id, - 'strategies_complete': strategies_complete, - 'total_strategies': total_strategies, - }) + self._emit_event( + EventType.VERIFICATION_STRATEGY, + { + "embryo_id": embryo_id, + "detector_name": detector.name, + "strategy": "ensemble", + "passed": ensemble_result.consensus_detected, + "summary": ( + f"Ensemble vote: {ensemble_result.votes_yes}/{ensemble_result.total_votes}" + f" YES ({ensemble_result.agreement_ratio * 100:.0f}%)" + ), + "votes_yes": ensemble_result.votes_yes, + "votes_no": ensemble_result.votes_no, + "total_votes": ensemble_result.total_votes, + }, + ) + self._emit_event( + EventType.VERIFICATION_PROGRESS, + { + "embryo_id": embryo_id, + "strategies_complete": strategies_complete, + "total_strategies": total_strategies, + }, + ) # Hardware context result (if applicable) if hardware_result: strategies_complete += 1 - self._emit_event(EventType.VERIFICATION_STRATEGY, { - 'embryo_id': embryo_id, - 'detector_name': detector.name, - 'strategy': 'hardware_context', - 'passed': not hardware_result.suspicious, - 'summary': f"Hardware errors suspicious: {'YES - ' + ', '.join(hardware_result.concerns) if hardware_result.suspicious else 'No'}", - 'reasoning': hardware_result.reasoning, - }) - self._emit_event(EventType.VERIFICATION_PROGRESS, { - 'embryo_id': embryo_id, - 'strategies_complete': strategies_complete, - 'total_strategies': total_strategies, - }) + self._emit_event( + EventType.VERIFICATION_STRATEGY, + { + "embryo_id": embryo_id, + "detector_name": detector.name, + "strategy": "hardware_context", + "passed": not hardware_result.suspicious, + "summary": ( + "Hardware errors suspicious: " + + ( + "YES - " + ", ".join(hardware_result.concerns) + if hardware_result.suspicious + else "No" + ) + ), + "reasoning": hardware_result.reasoning, + }, + ) + self._emit_event( + EventType.VERIFICATION_PROGRESS, + { + "embryo_id": embryo_id, + "strategies_complete": strategies_complete, + "total_strategies": total_strategies, + }, + ) # Determine consensus (with hardware context) consensus, reasoning = self._evaluate_consensus_with_hardware( - original_result, adversarial, independent, temporal, ensemble_result, hardware_result + original_result, + adversarial, + independent, + temporal, + ensemble_result, + hardware_result, ) duration = (datetime.now() - start_time).total_seconds() @@ -440,26 +519,37 @@ async def verify_with_context( logger.info( f"Verification (with context) complete for {detector.name}: " f"consensus={consensus}, duration={duration:.2f}s" - + (f", ensemble={ensemble_result.votes_yes}/{ensemble_result.total_votes} YES" if ensemble_result else "") + + ( + f", ensemble={ensemble_result.votes_yes}/{ensemble_result.total_votes} YES" + if ensemble_result + else "" + ) + (f", hardware_suspicious={hardware_result.suspicious}" if hardware_result else "") ) # Emit VERIFICATION_COMPLETED event with full summary - self._emit_event(EventType.VERIFICATION_COMPLETED, { - 'embryo_id': embryo_id, - 'detector_name': detector.name, - 'consensus': consensus, - 'reasoning': reasoning, - 'duration_seconds': duration, - 'strategies': { - 'adversarial': not adversarial.found_counter_evidence, - 'independent': independent.detected, - 'temporal': temporal.change_detected, - 'ensemble': ensemble_result.consensus_detected if ensemble_result else None, - 'hardware_context': (not hardware_result.suspicious) if hardware_result else None, + self._emit_event( + EventType.VERIFICATION_COMPLETED, + { + "embryo_id": embryo_id, + "detector_name": detector.name, + "consensus": consensus, + "reasoning": reasoning, + "duration_seconds": duration, + "strategies": { + "adversarial": not adversarial.found_counter_evidence, + "independent": independent.detected, + "temporal": temporal.change_detected, + "ensemble": ensemble_result.consensus_detected if ensemble_result else None, + "hardware_context": (not hardware_result.suspicious) + if hardware_result + else None, + }, + "ensemble_votes": f"{ensemble_result.votes_yes}/{ensemble_result.total_votes}" + if ensemble_result + else None, }, - 'ensemble_votes': f"{ensemble_result.votes_yes}/{ensemble_result.total_votes}" if ensemble_result else None, - }) + ) return result @@ -487,7 +577,8 @@ async def _run_hardware_context_analysis( Analysis result """ try: - prompt = f"""You are analyzing hardware error context for a microscopy detection verification. + prompt = f"""You are analyzing hardware error context for a microscopy detection +verification. GLOBAL ERROR LOG: {global_error_context} @@ -500,10 +591,12 @@ async def _run_hardware_context_analysis( - Stage positioning errors could cause wrong embryo to be imaged - Acquisition timeouts could cause partial/blank images (blank images look like empty FOV = hatched) - Camera errors could produce corrupted data -- Errors on OTHER embryos in the same round could indicate systemic issues (stage drift, hardware instability) +- Errors on OTHER embryos in the same round could indicate systemic issues + (stage drift, hardware instability) - Multiple errors in quick succession suggests hardware problems -If ANY errors occurred that could have affected the image quality or positioning for {embryo_id}, report as SUSPICIOUS. +If ANY errors occurred that could have affected the image quality or positioning for +{embryo_id}, report as SUSPICIOUS. Respond in EXACTLY this format: SUSPICIOUS: [YES/NO] @@ -515,7 +608,7 @@ async def _run_hardware_context_analysis( self.claude.messages.create, model=self.ensemble_model, # Use Haiku for speed max_tokens=300, - messages=[{"role": "user", "content": prompt}] + messages=[{"role": "user", "content": prompt}], ) response_text = response.content[0].text @@ -536,16 +629,16 @@ def _parse_hardware_context_response(self, response: str) -> HardwareContextResu concerns = [] reasoning = "" - for line in response.split('\n'): + for line in response.split("\n"): line = line.strip() - if line.startswith('SUSPICIOUS:'): - value = line.split(':', 1)[1].strip().upper() - suspicious = value == 'YES' - elif line.startswith('CONCERNS:'): - concerns_str = line.split(':', 1)[1].strip() - concerns = [c.strip() for c in concerns_str.split(';') if c.strip()] - elif line.startswith('REASONING:'): - reasoning = line.split(':', 1)[1].strip() + if line.startswith("SUSPICIOUS:"): + value = line.split(":", 1)[1].strip().upper() + suspicious = value == "YES" + elif line.startswith("CONCERNS:"): + concerns_str = line.split(":", 1)[1].strip() + concerns = [c.strip() for c in concerns_str.split(";") if c.strip()] + elif line.startswith("REASONING:"): + reasoning = line.split(":", 1)[1].strip() return HardwareContextResult( suspicious=suspicious, @@ -560,8 +653,8 @@ def _evaluate_consensus_with_hardware( adversarial: AdversarialResult, independent: IndependentResult, temporal: TemporalResult, - ensemble: Optional[EnsembleResult] = None, - hardware_context: Optional[HardwareContextResult] = None, + ensemble: EnsembleResult | None = None, + hardware_context: HardwareContextResult | None = None, ) -> tuple[bool, str]: """ Evaluate consensus across all verification strategies including hardware context. @@ -577,7 +670,9 @@ def _evaluate_consensus_with_hardware( if not adversarial.found_counter_evidence: agreements += 1 else: - disagreements.append(f"Adversarial found counter-evidence: {', '.join(adversarial.concerns[:2])}") + disagreements.append( + f"Adversarial found counter-evidence: {', '.join(adversarial.concerns[:2])}" + ) # Check independent: should also detect if independent.detected: @@ -617,12 +712,19 @@ def _evaluate_consensus_with_hardware( consensus = agreements == total_strategies if consensus: - parts = ["no counter-evidence found", "independent analysis confirms", "temporal change observed"] + parts = [ + "no counter-evidence found", + "independent analysis confirms", + "temporal change observed", + ] if ensemble: parts.append(f"ensemble confirms ({ensemble.votes_yes}/{ensemble.total_votes} YES)") if hardware_context: parts.append("no hardware error concerns") - reasoning = f"All verification strategies agree ({total_strategies}/{total_strategies}): " + ", ".join(parts) + reasoning = ( + f"All verification strategies agree ({total_strategies}/{total_strategies}): " + + ", ".join(parts) + ) else: reasoning = ( f"Verification disagreement ({agreements}/{total_strategies} agree): " @@ -656,7 +758,7 @@ async def _run_adversarial( ) # Build detector-specific critical review guidance - if detector.name == 'hatching': + if detector.name == "hatching": specific_guidance = """ For HATCHING specifically, look for these common FALSE POSITIVE patterns: - Is the worm still COILED/PRETZEL-SHAPED inside the eggshell? @@ -669,11 +771,12 @@ async def _run_adversarial( else: specific_guidance = "" - prompt = f"""You are reviewing a detection result for a C. elegans embryo (diSPIM max projection). + prompt = f"""You are reviewing a detection result for a C. elegans embryo +(diSPIM max projection). The system detected: {detector.name} -Original confidence: {original_result.confidence.value if original_result.confidence else 'unknown'} -Original reasoning: {original_result.reasoning or 'not provided'} +Original confidence: {original_result.confidence.value if original_result.confidence else "unknown"} +Original reasoning: {original_result.reasoning or "not provided"} NOW ACT AS A CRITICAL REVIEWER. Your job is to find reasons why this detection might be INCORRECT: - Could this be a false positive? @@ -693,7 +796,7 @@ async def _run_adversarial( self.claude.messages.create, model=self.model, max_tokens=500, - messages=[{"role": "user", "content": content}] + messages=[{"role": "user", "content": content}], ) response_text = response.content[0].text @@ -732,7 +835,7 @@ async def _run_independent( ) # Build detector-specific criteria - if detector.name == 'hatching': + if detector.name == "hatching": criteria = """ TRUE HATCHING criteria (must meet at least one): - Worm body is OUTSIDE the eggshell boundary (free-floating, elongated) @@ -746,7 +849,8 @@ async def _run_independent( criteria = detector.description # Use a neutral prompt that doesn't reveal the previous detection - prompt = f"""Analyze this C. elegans embryo image (diSPIM max projection) at timepoint {timepoint}. + prompt = f"""Analyze this C. elegans embryo image (diSPIM max projection) at +timepoint {timepoint}. Question: Has '{detector.name}' occurred in this embryo? @@ -767,7 +871,7 @@ async def _run_independent( self.claude.messages.create, model=self.model, max_tokens=400, - messages=[{"role": "user", "content": content}] + messages=[{"role": "user", "content": content}], ) response_text = response.content[0].text @@ -811,14 +915,16 @@ async def _run_temporal_check( prev_images = [] for img in embryo_state.recent_images[-3:-1]: if img.max_projection_b64: - prev_images.append({ - "type": "image", - "source": { - "type": "base64", - "media_type": "image/png", - "data": img.max_projection_b64, + prev_images.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": img.max_projection_b64, + }, } - }) + ) if not prev_images: return TemporalResult( @@ -829,7 +935,7 @@ async def _run_temporal_check( ) # Build detector-specific temporal criteria - if detector.name == 'hatching': + if detector.name == "hatching": temporal_criteria = """For HATCHING, look for: - A visible BREACH appearing in the eggshell boundary (not just expansion) - The worm physically EXITING the shell (part of body moves outside) @@ -841,10 +947,11 @@ async def _run_temporal_check( - Not just a static state that could have existed before - Clear evidence of progression or event occurrence""" - prompt = f"""Compare these sequential timepoints of a C. elegans embryo (diSPIM max projection). + prompt = f"""Compare these sequential timepoints of a C. elegans embryo +(diSPIM max projection). PREVIOUS TIMEPOINTS (shown first): -These are from t={timepoint-2} to t={timepoint-1} +These are from t={timepoint - 2} to t={timepoint - 1} CURRENT TIMEPOINT (shown last): This is t={timepoint} @@ -866,7 +973,7 @@ async def _run_temporal_check( self.claude.messages.create, model=self.model, max_tokens=400, - messages=[{"role": "user", "content": content}] + messages=[{"role": "user", "content": content}], ) response_text = response.content[0].text @@ -881,25 +988,23 @@ async def _run_temporal_check( raw_response="", ) - def _get_image_content( - self, - embryo_state: EmbryoState, - num_images: int = 1 - ) -> List[Dict]: + def _get_image_content(self, embryo_state: EmbryoState, num_images: int = 1) -> list[dict]: """Get image content blocks for Claude API""" images = [] recent = embryo_state.recent_images[-num_images:] if embryo_state.recent_images else [] for img in recent: if img.max_projection_b64: - images.append({ - "type": "image", - "source": { - "type": "base64", - "media_type": "image/png", - "data": img.max_projection_b64, + images.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": img.max_projection_b64, + }, } - }) + ) return images @@ -933,8 +1038,10 @@ async def _run_ensemble_hatching(self, embryo_state: EmbryoState) -> EnsembleRes Answer ONE question: Has the embryo HATCHED? -HATCHED means: The worm body is OUTSIDE the eggshell (free-floating, elongated, or field is empty because worm left). -NOT HATCHED means: The worm is still INSIDE the eggshell (coiled/pretzel-shaped, even if shell looks expanded). +HATCHED means: The worm body is OUTSIDE the eggshell (free-floating, elongated, or field is +empty because worm left). +NOT HATCHED means: The worm is still INSIDE the eggshell (coiled/pretzel-shaped, even if +shell looks expanded). Respond with ONLY: YES or NO""" @@ -946,7 +1053,7 @@ async def single_vote() -> str: self.claude.messages.create, model=self.ensemble_model, max_tokens=10, # Very short response expected - messages=[{"role": "user", "content": content}] + messages=[{"role": "user", "content": content}], ) return response.content[0].text.strip().upper() except Exception as e: @@ -954,7 +1061,10 @@ async def single_vote() -> str: return "ERROR" # Run all votes in parallel - logger.info(f"[ENSEMBLE] Running {self.ensemble_size} parallel Haiku calls for hatching verification") + logger.info( + f"[ENSEMBLE] Running {self.ensemble_size} parallel Haiku calls" + " for hatching verification" + ) tasks = [single_vote() for _ in range(self.ensemble_size)] responses = await asyncio.gather(*tasks) @@ -1009,16 +1119,16 @@ def _parse_adversarial_response(self, response: str) -> AdversarialResult: concerns = [] confidence = None - for line in response.split('\n'): + for line in response.split("\n"): line = line.strip() - if line.startswith('COUNTER_EVIDENCE_FOUND:'): - value = line.split(':', 1)[1].strip().upper() - found_counter = value == 'YES' - elif line.startswith('CONCERNS:'): - concerns_str = line.split(':', 1)[1].strip() - concerns = [c.strip() for c in concerns_str.split(';') if c.strip()] - elif line.startswith('CONFIDENCE_IN_ORIGINAL:'): - value = line.split(':', 1)[1].strip().upper() + if line.startswith("COUNTER_EVIDENCE_FOUND:"): + value = line.split(":", 1)[1].strip().upper() + found_counter = value == "YES" + elif line.startswith("CONCERNS:"): + concerns_str = line.split(":", 1)[1].strip() + concerns = [c.strip() for c in concerns_str.split(";") if c.strip()] + elif line.startswith("CONFIDENCE_IN_ORIGINAL:"): + value = line.split(":", 1)[1].strip().upper() try: confidence = ConfidenceLevel(value) except ValueError: @@ -1037,19 +1147,19 @@ def _parse_independent_response(self, response: str) -> IndependentResult: confidence = None evidence = "" - for line in response.split('\n'): + for line in response.split("\n"): line = line.strip() - if line.startswith('DETECTED:'): - value = line.split(':', 1)[1].strip().upper() - detected = value == 'YES' - elif line.startswith('CONFIDENCE:'): - value = line.split(':', 1)[1].strip().upper() + if line.startswith("DETECTED:"): + value = line.split(":", 1)[1].strip().upper() + detected = value == "YES" + elif line.startswith("CONFIDENCE:"): + value = line.split(":", 1)[1].strip().upper() try: confidence = ConfidenceLevel(value) except ValueError: pass - elif line.startswith('KEY_EVIDENCE:'): - evidence = line.split(':', 1)[1].strip() + elif line.startswith("KEY_EVIDENCE:"): + evidence = line.split(":", 1)[1].strip() return IndependentResult( detected=detected, @@ -1064,15 +1174,15 @@ def _parse_temporal_response(self, response: str) -> TemporalResult: description = "" confidence = None - for line in response.split('\n'): + for line in response.split("\n"): line = line.strip() - if line.startswith('CHANGE_DETECTED:'): - value = line.split(':', 1)[1].strip().upper() - change_detected = value == 'YES' - elif line.startswith('DESCRIPTION:'): - description = line.split(':', 1)[1].strip() - elif line.startswith('CONFIDENCE:'): - value = line.split(':', 1)[1].strip().upper() + if line.startswith("CHANGE_DETECTED:"): + value = line.split(":", 1)[1].strip().upper() + change_detected = value == "YES" + elif line.startswith("DESCRIPTION:"): + description = line.split(":", 1)[1].strip() + elif line.startswith("CONFIDENCE:"): + value = line.split(":", 1)[1].strip().upper() try: confidence = ConfidenceLevel(value) except ValueError: @@ -1091,7 +1201,7 @@ def _evaluate_consensus( adversarial: AdversarialResult, independent: IndependentResult, temporal: TemporalResult, - ensemble: Optional[EnsembleResult] = None, + ensemble: EnsembleResult | None = None, ) -> tuple[bool, str]: """ Evaluate consensus across all verification strategies. @@ -1107,7 +1217,9 @@ def _evaluate_consensus( if not adversarial.found_counter_evidence: agreements += 1 else: - disagreements.append(f"Adversarial found counter-evidence: {', '.join(adversarial.concerns[:2])}") + disagreements.append( + f"Adversarial found counter-evidence: {', '.join(adversarial.concerns[:2])}" + ) # Check independent: should also detect if independent.detected: @@ -1141,13 +1253,14 @@ def _evaluate_consensus( f"All verification strategies agree ({total_strategies}/{total_strategies}): " f"no counter-evidence found, independent analysis confirms detection, " f"temporal change observed, ensemble voting confirms " - f"({ensemble.votes_yes}/{ensemble.total_votes} = {ensemble.agreement_ratio:.0%} YES)." + f"({ensemble.votes_yes}/{ensemble.total_votes}" + f" = {ensemble.agreement_ratio:.0%} YES)." ) else: reasoning = ( - f"All verification strategies agree: " - f"no counter-evidence found, independent analysis confirms detection, " - f"temporal change observed." + "All verification strategies agree: " + "no counter-evidence found, independent analysis confirms detection, " + "temporal change observed." ) else: reasoning = ( diff --git a/gently/harness/error_log.py b/gently/harness/error_log.py index 4b741e57..9b2b228f 100644 --- a/gently/harness/error_log.py +++ b/gently/harness/error_log.py @@ -6,9 +6,8 @@ """ import logging -from dataclasses import dataclass, field +from dataclasses import dataclass from datetime import datetime -from typing import Dict, List, Optional logger = logging.getLogger(__name__) @@ -16,13 +15,14 @@ @dataclass class ErrorEntry: """Single error entry in the global log""" + timestamp: datetime round_number: int embryo_id: str timepoint: int error_type: str message: str - exception: Optional[Exception] = None + exception: Exception | None = None class GlobalErrorLog: @@ -40,7 +40,7 @@ def __init__(self, max_entries: int = 100): max_entries : int Maximum number of entries to keep (oldest are dropped) """ - self._entries: List[ErrorEntry] = [] + self._entries: list[ErrorEntry] = [] self._max_entries = max_entries def log_error( @@ -50,7 +50,7 @@ def log_error( timepoint: int, error_type: str, message: str, - exception: Optional[Exception] = None, + exception: Exception | None = None, ): """ Log an error during timelapse acquisition. @@ -84,22 +84,20 @@ def log_error( # Trim old entries if len(self._entries) > self._max_entries: - self._entries = self._entries[-self._max_entries:] + self._entries = self._entries[-self._max_entries :] # Also log to standard logger - logger.warning( - f"[{error_type}] Round {round_number}, {embryo_id} t{timepoint}: {message}" - ) + logger.warning(f"[{error_type}] Round {round_number}, {embryo_id} t{timepoint}: {message}") - def get_recent_errors(self, limit: int = 10) -> List[ErrorEntry]: + def get_recent_errors(self, limit: int = 10) -> list[ErrorEntry]: """Get most recent errors""" return self._entries[-limit:] - def get_errors_for_embryo(self, embryo_id: str) -> List[ErrorEntry]: + def get_errors_for_embryo(self, embryo_id: str) -> list[ErrorEntry]: """Get all errors for a specific embryo""" return [e for e in self._entries if e.embryo_id == embryo_id] - def get_errors_in_round(self, round_number: int) -> List[ErrorEntry]: + def get_errors_in_round(self, round_number: int) -> list[ErrorEntry]: """Get all errors from a specific round""" return [e for e in self._entries if e.round_number == round_number] diff --git a/gently/harness/memory/__init__.py b/gently/harness/memory/__init__.py index 3d76b540..ce9083c0 100644 --- a/gently/harness/memory/__init__.py +++ b/gently/harness/memory/__init__.py @@ -13,43 +13,50 @@ """ from .model import ( + Attention, + BenchSpec, Campaign, - Project, - SessionIntent, - PlannedSession, - PlanItem, - PlanItemStatus, - PlanItemType, + Confidence, + Context, + ContextUpdates, + EmbryoUnderstanding, + Expectation, + ExpectationStatus, ImagingSpec, - BenchSpec, + Intentions, Learning, Observation, - Expectation, - Watchpoint, + PlanItem, + PlanItemStatus, + PlanItemType, + PlannedSession, + PlannedSessionStatus, + Project, Question, - EmbryoUnderstanding, - Intentions, - Understanding, - Attention, - Context, - ContextUpdates, + QuestionStatus, + SessionIntent, # Enums Significance, - Confidence, Status, - PlannedSessionStatus, - ExpectationStatus, + Understanding, + Watchpoint, WatchpointStatus, - QuestionStatus, ) -from .store import ContextStore # legacy SQLite store (read-only, replaced by FileContextStore) +from .store import ( + ContextStore, +) + try: from .file_store import FileContextStore except ImportError: FileContextStore = None -from .serialization import context_to_dict, context_to_json, context_summary -from .gap_assessment import assess_gaps, ContextGapReport, Gap, GapLayer, GapSeverity -from .onboarding import generate_onboarding_messages, process_onboarding_response, OnboardingMessage +from .gap_assessment import ContextGapReport, Gap, GapLayer, GapSeverity, assess_gaps +from .onboarding import ( + OnboardingMessage, + generate_onboarding_messages, + process_onboarding_response, +) +from .serialization import context_summary, context_to_dict, context_to_json from .startup_wizard import StartupWizard __all__ = [ diff --git a/gently/harness/memory/_intentions.py b/gently/harness/memory/_intentions.py index 0de5e540..59de4048 100644 --- a/gently/harness/memory/_intentions.py +++ b/gently/harness/memory/_intentions.py @@ -9,8 +9,9 @@ import logging import sqlite3 from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any +from ._protocols import StoreProtocol from .model import ( Campaign, Intentions, @@ -24,7 +25,7 @@ logger = logging.getLogger(__name__) -class IntentionsMixin: +class IntentionsMixin(StoreProtocol): """Campaign management, projects, session intents, and planned sessions.""" # ------------------------------------------------------------------ @@ -47,11 +48,11 @@ def _load_intentions(self) -> Intentions: def create_campaign( self, description: str, - shorthand: Optional[str] = None, - summary: Optional[str] = None, - target: Optional[str] = None, - parent_id: Optional[str] = None, - campaign_id: Optional[str] = None, + shorthand: str | None = None, + summary: str | None = None, + target: str | None = None, + parent_id: str | None = None, + campaign_id: str | None = None, ) -> str: """Create a new campaign. Returns campaign ID.""" cid = campaign_id or self._gen_id() @@ -59,7 +60,8 @@ def create_campaign( with self._tx(): self._conn.execute( "INSERT INTO campaigns " - "(id, description, shorthand, summary, target, parent_id, status, created_at, updated_at) " + "(id, description, shorthand, summary, target, parent_id," + " status, created_at, updated_at) " "VALUES (?, ?, ?, ?, ?, ?, 'active', ?, ?)", (cid, description, shorthand, summary, target, parent_id, now, now), ) @@ -67,7 +69,7 @@ def create_campaign( logger.info(f"Created campaign {cid} [{label}]") return cid - def get_active_campaigns(self) -> List[Campaign]: + def get_active_campaigns(self) -> list[Campaign]: """Get all active campaigns.""" rows = self._conn.execute( "SELECT * FROM campaigns WHERE status = 'active' ORDER BY created_at DESC" @@ -83,12 +85,10 @@ def count_non_active_campaigns(self) -> int: def count_session_intents(self) -> int: """Count total session intent records.""" - row = self._conn.execute( - "SELECT COUNT(*) as cnt FROM session_intents" - ).fetchone() + row = self._conn.execute("SELECT COUNT(*) as cnt FROM session_intents").fetchone() return row["cnt"] if row else 0 - def get_all_campaigns(self, limit: int = 50) -> List[Campaign]: + def get_all_campaigns(self, limit: int = 50) -> list[Campaign]: """Get all campaigns regardless of status, ordered by created_at descending.""" rows = self._conn.execute( "SELECT * FROM campaigns ORDER BY created_at DESC LIMIT ?", @@ -96,7 +96,7 @@ def get_all_campaigns(self, limit: int = 50) -> List[Campaign]: ).fetchall() return [self._row_to_campaign(row) for row in rows] - def get_recent_session_intents(self, limit: int = 50) -> List["SessionIntent"]: + def get_recent_session_intents(self, limit: int = 50) -> list["SessionIntent"]: """Get recent session intents, ordered by created_at descending.""" rows = self._conn.execute( "SELECT * FROM session_intents ORDER BY created_at DESC LIMIT ?", @@ -107,24 +107,26 @@ def get_recent_session_intents(self, limit: int = 50) -> List["SessionIntent"]: d = dict(row) sid = d["session_id"] campaign_ids = self.get_campaign_ids_for_session(sid) - results.append(SessionIntent( - session_id=sid, - planned_intent=d.get("planned_intent"), - actual_summary=d.get("actual_summary"), - campaign_ids=campaign_ids, - created_at=datetime.fromisoformat(d["created_at"]), - completed_at=datetime.fromisoformat(d["completed_at"]) if d.get("completed_at") else None, - )) + results.append( + SessionIntent( + session_id=sid, + planned_intent=d.get("planned_intent"), + actual_summary=d.get("actual_summary"), + campaign_ids=campaign_ids, + created_at=datetime.fromisoformat(d["created_at"]), + completed_at=datetime.fromisoformat(d["completed_at"]) + if d.get("completed_at") + else None, + ) + ) return results - def get_campaign(self, campaign_id: str) -> Optional[Campaign]: + def get_campaign(self, campaign_id: str) -> Campaign | None: """Get a specific campaign by exact ID.""" - row = self._conn.execute( - "SELECT * FROM campaigns WHERE id = ?", (campaign_id,) - ).fetchone() + row = self._conn.execute("SELECT * FROM campaigns WHERE id = ?", (campaign_id,)).fetchone() return self._row_to_campaign(row) if row else None - def resolve_campaign(self, ref: str) -> Optional[Campaign]: + def resolve_campaign(self, ref: str) -> Campaign | None: """Resolve a campaign by UUID, shorthand, UUID prefix, or description. Tries exact ID first, then falls back to _resolve_campaign_label @@ -159,7 +161,7 @@ def update_campaign_status(self, campaign_id: str, status: Status): (status.value, now, campaign_id), ) - def delete_campaign(self, campaign_id: str, cascade: bool = True) -> Dict[str, int]: + def delete_campaign(self, campaign_id: str, cascade: bool = True) -> dict[str, int]: """ Delete a campaign and optionally its children and plan items. @@ -195,18 +197,21 @@ def _delete_recursive(cid: str): # Delete plan items r = self._conn.execute( - "DELETE FROM plan_items WHERE campaign_id = ?", (cid,), + "DELETE FROM plan_items WHERE campaign_id = ?", + (cid,), ) counts["plan_items"] += r.rowcount # Delete campaign participants self._conn.execute( - "DELETE FROM campaign_participants WHERE campaign_id = ?", (cid,), + "DELETE FROM campaign_participants WHERE campaign_id = ?", + (cid,), ) # Delete campaign r = self._conn.execute( - "DELETE FROM campaigns WHERE id = ?", (cid,), + "DELETE FROM campaigns WHERE id = ?", + (cid,), ) counts["campaigns"] += r.rowcount @@ -215,7 +220,7 @@ def _delete_recursive(cid: str): return counts - def get_subcampaigns(self, campaign_id: str) -> List[Campaign]: + def get_subcampaigns(self, campaign_id: str) -> list[Campaign]: """Get direct children of a campaign.""" rows = self._conn.execute( "SELECT * FROM campaigns WHERE parent_id = ? ORDER BY created_at", @@ -223,14 +228,14 @@ def get_subcampaigns(self, campaign_id: str) -> List[Campaign]: ).fetchall() return [self._row_to_campaign(row) for row in rows] - def get_nth_subcampaign(self, parent_id: str, n: int) -> Optional[Campaign]: + def get_nth_subcampaign(self, parent_id: str, n: int) -> Campaign | None: """Get the nth child campaign (1-indexed) of a parent, ordered by creation.""" phases = self.get_subcampaigns(parent_id) if 1 <= n <= len(phases): return phases[n - 1] return None - def get_campaign_tree(self, campaign_id: str) -> Dict[str, Any]: + def get_campaign_tree(self, campaign_id: str) -> dict[str, Any]: """Get a campaign and all its descendants as a tree.""" campaign = self.get_campaign(campaign_id) if not campaign: @@ -241,12 +246,11 @@ def get_campaign_tree(self, campaign_id: str) -> Dict[str, Any]: "children": [self.get_campaign_tree(c.id) for c in children], } - def get_root_campaigns(self, status: Optional[str] = "active") -> List[Campaign]: + def get_root_campaigns(self, status: str | None = "active") -> list[Campaign]: """Get top-level campaigns (no parent). If status is None, returns all.""" if status is None: rows = self._conn.execute( - "SELECT * FROM campaigns WHERE parent_id IS NULL " - "ORDER BY updated_at DESC LIMIT 50" + "SELECT * FROM campaigns WHERE parent_id IS NULL ORDER BY updated_at DESC LIMIT 50" ).fetchall() else: rows = self._conn.execute( @@ -259,11 +263,11 @@ def get_root_campaigns(self, status: Optional[str] = "active") -> List[Campaign] def update_campaign( self, campaign_id: str, - description: Optional[str] = None, - shorthand: Optional[str] = None, - summary: Optional[str] = None, - target: Optional[str] = None, - parent_id: Optional[str] = None, + description: str | None = None, + shorthand: str | None = None, + summary: str | None = None, + target: str | None = None, + parent_id: str | None = None, ): """Update campaign fields. Only non-None values are applied.""" now = self._now() @@ -310,7 +314,7 @@ def unshare_campaign(self, campaign_id: str): (self._now(), campaign_id), ) - def get_shared_campaigns(self) -> List[Campaign]: + def get_shared_campaigns(self) -> list[Campaign]: """Get all campaigns marked as shared.""" rows = self._conn.execute( "SELECT * FROM campaigns WHERE is_shared = 1 ORDER BY created_at", @@ -326,7 +330,7 @@ def add_campaign_participant(self, campaign_id: str, instance_id: str, hostname: (campaign_id, instance_id, hostname, self._now()), ) - def get_campaign_participants(self, campaign_id: str) -> List[Dict]: + def get_campaign_participants(self, campaign_id: str) -> list[dict]: """Get all participants for a campaign.""" rows = self._conn.execute( "SELECT * FROM campaign_participants WHERE campaign_id = ? ORDER BY joined_at", @@ -380,22 +384,23 @@ def _row_to_campaign(self, row: sqlite3.Row) -> Campaign: def create_project( self, description: str, - campaign_id: Optional[str] = None, - project_id: Optional[str] = None, + campaign_id: str | None = None, + project_id: str | None = None, ) -> str: """Create a new project. Returns project ID.""" pid = project_id or self._gen_id() now = self._now() with self._tx(): self._conn.execute( - "INSERT INTO projects (id, description, campaign_id, status, created_at, updated_at) " + "INSERT INTO projects" + " (id, description, campaign_id, status, created_at, updated_at) " "VALUES (?, ?, ?, 'active', ?, ?)", (pid, description, campaign_id, now, now), ) logger.info(f"Created project {pid}: {description}") return pid - def get_active_projects(self) -> List[Project]: + def get_active_projects(self) -> list[Project]: """Get all active projects.""" rows = self._conn.execute( "SELECT * FROM projects WHERE status = 'active' ORDER BY created_at DESC" @@ -420,8 +425,8 @@ def _row_to_project(self, row: sqlite3.Row) -> Project: def create_session_intent( self, session_id: str, - planned_intent: Optional[str] = None, - campaign_ids: Optional[List[str]] = None, + planned_intent: str | None = None, + campaign_ids: list[str] | None = None, ): """Create or update session intent, optionally linking to campaigns.""" now = self._now() @@ -436,7 +441,7 @@ def create_session_intent( for cid in campaign_ids: self.link_session_campaign(session_id, cid) - def get_current_session_intent(self) -> Optional[SessionIntent]: + def get_current_session_intent(self) -> SessionIntent | None: """Get the most recent incomplete session intent.""" row = self._conn.execute( "SELECT * FROM session_intents WHERE completed_at IS NULL " @@ -453,7 +458,9 @@ def get_current_session_intent(self) -> Optional[SessionIntent]: actual_summary=d.get("actual_summary"), campaign_ids=campaign_ids, created_at=datetime.fromisoformat(d["created_at"]), - completed_at=datetime.fromisoformat(d["completed_at"]) if d.get("completed_at") else None, + completed_at=datetime.fromisoformat(d["completed_at"]) + if d.get("completed_at") + else None, ) def complete_session_intent(self, session_id: str, actual_summary: str): @@ -476,8 +483,7 @@ def link_session_campaign(self, session_id: str, campaign_id: str): with self._tx(): # Ensure session_intents row exists (FK target) self._conn.execute( - "INSERT OR IGNORE INTO session_intents " - "(session_id, created_at) VALUES (?, ?)", + "INSERT OR IGNORE INTO session_intents (session_id, created_at) VALUES (?, ?)", (session_id, now), ) self._conn.execute( @@ -494,16 +500,15 @@ def unlink_session_campaign(self, session_id: str, campaign_id: str): (session_id, campaign_id), ) - def get_campaign_ids_for_session(self, session_id: str) -> List[str]: + def get_campaign_ids_for_session(self, session_id: str) -> list[str]: """Get campaign IDs linked to a session.""" rows = self._conn.execute( - "SELECT campaign_id FROM session_campaigns WHERE session_id = ? " - "ORDER BY linked_at", + "SELECT campaign_id FROM session_campaigns WHERE session_id = ? ORDER BY linked_at", (session_id,), ).fetchall() return [row["campaign_id"] for row in rows] - def get_campaigns_for_session(self, session_id: str) -> List[Campaign]: + def get_campaigns_for_session(self, session_id: str) -> list[Campaign]: """Get campaigns linked to a session.""" rows = self._conn.execute( "SELECT c.* FROM campaigns c " @@ -513,7 +518,7 @@ def get_campaigns_for_session(self, session_id: str) -> List[Campaign]: ).fetchall() return [self._row_to_campaign(row) for row in rows] - def get_sessions_for_campaign(self, campaign_id: str) -> List[SessionIntent]: + def get_sessions_for_campaign(self, campaign_id: str) -> list[SessionIntent]: """Get session intents linked to a campaign.""" rows = self._conn.execute( "SELECT si.* FROM session_intents si " @@ -526,14 +531,18 @@ def get_sessions_for_campaign(self, campaign_id: str) -> List[SessionIntent]: d = dict(row) sid = d["session_id"] cids = self.get_campaign_ids_for_session(sid) - results.append(SessionIntent( - session_id=sid, - planned_intent=d.get("planned_intent"), - actual_summary=d.get("actual_summary"), - campaign_ids=cids, - created_at=datetime.fromisoformat(d["created_at"]), - completed_at=datetime.fromisoformat(d["completed_at"]) if d.get("completed_at") else None, - )) + results.append( + SessionIntent( + session_id=sid, + planned_intent=d.get("planned_intent"), + actual_summary=d.get("actual_summary"), + campaign_ids=cids, + created_at=datetime.fromisoformat(d["created_at"]), + completed_at=datetime.fromisoformat(d["completed_at"]) + if d.get("completed_at") + else None, + ) + ) return results # ================================================================== @@ -543,14 +552,14 @@ def get_sessions_for_campaign(self, campaign_id: str) -> List[SessionIntent]: def create_planned_session( self, scheduled_date: str, - title: Optional[str] = None, - notes: Optional[str] = None, - scheduled_time: Optional[str] = None, - estimated_duration_minutes: Optional[int] = None, - acquisition_params: Optional[Dict] = None, - source_session_id: Optional[str] = None, - campaign_ids: Optional[List[str]] = None, - planned_session_id: Optional[str] = None, + title: str | None = None, + notes: str | None = None, + scheduled_time: str | None = None, + estimated_duration_minutes: int | None = None, + acquisition_params: dict | None = None, + source_session_id: str | None = None, + campaign_ids: list[str] | None = None, + planned_session_id: str | None = None, ) -> str: """Create a planned imaging session. Returns its ID.""" psid = planned_session_id or self._gen_id() @@ -563,19 +572,27 @@ def create_planned_session( " source_session_id, status, created_at, updated_at) " "VALUES (?, ?, ?, ?, ?, ?, ?, ?, 'planned', ?, ?)", ( - psid, title, notes, scheduled_date, scheduled_time, + psid, + title, + notes, + scheduled_date, + scheduled_time, estimated_duration_minutes, json.dumps(acquisition_params) if acquisition_params else None, - source_session_id, now, now, + source_session_id, + now, + now, ), ) if campaign_ids: for cid in campaign_ids: self.link_planned_session_campaign(psid, cid) - logger.info(f"Created planned session {psid} for {scheduled_date}: {title or notes or '(untitled)'}") + logger.info( + f"Created planned session {psid} for {scheduled_date}: {title or notes or '(untitled)'}" + ) return psid - def get_planned_session(self, planned_session_id: str) -> Optional[PlannedSession]: + def get_planned_session(self, planned_session_id: str) -> PlannedSession | None: """Get a specific planned session.""" row = self._conn.execute( "SELECT * FROM planned_sessions WHERE id = ?", @@ -585,11 +602,11 @@ def get_planned_session(self, planned_session_id: str) -> Optional[PlannedSessio def get_planned_sessions( self, - status: Optional[str] = None, - campaign_id: Optional[str] = None, - from_date: Optional[str] = None, - to_date: Optional[str] = None, - ) -> List[PlannedSession]: + status: str | None = None, + campaign_id: str | None = None, + from_date: str | None = None, + to_date: str | None = None, + ) -> list[PlannedSession]: """ Query planned sessions with optional filters. @@ -627,7 +644,7 @@ def get_planned_sessions( rows = self._conn.execute(query, params).fetchall() return [self._row_to_planned_session(row) for row in rows] - def get_upcoming_sessions(self, limit: int = 10) -> List[PlannedSession]: + def get_upcoming_sessions(self, limit: int = 10) -> list[PlannedSession]: """Get upcoming planned sessions (today and future, status=planned).""" today = datetime.now().strftime("%Y-%m-%d") rows = self._conn.execute( @@ -638,7 +655,7 @@ def get_upcoming_sessions(self, limit: int = 10) -> List[PlannedSession]: ).fetchall() return [self._row_to_planned_session(row) for row in rows] - def get_todays_sessions(self) -> List[PlannedSession]: + def get_todays_sessions(self) -> list[PlannedSession]: """Get planned sessions for today.""" today = datetime.now().strftime("%Y-%m-%d") rows = self._conn.execute( @@ -652,15 +669,15 @@ def get_todays_sessions(self) -> List[PlannedSession]: def update_planned_session( self, planned_session_id: str, - title: Optional[str] = None, - notes: Optional[str] = None, - scheduled_date: Optional[str] = None, - scheduled_time: Optional[str] = None, - estimated_duration_minutes: Optional[int] = None, - acquisition_params: Optional[Dict] = None, - source_session_id: Optional[str] = None, - status: Optional[PlannedSessionStatus] = None, - session_id: Optional[str] = None, + title: str | None = None, + notes: str | None = None, + scheduled_date: str | None = None, + scheduled_time: str | None = None, + estimated_duration_minutes: int | None = None, + acquisition_params: dict | None = None, + source_session_id: str | None = None, + status: PlannedSessionStatus | None = None, + session_id: str | None = None, ): """Update a planned session. Only non-None values are applied.""" now = self._now() @@ -729,7 +746,7 @@ def unlink_planned_session_campaign(self, planned_session_id: str, campaign_id: (planned_session_id, campaign_id), ) - def get_campaign_ids_for_planned_session(self, planned_session_id: str) -> List[str]: + def get_campaign_ids_for_planned_session(self, planned_session_id: str) -> list[str]: """Get campaign IDs linked to a planned session.""" rows = self._conn.execute( "SELECT campaign_id FROM planned_session_campaigns " @@ -749,7 +766,9 @@ def _row_to_planned_session(self, row: sqlite3.Row) -> PlannedSession: scheduled_date=d.get("scheduled_date"), scheduled_time=d.get("scheduled_time"), estimated_duration_minutes=d.get("estimated_duration_minutes"), - acquisition_params=json.loads(d["acquisition_params"]) if d.get("acquisition_params") else None, + acquisition_params=json.loads(d["acquisition_params"]) + if d.get("acquisition_params") + else None, source_session_id=d.get("source_session_id"), status=PlannedSessionStatus(d.get("status", "planned")), session_id=d.get("session_id"), diff --git a/gently/harness/memory/_ml_pipelines.py b/gently/harness/memory/_ml_pipelines.py index a93f021d..e15a8eb6 100644 --- a/gently/harness/memory/_ml_pipelines.py +++ b/gently/harness/memory/_ml_pipelines.py @@ -9,7 +9,9 @@ import json import logging -from typing import Any, Dict, List, Optional +from typing import Any + +from ._protocols import StoreProtocol logger = logging.getLogger(__name__) @@ -74,7 +76,7 @@ """ -class MlPipelinesMixin: +class MlPipelinesMixin(StoreProtocol): """ContextStore mixin for ML pipeline management.""" def _ensure_ml_tables(self): @@ -91,10 +93,10 @@ def create_ml_pipeline( campaign_id: str, name: str, task: str = "embryo_stage_classification", - model_config: Optional[Dict] = None, - data_split: Optional[Dict] = None, - training_config: Optional[Dict] = None, - ) -> Dict[str, Any]: + model_config: dict | None = None, + data_split: dict | None = None, + training_config: dict | None = None, + ) -> dict[str, Any]: """Create a new ML pipeline.""" self._ensure_ml_tables() pipeline_id = self._gen_id() @@ -106,16 +108,22 @@ def create_ml_pipeline( training_config, created_at, updated_at) VALUES (?, ?, ?, ?, 'planned', ?, ?, ?, ?, ?)""", ( - pipeline_id, campaign_id, name, task, + pipeline_id, + campaign_id, + name, + task, json.dumps(model_config) if model_config else None, json.dumps(data_split) if data_split else None, json.dumps(training_config) if training_config else None, - now, now, + now, + now, ), ) - return self.get_ml_pipeline(pipeline_id) + pipeline = self.get_ml_pipeline(pipeline_id) + assert pipeline is not None + return pipeline - def get_ml_pipeline(self, pipeline_id: str) -> Optional[Dict[str, Any]]: + def get_ml_pipeline(self, pipeline_id: str) -> dict[str, Any] | None: """Get a pipeline by ID.""" self._ensure_ml_tables() row = self._conn.execute( @@ -125,7 +133,7 @@ def get_ml_pipeline(self, pipeline_id: str) -> Optional[Dict[str, Any]]: return None return self._row_to_pipeline(row) - def list_ml_pipelines(self, campaign_id: Optional[str] = None) -> List[Dict[str, Any]]: + def list_ml_pipelines(self, campaign_id: str | None = None) -> list[dict[str, Any]]: """List pipelines, optionally filtered by campaign.""" self._ensure_ml_tables() if campaign_id: @@ -139,11 +147,18 @@ def list_ml_pipelines(self, campaign_id: Optional[str] = None) -> List[Dict[str, ).fetchall() return [self._row_to_pipeline(r) for r in rows] - def update_ml_pipeline(self, pipeline_id: str, **kwargs) -> Optional[Dict[str, Any]]: + def update_ml_pipeline(self, pipeline_id: str, **kwargs) -> dict[str, Any] | None: """Update pipeline fields.""" self._ensure_ml_tables() - allowed = {"status", "model_config", "data_split", "training_config", - "best_run_id", "best_accuracy", "name"} + allowed = { + "status", + "model_config", + "data_split", + "training_config", + "best_run_id", + "best_accuracy", + "name", + } updates = [] values = [] for k, v in kwargs.items(): @@ -168,7 +183,7 @@ def update_ml_pipeline(self, pipeline_id: str, **kwargs) -> Optional[Dict[str, A ) return self.get_ml_pipeline(pipeline_id) - def _row_to_pipeline(self, row) -> Dict[str, Any]: + def _row_to_pipeline(self, row) -> dict[str, Any]: return { "id": row["id"], "campaign_id": row["campaign_id"], @@ -177,7 +192,9 @@ def _row_to_pipeline(self, row) -> Dict[str, Any]: "status": row["status"], "model_config": json.loads(row["model_config"]) if row["model_config"] else None, "data_split": json.loads(row["data_split"]) if row["data_split"] else None, - "training_config": json.loads(row["training_config"]) if row["training_config"] else None, + "training_config": json.loads(row["training_config"]) + if row["training_config"] + else None, "best_run_id": row["best_run_id"], "best_accuracy": row["best_accuracy"], "created_at": row["created_at"], @@ -191,11 +208,11 @@ def _row_to_pipeline(self, row) -> Dict[str, Any]: def create_training_run( self, pipeline_id: str, - model_config: Optional[Dict] = None, - training_config: Optional[Dict] = None, - data_split: Optional[Dict] = None, + model_config: dict | None = None, + training_config: dict | None = None, + data_split: dict | None = None, peer_instance_id: str = "", - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Create a new training run.""" self._ensure_ml_tables() run_id = self._gen_id() @@ -206,16 +223,19 @@ def create_training_run( data_split, peer_instance_id) VALUES (?, ?, 'planned', ?, ?, ?, ?)""", ( - run_id, pipeline_id, + run_id, + pipeline_id, json.dumps(model_config) if model_config else None, json.dumps(training_config) if training_config else None, json.dumps(data_split) if data_split else None, peer_instance_id, ), ) - return self.get_training_run(run_id) + run = self.get_training_run(run_id) + assert run is not None + return run - def get_training_run(self, run_id: str) -> Optional[Dict[str, Any]]: + def get_training_run(self, run_id: str) -> dict[str, Any] | None: """Get a training run by ID.""" self._ensure_ml_tables() row = self._conn.execute( @@ -225,7 +245,7 @@ def get_training_run(self, run_id: str) -> Optional[Dict[str, Any]]: return None return self._row_to_run(row) - def list_training_runs(self, pipeline_id: str) -> List[Dict[str, Any]]: + def list_training_runs(self, pipeline_id: str) -> list[dict[str, Any]]: """List runs for a pipeline.""" self._ensure_ml_tables() rows = self._conn.execute( @@ -234,13 +254,22 @@ def list_training_runs(self, pipeline_id: str) -> List[Dict[str, Any]]: ).fetchall() return [self._row_to_run(r) for r in rows] - def update_training_run(self, run_id: str, **kwargs) -> Optional[Dict[str, Any]]: + def update_training_run(self, run_id: str, **kwargs) -> dict[str, Any] | None: """Update training run fields.""" self._ensure_ml_tables() allowed = { - "status", "current_epoch", "total_epochs", "train_loss", "val_loss", - "val_accuracy", "best_val_accuracy", "model_weights_path", "metrics_path", - "started_at", "completed_at", "error_message", + "status", + "current_epoch", + "total_epochs", + "train_loss", + "val_loss", + "val_accuracy", + "best_val_accuracy", + "model_weights_path", + "metrics_path", + "started_at", + "completed_at", + "error_message", } updates = [] values = [] @@ -261,13 +290,15 @@ def update_training_run(self, run_id: str, **kwargs) -> Optional[Dict[str, Any]] ) return self.get_training_run(run_id) - def _row_to_run(self, row) -> Dict[str, Any]: + def _row_to_run(self, row) -> dict[str, Any]: return { "id": row["id"], "pipeline_id": row["pipeline_id"], "status": row["status"], "model_config": json.loads(row["model_config"]) if row["model_config"] else None, - "training_config": json.loads(row["training_config"]) if row["training_config"] else None, + "training_config": json.loads(row["training_config"]) + if row["training_config"] + else None, "data_split": json.loads(row["data_split"]) if row["data_split"] else None, "current_epoch": row["current_epoch"], "total_epochs": row["total_epochs"], @@ -289,15 +320,15 @@ def _row_to_run(self, row) -> Dict[str, Any]: def save_data_assessment( self, - pipeline_id: Optional[str] = None, + pipeline_id: str | None = None, total_sessions: int = 0, total_embryos: int = 0, total_volumes: int = 0, annotated_embryos: int = 0, - stage_distribution: Optional[Dict] = None, - coverage_gaps: Optional[List] = None, + stage_distribution: dict | None = None, + coverage_gaps: list | None = None, quality_notes: str = "", - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Save a data assessment snapshot.""" self._ensure_ml_tables() assessment_id = self._gen_id() @@ -310,16 +341,23 @@ def save_data_assessment( quality_notes, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", ( - assessment_id, pipeline_id, total_sessions, total_embryos, - total_volumes, annotated_embryos, + assessment_id, + pipeline_id, + total_sessions, + total_embryos, + total_volumes, + annotated_embryos, json.dumps(stage_distribution) if stage_distribution else None, json.dumps(coverage_gaps) if coverage_gaps else None, - quality_notes, now, + quality_notes, + now, ), ) - return self.get_data_assessment(assessment_id) + assessment = self.get_data_assessment(assessment_id) + assert assessment is not None + return assessment - def get_data_assessment(self, assessment_id: str) -> Optional[Dict[str, Any]]: + def get_data_assessment(self, assessment_id: str) -> dict[str, Any] | None: """Get a data assessment by ID.""" self._ensure_ml_tables() row = self._conn.execute( @@ -334,7 +372,9 @@ def get_data_assessment(self, assessment_id: str) -> Optional[Dict[str, Any]]: "total_embryos": row["total_embryos"], "total_volumes": row["total_volumes"], "annotated_embryos": row["annotated_embryos"], - "stage_distribution": json.loads(row["stage_distribution"]) if row["stage_distribution"] else None, + "stage_distribution": json.loads(row["stage_distribution"]) + if row["stage_distribution"] + else None, "coverage_gaps": json.loads(row["coverage_gaps"]) if row["coverage_gaps"] else None, "quality_notes": row["quality_notes"], "created_at": row["created_at"], diff --git a/gently/harness/memory/_plans.py b/gently/harness/memory/_plans.py index 0a872e2f..96e4d18a 100644 --- a/gently/harness/memory/_plans.py +++ b/gently/harness/memory/_plans.py @@ -9,8 +9,9 @@ import logging import sqlite3 from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any +from ._protocols import StoreProtocol from .model import ( BenchSpec, ImagingSpec, @@ -22,7 +23,7 @@ logger = logging.getLogger(__name__) -class PlansMixin: +class PlansMixin(StoreProtocol): """Plan items, templates, snapshots, and dependency management.""" # ================================================================== @@ -34,15 +35,15 @@ def create_plan_item( campaign_id: str, type: str, title: str, - description: Optional[str] = None, - spec: Optional[Dict] = None, - inherit_from: Optional[str] = None, - planned_session_id: Optional[str] = None, + description: str | None = None, + spec: dict | None = None, + inherit_from: str | None = None, + planned_session_id: str | None = None, phase_order: int = -1, - depends_on: Optional[List[str]] = None, - item_id: Optional[str] = None, - references: Optional[List[Dict]] = None, - estimated_days: Optional[int] = None, + depends_on: list[str] | None = None, + item_id: str | None = None, + references: list[dict] | None = None, + estimated_days: int | None = None, ) -> str: """Create a plan item. Returns its ID. @@ -55,8 +56,7 @@ def create_plan_item( if phase_order < 0: # Auto-assign: next number in this campaign row = self._conn.execute( - "SELECT COALESCE(MAX(phase_order), 0) FROM plan_items " - "WHERE campaign_id = ?", + "SELECT COALESCE(MAX(phase_order), 0) FROM plan_items WHERE campaign_id = ?", (campaign_id,), ).fetchone() phase_order = row[0] + 1 @@ -65,14 +65,23 @@ def create_plan_item( self._conn.execute( "INSERT INTO plan_items " "(id, campaign_id, type, title, description, spec, inherit_from, " - " planned_session_id, estimated_days, phase_order, \"references\", status, created_at, updated_at) " + " planned_session_id, estimated_days, phase_order," + ' "references", status, created_at, updated_at) ' "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 'planned', ?, ?)", ( - pid, campaign_id, type, title, description, + pid, + campaign_id, + type, + title, + description, json.dumps(spec) if spec else None, - inherit_from, planned_session_id, estimated_days, phase_order, + inherit_from, + planned_session_id, + estimated_days, + phase_order, json.dumps(references) if references else None, - now, now, + now, + now, ), ) if depends_on: @@ -85,16 +94,16 @@ def create_plan_item( logger.info(f"Created plan item {pid} [{type}] #{phase_order}: {title}") return pid - def get_plan_item(self, item_id: str) -> Optional[PlanItem]: + def get_plan_item(self, item_id: str) -> PlanItem | None: """Get a specific plan item.""" - row = self._conn.execute( - "SELECT * FROM plan_items WHERE id = ?", (item_id,) - ).fetchone() + row = self._conn.execute("SELECT * FROM plan_items WHERE id = ?", (item_id,)).fetchone() return self._row_to_plan_item(row) if row else None def resolve_plan_item( - self, ref: str, campaign_id: Optional[str] = None, - ) -> Optional[PlanItem]: + self, + ref: str, + campaign_id: str | None = None, + ) -> PlanItem | None: """Resolve a human-friendly plan item reference. Supported formats: @@ -114,16 +123,14 @@ def resolve_plan_item( ref = ref.strip().lower() # --- Direct ID match (UUID prefix) --- - row = self._conn.execute( - "SELECT * FROM plan_items WHERE id = ?", (ref,) - ).fetchone() + row = self._conn.execute("SELECT * FROM plan_items WHERE id = ?", (ref,)).fetchone() if row: return self._row_to_plan_item(row) # Also try UUID prefix match (e.g. first few chars) - if len(ref) >= 4 and re.match(r'^[0-9a-f]+$', ref): + if len(ref) >= 4 and re.match(r"^[0-9a-f]+$", ref): row = self._conn.execute( - "SELECT * FROM plan_items WHERE id LIKE ?", (ref + '%',) + "SELECT * FROM plan_items WHERE id LIKE ?", (ref + "%",) ).fetchone() if row: return self._row_to_plan_item(row) @@ -133,7 +140,7 @@ def resolve_plan_item( task_num = None # "campaign.phase.task" — e.g. "nerve-ring.1.3" or "ec11.2.1" - m = re.match(r'^([^.\s]+)\.(\d+)\.(\d+)$', ref) + m = re.match(r"^([^.\s]+)\.(\d+)\.(\d+)$", ref) if m: campaign_label = m.group(1) phase_num, task_num = int(m.group(2)), int(m.group(3)) @@ -144,25 +151,25 @@ def resolve_plan_item( # "1.3" or "2.1" if not task_num: - m = re.match(r'^(\d+)\.(\d+)$', ref) + m = re.match(r"^(\d+)\.(\d+)$", ref) if m: phase_num, task_num = int(m.group(1)), int(m.group(2)) # "task 3 of phase 1" / "task 3 phase 1" if not task_num: - m = re.search(r'task\s+(\d+)\s+(?:of\s+)?phase\s+(\d+)', ref) + m = re.search(r"task\s+(\d+)\s+(?:of\s+)?phase\s+(\d+)", ref) if m: task_num, phase_num = int(m.group(1)), int(m.group(2)) # "phase 1 task 3" if not task_num: - m = re.search(r'phase\s+(\d+)\s+task\s+(\d+)', ref) + m = re.search(r"phase\s+(\d+)\s+task\s+(\d+)", ref) if m: phase_num, task_num = int(m.group(1)), int(m.group(2)) # "task 3" / "#3" / just "3" if not task_num: - m = re.match(r'^(?:task\s+|#)?(\d+)$', ref) + m = re.match(r"^(?:task\s+|#)?(\d+)$", ref) if m: task_num = int(m.group(1)) @@ -212,7 +219,7 @@ def resolve_plan_item( return None - def _resolve_campaign_label(self, label: str) -> Optional[str]: + def _resolve_campaign_label(self, label: str) -> str | None: """Resolve a campaign shorthand or UUID prefix to an ID. Checks shorthand (case-insensitive), then UUID prefix, then @@ -232,7 +239,7 @@ def _resolve_campaign_label(self, label: str) -> Optional[str]: if len(label) >= 4: row = self._conn.execute( "SELECT id FROM campaigns WHERE id LIKE ? AND parent_id IS NULL", - (label_lower + '%',), + (label_lower + "%",), ).fetchone() if row: return row["id"] @@ -240,7 +247,7 @@ def _resolve_campaign_label(self, label: str) -> Optional[str]: # Description substring match (first word or hyphenated slug) row = self._conn.execute( "SELECT id FROM campaigns WHERE LOWER(description) LIKE ? AND parent_id IS NULL", - ('%' + label_lower + '%',), + ("%" + label_lower + "%",), ).fetchone() if row: return row["id"] @@ -249,11 +256,11 @@ def _resolve_campaign_label(self, label: str) -> Optional[str]: def get_plan_items( self, - campaign_id: Optional[str] = None, - status: Optional[str] = None, - type: Optional[str] = None, + campaign_id: str | None = None, + status: str | None = None, + type: str | None = None, include_children: bool = False, - ) -> List[PlanItem]: + ) -> list[PlanItem]: """ Query plan items with optional filters. @@ -296,17 +303,17 @@ def get_plan_items( def update_plan_item( self, item_id: str, - title: Optional[str] = None, - description: Optional[str] = None, - status: Optional[PlanItemStatus] = None, - outcome: Optional[str] = None, - spec: Optional[Dict] = None, - planned_session_id: Optional[str] = None, - session_id: Optional[str] = None, - phase_order: Optional[int] = None, - campaign_id: Optional[str] = None, - references: Optional[List[Dict]] = None, - estimated_days: Optional[int] = None, + title: str | None = None, + description: str | None = None, + status: PlanItemStatus | None = None, + outcome: str | None = None, + spec: dict | None = None, + planned_session_id: str | None = None, + session_id: str | None = None, + phase_order: int | None = None, + campaign_id: str | None = None, + references: list[dict] | None = None, + estimated_days: int | None = None, ): """Update a plan item. Only non-None values are applied.""" now = self._now() @@ -334,7 +341,7 @@ def update_plan_item( updates.append("phase_order = ?") values.append(phase_order) if references is not None: - updates.append("\"references\" = ?") + updates.append('"references" = ?') values.append(json.dumps(references)) if not updates: return @@ -350,10 +357,12 @@ def update_plan_item( def complete_plan_item(self, item_id: str, outcome: str): """Mark a plan item as completed with an outcome description.""" self.update_plan_item( - item_id, status=PlanItemStatus.COMPLETED, outcome=outcome, + item_id, + status=PlanItemStatus.COMPLETED, + outcome=outcome, ) - def skip_plan_item(self, item_id: str, reason: Optional[str] = None): + def skip_plan_item(self, item_id: str, reason: str | None = None): """Mark a plan item as skipped.""" self.update_plan_item( item_id, @@ -369,12 +378,12 @@ def delete_plan_item(self, item_id: str) -> bool: with self._tx(): # Remove dependency links (both directions) self._conn.execute( - "DELETE FROM plan_item_dependencies " - "WHERE item_id = ? OR depends_on_id = ?", + "DELETE FROM plan_item_dependencies WHERE item_id = ? OR depends_on_id = ?", (item_id, item_id), ) r = self._conn.execute( - "DELETE FROM plan_items WHERE id = ?", (item_id,), + "DELETE FROM plan_items WHERE id = ?", + (item_id,), ) deleted = r.rowcount > 0 if deleted: @@ -394,12 +403,11 @@ def remove_plan_item_dependency(self, item_id: str, depends_on_id: str): """Remove a dependency between plan items.""" with self._tx(): self._conn.execute( - "DELETE FROM plan_item_dependencies " - "WHERE item_id = ? AND depends_on_id = ?", + "DELETE FROM plan_item_dependencies WHERE item_id = ? AND depends_on_id = ?", (item_id, depends_on_id), ) - def get_plan_item_dependencies(self, item_id: str) -> List[str]: + def get_plan_item_dependencies(self, item_id: str) -> list[str]: """Get IDs of items this item depends on.""" rows = self._conn.execute( "SELECT depends_on_id FROM plan_item_dependencies WHERE item_id = ?", @@ -407,7 +415,7 @@ def get_plan_item_dependencies(self, item_id: str) -> List[str]: ).fetchall() return [row["depends_on_id"] for row in rows] - def get_plan_item_dependents(self, item_id: str) -> List[str]: + def get_plan_item_dependents(self, item_id: str) -> list[str]: """Get IDs of items that depend on this item.""" rows = self._conn.execute( "SELECT item_id FROM plan_item_dependencies WHERE depends_on_id = ?", @@ -415,13 +423,15 @@ def get_plan_item_dependents(self, item_id: str) -> List[str]: ).fetchall() return [row["item_id"] for row in rows] - def get_unblocked_plan_items(self, campaign_id: str) -> List[PlanItem]: + def get_unblocked_plan_items(self, campaign_id: str) -> list[PlanItem]: """ Get plan items that are planned and have all dependencies completed. These are the items that can be started next. """ items = self.get_plan_items( - campaign_id=campaign_id, status="planned", include_children=True, + campaign_id=campaign_id, + status="planned", + include_children=True, ) unblocked = [] for item in items: @@ -433,7 +443,8 @@ def get_unblocked_plan_items(self, campaign_id: str) -> List[PlanItem]: for dep_id in item.depends_on: dep = self.get_plan_item(dep_id) if dep and dep.status not in ( - PlanItemStatus.COMPLETED, PlanItemStatus.SKIPPED, + PlanItemStatus.COMPLETED, + PlanItemStatus.SKIPPED, ): all_resolved = False break @@ -441,7 +452,7 @@ def get_unblocked_plan_items(self, campaign_id: str) -> List[PlanItem]: unblocked.append(item) return unblocked - def get_plan_status(self, campaign_id: str) -> Dict[str, Any]: + def get_plan_status(self, campaign_id: str) -> dict[str, Any]: """ Get a summary of plan progress for a campaign and its children. @@ -461,9 +472,10 @@ def get_plan_status(self, campaign_id: str) -> Dict[str, Any]: } """ items = self.get_plan_items( - campaign_id=campaign_id, include_children=True, + campaign_id=campaign_id, + include_children=True, ) - result = { + result: dict[str, Any] = { "total": len(items), "completed": 0, "in_progress": 0, @@ -488,10 +500,7 @@ def get_plan_status(self, campaign_id: str) -> Dict[str, Any]: result["by_type"][type_key]["completed"] += 1 # Pending decisions - if ( - item.type == PlanItemType.DECISION_POINT - and item.status == PlanItemStatus.PLANNED - ): + if item.type == PlanItemType.DECISION_POINT and item.status == PlanItemStatus.PLANNED: result["pending_decisions"].append(item) # Next actions = unblocked items @@ -499,7 +508,7 @@ def get_plan_status(self, campaign_id: str) -> Dict[str, Any]: return result - def resolve_imaging_spec(self, item: PlanItem) -> Optional[ImagingSpec]: + def resolve_imaging_spec(self, item: PlanItem) -> ImagingSpec | None: """ Resolve the full ImagingSpec for an item, following inheritance. @@ -542,7 +551,7 @@ def resolve_imaging_spec(self, item: PlanItem) -> Optional[ImagingSpec]: def save_plan_template( self, name: str, - description: Optional[str], + description: str | None, campaign_id: str, ) -> str: """ @@ -568,7 +577,7 @@ def save_plan_template( logger.info(f"Saved plan template '{name}' ({tid})") return tid - def _serialize_campaign_tree(self, campaign_id: str) -> Dict: + def _serialize_campaign_tree(self, campaign_id: str) -> dict: """Recursively serialize a campaign and its children/items.""" campaign = self.get_campaign(campaign_id) if not campaign: @@ -581,7 +590,7 @@ def _serialize_campaign_tree(self, campaign_id: str) -> Dict: all_item_ids = [it.id for it in items] serialized_items = [] for item in items: - item_data = { + item_data: dict[str, Any] = { "type": item.type.value, "title": item.title, "description": item.description, @@ -590,6 +599,7 @@ def _serialize_campaign_tree(self, campaign_id: str) -> Dict: # Serialize spec if item.imaging_spec: import dataclasses as _dc + spec_dict = {} for f in _dc.fields(item.imaging_spec): val = getattr(item.imaging_spec, f.name) @@ -598,6 +608,7 @@ def _serialize_campaign_tree(self, campaign_id: str) -> Dict: item_data["spec"] = spec_dict elif item.bench_spec: import dataclasses as _dc + spec_dict = {} for f in _dc.fields(item.bench_spec): val = getattr(item.bench_spec, f.name) @@ -633,7 +644,7 @@ def _serialize_campaign_tree(self, campaign_id: str) -> Dict: "children": serialized_children, } - def list_plan_templates(self) -> List[Dict]: + def list_plan_templates(self) -> list[dict]: """List all plan templates (id, name, description, dates).""" rows = self._conn.execute( "SELECT id, name, description, created_at, updated_at " @@ -641,7 +652,7 @@ def list_plan_templates(self) -> List[Dict]: ).fetchall() return [dict(row) for row in rows] - def get_plan_template(self, id_or_name: str) -> Optional[Dict]: + def get_plan_template(self, id_or_name: str) -> dict | None: """Get a plan template by ID or name.""" row = self._conn.execute( "SELECT * FROM plan_templates WHERE id = ? OR name = ?", @@ -656,7 +667,7 @@ def get_plan_template(self, id_or_name: str) -> Optional[Dict]: def apply_plan_template( self, template_id: str, - overrides: Optional[Dict] = None, + overrides: dict | None = None, ) -> str: """ Instantiate a template into a new campaign with plan items. @@ -673,9 +684,9 @@ def apply_plan_template( def _instantiate_template_tree( self, - data: Dict, - parent_id: Optional[str], - overrides: Dict, + data: dict, + parent_id: str | None, + overrides: dict, ) -> str: """Recursively create campaigns and items from template data.""" cid = self.create_campaign( @@ -687,7 +698,7 @@ def _instantiate_template_tree( # Create items, track new IDs for dependency wiring items_data = data.get("items", []) - new_item_ids: List[str] = [] + new_item_ids: list[str] = [] for item_data in items_data: spec = item_data.get("spec") @@ -697,9 +708,15 @@ def _instantiate_template_tree( spec = dict(spec) # copy for k, v in overrides.items(): if k in spec or k in ( - "strain", "genotype", "reporter", "temperature_c", - "num_slices", "exposure_ms", "interval_s", - "num_embryos", "stop_condition", + "strain", + "genotype", + "reporter", + "temperature_c", + "num_slices", + "exposure_ms", + "interval_s", + "num_embryos", + "stop_condition", ): spec[k] = v @@ -720,7 +737,8 @@ def _instantiate_template_tree( for dep_idx in dep_indices: if 0 <= dep_idx < len(new_item_ids): self.add_plan_item_dependency( - new_item_ids[idx], new_item_ids[dep_idx], + new_item_ids[idx], + new_item_ids[dep_idx], ) # Recurse into children @@ -745,8 +763,8 @@ def delete_plan_template(self, template_id: str) -> bool: def create_plan_snapshot( self, campaign_id: str, - label: Optional[str] = None, - summary: Optional[str] = None, + label: str | None = None, + summary: str | None = None, ) -> str: """Create a snapshot of the current plan state. @@ -770,8 +788,7 @@ def create_plan_snapshot( # Auto-increment version number for this campaign row = self._conn.execute( - "SELECT COALESCE(MAX(version_number), 0) FROM plan_snapshots " - "WHERE campaign_id = ?", + "SELECT COALESCE(MAX(version_number), 0) FROM plan_snapshots WHERE campaign_id = ?", (campaign_id,), ).fetchone() version_number = row[0] + 1 @@ -793,12 +810,19 @@ def create_plan_snapshot( " summary, label, parent_version_id, created_at) " "VALUES (?, ?, ?, ?, ?, ?, ?, ?)", ( - version_id, campaign_id, version_number, - json.dumps(snapshot_data), summary, label, - parent_version_id, now, + version_id, + campaign_id, + version_number, + json.dumps(snapshot_data), + summary, + label, + parent_version_id, + now, ), ) - logger.info(f"Created plan snapshot v{version_number} ({version_id}) for campaign {campaign_id}") + logger.info( + f"Created plan snapshot v{version_number} ({version_id}) for campaign {campaign_id}" + ) return version_id def _generate_snapshot_summary(self, campaign_id: str) -> str: @@ -811,7 +835,7 @@ def _generate_snapshot_summary(self, campaign_id: str) -> str: items = self.get_plan_items(campaign_id=campaign_id, include_children=True) # Count items by status - status_counts: Dict[str, int] = {} + status_counts: dict[str, int] = {} for item in items: key = item.status.value status_counts[key] = status_counts.get(key, 0) + 1 @@ -827,8 +851,10 @@ def _generate_snapshot_summary(self, campaign_id: str) -> str: return "\n".join(parts) def list_plan_snapshots( - self, campaign_id: str, limit: int = 50, - ) -> List[Dict]: + self, + campaign_id: str, + limit: int = 50, + ) -> list[dict]: """List snapshots for a campaign (metadata only, no blob). Returns list of dicts with version_id, version_number, label, @@ -843,7 +869,7 @@ def list_plan_snapshots( ).fetchall() return [dict(row) for row in rows] - def get_plan_snapshot(self, version_id: str) -> Optional[Dict]: + def get_plan_snapshot(self, version_id: str) -> dict | None: """Get a full snapshot including the parsed JSON blob.""" row = self._conn.execute( "SELECT * FROM plan_snapshots WHERE version_id = ?", @@ -901,7 +927,7 @@ def restore_plan_snapshot(self, version_id: str) -> str: ) return new_campaign_id - def _get_campaign_tree_ids(self, campaign_id: str) -> List[str]: + def _get_campaign_tree_ids(self, campaign_id: str) -> list[str]: """Get all campaign IDs in a tree (recursive).""" ids = [campaign_id] children = self._conn.execute( @@ -932,16 +958,14 @@ def _row_to_plan_item(self, row: sqlite3.Row) -> PlanItem: if spec_data: if item_type == PlanItemType.IMAGING: import dataclasses as _dc + valid = {f.name for f in _dc.fields(ImagingSpec)} - imaging_spec = ImagingSpec(**{ - k: v for k, v in spec_data.items() if k in valid - }) + imaging_spec = ImagingSpec(**{k: v for k, v in spec_data.items() if k in valid}) else: import dataclasses as _dc + valid = {f.name for f in _dc.fields(BenchSpec)} - bench_spec = BenchSpec(**{ - k: v for k, v in spec_data.items() if k in valid - }) + bench_spec = BenchSpec(**{k: v for k, v in spec_data.items() if k in valid}) references = json.loads(d["references"]) if d.get("references") else [] diff --git a/gently/harness/memory/_protocols.py b/gently/harness/memory/_protocols.py new file mode 100644 index 00000000..16162a29 --- /dev/null +++ b/gently/harness/memory/_protocols.py @@ -0,0 +1,60 @@ +""" +StoreProtocol — typing-only interface for members the memory mixins expect +from their host class (ContextStore) and from each other. + +IntentionsMixin, PlansMixin, UnderstandingMixin, and MlPipelinesMixin are +combined into ContextStore via multiple inheritance. Each mixin calls +methods/attributes defined either on ContextStore itself (_conn, _tx, _now, +_gen_id) or on one of the sibling mixins (e.g. PlansMixin.get_plan_items +called from IntentionsMixin). Declaring this Protocol as a base lets mypy +see those members without introducing a runtime dependency between mixins. +""" + +import sqlite3 +from contextlib import AbstractContextManager +from typing import Protocol, runtime_checkable + +from .model import Campaign, PlanItem + + +@runtime_checkable +class StoreProtocol(Protocol): + _conn: sqlite3.Connection + + def _tx(self) -> AbstractContextManager[sqlite3.Connection]: ... + + def _now(self) -> str: ... + + def _gen_id(self) -> str: ... + + def get_state(self, key: str) -> str | None: ... + + def get_campaign(self, campaign_id: str) -> Campaign | None: ... + + def get_root_campaigns(self, status: str | None = "active") -> list[Campaign]: ... + + def get_subcampaigns(self, campaign_id: str) -> list[Campaign]: ... + + def create_campaign( + self, + description: str, + shorthand: str | None = None, + summary: str | None = None, + target: str | None = None, + parent_id: str | None = None, + campaign_id: str | None = None, + ) -> str: ... + + def delete_campaign(self, campaign_id: str, cascade: bool = True) -> dict[str, int]: ... + + def update_campaign_progress(self, campaign_id: str, progress: str) -> None: ... + + def get_plan_items( + self, + campaign_id: str | None = None, + status: str | None = None, + type: str | None = None, + include_children: bool = False, + ) -> list[PlanItem]: ... + + def _resolve_campaign_label(self, label: str) -> str | None: ... diff --git a/gently/harness/memory/_understanding.py b/gently/harness/memory/_understanding.py index 19d81f57..29cc9c9a 100644 --- a/gently/harness/memory/_understanding.py +++ b/gently/harness/memory/_understanding.py @@ -10,8 +10,8 @@ import logging import sqlite3 from datetime import datetime -from typing import Dict, List, Optional +from ._protocols import StoreProtocol from .model import ( Attention, Confidence, @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) -class UnderstandingMixin: +class UnderstandingMixin(StoreProtocol): """Observations, expectations, watchpoints, questions, learnings, embryo understanding, agent state, and batch updates.""" @@ -46,7 +46,7 @@ def _load_understanding(self) -> Understanding: learnings=self.get_learnings(), ) - def _load_embryo_states(self) -> Dict[str, EmbryoUnderstanding]: + def _load_embryo_states(self) -> dict[str, EmbryoUnderstanding]: rows = self._conn.execute( "SELECT * FROM embryo_understanding WHERE is_tracked = 1" ).fetchall() @@ -56,10 +56,14 @@ def _load_embryo_states(self) -> Dict[str, EmbryoUnderstanding]: result[d["embryo_id"]] = EmbryoUnderstanding( embryo_id=d["embryo_id"], current_stage=d.get("current_stage"), - stage_confidence=Confidence(d["stage_confidence"]) if d.get("stage_confidence") else None, + stage_confidence=Confidence(d["stage_confidence"]) + if d.get("stage_confidence") + else None, health_assessment=d.get("health_assessment"), notes=json.loads(d["notes"]) if d.get("notes") else [], - last_observed=datetime.fromisoformat(d["last_observed"]) if d.get("last_observed") else None, + last_observed=datetime.fromisoformat(d["last_observed"]) + if d.get("last_observed") + else None, is_tracked=bool(d.get("is_tracked", True)), is_hatched=bool(d.get("is_hatched", False)), needs_attention=bool(d.get("needs_attention", False)), @@ -98,7 +102,7 @@ def add_observation(self, obs: Observation): ), ) - def get_recent_observations(self, limit: int = 50) -> List[Observation]: + def get_recent_observations(self, limit: int = 50) -> list[Observation]: """Get recent observations.""" rows = self._conn.execute( "SELECT * FROM observations ORDER BY timestamp DESC LIMIT ?", @@ -106,11 +110,10 @@ def get_recent_observations(self, limit: int = 50) -> List[Observation]: ).fetchall() return [self._row_to_observation(row) for row in reversed(rows)] - def get_observations_for_embryo(self, embryo_id: str, limit: int = 20) -> List[Observation]: + def get_observations_for_embryo(self, embryo_id: str, limit: int = 20) -> list[Observation]: """Get observations for a specific embryo.""" rows = self._conn.execute( - "SELECT * FROM observations WHERE embryo_id = ? " - "ORDER BY timestamp DESC LIMIT ?", + "SELECT * FROM observations WHERE embryo_id = ? ORDER BY timestamp DESC LIMIT ?", (embryo_id, limit), ).fetchall() return [self._row_to_observation(row) for row in reversed(rows)] @@ -152,14 +155,14 @@ def add_expectation(self, exp: Expectation): ), ) - def get_pending_expectations(self) -> List[Expectation]: + def get_pending_expectations(self) -> list[Expectation]: """Get all pending expectations.""" rows = self._conn.execute( "SELECT * FROM expectations WHERE status = 'pending' ORDER BY expected_time" ).fetchall() return [self._row_to_expectation(row) for row in rows] - def get_expectation_for(self, target: str) -> Optional[Expectation]: + def get_expectation_for(self, target: str) -> Expectation | None: """Get the pending expectation for a specific target.""" row = self._conn.execute( "SELECT * FROM expectations WHERE target = ? AND status = 'pending' " @@ -211,7 +214,7 @@ def add_watchpoint(self, wp: Watchpoint): ), ) - def get_active_watchpoints(self) -> List[Watchpoint]: + def get_active_watchpoints(self) -> list[Watchpoint]: """Get all active watchpoints.""" rows = self._conn.execute( "SELECT * FROM watchpoints WHERE status = 'active' ORDER BY priority DESC, created_at" @@ -253,16 +256,14 @@ def add_question(self, q: Question): """Add a question.""" with self._tx(): self._conn.execute( - "INSERT INTO questions (id, content, status, created_at) " - "VALUES (?, ?, ?, ?)", + "INSERT INTO questions (id, content, status, created_at) VALUES (?, ?, ?, ?)", (q.id, q.content, q.status.value, q.created_at.isoformat()), ) - def get_open_questions(self) -> List[Question]: + def get_open_questions(self) -> list[Question]: """Get all open questions.""" rows = self._conn.execute( - "SELECT * FROM questions WHERE status IN ('open', 'investigating') " - "ORDER BY created_at" + "SELECT * FROM questions WHERE status IN ('open', 'investigating') ORDER BY created_at" ).fetchall() return [self._row_to_question(row) for row in rows] @@ -306,7 +307,7 @@ def add_learning(self, learning: Learning): ), ) - def get_learnings(self, limit: int = 50) -> List[Learning]: + def get_learnings(self, limit: int = 50) -> list[Learning]: """Get learnings.""" rows = self._conn.execute( "SELECT * FROM learnings ORDER BY created_at DESC LIMIT ?", @@ -331,13 +332,13 @@ def _row_to_learning(self, row: sqlite3.Row) -> Learning: def update_embryo_understanding( self, embryo_id: str, - current_stage: Optional[str] = None, - stage_confidence: Optional[Confidence] = None, - health_assessment: Optional[str] = None, - note: Optional[str] = None, - is_hatched: Optional[bool] = None, - needs_attention: Optional[bool] = None, - attention_reason: Optional[str] = None, + current_stage: str | None = None, + stage_confidence: Confidence | None = None, + health_assessment: str | None = None, + note: str | None = None, + is_hatched: bool | None = None, + needs_attention: bool | None = None, + attention_reason: str | None = None, ): """Update understanding of an embryo.""" now = self._now() @@ -403,11 +404,9 @@ def update_embryo_understanding( # Agent State # ================================================================== - def get_state(self, key: str) -> Optional[str]: + def get_state(self, key: str) -> str | None: """Get a state value.""" - row = self._conn.execute( - "SELECT value FROM agent_state WHERE key = ?", (key,) - ).fetchone() + row = self._conn.execute("SELECT value FROM agent_state WHERE key = ?", (key,)).fetchone() return row["value"] if row else None def set_state(self, key: str, value: str): @@ -415,8 +414,7 @@ def set_state(self, key: str, value: str): now = self._now() with self._tx(): self._conn.execute( - "INSERT OR REPLACE INTO agent_state (key, value, updated_at) " - "VALUES (?, ?, ?)", + "INSERT OR REPLACE INTO agent_state (key, value, updated_at) VALUES (?, ?, ?)", (key, value, now), ) diff --git a/gently/harness/memory/file_store.py b/gently/harness/memory/file_store.py index 84beebf2..0531dcce 100644 --- a/gently/harness/memory/file_store.py +++ b/gently/harness/memory/file_store.py @@ -28,6 +28,7 @@ assessments/{id}.yaml """ +import copy import dataclasses import json import logging @@ -35,10 +36,9 @@ import re import shutil import uuid -from contextlib import contextmanager from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any import yaml @@ -56,11 +56,11 @@ Intentions, Learning, Observation, - PlannedSession, - PlannedSessionStatus, PlanItem, PlanItemStatus, PlanItemType, + PlannedSession, + PlannedSessionStatus, Project, Question, QuestionStatus, @@ -79,8 +79,10 @@ # YAML helpers -- keep datetimes as ISO strings in files # --------------------------------------------------------------------------- + class _ISODumper(yaml.SafeDumper): """Custom dumper that serialises datetime objects as ISO strings.""" + pass @@ -95,6 +97,7 @@ def _datetime_representer(dumper, data): # FileContextStore # --------------------------------------------------------------------------- + class FileContextStore: """ File-based storage for the agent's context. @@ -106,8 +109,13 @@ class FileContextStore: def __init__(self, agent_dir: Path): self.agent_dir = Path(agent_dir) self._ensure_dirs() + # YAML parse cache: str(path) -> ((mtime, size), parsed). Collapses the + # O(N^2) re-parsing in campaign-tree builds; auto-invalidated by file + # mtime/size changes and explicitly on _write_yaml. Set BEFORE the index + # rebuild below, which reads YAML through the cache. + self._yaml_cache: dict[str, tuple] = {} # In-memory index: campaign_id -> folder Path - self._campaign_index: Dict[str, Path] = {} + self._campaign_index: dict[str, Path] = {} self._rebuild_campaign_index() # ------------------------------------------------------------------ @@ -149,7 +157,7 @@ def _rebuild_campaign_index(self): if data and "id" in data: self._campaign_index[data["id"]] = entry - def _campaign_folder(self, campaign_id: str) -> Optional[Path]: + def _campaign_folder(self, campaign_id: str) -> Path | None: """Return the folder for a campaign, or None.""" return self._campaign_index.get(campaign_id) @@ -174,25 +182,41 @@ def _write_yaml(self, path: Path, data): path.parent.mkdir(parents=True, exist_ok=True) tmp = path.with_suffix(".tmp") with open(tmp, "w", encoding="utf-8") as fh: - yaml.dump(data, fh, Dumper=_ISODumper, default_flow_style=False, - allow_unicode=True, sort_keys=False) + yaml.dump( + data, + fh, + Dumper=_ISODumper, + default_flow_style=False, + allow_unicode=True, + sort_keys=False, + ) # Atomic rename (on Windows this replaces the target). - if os.name == "nt": - # os.replace is atomic on Windows when on same volume. - os.replace(str(tmp), str(path)) - else: - os.replace(str(tmp), str(path)) + os.replace(str(tmp), str(path)) + # Invalidate the parse cache so the next read reloads (new mtime anyway). + self._yaml_cache.pop(str(path), None) def _read_yaml(self, path: Path): - """Read a YAML file; return None if missing or empty.""" - if not path.exists(): + """Read a YAML file, parse-cached by (mtime, size). Returns None if + missing or empty. The cached object is never handed out directly — every + return is a deepcopy — so callers may freely mutate the result without + corrupting the cache.""" + try: + st = path.stat() + except OSError: return None + key = str(path) + sig = (st.st_mtime, st.st_size) + cached = self._yaml_cache.get(key) + if cached is not None and cached[0] == sig: + return copy.deepcopy(cached[1]) try: - with open(path, "r", encoding="utf-8") as fh: - return yaml.safe_load(fh) + with open(path, encoding="utf-8") as fh: + data = yaml.safe_load(fh) except Exception: logger.warning(f"Failed to read {path}", exc_info=True) return None + self._yaml_cache[key] = (sig, data) + return copy.deepcopy(data) def _append_jsonl(self, path: Path, record: dict): """Append one JSON line to a file.""" @@ -223,7 +247,7 @@ def __repr__(self): def reset(self) -> dict: """Delete all data files; return counts of deleted items by category.""" - counts: Dict[str, int] = {} + counts: dict[str, int] = {} def _count_and_remove(subdir: str, label: str): d = self.agent_dir / subdir @@ -270,7 +294,9 @@ def _count_and_remove(subdir: str, label: str): self._campaign_index.clear() total = sum(counts.values()) - logger.info(f"File context store reset -- {total} items cleared from {len(counts)} categories") + logger.info( + f"File context store reset -- {total} items cleared from {len(counts)} categories" + ) return counts # ================================================================== @@ -301,8 +327,8 @@ def _load_understanding(self) -> Understanding: learnings=self.get_learnings(), ) - def _load_embryo_states(self) -> Dict[str, EmbryoUnderstanding]: - result: Dict[str, EmbryoUnderstanding] = {} + def _load_embryo_states(self) -> dict[str, EmbryoUnderstanding]: + result: dict[str, EmbryoUnderstanding] = {} eu_dir = self.agent_dir / "embryo_understanding" if not eu_dir.exists(): return result @@ -330,11 +356,11 @@ def _load_attention(self) -> Attention: def create_campaign( self, description: str, - shorthand: Optional[str] = None, - summary: Optional[str] = None, - target: Optional[str] = None, - parent_id: Optional[str] = None, - campaign_id: Optional[str] = None, + shorthand: str | None = None, + summary: str | None = None, + target: str | None = None, + parent_id: str | None = None, + campaign_id: str | None = None, ) -> str: cid = campaign_id or self._gen_id() now = self._now() @@ -368,7 +394,7 @@ def create_campaign( logger.info(f"Created campaign {cid} [{label}]") return cid - def get_campaign(self, campaign_id: str) -> Optional[Campaign]: + def get_campaign(self, campaign_id: str) -> Campaign | None: folder = self._campaign_folder(campaign_id) if not folder: return None @@ -377,9 +403,9 @@ def get_campaign(self, campaign_id: str) -> Optional[Campaign]: return None return self._dict_to_campaign(data) - def get_active_campaigns(self) -> List[Campaign]: + def get_active_campaigns(self) -> list[Campaign]: campaigns = [] - for cid, folder in self._campaign_index.items(): + for _cid, folder in self._campaign_index.items(): data = self._read_yaml(folder / "campaign.yaml") if data and data.get("status") == "active": campaigns.append(self._dict_to_campaign(data)) @@ -389,7 +415,7 @@ def get_active_campaigns(self) -> List[Campaign]: def count_non_active_campaigns(self) -> int: """Count campaigns whose status is not 'active'.""" count = 0 - for cid, folder in self._campaign_index.items(): + for _cid, folder in self._campaign_index.items(): data = self._read_yaml(folder / "campaign.yaml") if data and data.get("status") != "active": count += 1 @@ -402,17 +428,17 @@ def count_session_intents(self) -> int: return 0 return sum(1 for f in si_dir.iterdir() if f.suffix in (".yaml", ".yml")) - def get_all_campaigns(self, limit: int = 50) -> List[Campaign]: + def get_all_campaigns(self, limit: int = 50) -> list[Campaign]: """Get all campaigns regardless of status, ordered by created_at descending.""" campaigns = [] - for cid, folder in self._campaign_index.items(): + for _cid, folder in self._campaign_index.items(): data = self._read_yaml(folder / "campaign.yaml") if data: campaigns.append(self._dict_to_campaign(data)) campaigns.sort(key=lambda c: c.created_at, reverse=True) return campaigns[:limit] - def get_recent_session_intents(self, limit: int = 50) -> List[SessionIntent]: + def get_recent_session_intents(self, limit: int = 50) -> list[SessionIntent]: """Get recent session intents, ordered by created_at descending.""" si_dir = self.agent_dir / "session_intents" if not si_dir.exists(): @@ -426,7 +452,7 @@ def get_recent_session_intents(self, limit: int = 50) -> List[SessionIntent]: intents.sort(key=lambda i: i.created_at, reverse=True) return intents[:limit] - def resolve_campaign(self, ref: str) -> Optional[Campaign]: + def resolve_campaign(self, ref: str) -> Campaign | None: campaign = self.get_campaign(ref) if campaign: return campaign @@ -435,7 +461,7 @@ def resolve_campaign(self, ref: str) -> Optional[Campaign]: return self.get_campaign(resolved_id) return None - def _resolve_campaign_label(self, label: str) -> Optional[str]: + def _resolve_campaign_label(self, label: str) -> str | None: label_lower = label.lower() # Shorthand match (case-insensitive), root campaigns only @@ -492,11 +518,11 @@ def update_campaign_status(self, campaign_id: str, status: Status): def update_campaign( self, campaign_id: str, - description: Optional[str] = None, - shorthand: Optional[str] = None, - summary: Optional[str] = None, - target: Optional[str] = None, - parent_id: Optional[str] = None, + description: str | None = None, + shorthand: str | None = None, + summary: str | None = None, + target: str | None = None, + parent_id: str | None = None, ): folder = self._campaign_folder(campaign_id) if not folder: @@ -520,8 +546,8 @@ def update_campaign( data["updated_at"] = self._now() self._write_yaml(folder / "campaign.yaml", data) - def delete_campaign(self, campaign_id: str, cascade: bool = True) -> Dict[str, int]: - counts: Dict[str, int] = {"campaigns": 0, "plan_items": 0, "dependencies": 0} + def delete_campaign(self, campaign_id: str, cascade: bool = True) -> dict[str, int]: + counts: dict[str, int] = {"campaigns": 0, "plan_items": 0, "dependencies": 0} def _delete_recursive(cid: str): if cascade: @@ -546,22 +572,22 @@ def _delete_recursive(cid: str): _delete_recursive(campaign_id) return counts - def get_subcampaigns(self, campaign_id: str) -> List[Campaign]: + def get_subcampaigns(self, campaign_id: str) -> list[Campaign]: children = [] - for cid, folder in self._campaign_index.items(): + for _cid, folder in self._campaign_index.items(): data = self._read_yaml(folder / "campaign.yaml") if data and data.get("parent_id") == campaign_id: children.append(self._dict_to_campaign(data)) children.sort(key=lambda c: c.created_at) return children - def get_nth_subcampaign(self, parent_id: str, n: int) -> Optional[Campaign]: + def get_nth_subcampaign(self, parent_id: str, n: int) -> Campaign | None: phases = self.get_subcampaigns(parent_id) if 1 <= n <= len(phases): return phases[n - 1] return None - def get_campaign_tree(self, campaign_id: str) -> Dict[str, Any]: + def get_campaign_tree(self, campaign_id: str) -> dict[str, Any]: campaign = self.get_campaign(campaign_id) if not campaign: return {} @@ -571,10 +597,10 @@ def get_campaign_tree(self, campaign_id: str) -> Dict[str, Any]: "children": [self.get_campaign_tree(c.id) for c in children], } - def get_root_campaigns(self, status: Optional[str] = "active") -> List[Campaign]: + def get_root_campaigns(self, status: str | None = "active") -> list[Campaign]: """Get root campaigns (no parent). If status is None, returns all.""" roots = [] - for cid, folder in self._campaign_index.items(): + for _cid, folder in self._campaign_index.items(): data = self._read_yaml(folder / "campaign.yaml") if data and data.get("parent_id") is None: if status is None or data.get("status") == status: @@ -608,9 +634,9 @@ def unshare_campaign(self, campaign_id: str): data["updated_at"] = self._now() self._write_yaml(folder / "campaign.yaml", data) - def get_shared_campaigns(self) -> List[Campaign]: + def get_shared_campaigns(self) -> list[Campaign]: shared = [] - for cid, folder in self._campaign_index.items(): + for _cid, folder in self._campaign_index.items(): data = self._read_yaml(folder / "campaign.yaml") if data and data.get("is_shared"): shared.append(self._dict_to_campaign(data)) @@ -627,16 +653,18 @@ def add_campaign_participant(self, campaign_id: str, instance_id: str, hostname: participants = data.get("participants", []) # Replace existing entry for this instance_id participants = [p for p in participants if p.get("instance_id") != instance_id] - participants.append({ - "campaign_id": campaign_id, - "instance_id": instance_id, - "hostname": hostname, - "joined_at": self._now(), - }) + participants.append( + { + "campaign_id": campaign_id, + "instance_id": instance_id, + "hostname": hostname, + "joined_at": self._now(), + } + ) data["participants"] = participants self._write_yaml(folder / "campaign.yaml", data) - def get_campaign_participants(self, campaign_id: str) -> List[Dict]: + def get_campaign_participants(self, campaign_id: str) -> list[dict]: folder = self._campaign_folder(campaign_id) if not folder: return [] @@ -683,8 +711,8 @@ def unclaim_plan_item(self, item_id: str): def create_project( self, description: str, - campaign_id: Optional[str] = None, - project_id: Optional[str] = None, + campaign_id: str | None = None, + project_id: str | None = None, ) -> str: pid = project_id or self._gen_id() now = self._now() @@ -698,13 +726,14 @@ def create_project( "updated_at": now, } self._write_yaml( - self.agent_dir / "projects" / f"{pid}_{slug}.yaml", data, + self.agent_dir / "projects" / f"{pid}_{slug}.yaml", + data, ) logger.info(f"Created project {pid}: {description}") return pid - def get_active_projects(self) -> List[Project]: - projects: List[Project] = [] + def get_active_projects(self) -> list[Project]: + projects: list[Project] = [] proj_dir = self.agent_dir / "projects" if not proj_dir.exists(): return projects @@ -723,19 +752,23 @@ def get_active_projects(self) -> List[Project]: def create_session_intent( self, session_id: str, - planned_intent: Optional[str] = None, - campaign_ids: Optional[List[str]] = None, + planned_intent: str | None = None, + campaign_ids: list[str] | None = None, ): now = self._now() path = self.agent_dir / "session_intents" / f"{session_id}.yaml" existing = self._read_yaml(path) data = existing or {} - data.update({ - "session_id": session_id, - "planned_intent": planned_intent if planned_intent is not None else data.get("planned_intent"), - "created_at": data.get("created_at", now), - "campaign_ids": data.get("campaign_ids", []), - }) + data.update( + { + "session_id": session_id, + "planned_intent": planned_intent + if planned_intent is not None + else data.get("planned_intent"), + "created_at": data.get("created_at", now), + "campaign_ids": data.get("campaign_ids", []), + } + ) if "actual_summary" not in data: data["actual_summary"] = None if "completed_at" not in data: @@ -746,7 +779,7 @@ def create_session_intent( for cid in campaign_ids: self.link_session_campaign(session_id, cid) - def get_current_session_intent(self) -> Optional[SessionIntent]: + def get_current_session_intent(self) -> SessionIntent | None: si_dir = self.agent_dir / "session_intents" if not si_dir.exists(): return None @@ -806,14 +839,14 @@ def unlink_session_campaign(self, session_id: str, campaign_id: str): data["campaign_ids"] = cids self._write_yaml(path, data) - def get_campaign_ids_for_session(self, session_id: str) -> List[str]: + def get_campaign_ids_for_session(self, session_id: str) -> list[str]: path = self.agent_dir / "session_intents" / f"{session_id}.yaml" data = self._read_yaml(path) if not data: return [] return data.get("campaign_ids", []) - def get_campaigns_for_session(self, session_id: str) -> List[Campaign]: + def get_campaigns_for_session(self, session_id: str) -> list[Campaign]: cids = self.get_campaign_ids_for_session(session_id) result = [] for cid in cids: @@ -822,8 +855,8 @@ def get_campaigns_for_session(self, session_id: str) -> List[Campaign]: result.append(c) return result - def get_sessions_for_campaign(self, campaign_id: str) -> List[SessionIntent]: - results: List[SessionIntent] = [] + def get_sessions_for_campaign(self, campaign_id: str) -> list[SessionIntent]: + results: list[SessionIntent] = [] si_dir = self.agent_dir / "session_intents" if not si_dir.exists(): return results @@ -842,14 +875,14 @@ def get_sessions_for_campaign(self, campaign_id: str) -> List[SessionIntent]: def create_planned_session( self, scheduled_date: str, - title: Optional[str] = None, - notes: Optional[str] = None, - scheduled_time: Optional[str] = None, - estimated_duration_minutes: Optional[int] = None, - acquisition_params: Optional[Dict] = None, - source_session_id: Optional[str] = None, - campaign_ids: Optional[List[str]] = None, - planned_session_id: Optional[str] = None, + title: str | None = None, + notes: str | None = None, + scheduled_time: str | None = None, + estimated_duration_minutes: int | None = None, + acquisition_params: dict | None = None, + source_session_id: str | None = None, + campaign_ids: list[str] | None = None, + planned_session_id: str | None = None, ) -> str: psid = planned_session_id or self._gen_id() now = self._now() @@ -869,15 +902,15 @@ def create_planned_session( "updated_at": now, } self._write_yaml( - self.agent_dir / "planned_sessions" / f"{psid}.yaml", data, + self.agent_dir / "planned_sessions" / f"{psid}.yaml", + data, ) logger.info( - f"Created planned session {psid} for {scheduled_date}: " - f"{title or notes or '(untitled)'}" + f"Created planned session {psid} for {scheduled_date}: {title or notes or '(untitled)'}" ) return psid - def get_planned_session(self, planned_session_id: str) -> Optional[PlannedSession]: + def get_planned_session(self, planned_session_id: str) -> PlannedSession | None: path = self.agent_dir / "planned_sessions" / f"{planned_session_id}.yaml" data = self._read_yaml(path) if not data: @@ -886,12 +919,12 @@ def get_planned_session(self, planned_session_id: str) -> Optional[PlannedSessio def get_planned_sessions( self, - status: Optional[str] = None, - campaign_id: Optional[str] = None, - from_date: Optional[str] = None, - to_date: Optional[str] = None, - ) -> List[PlannedSession]: - results: List[PlannedSession] = [] + status: str | None = None, + campaign_id: str | None = None, + from_date: str | None = None, + to_date: str | None = None, + ) -> list[PlannedSession]: + results: list[PlannedSession] = [] ps_dir = self.agent_dir / "planned_sessions" if not ps_dir.exists(): return results @@ -914,9 +947,9 @@ def get_planned_sessions( results.sort(key=lambda ps: (ps.scheduled_date or "", ps.scheduled_time or "")) return results - def get_upcoming_sessions(self, limit: int = 10) -> List[PlannedSession]: + def get_upcoming_sessions(self, limit: int = 10) -> list[PlannedSession]: today = datetime.now().strftime("%Y-%m-%d") - results: List[PlannedSession] = [] + results: list[PlannedSession] = [] ps_dir = self.agent_dir / "planned_sessions" if not ps_dir.exists(): return results @@ -935,9 +968,9 @@ def get_upcoming_sessions(self, limit: int = 10) -> List[PlannedSession]: results.sort(key=lambda ps: (ps.scheduled_date or "", ps.scheduled_time or "")) return results[:limit] - def get_todays_sessions(self) -> List[PlannedSession]: + def get_todays_sessions(self) -> list[PlannedSession]: today = datetime.now().strftime("%Y-%m-%d") - results: List[PlannedSession] = [] + results: list[PlannedSession] = [] ps_dir = self.agent_dir / "planned_sessions" if not ps_dir.exists(): return results @@ -958,15 +991,15 @@ def get_todays_sessions(self) -> List[PlannedSession]: def update_planned_session( self, planned_session_id: str, - title: Optional[str] = None, - notes: Optional[str] = None, - scheduled_date: Optional[str] = None, - scheduled_time: Optional[str] = None, - estimated_duration_minutes: Optional[int] = None, - acquisition_params: Optional[Dict] = None, - source_session_id: Optional[str] = None, - status: Optional[PlannedSessionStatus] = None, - session_id: Optional[str] = None, + title: str | None = None, + notes: str | None = None, + scheduled_date: str | None = None, + scheduled_time: str | None = None, + estimated_duration_minutes: int | None = None, + acquisition_params: dict | None = None, + source_session_id: str | None = None, + status: PlannedSessionStatus | None = None, + session_id: str | None = None, ): path = self.agent_dir / "planned_sessions" / f"{planned_session_id}.yaml" data = self._read_yaml(path) @@ -1025,7 +1058,7 @@ def unlink_planned_session_campaign(self, planned_session_id: str, campaign_id: data["campaign_ids"] = cids self._write_yaml(path, data) - def get_campaign_ids_for_planned_session(self, planned_session_id: str) -> List[str]: + def get_campaign_ids_for_planned_session(self, planned_session_id: str) -> list[str]: path = self.agent_dir / "planned_sessions" / f"{planned_session_id}.yaml" data = self._read_yaml(path) if not data: @@ -1036,7 +1069,7 @@ def get_campaign_ids_for_planned_session(self, planned_session_id: str) -> List[ # Plan Items # ================================================================== - def _read_plan_items_raw(self, campaign_id: str) -> List[Dict]: + def _read_plan_items_raw(self, campaign_id: str) -> list[dict]: """Read the raw plan items list for a campaign.""" folder = self._campaign_folder(campaign_id) if not folder: @@ -1046,7 +1079,7 @@ def _read_plan_items_raw(self, campaign_id: str) -> List[Dict]: return [] return data - def _write_plan_items(self, campaign_id: str, items: List[Dict]): + def _write_plan_items(self, campaign_id: str, items: list[dict]): """Write the plan items list for a campaign.""" folder = self._campaign_folder(campaign_id) if not folder: @@ -1054,8 +1087,9 @@ def _write_plan_items(self, campaign_id: str, items: List[Dict]): self._write_yaml(folder / "plan" / "current.yaml", items) def _find_plan_item_location( - self, item_id: str, - ) -> Optional[tuple]: + self, + item_id: str, + ) -> tuple | None: """Find a plan item across all campaigns. Returns (campaign_id, items_list, index) or None. @@ -1072,15 +1106,15 @@ def create_plan_item( campaign_id: str, type: str, title: str, - description: Optional[str] = None, - spec: Optional[Dict] = None, - inherit_from: Optional[str] = None, - planned_session_id: Optional[str] = None, + description: str | None = None, + spec: dict | None = None, + inherit_from: str | None = None, + planned_session_id: str | None = None, phase_order: int = -1, - depends_on: Optional[List[str]] = None, - item_id: Optional[str] = None, - references: Optional[List[Dict]] = None, - estimated_days: Optional[int] = None, + depends_on: list[str] | None = None, + item_id: str | None = None, + references: list[dict] | None = None, + estimated_days: int | None = None, ) -> str: pid = item_id or self._gen_id() now = self._now() @@ -1120,7 +1154,7 @@ def create_plan_item( logger.info(f"Created plan item {pid} [{type}] #{phase_order}: {title}") return pid - def get_plan_item(self, item_id: str) -> Optional[PlanItem]: + def get_plan_item(self, item_id: str) -> PlanItem | None: loc = self._find_plan_item_location(item_id) if not loc: return None @@ -1128,8 +1162,10 @@ def get_plan_item(self, item_id: str) -> Optional[PlanItem]: return self._dict_to_plan_item(items[idx]) def resolve_plan_item( - self, ref: str, campaign_id: Optional[str] = None, - ) -> Optional[PlanItem]: + self, + ref: str, + campaign_id: str | None = None, + ) -> PlanItem | None: ref = ref.strip().lower() # Direct ID match @@ -1138,7 +1174,7 @@ def resolve_plan_item( return item # UUID prefix match - if len(ref) >= 4 and re.match(r'^[0-9a-f]+$', ref): + if len(ref) >= 4 and re.match(r"^[0-9a-f]+$", ref): for cid in self._campaign_index: for raw in self._read_plan_items_raw(cid): if raw.get("id", "").startswith(ref): @@ -1149,7 +1185,7 @@ def resolve_plan_item( task_num = None # "campaign.phase.task" - m = re.match(r'^([^.\s]+)\.(\d+)\.(\d+)$', ref) + m = re.match(r"^([^.\s]+)\.(\d+)\.(\d+)$", ref) if m: campaign_label = m.group(1) phase_num, task_num = int(m.group(2)), int(m.group(3)) @@ -1159,25 +1195,25 @@ def resolve_plan_item( # "1.3" or "2.1" if not task_num: - m = re.match(r'^(\d+)\.(\d+)$', ref) + m = re.match(r"^(\d+)\.(\d+)$", ref) if m: phase_num, task_num = int(m.group(1)), int(m.group(2)) # "task 3 of phase 1" if not task_num: - m = re.search(r'task\s+(\d+)\s+(?:of\s+)?phase\s+(\d+)', ref) + m = re.search(r"task\s+(\d+)\s+(?:of\s+)?phase\s+(\d+)", ref) if m: task_num, phase_num = int(m.group(1)), int(m.group(2)) # "phase 1 task 3" if not task_num: - m = re.search(r'phase\s+(\d+)\s+task\s+(\d+)', ref) + m = re.search(r"phase\s+(\d+)\s+task\s+(\d+)", ref) if m: phase_num, task_num = int(m.group(1)), int(m.group(2)) # "task 3" / "#3" / just "3" if not task_num: - m = re.match(r'^(?:task\s+|#)?(\d+)$', ref) + m = re.match(r"^(?:task\s+|#)?(\d+)$", ref) if m: task_num = int(m.group(1)) @@ -1204,7 +1240,7 @@ def resolve_plan_item( else: phases = self.get_subcampaigns(root_id) if phases: - all_items: List[PlanItem] = [] + all_items: list[PlanItem] = [] for phase in phases: p_items = self.get_plan_items(campaign_id=phase.id) p_items.sort(key=lambda x: x.phase_order) @@ -1224,11 +1260,11 @@ def resolve_plan_item( def get_plan_items( self, - campaign_id: Optional[str] = None, - status: Optional[str] = None, - type: Optional[str] = None, + campaign_id: str | None = None, + status: str | None = None, + type: str | None = None, include_children: bool = False, - ) -> List[PlanItem]: + ) -> list[PlanItem]: if campaign_id and include_children: cids = self._get_campaign_tree_ids(campaign_id) elif campaign_id: @@ -1236,7 +1272,7 @@ def get_plan_items( else: cids = list(self._campaign_index.keys()) - result: List[PlanItem] = [] + result: list[PlanItem] = [] for cid in cids: for raw in self._read_plan_items_raw(cid): if status and raw.get("status") != status: @@ -1250,17 +1286,17 @@ def get_plan_items( def update_plan_item( self, item_id: str, - title: Optional[str] = None, - description: Optional[str] = None, - status: Optional[PlanItemStatus] = None, - outcome: Optional[str] = None, - spec: Optional[Dict] = None, - planned_session_id: Optional[str] = None, - session_id: Optional[str] = None, - phase_order: Optional[int] = None, - campaign_id: Optional[str] = None, - references: Optional[List[Dict]] = None, - estimated_days: Optional[int] = None, + title: str | None = None, + description: str | None = None, + status: PlanItemStatus | None = None, + outcome: str | None = None, + spec: dict | None = None, + planned_session_id: str | None = None, + session_id: str | None = None, + phase_order: int | None = None, + campaign_id: str | None = None, + references: list[dict] | None = None, + estimated_days: int | None = None, ): loc = self._find_plan_item_location(item_id) if not loc: @@ -1302,10 +1338,12 @@ def update_plan_item( def complete_plan_item(self, item_id: str, outcome: str): self.update_plan_item( - item_id, status=PlanItemStatus.COMPLETED, outcome=outcome, + item_id, + status=PlanItemStatus.COMPLETED, + outcome=outcome, ) - def skip_plan_item(self, item_id: str, reason: Optional[str] = None): + def skip_plan_item(self, item_id: str, reason: str | None = None): self.update_plan_item( item_id, status=PlanItemStatus.SKIPPED, @@ -1362,27 +1400,29 @@ def remove_plan_item_dependency(self, item_id: str, depends_on_id: str): items[idx]["depends_on"] = deps self._write_plan_items(campaign_id, items) - def get_plan_item_dependencies(self, item_id: str) -> List[str]: + def get_plan_item_dependencies(self, item_id: str) -> list[str]: loc = self._find_plan_item_location(item_id) if not loc: return [] _, items, idx = loc return list(items[idx].get("depends_on", [])) - def get_plan_item_dependents(self, item_id: str) -> List[str]: + def get_plan_item_dependents(self, item_id: str) -> list[str]: """Get IDs of items that depend on this item.""" - dependents: List[str] = [] + dependents: list[str] = [] for cid in self._campaign_index: for raw in self._read_plan_items_raw(cid): if item_id in raw.get("depends_on", []): dependents.append(raw["id"]) return dependents - def get_unblocked_plan_items(self, campaign_id: str) -> List[PlanItem]: + def get_unblocked_plan_items(self, campaign_id: str) -> list[PlanItem]: items = self.get_plan_items( - campaign_id=campaign_id, status="planned", include_children=True, + campaign_id=campaign_id, + status="planned", + include_children=True, ) - unblocked: List[PlanItem] = [] + unblocked: list[PlanItem] = [] for item in items: if not item.depends_on: unblocked.append(item) @@ -1391,7 +1431,8 @@ def get_unblocked_plan_items(self, campaign_id: str) -> List[PlanItem]: for dep_id in item.depends_on: dep = self.get_plan_item(dep_id) if dep and dep.status not in ( - PlanItemStatus.COMPLETED, PlanItemStatus.SKIPPED, + PlanItemStatus.COMPLETED, + PlanItemStatus.SKIPPED, ): all_resolved = False break @@ -1399,11 +1440,12 @@ def get_unblocked_plan_items(self, campaign_id: str) -> List[PlanItem]: unblocked.append(item) return unblocked - def get_plan_status(self, campaign_id: str) -> Dict[str, Any]: + def get_plan_status(self, campaign_id: str) -> dict[str, Any]: items = self.get_plan_items( - campaign_id=campaign_id, include_children=True, + campaign_id=campaign_id, + include_children=True, ) - result: Dict[str, Any] = { + result: dict[str, Any] = { "total": len(items), "completed": 0, "in_progress": 0, @@ -1426,16 +1468,13 @@ def get_plan_status(self, campaign_id: str) -> Dict[str, Any]: if item.status == PlanItemStatus.COMPLETED: result["by_type"][type_key]["completed"] += 1 - if ( - item.type == PlanItemType.DECISION_POINT - and item.status == PlanItemStatus.PLANNED - ): + if item.type == PlanItemType.DECISION_POINT and item.status == PlanItemStatus.PLANNED: result["pending_decisions"].append(item) result["next_actions"] = self.get_unblocked_plan_items(campaign_id) return result - def resolve_imaging_spec(self, item: PlanItem) -> Optional[ImagingSpec]: + def resolve_imaging_spec(self, item: PlanItem) -> ImagingSpec | None: if item.type != PlanItemType.IMAGING: return None if not item.inherit_from: @@ -1462,7 +1501,7 @@ def resolve_imaging_spec(self, item: PlanItem) -> Optional[ImagingSpec]: def save_plan_template( self, name: str, - description: Optional[str], + description: str | None, campaign_id: str, ) -> str: campaign = self.get_campaign(campaign_id) @@ -1488,7 +1527,7 @@ def save_plan_template( logger.info(f"Saved plan template '{name}' ({tid})") return tid - def _serialize_campaign_tree(self, campaign_id: str) -> Dict: + def _serialize_campaign_tree(self, campaign_id: str) -> dict: campaign = self.get_campaign(campaign_id) if not campaign: return {} @@ -1496,9 +1535,9 @@ def _serialize_campaign_tree(self, campaign_id: str) -> Dict: items.sort(key=lambda x: x.phase_order) all_item_ids = [it.id for it in items] - serialized_items: List[Dict] = [] + serialized_items: list[dict] = [] for item in items: - item_data: Dict[str, Any] = { + item_data: dict[str, Any] = { "type": item.type.value, "title": item.title, "description": item.description, @@ -1533,9 +1572,7 @@ def _serialize_campaign_tree(self, campaign_id: str) -> Dict: serialized_items.append(item_data) children = self.get_subcampaigns(campaign_id) - serialized_children = [ - self._serialize_campaign_tree(child.id) for child in children - ] + serialized_children = [self._serialize_campaign_tree(child.id) for child in children] return { "description": campaign.description, @@ -1545,9 +1582,9 @@ def _serialize_campaign_tree(self, campaign_id: str) -> Dict: "children": serialized_children, } - def list_plan_templates(self) -> List[Dict]: - templates: List[Dict] = [] - for cid, folder in self._campaign_index.items(): + def list_plan_templates(self) -> list[dict]: + templates: list[dict] = [] + for _cid, folder in self._campaign_index.items(): tpl_dir = folder / "templates" if not tpl_dir.exists(): continue @@ -1555,35 +1592,34 @@ def list_plan_templates(self) -> List[Dict]: if f.suffix in (".yaml", ".yml"): data = self._read_yaml(f) if data: - templates.append({ - "id": data.get("id", ""), - "name": data.get("name", ""), - "description": data.get("description"), - "created_at": data.get("created_at", ""), - "updated_at": data.get("updated_at", ""), - }) + templates.append( + { + "id": data.get("id", ""), + "name": data.get("name", ""), + "description": data.get("description"), + "created_at": data.get("created_at", ""), + "updated_at": data.get("updated_at", ""), + } + ) templates.sort(key=lambda t: t.get("created_at", ""), reverse=True) return templates - def get_plan_template(self, id_or_name: str) -> Optional[Dict]: - for cid, folder in self._campaign_index.items(): + def get_plan_template(self, id_or_name: str) -> dict | None: + for _cid, folder in self._campaign_index.items(): tpl_dir = folder / "templates" if not tpl_dir.exists(): continue for f in tpl_dir.iterdir(): if f.suffix in (".yaml", ".yml"): data = self._read_yaml(f) - if data and ( - data.get("id") == id_or_name - or data.get("name") == id_or_name - ): + if data and (data.get("id") == id_or_name or data.get("name") == id_or_name): return data return None def apply_plan_template( self, template_id: str, - overrides: Optional[Dict] = None, + overrides: dict | None = None, ) -> str: tmpl = self.get_plan_template(template_id) if not tmpl: @@ -1594,9 +1630,9 @@ def apply_plan_template( def _instantiate_template_tree( self, - data: Dict, - parent_id: Optional[str], - overrides: Dict, + data: dict, + parent_id: str | None, + overrides: dict, ) -> str: cid = self.create_campaign( description=data.get("description", "Untitled"), @@ -1605,7 +1641,7 @@ def _instantiate_template_tree( parent_id=parent_id, ) items_data = data.get("items", []) - new_item_ids: List[str] = [] + new_item_ids: list[str] = [] for item_data in items_data: spec = item_data.get("spec") @@ -1613,9 +1649,15 @@ def _instantiate_template_tree( spec = dict(spec) for k, v in overrides.items(): if k in spec or k in ( - "strain", "genotype", "reporter", "temperature_c", - "num_slices", "exposure_ms", "interval_s", - "num_embryos", "stop_condition", + "strain", + "genotype", + "reporter", + "temperature_c", + "num_slices", + "exposure_ms", + "interval_s", + "num_embryos", + "stop_condition", ): spec[k] = v @@ -1635,7 +1677,8 @@ def _instantiate_template_tree( for dep_idx in dep_indices: if 0 <= dep_idx < len(new_item_ids): self.add_plan_item_dependency( - new_item_ids[idx_item], new_item_ids[dep_idx], + new_item_ids[idx_item], + new_item_ids[dep_idx], ) for child_data in data.get("children", []): @@ -1644,17 +1687,14 @@ def _instantiate_template_tree( return cid def delete_plan_template(self, template_id: str) -> bool: - for cid, folder in self._campaign_index.items(): + for _cid, folder in self._campaign_index.items(): tpl_dir = folder / "templates" if not tpl_dir.exists(): continue for f in tpl_dir.iterdir(): if f.suffix in (".yaml", ".yml"): data = self._read_yaml(f) - if data and ( - data.get("id") == template_id - or data.get("name") == template_id - ): + if data and (data.get("id") == template_id or data.get("name") == template_id): f.unlink() return True return False @@ -1666,8 +1706,8 @@ def delete_plan_template(self, template_id: str) -> bool: def create_plan_snapshot( self, campaign_id: str, - label: Optional[str] = None, - summary: Optional[str] = None, + label: str | None = None, + summary: str | None = None, ) -> str: snapshot_data = self._serialize_campaign_tree(campaign_id) if not summary: @@ -1703,7 +1743,9 @@ def create_plan_snapshot( "created_at": now, } self._write_yaml(history_dir / f"{timestamp_slug}.yaml", snapshot) - logger.info(f"Created plan snapshot v{version_number} ({version_id}) for campaign {campaign_id}") + logger.info( + f"Created plan snapshot v{version_number} ({version_id}) for campaign {campaign_id}" + ) return version_id def _generate_snapshot_summary(self, campaign_id: str) -> str: @@ -1712,7 +1754,7 @@ def _generate_snapshot_summary(self, campaign_id: str) -> str: return "Unknown campaign" phases = self.get_subcampaigns(campaign_id) items = self.get_plan_items(campaign_id=campaign_id, include_children=True) - status_counts: Dict[str, int] = {} + status_counts: dict[str, int] = {} for item in items: key = item.status.value status_counts[key] = status_counts.get(key, 0) + 1 @@ -1725,7 +1767,7 @@ def _generate_snapshot_summary(self, campaign_id: str) -> str: parts.append(f" {status_name}: {count}") return "\n".join(parts) - def _read_all_snapshots(self, campaign_id: str) -> List[Dict]: + def _read_all_snapshots(self, campaign_id: str) -> list[dict]: """Read all snapshot files for a campaign.""" folder = self._campaign_folder(campaign_id) if not folder: @@ -1733,7 +1775,7 @@ def _read_all_snapshots(self, campaign_id: str) -> List[Dict]: history_dir = folder / "plan" / "history" if not history_dir.exists(): return [] - snapshots: List[Dict] = [] + snapshots: list[dict] = [] for f in history_dir.iterdir(): if f.suffix in (".yaml", ".yml"): data = self._read_yaml(f) @@ -1742,25 +1784,29 @@ def _read_all_snapshots(self, campaign_id: str) -> List[Dict]: return snapshots def list_plan_snapshots( - self, campaign_id: str, limit: int = 50, - ) -> List[Dict]: + self, + campaign_id: str, + limit: int = 50, + ) -> list[dict]: snapshots = self._read_all_snapshots(campaign_id) # Return metadata only (no blob) result = [] for s in snapshots: - result.append({ - "version_id": s.get("version_id"), - "campaign_id": s.get("campaign_id"), - "version_number": s.get("version_number"), - "summary": s.get("summary"), - "label": s.get("label"), - "parent_version_id": s.get("parent_version_id"), - "created_at": s.get("created_at"), - }) + result.append( + { + "version_id": s.get("version_id"), + "campaign_id": s.get("campaign_id"), + "version_number": s.get("version_number"), + "summary": s.get("summary"), + "label": s.get("label"), + "parent_version_id": s.get("parent_version_id"), + "created_at": s.get("created_at"), + } + ) result.sort(key=lambda s: s.get("version_number", 0), reverse=True) return result[:limit] - def get_plan_snapshot(self, version_id: str) -> Optional[Dict]: + def get_plan_snapshot(self, version_id: str) -> dict | None: for cid in self._campaign_index: for snap in self._read_all_snapshots(cid): if snap.get("version_id") == version_id: @@ -1801,7 +1847,7 @@ def restore_plan_snapshot(self, version_id: str) -> str: ) return new_campaign_id - def _get_campaign_tree_ids(self, campaign_id: str) -> List[str]: + def _get_campaign_tree_ids(self, campaign_id: str) -> list[str]: ids = [campaign_id] for cid, folder in self._campaign_index.items(): data = self._read_yaml(folder / "campaign.yaml") @@ -1827,14 +1873,15 @@ def add_observation(self, obs: Observation): "relates_to": obs.relates_to, } self._write_yaml( - self.agent_dir / "observations" / f"{obs.id}_{slug}.yaml", data, + self.agent_dir / "observations" / f"{obs.id}_{slug}.yaml", + data, ) - def get_recent_observations(self, limit: int = 50) -> List[Observation]: + def get_recent_observations(self, limit: int = 50) -> list[Observation]: obs_dir = self.agent_dir / "observations" if not obs_dir.exists(): return [] - all_obs: List[Observation] = [] + all_obs: list[Observation] = [] for f in obs_dir.iterdir(): if f.suffix in (".yaml", ".yml"): data = self._read_yaml(f) @@ -1844,11 +1891,11 @@ def get_recent_observations(self, limit: int = 50) -> List[Observation]: # Return in chronological order (oldest first in the window) return list(reversed(all_obs[:limit])) - def get_observations_for_embryo(self, embryo_id: str, limit: int = 20) -> List[Observation]: + def get_observations_for_embryo(self, embryo_id: str, limit: int = 20) -> list[Observation]: obs_dir = self.agent_dir / "observations" if not obs_dir.exists(): return [] - matches: List[Observation] = [] + matches: list[Observation] = [] for f in obs_dir.iterdir(): if f.suffix in (".yaml", ".yml"): data = self._read_yaml(f) @@ -1864,34 +1911,34 @@ def get_observations_for_embryo(self, embryo_id: str, limit: int = 20) -> List[O def add_expectation(self, exp: Expectation): path = self.agent_dir / "active" / "expectations.yaml" items = self._read_yaml(path) or [] - items.append({ - "id": exp.id, - "target": exp.target, - "prediction": exp.prediction, - "expected_time": exp.expected_time.isoformat(), - "uncertainty": exp.uncertainty, - "basis": exp.basis, - "status": exp.status.value, - "created_at": exp.created_at.isoformat(), - "resolved_at": None, - }) + items.append( + { + "id": exp.id, + "target": exp.target, + "prediction": exp.prediction, + "expected_time": exp.expected_time.isoformat(), + "uncertainty": exp.uncertainty, + "basis": exp.basis, + "status": exp.status.value, + "created_at": exp.created_at.isoformat(), + "resolved_at": None, + } + ) self._write_yaml(path, items) - def get_pending_expectations(self) -> List[Expectation]: + def get_pending_expectations(self) -> list[Expectation]: path = self.agent_dir / "active" / "expectations.yaml" items = self._read_yaml(path) or [] - pending = [ - self._dict_to_expectation(d) for d in items - if d.get("status") == "pending" - ] + pending = [self._dict_to_expectation(d) for d in items if d.get("status") == "pending"] pending.sort(key=lambda e: e.expected_time) return pending - def get_expectation_for(self, target: str) -> Optional[Expectation]: + def get_expectation_for(self, target: str) -> Expectation | None: path = self.agent_dir / "active" / "expectations.yaml" items = self._read_yaml(path) or [] candidates = [ - self._dict_to_expectation(d) for d in items + self._dict_to_expectation(d) + for d in items if d.get("target") == target and d.get("status") == "pending" ] if not candidates: @@ -1917,23 +1964,22 @@ def resolve_expectation(self, exp_id: str, status: ExpectationStatus): def add_watchpoint(self, wp: Watchpoint): path = self.agent_dir / "active" / "watchpoints.yaml" items = self._read_yaml(path) or [] - items.append({ - "id": wp.id, - "target": wp.target, - "condition": wp.condition, - "priority": wp.priority.value if wp.priority else "medium", - "status": wp.status.value, - "created_at": wp.created_at.isoformat(), - }) + items.append( + { + "id": wp.id, + "target": wp.target, + "condition": wp.condition, + "priority": wp.priority.value if wp.priority else "medium", + "status": wp.status.value, + "created_at": wp.created_at.isoformat(), + } + ) self._write_yaml(path, items) - def get_active_watchpoints(self) -> List[Watchpoint]: + def get_active_watchpoints(self) -> list[Watchpoint]: path = self.agent_dir / "active" / "watchpoints.yaml" items = self._read_yaml(path) or [] - active = [ - self._dict_to_watchpoint(d) for d in items - if d.get("status") == "active" - ] + active = [self._dict_to_watchpoint(d) for d in items if d.get("status") == "active"] # Sort: high > medium > low, then by created_at priority_order = {"high": 0, "medium": 1, "low": 2} active.sort(key=lambda w: (priority_order.get(w.priority.value, 1), w.created_at)) @@ -1964,22 +2010,23 @@ def resolve_watchpoint(self, wp_id: str): def add_question(self, q: Question): path = self.agent_dir / "active" / "questions.yaml" items = self._read_yaml(path) or [] - items.append({ - "id": q.id, - "content": q.content, - "status": q.status.value, - "resolution": None, - "created_at": q.created_at.isoformat(), - "resolved_at": None, - }) + items.append( + { + "id": q.id, + "content": q.content, + "status": q.status.value, + "resolution": None, + "created_at": q.created_at.isoformat(), + "resolved_at": None, + } + ) self._write_yaml(path, items) - def get_open_questions(self) -> List[Question]: + def get_open_questions(self) -> list[Question]: path = self.agent_dir / "active" / "questions.yaml" items = self._read_yaml(path) or [] open_qs = [ - self._dict_to_question(d) for d in items - if d.get("status") in ("open", "investigating") + self._dict_to_question(d) for d in items if d.get("status") in ("open", "investigating") ] open_qs.sort(key=lambda q: q.created_at) return open_qs @@ -2010,20 +2057,21 @@ def add_learning(self, learning: Learning): "created_at": learning.created_at.isoformat(), } self._write_yaml( - self.agent_dir / "learnings" / f"{learning.id}_{slug}.yaml", data, + self.agent_dir / "learnings" / f"{learning.id}_{slug}.yaml", + data, ) - def get_learnings(self, limit: int = 50) -> List[Learning]: + def get_learnings(self, limit: int = 50) -> list[Learning]: learn_dir = self.agent_dir / "learnings" if not learn_dir.exists(): return [] - all_learnings: List[Learning] = [] + all_learnings: list[Learning] = [] for f in learn_dir.iterdir(): if f.suffix in (".yaml", ".yml"): data = self._read_yaml(f) if data: all_learnings.append(self._dict_to_learning(data)) - all_learnings.sort(key=lambda l: l.created_at, reverse=True) + all_learnings.sort(key=lambda learning: learning.created_at, reverse=True) return all_learnings[:limit] # ================================================================== @@ -2033,13 +2081,13 @@ def get_learnings(self, limit: int = 50) -> List[Learning]: def update_embryo_understanding( self, embryo_id: str, - current_stage: Optional[str] = None, - stage_confidence: Optional[Confidence] = None, - health_assessment: Optional[str] = None, - note: Optional[str] = None, - is_hatched: Optional[bool] = None, - needs_attention: Optional[bool] = None, - attention_reason: Optional[str] = None, + current_stage: str | None = None, + stage_confidence: Confidence | None = None, + health_assessment: str | None = None, + note: str | None = None, + is_hatched: bool | None = None, + needs_attention: bool | None = None, + attention_reason: str | None = None, ): now = self._now() path = self.agent_dir / "embryo_understanding" / f"{embryo_id}.yaml" @@ -2085,7 +2133,7 @@ def update_embryo_understanding( # Agent State # ================================================================== - def get_state(self, key: str) -> Optional[str]: + def get_state(self, key: str) -> str | None: path = self.agent_dir / "state.yaml" data = self._read_yaml(path) if not data or not isinstance(data, dict): @@ -2142,10 +2190,10 @@ def create_ml_pipeline( campaign_id: str, name: str, task: str = "embryo_stage_classification", - model_config: Optional[Dict] = None, - data_split: Optional[Dict] = None, - training_config: Optional[Dict] = None, - ) -> Dict[str, Any]: + model_config: dict | None = None, + data_split: dict | None = None, + training_config: dict | None = None, + ) -> dict[str, Any]: pipeline_id = self._gen_id() now = self._now() data = { @@ -2163,11 +2211,12 @@ def create_ml_pipeline( "updated_at": now, } self._write_yaml( - self.agent_dir / "ml" / "pipelines" / f"{pipeline_id}.yaml", data, + self.agent_dir / "ml" / "pipelines" / f"{pipeline_id}.yaml", + data, ) return self.get_ml_pipeline(pipeline_id) - def get_ml_pipeline(self, pipeline_id: str) -> Optional[Dict[str, Any]]: + def get_ml_pipeline(self, pipeline_id: str) -> dict[str, Any] | None: path = self.agent_dir / "ml" / "pipelines" / f"{pipeline_id}.yaml" data = self._read_yaml(path) if not data: @@ -2187,11 +2236,11 @@ def get_ml_pipeline(self, pipeline_id: str) -> Optional[Dict[str, Any]]: "updated_at": data.get("updated_at"), } - def list_ml_pipelines(self, campaign_id: Optional[str] = None) -> List[Dict[str, Any]]: + def list_ml_pipelines(self, campaign_id: str | None = None) -> list[dict[str, Any]]: pipe_dir = self.agent_dir / "ml" / "pipelines" if not pipe_dir.exists(): return [] - results: List[Dict[str, Any]] = [] + results: list[dict[str, Any]] = [] for f in pipe_dir.iterdir(): if f.suffix in (".yaml", ".yml"): data = self._read_yaml(f) @@ -2205,13 +2254,20 @@ def list_ml_pipelines(self, campaign_id: Optional[str] = None) -> List[Dict[str, results.sort(key=lambda p: p.get("created_at", ""), reverse=True) return results - def update_ml_pipeline(self, pipeline_id: str, **kwargs) -> Optional[Dict[str, Any]]: + def update_ml_pipeline(self, pipeline_id: str, **kwargs) -> dict[str, Any] | None: path = self.agent_dir / "ml" / "pipelines" / f"{pipeline_id}.yaml" data = self._read_yaml(path) if not data: return self.get_ml_pipeline(pipeline_id) - allowed = {"status", "model_config", "data_split", "training_config", - "best_run_id", "best_accuracy", "name"} + allowed = { + "status", + "model_config", + "data_split", + "training_config", + "best_run_id", + "best_accuracy", + "name", + } changed = False for k, v in kwargs.items(): if k in allowed: @@ -2229,11 +2285,11 @@ def update_ml_pipeline(self, pipeline_id: str, **kwargs) -> Optional[Dict[str, A def create_training_run( self, pipeline_id: str, - model_config: Optional[Dict] = None, - training_config: Optional[Dict] = None, - data_split: Optional[Dict] = None, + model_config: dict | None = None, + training_config: dict | None = None, + data_split: dict | None = None, peer_instance_id: str = "", - ) -> Dict[str, Any]: + ) -> dict[str, Any]: run_id = self._gen_id() data = { "id": run_id, @@ -2256,11 +2312,12 @@ def create_training_run( "error_message": "", } self._write_yaml( - self.agent_dir / "ml" / "runs" / f"{run_id}.yaml", data, + self.agent_dir / "ml" / "runs" / f"{run_id}.yaml", + data, ) return self.get_training_run(run_id) - def get_training_run(self, run_id: str) -> Optional[Dict[str, Any]]: + def get_training_run(self, run_id: str) -> dict[str, Any] | None: path = self.agent_dir / "ml" / "runs" / f"{run_id}.yaml" data = self._read_yaml(path) if not data: @@ -2286,11 +2343,11 @@ def get_training_run(self, run_id: str) -> Optional[Dict[str, Any]]: "error_message": data.get("error_message", ""), } - def list_training_runs(self, pipeline_id: str) -> List[Dict[str, Any]]: + def list_training_runs(self, pipeline_id: str) -> list[dict[str, Any]]: runs_dir = self.agent_dir / "ml" / "runs" if not runs_dir.exists(): return [] - results: List[Dict[str, Any]] = [] + results: list[dict[str, Any]] = [] for f in runs_dir.iterdir(): if f.suffix in (".yaml", ".yml"): data = self._read_yaml(f) @@ -2300,15 +2357,24 @@ def list_training_runs(self, pipeline_id: str) -> List[Dict[str, Any]]: results.append(run) return results - def update_training_run(self, run_id: str, **kwargs) -> Optional[Dict[str, Any]]: + def update_training_run(self, run_id: str, **kwargs) -> dict[str, Any] | None: path = self.agent_dir / "ml" / "runs" / f"{run_id}.yaml" data = self._read_yaml(path) if not data: return self.get_training_run(run_id) allowed = { - "status", "current_epoch", "total_epochs", "train_loss", "val_loss", - "val_accuracy", "best_val_accuracy", "model_weights_path", "metrics_path", - "started_at", "completed_at", "error_message", + "status", + "current_epoch", + "total_epochs", + "train_loss", + "val_loss", + "val_accuracy", + "best_val_accuracy", + "model_weights_path", + "metrics_path", + "started_at", + "completed_at", + "error_message", } changed = False for k, v in kwargs.items(): @@ -2325,15 +2391,15 @@ def update_training_run(self, run_id: str, **kwargs) -> Optional[Dict[str, Any]] def save_data_assessment( self, - pipeline_id: Optional[str] = None, + pipeline_id: str | None = None, total_sessions: int = 0, total_embryos: int = 0, total_volumes: int = 0, annotated_embryos: int = 0, - stage_distribution: Optional[Dict] = None, - coverage_gaps: Optional[List] = None, + stage_distribution: dict | None = None, + coverage_gaps: list | None = None, quality_notes: str = "", - ) -> Dict[str, Any]: + ) -> dict[str, Any]: assessment_id = self._gen_id() now = self._now() data = { @@ -2349,11 +2415,12 @@ def save_data_assessment( "created_at": now, } self._write_yaml( - self.agent_dir / "ml" / "assessments" / f"{assessment_id}.yaml", data, + self.agent_dir / "ml" / "assessments" / f"{assessment_id}.yaml", + data, ) return self.get_data_assessment(assessment_id) - def get_data_assessment(self, assessment_id: str) -> Optional[Dict[str, Any]]: + def get_data_assessment(self, assessment_id: str) -> dict[str, Any] | None: path = self.agent_dir / "ml" / "assessments" / f"{assessment_id}.yaml" data = self._read_yaml(path) if not data: @@ -2376,7 +2443,7 @@ def get_data_assessment(self, assessment_id: str) -> Optional[Dict[str, Any]]: # ================================================================== @staticmethod - def _dict_to_campaign(d: Dict) -> Campaign: + def _dict_to_campaign(d: dict) -> Campaign: return Campaign( id=d["id"], description=d["description"], @@ -2387,22 +2454,30 @@ def _dict_to_campaign(d: Dict) -> Campaign: parent_id=d.get("parent_id"), status=Status(d.get("status", "active")), is_shared=bool(d.get("is_shared", False)), - created_at=datetime.fromisoformat(d["created_at"]) if isinstance(d["created_at"], str) else d["created_at"], - updated_at=datetime.fromisoformat(d["updated_at"]) if isinstance(d["updated_at"], str) else d["updated_at"], + created_at=datetime.fromisoformat(d["created_at"]) + if isinstance(d["created_at"], str) + else d["created_at"], + updated_at=datetime.fromisoformat(d["updated_at"]) + if isinstance(d["updated_at"], str) + else d["updated_at"], ) @staticmethod - def _dict_to_project(d: Dict) -> Project: + def _dict_to_project(d: dict) -> Project: return Project( id=d["id"], description=d["description"], campaign_id=d.get("campaign_id"), status=Status(d.get("status", "active")), - created_at=datetime.fromisoformat(d["created_at"]) if isinstance(d["created_at"], str) else d["created_at"], - updated_at=datetime.fromisoformat(d["updated_at"]) if isinstance(d["updated_at"], str) else d["updated_at"], + created_at=datetime.fromisoformat(d["created_at"]) + if isinstance(d["created_at"], str) + else d["created_at"], + updated_at=datetime.fromisoformat(d["updated_at"]) + if isinstance(d["updated_at"], str) + else d["updated_at"], ) - def _dict_to_session_intent(self, d: Dict) -> SessionIntent: + def _dict_to_session_intent(self, d: dict) -> SessionIntent: session_id = d["session_id"] campaign_ids = d.get("campaign_ids", []) return SessionIntent( @@ -2410,12 +2485,16 @@ def _dict_to_session_intent(self, d: Dict) -> SessionIntent: planned_intent=d.get("planned_intent"), actual_summary=d.get("actual_summary"), campaign_ids=campaign_ids, - created_at=datetime.fromisoformat(d["created_at"]) if isinstance(d.get("created_at"), str) else d.get("created_at", datetime.now()), - completed_at=datetime.fromisoformat(d["completed_at"]) if d.get("completed_at") and isinstance(d["completed_at"], str) else None, + created_at=datetime.fromisoformat(d["created_at"]) + if isinstance(d.get("created_at"), str) + else d.get("created_at", datetime.now()), + completed_at=datetime.fromisoformat(d["completed_at"]) + if d.get("completed_at") and isinstance(d["completed_at"], str) + else None, ) @staticmethod - def _dict_to_planned_session(d: Dict) -> PlannedSession: + def _dict_to_planned_session(d: dict) -> PlannedSession: return PlannedSession( id=d["id"], title=d.get("title"), @@ -2428,12 +2507,16 @@ def _dict_to_planned_session(d: Dict) -> PlannedSession: status=PlannedSessionStatus(d.get("status", "planned")), session_id=d.get("session_id"), campaign_ids=d.get("campaign_ids", []), - created_at=datetime.fromisoformat(d["created_at"]) if isinstance(d.get("created_at"), str) else d.get("created_at", datetime.now()), - updated_at=datetime.fromisoformat(d["updated_at"]) if isinstance(d.get("updated_at"), str) else d.get("updated_at", datetime.now()), + created_at=datetime.fromisoformat(d["created_at"]) + if isinstance(d.get("created_at"), str) + else d.get("created_at", datetime.now()), + updated_at=datetime.fromisoformat(d["updated_at"]) + if isinstance(d.get("updated_at"), str) + else d.get("updated_at", datetime.now()), ) @staticmethod - def _dict_to_plan_item(d: Dict) -> PlanItem: + def _dict_to_plan_item(d: dict) -> PlanItem: item_type = PlanItemType(d["type"]) spec_data = d.get("spec") imaging_spec = None @@ -2442,14 +2525,10 @@ def _dict_to_plan_item(d: Dict) -> PlanItem: if spec_data: if item_type == PlanItemType.IMAGING: valid = {f.name for f in dataclasses.fields(ImagingSpec)} - imaging_spec = ImagingSpec(**{ - k: v for k, v in spec_data.items() if k in valid - }) + imaging_spec = ImagingSpec(**{k: v for k, v in spec_data.items() if k in valid}) else: valid = {f.name for f in dataclasses.fields(BenchSpec)} - bench_spec = BenchSpec(**{ - k: v for k, v in spec_data.items() if k in valid - }) + bench_spec = BenchSpec(**{k: v for k, v in spec_data.items() if k in valid}) references = d.get("references") or [] @@ -2472,15 +2551,21 @@ def _dict_to_plan_item(d: Dict) -> PlanItem: inherit_from=d.get("inherit_from"), estimated_days=d.get("estimated_days"), phase_order=d.get("phase_order", 0), - created_at=datetime.fromisoformat(d["created_at"]) if isinstance(d.get("created_at"), str) else d.get("created_at", datetime.now()), - updated_at=datetime.fromisoformat(d["updated_at"]) if isinstance(d.get("updated_at"), str) else d.get("updated_at", datetime.now()), + created_at=datetime.fromisoformat(d["created_at"]) + if isinstance(d.get("created_at"), str) + else d.get("created_at", datetime.now()), + updated_at=datetime.fromisoformat(d["updated_at"]) + if isinstance(d.get("updated_at"), str) + else d.get("updated_at", datetime.now()), ) @staticmethod - def _dict_to_observation(d: Dict) -> Observation: + def _dict_to_observation(d: dict) -> Observation: return Observation( id=d["id"], - timestamp=datetime.fromisoformat(d["timestamp"]) if isinstance(d.get("timestamp"), str) else d.get("timestamp", datetime.now()), + timestamp=datetime.fromisoformat(d["timestamp"]) + if isinstance(d.get("timestamp"), str) + else d.get("timestamp", datetime.now()), type=d["type"], content=d["content"], embryo_id=d.get("embryo_id"), @@ -2491,60 +2576,78 @@ def _dict_to_observation(d: Dict) -> Observation: ) @staticmethod - def _dict_to_expectation(d: Dict) -> Expectation: + def _dict_to_expectation(d: dict) -> Expectation: return Expectation( id=d["id"], target=d["target"], prediction=d["prediction"], - expected_time=datetime.fromisoformat(d["expected_time"]) if isinstance(d.get("expected_time"), str) else d.get("expected_time", datetime.now()), + expected_time=datetime.fromisoformat(d["expected_time"]) + if isinstance(d.get("expected_time"), str) + else d.get("expected_time", datetime.now()), uncertainty=d.get("uncertainty"), basis=d.get("basis"), status=ExpectationStatus(d.get("status", "pending")), - created_at=datetime.fromisoformat(d["created_at"]) if isinstance(d.get("created_at"), str) else d.get("created_at", datetime.now()), - resolved_at=datetime.fromisoformat(d["resolved_at"]) if d.get("resolved_at") and isinstance(d["resolved_at"], str) else None, + created_at=datetime.fromisoformat(d["created_at"]) + if isinstance(d.get("created_at"), str) + else d.get("created_at", datetime.now()), + resolved_at=datetime.fromisoformat(d["resolved_at"]) + if d.get("resolved_at") and isinstance(d["resolved_at"], str) + else None, ) @staticmethod - def _dict_to_watchpoint(d: Dict) -> Watchpoint: + def _dict_to_watchpoint(d: dict) -> Watchpoint: return Watchpoint( id=d["id"], target=d["target"], condition=d["condition"], priority=Significance(d.get("priority", "medium")), status=WatchpointStatus(d.get("status", "active")), - created_at=datetime.fromisoformat(d["created_at"]) if isinstance(d.get("created_at"), str) else d.get("created_at", datetime.now()), + created_at=datetime.fromisoformat(d["created_at"]) + if isinstance(d.get("created_at"), str) + else d.get("created_at", datetime.now()), ) @staticmethod - def _dict_to_question(d: Dict) -> Question: + def _dict_to_question(d: dict) -> Question: return Question( id=d["id"], content=d["content"], status=QuestionStatus(d.get("status", "open")), resolution=d.get("resolution"), - created_at=datetime.fromisoformat(d["created_at"]) if isinstance(d.get("created_at"), str) else d.get("created_at", datetime.now()), - resolved_at=datetime.fromisoformat(d["resolved_at"]) if d.get("resolved_at") and isinstance(d["resolved_at"], str) else None, + created_at=datetime.fromisoformat(d["created_at"]) + if isinstance(d.get("created_at"), str) + else d.get("created_at", datetime.now()), + resolved_at=datetime.fromisoformat(d["resolved_at"]) + if d.get("resolved_at") and isinstance(d["resolved_at"], str) + else None, ) @staticmethod - def _dict_to_learning(d: Dict) -> Learning: + def _dict_to_learning(d: dict) -> Learning: return Learning( id=d["id"], content=d["content"], confidence=Confidence(d.get("confidence", "medium")), basis=d.get("basis"), - created_at=datetime.fromisoformat(d["created_at"]) if isinstance(d.get("created_at"), str) else d.get("created_at", datetime.now()), + created_at=datetime.fromisoformat(d["created_at"]) + if isinstance(d.get("created_at"), str) + else d.get("created_at", datetime.now()), ) @staticmethod - def _dict_to_embryo_understanding(d: Dict) -> EmbryoUnderstanding: + def _dict_to_embryo_understanding(d: dict) -> EmbryoUnderstanding: return EmbryoUnderstanding( embryo_id=d["embryo_id"], current_stage=d.get("current_stage"), - stage_confidence=Confidence(d["stage_confidence"]) if d.get("stage_confidence") else None, + stage_confidence=Confidence(d["stage_confidence"]) + if d.get("stage_confidence") + else None, health_assessment=d.get("health_assessment"), notes=d.get("notes") or [], - last_observed=datetime.fromisoformat(d["last_observed"]) if d.get("last_observed") and isinstance(d["last_observed"], str) else None, + last_observed=datetime.fromisoformat(d["last_observed"]) + if d.get("last_observed") and isinstance(d["last_observed"], str) + else None, is_tracked=d.get("is_tracked", True), is_hatched=bool(d.get("is_hatched", False)), needs_attention=bool(d.get("needs_attention", False)), diff --git a/gently/harness/memory/gap_assessment.py b/gently/harness/memory/gap_assessment.py index d43a5ec0..4b64a10b 100644 --- a/gently/harness/memory/gap_assessment.py +++ b/gently/harness/memory/gap_assessment.py @@ -9,9 +9,9 @@ import logging from dataclasses import dataclass, field from enum import Enum -from typing import List, Optional from .model import Campaign + try: from .file_store import FileContextStore as ContextStore except ImportError: @@ -22,22 +22,25 @@ class GapLayer(str, Enum): """Which context layer has a gap.""" - LAB = "lab" # Layer 1: lab identity, setup, organism + + LAB = "lab" # Layer 1: lab identity, setup, organism CAMPAIGN = "campaign" # Layer 2: research direction, goals - SESSION = "session" # Layer 3: today's intent + SESSION = "session" # Layer 3: today's intent REALTIME = "realtime" # Layer 4: current observations (never gapped) class GapSeverity(str, Enum): """How critical this gap is.""" - EMPTY = "empty" # Nothing at all — full onboarding needed - THIN = "thin" # Some context exists but insufficient + + EMPTY = "empty" # Nothing at all — full onboarding needed + THIN = "thin" # Some context exists but insufficient ADEQUATE = "adequate" # Enough to function, could be richer @dataclass class Gap: """A single identified gap in the daemon's knowledge.""" + layer: GapLayer severity: GapSeverity description: str @@ -51,12 +54,13 @@ class ContextGapReport: Drives the startup wizard: which steps to show, what to ask. """ - gaps: List[Gap] = field(default_factory=list) + + gaps: list[Gap] = field(default_factory=list) readiness: float = 0.0 # 0.0 (blank) to 1.0 (fully oriented) needs_lab_onboarding: bool = False needs_campaign: bool = False needs_session_intent: bool = False - active_campaigns: List[Campaign] = field(default_factory=list) # All active campaigns + active_campaigns: list[Campaign] = field(default_factory=list) # All active campaigns past_campaign_count: int = 0 # Completed/paused campaigns session_count: int = 0 # How many past sessions exist learning_count: int = 0 @@ -102,28 +106,41 @@ def assess_gaps(context_store: ContextStore) -> ContextGapReport: # Check for lab-level knowledge (learnings about the setup, organism, etc.) lab_learnings = [ - l for l in learnings - if l.basis and any( - kw in l.basis.lower() - for kw in ("lab", "setup", "microscope", "organism", "onboarding", "identity") + learning + for learning in learnings + if learning.basis + and any( + kw in learning.basis.lower() + for kw in ( + "lab", + "setup", + "microscope", + "organism", + "onboarding", + "identity", + ) ) ] if not learnings: report.needs_lab_onboarding = True - report.gaps.append(Gap( - layer=GapLayer.LAB, - severity=GapSeverity.EMPTY, - description="No learnings at all — this appears to be a first launch.", - suggested_action="Conduct lab onboarding conversation.", - )) + report.gaps.append( + Gap( + layer=GapLayer.LAB, + severity=GapSeverity.EMPTY, + description="No learnings at all — this appears to be a first launch.", + suggested_action="Conduct lab onboarding conversation.", + ) + ) elif not lab_learnings: - report.gaps.append(Gap( - layer=GapLayer.LAB, - severity=GapSeverity.THIN, - description="Have learnings but none about lab identity/setup.", - suggested_action="Ask about lab setup and research program.", - )) + report.gaps.append( + Gap( + layer=GapLayer.LAB, + severity=GapSeverity.THIN, + description="Have learnings but none about lab identity/setup.", + suggested_action="Ask about lab setup and research program.", + ) + ) readiness_score += 0.1 else: readiness_score += 0.25 @@ -139,19 +156,23 @@ def assess_gaps(context_store: ContextStore) -> ContextGapReport: report.needs_campaign = True if report.past_campaign_count == 0: - report.gaps.append(Gap( - layer=GapLayer.CAMPAIGN, - severity=GapSeverity.EMPTY, - description="No campaigns ever created — no research direction.", - suggested_action="Ask about research goals or suggest ingesting papers.", - )) + report.gaps.append( + Gap( + layer=GapLayer.CAMPAIGN, + severity=GapSeverity.EMPTY, + description="No campaigns ever created — no research direction.", + suggested_action="Ask about research goals or suggest ingesting papers.", + ) + ) else: - report.gaps.append(Gap( - layer=GapLayer.CAMPAIGN, - severity=GapSeverity.THIN, - description=f"{report.past_campaign_count} past campaigns but none active.", - suggested_action="Ask if starting new work or resuming.", - )) + report.gaps.append( + Gap( + layer=GapLayer.CAMPAIGN, + severity=GapSeverity.THIN, + description=f"{report.past_campaign_count} past campaigns but none active.", + suggested_action="Ask if starting new work or resuming.", + ) + ) readiness_score += 0.1 else: readiness_score += 0.25 @@ -166,19 +187,23 @@ def assess_gaps(context_store: ContextStore) -> ContextGapReport: report.session_count = context_store.count_session_intents() if report.session_count == 0: - report.gaps.append(Gap( - layer=GapLayer.SESSION, - severity=GapSeverity.EMPTY, - description="No session history at all.", - suggested_action="Establish session intent after campaign context.", - )) + report.gaps.append( + Gap( + layer=GapLayer.SESSION, + severity=GapSeverity.EMPTY, + description="No session history at all.", + suggested_action="Establish session intent after campaign context.", + ) + ) else: - report.gaps.append(Gap( - layer=GapLayer.SESSION, - severity=GapSeverity.THIN, - description=f"{report.session_count} past sessions, but no intent for current.", - suggested_action="Quick check-in: continuing campaign or starting fresh?", - )) + report.gaps.append( + Gap( + layer=GapLayer.SESSION, + severity=GapSeverity.THIN, + description=f"{report.session_count} past sessions, but no intent for current.", + suggested_action="Quick check-in: continuing campaign or starting fresh?", + ) + ) readiness_score += 0.1 else: readiness_score += 0.25 diff --git a/gently/harness/memory/interface.py b/gently/harness/memory/interface.py index 49712371..2e27db8d 100644 --- a/gently/harness/memory/interface.py +++ b/gently/harness/memory/interface.py @@ -10,7 +10,6 @@ """ import logging -from typing import Dict, List, Optional, Tuple logger = logging.getLogger(__name__) @@ -36,17 +35,17 @@ class AgentMemory: or prompt injection. """ - def __init__(self, context_store, session_id: str = None): + def __init__(self, context_store, session_id: str | None = None): self.store = context_store self.session_id = session_id # Set by startup flow after resolve_plan_context() - self.active_plan_item_id: Optional[str] = None + self.active_plan_item_id: str | None = None # ------------------------------------------------------------------ # Plan context resolution — called at startup # ------------------------------------------------------------------ - def resolve_plan_context(self) -> Tuple[Optional[str], List]: + def resolve_plan_context(self) -> tuple[str | None, list]: """Determine which plan item to activate for this session. Scans all active campaigns for unblocked imaging items. @@ -61,6 +60,7 @@ def resolve_plan_context(self) -> Tuple[Optional[str], List]: """ try: from .model import PlanItemType + root_campaigns = self.store.get_root_campaigns(status="active") imaging_candidates = [] for campaign in root_campaigns: @@ -186,13 +186,11 @@ def get_awareness_summary(self) -> str: # Session-campaign link if self.session_id: try: - session_campaigns = self.store.get_campaigns_for_session( - self.session_id - ) + session_campaigns = self.store.get_campaigns_for_session(self.session_id) if session_campaigns: names = [_short_name(c) for c in session_campaigns[:2]] lines.append( - f'- This session: linked to {", ".join(f"{n!r}" for n in names)}' + f"- This session: linked to {', '.join(f'{n!r}' for n in names)}" ) except Exception: pass @@ -245,7 +243,7 @@ def get_awareness_summary(self) -> str: # Briefing layer — auto-briefing at session start # ------------------------------------------------------------------ - def get_session_briefing(self, campaign_id: str = None) -> str: + def get_session_briefing(self, campaign_id: str | None = None) -> str: """Generate a session briefing for new sessions. If campaign_id is provided (or session is linked to a campaign), @@ -291,7 +289,9 @@ def _briefing_for_campaign(self, campaign) -> str: for p in phases: try: ps = self.store.get_plan_status(p.id) - items_str = f" ({ps['completed']}/{ps['total']} items)" if ps["total"] > 0 else "" + items_str = ( + f" ({ps['completed']}/{ps['total']} items)" if ps["total"] > 0 else "" + ) except Exception: items_str = "" lines.append(f" - {_friendly_name(p)}{items_str}") @@ -301,6 +301,7 @@ def _briefing_for_campaign(self, campaign) -> str: # Plan status (root campaign) try: from .model import PlanItemType + status = self.store.get_plan_status(campaign.id) if status["total"] > 0: lines.append( @@ -328,9 +329,9 @@ def _briefing_for_campaign(self, campaign) -> str: learnings = self.store.get_learnings(limit=50) if learnings: lines.append("\n**Recent learnings**:") - for l in learnings[:5]: - conf = l.confidence.value if l.confidence else "?" - lines.append(f" - [{conf}] {l.content[:150]}") + for learning in learnings[:5]: + conf = learning.confidence.value if learning.confidence else "?" + lines.append(f" - [{conf}] {learning.content[:150]}") # Other active root campaigns (brief mention) root_campaigns = self.store.get_root_campaigns() @@ -379,9 +380,7 @@ def _briefing_broad(self) -> str: ) else: lines.append("## Ready to image") - lines.append( - f"{len(candidates)} imaging tasks are unblocked:" - ) + lines.append(f"{len(candidates)} imaging tasks are unblocked:") lines.append("") for item, spec, campaign in candidates: spec_summary = self.format_imaging_spec_summary(spec) if spec else "no spec" @@ -441,7 +440,11 @@ def recall_campaigns(self, status: str = "active") -> str: for p in phases: try: ps = self.store.get_plan_status(p.id) - items_str = f" ({ps['completed']}/{ps['total']} items)" if ps["total"] > 0 else "" + items_str = ( + f" ({ps['completed']}/{ps['total']} items)" + if ps["total"] > 0 + else "" + ) except Exception: items_str = "" lines.append(f" - {_friendly_name(p)}{items_str}") @@ -452,7 +455,7 @@ def recall_campaigns(self, status: str = "active") -> str: return "\n".join(lines) - def recall_learnings(self, query: str = None, limit: int = 20) -> str: + def recall_learnings(self, query: str | None = None, limit: int = 20) -> str: """Search or list learnings.""" learnings = self.store.get_learnings(limit=max(limit, 50)) @@ -460,10 +463,11 @@ def recall_learnings(self, query: str = None, limit: int = 20) -> str: query_lower = query.lower() terms = query_lower.split() learnings = [ - l - for l in learnings + learning + for learning in learnings if any( - term in l.content.lower() or (l.basis and term in l.basis.lower()) + term in learning.content.lower() + or (learning.basis and term in learning.basis.lower()) for term in terms ) ] @@ -471,22 +475,24 @@ def recall_learnings(self, query: str = None, limit: int = 20) -> str: learnings = learnings[:limit] if not learnings: - msg = f"No learnings found matching '{query}'." if query else "No learnings recorded yet." + msg = ( + f"No learnings found matching '{query}'." if query else "No learnings recorded yet." + ) return msg header = f"## Learnings matching '{query}'" if query else "## Recent Learnings" lines = [header] - for l in learnings: - conf = l.confidence.value if l.confidence else "?" - lines.append(f"\n- [{conf}] {l.content}") - if l.basis: - lines.append(f" _Basis_: {l.basis[:200]}") - lines.append(f" _{l.created_at.strftime('%Y-%m-%d %H:%M')}_") + for learning in learnings: + conf = learning.confidence.value if learning.confidence else "?" + lines.append(f"\n- [{conf}] {learning.content}") + if learning.basis: + lines.append(f" _Basis_: {learning.basis[:200]}") + lines.append(f" _{learning.created_at.strftime('%Y-%m-%d %H:%M')}_") return "\n".join(lines) def recall_observations( - self, query: str = None, embryo_id: str = None, limit: int = 20 + self, query: str | None = None, embryo_id: str | None = None, limit: int = 20 ) -> str: """Search or list observations.""" if embryo_id: @@ -498,15 +504,17 @@ def recall_observations( query_lower = query.lower() terms = query_lower.split() observations = [ - o - for o in observations - if any(term in o.content.lower() for term in terms) + o for o in observations if any(term in o.content.lower() for term in terms) ] observations = observations[:limit] if not observations: - msg = f"No observations found matching '{query}'." if query else "No observations recorded yet." + msg = ( + f"No observations found matching '{query}'." + if query + else "No observations recorded yet." + ) return msg header = "## Observations" @@ -525,7 +533,7 @@ def recall_observations( return "\n".join(lines) - def recall_full_context(self, campaign_id: str = None) -> str: + def recall_full_context(self, campaign_id: str | None = None) -> str: """Full context snapshot — the 'catch me up' method. If campaign_id provided (or session is linked), focuses there. @@ -559,7 +567,9 @@ def recall_full_context(self, campaign_id: str = None) -> str: lines.append("\n### Active Campaigns") for c in root_campaigns: name = _friendly_name(c) - is_focus = " ← this session" if (focus_campaign and c.id == focus_campaign.id) else "" + is_focus = ( + " ← this session" if (focus_campaign and c.id == focus_campaign.id) else "" + ) progress = f" — {c.progress}" if c.progress else "" try: status = self.store.get_plan_status(c.id) @@ -580,7 +590,9 @@ def recall_full_context(self, campaign_id: str = None) -> str: for p in phases: try: ps = self.store.get_plan_status(p.id) - items_str = f" ({ps['completed']}/{ps['total']})" if ps["total"] > 0 else "" + items_str = ( + f" ({ps['completed']}/{ps['total']})" if ps["total"] > 0 else "" + ) except Exception: items_str = "" lines.append(f" - {_short_name(p)}{items_str}") @@ -606,9 +618,9 @@ def recall_full_context(self, campaign_id: str = None) -> str: learnings = self.store.get_learnings(limit=10) if learnings: lines.append("\n### Recent Learnings") - for l in learnings[:10]: - conf = l.confidence.value if l.confidence else "?" - lines.append(f"- [{conf}] {l.content[:150]}") + for learning in learnings[:10]: + conf = learning.confidence.value if learning.confidence else "?" + lines.append(f"- [{conf}] {learning.content[:150]}") # Expectations expectations = self.store.get_pending_expectations() diff --git a/gently/harness/memory/model.py b/gently/harness/memory/model.py index d9dc9a4e..2a164176 100644 --- a/gently/harness/memory/model.py +++ b/gently/harness/memory/model.py @@ -6,12 +6,13 @@ from dataclasses import dataclass, field from datetime import datetime -from typing import Any, Dict, List, Optional from enum import Enum +from typing import Any class Significance(str, Enum): """How important something is.""" + HIGH = "high" MEDIUM = "medium" LOW = "low" @@ -19,6 +20,7 @@ class Significance(str, Enum): class Confidence(str, Enum): """How confident we are in a belief.""" + HIGH = "high" MEDIUM = "medium" LOW = "low" @@ -26,6 +28,7 @@ class Confidence(str, Enum): class Status(str, Enum): """Generic status for things that can be active/completed.""" + ACTIVE = "active" PAUSED = "paused" COMPLETED = "completed" @@ -33,15 +36,17 @@ class Status(str, Enum): class PlannedSessionStatus(str, Enum): """Status for planned imaging sessions.""" + PLANNED = "planned" - ACTIVE = "active" # Currently in progress + ACTIVE = "active" # Currently in progress COMPLETED = "completed" - SKIPPED = "skipped" # Decided not to do it + SKIPPED = "skipped" # Decided not to do it CANCELLED = "cancelled" class PlanItemStatus(str, Enum): """Status for plan items.""" + PLANNED = "planned" IN_PROGRESS = "in_progress" COMPLETED = "completed" @@ -51,6 +56,7 @@ class PlanItemStatus(str, Enum): class PlanItemType(str, Enum): """Type of plan item.""" + IMAGING = "imaging" BENCH = "bench" GENETICS = "genetics" @@ -60,6 +66,7 @@ class PlanItemType(str, Enum): class ExpectationStatus(str, Enum): """Status for expectations/predictions.""" + PENDING = "pending" CONFIRMED = "confirmed" SURPRISED = "surprised" @@ -68,6 +75,7 @@ class ExpectationStatus(str, Enum): class WatchpointStatus(str, Enum): """Status for watchpoints.""" + ACTIVE = "active" TRIGGERED = "triggered" RESOLVED = "resolved" @@ -75,6 +83,7 @@ class WatchpointStatus(str, Enum): class QuestionStatus(str, Enum): """Status for open questions.""" + OPEN = "open" INVESTIGATING = "investigating" RESOLVED = "resolved" @@ -84,6 +93,7 @@ class QuestionStatus(str, Enum): # Intentions: Why are we doing this? # --------------------------------------------------------------------------- + @dataclass class Campaign: """ @@ -95,13 +105,14 @@ class Campaign: Example: "Capture 50 hatching events from wild-type embryos" """ + id: str description: str # Natural language, as the researcher said it - shorthand: Optional[str] = None # Short label: "temp-division", "hatching-50" - summary: Optional[str] = None # Agent-rephrased structured summary - target: Optional[str] = None # Measurable goal: "50 hatching events" - progress: Optional[str] = None # Current state: "23/50" - parent_id: Optional[str] = None # Parent campaign (for hierarchy) + shorthand: str | None = None # Short label: "temp-division", "hatching-50" + summary: str | None = None # Agent-rephrased structured summary + target: str | None = None # Measurable goal: "50 hatching events" + progress: str | None = None # Current state: "23/50" + parent_id: str | None = None # Parent campaign (for hierarchy) status: Status = Status.ACTIVE is_shared: bool = False created_at: datetime = field(default_factory=datetime.now) @@ -122,9 +133,10 @@ class Project: Example: "Optimize imaging parameters for early stages" """ + id: str description: str - campaign_id: Optional[str] = None + campaign_id: str | None = None status: Status = Status.ACTIVE created_at: datetime = field(default_factory=datetime.now) updated_at: datetime = field(default_factory=datetime.now) @@ -138,12 +150,13 @@ class SessionIntent: Tracks planned intent vs what actually happened. A session can belong to multiple campaigns (linked via session_campaigns). """ + session_id: str - planned_intent: Optional[str] = None # What was planned - actual_summary: Optional[str] = None # What happened - campaign_ids: List[str] = field(default_factory=list) # Linked campaigns + planned_intent: str | None = None # What was planned + actual_summary: str | None = None # What happened + campaign_ids: list[str] = field(default_factory=list) # Linked campaigns created_at: datetime = field(default_factory=datetime.now) - completed_at: Optional[datetime] = None + completed_at: datetime | None = None @dataclass @@ -159,17 +172,18 @@ class PlannedSession: can match the planned session to the actual session and pre-populate intent + parameters. """ + id: str - title: Optional[str] = None # "N2 baseline imaging round 3" - notes: Optional[str] = None # Free-form: what to do, what to watch for - scheduled_date: Optional[str] = None # ISO date: "2026-02-15" - scheduled_time: Optional[str] = None # ISO time: "14:00" (optional) - estimated_duration_minutes: Optional[int] = None - acquisition_params: Optional[Dict[str, Any]] = None # From previous session - source_session_id: Optional[str] = None # "use params from this session" + title: str | None = None # "N2 baseline imaging round 3" + notes: str | None = None # Free-form: what to do, what to watch for + scheduled_date: str | None = None # ISO date: "2026-02-15" + scheduled_time: str | None = None # ISO time: "14:00" (optional) + estimated_duration_minutes: int | None = None + acquisition_params: dict[str, Any] | None = None # From previous session + source_session_id: str | None = None # "use params from this session" status: PlannedSessionStatus = PlannedSessionStatus.PLANNED - session_id: Optional[str] = None # Linked actual session once started - campaign_ids: List[str] = field(default_factory=list) + session_id: str | None = None # Linked actual session once started + campaign_ids: list[str] = field(default_factory=list) created_at: datetime = field(default_factory=datetime.now) updated_at: datetime = field(default_factory=datetime.now) @@ -190,40 +204,41 @@ class ImagingSpec: Everything needed to auto-configure the microscope when it's time to execute this plan item. """ + # Sample - strain: Optional[str] = None # "OH904" - genotype: Optional[str] = None # "otIs355[rab-3p::2xNLS::TagRFP]" - reporter: Optional[str] = None # "rab-3p::GFP (pan-neuronal)" - sample_prep: Optional[str] = None # "Standard egg prep, poly-lysine pads" - temperature_c: Optional[float] = None # 20.0 - num_embryos: Optional[int] = None # 4 + strain: str | None = None # "OH904" + genotype: str | None = None # "otIs355[rab-3p::2xNLS::TagRFP]" + reporter: str | None = None # "rab-3p::GFP (pan-neuronal)" + sample_prep: str | None = None # "Standard egg prep, poly-lysine pads" + temperature_c: float | None = None # 20.0 + num_embryos: int | None = None # 4 # Acquisition - num_slices: Optional[int] = None # 80 - exposure_ms: Optional[float] = None # 10.0 - laser_wavelength_nm: Optional[int] = None # 488 - laser_power_pct: Optional[float] = None # 10.0 - galvo_amplitude: Optional[float] = None # 8.0 - piezo_amplitude_um: Optional[float] = None # 50.0 + num_slices: int | None = None # 80 + exposure_ms: float | None = None # 10.0 + laser_wavelength_nm: int | None = None # 488 + laser_power_pct: float | None = None # 10.0 + galvo_amplitude: float | None = None # 8.0 + piezo_amplitude_um: float | None = None # 50.0 # Timing - interval_s: Optional[int] = None # 180 - adaptive_intervals: Optional[Dict[str, int]] = None + interval_s: int | None = None # 180 + adaptive_intervals: dict[str, int] | None = None # e.g. {"early_to_comma": 300, "comma_to_2fold": 60} # Developmental window - target_window: Optional[str] = None # "comma → pretzel" - start_stage: Optional[str] = None # "comma" - stop_condition: Optional[str] = None # "pretzel" - estimated_duration_h: Optional[float] = None # 4.0 + target_window: str | None = None # "comma → pretzel" + start_stage: str | None = None # "comma" + stop_condition: str | None = None # "pretzel" + estimated_duration_h: float | None = None # 4.0 # Detection - detectors: Optional[List[str]] = None # ["comma", "pretzel"] - pre_terminal_speedup: Optional[bool] = None + detectors: list[str] | None = None # ["comma", "pretzel"] + pre_terminal_speedup: bool | None = None # Validation - success_criteria: Optional[str] = None - comparison_to: Optional[str] = None # "Compare to WT session 1" + success_criteria: str | None = None + comparison_to: str | None = None # "Compare to WT session 1" @dataclass @@ -231,13 +246,14 @@ class BenchSpec: """ Specification for non-imaging tasks (bench work, genetics, analysis). """ - protocol: Optional[str] = None # "Standard chemotaxis assay" - reagents: Optional[List[str]] = None # ["anti-UNC-33", "secondary 568"] - strains: Optional[List[str]] = None # ["OH904", "N2"] - target_genotype: Optional[str] = None # "unc-6(ev400); otIs355" - estimated_days: Optional[int] = None # 14 - success_criteria: Optional[str] = None - notes: Optional[str] = None + + protocol: str | None = None # "Standard chemotaxis assay" + reagents: list[str] | None = None # ["anti-UNC-33", "secondary 568"] + strains: list[str] | None = None # ["OH904", "N2"] + target_genotype: str | None = None # "unc-6(ev400); otIs355" + estimated_days: int | None = None # 14 + success_criteria: str | None = None + notes: str | None = None @dataclass @@ -252,29 +268,30 @@ class PlanItem: Dependencies between items (depends_on) enable the agent to track what's blocked and what's newly unblocked. """ + id: str - campaign_id: str # Which campaign/phase - type: PlanItemType # imaging, bench, genetics, analysis, decision_point - title: str # "Pilot — rab-3p::GFP visibility test" - description: Optional[str] = None # Detailed notes, protocols, what to watch for + campaign_id: str # Which campaign/phase + type: PlanItemType # imaging, bench, genetics, analysis, decision_point + title: str # "Pilot — rab-3p::GFP visibility test" + description: str | None = None # Detailed notes, protocols, what to watch for status: PlanItemStatus = PlanItemStatus.PLANNED - depends_on: List[str] = field(default_factory=list) # PlanItem IDs - outcome: Optional[str] = None # What happened (filled after completion) - claimed_by: Optional[str] = None # instance_id of claiming node - claimed_by_hostname: Optional[str] = None # human-readable hostname - references: List[Dict[str, str]] = field(default_factory=list) # Source citations + depends_on: list[str] = field(default_factory=list) # PlanItem IDs + outcome: str | None = None # What happened (filled after completion) + claimed_by: str | None = None # instance_id of claiming node + claimed_by_hostname: str | None = None # human-readable hostname + references: list[dict[str, str]] = field(default_factory=list) # Source citations # Specifications (type-dependent) - imaging_spec: Optional[ImagingSpec] = None - bench_spec: Optional[BenchSpec] = None + imaging_spec: ImagingSpec | None = None + bench_spec: BenchSpec | None = None # Linking - planned_session_id: Optional[str] = None # → PlannedSession (for imaging items) - session_id: Optional[str] = None # → Actual session (once executed) - inherit_from: Optional[str] = None # PlanItem ID to inherit spec from + planned_session_id: str | None = None # → PlannedSession (for imaging items) + session_id: str | None = None # → Actual session (once executed) + inherit_from: str | None = None # PlanItem ID to inherit spec from # Scheduling — relative timeline from Day 0 - estimated_days: Optional[int] = None # Duration in days (for Gantt/timeline views) + estimated_days: int | None = None # Duration in days (for Gantt/timeline views) # Ordering phase_order: int = 0 @@ -291,17 +308,19 @@ def is_actionable(self) -> bool: @dataclass class Intentions: """Collection of the agent's intentions at multiple levels.""" - campaigns: List[Campaign] = field(default_factory=list) - projects: List[Project] = field(default_factory=list) - planned_sessions: List[PlannedSession] = field(default_factory=list) - current_focus: Optional[str] = None - session_intent: Optional[SessionIntent] = None + + campaigns: list[Campaign] = field(default_factory=list) + projects: list[Project] = field(default_factory=list) + planned_sessions: list[PlannedSession] = field(default_factory=list) + current_focus: str | None = None + session_intent: SessionIntent | None = None # --------------------------------------------------------------------------- # Understanding: What do we believe? # --------------------------------------------------------------------------- + @dataclass class Learning: """ @@ -309,10 +328,11 @@ class Learning: Example: "Batch 7 develops 20% faster than average" """ + id: str content: str # Human-readable insight confidence: Confidence = Confidence.MEDIUM - basis: Optional[str] = None # What observations support this + basis: str | None = None # What observations support this created_at: datetime = field(default_factory=datetime.now) @@ -323,31 +343,34 @@ class EmbryoUnderstanding: This is synthesized understanding, not raw data. """ + embryo_id: str - current_stage: Optional[str] = None - stage_confidence: Optional[Confidence] = None - health_assessment: Optional[str] = None - notes: List[str] = field(default_factory=list) - last_observed: Optional[datetime] = None + current_stage: str | None = None + stage_confidence: Confidence | None = None + health_assessment: str | None = None + notes: list[str] = field(default_factory=list) + last_observed: datetime | None = None # Tracking flags is_tracked: bool = True is_hatched: bool = False needs_attention: bool = False - attention_reason: Optional[str] = None + attention_reason: str | None = None @dataclass class Understanding: """The agent's overall understanding of the experiment.""" - embryo_states: Dict[str, EmbryoUnderstanding] = field(default_factory=dict) - learnings: List[Learning] = field(default_factory=list) + + embryo_states: dict[str, EmbryoUnderstanding] = field(default_factory=dict) + learnings: list[Learning] = field(default_factory=list) # --------------------------------------------------------------------------- # Observations: What have we seen? (synthesized, not raw) # --------------------------------------------------------------------------- + @dataclass class Observation: """ @@ -355,21 +378,23 @@ class Observation: Not raw data — a meaningful note about what happened. """ + id: str timestamp: datetime type: str # stage_transition, anomaly, session_summary, milestone content: str # Human-readable description - embryo_id: Optional[str] = None + embryo_id: str | None = None significance: Significance = Significance.MEDIUM - session_id: Optional[str] = None - gently_refs: Optional[Dict[str, Any]] = None # References to FileStore data - relates_to: Optional[List[str]] = None # Related goals/observations + session_id: str | None = None + gently_refs: dict[str, Any] | None = None # References to FileStore data + relates_to: list[str] | None = None # Related goals/observations # --------------------------------------------------------------------------- # Expectations: What do we predict? # --------------------------------------------------------------------------- + @dataclass class Expectation: """ @@ -377,21 +402,23 @@ class Expectation: Example: "embryo_3 will reach comma stage by 14:30" """ + id: str target: str # What this is about (embryo_id, etc) prediction: str # "will reach comma stage" expected_time: datetime - uncertainty: Optional[str] = None # "±30 minutes" - basis: Optional[str] = None # Why we expect this + uncertainty: str | None = None # "±30 minutes" + basis: str | None = None # Why we expect this status: ExpectationStatus = ExpectationStatus.PENDING created_at: datetime = field(default_factory=datetime.now) - resolved_at: Optional[datetime] = None + resolved_at: datetime | None = None # --------------------------------------------------------------------------- # Attention: What should we watch? # --------------------------------------------------------------------------- + @dataclass class Watchpoint: """ @@ -399,6 +426,7 @@ class Watchpoint: Example: Watch embryo_3 for "approaching hatching" """ + id: str target: str # "embryo_3" condition: str # "approaching hatching" @@ -414,25 +442,28 @@ class Question: Example: "Why is batch 7 slower than batch 6?" """ + id: str content: str status: QuestionStatus = QuestionStatus.OPEN - resolution: Optional[str] = None + resolution: str | None = None created_at: datetime = field(default_factory=datetime.now) - resolved_at: Optional[datetime] = None + resolved_at: datetime | None = None @dataclass class Attention: """What the agent is watching/thinking about.""" - watchpoints: List[Watchpoint] = field(default_factory=list) - open_questions: List[Question] = field(default_factory=list) + + watchpoints: list[Watchpoint] = field(default_factory=list) + open_questions: list[Question] = field(default_factory=list) # --------------------------------------------------------------------------- # Full Context # --------------------------------------------------------------------------- + @dataclass class Context: """ @@ -440,24 +471,25 @@ class Context: This is the agent's "working memory" for a single think. """ + intentions: Intentions = field(default_factory=Intentions) understanding: Understanding = field(default_factory=Understanding) - observations: List[Observation] = field(default_factory=list) - expectations: List[Expectation] = field(default_factory=list) + observations: list[Observation] = field(default_factory=list) + expectations: list[Expectation] = field(default_factory=list) attention: Attention = field(default_factory=Attention) @property - def active_campaigns(self) -> List[Campaign]: + def active_campaigns(self) -> list[Campaign]: """Get active campaigns.""" return [c for c in self.intentions.campaigns if c.status == Status.ACTIVE] @property - def pending_expectations(self) -> List[Expectation]: + def pending_expectations(self) -> list[Expectation]: """Get pending expectations.""" return [e for e in self.expectations if e.status == ExpectationStatus.PENDING] @property - def active_watchpoints(self) -> List[Watchpoint]: + def active_watchpoints(self) -> list[Watchpoint]: """Get active watchpoints.""" return [w for w in self.attention.watchpoints if w.status == WatchpointStatus.ACTIVE] @@ -466,6 +498,7 @@ def active_watchpoints(self) -> List[Watchpoint]: # Context Updates (from agent response) # --------------------------------------------------------------------------- + @dataclass class ContextUpdates: """ @@ -473,23 +506,24 @@ class ContextUpdates: The agent returns these after thinking. """ + # New items to add - new_observations: List[Observation] = field(default_factory=list) - new_expectations: List[Expectation] = field(default_factory=list) - new_watchpoints: List[Watchpoint] = field(default_factory=list) - new_learnings: List[Learning] = field(default_factory=list) - new_questions: List[Question] = field(default_factory=list) + new_observations: list[Observation] = field(default_factory=list) + new_expectations: list[Expectation] = field(default_factory=list) + new_watchpoints: list[Watchpoint] = field(default_factory=list) + new_learnings: list[Learning] = field(default_factory=list) + new_questions: list[Question] = field(default_factory=list) # Status updates - resolved_expectations: Dict[str, ExpectationStatus] = field(default_factory=dict) - triggered_watchpoints: List[str] = field(default_factory=list) - resolved_questions: Dict[str, str] = field(default_factory=dict) # id -> resolution + resolved_expectations: dict[str, ExpectationStatus] = field(default_factory=dict) + triggered_watchpoints: list[str] = field(default_factory=list) + resolved_questions: dict[str, str] = field(default_factory=dict) # id -> resolution # Understanding updates - embryo_updates: Dict[str, Dict[str, Any]] = field(default_factory=dict) + embryo_updates: dict[str, dict[str, Any]] = field(default_factory=dict) # Campaign/project progress - campaign_progress: Dict[str, str] = field(default_factory=dict) # id -> progress + campaign_progress: dict[str, str] = field(default_factory=dict) # id -> progress # Focus update - new_focus: Optional[str] = None + new_focus: str | None = None diff --git a/gently/harness/memory/onboarding.py b/gently/harness/memory/onboarding.py index 937fba7b..75f24d19 100644 --- a/gently/harness/memory/onboarding.py +++ b/gently/harness/memory/onboarding.py @@ -8,20 +8,18 @@ import logging import uuid -from dataclasses import dataclass, field -from datetime import datetime -from typing import Any, Dict, List, Optional +from dataclasses import dataclass +from typing import Any from gently.settings import settings + from .gap_assessment import ContextGapReport, GapLayer from .model import ( - Campaign, + Confidence, Learning, - SessionIntent, Watchpoint, - Confidence, - Significance, ) + try: from .file_store import FileContextStore as ContextStore except ImportError: @@ -31,11 +29,13 @@ @dataclass class OnboardingMessage: """A message to surface to the researcher during onboarding.""" + message: str layer: GapLayer priority: str = "normal" # "high", "normal", "low" reason: str = "" + logger = logging.getLogger(__name__) @@ -80,8 +80,8 @@ class OnboardingMessage: def generate_onboarding_messages( gap_report: ContextGapReport, - session_id: Optional[str] = None, -) -> List[OnboardingMessage]: + session_id: str | None = None, +) -> list[OnboardingMessage]: """ Generate onboarding messages based on the gap assessment. @@ -100,12 +100,14 @@ def generate_onboarding_messages( messages = [] if gap_report.needs_lab_onboarding: - messages.append(OnboardingMessage( - message=LAB_ONBOARDING_GREETING, - layer=GapLayer.LAB, - priority="high", - reason="First launch — need to learn about the lab.", - )) + messages.append( + OnboardingMessage( + message=LAB_ONBOARDING_GREETING, + layer=GapLayer.LAB, + priority="high", + reason="First launch — need to learn about the lab.", + ) + ) if gap_report.needs_campaign: if gap_report.past_campaign_count > 0: @@ -116,12 +118,14 @@ def generate_onboarding_messages( else: prompt = CAMPAIGN_PROMPT_FRESH - messages.append(OnboardingMessage( - message=prompt, - layer=GapLayer.CAMPAIGN, - priority="normal", - reason="No active campaign — need research direction.", - )) + messages.append( + OnboardingMessage( + message=prompt, + layer=GapLayer.CAMPAIGN, + priority="normal", + reason="No active campaign — need research direction.", + ) + ) if gap_report.needs_session_intent and session_id: if gap_report.has_campaigns: @@ -137,12 +141,14 @@ def generate_onboarding_messages( else: prompt = SESSION_PROMPT_NO_CAMPAIGN - messages.append(OnboardingMessage( - message=prompt, - layer=GapLayer.SESSION, - priority="normal", - reason="Need to establish session intent.", - )) + messages.append( + OnboardingMessage( + message=prompt, + layer=GapLayer.SESSION, + priority="normal", + reason="Need to establish session intent.", + ) + ) if messages: logger.info( @@ -155,8 +161,8 @@ def generate_onboarding_messages( def get_onboarding_messages( gap_report: ContextGapReport, - session_id: Optional[str] = None, -) -> List[str]: + session_id: str | None = None, +) -> list[str]: """ Get plain-text onboarding messages for direct CLI display. @@ -179,10 +185,12 @@ def get_onboarding_messages( if gap_report.needs_campaign: if gap_report.past_campaign_count > 0: - messages.append(CAMPAIGN_PROMPT_RETURNING.format( - campaign_description="(previous campaigns completed)", - campaign_status="completed", - )) + messages.append( + CAMPAIGN_PROMPT_RETURNING.format( + campaign_description="(previous campaigns completed)", + campaign_status="completed", + ) + ) else: messages.append(CAMPAIGN_PROMPT_FRESH) @@ -192,10 +200,12 @@ def get_onboarding_messages( if gap_report.session_count > 0: history = f"This is session #{gap_report.session_count + 1}. " campaign_name = gap_report.active_campaigns[0].display_name - messages.append(SESSION_PROMPT_WITH_CAMPAIGN.format( - campaign_description=campaign_name, - history_context=history, - )) + messages.append( + SESSION_PROMPT_WITH_CAMPAIGN.format( + campaign_description=campaign_name, + history_context=history, + ) + ) else: messages.append(SESSION_PROMPT_NO_CAMPAIGN) @@ -238,9 +248,9 @@ async def process_onboarding_response( response: str, topic: str, context_store: ContextStore, - claude_client: Optional[Any] = None, - session_id: Optional[str] = None, -) -> Dict[str, Any]: + claude_client: Any | None = None, + session_id: str | None = None, +) -> dict[str, Any]: """ Process a researcher's response during onboarding. @@ -274,8 +284,7 @@ async def process_onboarding_response( extracted = _extract_basic(response, topic, context_store, session_id) logger.info( - f"Onboarding response processed ({topic}): " - f"{extracted['entries_created']} entries created" + f"Onboarding response processed ({topic}): {extracted['entries_created']} entries created" ) return extracted @@ -285,8 +294,8 @@ async def _extract_with_llm( topic: str, context_store: ContextStore, claude_client: Any, - session_id: Optional[str], -) -> Dict[str, Any]: + session_id: str | None, +) -> dict[str, Any]: """Use Claude to extract structured context from a response.""" import asyncio import json @@ -314,33 +323,37 @@ async def _extract_with_llm( entries = 0 # Store learnings - for item in (data.get("learnings") or []): + for item in data.get("learnings") or []: if item and item.get("content"): - context_store.add_learning(Learning( - id=str(uuid.uuid4())[:8], - content=item["content"], - confidence=Confidence(item.get("confidence", "medium")), - basis=item.get("basis", f"onboarding:{topic}"), - )) + context_store.add_learning( + Learning( + id=str(uuid.uuid4())[:8], + content=item["content"], + confidence=Confidence(item.get("confidence", "medium")), + basis=item.get("basis", f"onboarding:{topic}"), + ) + ) entries += 1 # Store campaign campaign_data = data.get("campaign") if campaign_data and campaign_data.get("description"): - cid = context_store.create_campaign( + context_store.create_campaign( description=campaign_data["description"], target=campaign_data.get("target"), ) entries += 1 # Store watchpoints - for item in (data.get("watchpoints") or []): + for item in data.get("watchpoints") or []: if item and item.get("target"): - context_store.add_watchpoint(Watchpoint( - id=str(uuid.uuid4())[:8], - target=item["target"], - condition=item.get("condition", "monitor"), - )) + context_store.add_watchpoint( + Watchpoint( + id=str(uuid.uuid4())[:8], + target=item["target"], + condition=item.get("condition", "monitor"), + ) + ) entries += 1 # Store session intent @@ -356,12 +369,14 @@ async def _extract_with_llm( for field_name in ("organism", "microscope"): value = data.get(field_name) if value: - context_store.add_learning(Learning( - id=str(uuid.uuid4())[:8], - content=f"Lab {field_name}: {value}", - confidence=Confidence.HIGH, - basis="onboarding:identity", - )) + context_store.add_learning( + Learning( + id=str(uuid.uuid4())[:8], + content=f"Lab {field_name}: {value}", + confidence=Confidence.HIGH, + basis="onboarding:identity", + ) + ) entries += 1 return { @@ -375,8 +390,8 @@ def _extract_basic( response: str, topic: str, context_store: ContextStore, - session_id: Optional[str], -) -> Dict[str, Any]: + session_id: str | None, +) -> dict[str, Any]: """ Basic keyword-based extraction when no LLM is available. @@ -385,12 +400,14 @@ def _extract_basic( entries = 0 # Store the raw response as a learning - context_store.add_learning(Learning( - id=str(uuid.uuid4())[:8], - content=f"Researcher ({topic}): {response[:500]}", - confidence=Confidence.MEDIUM, - basis=f"onboarding:{topic}", - )) + context_store.add_learning( + Learning( + id=str(uuid.uuid4())[:8], + content=f"Researcher ({topic}): {response[:500]}", + confidence=Confidence.MEDIUM, + basis=f"onboarding:{topic}", + ) + ) entries += 1 # If this is a session topic and we have a session ID, create intent @@ -411,8 +428,9 @@ def _extract_basic( # Ingestion result → context store # --------------------------------------------------------------------------- + def apply_ingestion_to_context( - result: "IngestionResult", + result: "IngestionResult", # noqa: F821 context_store: ContextStore, ) -> int: """ @@ -443,34 +461,40 @@ def apply_ingestion_to_context( # Learnings for item in result.learnings: if item.get("content"): - context_store.add_learning(Learning( - id=str(uuid.uuid4())[:8], - content=item["content"], - confidence=Confidence(item.get("confidence", "medium")), - basis=f"ingestion:{result.source}", - )) + context_store.add_learning( + Learning( + id=str(uuid.uuid4())[:8], + content=item["content"], + confidence=Confidence(item.get("confidence", "medium")), + basis=f"ingestion:{result.source}", + ) + ) entries += 1 # Imaging parameters as learnings if result.imaging_parameters: for key, value in result.imaging_parameters.items(): if value is not None and key != "notes": - context_store.add_learning(Learning( - id=str(uuid.uuid4())[:8], - content=f"Recommended {key}: {value}", - confidence=Confidence.MEDIUM, - basis=f"ingestion:{result.source}", - )) + context_store.add_learning( + Learning( + id=str(uuid.uuid4())[:8], + content=f"Recommended {key}: {value}", + confidence=Confidence.MEDIUM, + basis=f"ingestion:{result.source}", + ) + ) entries += 1 # Store notes separately if present notes = result.imaging_parameters.get("notes") if notes: - context_store.add_learning(Learning( - id=str(uuid.uuid4())[:8], - content=f"Imaging notes: {notes}", - confidence=Confidence.MEDIUM, - basis=f"ingestion:{result.source}", - )) + context_store.add_learning( + Learning( + id=str(uuid.uuid4())[:8], + content=f"Imaging notes: {notes}", + confidence=Confidence.MEDIUM, + basis=f"ingestion:{result.source}", + ) + ) entries += 1 # Sample requirements as a learning @@ -480,22 +504,26 @@ def apply_ingestion_to_context( if value and key != "notes": parts.append(f"{key}: {value}") if parts: - context_store.add_learning(Learning( - id=str(uuid.uuid4())[:8], - content=f"Sample requirements: {', '.join(parts)}", - confidence=Confidence.MEDIUM, - basis=f"ingestion:{result.source}", - )) + context_store.add_learning( + Learning( + id=str(uuid.uuid4())[:8], + content=f"Sample requirements: {', '.join(parts)}", + confidence=Confidence.MEDIUM, + basis=f"ingestion:{result.source}", + ) + ) entries += 1 # Watchpoints for item in result.watchpoints: if item.get("target"): - context_store.add_watchpoint(Watchpoint( - id=str(uuid.uuid4())[:8], - target=item["target"], - condition=item.get("condition", "monitor"), - )) + context_store.add_watchpoint( + Watchpoint( + id=str(uuid.uuid4())[:8], + target=item["target"], + condition=item.get("condition", "monitor"), + ) + ) entries += 1 logger.info(f"Applied {entries} entries from ingestion of {result.source}") diff --git a/gently/harness/memory/serialization.py b/gently/harness/memory/serialization.py index 71442c53..6996997f 100644 --- a/gently/harness/memory/serialization.py +++ b/gently/harness/memory/serialization.py @@ -6,29 +6,26 @@ """ import json -from datetime import datetime -from typing import Any, Dict, List +from typing import Any from .model import ( Campaign, - Project, - SessionIntent, - PlannedSession, - Learning, - Observation, - Expectation, - Watchpoint, - Question, - EmbryoUnderstanding, Context, - Significance, - Confidence, + EmbryoUnderstanding, + Expectation, ExpectationStatus, + Learning, + Observation, + PlannedSession, PlannedSessionStatus, + Project, + Question, + SessionIntent, + Watchpoint, ) -def context_to_dict(context: Context) -> Dict[str, Any]: +def context_to_dict(context: Context) -> dict[str, Any]: """ Serialize a Context to a dictionary. @@ -38,7 +35,9 @@ def context_to_dict(context: Context) -> Dict[str, Any]: "intentions": { "campaigns": [_campaign_to_dict(c) for c in context.intentions.campaigns], "projects": [_project_to_dict(p) for p in context.intentions.projects], - "planned_sessions": [_planned_session_to_dict(ps) for ps in context.intentions.planned_sessions], + "planned_sessions": [ + _planned_session_to_dict(ps) for ps in context.intentions.planned_sessions + ], "current_focus": context.intentions.current_focus, "session_intent": _session_intent_to_dict(context.intentions.session_intent) if context.intentions.session_intent @@ -46,10 +45,11 @@ def context_to_dict(context: Context) -> Dict[str, Any]: }, "understanding": { "embryo_states": { - eid: _embryo_to_dict(e) - for eid, e in context.understanding.embryo_states.items() + eid: _embryo_to_dict(e) for eid, e in context.understanding.embryo_states.items() }, - "learnings": [_learning_to_dict(l) for l in context.understanding.learnings], + "learnings": [ + _learning_to_dict(learning) for learning in context.understanding.learnings + ], }, "observations": [_observation_to_dict(o) for o in context.observations], "expectations": [_expectation_to_dict(e) for e in context.expectations], @@ -82,8 +82,11 @@ def context_summary(context: Context) -> str: lines.append(f" - {c.display_name}{progress}") # Planned sessions - upcoming = [ps for ps in context.intentions.planned_sessions - if ps.status == PlannedSessionStatus.PLANNED] + upcoming = [ + ps + for ps in context.intentions.planned_sessions + if ps.status == PlannedSessionStatus.PLANNED + ] if upcoming: lines.append(f"Planned sessions: {len(upcoming)} upcoming") for ps in upcoming[:2]: @@ -100,7 +103,10 @@ def context_summary(context: Context) -> str: tracked = [e for e in embryos.values() if e.is_tracked] hatched = [e for e in embryos.values() if e.is_hatched] attention = [e for e in embryos.values() if e.needs_attention] - lines.append(f"Embryos: {len(tracked)} tracked, {len(hatched)} hatched, {len(attention)} need attention") + lines.append( + f"Embryos: {len(tracked)} tracked, {len(hatched)} hatched," + f" {len(attention)} need attention" + ) # Expectations pending = [e for e in context.expectations if e.status == ExpectationStatus.PENDING] @@ -113,7 +119,9 @@ def context_summary(context: Context) -> str: lines.append(f"Watchpoints: {len(active_wp)} active") # Questions - open_q = [q for q in context.attention.open_questions if q.status.value in ("open", "investigating")] + open_q = [ + q for q in context.attention.open_questions if q.status.value in ("open", "investigating") + ] if open_q: lines.append(f"Questions: {len(open_q)} open") @@ -128,7 +136,8 @@ def context_summary(context: Context) -> str: # Helper functions # --------------------------------------------------------------------------- -def _campaign_to_dict(c: Campaign) -> Dict[str, Any]: + +def _campaign_to_dict(c: Campaign) -> dict[str, Any]: return { "id": c.id, "description": c.description, @@ -143,7 +152,7 @@ def _campaign_to_dict(c: Campaign) -> Dict[str, Any]: } -def _project_to_dict(p: Project) -> Dict[str, Any]: +def _project_to_dict(p: Project) -> dict[str, Any]: return { "id": p.id, "description": p.description, @@ -154,7 +163,7 @@ def _project_to_dict(p: Project) -> Dict[str, Any]: } -def _planned_session_to_dict(ps: PlannedSession) -> Dict[str, Any]: +def _planned_session_to_dict(ps: PlannedSession) -> dict[str, Any]: return { "id": ps.id, "title": ps.title, @@ -172,7 +181,7 @@ def _planned_session_to_dict(ps: PlannedSession) -> Dict[str, Any]: } -def _session_intent_to_dict(s: SessionIntent) -> Dict[str, Any]: +def _session_intent_to_dict(s: SessionIntent) -> dict[str, Any]: return { "session_id": s.session_id, "planned_intent": s.planned_intent, @@ -183,17 +192,17 @@ def _session_intent_to_dict(s: SessionIntent) -> Dict[str, Any]: } -def _learning_to_dict(l: Learning) -> Dict[str, Any]: +def _learning_to_dict(learning: Learning) -> dict[str, Any]: return { - "id": l.id, - "content": l.content, - "confidence": l.confidence.value, - "basis": l.basis, - "created_at": l.created_at.isoformat(), + "id": learning.id, + "content": learning.content, + "confidence": learning.confidence.value, + "basis": learning.basis, + "created_at": learning.created_at.isoformat(), } -def _embryo_to_dict(e: EmbryoUnderstanding) -> Dict[str, Any]: +def _embryo_to_dict(e: EmbryoUnderstanding) -> dict[str, Any]: return { "embryo_id": e.embryo_id, "current_stage": e.current_stage, @@ -208,7 +217,7 @@ def _embryo_to_dict(e: EmbryoUnderstanding) -> Dict[str, Any]: } -def _observation_to_dict(o: Observation) -> Dict[str, Any]: +def _observation_to_dict(o: Observation) -> dict[str, Any]: return { "id": o.id, "timestamp": o.timestamp.isoformat(), @@ -222,7 +231,7 @@ def _observation_to_dict(o: Observation) -> Dict[str, Any]: } -def _expectation_to_dict(e: Expectation) -> Dict[str, Any]: +def _expectation_to_dict(e: Expectation) -> dict[str, Any]: return { "id": e.id, "target": e.target, @@ -236,7 +245,7 @@ def _expectation_to_dict(e: Expectation) -> Dict[str, Any]: } -def _watchpoint_to_dict(w: Watchpoint) -> Dict[str, Any]: +def _watchpoint_to_dict(w: Watchpoint) -> dict[str, Any]: return { "id": w.id, "target": w.target, @@ -247,7 +256,7 @@ def _watchpoint_to_dict(w: Watchpoint) -> Dict[str, Any]: } -def _question_to_dict(q: Question) -> Dict[str, Any]: +def _question_to_dict(q: Question) -> dict[str, Any]: return { "id": q.id, "content": q.content, diff --git a/gently/harness/memory/startup_wizard.py b/gently/harness/memory/startup_wizard.py index 4b22b3d7..f82c706e 100644 --- a/gently/harness/memory/startup_wizard.py +++ b/gently/harness/memory/startup_wizard.py @@ -13,14 +13,17 @@ import logging import uuid -from typing import Any, Callable, Coroutine, Optional +from collections.abc import Callable, Coroutine +from typing import Any from gently.settings import settings -from .gap_assessment import assess_gaps, ContextGapReport + +from .gap_assessment import ContextGapReport, assess_gaps from .model import Confidence, Learning from .onboarding import ( process_onboarding_response, ) + try: from .file_store import FileContextStore as ContextStore except ImportError: @@ -50,12 +53,12 @@ def __init__( self, context_store: ContextStore, session_id: str, - claude_client: Optional[Any] = None, + claude_client: Any | None = None, ): self.context_store = context_store self.session_id = session_id self.claude_client = claude_client - self._gap_report: Optional[ContextGapReport] = None + self._gap_report: ContextGapReport | None = None # ------------------------------------------------------------------ # Public properties @@ -181,14 +184,15 @@ async def _step_campaign_select(self, send_fn, wait_for_input, wait_for_choice): """Show active campaigns in a picker.""" campaigns = self.context_store.get_active_campaigns() options = [ - {"id": c.id, "label": c.display_name, "description": c.target or ""} - for c in campaigns + {"id": c.id, "label": c.display_name, "description": c.target or ""} for c in campaigns ] - options.append({ - "id": "__new__", - "label": "Start something new", - "description": "Describe a new research direction", - }) + options.append( + { + "id": "__new__", + "label": "Start something new", + "description": "Describe a new research direction", + } + ) await self._say(send_fn, "Welcome back.") @@ -233,11 +237,13 @@ async def _step_planned_session(self, planned_sessions, send_fn, wait_for_choice } for ps in planned_sessions ] - options.append({ - "id": "__other__", - "label": "Something else", - "description": "Not one of these", - }) + options.append( + { + "id": "__other__", + "label": "Something else", + "description": "Not one of these", + } + ) count = len(planned_sessions) await self._say( @@ -272,7 +278,7 @@ async def _step_session_intent(self, send_fn, wait_for_input, wait_for_choice): if campaigns: label = campaigns[0].display_name - await self._say(send_fn, f"Continuing \"{label}\" — what's the plan?") + await self._say(send_fn, f'Continuing "{label}" — what\'s the plan?') else: await self._say(send_fn, "What's the plan for this session?") @@ -338,12 +344,14 @@ async def _step_session_intent(self, send_fn, wait_for_input, wait_for_choice): def _store_learning(self, content: str, basis: str = "onboarding:identity"): """Write a learning directly — no LLM round-trip.""" - self.context_store.add_learning(Learning( - id=str(uuid.uuid4())[:8], - content=content, - confidence=Confidence.HIGH, - basis=basis, - )) + self.context_store.add_learning( + Learning( + id=str(uuid.uuid4())[:8], + content=content, + confidence=Confidence.HIGH, + basis=basis, + ) + ) async def _extract(self, send_fn, response: str, topic: str): """Run LLM extraction silently — no acknowledgment message. @@ -388,9 +396,9 @@ async def _finish(self, send_fn): campaign_name = campaigns[0].display_name if campaigns else None plan = intent.planned_intent if intent else None organism = None - for l in learnings: - if l.content.startswith("Lab organism:"): - organism = l.content.split(":", 1)[1].strip() + for learning in learnings: + if learning.content.startswith("Lab organism:"): + organism = learning.content.split(":", 1)[1].strip() break # Try LLM-generated summary @@ -425,22 +433,28 @@ async def _finish(self, send_fn): if not summary: if organism: - summary = f"Got it — {organism}. I'm ready to help with your imaging session. What would you like to do?" + summary = ( + f"Got it — {organism}. I'm ready to help with your imaging session." + " What would you like to do?" + ) else: summary = "All set. What can I help with?" await send_fn({"type": "text", "text": summary}) - await send_fn({ - "type": "stream_end", - "tokens": _empty_tokens(), - "wizard_complete": True, - }) + await send_fn( + { + "type": "stream_end", + "tokens": _empty_tokens(), + "wizard_complete": True, + } + ) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- + def _is_skip(text: str) -> bool: return text.strip().lower() in SKIP_PHRASES diff --git a/gently/harness/memory/store.py b/gently/harness/memory/store.py index f3ae128f..5792981a 100644 --- a/gently/harness/memory/store.py +++ b/gently/harness/memory/store.py @@ -12,19 +12,18 @@ import logging import sqlite3 +import uuid from contextlib import contextmanager from datetime import datetime from pathlib import Path -import uuid - -from .model import ( - Context, -) from ._intentions import IntentionsMixin from ._ml_pipelines import MlPipelinesMixin from ._plans import PlansMixin from ._understanding import UnderstandingMixin +from .model import ( + Context, +) logger = logging.getLogger(__name__) @@ -248,8 +247,10 @@ CREATE INDEX IF NOT EXISTS idx_session_campaigns_campaign ON session_campaigns(campaign_id); CREATE INDEX IF NOT EXISTS idx_planned_sessions_date ON planned_sessions(scheduled_date); CREATE INDEX IF NOT EXISTS idx_planned_sessions_status ON planned_sessions(status); -CREATE INDEX IF NOT EXISTS idx_planned_session_campaigns_ps ON planned_session_campaigns(planned_session_id); -CREATE INDEX IF NOT EXISTS idx_planned_session_campaigns_c ON planned_session_campaigns(campaign_id); +CREATE INDEX IF NOT EXISTS idx_planned_session_campaigns_ps + ON planned_session_campaigns(planned_session_id); +CREATE INDEX IF NOT EXISTS idx_planned_session_campaigns_c + ON planned_session_campaigns(campaign_id); CREATE INDEX IF NOT EXISTS idx_plan_items_campaign ON plan_items(campaign_id); CREATE INDEX IF NOT EXISTS idx_plan_items_status ON plan_items(status); CREATE INDEX IF NOT EXISTS idx_plan_items_type ON plan_items(type); @@ -257,7 +258,8 @@ CREATE INDEX IF NOT EXISTS idx_plan_item_deps_item ON plan_item_dependencies(item_id); CREATE INDEX IF NOT EXISTS idx_plan_item_deps_dep ON plan_item_dependencies(depends_on_id); CREATE INDEX IF NOT EXISTS idx_plan_snapshots_campaign ON plan_snapshots(campaign_id); -CREATE INDEX IF NOT EXISTS idx_plan_snapshots_version ON plan_snapshots(campaign_id, version_number); +CREATE INDEX IF NOT EXISTS idx_plan_snapshots_version + ON plan_snapshots(campaign_id, version_number); """ @@ -265,6 +267,7 @@ # ContextStore # --------------------------------------------------------------------------- + class ContextStore(IntentionsMixin, PlansMixin, UnderstandingMixin, MlPipelinesMixin): """ SQLite-backed storage for the agent's context. diff --git a/gently/harness/microscope.py b/gently/harness/microscope.py index af25716c..6c684469 100644 --- a/gently/harness/microscope.py +++ b/gently/harness/microscope.py @@ -19,7 +19,6 @@ import logging from pathlib import Path -from typing import Dict, Optional, Set logger = logging.getLogger(__name__) @@ -51,7 +50,7 @@ class Microscope: DESCRIPTION: str = "" @property - def plans(self) -> Set[str]: + def plans(self) -> set[str]: """Discover available plans by inspecting _plan_* methods.""" return { name[6:] # strip "_plan_" prefix @@ -154,11 +153,11 @@ def __init__(self, http_url: str): self.http_url = http_url self._session = None self._connected = False - self._available_plans: Set[str] = set() + self._available_plans: set[str] = set() self._plan_schemas: list = [] # Anthropic tool-format schemas @property - def plans(self) -> Set[str]: + def plans(self) -> set[str]: return self._available_plans @property @@ -185,7 +184,8 @@ async def connect(self) -> bool: return False info = await resp.json() - # Plans come as Anthropic tool schemas (list of dicts with name, description, input_schema) + # Plans come as Anthropic tool schemas + # (list of dicts with name, description, input_schema) plans_data = info.get("plans", []) if plans_data and isinstance(plans_data[0], dict): # New format: list of tool schemas @@ -194,7 +194,14 @@ async def connect(self) -> bool: else: # Legacy format: list of plan name strings self._available_plans = set(plans_data) - self._plan_schemas = [{"name": p, "description": p, "input_schema": {"type": "object", "properties": {}}} for p in plans_data] + self._plan_schemas = [ + { + "name": p, + "description": p, + "input_schema": {"type": "object", "properties": {}}, + } + for p in plans_data + ] self.DESCRIPTION = info.get("description", "") self._connected = True @@ -256,6 +263,7 @@ def _resolve_file_refs(self, data: dict) -> None: for key, val in list(data.items()): if self._is_file_ref(val): import tifffile + path = Path(val["path"]) data[key] = tifffile.imread(str(path)) data[f"{key}_path"] = str(path) @@ -293,24 +301,31 @@ def register_microscope_tools(microscope: Microscope, registry=None) -> int: """ if registry is None: from gently.harness.tools.registry import get_tool_registry + registry = get_tool_registry() - from gently.harness.tools.registry import ToolDefinition, ToolParameter, ToolCategory + from gently.harness.tools.registry import ( + ToolCategory, + ToolDefinition, + ToolParameter, + ) - schemas = getattr(microscope, 'plan_schemas', []) + schemas = getattr(microscope, "plan_schemas", []) if not schemas: return 0 def _make_handler(pname): """Create an async handler that delegates to microscope.execute().""" - async def handler(context: dict = None, **params): - ms = context.get('client') if context else microscope + + async def handler(context: dict | None = None, **params): + ms = context.get("client") if context else microscope if ms is None: return "Error: microscope not connected" result = await ms.execute(pname, **params) - if not result.get('success', False): + if not result.get("success", False): return f"Error: {result.get('error', 'unknown')}" return result + handler.__name__ = f"microscope_{pname}" return handler diff --git a/gently/harness/orchestration/plan_synthesis.py b/gently/harness/orchestration/plan_synthesis.py index a3ae54a1..7c31f5cb 100644 --- a/gently/harness/orchestration/plan_synthesis.py +++ b/gently/harness/orchestration/plan_synthesis.py @@ -2,19 +2,19 @@ Bluesky plan synthesis from natural language goals """ -import logging -from typing import Dict, List, Optional -from jinja2 import Template import ast +import logging import re +from jinja2 import Template + logger = logging.getLogger(__name__) class PlanValidator: """Validates generated Bluesky plans for safety and correctness""" - def __init__(self, devices: Optional[Dict] = None): + def __init__(self, devices: dict | None = None): """ Parameters ---------- @@ -22,7 +22,7 @@ def __init__(self, devices: Optional[Dict] = None): Available Ophyd devices with their limits """ self.devices = devices or {} - self.errors: List[str] = [] + self.errors: list[str] = [] def is_valid(self, plan_code: str) -> bool: """ @@ -55,11 +55,11 @@ def is_valid(self, plan_code: str) -> bool: # Check for dangerous operations dangerous_patterns = [ - r'import\s+os', - r'import\s+subprocess', - r'eval\(', - r'exec\(', - r'__import__', + r"import\s+os", + r"import\s+subprocess", + r"eval\(", + r"exec\(", + r"__import__", ] for pattern in dangerous_patterns: @@ -68,19 +68,19 @@ def is_valid(self, plan_code: str) -> bool: return False # Check for proper yield from usage - if 'def ' in plan_code and 'yield from' not in plan_code: + if "def " in plan_code and "yield from" not in plan_code: self.errors.append("Plan function must use 'yield from' for Bluesky operations") return False # Warnings (not errors, but noted) - if 'metadata=' not in plan_code: + if "metadata=" not in plan_code: self.errors.append("Warning: Plan should include metadata for provenance") # If we have errors (not just warnings), fail actual_errors = [e for e in self.errors if not e.startswith("Warning:")] return len(actual_errors) == 0 - def check_parameters(self, params: Dict) -> bool: + def check_parameters(self, params: dict) -> bool: """ Validate acquisition parameters @@ -97,19 +97,22 @@ def check_parameters(self, params: Dict) -> bool: self.errors = [] # num_slices - if 'num_slices' in params: - if not (10 <= params['num_slices'] <= 200): + if "num_slices" in params: + if not (10 <= params["num_slices"] <= 200): self.errors.append(f"num_slices {params['num_slices']} outside range [10, 200]") # exposure_ms - if 'exposure_ms' in params: - if not (5 <= params['exposure_ms'] <= 100): + if "exposure_ms" in params: + if not (5 <= params["exposure_ms"] <= 100): self.errors.append(f"exposure_ms {params['exposure_ms']} outside range [5, 100]") # interval_seconds - if 'interval_seconds' in params: - if params['interval_seconds'] < 10: - self.errors.append(f"interval_seconds {params['interval_seconds']} too short (min 10s for hardware settle)") + if "interval_seconds" in params: + if params["interval_seconds"] < 10: + self.errors.append( + f"interval_seconds {params['interval_seconds']} too short" + " (min 10s for hardware settle)" + ) return len(self.errors) == 0 @@ -131,16 +134,16 @@ class PlanLibrary: """Collection of plan templates""" def __init__(self): - self.templates: Dict[str, PlanTemplate] = {} + self.templates: dict[str, PlanTemplate] = {} self._load_default_templates() def _load_default_templates(self): """Load built-in plan templates""" # Multi-embryo adaptive timelapse - self.templates['adaptive_timelapse'] = PlanTemplate( - name='adaptive_timelapse', - description='Multi-embryo timelapse with adaptive parameters', + self.templates["adaptive_timelapse"] = PlanTemplate( + name="adaptive_timelapse", + description="Multi-embryo timelapse with adaptive parameters", template_str=''' def adaptive_timelapse_plan( volume_scanner, @@ -233,12 +236,13 @@ def adaptive_timelapse_plan( # Wait for interval if timepoint < num_timepoints - 1: yield from bps.sleep(next_interval) -''') +''', + ) # Single embryo high-resolution scan - self.templates['single_highres'] = PlanTemplate( - name='single_highres', - description='Single embryo high-resolution volume', + self.templates["single_highres"] = PlanTemplate( + name="single_highres", + description="Single embryo high-resolution volume", template_str=''' def single_highres_scan_plan( volume_scanner, @@ -284,15 +288,18 @@ def single_highres_scan_plan( # Notify agent agent.on_volume_acquired(embryo_id, 0, volume_scanner) -''') +''', + ) def get_template(self, plan_type: str) -> PlanTemplate: """Get template by name""" if plan_type not in self.templates: - raise ValueError(f"Unknown plan type: {plan_type}. Available: {list(self.templates.keys())}") + raise ValueError( + f"Unknown plan type: {plan_type}. Available: {list(self.templates.keys())}" + ) return self.templates[plan_type] - def list_templates(self) -> List[str]: + def list_templates(self) -> list[str]: """List available template names""" return list(self.templates.keys()) @@ -300,13 +307,21 @@ def list_templates(self) -> List[str]: class PlanSynthesizer: """Converts scientific goals into executable Bluesky plans""" - def __init__(self, plan_library: Optional[PlanLibrary] = None, - validator: Optional[PlanValidator] = None): + def __init__( + self, + plan_library: PlanLibrary | None = None, + validator: PlanValidator | None = None, + ): self.library = plan_library or PlanLibrary() self.validator = validator or PlanValidator() - def synthesize(self, goal: str, embryo_ids: List[str], - params: Dict, plan_type: str = 'adaptive_timelapse') -> str: + def synthesize( + self, + goal: str, + embryo_ids: list[str], + params: dict, + plan_type: str = "adaptive_timelapse", + ) -> str: """ Generate Bluesky plan from goal @@ -337,10 +352,10 @@ def synthesize(self, goal: str, embryo_ids: List[str], plan_code = template.render( goal=goal, embryo_ids=embryo_ids, - num_slices=params.get('num_slices', 50), - exposure_ms=params.get('exposure_ms', 10.0), - interval_seconds=params.get('interval_seconds', 120), - num_timepoints=params.get('num_timepoints', 500), + num_slices=params.get("num_slices", 50), + exposure_ms=params.get("exposure_ms", 10.0), + interval_seconds=params.get("interval_seconds", 120), + num_timepoints=params.get("num_timepoints", 500), ) # Validate generated code @@ -366,10 +381,15 @@ def classify_goal(self, goal: str) -> str: goal_lower = goal.lower() # Pattern matching for plan types - if any(word in goal_lower for word in ['timelapse', 'time-lapse', 'monitor', 'track', 'all embryos']): - return 'adaptive_timelapse' - elif any(word in goal_lower for word in ['high-res', 'high resolution', 'detailed', 'single']): - return 'single_highres' + if any( + word in goal_lower + for word in ["timelapse", "time-lapse", "monitor", "track", "all embryos"] + ): + return "adaptive_timelapse" + elif any( + word in goal_lower for word in ["high-res", "high resolution", "detailed", "single"] + ): + return "single_highres" else: # Default to timelapse - return 'adaptive_timelapse' + return "adaptive_timelapse" diff --git a/gently/harness/plan_mode/prompt.py b/gently/harness/plan_mode/prompt.py index 19ea52b7..4228249c 100644 --- a/gently/harness/plan_mode/prompt.py +++ b/gently/harness/plan_mode/prompt.py @@ -5,11 +5,8 @@ rather than a live microscope control agent. """ -from typing import Optional - -from gently.organisms import get_organism from gently.hardware import get_hardware - +from gently.organisms import get_organism PLAN_MODE_IDENTITY = """\ You are a scientific research planner — the same microscopy agent, but right now @@ -147,9 +144,9 @@ def build_plan_prompt( - context_summary: Optional[str] = None, - active_plan_summary: Optional[str] = None, - memory_awareness: Optional[str] = None, + context_summary: str | None = None, + active_plan_summary: str | None = None, + memory_awareness: str | None = None, ) -> str: """ Build the system prompt for plan mode. diff --git a/gently/harness/plan_mode/tools/__init__.py b/gently/harness/plan_mode/tools/__init__.py index e39e6524..c59de53c 100644 --- a/gently/harness/plan_mode/tools/__init__.py +++ b/gently/harness/plan_mode/tools/__init__.py @@ -6,8 +6,6 @@ """ # Import tool modules so @tool decorators register them -from . import planning -from . import lab_context -from . import research -from . import validation -from . import templates +from . import lab_context, planning, research, templates, validation + +__all__ = ["lab_context", "planning", "research", "templates", "validation"] diff --git a/gently/harness/plan_mode/tools/lab_context.py b/gently/harness/plan_mode/tools/lab_context.py index 876092ef..574a725d 100644 --- a/gently/harness/plan_mode/tools/lab_context.py +++ b/gently/harness/plan_mode/tools/lab_context.py @@ -5,9 +5,7 @@ learnings, and hardware specs to inform experimental design. """ -from typing import Dict, Optional - -from ...tools.registry import tool, ToolCategory, ToolExample +from ...tools.registry import ToolCategory, ToolExample, tool @tool( @@ -27,7 +25,7 @@ ) async def query_lab_history( query: str, - context: Dict = None, + context: dict | None = None, ) -> str: """Search lab history for relevant context.""" agent = context.get("agent") if context else None @@ -59,19 +57,19 @@ async def query_lab_history( # Search learnings learnings = store.get_learnings(limit=100) matching_learnings = [ - l for l in learnings - if any(term in l.content.lower() for term in query_lower.split()) + learning + for learning in learnings + if any(term in learning.content.lower() for term in query_lower.split()) ] if matching_learnings: results.append("\n## Relevant Learnings") - for l in matching_learnings[:10]: - results.append(f"- {l.content} (confidence: {l.confidence.value})") + for learning in matching_learnings[:10]: + results.append(f"- {learning.content} (confidence: {learning.confidence.value})") # Search observations observations = store.get_recent_observations(limit=100) matching_obs = [ - o for o in observations - if any(term in o.content.lower() for term in query_lower.split()) + o for o in observations if any(term in o.content.lower() for term in query_lower.split()) ] if matching_obs: results.append("\n## Relevant Observations") @@ -116,7 +114,7 @@ async def query_lab_history( ) async def check_hardware_capability( question: str, - context: Dict = None, + context: dict | None = None, ) -> str: """Check hardware capabilities against a question.""" from gently.hardware import get_hardware diff --git a/gently/harness/plan_mode/tools/planning.py b/gently/harness/plan_mode/tools/planning.py index 2c33d43c..34785e5f 100644 --- a/gently/harness/plan_mode/tools/planning.py +++ b/gently/harness/plan_mode/tools/planning.py @@ -7,16 +7,14 @@ """ import dataclasses -import json -from typing import Dict, List, Optional - -from ...tools.registry import tool, ToolCategory, ToolExample +from ...tools.registry import ToolCategory, ToolExample, tool # --------------------------------------------------------------------------- # Campaign / Phase Management # --------------------------------------------------------------------------- + @tool( name="create_campaign", description=( @@ -39,10 +37,10 @@ ) async def create_campaign( description: str, - shorthand: str = None, - target: str = None, - parent_id: str = None, - context: Dict = None, + shorthand: str | None = None, + target: str | None = None, + parent_id: str | None = None, + context: dict | None = None, ) -> str: """Create a campaign or sub-campaign (phase).""" agent = context.get("agent") if context else None @@ -53,8 +51,9 @@ async def create_campaign( if shorthand: import re from datetime import datetime + current_year = str(datetime.now().year) - shorthand = re.sub(r'-20\d{2}$', f'-{current_year}', shorthand) + shorthand = re.sub(r"-20\d{2}$", f"-{current_year}", shorthand) store = agent.context_store cid = store.create_campaign( @@ -72,6 +71,7 @@ async def create_campaign( # Plan Item Management # --------------------------------------------------------------------------- + @tool( name="create_plan_item", description=( @@ -112,15 +112,15 @@ async def create_plan_item( campaign_id: str, type: str, title: str, - description: str = None, - spec: Dict = None, - inherit_from: str = None, - depends_on: List[str] = None, - phase_number: int = None, + description: str | None = None, + spec: dict | None = None, + inherit_from: str | None = None, + depends_on: list[str] | None = None, + phase_number: int | None = None, phase_order: int = -1, - references: List[Dict] = None, - estimated_days: int = None, - context: Dict = None, + references: list[dict] | None = None, + estimated_days: int | None = None, + context: dict | None = None, ) -> str: """Create a plan item within a campaign/phase. @@ -198,15 +198,15 @@ async def create_plan_item( ) async def update_plan_item( item_id: str, - status: str = None, - title: str = None, - description: str = None, - outcome: str = None, - spec: Dict = None, - references: List[Dict] = None, - estimated_days: int = None, - campaign_id: str = None, - context: Dict = None, + status: str | None = None, + title: str | None = None, + description: str | None = None, + outcome: str | None = None, + spec: dict | None = None, + references: list[dict] | None = None, + estimated_days: int | None = None, + campaign_id: str | None = None, + context: dict | None = None, ) -> str: """Update a plan item. item_id can be a UUID, task number (e.g. '3'), or phase.task reference (e.g. '1.3'). campaign_id scopes resolution @@ -240,9 +240,9 @@ async def update_plan_item( if status: changes.append(f"status -> {status}") if outcome: - changes.append(f"outcome recorded") + changes.append("outcome recorded") if spec: - changes.append(f"spec updated") + changes.append("spec updated") if title: changes.append(f"title -> {title}") if references: @@ -263,8 +263,8 @@ async def update_plan_item( async def link_plan_items( item_id: str, depends_on_id: str, - campaign_id: str = None, - context: Dict = None, + campaign_id: str | None = None, + context: dict | None = None, ) -> str: """Add a dependency between plan items. campaign_id scopes resolution when using shorthand refs (e.g. '1.3') with multiple plans.""" @@ -308,8 +308,8 @@ async def link_plan_items( ) async def get_plan_item_tool( ref: str, - campaign_id: str = None, - context: Dict = None, + campaign_id: str | None = None, + context: dict | None = None, ) -> str: """Look up a plan item by natural reference.""" agent = context.get("agent") if context else None @@ -328,6 +328,7 @@ async def get_plan_item_tool( # Plan Review # --------------------------------------------------------------------------- + @tool( name="propose_plan", description=( @@ -346,7 +347,7 @@ async def get_plan_item_tool( ) async def propose_plan( campaign_id: str, - context: Dict = None, + context: dict | None = None, ) -> str: """Render the full plan for review.""" agent = context.get("agent") if context else None @@ -401,7 +402,7 @@ async def propose_plan( # Summary status = store.get_plan_status(campaign_id) - lines.append(f"── Summary ──") + lines.append("── Summary ──") lines.append(f"Total items: {status['total']}") lines.append(f"Completed: {status['completed']}") if status["pending_decisions"]: @@ -485,8 +486,8 @@ def _format_plan_item(item, store, task_num: str = "") -> str: if item.references: ref_strs = [] for r in item.references: - tag = f"[{r.get('source', '').upper()}]" if r.get('source') else "" - cite = r.get('citation', '') + tag = f"[{r.get('source', '').upper()}]" if r.get("source") else "" + cite = r.get("citation", "") ref_strs.append(f"{tag} {cite}") details.append(f" Refs: {'; '.join(ref_strs)}") @@ -505,7 +506,7 @@ def _format_plan_item(item, store, task_num: str = "") -> str: ) async def get_plan_status( campaign_id: str, - context: Dict = None, + context: dict | None = None, ) -> str: """Get plan progress summary.""" agent = context.get("agent") if context else None @@ -551,6 +552,7 @@ async def get_plan_status( # Batch Operations # --------------------------------------------------------------------------- + @tool( name="batch_update_status", description=( @@ -574,10 +576,10 @@ async def get_plan_status( async def batch_update_status( campaign_id: str, new_status: str, - outcome: str = None, - phase_number: int = None, - item_type: str = None, - context: Dict = None, + outcome: str | None = None, + phase_number: int | None = None, + item_type: str | None = None, + context: dict | None = None, ) -> str: """Batch-update status of plan items.""" agent = context.get("agent") if context else None @@ -613,7 +615,8 @@ async def batch_update_status( for dep_id in item.depends_on: dep = store.get_plan_item(dep_id) if dep and dep.status not in ( - PlanItemStatus.COMPLETED, PlanItemStatus.SKIPPED, + PlanItemStatus.COMPLETED, + PlanItemStatus.SKIPPED, ): all_resolved = False break @@ -656,8 +659,8 @@ async def batch_update_spec( campaign_id: str, field_name: str, field_value: object, - phase_number: int = None, - context: Dict = None, + phase_number: int | None = None, + context: dict | None = None, ) -> str: """Batch-update a spec field on imaging items.""" agent = context.get("agent") if context else None @@ -668,6 +671,7 @@ async def batch_update_spec( # Validate field name against ImagingSpec from gently.harness.memory.model import ImagingSpec + valid_fields = {f.name for f in dataclasses.fields(ImagingSpec)} if field_name not in valid_fields: return ( @@ -709,6 +713,7 @@ async def batch_update_spec( # Plan Reorganization # --------------------------------------------------------------------------- + @tool( name="move_plan_item", description=( @@ -732,10 +737,10 @@ async def batch_update_spec( async def move_plan_item( campaign_id: str, item_ref: str, - to_phase_number: int = None, - to_campaign_id: str = None, - position: int = None, - context: Dict = None, + to_phase_number: int | None = None, + to_campaign_id: str | None = None, + position: int | None = None, + context: dict | None = None, ) -> str: """Move a plan item to a different phase.""" agent = context.get("agent") if context else None @@ -794,8 +799,8 @@ async def move_plan_item( ) async def delete_plan_item_tool( item_ref: str, - campaign_id: str = None, - context: Dict = None, + campaign_id: str | None = None, + context: dict | None = None, ) -> str: """Delete a plan item.""" agent = context.get("agent") if context else None @@ -850,9 +855,9 @@ async def delete_plan_item_tool( ) async def reorder_plan_items( campaign_id: str, - item_order: List[str], - phase_number: int = None, - context: Dict = None, + item_order: list[str], + phase_number: int | None = None, + context: dict | None = None, ) -> str: """Reorder plan items within a phase.""" agent = context.get("agent") if context else None @@ -914,10 +919,10 @@ async def reorder_plan_items( async def update_phase( campaign_id: str, phase_number: int, - description: str = None, - shorthand: str = None, - target: str = None, - context: Dict = None, + description: str | None = None, + shorthand: str | None = None, + target: str | None = None, + context: dict | None = None, ) -> str: """Update a phase's metadata.""" agent = context.get("agent") if context else None @@ -958,7 +963,7 @@ async def update_phase( async def delete_phase( campaign_id: str, phase_number: int, - context: Dict = None, + context: dict | None = None, ) -> str: """Delete a phase and its contents.""" agent = context.get("agent") if context else None @@ -992,6 +997,7 @@ async def delete_phase( # Plan Export # --------------------------------------------------------------------------- + @tool( name="export_plan", description=( @@ -1011,7 +1017,7 @@ async def delete_phase( async def export_plan( campaign_id: str, include_validation: bool = False, - context: Dict = None, + context: dict | None = None, ) -> str: """Export a plan as a shareable markdown document.""" agent = context.get("agent") if context else None @@ -1088,7 +1094,7 @@ async def export_plan( seen_ids = set() for item in all_items: for ref in item.references: - ref_key = ref.get('id') or ref.get('citation', '') + ref_key = ref.get("id") or ref.get("citation", "") if ref_key and ref_key not in seen_ids: seen_ids.add(ref_key) all_refs.append(ref) @@ -1096,10 +1102,10 @@ async def export_plan( if all_refs: lines.append("---\n## References\n") for i, r in enumerate(all_refs, 1): - tag = f"[{r.get('source', '').upper()}]" if r.get('source') else "" - cite = r.get('citation', '') - rid = r.get('id', '') - note = r.get('note', '') + tag = f"[{r.get('source', '').upper()}]" if r.get("source") else "" + cite = r.get("citation", "") + rid = r.get("id", "") + note = r.get("note", "") line = f"{i}. {tag} {cite}" if rid: line += f" ({rid})" @@ -1128,10 +1134,11 @@ async def export_plan( def _export_date() -> str: """Return current date in human-readable format.""" from datetime import datetime + return datetime.now().strftime("%Y-%m-%d") -def _export_item(item, store, num: str) -> List[str]: +def _export_item(item, store, num: str) -> list[str]: """Format a plan item for the export document.""" from gently.harness.memory.model import PlanItemStatus @@ -1211,10 +1218,10 @@ def _export_item(item, store, num: str) -> List[str]: if item.references: lines.append("**References:**") for r in item.references: - tag = f"[{r.get('source', '').upper()}]" if r.get('source') else "" - cite = r.get('citation', '') - rid = r.get('id', '') - note = r.get('note', '') + tag = f"[{r.get('source', '').upper()}]" if r.get("source") else "" + cite = r.get("citation", "") + rid = r.get("id", "") + note = r.get("note", "") line = f"- {tag} {cite}" if rid: line += f" ({rid})" @@ -1229,9 +1236,9 @@ def _export_item(item, store, num: str) -> List[str]: async def validate_plan_for_export(campaign_id: str, store) -> str: """Run validation and return a markdown-formatted report for export.""" from .validation import ( - HARDWARE_LIMITS, CONTROL_KEYWORDS, - _check_dependency_cycles, _stage_order, _normalise_stage, - _get_temp_factor, STAGE_TIMING_20C, + CONTROL_KEYWORDS, + HARDWARE_LIMITS, + _check_dependency_cycles, ) items = store.get_plan_items(campaign_id=campaign_id, include_children=True) @@ -1240,11 +1247,15 @@ async def validate_plan_for_export(campaign_id: str, store) -> str: try: from gently.organisms import get_organism + org = get_organism() - presets_mod = __import__(f"gently.organisms.{org.ORGANISM_NAME}.detector_presets", fromlist=["get_detector_presets"]) - valid_detectors = set(presets_mod.get_detector_presets().keys()) + presets_mod = __import__( + f"gently.organisms.{org.ORGANISM_NAME}.detector_presets", + fromlist=["get_detector_presets"], + ) + set(presets_mod.get_detector_presets().keys()) except ImportError: - valid_detectors = set() + pass issues = [] has_control = False @@ -1253,9 +1264,18 @@ async def validate_plan_for_export(campaign_id: str, store) -> str: label = f"{item.title}" text_blob = " ".join(filter(None, [item.title, item.description])).lower() if item.imaging_spec: - text_blob += " " + " ".join(filter(None, [ - item.imaging_spec.strain, item.imaging_spec.genotype, - ])).lower() + text_blob += ( + " " + + " ".join( + filter( + None, + [ + item.imaging_spec.strain, + item.imaging_spec.genotype, + ], + ) + ).lower() + ) if any(kw in text_blob for kw in CONTROL_KEYWORDS): has_control = True @@ -1286,6 +1306,7 @@ async def validate_plan_for_export(campaign_id: str, store) -> str: # Plan Versioning # --------------------------------------------------------------------------- + @tool( name="snapshot_plan", description=( @@ -1306,8 +1327,8 @@ async def validate_plan_for_export(campaign_id: str, store) -> str: ) async def snapshot_plan( campaign_id: str, - label: str = None, - context: Dict = None, + label: str | None = None, + context: dict | None = None, ) -> str: """Save a snapshot of the current plan.""" agent = context.get("agent") if context else None @@ -1346,7 +1367,7 @@ async def snapshot_plan( ) async def list_plan_versions( campaign_id: str, - context: Dict = None, + context: dict | None = None, ) -> str: """List saved plan versions.""" agent = context.get("agent") if context else None @@ -1367,10 +1388,7 @@ async def list_plan_versions( if s.get("summary"): first_line = s["summary"].split("\n")[0] summary_line = f" {first_line}" - lines.append( - f" v{s['version_number']}{label} ({s['version_id']}) " - f"{s['created_at']}" - ) + lines.append(f" v{s['version_number']}{label} ({s['version_id']}) {s['created_at']}") if summary_line: lines.append(f" {summary_line}") @@ -1397,9 +1415,9 @@ async def list_plan_versions( ) async def restore_plan_version( campaign_id: str, - version_id: str = None, - version_number: int = None, - context: Dict = None, + version_id: str | None = None, + version_number: int | None = None, + context: dict | None = None, ) -> str: """Restore a plan to a previous snapshot.""" agent = context.get("agent") if context else None diff --git a/gently/harness/plan_mode/tools/research.py b/gently/harness/plan_mode/tools/research.py index df8e4359..dabd92b5 100644 --- a/gently/harness/plan_mode/tools/research.py +++ b/gently/harness/plan_mode/tools/research.py @@ -14,10 +14,9 @@ import json import logging import ssl -from typing import Dict, List, Optional -from ...tools.registry import tool, ToolCategory, ToolExample from ....settings import settings +from ...tools.registry import ToolCategory, ToolExample, tool logger = logging.getLogger(__name__) @@ -37,12 +36,13 @@ def _ssl_context() -> ssl.SSLContext: """ try: import certifi + return ssl.create_default_context(cafile=certifi.where()) except ImportError: return ssl.create_default_context() -def _ncbi_params(**kwargs) -> Dict: +def _ncbi_params(**kwargs) -> dict: """Add standard NCBI tool/email to query params.""" kwargs["tool"] = _NCBI_TOOL kwargs["email"] = _NCBI_EMAIL @@ -53,6 +53,7 @@ def _ncbi_params(**kwargs) -> Dict: def _http_session(): """Create an aiohttp ClientSession with explicit SSL certs.""" import aiohttp + connector = aiohttp.TCPConnector(ssl=_ssl_context()) return aiohttp.ClientSession(connector=connector) @@ -61,11 +62,12 @@ def _http_session(): # PubMed literature search # --------------------------------------------------------------------------- -async def _pubmed_search(query: str, max_results: int) -> List[Dict]: + +async def _pubmed_search(query: str, max_results: int) -> list[dict]: """Search PubMed via E-utilities and return article summaries.""" import aiohttp - results: List[Dict] = [] + results: list[dict] = [] try: async with _http_session() as session: @@ -120,21 +122,23 @@ async def _pubmed_search(query: str, max_results: int) -> List[Dict]: else: author_str = first - results.append({ - "pmid": pmid, - "title": article.get("title", ""), - "authors": author_str, - "journal": article.get("fulljournalname", article.get("source", "")), - "year": article.get("pubdate", "")[:4], - "doi": next( - ( - aid.get("value", "") - for aid in article.get("articleids", []) - if aid.get("idtype") == "doi" + results.append( + { + "pmid": pmid, + "title": article.get("title", ""), + "authors": author_str, + "journal": article.get("fulljournalname", article.get("source", "")), + "year": article.get("pubdate", "")[:4], + "doi": next( + ( + aid.get("value", "") + for aid in article.get("articleids", []) + if aid.get("idtype") == "doi" + ), + "", ), - "", - ), - }) + } + ) except Exception as e: logger.warning(f"PubMed search failed: {e}", exc_info=True) @@ -146,11 +150,12 @@ async def _pubmed_search(query: str, max_results: int) -> List[Dict]: # NCBI Gene search # --------------------------------------------------------------------------- -async def _ncbi_gene_search(query: str, max_results: int = 5) -> List[Dict]: + +async def _ncbi_gene_search(query: str, max_results: int = 5) -> list[dict]: """Search NCBI Gene database for C. elegans genes.""" import aiohttp - results: List[Dict] = [] + results: list[dict] = [] try: async with _http_session() as session: @@ -199,14 +204,16 @@ async def _ncbi_gene_search(query: str, max_results: int = 5) -> List[Dict]: if org.get("taxid") != 6239: continue - results.append({ - "gene_id": gid, - "name": gene.get("name", ""), - "description": gene.get("description", ""), - "summary": gene.get("summary", ""), - "chromosome": gene.get("chromosome", ""), - "aliases": gene.get("otheraliases", ""), - }) + results.append( + { + "gene_id": gid, + "name": gene.get("name", ""), + "description": gene.get("description", ""), + "summary": gene.get("summary", ""), + "chromosome": gene.get("chromosome", ""), + "aliases": gene.get("otheraliases", ""), + } + ) except Exception as e: logger.warning(f"NCBI Gene search failed: {e}", exc_info=True) @@ -218,7 +225,8 @@ async def _ncbi_gene_search(query: str, max_results: int = 5) -> List[Dict]: # WormBase strain lookup (by gene ID) # --------------------------------------------------------------------------- -async def _wormbase_gene_strains(wbgene_id: str) -> List[Dict]: + +async def _wormbase_gene_strains(wbgene_id: str) -> list[dict]: """Get strains carrying a gene from WormBase REST API. Parameters @@ -228,7 +236,7 @@ async def _wormbase_gene_strains(wbgene_id: str) -> List[Dict]: """ import aiohttp - results: List[Dict] = [] + results: list[dict] = [] try: url = f"https://rest.wormbase.org/rest/field/gene/{wbgene_id}/strains" @@ -255,12 +263,14 @@ async def _wormbase_gene_strains(wbgene_id: str) -> List[Dict]: for entry in strain_list: if not isinstance(entry, dict): continue - results.append({ - "name": entry.get("label", ""), - "genotype": entry.get("genotype", ""), - "cgc_available": is_cgc or category == "available_from_cgc", - "category": category, - }) + results.append( + { + "name": entry.get("label", ""), + "genotype": entry.get("genotype", ""), + "cgc_available": is_cgc or category == "available_from_cgc", + "category": category, + } + ) except Exception as e: logger.warning(f"WormBase strain lookup failed: {e}", exc_info=True) @@ -268,7 +278,7 @@ async def _wormbase_gene_strains(wbgene_id: str) -> List[Dict]: return results -async def _wormbase_gene_id_lookup(gene_name: str) -> Optional[str]: +async def _wormbase_gene_id_lookup(gene_name: str) -> str | None: """Look up a WormBase gene ID from an NCBI gene name. Uses NCBI gene → dbxrefs to find WormBase ID, or falls back to @@ -314,7 +324,8 @@ async def _wormbase_gene_id_lookup(gene_name: str) -> Optional[str]: # Look for WBGene ID in the response text import re - match = re.search(r'(WBGene\d+)', text) + + match = re.search(r"(WBGene\d+)", text) if match: return match.group(1) @@ -328,7 +339,8 @@ async def _wormbase_gene_id_lookup(gene_name: str) -> Optional[str]: # CGC strain search (HTML scraping fallback) # --------------------------------------------------------------------------- -async def _cgc_search(query: str, field: str = "strain") -> List[Dict]: + +async def _cgc_search(query: str, field: str = "strain") -> list[dict]: """Search CGC (Caenorhabditis Genetics Center) strain database. Parameters @@ -338,10 +350,11 @@ async def _cgc_search(query: str, field: str = "strain") -> List[Dict]: field : str Field to search: "strain", "genotype", "description", "all". """ - import aiohttp import re - results: List[Dict] = [] + import aiohttp + + results: list[dict] = [] try: url = "https://cgc.umn.edu/strain/search" @@ -361,25 +374,27 @@ async def _cgc_search(query: str, field: str = "strain") -> List[Dict]: # CGC uses table rows with strain name, species, genotype, description # Pattern: look for strain links like /strain/OH904 strain_pattern = re.compile( - r'/strain/([A-Z]{1,3}\d+).*?' - r']*>(.*?).*?' # species - r']*>(.*?).*?' # genotype - r']*>(.*?)', # description + r"/strain/([A-Z]{1,3}\d+).*?" + r"]*>(.*?).*?" # species + r"]*>(.*?).*?" # genotype + r"]*>(.*?)", # description re.DOTALL, ) for match in strain_pattern.finditer(html): strain_name = match.group(1).strip() - genotype = re.sub(r'<[^>]+>', '', match.group(3)).strip() - description = re.sub(r'<[^>]+>', '', match.group(4)).strip() + genotype = re.sub(r"<[^>]+>", "", match.group(3)).strip() + description = re.sub(r"<[^>]+>", "", match.group(4)).strip() if strain_name: - results.append({ - "name": strain_name, - "genotype": genotype, - "description": description[:200] if description else "", - "source": "CGC", - }) + results.append( + { + "name": strain_name, + "genotype": genotype, + "description": description[:200] if description else "", + "source": "CGC", + } + ) # If regex didn't work, try a simpler approach — find strain names if not results: @@ -389,12 +404,14 @@ async def _cgc_search(query: str, field: str = "strain") -> List[Dict]: name = m.group(1) if name not in seen: seen.add(name) - results.append({ - "name": name, - "genotype": "", - "description": "", - "source": "CGC", - }) + results.append( + { + "name": name, + "genotype": "", + "description": "", + "source": "CGC", + } + ) # Check for "no results" message if "no results for this search" in html.lower(): @@ -410,6 +427,7 @@ async def _cgc_search(query: str, field: str = "strain") -> List[Dict]: # Tools # --------------------------------------------------------------------------- + @tool( name="search_literature", description=( @@ -430,7 +448,7 @@ async def _cgc_search(query: str, field: str = "strain") -> List[Dict]: async def search_literature( query: str, max_results: int = 5, - context: Dict = None, + context: dict | None = None, ) -> str: """Search PubMed for relevant papers. @@ -447,11 +465,10 @@ async def search_literature( # If no results, try progressively simpler queries if not results: - import re as _re words = query.split() if len(words) > 3: # Strategy 1: keep first ~60% of terms (drop trailing specifics) - shorter = " ".join(words[:max(3, len(words) * 2 // 3)]) + shorter = " ".join(words[: max(3, len(words) * 2 // 3)]) results = await _pubmed_search(shorter, max_results) if results: used_query = shorter @@ -459,8 +476,23 @@ async def search_literature( if not results and len(words) > 4: # Strategy 2: keep only the core noun phrases (drop adjectives/filler) # Remove common filler words - stopwords = {"and", "or", "the", "of", "in", "for", "with", "a", "an", - "using", "based", "via", "during", "after", "before"} + stopwords = { + "and", + "or", + "the", + "of", + "in", + "for", + "with", + "a", + "an", + "using", + "based", + "via", + "during", + "after", + "before", + } core = [w for w in words if w.lower() not in stopwords] if len(core) > 3: core = core[:4] @@ -514,9 +546,9 @@ async def search_literature( lines.append("\n---") lines.append("**Citation references for plan items** (pass to `references` param):") for r in results: - author = r['authors'] or 'Unknown' - year = r['year'] or '' - journal = r['journal'] or '' + author = r["authors"] or "Unknown" + year = r["year"] or "" + journal = r["journal"] or "" cite = f"{author} ({year}) {r['title'][:80]}. {journal}" ref = { "source": "pubmed", @@ -558,7 +590,7 @@ async def search_literature( async def search_strains( query: str, organism: str = "celegans", - context: Dict = None, + context: dict | None = None, ) -> str: """Search for strains and genes via NCBI Gene + WormBase REST API. @@ -574,7 +606,7 @@ async def search_strains( import re as _re # Detect if query looks like a strain name (e.g. OH904, N2, CB1370) - is_strain_query = bool(_re.match(r'^[A-Z]{1,3}\d+$', query.strip())) + is_strain_query = bool(_re.match(r"^[A-Z]{1,3}\d+$", query.strip())) # Step 0: If it looks like a strain name, search CGC directly cgc_results = [] @@ -632,7 +664,8 @@ async def search_strains( lines.append(f" **Strains ({len(strains)} found):**") # Show up to 6, prioritising CGC-available sorted_strains = sorted( - strains, key=lambda s: (not s["cgc_available"], s["name"]), + strains, + key=lambda s: (not s["cgc_available"], s["name"]), ) for s in sorted_strains[:6]: cgc_tag = " [CGC]" if s["cgc_available"] else "" @@ -663,17 +696,21 @@ async def search_strains( citation_refs = [] if cgc_results: for s in cgc_results: - citation_refs.append({ - "source": "cgc", - "citation": f"{s['name']} available from CGC", - "id": f"CGC:{s['name']}", - }) + citation_refs.append( + { + "source": "cgc", + "citation": f"{s['name']} available from CGC", + "id": f"CGC:{s['name']}", + } + ) for gene in genes: - citation_refs.append({ - "source": "ncbi_gene", - "citation": f"{gene['name']} — {gene['description']}", - "id": f"GeneID:{gene['gene_id']}", - }) + citation_refs.append( + { + "source": "ncbi_gene", + "citation": f"{gene['name']} — {gene['description']}", + "id": f"GeneID:{gene['gene_id']}", + } + ) if citation_refs: lines.append("\n---") lines.append("**Citation references for plan items** (pass to `references` param):") @@ -687,7 +724,8 @@ async def search_strains( # Paper reading — full text retrieval # --------------------------------------------------------------------------- -async def _pmid_to_pmcid(pmid: str) -> Optional[str]: + +async def _pmid_to_pmcid(pmid: str) -> str | None: """Convert a PMID to a PMCID via the NCBI ID converter.""" import aiohttp @@ -716,11 +754,12 @@ async def _pmid_to_pmcid(pmid: str) -> Optional[str]: return None -async def _fetch_pmc_fulltext(pmcid: str) -> Optional[str]: +async def _fetch_pmc_fulltext(pmcid: str) -> str | None: """Fetch full text from PubMed Central as XML, parse into sections.""" - import aiohttp import xml.etree.ElementTree as ET + import aiohttp + try: # Strip "PMC" prefix for efetch — it wants just the number pmc_num = pmcid.replace("PMC", "") @@ -836,7 +875,7 @@ def _find_parent(root, target): return None -async def _unpaywall_lookup(doi: str) -> Optional[str]: +async def _unpaywall_lookup(doi: str) -> str | None: """Find an open access full text URL via Unpaywall.""" import aiohttp @@ -861,11 +900,12 @@ async def _unpaywall_lookup(doi: str) -> Optional[str]: return None -async def _fetch_url_text(url: str) -> Optional[str]: +async def _fetch_url_text(url: str) -> str | None: """Fetch a URL and extract text content (HTML → plain text).""" - import aiohttp import re as _re + import aiohttp + try: async with _http_session() as session: async with session.get( @@ -884,10 +924,10 @@ async def _fetch_url_text(url: str) -> Optional[str]: html = await resp.text() # Simple HTML → text: strip tags, collapse whitespace - text = _re.sub(r']*>.*?', '', html, flags=_re.DOTALL) - text = _re.sub(r']*>.*?', '', html, flags=_re.DOTALL) - text = _re.sub(r'<[^>]+>', ' ', text) - text = _re.sub(r'\s+', ' ', text).strip() + text = _re.sub(r"]*>.*?", "", html, flags=_re.DOTALL) + text = _re.sub(r"]*>.*?", "", html, flags=_re.DOTALL) + text = _re.sub(r"<[^>]+>", " ", text) + text = _re.sub(r"\s+", " ", text).strip() if len(text) > 15000: text = text[:15000] + "\n\n[... truncated ...]" @@ -899,7 +939,7 @@ async def _fetch_url_text(url: str) -> Optional[str]: return None -def _read_pdf_file(path: str) -> Optional[str]: +def _read_pdf_file(path: str) -> str | None: """Extract text from a local PDF file using pymupdf if available.""" import os @@ -908,6 +948,7 @@ def _read_pdf_file(path: str) -> Optional[str]: try: import fitz # pymupdf + doc = fitz.open(path) pages = [] for page in doc: @@ -920,15 +961,16 @@ def _read_pdf_file(path: str) -> Optional[str]: return text if text.strip() else None except ImportError: - logger.info("pymupdf not installed — cannot extract PDF text. " - "Install with: pip install pymupdf") + logger.info( + "pymupdf not installed — cannot extract PDF text. Install with: pip install pymupdf" + ) return None except Exception as e: logger.warning(f"PDF extraction failed for {path}: {e}") return None -async def _pubmed_abstract(pmid: str) -> Optional[Dict]: +async def _pubmed_abstract(pmid: str) -> dict | None: """Fetch article metadata + abstract from PubMed.""" import aiohttp @@ -948,6 +990,7 @@ async def _pubmed_abstract(pmid: str) -> Optional[Dict]: xml_text = await resp.text() import xml.etree.ElementTree as ET + root = ET.fromstring(xml_text) article = root.find(".//PubmedArticle") @@ -998,26 +1041,26 @@ async def _pubmed_abstract(pmid: str) -> Optional[Dict]: return None -async def _resolve_reference(reference: str) -> Dict: +async def _resolve_reference(reference: str) -> dict: """Parse a reference string and determine what kind of input it is. Returns a dict with keys: type, pmid, doi, url, path, query """ - import re import os + import re ref = reference.strip() result = {"type": "unknown", "raw": ref} # PMID - m = re.match(r'^(?:PMID[:\s]*)?(\d{6,9})$', ref, re.IGNORECASE) + m = re.match(r"^(?:PMID[:\s]*)?(\d{6,9})$", ref, re.IGNORECASE) if m: result["type"] = "pmid" result["pmid"] = m.group(1) return result # DOI - m = re.search(r'(10\.\d{4,}/[^\s]+)', ref) + m = re.search(r"(10\.\d{4,}/[^\s]+)", ref) if m: result["type"] = "doi" result["doi"] = m.group(1).rstrip(".,;)") @@ -1029,13 +1072,13 @@ async def _resolve_reference(reference: str) -> Dict: result["url"] = ref if ref.startswith("http") else "https://" + ref # Extract PMID from PubMed URLs - m = re.search(r'pubmed\.ncbi.*?/(\d{6,9})', ref) + m = re.search(r"pubmed\.ncbi.*?/(\d{6,9})", ref) if m: result["type"] = "pmid" result["pmid"] = m.group(1) # Extract PMCID from PMC URLs - m = re.search(r'/pmc/articles/(PMC\d+)', ref) + m = re.search(r"/pmc/articles/(PMC\d+)", ref) if m: result["type"] = "pmcid" result["pmcid"] = m.group(1) @@ -1054,7 +1097,7 @@ async def _resolve_reference(reference: str) -> Dict: return result -async def _search_pmid(query: str) -> Optional[str]: +async def _search_pmid(query: str) -> str | None: """Search PubMed for a citation string and return the best PMID. Tries multiple query strategies to handle imprecise citations like @@ -1066,7 +1109,7 @@ async def _search_pmid(query: str) -> Optional[str]: # Detect "Author et al YEAR topic" pattern m = _re.match( - r'^([A-Z][a-z]+)\s+(?:et\s+al\.?\s+)?(\d{4})?\s*(.*)?$', + r"^([A-Z][a-z]+)\s+(?:et\s+al\.?\s+)?(\d{4})?\s*(.*)?$", query.strip(), ) if m: @@ -1076,8 +1119,10 @@ async def _search_pmid(query: str) -> Optional[str]: # Fix common organism names in topic topic_fixed = _re.sub( - r'\bC\.?\s*elegans\b', '"Caenorhabditis elegans"', - topic, flags=_re.IGNORECASE, + r"\bC\.?\s*elegans\b", + '"Caenorhabditis elegans"', + topic, + flags=_re.IGNORECASE, ) # Strategy 1: author + organism MeSH + quoted topic (most specific) @@ -1087,15 +1132,11 @@ async def _search_pmid(query: str) -> Optional[str]: ) # Strategy 2: author + organism MeSH (no topic — topic may not be in title) - strategies.append( - f'{author}[author] AND "Caenorhabditis elegans"[Mesh]' - ) + strategies.append(f'{author}[author] AND "Caenorhabditis elegans"[Mesh]') # Strategy 3: author + year + topic (exact year, may be wrong) if year and topic_fixed: - strategies.append( - f'{author}[author] AND {year}[pdat] AND {topic_fixed}' - ) + strategies.append(f"{author}[author] AND {year}[pdat] AND {topic_fixed}") # Strategy 4: author + year only if year: @@ -1103,8 +1144,10 @@ async def _search_pmid(query: str) -> Optional[str]: # Strategy 5: original query with organism name fix fixed = _re.sub( - r'\bC\.?\s*elegans\b', '"Caenorhabditis elegans"', - query, flags=_re.IGNORECASE, + r"\bC\.?\s*elegans\b", + '"Caenorhabditis elegans"', + query, + flags=_re.IGNORECASE, ) if fixed != query: strategies.append(fixed) @@ -1122,6 +1165,7 @@ async def _search_pmid(query: str) -> Optional[str]: # Try each strategy import aiohttp + for attempt in unique: try: async with _http_session() as session: @@ -1145,7 +1189,7 @@ async def _search_pmid(query: str) -> Optional[str]: return None -async def _doi_to_pmid(doi: str) -> Optional[str]: +async def _doi_to_pmid(doi: str) -> str | None: """Resolve a DOI to a PMID via PubMed search.""" return await _search_pmid(f"{doi}[doi]") @@ -1177,7 +1221,7 @@ async def _doi_to_pmid(doi: str) -> Optional[str]: ) async def read_paper( reference: str, - context: Dict = None, + context: dict | None = None, ) -> str: """Read a scientific paper and return its content. @@ -1229,11 +1273,7 @@ async def read_paper( status_lines.append(f"Reading local PDF: {path}") text = _read_pdf_file(path) if text: - return ( - f"[Paper from local file: {path}]\n\n" - f"{text}\n\n---\n" - f"Source: local file" - ) + return f"[Paper from local file: {path}]\n\n{text}\n\n---\nSource: local file" else: return ( f"[Paper from local file: {path}]\n\n" @@ -1308,7 +1348,7 @@ async def read_paper( meta = await _pubmed_abstract(pmid) if meta: lines = [ - f"[Abstract only — full text not freely available]\n", + "[Abstract only — full text not freely available]\n", f"# {meta['title']}\n", f"**Authors:** {meta['authors']}", f"**Journal:** {meta['journal']} ({meta['year']})", @@ -1319,9 +1359,9 @@ async def read_paper( lines.append(f"\n## Abstract\n\n{meta['abstract']}") lines.append( - f"\n---\n" - f"*Full text not available through open access channels. " - f"If you have a PDF, provide the file path and I can read it.*" + "\n---\n" + "*Full text not available through open access channels. " + "If you have a PDF, provide the file path and I can read it.*" ) lines.append(f"\n*Resolution path: {' → '.join(status_lines)}*") diff --git a/gently/harness/plan_mode/tools/templates.py b/gently/harness/plan_mode/tools/templates.py index 006c4c95..03557477 100644 --- a/gently/harness/plan_mode/tools/templates.py +++ b/gently/harness/plan_mode/tools/templates.py @@ -5,9 +5,7 @@ dependencies) for re-use with different strains, temperatures, etc. """ -from typing import Dict, Optional - -from ...tools.registry import tool, ToolCategory, ToolExample +from ...tools.registry import ToolCategory, ToolExample, tool @tool( @@ -33,8 +31,8 @@ async def save_plan_template( campaign_id: str, name: str, - description: str = None, - context: Dict = None, + description: str | None = None, + context: dict | None = None, ) -> str: """Save a campaign as a reusable template.""" agent = context.get("agent") if context else None @@ -67,7 +65,7 @@ async def save_plan_template( category=ToolCategory.UTILITY, ) async def list_templates( - context: Dict = None, + context: dict | None = None, ) -> str: """List available plan templates.""" agent = context.get("agent") if context else None @@ -112,8 +110,8 @@ async def list_templates( ) async def apply_template( template_id: str, - overrides: Dict = None, - context: Dict = None, + overrides: dict | None = None, + context: dict | None = None, ) -> str: """Instantiate a template into a new campaign.""" agent = context.get("agent") if context else None diff --git a/gently/harness/plan_mode/tools/validation.py b/gently/harness/plan_mode/tools/validation.py index 8e09802a..031918ca 100644 --- a/gently/harness/plan_mode/tools/validation.py +++ b/gently/harness/plan_mode/tools/validation.py @@ -5,11 +5,9 @@ detector validity, missing controls, dependency cycles, and completeness. """ -import json import logging -from typing import Dict, List, Optional, Set, Tuple -from ...tools.registry import tool, ToolCategory, ToolExample +from ...tools.registry import ToolCategory, ToolExample, tool logger = logging.getLogger(__name__) @@ -22,7 +20,7 @@ "num_slices": (10, 200), "exposure_ms": (5.0, 100.0), "laser_power_pct": (0.0, 100.0), - "interval_s": (10, None), # minimum 10s, no hard max + "interval_s": (10, None), # minimum 10s, no hard max "piezo_amplitude_um": (None, 200.0), # max ±200 μm } @@ -40,9 +38,9 @@ # Temperature scaling factors (relative to 20°C) TEMP_SCALE = { - 15.0: 24.0 / 14.0, # ~1.71× slower + 15.0: 24.0 / 14.0, # ~1.71× slower 20.0: 1.0, - 25.0: 10.0 / 14.0, # ~0.71× faster + 25.0: 10.0 / 14.0, # ~0.71× faster } CONTROL_KEYWORDS = {"control", "wildtype", "n2", "wt", "wild-type", "wild type"} @@ -52,7 +50,8 @@ # Helpers # --------------------------------------------------------------------------- -def _get_temp_factor(temperature_c: Optional[float]) -> float: + +def _get_temp_factor(temperature_c: float | None) -> float: """Return scaling factor for developmental timing at given temperature.""" if temperature_c is None: return 1.0 @@ -68,19 +67,19 @@ def _get_temp_factor(temperature_c: Optional[float]) -> float: return TEMP_SCALE[20.0] + frac * (TEMP_SCALE[25.0] - TEMP_SCALE[20.0]) -def _check_dependency_cycles(items) -> List[str]: +def _check_dependency_cycles(items) -> list[str]: """DFS-based cycle detection on the dependency graph.""" # Build adjacency list: item_id -> list of dependency IDs - adj: Dict[str, List[str]] = {} - id_to_title: Dict[str, str] = {} + adj: dict[str, list[str]] = {} + id_to_title: dict[str, str] = {} for item in items: adj[item.id] = list(item.depends_on) id_to_title[item.id] = item.title WHITE, GRAY, BLACK = 0, 1, 2 - color: Dict[str, int] = {nid: WHITE for nid in adj} - cycles: List[str] = [] - path: List[str] = [] + color: dict[str, int] = {nid: WHITE for nid in adj} + cycles: list[str] = [] + path: list[str] = [] def dfs(node: str): if node not in color: @@ -110,13 +109,16 @@ def dfs(node: str): return cycles -def _stage_order(stage_name: str) -> Optional[int]: +def _stage_order(stage_name: str) -> int | None: """Get ordinal position of a stage, or None if unrecognised.""" from gently_perception.organism import CELEGANS + stages = CELEGANS.stages aliases = { - "3fold": "pretzel", "threefold": "pretzel", - "1.5-fold": "1.5fold", "2-fold": "2fold", + "3fold": "pretzel", + "threefold": "pretzel", + "1.5-fold": "1.5fold", + "2-fold": "2fold", } normed = stage_name.lower().replace("-", "").replace(" ", "") name = aliases.get(normed, normed) @@ -127,9 +129,10 @@ def _stage_order(stage_name: str) -> Optional[int]: return None -def _normalise_stage(name: str) -> Optional[str]: +def _normalise_stage(name: str) -> str | None: """Normalise a stage name to canonical form, or None.""" from gently_perception.organism import CELEGANS + stages = CELEGANS.stages low = name.lower().strip() for s in stages: @@ -143,6 +146,7 @@ def _normalise_stage(name: str) -> Optional[str]: # Tool # --------------------------------------------------------------------------- + @tool( name="validate_plan", description=( @@ -161,7 +165,7 @@ def _normalise_stage(name: str) -> Optional[str]: ) async def validate_plan( campaign_id: str, - context: Dict = None, + context: dict | None = None, ) -> str: """Validate a plan and return errors/warnings.""" agent = context.get("agent") if context else None @@ -178,14 +182,18 @@ async def validate_plan( if not items: return f"Campaign '{campaign.description}' has no plan items to validate." - errors: List[str] = [] - warnings: List[str] = [] + errors: list[str] = [] + warnings: list[str] = [] # Load detector presets for validation try: from gently.organisms import get_organism + org = get_organism() - presets_mod = __import__(f"gently.organisms.{org.ORGANISM_NAME}.detector_presets", fromlist=["get_detector_presets"]) + presets_mod = __import__( + f"gently.organisms.{org.ORGANISM_NAME}.detector_presets", + fromlist=["get_detector_presets"], + ) valid_detectors = set(presets_mod.get_detector_presets().keys()) except ImportError: valid_detectors = set() @@ -199,16 +207,31 @@ async def validate_plan( label = f"[{item.type.value}] '{item.title}'" # Check for control mentions - text_blob = " ".join(filter(None, [ - item.title, item.description, item.outcome, - ])).lower() + text_blob = " ".join( + filter( + None, + [ + item.title, + item.description, + item.outcome, + ], + ) + ).lower() if item.imaging_spec: - text_blob += " " + " ".join(filter(None, [ - item.imaging_spec.strain, - item.imaging_spec.genotype, - item.imaging_spec.reporter, - item.imaging_spec.success_criteria, - ])).lower() + text_blob += ( + " " + + " ".join( + filter( + None, + [ + item.imaging_spec.strain, + item.imaging_spec.genotype, + item.imaging_spec.reporter, + item.imaging_spec.success_criteria, + ], + ) + ).lower() + ) if any(kw in text_blob for kw in CONTROL_KEYWORDS): has_control = True @@ -221,13 +244,9 @@ async def validate_plan( if val is None: continue if lo is not None and val < lo: - errors.append( - f"{label}: {field_name}={val} below minimum {lo}" - ) + errors.append(f"{label}: {field_name}={val} below minimum {lo}") if hi is not None and val > hi: - errors.append( - f"{label}: {field_name}={val} exceeds maximum {hi}" - ) + errors.append(f"{label}: {field_name}={val} exceeds maximum {hi}") # Stage consistency if spec.start_stage and spec.stop_condition: diff --git a/gently/harness/prompts/manager.py b/gently/harness/prompts/manager.py index 23cb9291..9405c820 100644 --- a/gently/harness/prompts/manager.py +++ b/gently/harness/prompts/manager.py @@ -9,13 +9,13 @@ import json import logging from datetime import datetime -from typing import Dict, List, Optional from gently.settings import settings -from .templates import build_system_prompt, build_context_message + from ..plan_mode.prompt import build_plan_prompt from ..resolution_mode.prompt import build_resolution_prompt from ..tools.registry import get_tool_registry +from .templates import build_system_prompt logger = logging.getLogger(__name__) @@ -33,21 +33,22 @@ def __init__(self, claude_client, model): self.model = model # Context summary caching - self._context_summary_cache: Optional[str] = None - self._context_summary_time: Optional[datetime] = None + self._context_summary_cache: str | None = None + self._context_summary_time: datetime | None = None self._context_summary_ttl: int = 300 # 5 minutes # Memory awareness caching - self._memory_awareness_cache: Optional[str] = None - self._memory_awareness_time: Optional[datetime] = None + self._memory_awareness_cache: str | None = None + self._memory_awareness_time: datetime | None = None self._memory_awareness_ttl: int = 600 # 10 minutes # Set by agent after construction self.context_store = None self.memory = None # AgentMemory instance - def update_system_prompt(self, experiment, client, mode: str, - context_summary: str = None) -> str: + def update_system_prompt( + self, experiment, client, mode: str, context_summary: str | None = None, perceiver=None + ) -> str: """ Rebuild system prompt with current experiment state and connection status. @@ -86,16 +87,19 @@ def update_system_prompt(self, experiment, client, mode: str, # Execution mode if client: connection_status = { - 'device_layer': client.is_connected, - 'sam_detection': client.has_sam, + "device_layer": client.is_connected, + "sam_detection": client.has_sam, } else: connection_status = None return build_system_prompt( - experiment, connection_status, context_summary, + experiment, + connection_status, + context_summary, memory_awareness=memory_awareness, microscope=client, + perceiver=perceiver, ) def get_tools_for_mode(self, mode: str, has_microscope: bool) -> list: @@ -142,17 +146,32 @@ def get_tools_for_mode(self, mode: str, has_microscope: bool) -> list: return [t for t in all_tools if t["name"] in resolution_tool_names] if mode == "plan": plan_tool_names = { - "create_campaign", "create_plan_item", "update_plan_item", - "link_plan_items", "propose_plan", "get_plan_status", + "create_campaign", + "create_plan_item", + "update_plan_item", + "link_plan_items", + "propose_plan", + "get_plan_status", "get_plan_item", - "move_plan_item", "delete_plan_item", "reorder_plan_items", - "update_phase", "delete_phase", + "move_plan_item", + "delete_plan_item", + "reorder_plan_items", + "update_phase", + "delete_phase", "export_plan", - "query_lab_history", "check_hardware_capability", - "search_literature", "search_strains", + "query_lab_history", + "check_hardware_capability", + "search_literature", + "search_strains", "validate_plan", - "batch_update_status", "batch_update_spec", - "save_plan_template", "list_templates", "apply_template", + "batch_update_status", + "batch_update_spec", + "snapshot_plan", + "list_plan_versions", + "restore_plan_version", + "save_plan_template", + "list_templates", + "apply_template", "ask_user_choice", } all_tools = registry.get_claude_schemas(has_microscope=False) @@ -165,14 +184,16 @@ def get_cached_memory_awareness(self) -> str: if not self.memory: return "" now = datetime.now() - if (self._memory_awareness_cache is None or - self._memory_awareness_time is None or - (now - self._memory_awareness_time).total_seconds() > self._memory_awareness_ttl): + if ( + self._memory_awareness_cache is None + or self._memory_awareness_time is None + or (now - self._memory_awareness_time).total_seconds() > self._memory_awareness_ttl + ): self._memory_awareness_cache = self.memory.get_awareness_summary() self._memory_awareness_time = now return self._memory_awareness_cache - def get_active_plan_summary(self) -> Optional[str]: + def get_active_plan_summary(self) -> str | None: """Get a summary of the active experimental plan, if any.""" if not self.context_store: return None @@ -190,13 +211,14 @@ def get_active_plan_summary(self) -> Optional[str]: f" ({status['completed']}/{status['total']} items done)" ) if status["next_actions"]: - lines.append(" Next: " + ", ".join( - a.title for a in status["next_actions"][:3] - )) + lines.append( + " Next: " + ", ".join(a.title for a in status["next_actions"][:3]) + ) if status["pending_decisions"]: - lines.append(" Decisions pending: " + ", ".join( - d.title for d in status["pending_decisions"] - )) + lines.append( + " Decisions pending: " + + ", ".join(d.title for d in status["pending_decisions"]) + ) return "\n".join(lines) if lines else None except Exception: return None @@ -220,36 +242,36 @@ def gather_context_data(self, experiment, timelapse_orch, timeline_mgr) -> dict: Context data including timelapse status, events, and detections """ data = { - 'current_time': datetime.now().isoformat(), - 'timelapse_status': None, - 'recent_events': [], - 'recent_detections': [], - 'detection_reasoning': [], + "current_time": datetime.now().isoformat(), + "timelapse_status": None, + "recent_events": [], + "recent_detections": [], + "detection_reasoning": [], } if timelapse_orch: try: status = timelapse_orch.get_status() - data['timelapse_status'] = { - 'state': status.status.value if status.status else 'unknown', - 'total_timepoints': status.total_timepoints or 0, - 'started_at': status.started_at.isoformat() if status.started_at else None, - 'embryo_count': len(status.embryos) if status.embryos else 0, + data["timelapse_status"] = { + "state": status.status.value if status.status else "unknown", + "total_timepoints": status.total_timepoints or 0, + "started_at": status.started_at.isoformat() if status.started_at else None, + "embryo_count": len(status.embryos) if status.embryos else 0, } except Exception as e: logger.debug(f"Could not get timelapse status: {e}") if timeline_mgr: try: - events = timeline_mgr.get_events(limit=20, session_id='current') - data['recent_events'] = [ + events = timeline_mgr.get_events(limit=20, session_id="current") + data["recent_events"] = [ { - 'type': e.event_subtype, - 'time': e.timestamp.isoformat(), - 'embryo': e.embryo_id, - 'detector': e.detector_name, - 'timepoint': e.timepoint, - 'confidence': e.confidence, + "type": e.event_subtype, + "time": e.timestamp.isoformat(), + "embryo": e.embryo_id, + "detector": e.detector_name, + "timepoint": e.timepoint, + "confidence": e.confidence, } for e in events ] @@ -258,32 +280,35 @@ def gather_context_data(self, experiment, timelapse_orch, timeline_mgr) -> dict: try: for embryo_id, embryo_state in experiment.embryos.items(): - if not hasattr(embryo_state, 'detection_results'): + if not hasattr(embryo_state, "detection_results"): continue for detector_name, results in embryo_state.detection_results.items(): recent_results = results[-3:] if len(results) > 3 else results for r in recent_results: - if r.get('detected'): - data['recent_detections'].append({ - 'detector': detector_name, - 'embryo': embryo_id, - 'timepoint': r.get('timepoint'), - 'confidence': r.get('confidence'), - }) - if r.get('reasoning'): - data['detection_reasoning'].append({ - 'detector': detector_name, - 'embryo': embryo_id, - 'timepoint': r.get('timepoint'), - 'reasoning': r.get('reasoning')[:500], - }) + if r.get("detected"): + data["recent_detections"].append( + { + "detector": detector_name, + "embryo": embryo_id, + "timepoint": r.get("timepoint"), + "confidence": r.get("confidence"), + } + ) + if r.get("reasoning"): + data["detection_reasoning"].append( + { + "detector": detector_name, + "embryo": embryo_id, + "timepoint": r.get("timepoint"), + "reasoning": r.get("reasoning")[:500], + } + ) except Exception as e: logger.debug(f"Could not get detection results: {e}") return data - async def generate_context_summary(self, experiment, timelapse_orch, - timeline_mgr) -> str: + async def generate_context_summary(self, experiment, timelapse_orch, timeline_mgr) -> str: """ Generate concise context summary using Haiku. @@ -300,22 +325,23 @@ async def generate_context_summary(self, experiment, timelapse_orch, """ raw_data = self.gather_context_data(experiment, timelapse_orch, timeline_mgr) - has_timelapse = raw_data['timelapse_status'] is not None - has_events = len(raw_data['recent_events']) > 0 - has_detections = len(raw_data['recent_detections']) > 0 + has_timelapse = raw_data["timelapse_status"] is not None + has_events = len(raw_data["recent_events"]) > 0 + has_detections = len(raw_data["recent_detections"]) > 0 if not (has_timelapse or has_events or has_detections): return "" - prompt = f"""Summarize the current microscopy session state in 2-3 sentences for another AI assistant. -Focus on: timelapse status (is it running, completed, or idle?), time since last activity, and notable detections. -Be factual and concise. + prompt = f"""Summarize the current microscopy session state in 2-3 sentences for +another AI assistant. Focus on: timelapse status (is it running, completed, or idle?), +time since last activity, and notable detections. Be factual and concise. Raw session data: {json.dumps(raw_data, indent=2, default=str)} Write a brief status summary. Examples: -- "Timelapse completed 10h ago with 233 timepoints. Hatching was detected at timepoints 175-193 with HIGH confidence." +- "Timelapse completed 10h ago with 233 timepoints. Hatching was detected at timepoints + 175-193 with HIGH confidence." - "Timelapse is currently running for embryo_1 at timepoint 45. No detections yet." - "No active timelapse. Last session had 50 timepoints, with comma stage detected at t=30." """ @@ -325,15 +351,14 @@ async def generate_context_summary(self, experiment, timelapse_orch, self.claude.messages.create, model=settings.models.fast, max_tokens=150, - messages=[{"role": "user", "content": prompt}] + messages=[{"role": "user", "content": prompt}], ) return response.content[0].text.strip() except Exception as e: logger.warning(f"Failed to generate context summary: {e}") return "" - async def get_cached_context_summary(self, experiment, timelapse_orch, - timeline_mgr) -> str: + async def get_cached_context_summary(self, experiment, timelapse_orch, timeline_mgr) -> str: """ Get context summary with caching (5-minute TTL). @@ -349,9 +374,11 @@ async def get_cached_context_summary(self, experiment, timelapse_orch, Cached or newly generated context summary """ now = datetime.now() - if (self._context_summary_cache is None or - self._context_summary_time is None or - (now - self._context_summary_time).total_seconds() > self._context_summary_ttl): + if ( + self._context_summary_cache is None + or self._context_summary_time is None + or (now - self._context_summary_time).total_seconds() > self._context_summary_ttl + ): self._context_summary_cache = await self.generate_context_summary( experiment, timelapse_orch, timeline_mgr ) @@ -382,6 +409,6 @@ def get_cached_system_prompt(self, system_prompt: str) -> list: { "type": "text", "text": system_prompt, - "cache_control": {"type": "ephemeral", "ttl": "1h"} + "cache_control": {"type": "ephemeral", "ttl": "1h"}, } ] diff --git a/gently/harness/prompts/templates.py b/gently/harness/prompts/templates.py index 55894c22..f75fc7a5 100644 --- a/gently/harness/prompts/templates.py +++ b/gently/harness/prompts/templates.py @@ -2,17 +2,20 @@ System prompts and context builders for the Microscopy Agent """ -from typing import Dict, List -from ..state import ExperimentState -from gently.organisms import get_organism from gently.hardware import get_hardware +from gently.organisms import get_organism +from ..state import ExperimentState # Interactive choice guidance USER_INTERACTION_GUIDELINES = """ # Interactive User Choices — MANDATORY -CRITICAL RULE: Whenever you need to ask the user a question — whether it's a yes/no confirmation, a choice between options, or any question where the answer could be one of several discrete responses — you MUST use the `ask_user_choice` tool. NEVER present options as numbered text lists or bullet points. NEVER ask the user to type their choice as text when you could present selectable options instead. +CRITICAL RULE: Whenever you need to ask the user a question — whether it's a yes/no +confirmation, a choice between options, or any question where the answer could be one of +several discrete responses — you MUST use the `ask_user_choice` tool. NEVER present options +as numbered text lists or bullet points. NEVER ask the user to type their choice as text when +you could present selectable options instead. ## When to use ask_user_choice @@ -45,14 +48,21 @@ GOOD (always do this): Call the `ask_user_choice` tool. Example parameters: question: "What would you like to work on today?" - options: [{"id": "new", "label": "Start a new experiment"}, {"id": "resume", "label": "Resume a session"}] + options: [{"id": "new", "label": "Start a new experiment"}, + {"id": "resume", "label": "Resume a session"}] -The user interface renders these as an interactive picker with arrow-key navigation — much better UX than typing. -Do NOT write tool calls as XML tags or code blocks in your text — always invoke tools through the tool mechanism. +The user interface renders these as an interactive picker with arrow-key navigation — much +better UX than typing. +Do NOT write tool calls as XML tags or code blocks in your text — always invoke tools through +the tool mechanism. -IMPORTANT: This is not optional. ALWAYS use ask_user_choice when presenting choices or asking questions. The ONLY exception is when you need a completely free-form text response (like asking for a name or description). +IMPORTANT: This is not optional. ALWAYS use ask_user_choice when presenting choices or asking +questions. The ONLY exception is when you need a completely free-form text response (like +asking for a name or description). -Each option should be a specific, distinct choice. The picker automatically adds a free-text "Something else..." input at the bottom for custom responses, so your options can focus on the most likely concrete answers. +Each option should be a specific, distinct choice. The picker automatically adds a free-text +"Something else..." input at the bottom for custom responses, so your options can focus on +the most likely concrete answers. """ @@ -87,62 +97,32 @@ # CV Subagent capabilities CV_SUBAGENT = """ -# CV Subagent for Advanced Analysis - -For complex computer vision analysis, you have access to a specialized CV subagent via the `cv_analyze` tool. - -## IMPORTANT: Volume Required First! - -Before using cv_analyze or classify_embryo_stage, you MUST ensure the embryo has a volume acquired -in this session. If the user asks for cell counting, stage classification, or any analysis: - -1. Check if the embryo has been imaged (recent_images exists) -2. If NOT, acquire a volume first with `acquire_volume` -3. Then proceed with analysis - -Example workflow: -User: "Count the cells in embryo_3" -→ First: acquire_volume(embryo_id="embryo_3") # Get fresh data -→ Then: cv_analyze(intent="count cells", embryo_id="embryo_3") - -## When to use cv_analyze - -Use the CV subagent when you need: -- **Accurate stage classification** - It segments nuclei (Cellpose) and uses count + morphology for staging -- **Cell counting** - 3D segmentation gives precise nuclei counts, not visual estimates -- **Division tracking** - Tracks cells across timepoints, identifies division events -- **Morphology measurements** - Elongation ratio, circularity (important for comma/fold stages) -- **Anomaly detection** - Compares to expected developmental patterns - -## When NOT to use cv_analyze - -Don't use it for: -- Quick visual checks (use simple image viewing instead) -- Hatching detection (the hatching detector handles this) -- Basic "what stage is this?" if rough estimate is fine - -## How it works - -The CV subagent is itself an AI agent that: -1. Loads volume data from the data store -2. Segments with Cellpose/StarDist (nuclei count!) -3. Measures morphology (elongation for fold stages) -4. Adds scale bars and annotations -5. Uses Claude Vision with rich quantitative context - -This gives much more accurate results than just sending an image to vision. - -## Example usage - -User: "How many cells does embryo 1 have?" -→ First acquire_volume if needed, then cv_analyze with intent="count cells and nuclei" - -User: "What stage is embryo 2?" -→ If precision matters: acquire_volume then cv_analyze intent="classify developmental stage" -→ If quick check: view the image yourself - -User: "Track cell divisions over the last 5 timepoints" -→ cv_analyze with intent="track cell divisions" and timepoints=[t-4, t-3, t-2, t-1, t] +# Perception & Analysis + +You see and reason about embryo development through three channels: + +1. **Live perception (the perceiver).** During a timelapse a vision-language + perceiver classifies each acquired volume's developmental stage and tracks + each embryo's trajectory. Its current read is injected into your context + under "## Perception (live)" — stage, stability (how long it's held that + stage), time-in-stage, and a possible-arrest flag. Call + `get_recent_perceptions(embryo_id)` for the fuller picture: stage history, + trajectory, the arrest signal, and the perceiver's own reasoning. This is + your primary signal for "how is it developing?" and for deciding whether to + adapt acquisition. + +2. **On-demand vision (`analyze_volume`).** Ask Claude Vision a specific + question about an acquired volume (e.g. "is the reporter saturating?", + "describe the morphology"). Requires a volume in this session — acquire one + first with `acquire_volume` if none exists. + +3. **Stage tools.** `classify_embryo_stage` (a vision spot-check of the latest + image), `get_stage_history`, and `predict_hatching` — the latter two read the + live perceiver when available, so they work without a manual classify call. + +Prefer the live perception snapshot + `get_recent_perceptions` for routine +"what stage / is anything stuck" questions; reach for `analyze_volume` when you +need a specific visual judgement about a particular volume. """ @@ -222,7 +202,7 @@ | User describes... | Mode to install | |---|---| -| reporter expression, GFP/mCherry onset, "neurons lighting up", dopaminergic signal, anything where fluorescence turns on | `expression_monitoring` | +| reporter onset: GFP/mCherry, dopaminergic signal, neurons lighting up | `expression_monitoring` | | hatching timing, pre-hatch dynamics, "track until they hatch" | `pre_terminal_monitoring` | | plain imaging, exploratory, no specific signal target | none (idle) | @@ -252,7 +232,8 @@ 1. **Non-blocking operation**: The timelapse runs independently - you can still chat with the user 2. **Per-embryo stop conditions**: Each embryo can stop at different times (e.g., when hatching) 3. **Dynamic intervals**: Adjust imaging frequency per-embryo during the experiment -4. **Detector integration**: Stop conditions triggered by visual detection (hatching, comma stage, etc.) +4. **Detector integration**: Stop conditions triggered by visual detection + (hatching, comma stage, etc.) ## Stop Conditions @@ -266,36 +247,129 @@ 1. User: "Run timelapse until all embryos hatch" 2. Agent: - - Enables hatching detector (enable_preset_detector) - - Starts timelapse with stop_condition="hatching" + - Starts the timelapse with stop_condition="hatching" (the stop condition + wires the detection; the perception loop classifies each acquired volume) + - Optionally installs a monitoring mode (enable_monitoring_mode) for + reactive cadence/power - Reports progress on request - Each embryo stops automatically when it hatches -## Available Preset Detectors +## Stage detection -- **hatching**: Detects eggshell breach and embryo emergence -- **comma**: Detects comma stage morphology -- **pretzel**: Detects 3-fold/pretzel stage -- **gastrulation**: Detects cell internalization -- **first_division**: Detects 1-cell to 2-cell transition +Developmental stage comes from the live perception loop (see "Perception & +Analysis"), surfaced in your context and via get_recent_perceptions. Stop +conditions can key on it — e.g. stop_condition="hatching" or "comma". ## Commands During Timelapse - Query status: get_timelapse_status - Stop one embryo: stop_timelapse_embryo -- Change interval: modify_timelapse_embryo +- Change interval (all embryos): modify_timelapse_interval +- Change one embryo's cadence: set_embryo_cadence +- Other per-embryo params: modify_timelapse_embryo / modify_parameters - Pause all: pause_timelapse - Resume: resume_timelapse - Stop all: stop_timelapse """ +AUTONOMY_AND_ADAPTATION = """ +# Adapting Acquisition — Gently + +Gentleness is the prime directive: every imaging action spends photodose on a +precious, living sample. Always prefer the *least* light that answers the +question. When you do adapt, you have direct, live knobs — each takes effect on +the embryo's next acquisition, no restart: + +- **Cadence**: `modify_timelapse_interval` (whole run) / `set_embryo_cadence` + (one embryo). Speed up only around events worth catching (e.g. approaching + hatching); slow back down when nothing is changing. +- **Dose levers**: `modify_parameters` — num_slices, exposure_ms, acquisition + mode (volume ↔ snap, snap is far gentler), and per-embryo 488 power (hard + clamped 2–6%). `set_photodose_budget` caps cumulative exposure and pauses an + embryo that exceeds it; `get_photodose_status` shows where each stands. +- **Events**: `add_stop_condition` (auto-stop on hatching/stage/duration), + `queue_burst` (one-shot high-rate capture of a transient), and per-embryo + pause / resume / stop. +- **Reactive modes**: `enable_monitoring_mode` installs perception-driven rules + that fire on their own (pre-hatching speedup, 488 rampdown on saturation, + burst on stable structure). + +Bias toward the gentlest sufficient action — snap over volume, fewer slices, +lower power, longer interval — unless an event genuinely needs the resolution. + +# Autonomy (OFF / ASK / AUTO) + +You may act between user messages, but only as far as the operator allows. The +mode is set with `set_autonomy` and is **OFF by default**: + +- **off** — act only when the user messages you. +- **ask** — on a notable event (a developmental stage transition, possible + arrest, hatching, an embryo terminating, or an error) you wake, briefly state + your PROPOSED change and why, then call `ask_user_choice` with + Approve / Modify / Skip and act ONLY on Approve. +- **auto** — you adapt on your own on those events. Still: prefer the gentlest + action, and a few irreversible tools (turning the laser on via + `set_laser_power`, `remove_embryo`, `stop_timelapse`) are hard-blocked from + autonomous use — ask the operator for those. + +When you wake autonomously, your turn and the trigger that woke you are shown to +the operator in the chat. Keep autonomous turns tight: assess, make the smallest +helpful change (or none), and explain it in a sentence or two. +""" + + +def build_perception_snapshot(perceiver, embryos) -> str: + """One compact line per embryo of live perception state for the system prompt. + + Reads straight from the perception sessions (current stage, stability, time in + stage, arrest signal, short trajectory). Every read here is synchronous and + side-effect-free — it never triggers a VLM call. Returns '' when there is + nothing to show, so callers can drop the section entirely. + """ + if not perceiver or not embryos: + return "" + lines = [] + for embryo_id in sorted(embryos): + try: + session = perceiver.get_session(embryo_id) + summary = session.summary() if session is not None else None + except Exception: + summary = None + if not summary or not summary.get("current_stage"): + lines.append(f"- {embryo_id}: no perception yet") + continue + parts = [ + f"stage={summary['current_stage']}", + f"stable={summary.get('stability', 0)}x", + ] + temporal = summary.get("temporal") # TemporalContext dataclass or None + if temporal is not None: + tmin = getattr(temporal, "time_in_stage_min", None) + exp = getattr(temporal, "expected_duration_min", None) + if tmin is not None: + seg = f"in_stage={tmin:.0f}min" + if exp: + seg += f"/{exp:.0f}" + parts.append(seg) + if getattr(temporal, "is_potentially_arrested", False): + parts.append("ARRESTED?") + seq = summary.get("stage_sequence") or [] + if len(seq) > 1: + parts.append("traj=" + "->".join(seq[-4:])) + lines.append(f"- {embryo_id}: " + " ".join(parts)) + if not lines: + return "" + return "## Perception (live)\n\n" + "\n".join(lines) + + def build_system_prompt( experiment_state: ExperimentState, - connection_status: dict = None, - context_summary: str = None, - memory_awareness: str = None, + connection_status: dict | None = None, + context_summary: str | None = None, + memory_awareness: str | None = None, microscope=None, + perceiver=None, ) -> str: """ Build complete system prompt for Claude @@ -314,14 +388,16 @@ def build_system_prompt( str Complete system prompt """ - embryo_summary = experiment_state.get_summary() if experiment_state.embryos else "No embryos loaded yet" + embryo_summary = ( + experiment_state.get_summary() if experiment_state.embryos else "No embryos loaded yet" + ) # Build connection status section if connection_status: - device_layer = "connected" if connection_status.get('device_layer') else "NOT CONNECTED" - sam = "available" if connection_status.get('sam_detection') else "not available" + device_layer = "connected" if connection_status.get("device_layer") else "NOT CONNECTED" + sam = "available" if connection_status.get("sam_detection") else "not available" - if not connection_status.get('device_layer'): + if not connection_status.get("device_layer"): connection_section = f"""# Hardware Connection Status ⚠️ **OFFLINE MODE** - Device layer is not connected. @@ -329,7 +405,8 @@ def build_system_prompt( - Device Layer: {device_layer} - SAM Detection: {sam} -**Important**: You cannot perform hardware operations (detect embryos, capture images, move stage, etc.) +**Important**: You cannot perform hardware operations (detect embryos, capture images, +move stage, etc.) without a connected device layer. If the user asks for hardware operations, inform them that the microscope is not connected and suggest they start the server or check the connection.""" else: @@ -357,6 +434,15 @@ def build_system_prompt( else: context_section = "" + # Live per-embryo perception snapshot (deterministic, read straight from the + # perception sessions — bypasses the AI context-summary cache so stage data is + # never stale). + perception_section = "" + if perceiver is not None and experiment_state.embryos: + snap = build_perception_snapshot(perceiver, experiment_state.embryos) + if snap: + perception_section = f"\n{snap}\n" + # Pull organism-specific content from the active organism module organism = get_organism() organism_display = organism.ORGANISM_DISPLAY_NAME @@ -364,13 +450,13 @@ def build_system_prompt( biology_knowledge = organism.BIOLOGY_KNOWLEDGE # Build stop conditions list from organism module - stop_condition_names = list(organism.STOP_CONDITIONS.keys()) - detector_names = list(organism.get_detector_presets().keys()) + list(organism.STOP_CONDITIONS.keys()) + list(organism.get_detector_presets().keys()) # Pull hardware description — prefer microscope (from device layer handshake), # fall back to the static hardware module hardware = get_hardware() - hardware_description = getattr(microscope, 'DESCRIPTION', '') or hardware.HARDWARE_DESCRIPTION + hardware_description = getattr(microscope, "DESCRIPTION", "") or hardware.HARDWARE_DESCRIPTION hardware_display = hardware.HARDWARE_DISPLAY_NAME return f"""You are Gently — an AI scientific collaborator running {hardware_display} @@ -397,6 +483,8 @@ def build_system_prompt( {REACTIVE_MONITORING_MODES} +{AUTONOMY_AND_ADAPTATION} + {USER_INTERACTION_GUIDELINES} {SESSION_MANAGEMENT} @@ -404,33 +492,46 @@ def build_system_prompt( # Current Experiment State {embryo_summary} +{perception_section} {context_section} # Tool Use Guidelines Answer the user's request using relevant tools. Before calling a tool, do some analysis: 1. Think about which of the provided tools is relevant to answer the user's request -2. Go through each required parameter and determine if the user has provided or given enough information to infer a value +2. Go through each required parameter and determine if the user has provided or given enough + information to infer a value 3. If all required parameters are present or can be reasonably inferred, PROCEED WITH THE TOOL CALL 4. If a required parameter is missing, ask the user to provide it 5. DO NOT ask for more information on optional parameters if not provided - use defaults IMPORTANT: When you need information (status, positions, etc.), CALL THE TOOL IMMEDIATELY. -Do NOT explain what you "would need to do" - just do it. Never say "I would need to query..." - just query it. +Do NOT explain what you "would need to do" - just do it. Never say "I would need to +query..." - just query it. # Behavior Guidelines -1. **Act, then explain**: Call tools first, then explain results. Don't describe what you would do - do it. -2. **Be scientifically accurate**: Base interpretations on actual developmental biology, not speculation +1. **Act, then explain**: Call tools first, then explain results. Don't describe what you + would do - do it. +2. **Be scientifically accurate**: Base interpretations on actual developmental biology, + not speculation 3. **Prioritize sample health**: Always minimize photobleaching and photodamage -4. **Respect embryo roles**: Every embryo line shows `[role=TEST]`, `[role=CALIBRATION]`, or `[role=UNASSIGNED]`. Calibrate / sweep / classify on CALIBRATION embryos; conserve photodose on TEST. Never suggest calibrating against a TEST embryo (see Embryo Roles section). +4. **Respect embryo roles**: Every embryo line shows `[role=TEST]`, `[role=CALIBRATION]`, + or `[role=UNASSIGNED]`. Calibrate / sweep / classify on CALIBRATION embryos; conserve + photodose on TEST. Never suggest calibrating against a TEST embryo (see Embryo Roles + section). 5. **Use proper terminology**: Refer to embryos by ID, nickname, or user label naturally 6. **Track temporal context**: Remember what you've seen in recent images when analyzing new data 6. **Generate safe plans**: Always validate parameters are within hardware limits 7. **Be conversational**: You're a scientific colleague, not a robot -8. **Stop after success**: When a tool returns a success message (starts with ✓), do NOT retry. Report success and wait for next request. -9. **Single tool = complete action**: Tools like capture_lightsheet, view_image, and acquire_volume are COMPLETE actions. Do NOT chain them unless explicitly asked. -10. **Use defaults**: If a tool has default parameters and the user doesn't specify values, use the defaults. -11. **ALWAYS use ask_user_choice**: When asking the user ANY question with selectable answers, MUST use the `ask_user_choice` tool. NEVER list options as text. This is the #1 UX rule. +8. **Stop after success**: When a tool returns a success message (starts with ✓), do NOT + retry. Report success and wait for next request. +9. **Single tool = complete action**: Tools like capture_lightsheet, view_image, and + acquire_volume are COMPLETE actions. Do NOT chain them unless explicitly asked. +10. **Use defaults**: If a tool has default parameters and the user doesn't specify values, + use the defaults. +11. **ALWAYS use ask_user_choice**: When asking the user ANY question with selectable + answers, MUST use the `ask_user_choice` tool. NEVER list options as text. This is the + #1 UX rule. # Embryo Naming @@ -446,7 +547,7 @@ def build_system_prompt( """ -def build_context_message(experiment_state: ExperimentState) -> Dict: +def build_context_message(experiment_state: ExperimentState) -> dict: """ Build context message with current experiment state @@ -464,5 +565,7 @@ def build_context_message(experiment_state: ExperimentState) -> Dict: """ return { "role": "user", - "content": f"[System update - current experiment state]\n\n{experiment_state.get_summary()}" + "content": ( + f"[System update - current experiment state]\n\n{experiment_state.get_summary()}" + ), } diff --git a/gently/harness/protocols.py b/gently/harness/protocols.py index 9b8b29e9..eb2e811a 100644 --- a/gently/harness/protocols.py +++ b/gently/harness/protocols.py @@ -10,7 +10,7 @@ from gently.harness.protocols import MicroscopeClientProtocol """ -from typing import Protocol, runtime_checkable, Dict, List, Set, Tuple, Optional +from typing import Protocol, runtime_checkable @runtime_checkable @@ -24,11 +24,11 @@ class OrganismProtocol(Protocol): ORGANISM_NAME: str ORGANISM_DISPLAY_NAME: str - SAMPLE_TERM: str # "embryo", "cell", "organoid" + SAMPLE_TERM: str # "embryo", "cell", "organoid" SAMPLE_TERM_PLURAL: str STAGES: list TERMINAL_STAGES: set - BIOLOGY_KNOWLEDGE: str # Markdown text for LLM context + BIOLOGY_KNOWLEDGE: str # Markdown text for LLM context PERCEPTION_SYSTEM_PROMPT: str @@ -55,10 +55,10 @@ class HardwareProtocol(Protocol): HARDWARE_NAME: str HARDWARE_DISPLAY_NAME: str - HARDWARE_DESCRIPTION: str # Markdown text for LLM context - CAPABILITIES: set # Set of capability strings + HARDWARE_DESCRIPTION: str # Markdown text for LLM context + CAPABILITIES: set # Set of capability strings # Backward-compat alias — the Microscope base class in harness/microscope.py # replaces this Protocol. Import from there for new code. -from .microscope import Microscope as MicroscopeClientProtocol # noqa: F401 +from .microscope import Microscope as MicroscopeClientProtocol # noqa: E402, F401 diff --git a/gently/harness/resolution_mode/prompt.py b/gently/harness/resolution_mode/prompt.py index 6b3d0e5e..ad4ae321 100644 --- a/gently/harness/resolution_mode/prompt.py +++ b/gently/harness/resolution_mode/prompt.py @@ -13,11 +13,8 @@ and call one of the resolution lifecycle tools to record it. """ -from typing import Optional - -from gently.organisms import get_organism from gently.hardware import get_hardware - +from gently.organisms import get_organism RESOLUTION_MODE_IDENTITY = """\ You're in **session resolution** — figure out what the researcher @@ -110,8 +107,8 @@ def build_resolution_prompt( - context_summary: Optional[str] = None, - memory_awareness: Optional[str] = None, + context_summary: str | None = None, + memory_awareness: str | None = None, ) -> str: """ Build the system prompt for resolution mode. diff --git a/gently/harness/roles.py b/gently/harness/roles.py index 8708f817..beaa59a4 100644 --- a/gently/harness/roles.py +++ b/gently/harness/roles.py @@ -20,7 +20,6 @@ """ from dataclasses import dataclass -from typing import Dict, List, Optional @dataclass(frozen=True) @@ -29,10 +28,11 @@ class EmbryoRole: Frozen so role definitions are immutable references after registry build. """ + name: str description: str default_cadence_seconds: float = 300.0 - detector_name: Optional[str] = None + detector_name: str | None = None photodose_budget_multiplier: float = 1.0 ui_color: str = "#888888" ui_icon: str = "circle" @@ -43,10 +43,10 @@ class EmbryoRole: # drift back; once they're out of view they stay out, so they get # a short threshold. Test embryos can occasionally pop out and # back, so they get a longer one. - no_object_consecutive_terminal: Optional[int] = None + no_object_consecutive_terminal: int | None = None -REGISTRY: Dict[str, EmbryoRole] = { +REGISTRY: dict[str, EmbryoRole] = { "unassigned": EmbryoRole( name="unassigned", description="No role assigned yet — treated like 'test' for safety.", @@ -94,10 +94,7 @@ class EmbryoRole: def get_role(name: str) -> EmbryoRole: """Look up a role by name. Raises KeyError with helpful message.""" if name not in REGISTRY: - raise KeyError( - f"Unknown embryo role: {name!r}. " - f"Available: {sorted(REGISTRY.keys())}" - ) + raise KeyError(f"Unknown embryo role: {name!r}. Available: {sorted(REGISTRY.keys())}") return REGISTRY[name] @@ -105,6 +102,6 @@ def is_valid_role(name: str) -> bool: return name in REGISTRY -def list_roles() -> List[str]: +def list_roles() -> list[str]: """All registered role names, sorted.""" return sorted(REGISTRY.keys()) diff --git a/gently/harness/session/interaction_logger.py b/gently/harness/session/interaction_logger.py index 15c32ce1..1caa1e90 100644 --- a/gently/harness/session/interaction_logger.py +++ b/gently/harness/session/interaction_logger.py @@ -14,12 +14,12 @@ """ import json +import logging import subprocess -from dataclasses import dataclass, field, asdict +from dataclasses import asdict, dataclass, field from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Optional -import logging +from typing import Any logger = logging.getLogger(__name__) @@ -27,12 +27,13 @@ @dataclass class ToolCallRecord: """Record of a single tool call""" + tool_name: str - tool_input: Dict[str, Any] + tool_input: dict[str, Any] result: str duration_seconds: float is_error: bool = False - error_message: Optional[str] = None + error_message: str | None = None @dataclass @@ -43,6 +44,7 @@ class InteractionRecord: An interaction is one user message and the agent's response, including any tool calls made during that response. """ + # Unique ID for this interaction interaction_id: str @@ -51,44 +53,42 @@ class InteractionRecord: timestamp: datetime # System state snapshot at time of request - system_state: Dict[str, Any] = field(default_factory=dict) + system_state: dict[str, Any] = field(default_factory=dict) # What happened - tool_calls: List[ToolCallRecord] = field(default_factory=list) + tool_calls: list[ToolCallRecord] = field(default_factory=list) assistant_response: str = "" total_duration_seconds: float = 0.0 # Errors - error: Optional[str] = None - error_traceback: Optional[str] = None + error: str | None = None + error_traceback: str | None = None # Correction detection (filled in after next turn) was_corrected: bool = False - correction_prompt: Optional[str] = None - correction_indicators: List[str] = field(default_factory=list) + correction_prompt: str | None = None + correction_indicators: list[str] = field(default_factory=list) # Metadata session_id: str = "" codebase_version: str = "" model: str = "" - def to_dict(self) -> Dict: + def to_dict(self) -> dict: """Serialize to dictionary for JSON storage""" d = asdict(self) # Convert datetime to ISO format - d['timestamp'] = self.timestamp.isoformat() + d["timestamp"] = self.timestamp.isoformat() # Convert tool calls - d['tool_calls'] = [asdict(tc) for tc in self.tool_calls] + d["tool_calls"] = [asdict(tc) for tc in self.tool_calls] return d @classmethod - def from_dict(cls, d: Dict) -> 'InteractionRecord': + def from_dict(cls, d: dict) -> "InteractionRecord": """Deserialize from dictionary""" d = d.copy() - d['timestamp'] = datetime.fromisoformat(d['timestamp']) - d['tool_calls'] = [ - ToolCallRecord(**tc) for tc in d.get('tool_calls', []) - ] + d["timestamp"] = datetime.fromisoformat(d["timestamp"]) + d["tool_calls"] = [ToolCallRecord(**tc) for tc in d.get("tool_calls", [])] return cls(**d) @@ -152,7 +152,7 @@ def __init__( self.log_file = self.logs_dir / f"{session_id}.jsonl" # In-memory buffer of recent interactions (for correction detection) - self._recent_interactions: List[InteractionRecord] = [] + self._recent_interactions: list[InteractionRecord] = [] self._max_recent = 10 # Get codebase version (git commit) @@ -171,7 +171,7 @@ def _get_git_version(self) -> str: capture_output=True, text=True, cwd=str(self.storage_path.parent), - timeout=5 + timeout=5, ) if result.returncode == 0: return result.stdout.strip() @@ -182,7 +182,7 @@ def _get_git_version(self) -> str: def start_interaction( self, user_prompt: str, - system_state: Dict[str, Any], + system_state: dict[str, Any], ) -> InteractionRecord: """ Start recording a new interaction @@ -219,11 +219,11 @@ def record_tool_call( self, interaction: InteractionRecord, tool_name: str, - tool_input: Dict[str, Any], + tool_input: dict[str, Any], result: str, duration_seconds: float, is_error: bool = False, - error_message: Optional[str] = None, + error_message: str | None = None, ): """ Record a tool call within an interaction @@ -268,8 +268,8 @@ def complete_interaction( interaction: InteractionRecord, assistant_response: str, total_duration_seconds: float, - error: Optional[str] = None, - error_traceback: Optional[str] = None, + error: str | None = None, + error_traceback: str | None = None, ): """ Complete and save an interaction record @@ -342,17 +342,13 @@ def _detect_correction(self, current: InteractionRecord): f"(indicators: {indicators_found})" ) - def _save_interaction( - self, - interaction: InteractionRecord, - append: bool = True - ): + def _save_interaction(self, interaction: InteractionRecord, append: bool = True): """Save interaction to JSONL file""" try: if append: # Append to log file - with open(self.log_file, 'a', encoding='utf-8') as f: - f.write(json.dumps(interaction.to_dict()) + '\n') + with open(self.log_file, "a", encoding="utf-8") as f: + f.write(json.dumps(interaction.to_dict()) + "\n") else: # Need to update existing record - rewrite file # This is less efficient but corrections are rare @@ -368,7 +364,7 @@ def _rewrite_with_update(self, updated: InteractionRecord): # Read all interactions interactions = [] try: - with open(self.log_file, 'r', encoding='utf-8') as f: + with open(self.log_file, encoding="utf-8") as f: for line in f: if line.strip(): record = InteractionRecord.from_dict(json.loads(line)) @@ -382,30 +378,30 @@ def _rewrite_with_update(self, updated: InteractionRecord): # Rewrite file try: - with open(self.log_file, 'w', encoding='utf-8') as f: + with open(self.log_file, "w", encoding="utf-8") as f: for record in interactions: - f.write(json.dumps(record.to_dict()) + '\n') + f.write(json.dumps(record.to_dict()) + "\n") except Exception as e: logger.error(f"Failed to rewrite log file: {e}") - def _sanitize_state(self, state: Dict[str, Any]) -> Dict[str, Any]: + def _sanitize_state(self, state: dict[str, Any]) -> dict[str, Any]: """Remove large/sensitive data from state snapshot""" sanitized = {} # Keep summary info - if 'embryos' in state: - sanitized['embryo_count'] = len(state['embryos']) - sanitized['embryo_ids'] = list(state['embryos'].keys()) + if "embryos" in state: + sanitized["embryo_count"] = len(state["embryos"]) + sanitized["embryo_ids"] = list(state["embryos"].keys()) - if 'detectors' in state: - sanitized['detector_count'] = len(state['detectors']) + if "detectors" in state: + sanitized["detector_count"] = len(state["detectors"]) - if 'acquisition_status' in state: - sanitized['acquisition_status'] = state['acquisition_status'] + if "acquisition_status" in state: + sanitized["acquisition_status"] = state["acquisition_status"] return sanitized - def _sanitize_tool_input(self, tool_input: Dict[str, Any]) -> Dict[str, Any]: + def _sanitize_tool_input(self, tool_input: dict[str, Any]) -> dict[str, Any]: """Remove large/binary data from tool input""" sanitized = {} for key, value in tool_input.items(): @@ -425,14 +421,14 @@ def _sanitize_tool_input(self, tool_input: Dict[str, Any]) -> Dict[str, Any]: sanitized[key] = f"[{type(value).__name__}]" return sanitized - def get_session_stats(self) -> Dict[str, Any]: + def get_session_stats(self) -> dict[str, Any]: """Get statistics for current session""" if not self.log_file.exists(): return { - 'total_interactions': 0, - 'corrections': 0, - 'errors': 0, - 'tool_calls': 0, + "total_interactions": 0, + "corrections": 0, + "errors": 0, + "tool_calls": 0, } total = 0 @@ -441,35 +437,35 @@ def get_session_stats(self) -> Dict[str, Any]: tool_calls = 0 try: - with open(self.log_file, 'r', encoding='utf-8') as f: + with open(self.log_file, encoding="utf-8") as f: for line in f: if line.strip(): record = json.loads(line) total += 1 - if record.get('was_corrected'): + if record.get("was_corrected"): corrections += 1 - if record.get('error'): + if record.get("error"): errors += 1 - tool_calls += len(record.get('tool_calls', [])) + tool_calls += len(record.get("tool_calls", [])) except Exception: pass return { - 'total_interactions': total, - 'corrections': corrections, - 'errors': errors, - 'tool_calls': tool_calls, - 'correction_rate': corrections / total if total > 0 else 0, + "total_interactions": total, + "corrections": corrections, + "errors": errors, + "tool_calls": tool_calls, + "correction_rate": corrections / total if total > 0 else 0, } - def load_session_interactions(self) -> List[InteractionRecord]: + def load_session_interactions(self) -> list[InteractionRecord]: """Load all interactions from current session""" if not self.log_file.exists(): return [] interactions = [] try: - with open(self.log_file, 'r', encoding='utf-8') as f: + with open(self.log_file, encoding="utf-8") as f: for line in f: if line.strip(): record = InteractionRecord.from_dict(json.loads(line)) diff --git a/gently/harness/session/manager.py b/gently/harness/session/manager.py index 9985b324..b8d12824 100644 --- a/gently/harness/session/manager.py +++ b/gently/harness/session/manager.py @@ -8,7 +8,6 @@ import json import logging import uuid -from typing import Dict, List, Optional logger = logging.getLogger(__name__) @@ -24,7 +23,7 @@ class SessionManager: def __init__(self, store, storage_path): self.store = store self.storage_path = storage_path - self._session_id: Optional[str] = None + self._session_id: str | None = None @property def session_id(self) -> str: @@ -74,41 +73,45 @@ def _resume_session(self, session_id: str, experiment): conversation_history = [] if snapshot: - raw_history = snapshot.get('conversation_history', []) + raw_history = snapshot.get("conversation_history", []) conversation_history = self.sanitize_loaded_messages(raw_history) - experiment_data = snapshot.get('experiment_data', {}) - experiment.active_plan_item_id = experiment_data.get('active_plan_item_id') - embryo_states = experiment_data.get('embryos', {}) + experiment_data = snapshot.get("experiment_data", {}) + experiment.active_plan_item_id = experiment_data.get("active_plan_item_id") + embryo_states = experiment_data.get("embryos", {}) for embryo_id, embryo_data in embryo_states.items(): - pos = embryo_data.get('stage_position', {}) + pos = embryo_data.get("stage_position", {}) experiment.add_embryo( embryo_id=embryo_id, position=pos, - calibration=embryo_data.get('calibration', {}), - user_label=embryo_data.get('user_label'), - uid=embryo_data.get('uid'), + calibration=embryo_data.get("calibration", {}), + user_label=embryo_data.get("user_label"), + uid=embryo_data.get("uid"), ) embryo = experiment.embryos[embryo_id] - embryo.nickname = embryo_data.get('nickname') - embryo.interval_seconds = embryo_data.get('interval_seconds') - embryo.num_slices = embryo_data.get('num_slices', 50) - embryo.exposure_ms = embryo_data.get('exposure_ms', 10.0) - embryo.priority = embryo_data.get('priority', 'normal') - embryo.timepoints_acquired = embryo_data.get('timepoints_acquired', 0) - embryo.should_skip = embryo_data.get('should_skip', False) - embryo.skip_reason = embryo_data.get('skip_reason') - - # Also load embryos from store's embryo table + embryo.nickname = embryo_data.get("nickname") + embryo.interval_seconds = embryo_data.get("interval_seconds") + embryo.num_slices = embryo_data.get("num_slices", 50) + embryo.exposure_ms = embryo_data.get("exposure_ms", 10.0) + embryo.priority = embryo_data.get("priority", "normal") + embryo.timepoints_acquired = embryo_data.get("timepoints_acquired", 0) + embryo.should_skip = embryo_data.get("should_skip", False) + embryo.skip_reason = embryo_data.get("skip_reason") + + # Also load embryos from store's embryo table. FileStore returns + # position_coarse / position_fine (with legacy position_x / position_y + # backfilled into coarse on read), so both calibration stages survive + # the resume. store_embryos = self.store.list_embryos(session_id) for e in store_embryos: - eid = e['embryo_id'] + eid = e["embryo_id"] if eid not in experiment.embryos: experiment.add_embryo( embryo_id=eid, - position={'x': e.get('position_x'), 'y': e.get('position_y')}, - calibration=json.loads(e['calibration']) if e.get('calibration') else {}, + position=e.get("position_coarse") or {}, + position_fine=e.get("position_fine") or {}, + calibration=json.loads(e["calibration"]) if e.get("calibration") else {}, ) self.store.touch_session(session_id) @@ -137,11 +140,14 @@ def save_session(self, experiment, conversation_history, system_prompt) -> bool: if not self._session_id: return False try: - self.store.save_session_snapshot(self._session_id, { - 'conversation_history': self.serialize_messages(conversation_history), - 'experiment_data': experiment.to_dict(), - 'system_prompt': system_prompt, - }) + self.store.save_session_snapshot( + self._session_id, + { + "conversation_history": self.serialize_messages(conversation_history), + "experiment_data": experiment.to_dict(), + "system_prompt": system_prompt, + }, + ) self._sync_embryos_to_db(experiment) self.store.touch_session(self._session_id) return True @@ -154,11 +160,14 @@ def auto_save(self, experiment, conversation_history, system_prompt): if not self._session_id: return try: - self.store.save_session_snapshot(self._session_id, { - 'conversation_history': self.serialize_messages(conversation_history), - 'experiment_data': experiment.to_dict(), - 'system_prompt': system_prompt, - }) + self.store.save_session_snapshot( + self._session_id, + { + "conversation_history": self.serialize_messages(conversation_history), + "experiment_data": experiment.to_dict(), + "system_prompt": system_prompt, + }, + ) self._sync_embryos_to_db(experiment) self.store.touch_session(self._session_id) except Exception: @@ -169,15 +178,16 @@ def _sync_embryos_to_db(self, experiment): for embryo_id, embryo in experiment.embryos.items(): pos = embryo.stage_position or {} self.store.register_embryo( - self._session_id, embryo_id, - embryo_uid=getattr(embryo, 'uid', None), - nickname=getattr(embryo, 'user_label', None), - position_x=pos.get('x'), - position_y=pos.get('y'), + self._session_id, + embryo_id, + embryo_uid=getattr(embryo, "uid", None), + nickname=getattr(embryo, "user_label", None), + position_x=pos.get("x"), + position_y=pos.get("y"), calibration=embryo.calibration, ) - def list_sessions(self) -> List[Dict]: + def list_sessions(self) -> list[dict]: """ List available sessions from FileStore. @@ -188,8 +198,9 @@ def list_sessions(self) -> List[Dict]: """ return self.store.list_sessions() - def resume_session(self, session_id: str, experiment, conversation_mgr, - prompt_mgr_update_fn) -> bool: + def resume_session( + self, session_id: str, experiment, conversation_mgr, prompt_mgr_update_fn + ) -> bool: """ Resume a session (public interface for CLI). @@ -230,7 +241,7 @@ def resume_session(self, session_id: str, experiment, conversation_mgr, # ===== Message Serialization ===== @staticmethod - def sanitize_loaded_messages(messages: List[Dict]) -> List[Dict]: + def sanitize_loaded_messages(messages: list[dict]) -> list[dict]: """Fix conversation history loaded from JSON snapshots. Old snapshots may contain content blocks that were serialized @@ -240,7 +251,7 @@ def sanitize_loaded_messages(messages: List[Dict]) -> List[Dict]: """ clean = [] for msg in messages: - content = msg.get('content') + content = msg.get("content") if content is None: continue if isinstance(content, str): @@ -252,16 +263,16 @@ def sanitize_loaded_messages(messages: List[Dict]) -> List[Dict]: if isinstance(block, dict): valid_blocks.append(block) elif isinstance(block, str): - if block.startswith(('TextBlock(', 'ToolUseBlock(')): + if block.startswith(("TextBlock(", "ToolUseBlock(")): continue valid_blocks.append(block) if valid_blocks: - clean.append({**msg, 'content': valid_blocks}) + clean.append({**msg, "content": valid_blocks}) continue return clean @staticmethod - def serialize_messages(messages: List[Dict]) -> List[Dict]: + def serialize_messages(messages: list[dict]) -> list[dict]: """Convert conversation history to JSON-safe plain dicts. Anthropic SDK returns content blocks as objects (TextBlock, @@ -269,30 +280,31 @@ def serialize_messages(messages: List[Dict]) -> List[Dict]: repr strings. This converts everything to plain dicts so the history round-trips cleanly through JSON. """ + def _block_to_dict(block): if isinstance(block, dict): return block if isinstance(block, str): return block - if hasattr(block, 'model_dump'): + if hasattr(block, "model_dump"): return block.model_dump() - if hasattr(block, 'to_dict'): + if hasattr(block, "to_dict"): return block.to_dict() - if hasattr(block, 'type'): - d = {'type': block.type} - if block.type == 'text' and hasattr(block, 'text'): - d['text'] = block.text - elif block.type == 'tool_use': - d['id'] = getattr(block, 'id', '') - d['name'] = getattr(block, 'name', '') - d['input'] = getattr(block, 'input', {}) + if hasattr(block, "type"): + d = {"type": block.type} + if block.type == "text" and hasattr(block, "text"): + d["text"] = block.text + elif block.type == "tool_use": + d["id"] = getattr(block, "id", "") + d["name"] = getattr(block, "name", "") + d["input"] = getattr(block, "input", {}) return d return str(block) serialized = [] for msg in messages: - content = msg.get('content') + content = msg.get("content") if isinstance(content, list): content = [_block_to_dict(b) for b in content] - serialized.append({**msg, 'content': content}) + serialized.append({**msg, "content": content}) return serialized diff --git a/gently/harness/session/timeline.py b/gently/harness/session/timeline.py index 5d7ea6f0..b0ae7b92 100644 --- a/gently/harness/session/timeline.py +++ b/gently/harness/session/timeline.py @@ -11,12 +11,12 @@ import json import logging import threading -import uuid from collections import deque -from dataclasses import dataclass, field, asdict +from collections.abc import Callable +from dataclasses import dataclass, field from datetime import datetime, timedelta from pathlib import Path -from typing import Any, Callable, Dict, List, Optional +from typing import Any from gently.core.event_bus import Event, EventType, get_event_bus @@ -55,83 +55,84 @@ class TimelineEvent: severity : str Severity level: info | success | warning | error """ + event_id: str event_type: str event_subtype: str timestamp: datetime source: str - session_id: Optional[str] = None # Session this event belongs to - embryo_id: Optional[str] = None - detector_name: Optional[str] = None - timepoint: Optional[int] = None - confidence: Optional[str] = None - data: Dict[str, Any] = field(default_factory=dict) + session_id: str | None = None # Session this event belongs to + embryo_id: str | None = None + detector_name: str | None = None + timepoint: int | None = None + confidence: str | None = None + data: dict[str, Any] = field(default_factory=dict) icon: str = ">" severity: str = "info" - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Serialize to dictionary""" return { - 'event_id': self.event_id, - 'event_type': self.event_type, - 'event_subtype': self.event_subtype, - 'timestamp': self.timestamp.isoformat(), - 'source': self.source, - 'session_id': self.session_id, - 'embryo_id': self.embryo_id, - 'detector_name': self.detector_name, - 'timepoint': self.timepoint, - 'confidence': self.confidence, - 'data': self.data, - 'icon': self.icon, - 'severity': self.severity, + "event_id": self.event_id, + "event_type": self.event_type, + "event_subtype": self.event_subtype, + "timestamp": self.timestamp.isoformat(), + "source": self.source, + "session_id": self.session_id, + "embryo_id": self.embryo_id, + "detector_name": self.detector_name, + "timepoint": self.timepoint, + "confidence": self.confidence, + "data": self.data, + "icon": self.icon, + "severity": self.severity, } @classmethod - def from_dict(cls, d: Dict[str, Any]) -> 'TimelineEvent': + def from_dict(cls, d: dict[str, Any]) -> "TimelineEvent": """Deserialize from dictionary""" return cls( - event_id=d['event_id'], - event_type=d['event_type'], - event_subtype=d['event_subtype'], - timestamp=datetime.fromisoformat(d['timestamp']), - source=d.get('source', 'unknown'), - session_id=d.get('session_id'), - embryo_id=d.get('embryo_id'), - detector_name=d.get('detector_name'), - timepoint=d.get('timepoint'), - confidence=d.get('confidence'), - data=d.get('data', {}), - icon=d.get('icon', '>'), - severity=d.get('severity', 'info'), + event_id=d["event_id"], + event_type=d["event_type"], + event_subtype=d["event_subtype"], + timestamp=datetime.fromisoformat(d["timestamp"]), + source=d.get("source", "unknown"), + session_id=d.get("session_id"), + embryo_id=d.get("embryo_id"), + detector_name=d.get("detector_name"), + timepoint=d.get("timepoint"), + confidence=d.get("confidence"), + data=d.get("data", {}), + icon=d.get("icon", ">"), + severity=d.get("severity", "info"), ) @property def short_label(self) -> str: """Short label for timeline display (e.g., 'TL', 'DET')""" - if self.event_type == 'timelapse': - return 'TL' - elif self.event_type == 'detection': - return 'DET' + if self.event_type == "timelapse": + return "TL" + elif self.event_type == "detection": + return "DET" else: - return 'SYS' + return "SYS" @property def description(self) -> str: """Human-readable description of the event""" - if self.event_type == 'timelapse': - if self.event_subtype == 'started': - embryos = self.data.get('embryo_ids', []) + if self.event_type == "timelapse": + if self.event_subtype == "started": + embryos = self.data.get("embryo_ids", []) return f"Started timelapse with {len(embryos)} embryo(s)" - elif self.event_subtype == 'volume_acquired': + elif self.event_subtype == "volume_acquired": return f"{self.embryo_id} @ t={self.timepoint}" - elif self.event_subtype == 'completed': - total = self.data.get('total_timepoints', '?') + elif self.event_subtype == "completed": + total = self.data.get("total_timepoints", "?") return f"Completed ({total} timepoints)" - elif self.event_subtype == 'failed': + elif self.event_subtype == "failed": return f"Failed: {self.data.get('error', 'unknown error')}" - elif self.event_type == 'detection': - detected = self.data.get('detected', False) + elif self.event_type == "detection": + detected = self.data.get("detected", False) status = "Detected" if detected else "Not detected" conf = f" ({self.confidence})" if self.confidence else "" return f"{self.detector_name} on {self.embryo_id} - {status}{conf}" @@ -141,91 +142,91 @@ def description(self) -> str: # Mapping from EventBus EventType to TimelineEvent properties EVENT_MAPPING = { EventType.ACQUISITION_STARTED: { - 'event_type': 'timelapse', - 'event_subtype': 'started', - 'icon': '>', - 'severity': 'info', + "event_type": "timelapse", + "event_subtype": "started", + "icon": ">", + "severity": "info", }, EventType.VOLUME_ACQUIRED: { - 'event_type': 'timelapse', - 'event_subtype': 'volume_acquired', - 'icon': '+', - 'severity': 'success', + "event_type": "timelapse", + "event_subtype": "volume_acquired", + "icon": "+", + "severity": "success", }, EventType.ACQUISITION_COMPLETED: { - 'event_type': 'timelapse', - 'event_subtype': 'completed', - 'icon': '+', - 'severity': 'success', + "event_type": "timelapse", + "event_subtype": "completed", + "icon": "+", + "severity": "success", }, EventType.ACQUISITION_STOPPED: { - 'event_type': 'timelapse', - 'event_subtype': 'stopped', - 'icon': '-', - 'severity': 'info', + "event_type": "timelapse", + "event_subtype": "stopped", + "icon": "-", + "severity": "info", }, EventType.ACQUISITION_FAILED: { - 'event_type': 'timelapse', - 'event_subtype': 'failed', - 'icon': 'x', - 'severity': 'error', + "event_type": "timelapse", + "event_subtype": "failed", + "icon": "x", + "severity": "error", }, EventType.DETECTOR_EVALUATED: { - 'event_type': 'detection', - 'event_subtype': 'evaluated', - 'icon': '?', - 'severity': 'info', + "event_type": "detection", + "event_subtype": "evaluated", + "icon": "?", + "severity": "info", }, EventType.DETECTION_TRIGGERED: { - 'event_type': 'detection', - 'event_subtype': 'triggered', - 'icon': '!', - 'severity': 'success', + "event_type": "detection", + "event_subtype": "triggered", + "icon": "!", + "severity": "success", }, EventType.HATCHING_DETECTED: { - 'event_type': 'detection', - 'event_subtype': 'hatching', - 'icon': '+', - 'severity': 'success', + "event_type": "detection", + "event_subtype": "hatching", + "icon": "+", + "severity": "success", }, # Strategy / experiment view persistence — these were already emitted on # the EventBus but weren't being captured to timeline.jsonl, so the # swimlane view had no event history to replay. EventType.EMBRYO_CADENCE_CHANGED: { - 'event_type': 'timelapse', - 'event_subtype': 'cadence_changed', - 'icon': '~', - 'severity': 'info', + "event_type": "timelapse", + "event_subtype": "cadence_changed", + "icon": "~", + "severity": "info", }, EventType.POWER_RAMP_STEP: { - 'event_type': 'timelapse', - 'event_subtype': 'power_changed', - 'icon': '*', - 'severity': 'info', + "event_type": "timelapse", + "event_subtype": "power_changed", + "icon": "*", + "severity": "info", }, EventType.TRIGGER_FIRED: { - 'event_type': 'timelapse', - 'event_subtype': 'trigger_fired', - 'icon': '<>', - 'severity': 'info', + "event_type": "timelapse", + "event_subtype": "trigger_fired", + "icon": "<>", + "severity": "info", }, EventType.BURST_QUEUED: { - 'event_type': 'timelapse', - 'event_subtype': 'burst_queued', - 'icon': 'q', - 'severity': 'info', + "event_type": "timelapse", + "event_subtype": "burst_queued", + "icon": "q", + "severity": "info", }, EventType.BURST_START: { - 'event_type': 'timelapse', - 'event_subtype': 'burst_started', - 'icon': '^', - 'severity': 'info', + "event_type": "timelapse", + "event_subtype": "burst_started", + "icon": "^", + "severity": "info", }, EventType.BURST_COMPLETE: { - 'event_type': 'timelapse', - 'event_subtype': 'burst_completed', - 'icon': 'v', - 'severity': 'success', + "event_type": "timelapse", + "event_subtype": "burst_completed", + "icon": "v", + "severity": "success", }, } @@ -243,9 +244,9 @@ class TimelineManager: def __init__( self, - storage_path: Optional[Path] = None, + storage_path: Path | None = None, max_events: int = 1000, - session_id: Optional[str] = None, + session_id: str | None = None, ): """ Parameters @@ -262,7 +263,7 @@ def __init__( self._session_id = session_id self._events: deque[TimelineEvent] = deque(maxlen=max_events) self._lock = threading.RLock() - self._unsubscribers: List[Callable] = [] + self._unsubscribers: list[Callable] = [] self._started = False # Load existing events from storage @@ -274,7 +275,7 @@ def set_session_id(self, session_id: str) -> None: self._session_id = session_id @property - def storage_file(self) -> Optional[Path]: + def storage_file(self) -> Path | None: """Path to the timeline JSONL file""" if self._storage_path: return self._storage_path / "timeline.jsonl" @@ -317,18 +318,18 @@ def _on_event(self, event: Event) -> None: timeline_event = TimelineEvent( event_id=event.event_id, - event_type=mapping['event_type'], - event_subtype=mapping['event_subtype'], + event_type=mapping["event_type"], + event_subtype=mapping["event_subtype"], timestamp=event.timestamp, source=event.source, session_id=self._session_id, # Tag with current session - embryo_id=data.get('embryo_id'), - detector_name=data.get('detector_name'), - timepoint=data.get('timepoint'), - confidence=data.get('confidence'), + embryo_id=data.get("embryo_id"), + detector_name=data.get("detector_name"), + timepoint=data.get("timepoint"), + confidence=data.get("confidence"), data=data, - icon=mapping['icon'], - severity=mapping['severity'], + icon=mapping["icon"], + severity=mapping["severity"], ) self.add_event(timeline_event) @@ -352,13 +353,13 @@ def add_event(self, event: TimelineEvent) -> None: def get_events( self, - event_type: Optional[str] = None, - embryo_id: Optional[str] = None, - since: Optional[datetime] = None, - until: Optional[datetime] = None, - session_id: Optional[str] = "current", + event_type: str | None = None, + embryo_id: str | None = None, + since: datetime | None = None, + until: datetime | None = None, + session_id: str | None = "current", limit: int = 50, - ) -> List[TimelineEvent]: + ) -> list[TimelineEvent]: """ Get filtered events from timeline @@ -405,7 +406,7 @@ def get_events( # Return limited, oldest first (chronological) return events[-limit:] if len(events) > limit else events - def get_time_range(self) -> tuple[Optional[datetime], Optional[datetime]]: + def get_time_range(self) -> tuple[datetime | None, datetime | None]: """ Get the time range of events in the timeline @@ -421,7 +422,7 @@ def get_time_range(self) -> tuple[Optional[datetime], Optional[datetime]]: return events[0].timestamp, events[-1].timestamp - def clear_events(self, before: Optional[datetime] = None) -> int: + def clear_events(self, before: datetime | None = None) -> int: """ Clear events from timeline @@ -443,7 +444,7 @@ def clear_events(self, before: Optional[datetime] = None) -> int: old_count = len(self._events) self._events = deque( (e for e in self._events if e.timestamp >= before), - maxlen=self._max_events + maxlen=self._max_events, ) count = old_count - len(self._events) @@ -460,11 +461,11 @@ def _load_from_file(self) -> None: return try: - with open(self.storage_file, 'r', encoding='utf-8') as f: + with open(self.storage_file, encoding="utf-8") as f: for line in f: line = line.strip() # Only parse lines that look like JSON objects - if line and line.startswith('{'): + if line and line.startswith("{"): try: data = json.loads(line) event = TimelineEvent.from_dict(data) @@ -484,8 +485,8 @@ def _persist_event(self, event: TimelineEvent) -> None: # Ensure directory exists self._storage_path.mkdir(parents=True, exist_ok=True) - with open(self.storage_file, 'a', encoding='utf-8') as f: - f.write(json.dumps(event.to_dict()) + '\n') + with open(self.storage_file, "a", encoding="utf-8") as f: + f.write(json.dumps(event.to_dict()) + "\n") except Exception as e: logger.error(f"Error persisting timeline event: {e}") @@ -500,9 +501,9 @@ def _rewrite_storage(self) -> None: with self._lock: events = list(self._events) - with open(self.storage_file, 'w', encoding='utf-8') as f: + with open(self.storage_file, "w", encoding="utf-8") as f: for event in events: - f.write(json.dumps(event.to_dict()) + '\n') + f.write(json.dumps(event.to_dict()) + "\n") except Exception as e: logger.error(f"Error rewriting timeline storage: {e}") @@ -512,7 +513,7 @@ def __len__(self) -> int: return len(self._events) -def parse_time_delta(s: str) -> Optional[timedelta]: +def parse_time_delta(s: str) -> timedelta | None: """ Parse a time delta string like "1h", "30m", "2d" @@ -531,13 +532,13 @@ def parse_time_delta(s: str) -> Optional[timedelta]: return None try: - if s.endswith('m'): + if s.endswith("m"): return timedelta(minutes=int(s[:-1])) - elif s.endswith('h'): + elif s.endswith("h"): return timedelta(hours=int(s[:-1])) - elif s.endswith('d'): + elif s.endswith("d"): return timedelta(days=int(s[:-1])) - elif s.endswith('w'): + elif s.endswith("w"): return timedelta(weeks=int(s[:-1])) else: # Try parsing as minutes diff --git a/gently/harness/state.py b/gently/harness/state.py index a6113480..34b7d8f6 100644 --- a/gently/harness/state.py +++ b/gently/harness/state.py @@ -25,11 +25,13 @@ fields are now on ``EmbryoState`` directly. """ +import logging import re +from collections.abc import Callable from dataclasses import dataclass, field from datetime import datetime -from typing import Any, Dict, List, Optional, Tuple -from pathlib import Path +from typing import Any + import numpy as np # Re-export CalibrationPrior from its hardware-specific home for backward compat. @@ -37,6 +39,8 @@ # modules will define their own calibration models. from gently.hardware.dispim.calibration import CalibrationPrior +logger = logging.getLogger(__name__) + @dataclass class FocusDataPoint: @@ -56,13 +60,14 @@ class FocusDataPoint: - z: primary focus axis (µm) — piezo for diSPIM, Z-motor for 2P/confocal - secondary_axis: optional second axis — galvo for diSPIM, unused (0.0) for single-axis systems """ - z: float # Primary focus position (µm) + + z: float # Primary focus position (µm) secondary_axis: float # Secondary axis position (galvo deg for diSPIM, 0.0 otherwise) - score: float # Focus quality score (algorithm-dependent) - r_squared: float # Gaussian fit quality (0-1), higher = more reliable - timestamp: datetime # When this measurement was made - method: str # 'calibration', 'fine_focus', 'manual' - algorithm: str = 'fft_bandpass' # Focus algorithm used + score: float # Focus quality score (algorithm-dependent) + r_squared: float # Gaussian fit quality (0-1), higher = more reliable + timestamp: datetime # When this measurement was made + method: str # 'calibration', 'fine_focus', 'manual' + algorithm: str = "fft_bandpass" # Focus algorithm used # Backward-compatible properties for code that uses the old field names @property @@ -73,38 +78,39 @@ def piezo(self) -> float: def galvo(self) -> float: return self.secondary_axis - def to_dict(self) -> Dict: + def to_dict(self) -> dict: """Serialize for JSON storage""" return { - 'z': self.z, - 'secondary_axis': self.secondary_axis, - 'score': self.score, - 'r_squared': self.r_squared, - 'timestamp': self.timestamp.isoformat(), - 'method': self.method, - 'algorithm': self.algorithm, + "z": self.z, + "secondary_axis": self.secondary_axis, + "score": self.score, + "r_squared": self.r_squared, + "timestamp": self.timestamp.isoformat(), + "method": self.method, + "algorithm": self.algorithm, # Backward-compatible keys for existing serialized data - 'galvo': self.secondary_axis, - 'piezo': self.z, + "galvo": self.secondary_axis, + "piezo": self.z, } @classmethod - def from_dict(cls, data: Dict) -> 'FocusDataPoint': + def from_dict(cls, data: dict) -> "FocusDataPoint": """Deserialize from JSON. Handles both old (galvo/piezo) and new (z/secondary_axis) keys.""" return cls( - z=data.get('z', data.get('piezo', 0.0)), - secondary_axis=data.get('secondary_axis', data.get('galvo', 0.0)), - score=data['score'], - r_squared=data['r_squared'], - timestamp=datetime.fromisoformat(data['timestamp']), - method=data['method'], - algorithm=data.get('algorithm', 'fft_bandpass'), + z=data.get("z", data.get("piezo", 0.0)), + secondary_axis=data.get("secondary_axis", data.get("galvo", 0.0)), + score=data["score"], + r_squared=data["r_squared"], + timestamp=datetime.fromisoformat(data["timestamp"]), + method=data["method"], + algorithm=data.get("algorithm", "fft_bandpass"), ) @dataclass class ImageRecord: """Record of a single acquired image/volume""" + embryo_id: str timepoint: int timestamp: datetime @@ -112,8 +118,8 @@ class ImageRecord: max_projection_b64: str # Base64-encoded JPEG for Claude Vision size_kb: float # UID-based data references (new data layer) - volume_uid: Optional[str] = None # UID for volume in DataStore - projection_uid: Optional[str] = None # UID for max projection in DataStore + volume_uid: str | None = None # UID for volume in DataStore + projection_uid: str | None = None # UID for max projection in DataStore @dataclass @@ -122,22 +128,27 @@ class EmbryoState: # Identity id: str # "embryo_1" - uid: Optional[str] = None # Global unique identifier for cross-session tracking - nickname: Optional[str] = None # Agent-assigned: "the fast one" - user_label: Optional[str] = None # User-provided: "control_1" + uid: str | None = None # Global unique identifier for cross-session tracking + nickname: str | None = None # Agent-assigned: "the fast one" + user_label: str | None = None # User-provided: "control_1" # Role key into gently.harness.roles.REGISTRY. Drives cadence policy, # detector selection, photodose budget, UI presentation. Default "test" # is the safe choice — accidental Calibration→Test only over-protects; # accidental Test→Calibration would burn extra dose on the precious sample. role: str = "test" - # Position - stage_position: Dict[str, float] = field(default_factory=dict) # {'x': 1234.5, 'y': 5678.9} - calibration: Dict = field(default_factory=dict) # Galvo/piezo parameters + # Position — two-stage: coarse (bottom-camera detection or manual map + # placement, always present once an embryo exists) and fine (populated + # later by SPIM-objective alignment). Resolved value is exposed by the + # `stage_position` property so downstream motion/perception can stay + # agnostic about which stage we're in. + position_coarse: dict[str, float] = field(default_factory=dict) # {'x': ..., 'y': ...} + position_fine: dict[str, float] = field(default_factory=dict) # empty until SPIM head alignment + calibration: dict = field(default_factory=dict) # Galvo/piezo parameters detection_confidence: float = 0.0 # SAM/detection confidence score (0-1) # Acquisition Parameters (current) - interval_seconds: Optional[float] = None # Per-embryo interval; None = use timelapse default + interval_seconds: float | None = None # Per-embryo interval; None = use timelapse default num_slices: int = 50 exposure_ms: float = 10.0 priority: str = "normal" # high/normal/low @@ -145,13 +156,13 @@ class EmbryoState: # Per-embryo 488 laser power %. None = use device-layer default (no # change at acquire time). Float values are hard-limited at the device # layer by DiSPIMLightSource.POWER_LIMITS_PCT[488] (default 2-6%). - laser_power_488_pct: Optional[float] = None + laser_power_488_pct: float | None = None # Status - last_imaged: Optional[datetime] = None + last_imaged: datetime | None = None timepoints_acquired: int = 0 should_skip: bool = False - skip_reason: Optional[str] = None + skip_reason: str | None = None # Timelapse runtime state (consolidated from former EmbryoAcquisitionState). # Populated/used by TimelapseOrchestrator while this embryo is part of an @@ -161,12 +172,12 @@ class EmbryoState: # is a Phase 9 concern. stop_condition: Any = None is_complete: bool = False - completion_reason: Optional[str] = None + completion_reason: str | None = None error_count: int = 0 - last_error: Optional[str] = None - detection_triggered_at: Optional[int] = None - detection_type: Optional[str] = None - no_object_since_timepoint: Optional[int] = None + last_error: str | None = None + detection_triggered_at: int | None = None + detection_type: str | None = None + no_object_since_timepoint: int | None = None # Count of consecutive "no_object" detections. Reset to 0 whenever # the embryo is detected again. When this crosses the role's # ``no_object_consecutive_terminal`` threshold, the orchestrator @@ -184,58 +195,65 @@ class EmbryoState: # - paused: skip in the due loop (over-budget, manually paused, or # idle during another embryo's burst) cadence_phase: str = "normal" - next_due_at: Optional[datetime] = None + next_due_at: datetime | None = None # Light exposure tracking (for phototoxicity monitoring) exposure_count: int = 0 # Number of imaging events (snaps + volumes) total_exposure_ms: float = 0.0 # Cumulative laser-on time in milliseconds # Analysis Results (cached) - hatching_status: Dict = field(default_factory=dict) + hatching_status: dict = field(default_factory=dict) # {hatched: bool, confidence: str, timepoint: int} - morphology_history: List[Dict] = field(default_factory=list) + morphology_history: list[dict] = field(default_factory=list) # [{timepoint, size, shape, activity_score}] - fluorescence_history: List[Dict] = field(default_factory=list) + fluorescence_history: list[dict] = field(default_factory=list) # [{timepoint, mean_intensity, photobleaching_estimate}] - custom_classifications: Dict = field(default_factory=dict) + custom_classifications: dict = field(default_factory=dict) # User-defined: {"first_cleavage": {detected: bool, timepoint: 42}} # Verification round tracking (for consecutive confirmation) pending_verification: bool = False # True when detection fired, awaiting verification consecutive_detection_count: int = 0 # Must reach 5 consecutive verified detections to stop - last_detection_round: Optional[int] = None # Round when detection was last verified + last_detection_round: int | None = None # Round when detection was last verified # Detection results from detector system - detection_results: Dict[str, List[Dict]] = field(default_factory=dict) + detection_results: dict[str, list[dict]] = field(default_factory=dict) # detector_name -> list of detection results # e.g., {"comma_stage": [{"timepoint": 120, "detected": False, "confidence": "HIGH"}, ...]} # CV Subagent analysis results (populated from CV_RESULT_READY events) - cv_analyses: Dict[str, List[Dict]] = field(default_factory=dict) + cv_analyses: dict[str, list[dict]] = field(default_factory=dict) # result_type -> list of results by timepoint # e.g., {"nuclei_count": [{"timepoint": 5, "num_nuclei": 66, ...}]} # Quick-access fields for latest CV results (for /embryos display) - latest_nuclei_count: Optional[int] = None - latest_developmental_stage: Optional[str] = None - latest_elongation_ratio: Optional[float] = None + latest_nuclei_count: int | None = None + latest_developmental_stage: str | None = None + latest_elongation_ratio: float | None = None # Images (recent for context) - recent_images: List[ImageRecord] = field(default_factory=list) + recent_images: list[ImageRecord] = field(default_factory=list) # Keep last 10 for temporal context in Claude Vision calls # Focus history - accumulated piezo-galvo measurements over time - focus_history: List[FocusDataPoint] = field(default_factory=list) + focus_history: list[FocusDataPoint] = field(default_factory=list) # Each focus operation adds a datapoint, building a focus map for this embryo - def add_focus_datapoint(self, z: float = None, secondary_axis: float = 0.0, - score: float = 0.0, r_squared: float = 0.0, - method: str = 'manual', algorithm: str = 'fft_bandpass', - # Backward-compatible kwargs - galvo: float = None, piezo: float = None): + def add_focus_datapoint( + self, + z: float | None = None, + secondary_axis: float = 0.0, + score: float = 0.0, + r_squared: float = 0.0, + method: str = "manual", + algorithm: str = "fft_bandpass", + # Backward-compatible kwargs + galvo: float | None = None, + piezo: float | None = None, + ): """ Record a focus measurement for this embryo. @@ -268,19 +286,24 @@ def add_focus_datapoint(self, z: float = None, secondary_axis: float = 0.0, if galvo is not None: secondary_axis = galvo - self.focus_history.append(FocusDataPoint( - z=z, - secondary_axis=secondary_axis, - score=score, - r_squared=r_squared, - timestamp=datetime.now(), - method=method, - algorithm=algorithm, - )) - - def get_focus_at_secondary(self, secondary_position: float, - max_age_hours: Optional[float] = None, - min_r_squared: float = 0.5) -> Optional[float]: + self.focus_history.append( + FocusDataPoint( + z=z, + secondary_axis=secondary_axis, + score=score, + r_squared=r_squared, + timestamp=datetime.now(), + method=method, + algorithm=algorithm, + ) + ) + + def get_focus_at_secondary( + self, + secondary_position: float, + max_age_hours: float | None = None, + min_r_squared: float = 0.5, + ) -> float | None: """ Get the best Z position for a given secondary axis position. @@ -326,35 +349,38 @@ def get_focus_at_secondary(self, secondary_position: float, axis_distance = abs(fp.secondary_axis - secondary_position) age_hours = (now - fp.timestamp).total_seconds() / 3600 - candidates.append({ - 'z': fp.z, - 'axis_distance': axis_distance, - 'age_hours': age_hours, - 'r_squared': fp.r_squared, - }) + candidates.append( + { + "z": fp.z, + "axis_distance": axis_distance, + "age_hours": age_hours, + "r_squared": fp.r_squared, + } + ) if not candidates: return None # If we have exact matches, use the most recent - exact_matches = [c for c in candidates if c['axis_distance'] < 0.01] + exact_matches = [c for c in candidates if c["axis_distance"] < 0.01] if exact_matches: # Sort by recency, return most recent - exact_matches.sort(key=lambda x: x['age_hours']) - return exact_matches[0]['z'] + exact_matches.sort(key=lambda x: x["age_hours"]) + return exact_matches[0]["z"] # Otherwise, interpolate from nearby measurements # Sort by axis distance - candidates.sort(key=lambda x: x['axis_distance']) - return candidates[0]['z'] # Return closest match + candidates.sort(key=lambda x: x["axis_distance"]) + return candidates[0]["z"] # Return closest match # Backward-compatible alias - def get_focus_at_galvo(self, galvo_position: float, **kwargs) -> Optional[float]: + def get_focus_at_galvo(self, galvo_position: float, **kwargs) -> float | None: """Backward-compatible alias for get_focus_at_secondary.""" return self.get_focus_at_secondary(galvo_position, **kwargs) - def get_z_axis_fit(self, max_age_hours: Optional[float] = None, - min_r_squared: float = 0.5) -> Optional[Tuple[float, float]]: + def get_z_axis_fit( + self, max_age_hours: float | None = None, min_r_squared: float = 0.5 + ) -> tuple[float, float] | None: """ Fit a linear relationship between Z and secondary axis from accumulated data. @@ -407,13 +433,16 @@ def get_z_axis_fit(self, max_age_hours: Optional[float] = None, return None # Backward-compatible alias - def get_piezo_galvo_fit(self, **kwargs) -> Optional[Tuple[float, float]]: + def get_piezo_galvo_fit(self, **kwargs) -> tuple[float, float] | None: """Backward-compatible alias for get_z_axis_fit.""" return self.get_z_axis_fit(**kwargs) - def get_focus_drift_rate(self, secondary_position: float = 0.0, - galvo_position: float = None, - min_measurements: int = 3) -> Optional[float]: + def get_focus_drift_rate( + self, + secondary_position: float = 0.0, + galvo_position: float | None = None, + min_measurements: int = 3, + ) -> float | None: """ Calculate how fast focus is drifting (µm/hour) at a given secondary axis position. @@ -434,8 +463,11 @@ def get_focus_drift_rate(self, secondary_position: float = 0.0, if galvo_position is not None: secondary_position = galvo_position # Get measurements at similar secondary axis position - relevant = [fp for fp in self.focus_history - if abs(fp.secondary_axis - secondary_position) < 0.1 and fp.r_squared >= 0.5] + relevant = [ + fp + for fp in self.focus_history + if abs(fp.secondary_axis - secondary_position) < 0.1 and fp.r_squared >= 0.5 + ] if len(relevant) < min_measurements: return None @@ -462,9 +494,12 @@ def get_focus_drift_rate(self, secondary_position: float = 0.0, except Exception: return None - def needs_refocus(self, max_age_minutes: float = 60, - secondary_position: float = 0.0, - galvo_position: float = None) -> bool: + def needs_refocus( + self, + max_age_minutes: float = 60, + secondary_position: float = 0.0, + galvo_position: float | None = None, + ) -> bool: """ Determine if this embryo needs focus re-measurement. @@ -515,7 +550,8 @@ def get_focus_summary(self) -> str: lines = [ f"Focus history: {n_points} measurements over {span_hours:.1f} hours", - f"Latest: z={last.z:.2f}µm @ secondary={last.secondary_axis:.2f} (R²={last.r_squared:.3f})", + f"Latest: z={last.z:.2f}µm @ secondary={last.secondary_axis:.2f}" + f" (R²={last.r_squared:.3f})", ] if drift is not None: @@ -528,7 +564,7 @@ def get_focus_summary(self) -> str: return "\n".join(lines) - def add_detection_result(self, detector_name: str, result: Dict): + def add_detection_result(self, detector_name: str, result: dict): """ Add detection result from detector system @@ -544,7 +580,7 @@ def add_detection_result(self, detector_name: str, result: Dict): self.detection_results[detector_name].append(result) - def get_latest_detection(self, detector_name: str) -> Optional[Dict]: + def get_latest_detection(self, detector_name: str) -> dict | None: """Get most recent detection result for a detector""" if detector_name not in self.detection_results: return None @@ -573,15 +609,15 @@ def was_detected(self, detector_name: str, require_verified: bool = False) -> bo return False for result in self.detection_results[detector_name]: - if result.get('detected', False): + if result.get("detected", False): if require_verified: - if result.get('verified', False): + if result.get("verified", False): return True else: return True return False - def mark_detection_verified(self, detector_name: str, timepoint: Optional[int] = None) -> bool: + def mark_detection_verified(self, detector_name: str, timepoint: int | None = None) -> bool: """ Mark a detection result as verified by the challenger system. @@ -608,19 +644,19 @@ def mark_detection_verified(self, detector_name: str, timepoint: Optional[int] = if timepoint is not None: # Find by timepoint for result in results: - if result.get('timepoint') == timepoint and result.get('detected', False): - result['verified'] = True + if result.get("timepoint") == timepoint and result.get("detected", False): + result["verified"] = True return True else: # Mark the most recent detected result for result in reversed(results): - if result.get('detected', False): - result['verified'] = True + if result.get("detected", False): + result["verified"] = True return True return False - def add_cv_result(self, result_type: str, result: Dict): + def add_cv_result(self, result_type: str, result: dict): """ Add CV analysis result from CV subagent. @@ -635,8 +671,8 @@ def add_cv_result(self, result_type: str, result: Dict): self.cv_analyses[result_type] = [] # Add timestamp if not present - if 'timestamp' not in result: - result['timestamp'] = datetime.now().isoformat() + if "timestamp" not in result: + result["timestamp"] = datetime.now().isoformat() self.cv_analyses[result_type].append(result) @@ -648,11 +684,7 @@ def add_cv_result(self, result_type: str, result: Dict): elif result_type == "elongation" and "elongation_ratio" in result: self.latest_elongation_ratio = result["elongation_ratio"] - def get_cv_result( - self, - result_type: str, - timepoint: Optional[int] = None - ) -> Optional[Dict]: + def get_cv_result(self, result_type: str, timepoint: int | None = None) -> dict | None: """ Get CV analysis result, optionally filtered by timepoint. @@ -683,7 +715,7 @@ def get_cv_result( # Return most recent return results[-1] - def get_cv_summary(self) -> Dict: + def get_cv_summary(self) -> dict: """ Get summary of CV analysis results for display. @@ -697,27 +729,27 @@ def get_cv_summary(self) -> Dict: "developmental_stage": self.latest_developmental_stage, "elongation_ratio": self.latest_elongation_ratio, "analyses_count": { - result_type: len(results) - for result_type, results in self.cv_analyses.items() + result_type: len(results) for result_type, results in self.cv_analyses.items() }, } - def update_from_analysis(self, analysis_result: Dict): + def update_from_analysis(self, analysis_result: dict): """Update state with new analysis""" - if 'hatching' in analysis_result: - self.hatching_status = analysis_result['hatching'] + if "hatching" in analysis_result: + self.hatching_status = analysis_result["hatching"] - if 'morphology' in analysis_result: - self.morphology_history.append({ - 'timepoint': self.timepoints_acquired, - **analysis_result['morphology'] - }) + if "morphology" in analysis_result: + self.morphology_history.append( + {"timepoint": self.timepoints_acquired, **analysis_result["morphology"]} + ) - if 'fluorescence' in analysis_result: - self.fluorescence_history.append({ - 'timepoint': self.timepoints_acquired, - **analysis_result['fluorescence'] - }) + if "fluorescence" in analysis_result: + self.fluorescence_history.append( + { + "timepoint": self.timepoints_acquired, + **analysis_result["fluorescence"], + } + ) def to_summary(self) -> str: """Format for Claude system prompt""" @@ -741,7 +773,7 @@ def to_summary(self) -> str: status_parts.append("not yet imaged") # Status - if self.hatching_status.get('hatched'): + if self.hatching_status.get("hatched"): status_parts.append(f"hatched at t{self.hatching_status['timepoint']:04d}") elif self.should_skip: status_parts.append(f"skipped ({self.skip_reason})") @@ -757,7 +789,12 @@ def to_summary(self) -> str: return " | ".join(status_parts) - def record_exposure(self, exposure_ms: float, num_frames: int = 1, timestamp: Optional[datetime] = None): + def record_exposure( + self, + exposure_ms: float, + num_frames: int = 1, + timestamp: datetime | None = None, + ): """ Record light exposure for phototoxicity tracking. @@ -789,39 +826,69 @@ def get_exposure_summary(self) -> str: return f"{self.exposure_count} exposures, {time_str} total" - def to_dict(self) -> Dict: + @property + def stage_position(self) -> dict[str, float]: + """Resolved XY position — fine if SPIM-aligned, else coarse. + + Coarse comes from the bottom-camera detection / manual map placement. + Fine comes from the SPIM-objective alignment workflow (not built yet). + Callers that just want "where is this embryo" read this; callers that + care about calibration state read position_coarse / position_fine + directly. + """ + return self.position_fine if self.position_fine else self.position_coarse + + @stage_position.setter + def stage_position(self, value: dict[str, float]) -> None: + """Back-compat setter — writes to coarse. + + Legacy callers that assigned `embryo.stage_position = {...}` were + writing a bottom-camera / manual position; that's the coarse stage. + New code should set position_coarse or position_fine explicitly. + """ + self.position_coarse = value or {} + + @property + def has_fine_position(self) -> bool: + """True once SPIM-objective alignment has refined the coarse position.""" + return bool(self.position_fine) + + def to_dict(self) -> dict: """Serialize for API responses""" return { - 'id': self.id, - 'uid': self.uid, - 'nickname': self.nickname, - 'user_label': self.user_label, - 'role': self.role, - 'stage_position': self.stage_position, - 'calibration': self.calibration, - 'detection_confidence': self.detection_confidence, - 'interval_seconds': self.interval_seconds, - 'num_slices': self.num_slices, - 'exposure_ms': self.exposure_ms, - 'priority': self.priority, - 'acquisition_mode': self.acquisition_mode, - 'laser_power_488_pct': self.laser_power_488_pct, - 'last_imaged': self.last_imaged.isoformat() if self.last_imaged else None, - 'timepoints_acquired': self.timepoints_acquired, - 'should_skip': self.should_skip, - 'skip_reason': self.skip_reason, - 'exposure_count': self.exposure_count, - 'total_exposure_ms': self.total_exposure_ms, - 'hatching_status': self.hatching_status, - 'pending_verification': self.pending_verification, - 'consecutive_detection_count': self.consecutive_detection_count, - 'last_detection_round': self.last_detection_round, - 'recent_analyses': { - 'morphology': self.morphology_history[-5:] if self.morphology_history else [], - 'fluorescence': self.fluorescence_history[-5:] if self.fluorescence_history else [], - 'custom': self.custom_classifications, + "id": self.id, + "uid": self.uid, + "nickname": self.nickname, + "user_label": self.user_label, + "role": self.role, + "stage_position": self.stage_position, + "position_coarse": self.position_coarse, + "position_fine": self.position_fine, + "has_fine_position": self.has_fine_position, + "calibration": self.calibration, + "detection_confidence": self.detection_confidence, + "interval_seconds": self.interval_seconds, + "num_slices": self.num_slices, + "exposure_ms": self.exposure_ms, + "priority": self.priority, + "acquisition_mode": self.acquisition_mode, + "laser_power_488_pct": self.laser_power_488_pct, + "last_imaged": self.last_imaged.isoformat() if self.last_imaged else None, + "timepoints_acquired": self.timepoints_acquired, + "should_skip": self.should_skip, + "skip_reason": self.skip_reason, + "exposure_count": self.exposure_count, + "total_exposure_ms": self.total_exposure_ms, + "hatching_status": self.hatching_status, + "pending_verification": self.pending_verification, + "consecutive_detection_count": self.consecutive_detection_count, + "last_detection_round": self.last_detection_round, + "recent_analyses": { + "morphology": self.morphology_history[-5:] if self.morphology_history else [], + "fluorescence": self.fluorescence_history[-5:] if self.fluorescence_history else [], + "custom": self.custom_classifications, }, - 'focus_history': [fp.to_dict() for fp in self.focus_history], + "focus_history": [fp.to_dict() for fp in self.focus_history], } @@ -829,38 +896,71 @@ class ExperimentState: """Global experiment state""" def __init__(self): - self.embryos: Dict[str, EmbryoState] = {} - self.start_time: Optional[datetime] = None + self.embryos: dict[str, EmbryoState] = {} + self.start_time: datetime | None = None self.acquisition_status: str = "idle" # idle/running/paused/completed - self.current_plan_name: Optional[str] = None - self.plan_history: List[Dict] = [] - self.metadata: Dict = {} + self.current_plan_name: str | None = None + self.plan_history: list[dict] = [] + self.metadata: dict = {} # Active plan item — set during plan context resolution at startup. # When set, the agent's system prompt includes the full ImagingSpec # so it knows what it's here to do without being told. - self.active_plan_item_id: Optional[str] = None + self.active_plan_item_id: str | None = None # Session-level calibration prior for cross-embryo learning # Updated after each successful calibration, used to initialize subsequent embryos self.calibration_prior: CalibrationPrior = CalibrationPrior() - def add_embryo(self, embryo_id: str, position: Dict = None, - calibration: Dict = None, user_label: Optional[str] = None, - confidence: float = 0.0, uid: Optional[str] = None, - role: str = "test"): + # Observer hook — agent wires this at startup to publish EMBRYOS_UPDATE + # over the event bus. Kept as a plain callback so this module stays + # bus-agnostic. + self.on_embryos_changed: Callable[[], None] | None = None + + def notify_embryos_changed(self) -> None: + """Fire the on_embryos_changed observer if one is wired. + + Call this after any mutation the agent can't intercept through + add_embryo / remove_embryo (e.g. a direct write to + embryo.position_coarse). UI hooks must not raise — failures here are + swallowed so state mutations stay durable. + """ + cb = self.on_embryos_changed + if cb is None: + return + try: + cb() + except Exception: + logger.exception("ExperimentState.on_embryos_changed callback failed") + + def add_embryo( + self, + embryo_id: str, + position: dict | None = None, + calibration: dict | None = None, + user_label: str | None = None, + confidence: float = 0.0, + uid: str | None = None, + role: str = "test", + position_fine: dict | None = None, + ): """Register new embryo. ``role`` must be a key in :data:`gently.harness.roles.REGISTRY` (e.g. ``"test"``, ``"calibration"``, ``"unassigned"``). Unknown roles raise KeyError. + `position` is the coarse XY (bottom-camera detection or manual map + placement). `position_fine` is reserved for the future SPIM-objective + alignment workflow and defaults to empty. + Emits an ``EMBRYO_DETECTED`` event so listeners (e.g. the viz server's TimelapseStateTracker, which feeds the device map) learn about marked embryos immediately — not just after the first acquisition. """ from gently.harness.roles import get_role + get_role(role) # raises KeyError if unknown # Auto-start experiment when first embryo is added @@ -871,17 +971,20 @@ def add_embryo(self, embryo_id: str, position: Dict = None, self.embryos[embryo_id] = EmbryoState( id=embryo_id, uid=uid, - stage_position=pos, + position_coarse=position or {}, + position_fine=position_fine or {}, calibration=calibration or {}, user_label=user_label, detection_confidence=confidence, role=role, ) + self.notify_embryos_changed() # Fire the registration event. Late-bound import keeps this module # decoupled from the event bus until first use. try: from gently.core import EventType, get_event_bus + get_event_bus().publish( event_type=EventType.EMBRYO_DETECTED, data={ @@ -903,6 +1006,7 @@ def remove_embryo(self, embryo_id: str) -> bool: """Remove embryo from experiment (e.g., false detection)""" if embryo_id in self.embryos: del self.embryos[embryo_id] + self.notify_embryos_changed() return True return False @@ -910,8 +1014,9 @@ def assign_nickname(self, embryo_id: str, nickname: str): """Agent assigns intuitive name""" if embryo_id in self.embryos: self.embryos[embryo_id].nickname = nickname + self.notify_embryos_changed() - def get_embryo_by_any_name(self, name: str) -> Optional[EmbryoState]: + def get_embryo_by_any_name(self, name: str) -> EmbryoState | None: """Get embryo by ID, nickname, or user label""" # Direct ID match if name in self.embryos: @@ -923,7 +1028,7 @@ def get_embryo_by_any_name(self, name: str) -> Optional[EmbryoState]: return embryo # Try extracting number from name like "embryo 3" -> "embryo_3" - match = re.search(r'(\d+)', name) + match = re.search(r"(\d+)", name) if match: num = int(match.group(1)) # Try simple format first (embryo_3) @@ -951,7 +1056,7 @@ def get_summary(self) -> str: f"Duration: {hours}h {minutes}m", f"Embryos: {len(self.embryos)}", "", - "Per-embryo status:" + "Per-embryo status:", ] for embryo in sorted(self.embryos.values(), key=lambda e: e.id): @@ -965,15 +1070,15 @@ def get_summary(self) -> str: return "\n".join(lines) - def to_dict(self) -> Dict: + def to_dict(self) -> dict: """Serialize for API responses""" return { - 'start_time': self.start_time.isoformat() if self.start_time else None, - 'acquisition_status': self.acquisition_status, - 'current_plan_name': self.current_plan_name, - 'active_plan_item_id': self.active_plan_item_id, - 'embryo_count': len(self.embryos), - 'embryos': {eid: e.to_dict() for eid, e in self.embryos.items()}, - 'metadata': self.metadata, - 'calibration_prior': self.calibration_prior.to_dict(), + "start_time": self.start_time.isoformat() if self.start_time else None, + "acquisition_status": self.acquisition_status, + "current_plan_name": self.current_plan_name, + "active_plan_item_id": self.active_plan_item_id, + "embryo_count": len(self.embryos), + "embryos": {eid: e.to_dict() for eid, e in self.embryos.items()}, + "metadata": self.metadata, + "calibration_prior": self.calibration_prior.to_dict(), } diff --git a/gently/harness/tools/helpers.py b/gently/harness/tools/helpers.py index 3d6b03b8..b4c58f39 100644 --- a/gently/harness/tools/helpers.py +++ b/gently/harness/tools/helpers.py @@ -5,17 +5,39 @@ used across multiple tools to reduce code duplication. """ -from typing import Any, Dict, List, Optional, Tuple from datetime import datetime +from typing import Any -def require_agent(context: Dict) -> Tuple[Optional[Any], Optional[str]]: +def ctx_get(context: dict | None, key: str) -> Any: + """ + Look up a key in a (possibly missing) tool execution context + + Parameters + ---------- + context : dict | None + Tool execution context + key : str + Key to look up + + Returns + ------- + Any + The value for ``key``, or ``None`` if ``context`` is ``None`` or + the key is absent. + """ + if context is None: + return None + return context.get(key) + + +def require_agent(context: dict | None) -> tuple[Any, str | None]: """ Extract agent from context or return error message Parameters ---------- - context : dict + context : dict | None Tool execution context Returns @@ -23,13 +45,13 @@ def require_agent(context: Dict) -> Tuple[Optional[Any], Optional[str]]: tuple (agent, None) if found, (None, error_message) if not """ - agent = context.get('agent') + agent = ctx_get(context, "agent") if not agent: return None, "Error: No agent context" return agent, None -def get_embryo_or_error(agent, embryo_id: str) -> Tuple[Optional[Any], Optional[str]]: +def get_embryo_or_error(agent, embryo_id: str) -> tuple[Any, str | None]: """ Get embryo by any name or return error message @@ -51,13 +73,13 @@ def get_embryo_or_error(agent, embryo_id: str) -> Tuple[Optional[Any], Optional[ return embryo, None -def require_microscope(context: Dict) -> Tuple[Optional[Any], Optional[str]]: +def require_microscope(context: dict | None) -> tuple[Any, str | None]: """ Get microscope client from context or return error message Parameters ---------- - context : dict + context : dict | None Tool execution context Returns @@ -65,13 +87,13 @@ def require_microscope(context: Dict) -> Tuple[Optional[Any], Optional[str]]: tuple (client, None) if connected, (None, error_message) if not """ - client = context.get('client') + client = ctx_get(context, "client") if not client: return None, "Not connected to microscope. Use connect_microscope first." return client, None -def require_interaction_logger(agent) -> Tuple[Optional[Any], Optional[str]]: +def require_interaction_logger(agent) -> tuple[Any, str | None]: """ Get interaction logger or return error message @@ -85,12 +107,12 @@ def require_interaction_logger(agent) -> Tuple[Optional[Any], Optional[str]]: tuple (logger, None) if available, (None, error_message) if not """ - if not hasattr(agent, 'interaction_logger') or not agent.interaction_logger: + if not hasattr(agent, "interaction_logger") or not agent.interaction_logger: return None, "Interaction logging not enabled." return agent.interaction_logger, None -def require_developmental_tracker(agent) -> Tuple[Optional[Any], Optional[str]]: +def require_developmental_tracker(agent) -> tuple[Any, str | None]: """ Get developmental tracker or return error message @@ -104,12 +126,15 @@ def require_developmental_tracker(agent) -> Tuple[Optional[Any], Optional[str]]: tuple (tracker, None) if available, (None, error_message) if not """ - if not hasattr(agent, 'developmental_tracker') or not agent.developmental_tracker: - return None, "No stage classifications recorded yet. Use classify_embryo_stage first." + if not hasattr(agent, "developmental_tracker") or not agent.developmental_tracker: + return ( + None, + "No stage classifications recorded yet. Use classify_embryo_stage first.", + ) return agent.developmental_tracker, None -def require_timelapse_orchestrator(agent) -> Tuple[Optional[Any], Optional[str]]: +def require_timelapse_orchestrator(agent) -> tuple[Any, str | None]: """ Get timelapse orchestrator or return error message @@ -123,12 +148,12 @@ def require_timelapse_orchestrator(agent) -> Tuple[Optional[Any], Optional[str]] tuple (orchestrator, None) if available, (None, error_message) if not """ - if not hasattr(agent, 'timelapse_orchestrator') or agent.timelapse_orchestrator is None: + if not hasattr(agent, "timelapse_orchestrator") or agent.timelapse_orchestrator is None: return None, "Timelapse orchestrator not initialized." return agent.timelapse_orchestrator, None -def require_databroker(agent) -> Tuple[Optional[Any], Optional[str]]: +def require_databroker(agent) -> tuple[Any, str | None]: """ Get databroker connection or return error message @@ -142,7 +167,7 @@ def require_databroker(agent) -> Tuple[Optional[Any], Optional[str]]: tuple (databroker, None) if available, (None, error_message) if not """ - if not hasattr(agent, 'databroker') or agent.databroker is None: + if not hasattr(agent, "databroker") or agent.databroker is None: return None, "No databroker connection. Data persistence not available." return agent.databroker, None @@ -184,13 +209,13 @@ def format_duration(seconds: float) -> str: def build_snapshot_metadata( - stage_position: Tuple[float, float], - image_shape: Tuple[int, ...], + stage_position: tuple[float, float], + image_shape: tuple[int, ...], experiment=None, pixel_size_um: float = 6.5, objective_mag: float = 10.0, - safety_limits: Optional[Dict] = None, -) -> Dict: + safety_limits: dict | None = None, +) -> dict: """Build metadata dict for a bottom camera snapshot. Captures everything needed to reconstruct embryo positions @@ -226,7 +251,7 @@ def build_snapshot_metadata( # gently/hardware/dispim/devices/stage.py::DiSPIMXYStage.__init__. safety_limits = {"x": (2000.0, 4000.0), "y": (-1000.0, 1000.0)} - meta: Dict[str, Any] = { + meta: dict[str, Any] = { "stage_x": stage_position[0], "stage_y": stage_position[1], "image_width": w, @@ -242,15 +267,17 @@ def build_snapshot_metadata( } if experiment and experiment.embryos: - embryos: List[Dict] = [] + embryos: list[dict] = [] for eid, emb in experiment.embryos.items(): pos = emb.stage_position or {} - embryos.append({ - "embryo_id": eid, - "stage_x": pos.get("x"), - "stage_y": pos.get("y"), - "nickname": getattr(emb, "nickname", None), - }) + embryos.append( + { + "embryo_id": eid, + "stage_x": pos.get("x"), + "stage_y": pos.get("y"), + "nickname": getattr(emb, "nickname", None), + } + ) meta["embryos"] = embryos return meta diff --git a/gently/harness/tools/registry.py b/gently/harness/tools/registry.py index 106bd7a2..88c37241 100644 --- a/gently/harness/tools/registry.py +++ b/gently/harness/tools/registry.py @@ -13,49 +13,56 @@ import functools import inspect import logging +import time +from collections.abc import Callable from dataclasses import dataclass, field from enum import Enum, auto from typing import ( - Any, Callable, Dict, List, Optional, Type, Union, - get_type_hints, get_origin, get_args + Any, + Union, + get_args, + get_origin, + get_type_hints, ) -import time logger = logging.getLogger(__name__) class ToolCategory(Enum): """Categories for organizing tools""" - ACQUISITION = auto() # Volume/image acquisition - MOVEMENT = auto() # Stage movement, positioning - CALIBRATION = auto() # Calibration procedures - ANALYSIS = auto() # Image/volume analysis - DETECTION = auto() # Detector management - EXPERIMENT = auto() # Experiment state management - EMBRYO = auto() # Embryo-specific operations - HARDWARE = auto() # Direct hardware control - DATA = auto() # Data/Databroker operations - UTILITY = auto() # Utility functions - ML = auto() # Machine learning training - TRANSFER = auto() # Bulk data transfer + + ACQUISITION = auto() # Volume/image acquisition + MOVEMENT = auto() # Stage movement, positioning + CALIBRATION = auto() # Calibration procedures + ANALYSIS = auto() # Image/volume analysis + DETECTION = auto() # Detector management + EXPERIMENT = auto() # Experiment state management + EMBRYO = auto() # Embryo-specific operations + HARDWARE = auto() # Direct hardware control + DATA = auto() # Data/Databroker operations + UTILITY = auto() # Utility functions + ML = auto() # Machine learning training + TRANSFER = auto() # Bulk data transfer @dataclass class ToolParameter: """Definition of a tool parameter""" + name: str type: str # JSON schema type description: str required: bool = True default: Any = None - enum: Optional[List[str]] = None + enum: list[str] | None = None @dataclass class ToolExample: """Example of when to use a tool""" + user_query: str - tool_input: Dict = field(default_factory=dict) + tool_input: dict = field(default_factory=dict) @dataclass @@ -69,17 +76,18 @@ class ToolDefinition: - Documentation - Filtering/discovery """ + name: str description: str handler: Callable - parameters: List[ToolParameter] = field(default_factory=list) + parameters: list[ToolParameter] = field(default_factory=list) category: ToolCategory = ToolCategory.UTILITY requires_microscope: bool = False is_async: bool = False - tags: List[str] = field(default_factory=list) - examples: List[ToolExample] = field(default_factory=list) + tags: list[str] = field(default_factory=list) + examples: list[ToolExample] = field(default_factory=list) - def to_claude_schema(self) -> Dict: + def to_claude_schema(self) -> dict: """Generate Claude API tool schema with examples embedded in description""" properties = {} required = [] @@ -114,7 +122,7 @@ def to_claude_schema(self) -> Dict: "type": "object", "properties": properties, "required": required, - } + }, } @@ -156,29 +164,29 @@ def _python_type_to_json_schema(python_type) -> str: return "string" -def _extract_parameters_from_function(func: Callable) -> List[ToolParameter]: +def _extract_parameters_from_function(func: Callable) -> list[ToolParameter]: """Extract parameter definitions from function signature and type hints""" sig = inspect.signature(func) - hints = get_type_hints(func) if hasattr(func, '__annotations__') else {} + hints = get_type_hints(func) if hasattr(func, "__annotations__") else {} doc = inspect.getdoc(func) or "" # Parse docstring for parameter descriptions param_docs = {} in_params = False current_param = None - for line in doc.split('\n'): + for line in doc.split("\n"): line = line.strip() - if line.lower().startswith('parameters'): + if line.lower().startswith("parameters"): in_params = True continue if in_params: - if line.startswith('---'): + if line.startswith("---"): continue - if line.lower().startswith('returns'): + if line.lower().startswith("returns"): in_params = False continue - if ' : ' in line: - parts = line.split(' : ') + if " : " in line: + parts = line.split(" : ") current_param = parts[0].strip() param_docs[current_param] = "" elif current_param and line: @@ -187,7 +195,7 @@ def _extract_parameters_from_function(func: Callable) -> List[ToolParameter]: parameters = [] for param_name, param in sig.parameters.items(): # Skip 'self', 'tool_input' (legacy pattern), and 'context' (injected at runtime) - if param_name in ('self', 'tool_input', 'context'): + if param_name in ("self", "tool_input", "context"): continue python_type = hints.get(param_name, str) @@ -200,13 +208,15 @@ def _extract_parameters_from_function(func: Callable) -> List[ToolParameter]: # Get description from docstring description = param_docs.get(param_name, f"The {param_name} parameter").strip() - parameters.append(ToolParameter( - name=param_name, - type=json_type, - description=description, - required=required, - default=default, - )) + parameters.append( + ToolParameter( + name=param_name, + type=json_type, + description=description, + required=required, + default=default, + ) + ) return parameters @@ -223,8 +233,8 @@ class ToolRegistry: """ def __init__(self): - self._tools: Dict[str, ToolDefinition] = {} - self._context: Dict[str, Any] = {} # Shared context (agent, client, etc.) + self._tools: dict[str, ToolDefinition] = {} + self._context: dict[str, Any] = {} # Shared context (agent, client, etc.) def set_context(self, key: str, value: Any): """Set shared context available to all tools""" @@ -236,13 +246,13 @@ def get_context(self, key: str) -> Any: def register( self, - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, category: ToolCategory = ToolCategory.UTILITY, requires_microscope: bool = False, - tags: Optional[List[str]] = None, - parameters: Optional[List[ToolParameter]] = None, - examples: Optional[List[ToolExample]] = None, + tags: list[str] | None = None, + parameters: list[ToolParameter] | None = None, + examples: list[ToolExample] | None = None, ) -> Callable: """ Decorator to register a function as a tool @@ -277,9 +287,10 @@ async def acquire_volume(embryo_id: str, num_slices: int = 50) -> str: examples : list of ToolExample, optional Usage examples showing when to call this tool """ + def decorator(func: Callable) -> Callable: tool_name = name or func.__name__ - tool_desc = description or (inspect.getdoc(func) or "").split('\n')[0] + tool_desc = description or (inspect.getdoc(func) or "").split("\n")[0] # Extract or use provided parameters tool_params = parameters or _extract_parameters_from_function(func) @@ -311,11 +322,11 @@ async def wrapper(*args, **kwargs): def register_function( self, func: Callable, - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, category: ToolCategory = ToolCategory.UTILITY, requires_microscope: bool = False, - tags: Optional[List[str]] = None, + tags: list[str] | None = None, ): """ Register an existing function as a tool (non-decorator form) @@ -336,7 +347,7 @@ def register_function( Additional tags """ tool_name = name or func.__name__ - tool_desc = description or (inspect.getdoc(func) or "").split('\n')[0] + tool_desc = description or (inspect.getdoc(func) or "").split("\n")[0] tool_params = _extract_parameters_from_function(func) tool_def = ToolDefinition( @@ -360,23 +371,23 @@ def unregister(self, name: str) -> bool: return True return False - def get(self, name: str) -> Optional[ToolDefinition]: + def get(self, name: str) -> ToolDefinition | None: """Get tool definition by name""" return self._tools.get(name) - def list_all(self) -> List[ToolDefinition]: + def list_all(self) -> list[ToolDefinition]: """List all registered tools""" return list(self._tools.values()) - def list_by_category(self, category: ToolCategory) -> List[ToolDefinition]: + def list_by_category(self, category: ToolCategory) -> list[ToolDefinition]: """List tools in a category""" return [t for t in self._tools.values() if t.category == category] - def list_by_tag(self, tag: str) -> List[ToolDefinition]: + def list_by_tag(self, tag: str) -> list[ToolDefinition]: """List tools with a specific tag""" return [t for t in self._tools.values() if tag in t.tags] - def list_available(self, has_microscope: bool = False) -> List[ToolDefinition]: + def list_available(self, has_microscope: bool = False) -> list[ToolDefinition]: """List tools available given current context""" tools = [] for tool in self._tools.values(): @@ -385,14 +396,11 @@ def list_available(self, has_microscope: bool = False) -> List[ToolDefinition]: tools.append(tool) return tools - def get_claude_schemas(self, has_microscope: bool = False) -> List[Dict]: + def get_claude_schemas(self, has_microscope: bool = False) -> list[dict]: """Get Claude API tool schemas for available tools""" - return [ - tool.to_claude_schema() - for tool in self.list_available(has_microscope) - ] + return [tool.to_claude_schema() for tool in self.list_available(has_microscope)] - async def execute(self, tool_name: str, tool_input: Dict, context: Dict = None) -> str: + async def execute(self, tool_name: str, tool_input: dict, context: dict | None = None) -> str: """ Execute a tool by name @@ -420,14 +428,30 @@ async def execute(self, tool_name: str, tool_input: Dict, context: Dict = None) # 3. Fall back to stored registry context if context is not None: exec_context = context - elif 'context' in tool_input and tool_input['context'] is not None: - exec_context = tool_input['context'] + elif "context" in tool_input and tool_input["context"] is not None: + exec_context = tool_input["context"] else: exec_context = self._context + # Hybrid-autonomy backstop: during an autonomous (wake) turn, a small set + # of irreversible tools (laser-on, embryo termination, stopping the run) + # must NEVER execute without a human — even if the model tries to call + # them directly. The agent sets these flags around its autonomous turns; + # user-driven turns are unaffected. The blocked set is supplied by the + # agent so this layer stays free of app-specific tool names. + _agent = exec_context.get("agent") if isinstance(exec_context, dict) else None + if _agent is not None and getattr(_agent, "_autonomous_active", False): + blocked = getattr(_agent, "_autonomous_blocked_tools", None) or () + if tool_name in blocked: + logger.info("Autonomy backstop blocked '%s' (irreversible)", tool_name) + return ( + f"'{tool_name}' is an irreversible action and cannot run " + f"autonomously. Ask the operator to confirm it." + ) + # Check microscope requirement if tool.requires_microscope: - client = exec_context.get('client') + client = exec_context.get("client") if client is None: return "Error: Not connected to microscope server. Start the server and reconnect." @@ -439,8 +463,8 @@ async def execute(self, tool_name: str, tool_input: Dict, context: Dict = None) # Inject context if handler expects it (but don't overwrite if already provided) sig = inspect.signature(tool.handler) - if 'context' in sig.parameters and 'context' not in kwargs: - kwargs['context'] = exec_context + if "context" in sig.parameters and "context" not in kwargs: + kwargs["context"] = exec_context # Execute handler if tool.is_async: @@ -455,6 +479,7 @@ async def execute(self, tool_name: str, tool_input: Dict, context: Dict = None) except Exception as e: import traceback + logger.error(f"Tool {tool_name} failed: {e}") return f"Error executing {tool_name}: {str(e)}\n{traceback.format_exc()}" @@ -466,7 +491,7 @@ def __len__(self) -> int: # Global registry instance -_global_registry: Optional[ToolRegistry] = None +_global_registry: ToolRegistry | None = None def get_tool_registry() -> ToolRegistry: @@ -485,12 +510,12 @@ def set_tool_registry(registry: ToolRegistry): # Convenience decorator using global registry def tool( - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, category: ToolCategory = ToolCategory.UTILITY, requires_microscope: bool = False, - tags: Optional[List[str]] = None, - examples: Optional[List[ToolExample]] = None, + tags: list[str] | None = None, + examples: list[ToolExample] | None = None, ) -> Callable: """ Decorator to register a tool with the global registry diff --git a/gently/log_config.py b/gently/log_config.py index 4db29ad1..69a673a0 100644 --- a/gently/log_config.py +++ b/gently/log_config.py @@ -9,6 +9,7 @@ GENTLY_LOG_FORMAT — console format string GENTLY_LOG_DATEFMT — timestamp format (default: %H:%M:%S) """ + import logging import os import sys @@ -19,8 +20,8 @@ def configure_logging( - level: str = None, - log_file: str = None, + level: str | None = None, + log_file: str | None = None, ): """Configure root logger for the Gently system. @@ -54,9 +55,18 @@ def configure_logging( lgr.addHandler(console) # Suppress noisy third-party loggers on console - for name in ("uvicorn", "uvicorn.error", "uvicorn.access", - "httpx", "httpcore", "anthropic", "aiohttp", - "aiohttp.access", "bluesky", "bluesky.RE.state"): + for name in ( + "uvicorn", + "uvicorn.error", + "uvicorn.access", + "httpx", + "httpcore", + "anthropic", + "aiohttp", + "aiohttp.access", + "bluesky", + "bluesky.RE.state", + ): logging.getLogger(name).setLevel(logging.WARNING) # File handler — always INFO+ regardless of console level diff --git a/gently/mesh/audit.py b/gently/mesh/audit.py index 06e76729..e4f06746 100644 --- a/gently/mesh/audit.py +++ b/gently/mesh/audit.py @@ -53,7 +53,7 @@ def _count_lines(self): """Count existing lines for rotation tracking.""" if self._log_file.exists(): try: - with open(self._log_file, "r") as f: + with open(self._log_file) as f: self._line_count = sum(1 for _ in f) except OSError: self._line_count = 0 @@ -88,7 +88,7 @@ def log( def _rotate(self): """Keep last KEEP_LINES, discard the rest.""" try: - with open(self._log_file, "r") as f: + with open(self._log_file) as f: lines = f.readlines() keep = lines[-KEEP_LINES:] with open(self._log_file, "w") as f: diff --git a/gently/mesh/capability_provider.py b/gently/mesh/capability_provider.py index 2453dce9..7dfc6fd1 100644 --- a/gently/mesh/capability_provider.py +++ b/gently/mesh/capability_provider.py @@ -8,18 +8,19 @@ import logging import os import platform -from typing import Any, Dict, List, Optional +from typing import Any from .models import DatasetAdvertisement, GpuInfo, PeerRole logger = logging.getLogger(__name__) -def _detect_gpus() -> List[GpuInfo]: +def _detect_gpus() -> list[GpuInfo]: """Detect available NVIDIA GPUs via torch.cuda (pynvml fallback).""" gpus = [] try: import torch + if torch.cuda.is_available(): for i in range(torch.cuda.device_count()): props = torch.cuda.get_device_properties(i) @@ -28,23 +29,26 @@ def _detect_gpus() -> List[GpuInfo]: mem_used_gb = 0.0 try: import pynvml + pynvml.nvmlInit() handle = pynvml.nvmlDeviceGetHandleByIndex(i) util = pynvml.nvmlDeviceGetUtilizationRates(handle) util_pct = float(util.gpu) mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) - mem_used_gb = mem_info.used / (1024 ** 3) + mem_used_gb = mem_info.used / (1024**3) except Exception: pass - gpus.append(GpuInfo( - device_index=i, - name=props.name, - vram_gb=round(props.total_mem / (1024 ** 3), 1), - compute_capability=f"{props.major}.{props.minor}", - utilization_pct=util_pct, - memory_used_gb=round(mem_used_gb, 2), - )) + gpus.append( + GpuInfo( + device_index=i, + name=props.name, + vram_gb=round(props.total_mem / (1024**3), 1), + compute_capability=f"{props.major}.{props.minor}", + utilization_pct=util_pct, + memory_used_gb=round(mem_used_gb, 2), + ) + ) except ImportError: pass except Exception as e: @@ -52,13 +56,14 @@ def _detect_gpus() -> List[GpuInfo]: return gpus -def _get_system_info() -> Dict[str, Any]: +def _get_system_info() -> dict[str, Any]: """Get CPU and RAM info.""" cpu_cores = os.cpu_count() or 0 ram_gb = 0.0 try: if platform.system() == "Windows": import ctypes + kernel32 = ctypes.windll.kernel32 mem_status = ctypes.c_ulonglong() kernel32.GetPhysicallyInstalledSystemMemory(ctypes.byref(mem_status)) @@ -92,12 +97,12 @@ def __init__( self, gently_store=None, device_layer=None, - static_caps: Optional[Dict[str, Any]] = None, + static_caps: dict[str, Any] | None = None, ): self._store = gently_store self._device_layer = device_layer self._static = static_caps or {} - self._gpus: List[GpuInfo] = [] + self._gpus: list[GpuInfo] = [] self._system_info = _get_system_info() # Initial GPU detection (cached, refreshed on demand) self._gpus = _detect_gpus() @@ -106,7 +111,7 @@ def refresh_gpus(self): """Re-detect GPUs (call periodically for live utilization).""" self._gpus = _detect_gpus() - def _compute_roles(self) -> List[str]: + def _compute_roles(self) -> list[str]: """Determine dynamic roles based on current state.""" roles = [] # Microscope controller if device is connected and responding @@ -131,7 +136,7 @@ def _compute_roles(self) -> List[str]: roles.append(PeerRole.PLANNER.value) return roles - def _get_datasets(self) -> List[DatasetAdvertisement]: + def _get_datasets(self) -> list[DatasetAdvertisement]: """Query FileStore for dataset advertisements.""" if self._store is None: return [] @@ -164,15 +169,17 @@ def _get_datasets(self) -> List[DatasetAdvertisement]: except Exception: pass - datasets.append(DatasetAdvertisement( - session_id=sid, - session_name=sname, - embryo_count=embryo_count, - volume_count=vol_count, - has_ground_truth=gt_count > 0, - ground_truth_count=gt_count, - stages_covered=sorted(stages), - )) + datasets.append( + DatasetAdvertisement( + session_id=sid, + session_name=sname, + embryo_count=embryo_count, + volume_count=vol_count, + has_ground_truth=gt_count > 0, + ground_truth_count=gt_count, + stages_covered=sorted(stages), + ) + ) except Exception as e: logger.debug(f"Dataset advertisement failed: {e}") return datasets @@ -186,7 +193,7 @@ def _is_microscope_connected(self) -> bool: pass return False - def __call__(self) -> Dict[str, Any]: + def __call__(self) -> dict[str, Any]: """Build the full capability dict. Called on each heartbeat.""" datasets = self._get_datasets() roles = self._compute_roles() @@ -199,10 +206,12 @@ def __call__(self) -> Dict[str, Any]: storage_total_gb = 0.0 try: import shutil + from ..settings import settings + usage = shutil.disk_usage(str(settings.storage.base_path)) - storage_free_gb = round(usage.free / (1024 ** 3), 1) - storage_total_gb = round(usage.total / (1024 ** 3), 1) + storage_free_gb = round(usage.free / (1024**3), 1) + storage_total_gb = round(usage.total / (1024**3), 1) except Exception: pass diff --git a/gently/mesh/discovery.py b/gently/mesh/discovery.py index 29fecd4a..878a9b81 100644 --- a/gently/mesh/discovery.py +++ b/gently/mesh/discovery.py @@ -15,7 +15,7 @@ import logging import socket import time -from typing import Callable, Optional +from collections.abc import Callable from ..settings import settings @@ -74,7 +74,7 @@ def __init__( self._pairing_manager = pairing_manager self._audit_log = audit_log self._known_ids: set = set() - self.transport: Optional[asyncio.DatagramTransport] = None + self.transport: asyncio.DatagramTransport | None = None def connection_made(self, transport: asyncio.DatagramTransport): self.transport = transport @@ -101,9 +101,12 @@ def datagram_received(self, data: bytes, addr: tuple): logger.debug(f"Mesh: rejected stale packet from {peer_id[:8]} (ts={ts})") if self._audit_log: from .audit import AuditEvent + self._audit_log.log( - AuditEvent.REPLAY_REJECTED, outcome="deny", - peer_id=peer_id, ip=sender_ip, + AuditEvent.REPLAY_REJECTED, + outcome="deny", + peer_id=peer_id, + ip=sender_ip, detail=f"ts_delta={abs(time.time() - ts):.1f}s", ) return @@ -120,9 +123,12 @@ def datagram_received(self, data: bytes, addr: tuple): logger.debug(f"Mesh: bad signature from {peer_id[:8]}") if self._audit_log: from .audit import AuditEvent + self._audit_log.log( - AuditEvent.SIG_INVALID, outcome="deny", - peer_id=peer_id, ip=sender_ip, + AuditEvent.SIG_INVALID, + outcome="deny", + peer_id=peer_id, + ip=sender_ip, ) if msg_type == "nudge": @@ -138,7 +144,7 @@ def datagram_received(self, data: bytes, addr: tuple): def error_received(self, exc: Exception): logger.debug(f"Mesh UDP error: {exc}") - def connection_lost(self, exc: Optional[Exception]): + def connection_lost(self, exc: Exception | None): pass def forget_peer(self, instance_id: str): @@ -169,9 +175,9 @@ def __init__( self._pairing_manager = pairing_manager self._audit_log = audit_log - self._protocol: Optional[_MeshProtocol] = None - self._transport: Optional[asyncio.DatagramTransport] = None - self._broadcast_task: Optional[asyncio.Task] = None + self._protocol: _MeshProtocol | None = None + self._transport: asyncio.DatagramTransport | None = None + self._broadcast_task: asyncio.Task | None = None self._running = False # Callbacks — set by MeshService before start() @@ -213,8 +219,7 @@ async def start(self): self._running = True self._broadcast_task = asyncio.create_task(self._broadcast_loop()) logger.info( - f"Mesh discovery started on port {self.mesh_port} " - f"(instance={self.instance_id[:8]})" + f"Mesh discovery started on port {self.mesh_port} (instance={self.instance_id[:8]})" ) async def stop(self): @@ -259,9 +264,7 @@ def send_nudge(self): packet = json.dumps(payload).encode("utf-8") try: - self._transport.sendto( - packet, ("255.255.255.255", self.mesh_port) - ) + self._transport.sendto(packet, ("255.255.255.255", self.mesh_port)) except OSError as e: logger.debug(f"Mesh nudge broadcast failed: {e}") @@ -286,9 +289,7 @@ async def _broadcast_loop(self): try: if self._transport: - self._transport.sendto( - heartbeat, ("255.255.255.255", self.mesh_port) - ) + self._transport.sendto(heartbeat, ("255.255.255.255", self.mesh_port)) except OSError as e: logger.debug(f"Mesh broadcast failed: {e}") diff --git a/gently/mesh/mesh_service.py b/gently/mesh/mesh_service.py index edac242c..3a4a7e52 100644 --- a/gently/mesh/mesh_service.py +++ b/gently/mesh/mesh_service.py @@ -11,12 +11,13 @@ import asyncio import logging import time +from collections.abc import Callable from pathlib import Path -from typing import Callable, Dict, List, Optional from gently.core.event_bus import EventType from gently.core.service import Service +from ..settings import settings from .discovery import MeshDiscovery from .models import PeerCapability, PeerInfo, PeerStatus from .peer_client import PeerClient @@ -24,8 +25,6 @@ logger = logging.getLogger(__name__) -from ..settings import settings - REAPER_INTERVAL = settings.mesh.reaper_interval_s STATUS_REFRESH_INTERVAL = settings.mesh.status_refresh_s @@ -59,7 +58,7 @@ def __init__( mesh_port: int = settings.network.mesh_port, pairing_manager=None, audit_log=None, - config_dir: Optional[Path] = None, + config_dir: Path | None = None, ): import socket as _socket @@ -78,12 +77,12 @@ def __init__( self._audit_log = audit_log self._hostname = _socket.gethostname() - self._peers: Dict[str, PeerInfo] = {} - self._discovery: Optional[MeshDiscovery] = None - self._peer_client: Optional[PeerClient] = None - self._reaper_task: Optional[asyncio.Task] = None - self._refresh_task: Optional[asyncio.Task] = None - self._cleanup_task: Optional[asyncio.Task] = None + self._peers: dict[str, PeerInfo] = {} + self._discovery: MeshDiscovery | None = None + self._peer_client: PeerClient | None = None + self._reaper_task: asyncio.Task | None = None + self._refresh_task: asyncio.Task | None = None + self._cleanup_task: asyncio.Task | None = None # Persistent verse map if config_dir is None: @@ -121,7 +120,8 @@ async def on_start(self): # When our own status changes, broadcast a nudge to all peers self._status_unsub = self._event_bus.subscribe( - EventType.STATUS_CHANGED, self._on_local_status_changed, + EventType.STATUS_CHANGED, + self._on_local_status_changed, ) async def on_stop(self): @@ -156,7 +156,8 @@ def _on_peer_discovered(self, data: dict, sender_ip: str, verified: bool = False # Check if this peer is already trusted trusted = ( self._pairing_manager.is_trusted(peer_id) - if self._pairing_manager else True # no manager = trust all (backward compat) + if self._pairing_manager + else True # no manager = trust all (backward compat) ) # Determine TLS status — trusted peers with a cert fingerprint use TLS @@ -187,24 +188,28 @@ def _on_peer_discovered(self, data: dict, sender_ip: str, verified: bool = False if was_offline: # Previously offline peer returned self._verse_map.on_peer_returned(peer_id) - self._emit_event(EventType.MESH_PEER_RETURNED, { - "instance_id": peer_id, - "hostname": peer.hostname, - "ip_address": sender_ip, - "is_trusted": trusted, - }) - logger.info( - f"Mesh: peer returned {peer.hostname} ({peer_id[:8]}) at {sender_ip}" + self._emit_event( + EventType.MESH_PEER_RETURNED, + { + "instance_id": peer_id, + "hostname": peer.hostname, + "ip_address": sender_ip, + "is_trusted": trusted, + }, ) + logger.info(f"Mesh: peer returned {peer.hostname} ({peer_id[:8]}) at {sender_ip}") else: - self._emit_event(EventType.MESH_PEER_DISCOVERED, { - "instance_id": peer_id, - "hostname": peer.hostname, - "ip_address": sender_ip, - "is_trusted": trusted, - "udp_verified": verified, - "tls_enabled": tls_enabled, - }) + self._emit_event( + EventType.MESH_PEER_DISCOVERED, + { + "instance_id": peer_id, + "hostname": peer.hostname, + "ip_address": sender_ip, + "is_trusted": trusted, + "udp_verified": verified, + "tls_enabled": tls_enabled, + }, + ) logger.info( f"Mesh: discovered peer {peer.hostname} ({peer_id[:8]}) at {sender_ip} " f"[trusted={trusted}, udp_verified={verified}, tls={tls_enabled}]" @@ -257,20 +262,28 @@ async def _reaper_loop(self): self._verse_map.on_peer_offline(pid) if self._discovery: self._discovery.forget_peer(pid) - self._emit_event(EventType.MESH_PEER_OFFLINE, { - "instance_id": pid, - "hostname": peer.hostname, - }) - logger.info(f"Mesh: peer offline {peer.hostname} ({pid[:8]}) — kept in verse map") + self._emit_event( + EventType.MESH_PEER_OFFLINE, + { + "instance_id": pid, + "hostname": peer.hostname, + }, + ) + logger.info( + f"Mesh: peer offline {peer.hostname} ({pid[:8]}) — kept in verse map" + ) else: # Untrusted peer: fully remove self._peers.pop(pid, None) if self._discovery: self._discovery.forget_peer(pid) - self._emit_event(EventType.MESH_PEER_LOST, { - "instance_id": pid, - "hostname": peer.hostname, - }) + self._emit_event( + EventType.MESH_PEER_LOST, + { + "instance_id": pid, + "hostname": peer.hostname, + }, + ) logger.info(f"Mesh: lost peer {peer.hostname} ({pid[:8]})") async def _refresh_loop(self): @@ -306,10 +319,13 @@ async def _fetch_and_update_peer(self, peer: PeerInfo): # Update verse map with latest capabilities self._verse_map.on_peer_updated(peer) - self._emit_event(EventType.MESH_PEER_UPDATED, { - "instance_id": peer.instance_id, - "hostname": peer.hostname, - }) + self._emit_event( + EventType.MESH_PEER_UPDATED, + { + "instance_id": peer.instance_id, + "hostname": peer.hostname, + }, + ) # ------------------------------------------------------------------ # Pairing integration @@ -343,19 +359,19 @@ def mark_peer_trusted(self, instance_id: str): # Public query API # ------------------------------------------------------------------ - def get_peers(self) -> List[PeerInfo]: + def get_peers(self) -> list[PeerInfo]: """Return all live (non-dead) peers.""" return [p for p in self._peers.values() if not p.is_dead] - def get_all_peers(self) -> List[PeerInfo]: + def get_all_peers(self) -> list[PeerInfo]: """Return all tracked peers including stale/dead ones.""" return list(self._peers.values()) - def get_peer(self, instance_id: str) -> Optional[PeerInfo]: + def get_peer(self, instance_id: str) -> PeerInfo | None: """Get a specific peer by instance_id.""" return self._peers.get(instance_id) - def find_peers_with(self, capability: str) -> List[PeerInfo]: + def find_peers_with(self, capability: str) -> list[PeerInfo]: """ Find live peers that have a given capability flag. @@ -386,11 +402,11 @@ def get_local_info(self) -> dict: } @property - def peer_client(self) -> Optional[PeerClient]: + def peer_client(self) -> PeerClient | None: """Expose the peer client for direct campaign operations.""" return self._peer_client - def find_peer_by_hostname(self, hostname: str) -> Optional[PeerInfo]: + def find_peer_by_hostname(self, hostname: str) -> PeerInfo | None: """Find a live peer by hostname (case-insensitive).""" hostname_lower = hostname.lower() for p in self.get_peers(): diff --git a/gently/mesh/models.py b/gently/mesh/models.py index 8c59877f..721525d0 100644 --- a/gently/mesh/models.py +++ b/gently/mesh/models.py @@ -14,13 +14,14 @@ import time from dataclasses import dataclass, field from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any from ..settings import settings class PeerRole(str, Enum): """Dynamic roles a gently node can fill.""" + MICROSCOPE_CONTROLLER = "microscope_controller" ML_TRAINER = "ml_trainer" DATA_SERVER = "data_server" @@ -30,6 +31,7 @@ class PeerRole(str, Enum): @dataclass class GpuInfo: """Details about a single GPU device.""" + device_index: int = 0 name: str = "" vram_gb: float = 0.0 @@ -37,7 +39,7 @@ class GpuInfo: utilization_pct: float = 0.0 memory_used_gb: float = 0.0 - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "device_index": self.device_index, "name": self.name, @@ -48,7 +50,7 @@ def to_dict(self) -> Dict[str, Any]: } @classmethod - def from_dict(cls, d: Dict[str, Any]) -> "GpuInfo": + def from_dict(cls, d: dict[str, Any]) -> "GpuInfo": return cls( device_index=d.get("device_index", 0), name=d.get("name", ""), @@ -62,16 +64,17 @@ def from_dict(cls, d: Dict[str, Any]) -> "GpuInfo": @dataclass class DatasetAdvertisement: """Advertises what data a node has available for training.""" + session_id: str = "" session_name: str = "" embryo_count: int = 0 volume_count: int = 0 has_ground_truth: bool = False ground_truth_count: int = 0 - stages_covered: List[str] = field(default_factory=list) + stages_covered: list[str] = field(default_factory=list) total_size_gb: float = 0.0 - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "session_id": self.session_id, "session_name": self.session_name, @@ -84,7 +87,7 @@ def to_dict(self) -> Dict[str, Any]: } @classmethod - def from_dict(cls, d: Dict[str, Any]) -> "DatasetAdvertisement": + def from_dict(cls, d: dict[str, Any]) -> "DatasetAdvertisement": return cls( session_id=d.get("session_id", ""), session_name=d.get("session_name", ""), @@ -107,19 +110,19 @@ class PeerCapability: gpu_name: str = "" gpu_vram_gb: float = 0.0 storage_free_gb: float = 0.0 - tool_categories: List[str] = field(default_factory=list) + tool_categories: list[str] = field(default_factory=list) organism: str = "" hardware_profile: str = "" # Enhanced capability fields (backward-compatible — old peers get defaults) - gpus: List[GpuInfo] = field(default_factory=list) - roles: List[str] = field(default_factory=list) - datasets: List[DatasetAdvertisement] = field(default_factory=list) + gpus: list[GpuInfo] = field(default_factory=list) + roles: list[str] = field(default_factory=list) + datasets: list[DatasetAdvertisement] = field(default_factory=list) microscope_connected: bool = False cpu_cores: int = 0 ram_gb: float = 0.0 storage_total_gb: float = 0.0 - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "has_microscope": self.has_microscope, "has_sam": self.has_sam, @@ -140,7 +143,7 @@ def to_dict(self) -> Dict[str, Any]: } @classmethod - def from_dict(cls, d: Dict[str, Any]) -> "PeerCapability": + def from_dict(cls, d: dict[str, Any]) -> "PeerCapability": return cls( has_microscope=d.get("has_microscope", False), has_sam=d.get("has_sam", False), @@ -174,7 +177,7 @@ class PeerStatus: active_plan: str = "" version: str = "" - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "session_id": self.session_id, "acquisition_status": self.acquisition_status, @@ -187,7 +190,7 @@ def to_dict(self) -> Dict[str, Any]: } @classmethod - def from_dict(cls, d: Dict[str, Any]) -> "PeerStatus": + def from_dict(cls, d: dict[str, Any]) -> "PeerStatus": return cls( session_id=d.get("session_id", ""), acquisition_status=d.get("acquisition_status", "idle"), @@ -232,7 +235,7 @@ def is_dead(self) -> bool: """True if no heartbeat beyond the dead threshold.""" return (time.time() - self.last_seen) > settings.mesh.dead_threshold_s - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "instance_id": self.instance_id, "hostname": self.hostname, @@ -252,7 +255,7 @@ def to_dict(self) -> Dict[str, Any]: } @classmethod - def from_dict(cls, d: Dict[str, Any]) -> "PeerInfo": + def from_dict(cls, d: dict[str, Any]) -> "PeerInfo": return cls( instance_id=d.get("instance_id", ""), hostname=d.get("hostname", ""), @@ -289,15 +292,15 @@ class PersistedPeer: # Persistence fields online: bool = True last_online: float = field(default_factory=time.time) - roles: List[str] = field(default_factory=list) - datasets: List[DatasetAdvertisement] = field(default_factory=list) + roles: list[str] = field(default_factory=list) + datasets: list[DatasetAdvertisement] = field(default_factory=list) @property def base_url(self) -> str: scheme = "https" if self.tls_enabled else "http" return f"{scheme}://{self.ip_address}:{self.viz_port}" - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "instance_id": self.instance_id, "hostname": self.hostname, @@ -316,7 +319,7 @@ def to_dict(self) -> Dict[str, Any]: } @classmethod - def from_dict(cls, d: Dict[str, Any]) -> "PersistedPeer": + def from_dict(cls, d: dict[str, Any]) -> "PersistedPeer": return cls( instance_id=d.get("instance_id", ""), hostname=d.get("hostname", ""), diff --git a/gently/mesh/pairing.py b/gently/mesh/pairing.py index 32e939d5..e8eda9d9 100644 --- a/gently/mesh/pairing.py +++ b/gently/mesh/pairing.py @@ -24,7 +24,6 @@ import uuid from dataclasses import dataclass, field from pathlib import Path -from typing import Dict, List, Optional, Tuple logger = logging.getLogger(__name__) @@ -71,7 +70,7 @@ class TrustedPeer: paired_at: str = "" # ISO timestamp cert_fingerprint: str = "" # SHA256 of peer's TLS cert (DER) udp_signing_key: str = "" # hex-encoded key for UDP HMAC verification - scopes: List[str] = field(default_factory=lambda: list(ALL_SCOPES)) + scopes: list[str] = field(default_factory=lambda: list(ALL_SCOPES)) class PairingManager: @@ -94,8 +93,8 @@ def __init__(self, instance_id: str, hostname: str, config_dir: Path, audit_log= self._config_dir = config_dir self._audit_log = audit_log - self._sessions: Dict[str, PairingSession] = {} - self._trusted: Dict[str, TrustedPeer] = {} # keyed by instance_id + self._sessions: dict[str, PairingSession] = {} + self._trusted: dict[str, TrustedPeer] = {} # keyed by instance_id self._trust_file = config_dir / "mesh_trusted_peers.json" # Phase 2: TLS cert fingerprint (set by launch_gently after cert gen) @@ -109,7 +108,7 @@ def __init__(self, instance_id: str, hostname: str, config_dir: Path, audit_log= ).hexdigest() # Phase 2: rate limiting state - self._pair_attempts: Dict[str, List[float]] = {} # IP -> timestamps + self._pair_attempts: dict[str, list[float]] = {} # IP -> timestamps self._load_trusted() @@ -163,7 +162,7 @@ def is_trusted(self, instance_id: str) -> bool: """Check if a peer is trusted.""" return instance_id in self._trusted - def get_token_for_peer(self, instance_id: str) -> Optional[str]: + def get_token_for_peer(self, instance_id: str) -> str | None: """Get the current daily auth token for a trusted peer.""" tp = self._trusted.get(instance_id) if tp is None: @@ -171,7 +170,7 @@ def get_token_for_peer(self, instance_id: str) -> Optional[str]: epoch_day = self._current_epoch_day() return self._derive_daily_token(tp.base_token, epoch_day) - def verify_token(self, token: str) -> Optional[str]: + def verify_token(self, token: str) -> str | None: """ Check if a token matches any trusted peer (timing-safe). @@ -186,26 +185,26 @@ def verify_token(self, token: str) -> Optional[str]: return tp.instance_id return None - def get_all_trusted(self) -> List[TrustedPeer]: + def get_all_trusted(self) -> list[TrustedPeer]: """Return all trusted peers.""" return list(self._trusted.values()) - def get_udp_key_for_peer(self, instance_id: str) -> Optional[str]: + def get_udp_key_for_peer(self, instance_id: str) -> str | None: """Get the UDP signing key for a trusted peer.""" tp = self._trusted.get(instance_id) return tp.udp_signing_key if tp else None - def get_cert_fingerprint_for_peer(self, instance_id: str) -> Optional[str]: + def get_cert_fingerprint_for_peer(self, instance_id: str) -> str | None: """Get the TLS cert fingerprint for a trusted peer.""" tp = self._trusted.get(instance_id) return tp.cert_fingerprint if tp else None - def get_scopes_for_peer(self, instance_id: str) -> List[str]: + def get_scopes_for_peer(self, instance_id: str) -> list[str]: """Get the permission scopes for a trusted peer.""" tp = self._trusted.get(instance_id) return list(tp.scopes) if tp else [] - def set_scopes(self, identifier: str, scopes: List[str]) -> bool: + def set_scopes(self, identifier: str, scopes: list[str]) -> bool: """ Set permission scopes for a peer (by instance_id, prefix, or hostname). @@ -269,8 +268,10 @@ def unpair(self, identifier: str) -> bool: self._save_trusted() if self._audit_log: from .audit import AuditEvent + self._audit_log.log( - AuditEvent.PEER_UNPAIRED, outcome="info", + AuditEvent.PEER_UNPAIRED, + outcome="info", peer_id=removed_id, ) return True @@ -281,7 +282,7 @@ def unpair(self, identifier: str) -> bool: # Rate limiting # ------------------------------------------------------------------ - def check_rate_limit(self, ip: str) -> Tuple[bool, float]: + def check_rate_limit(self, ip: str) -> tuple[bool, float]: """ Check if a pairing attempt from this IP is allowed. @@ -298,9 +299,12 @@ def check_rate_limit(self, ip: str) -> Tuple[bool, float]: retry_after = RATE_LIMIT_WINDOW - (now - attempts[0]) if self._audit_log: from .audit import AuditEvent + self._audit_log.log( - AuditEvent.RATE_LIMITED, outcome="deny", - ip=ip, detail=f"max_attempts={RATE_LIMIT_MAX}", + AuditEvent.RATE_LIMITED, + outcome="deny", + ip=ip, + detail=f"max_attempts={RATE_LIMIT_MAX}", ) return False, max(retry_after, 1.0) @@ -312,9 +316,12 @@ def check_rate_limit(self, ip: str) -> Tuple[bool, float]: if elapsed < backoff: if self._audit_log: from .audit import AuditEvent + self._audit_log.log( - AuditEvent.RATE_LIMITED, outcome="deny", - ip=ip, detail=f"backoff={backoff:.1f}s", + AuditEvent.RATE_LIMITED, + outcome="deny", + ip=ip, + detail=f"backoff={backoff:.1f}s", ) return False, backoff - elapsed @@ -405,7 +412,7 @@ def handle_pair_request( # Confirmation # ------------------------------------------------------------------ - def confirm_pairing(self, pairing_id: str, confirmer_id: str) -> Optional[PairingSession]: + def confirm_pairing(self, pairing_id: str, confirmer_id: str) -> PairingSession | None: """ Mark one side as confirmed. @@ -428,36 +435,39 @@ def confirm_pairing(self, pairing_id: str, confirmer_id: str) -> Optional[Pairin return session - def reject_pairing(self, pairing_id: str) -> Optional[PairingSession]: + def reject_pairing(self, pairing_id: str) -> PairingSession | None: """Reject a pending pairing session.""" session = self._sessions.get(pairing_id) if session and session.status == "pending": session.status = "rejected" if self._audit_log: from .audit import AuditEvent + self._audit_log.log( - AuditEvent.PAIR_REJECTED, outcome="deny", + AuditEvent.PAIR_REJECTED, + outcome="deny", peer_id=session.initiator_id, ) return session - def get_session(self, pairing_id: str) -> Optional[PairingSession]: + def get_session(self, pairing_id: str) -> PairingSession | None: """Get a pairing session by ID.""" return self._sessions.get(pairing_id) - def get_pending_sessions(self) -> List[PairingSession]: + def get_pending_sessions(self) -> list[PairingSession]: """Get all pending pairing sessions (for /pair accept).""" return [ - s for s in self._sessions.values() - if s.status == "pending" - and s.responder_id == self.instance_id + s + for s in self._sessions.values() + if s.status == "pending" and s.responder_id == self.instance_id ] def cleanup_expired(self): """Remove expired pending sessions.""" now = time.time() expired = [ - pid for pid, s in self._sessions.items() + pid + for pid, s in self._sessions.items() if s.status == "pending" and (now - s.created_at) > PAIRING_EXPIRY ] for pid in expired: @@ -497,9 +507,12 @@ def _finalize_pairing(self, session: PairingSession): logger.info(f"Paired with {peer_hostname} ({peer_id[:8]})") if self._audit_log: from .audit import AuditEvent + self._audit_log.log( - AuditEvent.PAIR_COMPLETED, outcome="info", - peer_id=peer_id, detail=f"hostname={peer_hostname}", + AuditEvent.PAIR_COMPLETED, + outcome="info", + peer_id=peer_id, + detail=f"hostname={peer_hostname}", ) def _load_trusted(self): diff --git a/gently/mesh/peer_client.py b/gently/mesh/peer_client.py index c0b1be5e..c6db282b 100644 --- a/gently/mesh/peer_client.py +++ b/gently/mesh/peer_client.py @@ -10,12 +10,12 @@ import asyncio import logging import ssl -from typing import Any, Dict, List, Optional +from typing import Any import aiohttp -from .models import PeerInfo from ..settings import settings +from .models import PeerInfo logger = logging.getLogger(__name__) @@ -24,7 +24,7 @@ class PeerClient: """Fetches full status from a peer's viz server over HTTP.""" def __init__(self, pairing_manager=None, audit_log=None): - self._session: Optional[aiohttp.ClientSession] = None + self._session: aiohttp.ClientSession | None = None self._pairing_manager = pairing_manager self._audit_log = audit_log self._pinning_verified: set = set() # track first-success per peer @@ -40,7 +40,7 @@ async def _ensure_session(self): self._session = aiohttp.ClientSession(timeout=timeout, connector=connector) self._pinning_verified.clear() - def _auth_headers(self, peer: PeerInfo) -> Dict[str, str]: + def _auth_headers(self, peer: PeerInfo) -> dict[str, str]: """Build auth headers for a trusted peer.""" if self._pairing_manager is None: return {} @@ -59,16 +59,13 @@ def _ssl_for_peer(self, peer: PeerInfo): """ if self._pairing_manager is None: return False - fingerprint = self._pairing_manager.get_cert_fingerprint_for_peer( - peer.instance_id - ) + fingerprint = self._pairing_manager.get_cert_fingerprint_for_peer(peer.instance_id) if fingerprint: try: return aiohttp.Fingerprint(bytes.fromhex(fingerprint)) except (ValueError, TypeError): logger.warning( - f"Invalid cert fingerprint for {peer.instance_id[:8]}, " - "falling back to unpinned" + f"Invalid cert fingerprint for {peer.instance_id[:8]}, falling back to unpinned" ) return False @@ -78,9 +75,12 @@ def _log_pinning_success(self, peer: PeerInfo, ssl_fp): self._pinning_verified.add(peer.instance_id) if self._audit_log: from .audit import AuditEvent + self._audit_log.log( - AuditEvent.CERT_PIN_OK, outcome="allow", - peer_id=peer.instance_id, ip=peer.ip_address, + AuditEvent.CERT_PIN_OK, + outcome="allow", + peer_id=peer.instance_id, + ip=peer.ip_address, ) def _log_pinning_failure(self, peer: PeerInfo, error): @@ -88,13 +88,16 @@ def _log_pinning_failure(self, peer: PeerInfo, error): logger.warning(f"CERT PINNING FAILED for {peer.instance_id[:8]}: {error}") if self._audit_log: from .audit import AuditEvent + self._audit_log.log( - AuditEvent.CERT_PIN_FAIL, outcome="deny", - peer_id=peer.instance_id, ip=peer.ip_address, + AuditEvent.CERT_PIN_FAIL, + outcome="deny", + peer_id=peer.instance_id, + ip=peer.ip_address, detail=str(error), ) - async def fetch_peer_info(self, peer: PeerInfo) -> Optional[Dict[str, Any]]: + async def fetch_peer_info(self, peer: PeerInfo) -> dict[str, Any] | None: """ GET /api/mesh/status from a peer. @@ -109,9 +112,7 @@ async def fetch_peer_info(self, peer: PeerInfo) -> Optional[Dict[str, Any]]: if resp.status == 200: self._log_pinning_success(peer, ssl_fp) return await resp.json() - logger.debug( - f"Peer {peer.instance_id[:8]} returned HTTP {resp.status}" - ) + logger.debug(f"Peer {peer.instance_id[:8]} returned HTTP {resp.status}") except aiohttp.ServerFingerprintMismatch as e: self._log_pinning_failure(peer, e) except (aiohttp.ClientError, asyncio.TimeoutError, OSError) as e: @@ -123,7 +124,7 @@ async def fetch_peer_info(self, peer: PeerInfo) -> Optional[Dict[str, Any]]: # Campaign coordination methods # ------------------------------------------------------------------ - async def fetch_peer_campaigns(self, peer: PeerInfo) -> Optional[List]: + async def fetch_peer_campaigns(self, peer: PeerInfo) -> list | None: """GET /api/campaigns from a peer.""" await self._ensure_session() url = f"{peer.base_url}/api/campaigns" @@ -141,7 +142,7 @@ async def fetch_peer_campaigns(self, peer: PeerInfo) -> Optional[List]: logger.debug(f"Failed to fetch campaigns from {peer.instance_id[:8]}: {e}") return None - async def fetch_campaign_export(self, peer: PeerInfo, campaign_id: str) -> Optional[Dict]: + async def fetch_campaign_export(self, peer: PeerInfo, campaign_id: str) -> dict | None: """GET /api/campaigns/{id}/export from a peer.""" await self._ensure_session() url = f"{peer.base_url}/api/campaigns/{campaign_id}/export" @@ -159,7 +160,11 @@ async def fetch_campaign_export(self, peer: PeerInfo, campaign_id: str) -> Optio return None async def join_campaign( - self, peer: PeerInfo, campaign_id: str, instance_id: str, hostname: str, + self, + peer: PeerInfo, + campaign_id: str, + instance_id: str, + hostname: str, ) -> bool: """POST /api/campaigns/{id}/join on a peer.""" await self._ensure_session() @@ -167,10 +172,15 @@ async def join_campaign( headers = self._auth_headers(peer) ssl_fp = self._ssl_for_peer(peer) try: - async with self._session.post(url, json={ - "instance_id": instance_id, - "hostname": hostname, - }, headers=headers, ssl=ssl_fp) as resp: + async with self._session.post( + url, + json={ + "instance_id": instance_id, + "hostname": hostname, + }, + headers=headers, + ssl=ssl_fp, + ) as resp: if resp.status == 200: self._log_pinning_success(peer, ssl_fp) return resp.status == 200 @@ -194,10 +204,15 @@ async def claim_item( headers = self._auth_headers(peer) ssl_fp = self._ssl_for_peer(peer) try: - async with self._session.post(url, json={ - "instance_id": instance_id, - "hostname": hostname, - }, headers=headers, ssl=ssl_fp) as resp: + async with self._session.post( + url, + json={ + "instance_id": instance_id, + "hostname": hostname, + }, + headers=headers, + ssl=ssl_fp, + ) as resp: if resp.status == 200: self._log_pinning_success(peer, ssl_fp) return resp.status == 200 @@ -208,7 +223,10 @@ async def claim_item( return False async def unclaim_item( - self, peer: PeerInfo, campaign_id: str, item_id: str, + self, + peer: PeerInfo, + campaign_id: str, + item_id: str, ) -> bool: """POST /api/campaigns/{id}/items/{item_id}/unclaim on a peer.""" await self._ensure_session() @@ -232,12 +250,12 @@ async def update_item_status( campaign_id: str, item_id: str, status: str, - outcome: Optional[str] = None, + outcome: str | None = None, ) -> bool: """POST /api/campaigns/{id}/items/{item_id}/status on a peer.""" await self._ensure_session() url = f"{peer.base_url}/api/campaigns/{campaign_id}/items/{item_id}/status" - body: Dict[str, Any] = {"status": status} + body: dict[str, Any] = {"status": status} if outcome is not None: body["outcome"] = outcome headers = self._auth_headers(peer) @@ -258,9 +276,14 @@ async def update_item_status( # ------------------------------------------------------------------ async def send_pair_request( - self, peer: PeerInfo, initiator_id: str, hostname: str, nonce: str, - cert_fingerprint: str = "", udp_sign_key: str = "", - ) -> Optional[Dict]: + self, + peer: PeerInfo, + initiator_id: str, + hostname: str, + nonce: str, + cert_fingerprint: str = "", + udp_sign_key: str = "", + ) -> dict | None: """POST /api/mesh/pair — initiate pairing with a peer. Returns response dict on success, or {"_error": "..."} on failure. @@ -278,8 +301,9 @@ async def send_pair_request( base = f"{peer.ip_address}:{peer.viz_port}" return {"_error": f"Could not reach {base} via HTTPS or HTTP"} - async def _pairing_request(self, peer: PeerInfo, method: str, path: str, - json_body: Optional[Dict] = None) -> Optional[Dict]: + async def _pairing_request( + self, peer: PeerInfo, method: str, path: str, json_body: dict | None = None + ) -> dict | None: """Make an HTTP request trying HTTPS first, then HTTP (for pre-pairing).""" await self._ensure_session() base = f"{peer.ip_address}:{peer.viz_port}" @@ -297,18 +321,25 @@ async def _pairing_request(self, peer: PeerInfo, method: str, path: str, continue return None - async def poll_pair_status(self, peer: PeerInfo, pairing_id: str) -> Optional[Dict]: + async def poll_pair_status(self, peer: PeerInfo, pairing_id: str) -> dict | None: """GET /api/mesh/pair/{id}/status — poll pairing status.""" return await self._pairing_request( - peer, "GET", f"/api/mesh/pair/{pairing_id}/status", + peer, + "GET", + f"/api/mesh/pair/{pairing_id}/status", ) async def confirm_pair_remote( - self, peer: PeerInfo, pairing_id: str, confirmer_id: str, + self, + peer: PeerInfo, + pairing_id: str, + confirmer_id: str, ) -> bool: """POST /api/mesh/pair/{id}/confirm — confirm pairing on remote side.""" resp = await self._pairing_request( - peer, "POST", f"/api/mesh/pair/{pairing_id}/confirm", + peer, + "POST", + f"/api/mesh/pair/{pairing_id}/confirm", json_body={"confirmer_id": confirmer_id}, ) return resp is not None @@ -317,7 +348,7 @@ async def confirm_pair_remote( # Data catalog methods (Phase 2 — requires "data" scope) # ------------------------------------------------------------------ - async def fetch_peer_sessions(self, peer: PeerInfo) -> Optional[List]: + async def fetch_peer_sessions(self, peer: PeerInfo) -> list | None: """GET /api/data/sessions from a peer.""" await self._ensure_session() url = f"{peer.base_url}/api/data/sessions" @@ -335,7 +366,7 @@ async def fetch_peer_sessions(self, peer: PeerInfo) -> Optional[List]: logger.debug(f"Failed to fetch sessions from {peer.instance_id[:8]}: {e}") return None - async def fetch_peer_session_detail(self, peer: PeerInfo, session_id: str) -> Optional[Dict]: + async def fetch_peer_session_detail(self, peer: PeerInfo, session_id: str) -> dict | None: """GET /api/data/sessions/{id} from a peer.""" await self._ensure_session() url = f"{peer.base_url}/api/data/sessions/{session_id}" @@ -352,7 +383,7 @@ async def fetch_peer_session_detail(self, peer: PeerInfo, session_id: str) -> Op logger.debug(f"Failed to fetch session detail from {peer.instance_id[:8]}: {e}") return None - async def fetch_peer_coverage(self, peer: PeerInfo) -> Optional[Dict]: + async def fetch_peer_coverage(self, peer: PeerInfo) -> dict | None: """GET /api/data/coverage from a peer.""" await self._ensure_session() url = f"{peer.base_url}/api/data/coverage" @@ -369,7 +400,7 @@ async def fetch_peer_coverage(self, peer: PeerInfo) -> Optional[Dict]: logger.debug(f"Failed to fetch coverage from {peer.instance_id[:8]}: {e}") return None - async def fetch_peer_stage_distribution(self, peer: PeerInfo) -> Optional[Dict]: + async def fetch_peer_stage_distribution(self, peer: PeerInfo) -> dict | None: """GET /api/data/stages from a peer.""" await self._ensure_session() url = f"{peer.base_url}/api/data/stages" diff --git a/gently/mesh/routes.py b/gently/mesh/routes.py index 1308340e..2f0edd38 100644 --- a/gently/mesh/routes.py +++ b/gently/mesh/routes.py @@ -51,15 +51,21 @@ async def require_mesh_auth(request: Request): if required_scope not in scopes: if audit_log: audit_log.log( - AuditEvent.SCOPE_DENIED, outcome="deny", - peer_id=peer_id, ip=host, + AuditEvent.SCOPE_DENIED, + outcome="deny", + peer_id=peer_id, + ip=host, detail=f"scope={required_scope} path={request.url.path}", ) if viz_server.event_bus is not None: viz_server.event_bus.publish( EventType.MESH_SCOPE_DENIED, - {"peer_id": peer_id, "scope": required_scope, - "ip": host, "path": str(request.url.path)}, + { + "peer_id": peer_id, + "scope": required_scope, + "ip": host, + "path": str(request.url.path), + }, source="mesh", ) raise HTTPException( @@ -68,16 +74,20 @@ async def require_mesh_auth(request: Request): ) if audit_log: audit_log.log( - AuditEvent.AUTH_SUCCESS, outcome="allow", - peer_id=peer_id, ip=host, + AuditEvent.AUTH_SUCCESS, + outcome="allow", + peer_id=peer_id, + ip=host, ) return # Auth failed if audit_log: audit_log.log( - AuditEvent.AUTH_FAILURE, outcome="deny", - ip=host, detail=f"path={request.url.path}", + AuditEvent.AUTH_FAILURE, + outcome="deny", + ip=host, + detail=f"path={request.url.path}", ) if viz_server.event_bus is not None: viz_server.event_bus.publish( @@ -106,13 +116,15 @@ async def mesh_status(): shared_list = [] for c in shared: status = cs.get_plan_status(c.id) - shared_list.append({ - "id": c.id, - "shorthand": c.shorthand, - "description": c.description, - "item_count": status["total"], - "completed_count": status["completed"], - }) + shared_list.append( + { + "id": c.id, + "shorthand": c.shorthand, + "description": c.description, + "item_count": status["total"], + "completed_count": status["completed"], + } + ) info["shared_campaigns"] = shared_list except Exception: pass @@ -123,12 +135,17 @@ async def mesh_status(): async def mesh_peers(): """List all discovered peers.""" peers = mesh_service.get_peers() - return JSONResponse({ - "peers": [p.to_dict() for p in peers], - "count": len(peers), - }) + return JSONResponse( + { + "peers": [p.to_dict() for p in peers], + "count": len(peers), + } + ) - @router.get("/api/mesh/peers/{instance_id}", dependencies=[Depends(_make_auth_dep("status"))]) + @router.get( + "/api/mesh/peers/{instance_id}", + dependencies=[Depends(_make_auth_dep("status"))], + ) async def mesh_peer_detail(instance_id: str): """Get specific peer details.""" peer = mesh_service.get_peer(instance_id) @@ -144,11 +161,13 @@ async def mesh_topology(): """Full mesh view: self + all peers.""" local = mesh_service.get_local_info() peers = mesh_service.get_all_peers() - return JSONResponse({ - "self": local, - "peers": [p.to_dict() for p in peers], - "total_nodes": 1 + len(peers), - }) + return JSONResponse( + { + "self": local, + "peers": [p.to_dict() for p in peers], + "total_nodes": 1 + len(peers), + } + ) # ------------------------------------------------------------------ # Pairing endpoints (no auth — these bootstrap trust) @@ -181,15 +200,19 @@ async def pair_request(request: Request): raise HTTPException(status_code=400, detail="initiator_id and nonce required") session = pairing_mgr.handle_pair_request( - initiator_id, hostname, nonce, + initiator_id, + hostname, + nonce, initiator_cert_fingerprint=initiator_cert_fp, initiator_udp_sign_key=initiator_udp_key, ) if audit_log: audit_log.log( - AuditEvent.PAIR_REQUESTED, outcome="info", - peer_id=initiator_id, ip=client_ip, + AuditEvent.PAIR_REQUESTED, + outcome="info", + peer_id=initiator_id, + ip=client_ip, detail=f"hostname={hostname}", ) @@ -205,15 +228,17 @@ async def pair_request(request: Request): source="mesh", ) - return JSONResponse({ - "nonce": session.nonce_responder, - "pairing_id": session.pairing_id, - "status": session.status, - "responder_id": mesh_service.instance_id, - "responder_hostname": mesh_service._hostname, - "cert_fingerprint": pairing_mgr.cert_fingerprint, - "udp_sign_key": pairing_mgr.udp_sign_key, - }) + return JSONResponse( + { + "nonce": session.nonce_responder, + "pairing_id": session.pairing_id, + "status": session.status, + "responder_id": mesh_service.instance_id, + "responder_hostname": mesh_service._hostname, + "cert_fingerprint": pairing_mgr.cert_fingerprint, + "udp_sign_key": pairing_mgr.udp_sign_key, + } + ) @router.get("/api/mesh/pair/{pairing_id}/status") async def pair_status(pairing_id: str): @@ -225,12 +250,14 @@ async def pair_status(pairing_id: str): if session is None: raise HTTPException(status_code=404, detail="Pairing session not found") - return JSONResponse({ - "pairing_id": session.pairing_id, - "status": session.status, - "confirmed_by_initiator": session.confirmed_by_initiator, - "confirmed_by_responder": session.confirmed_by_responder, - }) + return JSONResponse( + { + "pairing_id": session.pairing_id, + "status": session.status, + "confirmed_by_initiator": session.confirmed_by_initiator, + "confirmed_by_responder": session.confirmed_by_responder, + } + ) @router.post("/api/mesh/pair/{pairing_id}/confirm") async def pair_confirm(pairing_id: str, request: Request): @@ -267,10 +294,12 @@ async def pair_confirm(pairing_id: str, request: Request): source="mesh", ) - return JSONResponse({ - "pairing_id": session.pairing_id, - "status": session.status, - }) + return JSONResponse( + { + "pairing_id": session.pairing_id, + "status": session.status, + } + ) @router.post("/api/mesh/pair/{pairing_id}/reject") async def pair_reject(pairing_id: str): @@ -282,10 +311,12 @@ async def pair_reject(pairing_id: str): if session is None: raise HTTPException(status_code=404, detail="Pairing session not found") - return JSONResponse({ - "pairing_id": session.pairing_id, - "status": session.status, - }) + return JSONResponse( + { + "pairing_id": session.pairing_id, + "status": session.status, + } + ) # ------------------------------------------------------------------ # Verse map routes (scope: status) @@ -296,12 +327,14 @@ async def verse_map(): """Full persistent topology — includes offline peers.""" vm = mesh_service.verse_map peers = vm.get_all_peers() - return JSONResponse({ - "peers": [p.to_dict() for p in peers], - "online_count": len(vm.get_online_peers()), - "offline_count": len(vm.get_offline_peers()), - "total_count": len(peers), - }) + return JSONResponse( + { + "peers": [p.to_dict() for p in peers], + "online_count": len(vm.get_online_peers()), + "offline_count": len(vm.get_offline_peers()), + "total_count": len(peers), + } + ) @router.get( "/api/mesh/verse-map/resources/{capability}", @@ -311,11 +344,13 @@ async def verse_map_resources(capability: str): """Find peers matching a capability (route-finding).""" vm = mesh_service.verse_map peers = vm.find_resource(capability) - return JSONResponse({ - "capability": capability, - "peers": [p.to_dict() for p in peers], - "count": len(peers), - }) + return JSONResponse( + { + "capability": capability, + "peers": [p.to_dict() for p in peers], + "count": len(peers), + } + ) # ------------------------------------------------------------------ # Data catalog routes (scope: data) @@ -334,25 +369,32 @@ async def data_sessions(): sid = s.session_id if hasattr(s, "session_id") else s.get("session_id", "") name = s.name if hasattr(s, "name") else s.get("name", "") created = s.created_at if hasattr(s, "created_at") else s.get("created_at", "") - last_active = s.last_active if hasattr(s, "last_active") else s.get("last_active", "") + last_active = ( + s.last_active if hasattr(s, "last_active") else s.get("last_active", "") + ) embryos = store.list_embryos(sid) vol_count = 0 for e in embryos: eid = e.embryo_id if hasattr(e, "embryo_id") else e.get("embryo_id", "") vol_count += len(store.list_volumes(sid, eid)) - result.append({ - "session_id": sid, - "name": name, - "embryo_count": len(embryos), - "volume_count": vol_count, - "created_at": created, - "last_active": last_active, - }) + result.append( + { + "session_id": sid, + "name": name, + "embryo_count": len(embryos), + "volume_count": vol_count, + "created_at": created, + "last_active": last_active, + } + ) return JSONResponse({"sessions": result, "count": len(result)}) except Exception as e: return JSONResponse({"error": str(e)}, status_code=500) - @router.get("/api/data/sessions/{session_id}", dependencies=[Depends(_make_auth_dep("data"))]) + @router.get( + "/api/data/sessions/{session_id}", + dependencies=[Depends(_make_auth_dep("data"))], + ) async def data_session_detail(session_id: str): """Detailed session info with embryo list.""" store = getattr(viz_server, "gently_store", None) @@ -376,26 +418,29 @@ async def data_session_detail(session_id: str): try: gts = store.get_ground_truth(session_id, eid) has_gt = len(gts) > 0 - stages = list({ - (gt.stage if hasattr(gt, "stage") else gt.get("stage", "")) - for gt in gts - }) + stages = list( + {(gt.stage if hasattr(gt, "stage") else gt.get("stage", "")) for gt in gts} + ) except Exception: pass - embryo_list.append({ - "embryo_id": eid, - "nickname": nickname, - "volume_count": vol_count, - "has_ground_truth": has_gt, - "stages_annotated": stages, - }) + embryo_list.append( + { + "embryo_id": eid, + "nickname": nickname, + "volume_count": vol_count, + "has_ground_truth": has_gt, + "stages_annotated": stages, + } + ) sname = session.name if hasattr(session, "name") else session.get("name", "") - return JSONResponse({ - "session_id": session_id, - "name": sname, - "embryos": embryo_list, - "total_volumes": total_vols, - }) + return JSONResponse( + { + "session_id": session_id, + "name": sname, + "embryos": embryo_list, + "total_volumes": total_vols, + } + ) except HTTPException: raise except Exception as e: @@ -406,10 +451,16 @@ async def data_coverage(): """Annotation coverage summary across all sessions.""" store = getattr(viz_server, "gently_store", None) if store is None: - return JSONResponse({ - "total_embryos": 0, "annotated_embryos": 0, - "coverage_pct": 0.0, "stage_counts": {}, "imbalance_ratio": 0.0, "gaps": [], - }) + return JSONResponse( + { + "total_embryos": 0, + "annotated_embryos": 0, + "coverage_pct": 0.0, + "stage_counts": {}, + "imbalance_ratio": 0.0, + "gaps": [], + } + ) try: sessions = store.list_sessions() total_embryos = 0 @@ -437,14 +488,16 @@ async def data_coverage(): # Find stages with notably low counts avg = sum(counts) / len(counts) if counts else 0 gaps = [s for s, c in stage_counts.items() if c < avg * 0.5] - return JSONResponse({ - "total_embryos": total_embryos, - "annotated_embryos": annotated_embryos, - "coverage_pct": round(coverage_pct, 1), - "stage_counts": stage_counts, - "imbalance_ratio": round(imbalance_ratio, 2), - "gaps": gaps, - }) + return JSONResponse( + { + "total_embryos": total_embryos, + "annotated_embryos": annotated_embryos, + "coverage_pct": round(coverage_pct, 1), + "stage_counts": stage_counts, + "imbalance_ratio": round(imbalance_ratio, 2), + "gaps": gaps, + } + ) except Exception as e: return JSONResponse({"error": str(e)}, status_code=500) @@ -475,10 +528,12 @@ async def data_stages(): pass if session_dist: by_session[sid] = session_dist - return JSONResponse({ - "stage_distribution": total_dist, - "by_session": by_session, - }) + return JSONResponse( + { + "stage_distribution": total_dist, + "by_session": by_session, + } + ) except Exception as e: return JSONResponse({"error": str(e)}, status_code=500) diff --git a/gently/mesh/tls.py b/gently/mesh/tls.py index 9d1ec01c..412cc485 100644 --- a/gently/mesh/tls.py +++ b/gently/mesh/tls.py @@ -12,7 +12,6 @@ import logging import ssl from pathlib import Path -from typing import Optional, Tuple logger = logging.getLogger(__name__) @@ -21,7 +20,7 @@ CERT_DAYS = 3650 # ~10 years -def ensure_tls_cert(config_dir: Path) -> Tuple[Optional[Path], Optional[Path]]: +def ensure_tls_cert(config_dir: Path) -> tuple[Path | None, Path | None]: """ Ensure a TLS cert/key pair exists in config_dir. @@ -51,9 +50,11 @@ def ensure_tls_cert(config_dir: Path) -> Tuple[Optional[Path], Optional[Path]]: private_key = ec.generate_private_key(ec.SECP256R1()) now = datetime.datetime.now(datetime.timezone.utc) - subject = issuer = x509.Name([ - x509.NameAttribute(NameOID.COMMON_NAME, "gently-mesh"), - ]) + subject = issuer = x509.Name( + [ + x509.NameAttribute(NameOID.COMMON_NAME, "gently-mesh"), + ] + ) cert = ( x509.CertificateBuilder() @@ -64,20 +65,24 @@ def ensure_tls_cert(config_dir: Path) -> Tuple[Optional[Path], Optional[Path]]: .not_valid_before(now) .not_valid_after(now + datetime.timedelta(days=CERT_DAYS)) .add_extension( - x509.SubjectAlternativeName([ - x509.IPAddress(ipaddress.IPv4Address("0.0.0.0")), - ]), + x509.SubjectAlternativeName( + [ + x509.IPAddress(ipaddress.IPv4Address("0.0.0.0")), + ] + ), critical=False, ) .sign(private_key, hashes.SHA256()) ) # Write PEM files - key_path.write_bytes(private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - )) + key_path.write_bytes( + private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + ) cert_path.write_bytes(cert.public_bytes(serialization.Encoding.PEM)) logger.info(f"Generated TLS cert: {cert_path}") @@ -85,8 +90,7 @@ def ensure_tls_cert(config_dir: Path) -> Tuple[Optional[Path], Optional[Path]]: except ImportError: logger.warning( - "cryptography package not installed — TLS disabled " - "(pip install cryptography)" + "cryptography package not installed — TLS disabled (pip install cryptography)" ) return None, None except Exception as e: @@ -114,7 +118,8 @@ def get_cert_fingerprint(cert_path: Path) -> str: def build_server_ssl_context( - cert_path: Path, key_path: Path, + cert_path: Path, + key_path: Path, ) -> ssl.SSLContext: """Build an SSL context for the uvicorn/FastAPI server.""" ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) diff --git a/gently/mesh/transfer/__init__.py b/gently/mesh/transfer/__init__.py index d7ec6818..cf8df0fc 100644 --- a/gently/mesh/transfer/__init__.py +++ b/gently/mesh/transfer/__init__.py @@ -4,8 +4,14 @@ Resumable, authenticated transfers for datasets and model weights. """ -from .models import TransferFile, TransferJob, TransferManifest, TransferStatus, TransferType from .client import TransferClient +from .models import ( + TransferFile, + TransferJob, + TransferManifest, + TransferStatus, + TransferType, +) from .server import TransferService from .tracker import TransferTracker diff --git a/gently/mesh/transfer/client.py b/gently/mesh/transfer/client.py index bdcd3507..066aea08 100644 --- a/gently/mesh/transfer/client.py +++ b/gently/mesh/transfer/client.py @@ -3,15 +3,12 @@ """ import asyncio -import hashlib import logging import time import uuid from pathlib import Path -from typing import List, Optional from ...core.event_bus import EventType, get_event_bus -from ...settings import settings from .models import TransferJob, TransferStatus, TransferType from .protocol import send_file @@ -42,7 +39,7 @@ async def send_dataset( peer_ip: str, peer_port: int, peer_instance_id: str, - file_paths: List[Path], + file_paths: list[Path], session_id: str = "", ) -> TransferJob: """Send dataset files to a peer. @@ -87,7 +84,10 @@ async def send_dataset( reader, writer = await asyncio.open_connection(peer_ip, peer_port) try: success, sha256 = await send_file( - writer, file_path, job.id, auth_token=token, + writer, + file_path, + job.id, + auth_token=token, ) if success: job.bytes_transferred += file_path.stat().st_size diff --git a/gently/mesh/transfer/models.py b/gently/mesh/transfer/models.py index c1534eae..a932d69e 100644 --- a/gently/mesh/transfer/models.py +++ b/gently/mesh/transfer/models.py @@ -2,14 +2,14 @@ Transfer data models. """ -import time from dataclasses import dataclass, field from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any class TransferType(str, Enum): """Type of data being transferred.""" + DATASET = "dataset" MODEL_WEIGHTS = "model_weights" SESSION = "session" @@ -17,6 +17,7 @@ class TransferType(str, Enum): class TransferStatus(str, Enum): """Transfer state machine.""" + PENDING = "pending" TRANSFERRING = "transferring" PAUSED = "paused" @@ -28,12 +29,13 @@ class TransferStatus(str, Enum): @dataclass class TransferFile: """A single file in a transfer manifest.""" + relative_path: str = "" total_size: int = 0 sha256: str = "" transferred: int = 0 - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "relative_path": self.relative_path, "total_size": self.total_size, @@ -42,7 +44,7 @@ def to_dict(self) -> Dict[str, Any]: } @classmethod - def from_dict(cls, d: Dict[str, Any]) -> "TransferFile": + def from_dict(cls, d: dict[str, Any]) -> "TransferFile": return cls( relative_path=d.get("relative_path", ""), total_size=d.get("total_size", 0), @@ -54,11 +56,12 @@ def from_dict(cls, d: Dict[str, Any]) -> "TransferFile": @dataclass class TransferManifest: """List of files to transfer.""" - files: List[TransferFile] = field(default_factory=list) + + files: list[TransferFile] = field(default_factory=list) total_size: int = 0 file_count: int = 0 - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "files": [f.to_dict() for f in self.files], "total_size": self.total_size, @@ -66,7 +69,7 @@ def to_dict(self) -> Dict[str, Any]: } @classmethod - def from_dict(cls, d: Dict[str, Any]) -> "TransferManifest": + def from_dict(cls, d: dict[str, Any]) -> "TransferManifest": files = [TransferFile.from_dict(f) for f in d.get("files", [])] return cls( files=files, @@ -78,6 +81,7 @@ def from_dict(cls, d: Dict[str, Any]) -> "TransferManifest": @dataclass class TransferJob: """State of a single transfer (send or receive).""" + id: str = "" transfer_type: str = TransferType.DATASET.value status: str = TransferStatus.PENDING.value @@ -103,7 +107,7 @@ def progress_pct(self) -> float: return 0.0 return (self.bytes_transferred / self.total_bytes) * 100 - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "id": self.id, "transfer_type": self.transfer_type, @@ -125,7 +129,7 @@ def to_dict(self) -> Dict[str, Any]: } @classmethod - def from_dict(cls, d: Dict[str, Any]) -> "TransferJob": + def from_dict(cls, d: dict[str, Any]) -> "TransferJob": return cls( id=d.get("id", ""), transfer_type=d.get("transfer_type", TransferType.DATASET.value), diff --git a/gently/mesh/transfer/protocol.py b/gently/mesh/transfer/protocol.py index 89af0258..df72c837 100644 --- a/gently/mesh/transfer/protocol.py +++ b/gently/mesh/transfer/protocol.py @@ -22,7 +22,6 @@ import logging import struct from pathlib import Path -from typing import Optional, Tuple from ...settings import settings @@ -39,7 +38,7 @@ async def send_file( auth_token: str = "", offset: int = 0, chunk_size: int = 0, -) -> Tuple[bool, str]: +) -> tuple[bool, str]: """Send a file over a TCP connection. Parameters @@ -108,7 +107,7 @@ async def send_file( async def receive_file( reader: asyncio.StreamReader, dest_dir: Path, -) -> Tuple[Optional[dict], Optional[Path], str]: +) -> tuple[dict | None, Path | None, str]: """Receive a file over a TCP connection. Parameters diff --git a/gently/mesh/transfer/server.py b/gently/mesh/transfer/server.py index ff2cfb2b..9ecf0fb9 100644 --- a/gently/mesh/transfer/server.py +++ b/gently/mesh/transfer/server.py @@ -4,15 +4,12 @@ import asyncio import logging -import time import uuid from pathlib import Path -from typing import Optional from ...core.event_bus import EventType, get_event_bus from ...core.service import Service from ...settings import settings -from .models import TransferJob, TransferStatus from .protocol import receive_file logger = logging.getLogger(__name__) @@ -46,7 +43,7 @@ def __init__( self._dest_dir = dest_dir self._pairing_manager = pairing_manager self._port = port - self._server: Optional[asyncio.AbstractServer] = None + self._server: asyncio.AbstractServer | None = None self._active_transfers: dict = {} async def on_start(self): diff --git a/gently/mesh/transfer/tracker.py b/gently/mesh/transfer/tracker.py index 43f52cb2..be896b5f 100644 --- a/gently/mesh/transfer/tracker.py +++ b/gently/mesh/transfer/tracker.py @@ -5,7 +5,6 @@ import json import logging from pathlib import Path -from typing import Dict, List, Optional from .models import TransferJob, TransferStatus @@ -24,7 +23,7 @@ class TransferTracker: def __init__(self, config_dir: Path): self._config_dir = config_dir self._state_file = config_dir / "mesh_transfers.json" - self._jobs: Dict[str, TransferJob] = {} + self._jobs: dict[str, TransferJob] = {} self._load() def _load(self): @@ -65,20 +64,21 @@ def update_job(self, job_id: str, **kwargs): setattr(job, k, v) self._save() - def get_job(self, job_id: str) -> Optional[TransferJob]: + def get_job(self, job_id: str) -> TransferJob | None: """Get a transfer job by ID.""" return self._jobs.get(job_id) - def list_jobs(self, status: Optional[str] = None) -> List[TransferJob]: + def list_jobs(self, status: str | None = None) -> list[TransferJob]: """List all jobs, optionally filtered by status.""" if status: return [j for j in self._jobs.values() if j.status == status] return list(self._jobs.values()) - def get_resumable(self) -> List[TransferJob]: + def get_resumable(self) -> list[TransferJob]: """Get transfers that were interrupted and can be resumed.""" return [ - j for j in self._jobs.values() + j + for j in self._jobs.values() if j.status == TransferStatus.TRANSFERRING.value and j.bytes_transferred > 0 and j.bytes_transferred < j.total_bytes @@ -87,9 +87,11 @@ def get_resumable(self) -> List[TransferJob]: def cleanup_completed(self, max_age_hours: float = 24.0): """Remove old completed/failed transfers.""" import time + cutoff = time.time() - (max_age_hours * 3600) to_remove = [ - jid for jid, j in self._jobs.items() + jid + for jid, j in self._jobs.items() if j.status in (TransferStatus.COMPLETED.value, TransferStatus.FAILED.value) and j.completed_at > 0 and j.completed_at < cutoff diff --git a/gently/mesh/verse_map.py b/gently/mesh/verse_map.py index 7309f753..7fd36934 100644 --- a/gently/mesh/verse_map.py +++ b/gently/mesh/verse_map.py @@ -9,13 +9,9 @@ import logging import time from pathlib import Path -from typing import Dict, List, Optional from .models import ( - DatasetAdvertisement, - PeerCapability, PeerInfo, - PeerStatus, PersistedPeer, ) @@ -28,7 +24,7 @@ class VerseMap: def __init__(self, config_dir: Path): self._config_dir = config_dir self._map_file = config_dir / "mesh_verse_map.json" - self._peers: Dict[str, PersistedPeer] = {} + self._peers: dict[str, PersistedPeer] = {} self._load() # ------------------------------------------------------------------ @@ -109,19 +105,19 @@ def on_peer_returned(self, instance_id: str): # Queries # ------------------------------------------------------------------ - def get_all_peers(self) -> List[PersistedPeer]: + def get_all_peers(self) -> list[PersistedPeer]: """All peers, online and offline.""" return list(self._peers.values()) - def get_online_peers(self) -> List[PersistedPeer]: + def get_online_peers(self) -> list[PersistedPeer]: """Only online peers.""" return [p for p in self._peers.values() if p.online] - def get_offline_peers(self) -> List[PersistedPeer]: + def get_offline_peers(self) -> list[PersistedPeer]: """Only offline peers.""" return [p for p in self._peers.values() if not p.online] - def get_peer(self, instance_id: str) -> Optional[PersistedPeer]: + def get_peer(self, instance_id: str) -> PersistedPeer | None: """Get a specific peer by instance_id.""" return self._peers.get(instance_id) @@ -138,27 +134,25 @@ def was_online(self, instance_id: str) -> bool: # Route-finding: sorted online-first, then by last_seen recency # ------------------------------------------------------------------ - def _sorted_peers(self, peers: List[PersistedPeer]) -> List[PersistedPeer]: + def _sorted_peers(self, peers: list[PersistedPeer]) -> list[PersistedPeer]: """Sort peers: online first, then by last_seen descending.""" return sorted(peers, key=lambda p: (not p.online, -p.last_seen)) - def find_gpu_peers(self) -> List[PersistedPeer]: + def find_gpu_peers(self) -> list[PersistedPeer]: """Find peers with GPU capability, best candidates first.""" - results = [ - p for p in self._peers.values() - if p.capabilities.has_gpu or p.capabilities.gpus - ] + results = [p for p in self._peers.values() if p.capabilities.has_gpu or p.capabilities.gpus] return self._sorted_peers(results) - def find_microscope_peers(self) -> List[PersistedPeer]: + def find_microscope_peers(self) -> list[PersistedPeer]: """Find peers with microscope capability.""" results = [ - p for p in self._peers.values() + p + for p in self._peers.values() if p.capabilities.has_microscope or p.capabilities.microscope_connected ] return self._sorted_peers(results) - def find_data_peers(self, session_id: str = None) -> List[PersistedPeer]: + def find_data_peers(self, session_id: str | None = None) -> list[PersistedPeer]: """Find peers with data, optionally filtering by session.""" results = [] for p in self._peers.values(): @@ -171,7 +165,7 @@ def find_data_peers(self, session_id: str = None) -> List[PersistedPeer]: results.append(p) return self._sorted_peers(results) - def find_resource(self, capability: str) -> List[PersistedPeer]: + def find_resource(self, capability: str) -> list[PersistedPeer]: """Find peers matching a generic capability attribute. The capability string is checked against: diff --git a/gently/ml/__init__.py b/gently/ml/__init__.py index 9c817a33..7b4c2dc1 100644 --- a/gently/ml/__init__.py +++ b/gently/ml/__init__.py @@ -9,6 +9,7 @@ - Federated averaging for distributed training """ +from .architectures import ARCHITECTURE_REGISTRY, get_suitable_architectures from .models import ( DataSplit, MLPipeline, @@ -18,7 +19,6 @@ TrainingRun, TrainingStatus, ) -from .architectures import ARCHITECTURE_REGISTRY, get_suitable_architectures __all__ = [ "ARCHITECTURE_REGISTRY", diff --git a/gently/ml/_train_worker.py b/gently/ml/_train_worker.py index bc905a87..04448de2 100644 --- a/gently/ml/_train_worker.py +++ b/gently/ml/_train_worker.py @@ -43,7 +43,7 @@ def main(): sys.exit(1) try: - import torchvision.models as models + import torchvision.models as models # noqa: F401 except ImportError: _write_progress(progress_file, {"error": "torchvision not installed"}) sys.exit(1) @@ -56,7 +56,7 @@ def main(): sys.exit(1) # Build datasets - from gently.ml.data_loader import GentlyDataset, create_data_splits + from gently.ml.data_loader import create_data_splits architecture = model_config.get("architecture", "resnet18") num_classes = model_config.get("num_classes", 8) @@ -75,8 +75,11 @@ def main(): # Create datasets from labels train_data, val_data, test_data = create_data_splits( - labels_data, data_root, input_size, - train_ratio=0.7, val_ratio=0.15, + labels_data, + data_root, + input_size, + train_ratio=0.7, + val_ratio=0.15, ) train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2) @@ -171,17 +174,20 @@ def main(): patience_counter += 1 # Write progress - _write_progress(progress_file, { - "epoch": epoch + 1, - "total_epochs": epochs, - "train_loss": round(train_loss, 4), - "train_accuracy": round(train_acc, 4), - "val_loss": round(val_loss, 4), - "val_accuracy": round(val_acc, 4), - "best_val_accuracy": round(best_val_acc, 4), - "lr": optimizer.param_groups[0]["lr"], - "timestamp": time.time(), - }) + _write_progress( + progress_file, + { + "epoch": epoch + 1, + "total_epochs": epochs, + "train_loss": round(train_loss, 4), + "train_accuracy": round(train_acc, 4), + "val_loss": round(val_loss, 4), + "val_accuracy": round(val_acc, 4), + "best_val_accuracy": round(best_val_acc, 4), + "lr": optimizer.param_groups[0]["lr"], + "timestamp": time.time(), + }, + ) # Early stopping if patience_counter >= early_stopping_patience: @@ -218,7 +224,12 @@ def _build_model(architecture, num_classes, pretrained, input_channels, dropout) if input_channels != 3: old_conv = model.conv1 model.conv1 = nn.Conv2d( - input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False, + input_channels, + 64, + kernel_size=7, + stride=2, + padding=3, + bias=False, ) if pretrained and input_channels == 1: # Average RGB weights for grayscale @@ -241,7 +252,8 @@ def _build_model(architecture, num_classes, pretrained, input_channels, dropout) old_conv = model.features[0][0] out_channels = old_conv.out_channels model.features[0][0] = nn.Conv2d( - input_channels, out_channels, + input_channels, + out_channels, kernel_size=old_conv.kernel_size, stride=old_conv.stride, padding=old_conv.padding, @@ -260,7 +272,8 @@ def _build_model(architecture, num_classes, pretrained, input_channels, dropout) old_conv = model.features[0][0] out_channels = old_conv.out_channels model.features[0][0] = nn.Conv2d( - input_channels, out_channels, + input_channels, + out_channels, kernel_size=old_conv.kernel_size, stride=old_conv.stride, padding=old_conv.padding, @@ -278,7 +291,8 @@ def _build_model(architecture, num_classes, pretrained, input_channels, dropout) old_conv = model.features[0][0] out_channels = old_conv.out_channels model.features[0][0] = nn.Conv2d( - input_channels, out_channels, + input_channels, + out_channels, kernel_size=old_conv.kernel_size, stride=old_conv.stride, padding=old_conv.padding, diff --git a/gently/ml/architectures.py b/gently/ml/architectures.py index 41f4fe77..2c6f91a8 100644 --- a/gently/ml/architectures.py +++ b/gently/ml/architectures.py @@ -5,12 +5,12 @@ for a given task, dataset size, and hardware constraints. """ -from typing import Any, Dict, List +from typing import Any from .models import ModelArchitectureType # Architecture registry: metadata per architecture -ARCHITECTURE_REGISTRY: Dict[str, Dict[str, Any]] = { +ARCHITECTURE_REGISTRY: dict[str, dict[str, Any]] = { ModelArchitectureType.RESNET_18.value: { "name": "ResNet-18", "family": "resnet", @@ -147,7 +147,7 @@ def get_suitable_architectures( dataset_size: int, vram_gb: float, image_type: str = "microscopy", -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """Filter architectures suitable for given constraints. Parameters diff --git a/gently/ml/data_loader.py b/gently/ml/data_loader.py index 082fae1e..ccf080b0 100644 --- a/gently/ml/data_loader.py +++ b/gently/ml/data_loader.py @@ -2,11 +2,10 @@ GentlyDataset — PyTorch Dataset loading projections + ground_truth from FileStore. """ -import json import logging import random from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple +from typing import Any logger = logging.getLogger(__name__) @@ -14,9 +13,11 @@ import numpy as np import torch from torch.utils.data import Dataset + HAS_TORCH = True except ImportError: HAS_TORCH = False + # Stub for import-time safety class Dataset: pass @@ -37,7 +38,7 @@ class GentlyDataset(Dataset): def __init__( self, - samples: List[Tuple[str, int]], + samples: list[tuple[str, int]], input_size: int = 224, augment: bool = False, ): @@ -57,6 +58,7 @@ def __getitem__(self, idx): # Load image try: from PIL import Image + img = Image.open(img_path).convert("L") # grayscale img = img.resize((self.input_size, self.input_size)) img_np = np.array(img, dtype=np.float32) / 255.0 @@ -92,13 +94,13 @@ def _apply_augmentations(self, img: np.ndarray) -> np.ndarray: def create_data_splits( - labels_data: Dict[str, Any], + labels_data: dict[str, Any], data_root: Path, input_size: int = 224, train_ratio: float = 0.7, val_ratio: float = 0.15, random_seed: int = 42, -) -> Tuple: +) -> tuple: """Create train/val/test datasets from a labels file. Parameters @@ -141,15 +143,15 @@ def create_data_splits( val_samples = [] test_samples = [] - for label, items in by_label.items(): + for _label, items in by_label.items(): random.shuffle(items) n = len(items) n_train = max(1, int(n * train_ratio)) n_val = max(1, int(n * val_ratio)) train_samples.extend(items[:n_train]) - val_samples.extend(items[n_train:n_train + n_val]) - test_samples.extend(items[n_train + n_val:]) + val_samples.extend(items[n_train : n_train + n_val]) + test_samples.extend(items[n_train + n_val :]) train_ds = GentlyDataset(train_samples, input_size=input_size, augment=True) val_ds = GentlyDataset(val_samples, input_size=input_size, augment=False) @@ -158,7 +160,7 @@ def create_data_splits( return train_ds, val_ds, test_ds -def build_labels_from_store(gently_store, session_ids: Optional[List[str]] = None) -> Dict: +def build_labels_from_store(gently_store, session_ids: list[str] | None = None) -> dict: """Build a labels dict from FileStore ground truth. Returns @@ -193,13 +195,15 @@ def build_labels_from_store(gently_store, session_ids: Optional[List[str]] = Non try: proj_path = gently_store.get_projection_path(sid, eid, start_tp) if proj_path: - samples.append({ - "path": str(proj_path), - "label": stage_to_idx[stage], - "session_id": sid, - "embryo_id": eid, - "stage": stage, - }) + samples.append( + { + "path": str(proj_path), + "label": stage_to_idx[stage], + "session_id": sid, + "embryo_id": eid, + "stage": stage, + } + ) except Exception: pass except Exception: diff --git a/gently/ml/evaluation.py b/gently/ml/evaluation.py index b03f35b5..deffce4d 100644 --- a/gently/ml/evaluation.py +++ b/gently/ml/evaluation.py @@ -6,7 +6,7 @@ import logging from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any logger = logging.getLogger(__name__) @@ -14,17 +14,18 @@ @dataclass class EvaluationReport: """Complete evaluation report for a trained model.""" + run_id: str = "" accuracy: float = 0.0 - per_stage_precision: Dict[str, float] = field(default_factory=dict) - per_stage_recall: Dict[str, float] = field(default_factory=dict) - per_stage_f1: Dict[str, float] = field(default_factory=dict) - confusion_matrix: List[List[int]] = field(default_factory=list) - class_names: List[str] = field(default_factory=list) + per_stage_precision: dict[str, float] = field(default_factory=dict) + per_stage_recall: dict[str, float] = field(default_factory=dict) + per_stage_f1: dict[str, float] = field(default_factory=dict) + confusion_matrix: list[list[int]] = field(default_factory=list) + class_names: list[str] = field(default_factory=list) total_samples: int = 0 correct: int = 0 - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "run_id": self.run_id, "accuracy": self.accuracy, @@ -38,7 +39,7 @@ def to_dict(self) -> Dict[str, Any]: } @classmethod - def from_dict(cls, d: Dict[str, Any]) -> "EvaluationReport": + def from_dict(cls, d: dict[str, Any]) -> "EvaluationReport": return cls( run_id=d.get("run_id", ""), accuracy=d.get("accuracy", 0.0), @@ -66,7 +67,7 @@ def summary(self) -> str: def evaluate_model( model, data_loader, - class_names: List[str], + class_names: list[str], device=None, run_id: str = "", ) -> EvaluationReport: @@ -111,7 +112,7 @@ def evaluate_model( outputs = model(batch_x) _, predicted = outputs.max(1) - for true, pred in zip(batch_y.cpu().tolist(), predicted.cpu().tolist()): + for true, pred in zip(batch_y.cpu().tolist(), predicted.cpu().tolist(), strict=False): cm[true][pred] += 1 total += 1 if true == pred: diff --git a/gently/ml/federated.py b/gently/ml/federated.py index c114b18e..f257bbd5 100644 --- a/gently/ml/federated.py +++ b/gently/ml/federated.py @@ -12,11 +12,9 @@ import asyncio import copy -import json import logging -import time from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple +from typing import Any from ..core.event_bus import EventType, get_event_bus @@ -24,9 +22,9 @@ def federated_average( - state_dicts: List[Dict[str, Any]], - weights: List[float], -) -> Dict[str, Any]: + state_dicts: list[dict[str, Any]], + weights: list[float], +) -> dict[str, Any]: """Compute weighted average of model state dicts. Parameters @@ -50,7 +48,7 @@ def federated_average( try: import torch except ImportError: - raise ImportError("PyTorch required for federated averaging") + raise ImportError("PyTorch required for federated averaging") from None total_weight = sum(weights) if total_weight == 0: @@ -66,7 +64,7 @@ def federated_average( averaged[key] = torch.zeros_like(state_dicts[0][key], dtype=torch.float32) # Weighted sum - for sd, w in zip(state_dicts, norm_weights): + for sd, w in zip(state_dicts, norm_weights, strict=False): for key in averaged: averaged[key] += sd[key].float() * w @@ -98,14 +96,14 @@ def __init__(self, verse_map, transfer_client=None, peer_client=None): async def run_federated_training( self, pipeline_id: str, - worker_peers: List, + worker_peers: list, initial_weights_path: Path, local_epochs_per_round: int = 5, max_rounds: int = 20, convergence_threshold: float = 0.001, - training_config: Optional[Dict] = None, - model_config: Optional[Dict] = None, - ) -> Dict[str, Any]: + training_config: dict | None = None, + model_config: dict | None = None, + ) -> dict[str, Any]: """Run federated averaging across mesh peers. Parameters @@ -175,10 +173,12 @@ async def run_federated_training( # 3. Federated average state_dicts = [r["state_dict"] for r in worker_results if r.get("state_dict")] - dataset_sizes = [r.get("dataset_size", 1) for r in worker_results if r.get("state_dict")] + dataset_sizes = [ + r.get("dataset_size", 1) for r in worker_results if r.get("state_dict") + ] if state_dicts: - global_state = federated_average(state_dicts, dataset_sizes) + federated_average(state_dicts, dataset_sizes) else: logger.warning(f"Round {round_num}: no state dicts to average") continue @@ -234,13 +234,13 @@ async def run_federated_training( async def _train_workers( self, - workers: List, + workers: list, pipeline_id: str, round_num: int, local_epochs: int, - training_config: Optional[Dict], - model_config: Optional[Dict], - ) -> List[Dict]: + training_config: dict | None, + model_config: dict | None, + ) -> list[dict]: """Send training jobs to all workers and collect results. In production this uses PeerClient to POST /api/ml/train on each @@ -253,18 +253,20 @@ async def _train_workers( for worker in workers: tasks.append( self._train_single_worker( - worker, pipeline_id, round_num, local_epochs, - training_config, model_config, + worker, + pipeline_id, + round_num, + local_epochs, + training_config, + model_config, ) ) completed = await asyncio.gather(*tasks, return_exceptions=True) - for worker, result in zip(workers, completed): + for worker, result in zip(workers, completed, strict=False): if isinstance(result, Exception): - logger.warning( - f"Worker {worker.hostname} failed in round {round_num}: {result}" - ) + logger.warning(f"Worker {worker.hostname} failed in round {round_num}: {result}") continue if result: results.append(result) @@ -277,9 +279,9 @@ async def _train_single_worker( pipeline_id: str, round_num: int, local_epochs: int, - training_config: Optional[Dict], - model_config: Optional[Dict], - ) -> Optional[Dict]: + training_config: dict | None, + model_config: dict | None, + ) -> dict | None: """Train on a single worker peer via HTTP API. Returns worker result dict with state_dict, val_accuracy, dataset_size. @@ -289,7 +291,8 @@ async def _train_single_worker( # Build a PeerInfo for the HTTP client from ..models import PeerInfo - peer = PeerInfo( + + PeerInfo( instance_id=worker.instance_id, hostname=worker.hostname, ip_address=worker.ip_address, diff --git a/gently/ml/models.py b/gently/ml/models.py index 0a5b5c00..dcf689ab 100644 --- a/gently/ml/models.py +++ b/gently/ml/models.py @@ -4,11 +4,12 @@ from dataclasses import dataclass, field from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any class TrainingStatus(str, Enum): """Status of an ML pipeline or training run.""" + PLANNED = "planned" DATA_PREP = "data_prep" TRAINING = "training" @@ -20,6 +21,7 @@ class TrainingStatus(str, Enum): class ModelArchitectureType(str, Enum): """Supported model architecture families.""" + RESNET_18 = "resnet18" RESNET_50 = "resnet50" EFFICIENTNET_B0 = "efficientnet_b0" @@ -33,6 +35,7 @@ class ModelArchitectureType(str, Enum): @dataclass class ModelConfig: """Configuration for a model architecture.""" + architecture: str = "resnet18" num_classes: int = 8 pretrained: bool = True @@ -41,7 +44,7 @@ class ModelConfig: dropout: float = 0.2 freeze_backbone_epochs: int = 5 # freeze backbone for N epochs, then unfreeze - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "architecture": self.architecture, "num_classes": self.num_classes, @@ -53,7 +56,7 @@ def to_dict(self) -> Dict[str, Any]: } @classmethod - def from_dict(cls, d: Dict[str, Any]) -> "ModelConfig": + def from_dict(cls, d: dict[str, Any]) -> "ModelConfig": return cls( architecture=d.get("architecture", "resnet18"), num_classes=d.get("num_classes", 8), @@ -68,6 +71,7 @@ def from_dict(cls, d: Dict[str, Any]) -> "ModelConfig": @dataclass class TrainingConfig: """Training hyperparameters.""" + batch_size: int = 32 epochs: int = 50 learning_rate: float = 1e-4 @@ -76,13 +80,15 @@ class TrainingConfig: warmup_epochs: int = 5 mixed_precision: bool = True # AMP on A5000 early_stopping_patience: int = 10 - augmentations: List[str] = field(default_factory=lambda: [ - "random_horizontal_flip", - "random_rotation", - "random_brightness", - ]) + augmentations: list[str] = field( + default_factory=lambda: [ + "random_horizontal_flip", + "random_rotation", + "random_brightness", + ] + ) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "batch_size": self.batch_size, "epochs": self.epochs, @@ -96,7 +102,7 @@ def to_dict(self) -> Dict[str, Any]: } @classmethod - def from_dict(cls, d: Dict[str, Any]) -> "TrainingConfig": + def from_dict(cls, d: dict[str, Any]) -> "TrainingConfig": return cls( batch_size=d.get("batch_size", 32), epochs=d.get("epochs", 50), @@ -113,14 +119,15 @@ def from_dict(cls, d: Dict[str, Any]) -> "TrainingConfig": @dataclass class DataSplit: """Defines how data is split for training.""" + train_ratio: float = 0.7 val_ratio: float = 0.15 test_ratio: float = 0.15 stratify_by: str = "stage" # stratify splits by stage label - session_ids: List[str] = field(default_factory=list) + session_ids: list[str] = field(default_factory=list) random_seed: int = 42 - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "train_ratio": self.train_ratio, "val_ratio": self.val_ratio, @@ -131,7 +138,7 @@ def to_dict(self) -> Dict[str, Any]: } @classmethod - def from_dict(cls, d: Dict[str, Any]) -> "DataSplit": + def from_dict(cls, d: dict[str, Any]) -> "DataSplit": return cls( train_ratio=d.get("train_ratio", 0.7), val_ratio=d.get("val_ratio", 0.15), @@ -145,12 +152,13 @@ def from_dict(cls, d: Dict[str, Any]) -> "DataSplit": @dataclass class TrainingRun: """State of a single training run.""" + id: str = "" pipeline_id: str = "" status: str = TrainingStatus.PLANNED.value - model_config: Optional[ModelConfig] = None - training_config: Optional[TrainingConfig] = None - data_split: Optional[DataSplit] = None + model_config: ModelConfig | None = None + training_config: TrainingConfig | None = None + data_split: DataSplit | None = None current_epoch: int = 0 total_epochs: int = 0 train_loss: float = 0.0 @@ -164,7 +172,7 @@ class TrainingRun: completed_at: str = "" error_message: str = "" - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "id": self.id, "pipeline_id": self.pipeline_id, @@ -187,7 +195,7 @@ def to_dict(self) -> Dict[str, Any]: } @classmethod - def from_dict(cls, d: Dict[str, Any]) -> "TrainingRun": + def from_dict(cls, d: dict[str, Any]) -> "TrainingRun": mc = d.get("model_config") tc = d.get("training_config") ds = d.get("data_split") @@ -216,20 +224,21 @@ def from_dict(cls, d: Dict[str, Any]) -> "TrainingRun": @dataclass class MLPipeline: """Top-level pipeline that coordinates one ML task.""" + id: str = "" campaign_id: str = "" name: str = "" task: str = "embryo_stage_classification" status: str = TrainingStatus.PLANNED.value - model_config: Optional[ModelConfig] = None - data_split: Optional[DataSplit] = None - training_config: Optional[TrainingConfig] = None + model_config: ModelConfig | None = None + data_split: DataSplit | None = None + training_config: TrainingConfig | None = None best_run_id: str = "" best_accuracy: float = 0.0 created_at: str = "" updated_at: str = "" - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "id": self.id, "campaign_id": self.campaign_id, @@ -246,7 +255,7 @@ def to_dict(self) -> Dict[str, Any]: } @classmethod - def from_dict(cls, d: Dict[str, Any]) -> "MLPipeline": + def from_dict(cls, d: dict[str, Any]) -> "MLPipeline": mc = d.get("model_config") ds = d.get("data_split") tc = d.get("training_config") diff --git a/gently/ml/trainer.py b/gently/ml/trainer.py index 67da88cb..2ef764f5 100644 --- a/gently/ml/trainer.py +++ b/gently/ml/trainer.py @@ -9,14 +9,12 @@ import asyncio import json import logging -import os import sys from datetime import datetime from pathlib import Path -from typing import Optional from ..core.event_bus import EventType, get_event_bus -from .models import ModelConfig, TrainingConfig, TrainingRun, TrainingStatus +from .models import TrainingRun, TrainingStatus logger = logging.getLogger(__name__) @@ -33,8 +31,8 @@ class LocalTrainer: def __init__(self, run_dir: Path): self._run_dir = run_dir self._run_dir.mkdir(parents=True, exist_ok=True) - self._process: Optional[asyncio.subprocess.Process] = None - self._monitor_task: Optional[asyncio.Task] = None + self._process: asyncio.subprocess.Process | None = None + self._monitor_task: asyncio.Task | None = None @property def progress_file(self) -> Path: @@ -95,16 +93,16 @@ async def start_training( # Launch subprocess train_script = Path(__file__).parent / "_train_worker.py" self._process = await asyncio.create_subprocess_exec( - sys.executable, str(train_script), str(config_file), + sys.executable, + str(train_script), + str(config_file), stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, cwd=str(self._run_dir), ) # Start progress monitor - self._monitor_task = asyncio.create_task( - self._monitor_progress(run.id, run.pipeline_id) - ) + self._monitor_task = asyncio.create_task(self._monitor_progress(run.id, run.pipeline_id)) logger.info(f"Training started: run={run.id}, pid={self._process.pid}") return run @@ -175,7 +173,7 @@ async def cancel(self): if self._monitor_task and not self._monitor_task.done(): self._monitor_task.cancel() - def get_latest_progress(self) -> Optional[dict]: + def get_latest_progress(self) -> dict | None: """Read the last line from progress.jsonl.""" if not self.progress_file.exists(): return None diff --git a/gently/organisms/__init__.py b/gently/organisms/__init__.py index d388ab60..1db2c317 100644 --- a/gently/organisms/__init__.py +++ b/gently/organisms/__init__.py @@ -14,12 +14,23 @@ import importlib import logging +import pkgutil from types import ModuleType -from typing import Optional logger = logging.getLogger(__name__) -_active_organism: Optional[ModuleType] = None +_active_organism: ModuleType | None = None + + +def available_organisms() -> list[str]: + """Names of the organism plugins shipped under gently.organisms.""" + import gently.organisms as _pkg + + return sorted( + m.name + for m in pkgutil.iter_modules(_pkg.__path__) + if m.ispkg and not m.name.startswith("_") + ) def load_organism(name: str) -> ModuleType: @@ -43,7 +54,19 @@ def load_organism(name: str) -> ModuleType: If the organism module cannot be found. """ global _active_organism - module = importlib.import_module(f"gently.organisms.{name}") + try: + module = importlib.import_module(f"gently.organisms.{name}") + except ModuleNotFoundError as e: + # Only treat a missing organism *package* as a config error; if a + # dependency *inside* the organism module is missing, re-raise so the + # real ImportError isn't masked. + if e.name in (f"gently.organisms.{name}", name): + avail = ", ".join(available_organisms()) or "(none found)" + raise ValueError( + f"Unknown organism '{name}'. Available: {avail}. " + f"Set 'organism:' in config/config.yml." + ) from e + raise _active_organism = module logger.info("Loaded organism module: %s (%s)", name, module.ORGANISM_DISPLAY_NAME) return module @@ -60,7 +83,6 @@ def get_organism() -> ModuleType: """ if _active_organism is None: raise RuntimeError( - "No organism loaded. Call load_organism() at startup, " - "or set 'organism' in config.yml." + "No organism loaded. Call load_organism() at startup, or set 'organism' in config.yml." ) return _active_organism diff --git a/gently/organisms/celegans/__init__.py b/gently/organisms/celegans/__init__.py index b8d353a6..afd4db01 100644 --- a/gently/organisms/celegans/__init__.py +++ b/gently/organisms/celegans/__init__.py @@ -10,21 +10,45 @@ from pathlib import Path +from .biology import BIOLOGY_KNOWLEDGE +from .detection_defaults import DETECTION_DEFAULTS +from .detector_presets import get_detector_presets +from .perception_prompt import PERCEPTION_SYSTEM_PROMPT from .stages import ( - DevelopmentalStage, - STAGES, STAGE_CRITERIA, + STAGES, TRANSITION_ZONES, - get_transition_zone, - get_adjacent_stages, - get_stage_description, + DevelopmentalStage, format_stage_criteria_for_prompt, + get_adjacent_stages, get_all_criteria_for_prompt, + get_stage_description, + get_transition_zone, ) -from .biology import BIOLOGY_KNOWLEDGE -from .detector_presets import get_detector_presets -from .detection_defaults import DETECTION_DEFAULTS -from .perception_prompt import PERCEPTION_SYSTEM_PROMPT + +__all__ = [ + "BIOLOGY_KNOWLEDGE", + "DETECTION_DEFAULTS", + "get_detector_presets", + "PERCEPTION_SYSTEM_PROMPT", + "STAGE_CRITERIA", + "STAGES", + "TRANSITION_ZONES", + "DevelopmentalStage", + "format_stage_criteria_for_prompt", + "get_adjacent_stages", + "get_all_criteria_for_prompt", + "get_stage_description", + "get_transition_zone", + "ORGANISM_NAME", + "ORGANISM_DISPLAY_NAME", + "SAMPLE_TERM", + "SAMPLE_TERM_PLURAL", + "TERMINAL_STAGES", + "STOP_CONDITIONS", + "PRE_TERMINAL_SPEEDUP_STAGE", + "EXAMPLES_PATH", +] # --- Organism identity --- ORGANISM_NAME = "celegans" diff --git a/gently/organisms/celegans/biology.py b/gently/organisms/celegans/biology.py index be2f44f1..6c47ae70 100644 --- a/gently/organisms/celegans/biology.py +++ b/gently/organisms/celegans/biology.py @@ -8,7 +8,8 @@ BIOLOGY_KNOWLEDGE = """ # C. elegans Embryonic Development -C. elegans embryogenesis is highly stereotyped and invariant, proceeding through well-defined stages: +C. elegans embryogenesis is highly stereotyped and invariant, proceeding through +well-defined stages: ## Key Developmental Stages diff --git a/gently/organisms/celegans/detector_presets.py b/gently/organisms/celegans/detector_presets.py index b2a0be4d..30d8d191 100644 --- a/gently/organisms/celegans/detector_presets.py +++ b/gently/organisms/celegans/detector_presets.py @@ -5,10 +5,8 @@ (hatching, comma stage, pretzel, gastrulation, first division). """ -from typing import Dict - -def get_detector_presets() -> Dict: +def get_detector_presets() -> dict: """ Get predefined detector presets for common C. elegans stages. @@ -18,10 +16,11 @@ def get_detector_presets() -> Dict: Preset detector configurations keyed by event name. """ return { - 'hatching': { - 'name': 'hatching', - 'description': 'Detects when C. elegans embryo hatches from eggshell', - 'prompt': """Analyze this C. elegans embryo image (diSPIM light sheet max projection) and determine if the embryo has HATCHED. + "hatching": { + "name": "hatching", + "description": "Detects when C. elegans embryo hatches from eggshell", + "prompt": """Analyze this C. elegans embryo image (diSPIM light sheet max +projection) and determine if the embryo has HATCHED. TRUE HATCHING looks like (must meet at least one): - Most or all of the worm body is OUTSIDE the eggshell boundary @@ -46,16 +45,16 @@ def get_detector_presets() -> Dict: DETECTED: [YES/NO] CONFIDENCE: [HIGH/MEDIUM/LOW] REASONING: [Brief explanation - specifically state if worm is INSIDE or OUTSIDE the shell]""", - 'use_temporal_context': True, - 'temporal_context_size': 10, - 'confidence_threshold': 'HIGH', - 'stop_timelapse': True, # Auto-stop when hatching detected + "use_temporal_context": True, + "temporal_context_size": 10, + "confidence_threshold": "HIGH", + "stop_timelapse": True, # Auto-stop when hatching detected }, - - 'comma': { - 'name': 'comma', - 'description': 'Detects comma stage (major morphogenesis)', - 'prompt': """Analyze this C. elegans embryo and determine if it has reached the COMMA STAGE. + "comma": { + "name": "comma", + "description": "Detects comma stage (major morphogenesis)", + "prompt": """Analyze this C. elegans embryo and determine if it has reached the +COMMA STAGE. Key characteristics of comma stage (~400 minutes, ~6.5 hours): - Distinct comma or bean shape (ventral curvature) @@ -70,15 +69,15 @@ def get_detector_presets() -> Dict: DETECTED: [YES/NO] CONFIDENCE: [HIGH/MEDIUM/LOW] REASONING: [Brief explanation]""", - 'use_temporal_context': True, - 'temporal_context_size': 5, - 'confidence_threshold': 'MEDIUM' + "use_temporal_context": True, + "temporal_context_size": 5, + "confidence_threshold": "MEDIUM", }, - - 'pretzel': { - 'name': 'pretzel', - 'description': 'Detects pretzel/3-fold stage (highly elongated)', - 'prompt': """Analyze this C. elegans embryo and determine if it has reached the PRETZEL/3-FOLD STAGE. + "pretzel": { + "name": "pretzel", + "description": "Detects pretzel/3-fold stage (highly elongated)", + "prompt": """Analyze this C. elegans embryo and determine if it has reached the +PRETZEL/3-FOLD STAGE. Key characteristics of 3-fold stage (~550 minutes, ~9 hours): - Highly elongated, approximately 3x the eggshell length @@ -93,15 +92,14 @@ def get_detector_presets() -> Dict: DETECTED: [YES/NO] CONFIDENCE: [HIGH/MEDIUM/LOW] REASONING: [Brief explanation]""", - 'use_temporal_context': True, - 'temporal_context_size': 5, - 'confidence_threshold': 'MEDIUM' + "use_temporal_context": True, + "temporal_context_size": 5, + "confidence_threshold": "MEDIUM", }, - - 'gastrulation': { - 'name': 'gastrulation', - 'description': 'Detects onset of gastrulation', - 'prompt': """Analyze this C. elegans embryo and determine if GASTRULATION has begun. + "gastrulation": { + "name": "gastrulation", + "description": "Detects onset of gastrulation", + "prompt": """Analyze this C. elegans embryo and determine if GASTRULATION has begun. Key characteristics of gastrulation (~210 minutes, ~3.5 hours): - Visible internalization of cells (especially E cells - gut precursors) @@ -115,15 +113,15 @@ def get_detector_presets() -> Dict: DETECTED: [YES/NO] CONFIDENCE: [HIGH/MEDIUM/LOW] REASONING: [Brief explanation]""", - 'use_temporal_context': True, - 'temporal_context_size': 5, - 'confidence_threshold': 'MEDIUM' + "use_temporal_context": True, + "temporal_context_size": 5, + "confidence_threshold": "MEDIUM", }, - - 'first_division': { - 'name': 'first_division', - 'description': 'Detects first cell division (1-cell to 2-cell)', - 'prompt': """Analyze this C. elegans embryo and determine if FIRST CELL DIVISION has occurred. + "first_division": { + "name": "first_division", + "description": "Detects first cell division (1-cell to 2-cell)", + "prompt": """Analyze this C. elegans embryo and determine if FIRST CELL DIVISION +has occurred. Key characteristics: - Transition from single large cell to two cells @@ -137,8 +135,8 @@ def get_detector_presets() -> Dict: DETECTED: [YES/NO] CONFIDENCE: [HIGH/MEDIUM/LOW] REASONING: [Brief explanation]""", - 'use_temporal_context': True, - 'temporal_context_size': 3, - 'confidence_threshold': 'HIGH' + "use_temporal_context": True, + "temporal_context_size": 3, + "confidence_threshold": "HIGH", }, } diff --git a/gently/organisms/celegans/developmental_tracker.py b/gently/organisms/celegans/developmental_tracker.py index bd10a01e..1750ba36 100644 --- a/gently/organisms/celegans/developmental_tracker.py +++ b/gently/organisms/celegans/developmental_tracker.py @@ -11,8 +11,8 @@ import logging from dataclasses import dataclass, field from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional, Tuple from enum import Enum +from typing import Any import anthropic @@ -29,6 +29,7 @@ class DevelopmentalStage(str, Enum): distinguish. The perception enum (in stages.py) maps "early" to everything before comma. """ + ONE_CELL = "1-cell" TWO_CELL = "2-cell" FOUR_CELL = "4-cell" @@ -97,12 +98,13 @@ class DevelopmentalStage(str, Enum): @dataclass class HatchingPrediction: """Prediction of time to hatching with confidence interval""" + embryo_id: str current_stage: DevelopmentalStage predicted_minutes: int min_minutes: int # Lower bound (optimistic) max_minutes: int # Upper bound (conservative) - confidence: str # Based on stage classification confidence + confidence: str # Based on stage classification confidence timestamp: datetime = field(default_factory=datetime.now) @property @@ -110,20 +112,20 @@ def predicted_hours(self) -> float: return self.predicted_minutes / 60 @property - def range_hours(self) -> Tuple[float, float]: + def range_hours(self) -> tuple[float, float]: return (self.min_minutes / 60, self.max_minutes / 60) - def to_dict(self) -> Dict: + def to_dict(self) -> dict: return { - 'embryo_id': self.embryo_id, - 'current_stage': self.current_stage.value, - 'predicted_minutes': self.predicted_minutes, - 'predicted_hours': self.predicted_hours, - 'min_minutes': self.min_minutes, - 'max_minutes': self.max_minutes, - 'range_hours': self.range_hours, - 'confidence': self.confidence, - 'timestamp': self.timestamp.isoformat(), + "embryo_id": self.embryo_id, + "current_stage": self.current_stage.value, + "predicted_minutes": self.predicted_minutes, + "predicted_hours": self.predicted_hours, + "min_minutes": self.min_minutes, + "max_minutes": self.max_minutes, + "range_hours": self.range_hours, + "confidence": self.confidence, + "timestamp": self.timestamp.isoformat(), } def __str__(self) -> str: @@ -137,25 +139,27 @@ def __str__(self) -> str: @dataclass class StageClassification: """Result of a stage classification""" + stage: DevelopmentalStage confidence: str # HIGH, MEDIUM, LOW reasoning: str timestamp: datetime = field(default_factory=datetime.now) timepoint: int = 0 - predicted_minutes_to_hatching: Optional[int] = None + predicted_minutes_to_hatching: int | None = None - def to_dict(self) -> Dict: + def to_dict(self) -> dict: return { - 'stage': self.stage.value, - 'confidence': self.confidence, - 'reasoning': self.reasoning, - 'timestamp': self.timestamp.isoformat(), - 'timepoint': self.timepoint, - 'predicted_minutes_to_hatching': self.predicted_minutes_to_hatching, + "stage": self.stage.value, + "confidence": self.confidence, + "reasoning": self.reasoning, + "timestamp": self.timestamp.isoformat(), + "timepoint": self.timepoint, + "predicted_minutes_to_hatching": self.predicted_minutes_to_hatching, } -STAGE_CLASSIFICATION_PROMPT = """Analyze this C. elegans embryo image and determine its DEVELOPMENTAL STAGE. +STAGE_CLASSIFICATION_PROMPT = """Analyze this C. elegans embryo image and determine its +DEVELOPMENTAL STAGE. Stages in order (earliest to latest): - 1-cell: Single cell, spherical, no division @@ -194,7 +198,7 @@ class DevelopmentalTracker: def __init__( self, - claude_client: Optional[anthropic.Anthropic] = None, + claude_client: anthropic.Anthropic | None = None, model: str = settings.models.perception, ): """ @@ -209,14 +213,14 @@ def __init__( self.model = model # Stage history per embryo - self._stage_history: Dict[str, List[StageClassification]] = {} + self._stage_history: dict[str, list[StageClassification]] = {} def classify_stage( self, image_b64: str, embryo_id: str, timepoint: int = 0, - recent_images: Optional[List[Dict]] = None, + recent_images: list[dict] | None = None, ) -> StageClassification: """ Classify the developmental stage of an embryo @@ -242,49 +246,51 @@ def classify_stage( # Add temporal context if available if recent_images and len(recent_images) > 1: - content.append({ - "type": "text", - "text": f"Recent images from {embryo_id} (for temporal context):" - }) - for img in recent_images[:-1]: # All but last - content.append({ + content.append( + { "type": "text", - "text": f"Timepoint {img.get('timepoint', '?')}" - }) - content.append({ - "type": "image", - "source": { - "type": "base64", - "media_type": "image/jpeg", - "data": img['b64_image'] + "text": f"Recent images from {embryo_id} (for temporal context):", + } + ) + for img in recent_images[:-1]: # All but last + content.append({"type": "text", "text": f"Timepoint {img.get('timepoint', '?')}"}) + content.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": img["b64_image"], + }, } - }) + ) # Add current image - content.append({ - "type": "text", - "text": f"CURRENT image (timepoint {timepoint}) - classify this one:" - }) - content.append({ - "type": "image", - "source": { - "type": "base64", - "media_type": "image/jpeg", - "data": image_b64 + content.append( + { + "type": "text", + "text": f"CURRENT image (timepoint {timepoint}) - classify this one:", } - }) + ) + content.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": image_b64, + }, + } + ) # Add prompt - content.append({ - "type": "text", - "text": STAGE_CLASSIFICATION_PROMPT - }) + content.append({"type": "text", "text": STAGE_CLASSIFICATION_PROMPT}) try: response = self.claude.messages.create( model=self.model, max_tokens=500, - messages=[{"role": "user", "content": content}] + messages=[{"role": "user", "content": content}], ) result = self._parse_classification(response.content[0].text) @@ -321,31 +327,31 @@ def _parse_classification(self, response_text: str) -> StageClassification: confidence = "LOW" reasoning = "" - lines = response_text.strip().split('\n') + lines = response_text.strip().split("\n") for line in lines: line = line.strip() - if line.startswith('STAGE:'): - stage_str = line.split(':', 1)[1].strip().lower() + if line.startswith("STAGE:"): + stage_str = line.split(":", 1)[1].strip().lower() # Map to enum stage = self._parse_stage_name(stage_str) - elif line.startswith('CONFIDENCE:'): - confidence = line.split(':', 1)[1].strip().upper() - elif line.startswith('REASONING:'): - reasoning = line.split(':', 1)[1].strip() + elif line.startswith("CONFIDENCE:"): + confidence = line.split(":", 1)[1].strip().upper() + elif line.startswith("REASONING:"): + reasoning = line.split(":", 1)[1].strip() # Capture multi-line reasoning if not reasoning: in_reasoning = False reasoning_lines = [] for line in lines: - if line.startswith('REASONING:'): + if line.startswith("REASONING:"): in_reasoning = True - reasoning_lines.append(line.split(':', 1)[1].strip()) + reasoning_lines.append(line.split(":", 1)[1].strip()) elif in_reasoning and line: reasoning_lines.append(line) if reasoning_lines: - reasoning = ' '.join(reasoning_lines) + reasoning = " ".join(reasoning_lines) return StageClassification( stage=stage, @@ -359,43 +365,43 @@ def _parse_stage_name(self, name: str) -> DevelopmentalStage: # Direct matches mappings = { - '1-cell': DevelopmentalStage.ONE_CELL, - 'one-cell': DevelopmentalStage.ONE_CELL, - '2-cell': DevelopmentalStage.TWO_CELL, - 'two-cell': DevelopmentalStage.TWO_CELL, - '4-cell': DevelopmentalStage.FOUR_CELL, - 'four-cell': DevelopmentalStage.FOUR_CELL, - '8-cell': DevelopmentalStage.EIGHT_CELL, - 'eight-cell': DevelopmentalStage.EIGHT_CELL, - 'gastrulation': DevelopmentalStage.GASTRULATION, - 'comma': DevelopmentalStage.COMMA, - '1.5-fold': DevelopmentalStage.ONE_POINT_FIVE_FOLD, - '1.5 fold': DevelopmentalStage.ONE_POINT_FIVE_FOLD, - '2-fold': DevelopmentalStage.TWO_FOLD, - '2 fold': DevelopmentalStage.TWO_FOLD, - 'pretzel': DevelopmentalStage.PRETZEL, - '3-fold': DevelopmentalStage.PRETZEL, - '3 fold': DevelopmentalStage.PRETZEL, - 'pre-hatching': DevelopmentalStage.PRE_HATCHING, - 'prehatching': DevelopmentalStage.PRE_HATCHING, - 'hatching': DevelopmentalStage.HATCHING, - 'hatched': DevelopmentalStage.HATCHED, - 'dead': DevelopmentalStage.DEAD, - 'unknown': DevelopmentalStage.UNKNOWN, + "1-cell": DevelopmentalStage.ONE_CELL, + "one-cell": DevelopmentalStage.ONE_CELL, + "2-cell": DevelopmentalStage.TWO_CELL, + "two-cell": DevelopmentalStage.TWO_CELL, + "4-cell": DevelopmentalStage.FOUR_CELL, + "four-cell": DevelopmentalStage.FOUR_CELL, + "8-cell": DevelopmentalStage.EIGHT_CELL, + "eight-cell": DevelopmentalStage.EIGHT_CELL, + "gastrulation": DevelopmentalStage.GASTRULATION, + "comma": DevelopmentalStage.COMMA, + "1.5-fold": DevelopmentalStage.ONE_POINT_FIVE_FOLD, + "1.5 fold": DevelopmentalStage.ONE_POINT_FIVE_FOLD, + "2-fold": DevelopmentalStage.TWO_FOLD, + "2 fold": DevelopmentalStage.TWO_FOLD, + "pretzel": DevelopmentalStage.PRETZEL, + "3-fold": DevelopmentalStage.PRETZEL, + "3 fold": DevelopmentalStage.PRETZEL, + "pre-hatching": DevelopmentalStage.PRE_HATCHING, + "prehatching": DevelopmentalStage.PRE_HATCHING, + "hatching": DevelopmentalStage.HATCHING, + "hatched": DevelopmentalStage.HATCHED, + "dead": DevelopmentalStage.DEAD, + "unknown": DevelopmentalStage.UNKNOWN, } return mappings.get(name, DevelopmentalStage.UNKNOWN) - def get_stage_history(self, embryo_id: str) -> List[StageClassification]: + def get_stage_history(self, embryo_id: str) -> list[StageClassification]: """Get stage classification history for an embryo""" return self._stage_history.get(embryo_id, []) - def get_current_stage(self, embryo_id: str) -> Optional[StageClassification]: + def get_current_stage(self, embryo_id: str) -> StageClassification | None: """Get the most recent stage classification""" history = self._stage_history.get(embryo_id, []) return history[-1] if history else None - def predict_time_to_hatching(self, embryo_id: str) -> Optional[timedelta]: + def predict_time_to_hatching(self, embryo_id: str) -> timedelta | None: """ Predict time to hatching based on current stage @@ -422,7 +428,7 @@ def predict_time_to_stage( self, embryo_id: str, target_stage: DevelopmentalStage, - ) -> Optional[timedelta]: + ) -> timedelta | None: """ Predict time until embryo reaches target stage @@ -457,7 +463,7 @@ def predict_time_to_stage( minutes = target_timing - current_timing return timedelta(minutes=minutes) - def get_progression_summary(self, embryo_id: str) -> Dict[str, Any]: + def get_progression_summary(self, embryo_id: str) -> dict[str, Any]: """ Get a summary of stage progression for an embryo @@ -475,28 +481,28 @@ def get_progression_summary(self, embryo_id: str) -> Dict[str, Any]: if not history: return { - 'embryo_id': embryo_id, - 'observations': 0, - 'current_stage': None, - 'stages_observed': [], - 'predicted_hatching': None, + "embryo_id": embryo_id, + "observations": 0, + "current_stage": None, + "stages_observed": [], + "predicted_hatching": None, } current = history[-1] stages_observed = list(set(h.stage.value for h in history)) return { - 'embryo_id': embryo_id, - 'observations': len(history), - 'current_stage': current.stage.value, - 'current_confidence': current.confidence, - 'stages_observed': stages_observed, - 'first_observation': history[0].timestamp.isoformat(), - 'last_observation': current.timestamp.isoformat(), - 'predicted_minutes_to_hatching': current.predicted_minutes_to_hatching, + "embryo_id": embryo_id, + "observations": len(history), + "current_stage": current.stage.value, + "current_confidence": current.confidence, + "stages_observed": stages_observed, + "first_observation": history[0].timestamp.isoformat(), + "last_observation": current.timestamp.isoformat(), + "predicted_minutes_to_hatching": current.predicted_minutes_to_hatching, } - def get_hatching_prediction(self, embryo_id: str) -> Optional[HatchingPrediction]: + def get_hatching_prediction(self, embryo_id: str) -> HatchingPrediction | None: """ Get detailed hatching prediction with confidence interval @@ -525,9 +531,9 @@ def get_hatching_prediction(self, embryo_id: str) -> Optional[HatchingPrediction # Adjust confidence interval based on classification confidence confidence_multiplier = { - 'HIGH': 1.0, - 'MEDIUM': 1.5, - 'LOW': 2.0, + "HIGH": 1.0, + "MEDIUM": 1.5, + "LOW": 2.0, }.get(current.confidence, 2.0) adjusted_variability = int(variability * confidence_multiplier) @@ -541,7 +547,7 @@ def get_hatching_prediction(self, embryo_id: str) -> Optional[HatchingPrediction confidence=current.confidence, ) - def get_all_predictions(self, embryo_ids: List[str]) -> Dict[str, HatchingPrediction]: + def get_all_predictions(self, embryo_ids: list[str]) -> dict[str, HatchingPrediction]: """ Get predictions for multiple embryos @@ -562,7 +568,7 @@ def get_all_predictions(self, embryo_ids: List[str]) -> Dict[str, HatchingPredic predictions[embryo_id] = pred return predictions - def estimate_development_rate(self, embryo_id: str) -> Optional[float]: + def estimate_development_rate(self, embryo_id: str) -> float | None: """ Estimate relative development rate compared to standard @@ -584,7 +590,9 @@ def estimate_development_rate(self, embryo_id: str) -> Optional[float]: return None # Need at least two different stages - stages_seen = [(h.stage, h.timestamp) for h in history if h.stage != DevelopmentalStage.UNKNOWN] + stages_seen = [ + (h.stage, h.timestamp) for h in history if h.stage != DevelopmentalStage.UNKNOWN + ] if len(stages_seen) < 2: return None diff --git a/gently/organisms/celegans/perception_prompt.py b/gently/organisms/celegans/perception_prompt.py index 2604a6c7..895fa3e6 100644 --- a/gently/organisms/celegans/perception_prompt.py +++ b/gently/organisms/celegans/perception_prompt.py @@ -5,7 +5,8 @@ Extracted from gently/agent/perception/engine.py. """ -PERCEPTION_SYSTEM_PROMPT = """You are an expert microscopy perception system analyzing C. elegans embryo development. +PERCEPTION_SYSTEM_PROMPT = """You are an expert microscopy perception system analyzing +C. elegans embryo development. IMPORTANT PRINCIPLES: 1. DESCRIBE FIRST: Always describe what you actually see BEFORE classifying @@ -23,27 +24,35 @@ - XY (top-left): Looking DOWN - Best for end asymmetry, ventral indentation, folding - YZ (top-right): Looking from SIDE - Best for body height/thickness -- XZ (bottom): Looking from FRONT - CRITICAL for early->bean transition (look for "peanut" or central constriction) +- XZ (bottom): Looking from FRONT - CRITICAL for early->bean transition (look for "peanut" + or central constriction) -**ALWAYS ANALYZE XZ VIEW**: The XZ view often shows bean-stage features (central constriction, "peanut" shape) BEFORE they're visible in XY. If XZ shows ANY central narrowing or figure-8 appearance, this suggests bean stage even if XY looks symmetric. +**ALWAYS ANALYZE XZ VIEW**: The XZ view often shows bean-stage features (central +constriction, "peanut" shape) BEFORE they're visible in XY. If XZ shows ANY central +narrowing or figure-8 appearance, this suggests bean stage even if XY looks symmetric. DEVELOPMENTAL STAGES: EARLY: Elongated oval (~2:1), SYMMETRIC ENDS, both edges CONVEX, NO central constriction in XZ -BEAN: Even SUBTLE end asymmetry OR central constriction/"peanut" shape in XZ view, edges still CONVEX -COMMA: One edge FLAT or curves INWARD (ventral indentation). XZ shows side-by-side lobes (horizontal figure-8) -1.5-FOLD: Body folding back. XZ shows STACKED horizontal layers (two parallel bands, one above the other) +BEAN: Even SUBTLE end asymmetry OR central constriction/"peanut" shape in XZ view, edges + still CONVEX +COMMA: One edge FLAT or curves INWARD (ventral indentation). XZ shows side-by-side lobes + (horizontal figure-8) +1.5-FOLD: Body folding back. XZ shows STACKED horizontal layers (two parallel bands, one + above the other) 2-FOLD: Body doubled back completely. XZ shows TWO DISTINCT HORIZONTAL LINES with dark gap between PRETZEL: Tightly coiled, 3+ body segments visible as multiple stacked layers HATCHED: Worm exited shell CRITICAL FOR EARLY vs BEAN vs COMMA: - EARLY: Both ends symmetric AND both edges convex AND no central constriction in XZ -- BEAN: ANY of these: subtle end tapering, central constriction in XZ, "peanut" shape - edges still convex +- BEAN: ANY of these: subtle end tapering, central constriction in XZ, "peanut" shape - + edges still convex - COMMA: One edge is flat or curves INWARD (not convex) CRITICAL FOR BEAN/COMMA vs FOLD STAGES (examine XZ view carefully): -The XZ view shows two masses in BOTH bean/comma AND fold stages - the key is their VERTICAL ARRANGEMENT: +The XZ view shows two masses in BOTH bean/comma AND fold stages - the key is their VERTICAL +ARRANGEMENT: BEAN/COMMA XZ: Two lobes at the SAME VERTICAL LEVEL - Lobes are side-by-side horizontally, spanning the same vertical range @@ -67,11 +76,14 @@ - Figure-8 or peanut appearance in any view Mark as TRANSITIONAL (early->bean) or BEAN with appropriate confidence. -SPECIAL: If the field of view is EMPTY (no embryo, no eggshell, only background/debris), return "no_object". +SPECIAL: If the field of view is EMPTY (no embryo, no eggshell, only background/debris), +return "no_object". Respond with JSON: { - "observed_features": {"shape": "...", "curvature": "...", "shell_status": "...", "emergence": "..."}, + "observed_features": { + "shape": "...", "curvature": "...", "shell_status": "...", "emergence": "..." + }, "contrastive_reasoning": {"why_not_previous_stage": "...", "why_not_next_stage": "..."}, "stage": "early|bean|comma|1.5fold|2fold|pretzel|hatching|hatched|arrested|no_object", "is_transitional": true/false, diff --git a/gently/organisms/celegans/stages.py b/gently/organisms/celegans/stages.py index 4ad16970..b75f8d34 100644 --- a/gently/organisms/celegans/stages.py +++ b/gently/organisms/celegans/stages.py @@ -5,7 +5,7 @@ """ from enum import Enum -from typing import List, Dict, Any +from typing import Any class DevelopmentalStage(str, Enum): @@ -21,27 +21,34 @@ class DevelopmentalStage(str, Enum): Special states: - "arrested" is not part of normal progression (dead/arrested embryo) """ - EARLY = "early" # Gastrulation through early morphogenesis, oval shape - BEAN = "bean" # Elongated oval, "bean-shaped", pre-comma curvature - COMMA = "comma" # Clear C-shape, head/tail distinguishable - FOLD_1_5 = "1.5fold" # Elongation, ~1.5x eggshell length - FOLD_2 = "2fold" # Body folded back twice, between 1.5fold and pretzel - PRETZEL = "pretzel" # Tight coil, 3+ body segments (formerly 3fold) - HATCHING = "hatching" # Active emergence, shell breach visible - HATCHED = "hatched" # Fully emerged L1 larva - ARRESTED = "arrested" # Dead or developmentally arrested embryo (special state) - NO_OBJECT = "no_object" # No embryo visible in field of view (special state) + + EARLY = "early" # Gastrulation through early morphogenesis, oval shape + BEAN = "bean" # Elongated oval, "bean-shaped", pre-comma curvature + COMMA = "comma" # Clear C-shape, head/tail distinguishable + FOLD_1_5 = "1.5fold" # Elongation, ~1.5x eggshell length + FOLD_2 = "2fold" # Body folded back twice, between 1.5fold and pretzel + PRETZEL = "pretzel" # Tight coil, 3+ body segments (formerly 3fold) + HATCHING = "hatching" # Active emergence, shell breach visible + HATCHED = "hatched" # Fully emerged L1 larva + ARRESTED = "arrested" # Dead or developmentally arrested embryo (special state) + NO_OBJECT = "no_object" # No embryo visible in field of view (special state) @classmethod - def ordered_list(cls) -> List["DevelopmentalStage"]: + def ordered_list(cls) -> list["DevelopmentalStage"]: """Return stages in developmental order.""" return [ - cls.EARLY, cls.BEAN, cls.COMMA, cls.FOLD_1_5, - cls.FOLD_2, cls.PRETZEL, cls.HATCHING, cls.HATCHED + cls.EARLY, + cls.BEAN, + cls.COMMA, + cls.FOLD_1_5, + cls.FOLD_2, + cls.PRETZEL, + cls.HATCHING, + cls.HATCHED, ] @classmethod - def ordered_values(cls) -> List[str]: + def ordered_values(cls) -> list[str]: """Return stage string values in developmental order.""" return [s.value for s in cls.ordered_list()] @@ -61,7 +68,7 @@ def is_valid(cls, stage: str) -> bool: return stage in cls.all_valid_values() @classmethod - def all_valid_values(cls) -> List[str]: + def all_valid_values(cls) -> list[str]: """Return all valid stage values including special states like 'arrested'.""" return cls.ordered_values() + ["arrested", "no_object"] @@ -97,7 +104,7 @@ def compare(cls, stage_a: str, stage_b: str) -> int: # Each stage has: # - features: what to look for (positive indicators) # - NOT_if: what rules out this stage (negative indicators) -STAGE_CRITERIA: Dict[str, Dict[str, Any]] = { +STAGE_CRITERIA: dict[str, dict[str, Any]] = { "early": { "features": [ "oval/elliptical shape", @@ -249,7 +256,7 @@ def compare(cls, stage_a: str, stage_b: str) -> int: # Transition zones between stages # Used for detecting transitional states and setting expectations for temporal analysis -TRANSITION_ZONES: Dict[str, Dict[str, Any]] = { +TRANSITION_ZONES: dict[str, dict[str, Any]] = { "early_to_bean": { "from_stage": "early", "to_stage": "bean", @@ -325,7 +332,7 @@ def compare(cls, stage_a: str, stage_b: str) -> int: } -def get_transition_zone(from_stage: str, to_stage: str) -> Dict[str, Any]: +def get_transition_zone(from_stage: str, to_stage: str) -> dict[str, Any]: """Get transition zone info between two stages.""" key = f"{from_stage}_to_{to_stage}" return TRANSITION_ZONES.get(key, {}) diff --git a/gently/settings.py b/gently/settings.py index 5a0bdb89..68cebd38 100644 --- a/gently/settings.py +++ b/gently/settings.py @@ -4,6 +4,7 @@ All configurable values live here. Override via environment variables prefixed with GENTLY_ (e.g., GENTLY_VIZ_PORT=9090). """ + import os from dataclasses import dataclass, field from pathlib import Path @@ -29,6 +30,7 @@ def _env(key: str, default): @dataclass(frozen=True) class NetworkSettings: """Ports, hosts, and bind addresses.""" + viz_port: int = field(default_factory=lambda: _env("VIZ_PORT", 8080)) viz_host: str = field(default_factory=lambda: _env("VIZ_HOST", "0.0.0.0")) device_port: int = field(default_factory=lambda: _env("DEVICE_PORT", 60610)) @@ -40,7 +42,10 @@ class NetworkSettings: @dataclass(frozen=True) class MeshSettings: """Mesh networking parameters.""" - broadcast_interval_s: float = field(default_factory=lambda: _env("MESH_BROADCAST_INTERVAL", 5.0)) + + broadcast_interval_s: float = field( + default_factory=lambda: _env("MESH_BROADCAST_INTERVAL", 5.0) + ) replay_window_s: float = field(default_factory=lambda: _env("MESH_REPLAY_WINDOW", 30.0)) reaper_interval_s: float = field(default_factory=lambda: _env("MESH_REAPER_INTERVAL", 10.0)) status_refresh_s: float = field(default_factory=lambda: _env("MESH_STATUS_REFRESH", 30.0)) @@ -52,8 +57,11 @@ class MeshSettings: @dataclass(frozen=True) class ModelSettings: """Claude model identifiers.""" + main: str = field(default_factory=lambda: _env("MODEL_MAIN", "claude-opus-4-6")) - perception: str = field(default_factory=lambda: _env("MODEL_PERCEPTION", "claude-opus-4-5-20251101")) + perception: str = field( + default_factory=lambda: _env("MODEL_PERCEPTION", "claude-opus-4-5-20251101") + ) fast: str = field(default_factory=lambda: _env("MODEL_FAST", "claude-haiku-4-5-20251001")) medium: str = field(default_factory=lambda: _env("MODEL_MEDIUM", "claude-sonnet-4-5-20250929")) @@ -61,6 +69,7 @@ class ModelSettings: @dataclass(frozen=True) class StorageSettings: """File paths for data storage.""" + base_path: Path = field(default_factory=lambda: _env("STORAGE_PATH", Path("D:/Gently3"))) @property @@ -75,6 +84,7 @@ def traces_dir(self) -> Path: @dataclass(frozen=True) class TimeoutSettings: """Timeout values in seconds.""" + plan_execution: int = field(default_factory=lambda: _env("TIMEOUT_PLAN", 300)) rpc_call: int = field(default_factory=lambda: _env("TIMEOUT_RPC", 60)) volume_acquisition: int = field(default_factory=lambda: _env("TIMEOUT_VOLUME", 15)) @@ -84,6 +94,7 @@ class TimeoutSettings: @dataclass(frozen=True) class ApiSettings: """External API configuration.""" + ncbi_tool: str = field(default_factory=lambda: _env("NCBI_TOOL", "gently")) ncbi_email: str = field(default_factory=lambda: _env("NCBI_EMAIL", "pskeshu@gmail.com")) @@ -91,6 +102,7 @@ class ApiSettings: @dataclass(frozen=True) class MlSettings: """Machine learning training parameters.""" + model_cache_dir: Path = field(default_factory=lambda: _env("ML_MODEL_CACHE", Path("models"))) default_batch_size: int = field(default_factory=lambda: _env("ML_BATCH_SIZE", 32)) default_epochs: int = field(default_factory=lambda: _env("ML_EPOCHS", 50)) @@ -100,14 +112,18 @@ class MlSettings: @dataclass(frozen=True) class TransferSettings: """Bulk transfer protocol parameters.""" + transfer_port: int = field(default_factory=lambda: _env("TRANSFER_PORT", 19548)) chunk_size: int = field(default_factory=lambda: _env("TRANSFER_CHUNK_SIZE", 1048576)) # 1MB - max_concurrent_transfers: int = field(default_factory=lambda: _env("TRANSFER_MAX_CONCURRENT", 4)) + max_concurrent_transfers: int = field( + default_factory=lambda: _env("TRANSFER_MAX_CONCURRENT", 4) + ) @dataclass(frozen=True) class Settings: """Top-level settings container.""" + network: NetworkSettings = field(default_factory=NetworkSettings) mesh: MeshSettings = field(default_factory=MeshSettings) models: ModelSettings = field(default_factory=ModelSettings) diff --git a/gently/ui/web/__init__.py b/gently/ui/web/__init__.py index 1ef5ab2d..ab5bcd57 100644 --- a/gently/ui/web/__init__.py +++ b/gently/ui/web/__init__.py @@ -13,20 +13,23 @@ from .embryo_marker import mark_embryos_web from .plots import ( - generate_focus_curve_plot, generate_calibration_summary_plot, generate_edge_detection_plot, + generate_focus_curve_plot, ) + # Lazy import for server (requires FastAPI) def get_visualization_server(): from .server import VisualizationServer, create_visualization_server + return VisualizationServer, create_visualization_server + __all__ = [ - 'mark_embryos_web', - 'get_visualization_server', - 'generate_focus_curve_plot', - 'generate_calibration_summary_plot', - 'generate_edge_detection_plot', + "mark_embryos_web", + "get_visualization_server", + "generate_focus_curve_plot", + "generate_calibration_summary_plot", + "generate_edge_detection_plot", ] diff --git a/gently/ui/web/accounts.py b/gently/ui/web/accounts.py new file mode 100644 index 00000000..3da1c1a0 --- /dev/null +++ b/gently/ui/web/accounts.py @@ -0,0 +1,189 @@ +"""Self-managed user accounts for the web UI. + +A small, dependency-free account store: users live in a YAML file under the +storage directory (NOT the repo), passwords are PBKDF2-hashed, and browser +sessions are stateless HMAC-signed cookies. This is the "self-managed +accounts" backend chosen for the LAN deployment; institute SSO can be layered +on later behind the same ``resolve_role`` surface in ``auth.py``. + +Roles +----- + viewer -- read-only. Sees everything (today's watching experience). + operator -- viewer + may take the microscope control lock and drive. + admin -- operator + may manage users. + +Layout (under /auth/) + users.yaml -- { users: { : {role, salt, hash, iterations, created_at} } } + secret.key -- random key used to sign session cookies (created on first run) +""" + +from __future__ import annotations + +import base64 +import hashlib +import hmac +import logging +import secrets +import time +from datetime import datetime +from pathlib import Path + +import yaml + +logger = logging.getLogger(__name__) + +ROLES = ("viewer", "operator", "admin") +CONTROL_ROLES = frozenset({"operator", "admin"}) +_PBKDF2_ITERATIONS = 200_000 +_SESSION_TTL_SECONDS = 7 * 24 * 3600 # 1 week + + +def _b64(raw: bytes) -> str: + return base64.urlsafe_b64encode(raw).decode("ascii").rstrip("=") + + +def _unb64(s: str) -> bytes: + pad = "=" * (-len(s) % 4) + return base64.urlsafe_b64decode(s + pad) + + +class AccountStore: + """File-backed user accounts + signed session tokens.""" + + def __init__(self, auth_dir: Path): + self.auth_dir = Path(auth_dir) + self.auth_dir.mkdir(parents=True, exist_ok=True) + self.users_path = self.auth_dir / "users.yaml" + self.secret_path = self.auth_dir / "secret.key" + self._users: dict = self._load_users() + self._secret: bytes = self._load_or_create_secret() + + # ── Persistence ─────────────────────────────────────────── + def _load_users(self) -> dict: + if not self.users_path.exists(): + return {} + try: + data = yaml.safe_load(self.users_path.read_text(encoding="utf-8")) or {} + return data.get("users", {}) or {} + except Exception as e: + logger.error("Failed to read users.yaml: %s", e) + return {} + + def _save_users(self) -> None: + tmp = self.users_path.with_suffix(".yaml.tmp") + tmp.write_text(yaml.safe_dump({"users": self._users}, sort_keys=True), encoding="utf-8") + tmp.replace(self.users_path) # atomic + + def _load_or_create_secret(self) -> bytes: + if self.secret_path.exists(): + return self.secret_path.read_bytes() + secret = secrets.token_bytes(32) + self.secret_path.write_bytes(secret) + try: + self.secret_path.chmod(0o600) + except OSError: + pass # best-effort on Windows + return secret + + # ── Users ───────────────────────────────────────────────── + def has_users(self) -> bool: + return bool(self._users) + + def list_users(self) -> list: + return [ + {"username": u, "role": r.get("role", "viewer")} for u, r in sorted(self._users.items()) + ] + + def get_role(self, username: str) -> str | None: + rec = self._users.get(username) + return rec.get("role") if rec else None + + def _hash(self, password: str, salt: bytes, iterations: int) -> bytes: + return hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, iterations) + + def create_user(self, username: str, password: str, role: str = "viewer") -> None: + username = (username or "").strip() + if not username: + raise ValueError("username required") + if role not in ROLES: + raise ValueError(f"role must be one of {ROLES}") + salt = secrets.token_bytes(16) + self._users[username] = { + "role": role, + "salt": salt.hex(), + "hash": self._hash(password, salt, _PBKDF2_ITERATIONS).hex(), + "iterations": _PBKDF2_ITERATIONS, + "created_at": datetime.now().isoformat(timespec="seconds"), + } + self._save_users() + + def verify_password(self, username: str, password: str) -> str | None: + """Return the user's role if the password matches, else None.""" + rec = self._users.get((username or "").strip()) + if not rec: + return None + try: + salt = bytes.fromhex(rec["salt"]) + expected = bytes.fromhex(rec["hash"]) + iterations = int(rec.get("iterations", _PBKDF2_ITERATIONS)) + except (KeyError, ValueError): + return None + candidate = self._hash(password, salt, iterations) + if hmac.compare_digest(candidate, expected): + return rec.get("role", "viewer") + return None + + def bootstrap_admin_if_empty(self) -> tuple[str, str] | None: + """If no users exist, create an admin with a random password. + + Returns (username, password) so the launcher can print it once, or + None if users already exist. + """ + if self._users: + return None + password = secrets.token_urlsafe(12) + self.create_user("admin", password, role="admin") + logger.info("Bootstrapped default admin account") + return ("admin", password) + + # ── Sessions (stateless signed cookie) ──────────────────── + def issue_session(self, username: str, ttl: int = _SESSION_TTL_SECONDS) -> str: + expiry = int(time.time()) + ttl + payload = f"{username}|{expiry}".encode() + sig = hmac.new(self._secret, payload, hashlib.sha256).digest() + return f"{_b64(payload)}.{_b64(sig)}" + + def verify_session(self, token: str) -> str | None: + """Return the username for a valid, unexpired token, else None.""" + if not token or "." not in token: + return None + try: + payload_b64, sig_b64 = token.split(".", 1) + payload = _unb64(payload_b64) + sig = _unb64(sig_b64) + except Exception: + return None + expected = hmac.new(self._secret, payload, hashlib.sha256).digest() + if not hmac.compare_digest(sig, expected): + return None + try: + username, expiry_s = payload.decode("utf-8").rsplit("|", 1) + if int(expiry_s) < int(time.time()): + return None + except Exception: + return None + # The user may have been deleted since the token was issued. + return username if username in self._users else None + + +# ── Module-level singleton (set during server init) ─────────── +_store: AccountStore | None = None + + +def set_account_store(store: AccountStore | None) -> None: + global _store + _store = store + + +def get_account_store() -> AccountStore | None: + return _store diff --git a/gently/ui/web/auth.py b/gently/ui/web/auth.py new file mode 100644 index 00000000..f57ccb2d --- /dev/null +++ b/gently/ui/web/auth.py @@ -0,0 +1,127 @@ +"""Web-UI authorization roles. + +Two roles: + view -- read-only. GET endpoints, SSE / WebSocket event streams. + control -- can drive hardware (POST/PUT/DELETE). Localhost is always + control; remote callers must present a matching token in the + X-Gently-Token header (token read from GENTLY_CONTROL_TOKEN). + +Routes that move hardware or mutate persistent state declare a dependency: + + from gently.ui.web.auth import require_control + + @router.post("/api/devices/foo") + async def foo(_=Depends(require_control)): + ... + +Default-deny on control: if the token env var is unset, remote callers get +view-only access until the operator provisions a token. That matches the +"diSPIM computer alone gives control directions" intent while leaving room +for authenticated remote operators later. +""" + +from __future__ import annotations + +import logging +import os +from enum import Enum + +from fastapi import HTTPException, Request + +logger = logging.getLogger(__name__) + + +_LOOPBACK_HOSTS = frozenset({"127.0.0.1", "::1", "localhost"}) + +# Header name used to upgrade a remote session to control role (legacy +# single-shared-token path, used only when no user accounts are configured). +_TOKEN_HEADER = "X-Gently-Token" +_TOKEN_ENV = "GENTLY_CONTROL_TOKEN" + +# Browser session cookie set by the login flow (see routes/auth_routes.py). +SESSION_COOKIE = "gently_session" + + +class Role(str, Enum): + VIEW = "view" + CONTROL = "control" + + +def current_username(request: Request) -> str | None: + """Return the authenticated username from the session cookie, or None. + + None when no account store is configured or the cookie is missing/invalid. + """ + from gently.ui.web.accounts import get_account_store + + store = get_account_store() + if store is None: + return None + token = request.cookies.get(SESSION_COOKIE) + return store.verify_session(token) if token else None + + +def _configured_token() -> str | None: + """Return the shared control token, or None if no token is provisioned. + + Read fresh each request so the operator can rotate the token without + restarting the web server. + """ + tok = os.environ.get(_TOKEN_ENV, "").strip() + return tok or None + + +def resolve_role(request: Request) -> Role: + """Determine the effective role for a request. + + Account mode (preferred): if user accounts are configured, identity comes + from the signed session cookie — operators/admins get control, everyone + else (including anonymous) gets view. + + Legacy mode (no accounts configured): localhost is always control (the + diSPIM box); remote callers need X-Gently-Token matching + GENTLY_CONTROL_TOKEN. This keeps existing single-operator rigs working + until an admin provisions accounts. + """ + from gently.ui.web.accounts import CONTROL_ROLES, get_account_store + + store = get_account_store() + if store is not None and store.has_users(): + username = current_username(request) + if username: + role = store.get_role(username) + return Role.CONTROL if role in CONTROL_ROLES else Role.VIEW + return Role.VIEW + + # Legacy mode — no accounts configured. + client = request.client + host = client.host if client else None + if host in _LOOPBACK_HOSTS: + return Role.CONTROL + + token = _configured_token() + if token is not None: + supplied = request.headers.get(_TOKEN_HEADER, "").strip() + if supplied and supplied == token: + return Role.CONTROL + + return Role.VIEW + + +def require_control(request: Request) -> Role: + """FastAPI dependency — 403 unless the caller has the control role. + + Logs the denied client host (without leaking the token) so the operator + can spot if a remote browser is trying to drive hardware. + """ + role = resolve_role(request) + if role is Role.CONTROL: + return role + host = request.client.host if request.client else "unknown" + logger.warning("control-route 403 for %s -> %s %s", host, request.method, request.url.path) + raise HTTPException( + status_code=403, + detail="control role required (this endpoint moves hardware or " + "mutates persistent state; localhost has it by default, " + "remote callers need X-Gently-Token)", + ) diff --git a/gently/ui/web/connection_manager.py b/gently/ui/web/connection_manager.py index 266e9cb2..11cc3fe4 100644 --- a/gently/ui/web/connection_manager.py +++ b/gently/ui/web/connection_manager.py @@ -9,7 +9,6 @@ import json import logging from datetime import datetime -from typing import Dict, Optional from .models import ClientInfo, ImageData @@ -18,6 +17,7 @@ # Optional imports try: from fastapi import WebSocket + FASTAPI_AVAILABLE = True except ImportError: FASTAPI_AVAILABLE = False @@ -28,13 +28,25 @@ class ConnectionManager: # Colors for avatar backgrounds (pleasant, distinct colors) AVATAR_COLORS = [ - '#4a9eff', '#ff6b6b', '#51cf66', '#ffd43b', '#cc5de8', - '#ff922b', '#20c997', '#748ffc', '#f06595', '#69db7c', - '#ffa94d', '#9775fa', '#38d9a9', '#e599f7', '#74c0fc' + "#4a9eff", + "#ff6b6b", + "#51cf66", + "#ffd43b", + "#cc5de8", + "#ff922b", + "#20c997", + "#748ffc", + "#f06595", + "#69db7c", + "#ffa94d", + "#9775fa", + "#38d9a9", + "#e599f7", + "#74c0fc", ] def __init__(self): - self.active_connections: Dict[WebSocket, ClientInfo] = {} + self.active_connections: dict[WebSocket, ClientInfo] = {} self._lock = asyncio.Lock() def _generate_color(self, client_id: str) -> str: @@ -42,12 +54,15 @@ def _generate_color(self, client_id: str) -> str: hash_val = sum(ord(c) for c in client_id) return self.AVATAR_COLORS[hash_val % len(self.AVATAR_COLORS)] - async def connect(self, websocket: WebSocket, client_id: str = None, name: str = None): + async def connect( + self, websocket: WebSocket, client_id: str | None = None, name: str | None = None + ): await websocket.accept() # Generate defaults if not provided if not client_id: import uuid + client_id = str(uuid.uuid4())[:8] if not name: name = f"Anonymous {client_id[:4]}" @@ -56,12 +71,14 @@ async def connect(self, websocket: WebSocket, client_id: str = None, name: str = client_id=client_id, name=name, color=self._generate_color(client_id), - connected_at=datetime.now().isoformat() + connected_at=datetime.now().isoformat(), ) async with self._lock: self.active_connections[websocket] = client_info - logger.info(f"WebSocket connected: {name} ({client_id}). Total: {len(self.active_connections)}") + logger.info( + f"WebSocket connected: {name} ({client_id}). Total: {len(self.active_connections)}" + ) # Broadcast updated presence to all clients await self.broadcast_presence() @@ -70,7 +87,9 @@ async def disconnect(self, websocket: WebSocket): async with self._lock: client_info = self.active_connections.pop(websocket, None) if client_info: - logger.info(f"WebSocket disconnected: {client_info.name}. Total: {len(self.active_connections)}") + logger.info( + f"WebSocket disconnected: {client_info.name}. Total: {len(self.active_connections)}" + ) else: logger.info(f"WebSocket disconnected. Total: {len(self.active_connections)}") @@ -86,11 +105,11 @@ async def update_client_name(self, websocket: WebSocket, name: str): client_id=old_info.client_id, name=name, color=old_info.color, - connected_at=old_info.connected_at + connected_at=old_info.connected_at, ) await self.broadcast_presence() - def get_client_info(self, websocket: WebSocket) -> Optional[ClientInfo]: + def get_client_info(self, websocket: WebSocket) -> ClientInfo | None: """Get client info for a websocket""" return self.active_connections.get(websocket) @@ -102,12 +121,12 @@ async def broadcast_presence(self): # Deduplicate by client_id (same user in multiple tabs = one avatar) async with self._lock: seen_clients = {} - for ws, info in self.active_connections.items(): + for _ws, info in self.active_connections.items(): # Keep the most recent entry for each client_id seen_clients[info.client_id] = { - 'client_id': info.client_id, - 'name': info.name, - 'color': info.color + "client_id": info.client_id, + "name": info.name, + "color": info.color, } clients_list = list(seen_clients.values()) @@ -117,14 +136,8 @@ async def broadcast_presence(self): try: personalized = [] for client in clients_list: - personalized.append({ - **client, - 'is_you': client['client_id'] == info.client_id - }) - await ws.send_json({ - 'type': 'presence', - 'clients': personalized - }) + personalized.append({**client, "is_you": client["client_id"] == info.client_id}) + await ws.send_json({"type": "presence", "clients": personalized}) except Exception: disconnected.append(ws) @@ -133,7 +146,7 @@ async def broadcast_presence(self): for ws in disconnected: self.active_connections.pop(ws, None) - async def broadcast(self, message: Dict): + async def broadcast(self, message: dict): """Broadcast message to all connected clients""" if not self.active_connections: return @@ -154,18 +167,19 @@ async def broadcast(self, message: Dict): async def send_image(self, image_data: ImageData): """Send image data to all connected clients""" - await self.broadcast({ - 'type': 'image', - 'data': image_data.to_dict() - }) + await self.broadcast({"type": "image", "data": image_data.to_dict()}) - async def send_event(self, event_type: str, data: Dict, source: str = None, event_id: str = None): + async def send_event( + self, event_type: str, data: dict, source: str | None = None, event_id: str | None = None + ): """Send event notification to all clients""" - await self.broadcast({ - 'type': 'event', - 'event_type': event_type, - 'data': data, - 'source': source or 'unknown', - 'event_id': event_id or '', - 'timestamp': datetime.now().isoformat() - }) + await self.broadcast( + { + "type": "event", + "event_type": event_type, + "data": data, + "source": source or "unknown", + "event_id": event_id or "", + "timestamp": datetime.now().isoformat(), + } + ) diff --git a/gently/ui/web/embryo_marker.py b/gently/ui/web/embryo_marker.py index e84b26a4..b6671f71 100644 --- a/gently/ui/web/embryo_marker.py +++ b/gently/ui/web/embryo_marker.py @@ -13,9 +13,7 @@ """ import logging -from typing import List, Dict, Tuple, Optional from pathlib import Path -from datetime import datetime import numpy as np @@ -25,13 +23,13 @@ async def mark_embryos_web( viz_server, image: np.ndarray, - initial_stage_position: Tuple[float, float], + initial_stage_position: tuple[float, float], pixel_size_um: float = 0.65, - timeout: Optional[float] = None, - save_image_path: Optional[Path] = None, - initial_markers: Optional[List[Dict]] = None, + timeout: float | None = None, + save_image_path: Path | None = None, + initial_markers: list[dict] | None = None, default_role: str = "test", -) -> List[Dict]: +) -> list[dict]: """ Interactive embryo marking via the web map view. @@ -86,43 +84,52 @@ async def mark_embryos_web( return embryos -def _save_marked_image(image: np.ndarray, embryos: List[Dict], output_path: Path): +def _save_marked_image(image: np.ndarray, embryos: list[dict], output_path: Path): """Save image with embryo markers drawn on it.""" - from PIL import Image as PILImage, ImageDraw, ImageFont + from PIL import Image as PILImage + from PIL import ImageDraw, ImageFont output_path = Path(output_path) if image.dtype != np.uint8: - img_normalized = ((image - image.min()) / - max(image.max() - image.min(), 1) * 255).astype(np.uint8) + img_normalized = ((image - image.min()) / max(image.max() - image.min(), 1) * 255).astype( + np.uint8 + ) else: img_normalized = image pil_image = PILImage.fromarray(img_normalized) - if pil_image.mode != 'RGB': - pil_image = pil_image.convert('RGB') + if pil_image.mode != "RGB": + pil_image = pil_image.convert("RGB") draw = ImageDraw.Draw(pil_image) for embryo in embryos: - pixel_x, pixel_y = embryo['pixel_position'] - embryo_num = embryo.get('embryo_number') or embryo.get('embryo_id') or '?' + pixel_x, pixel_y = embryo["pixel_position"] + embryo_num = embryo.get("embryo_number") or embryo.get("embryo_id") or "?" marker_size = 20 draw.line( [(pixel_x - marker_size, pixel_y), (pixel_x + marker_size, pixel_y)], - fill=(0, 255, 255), width=3 + fill=(0, 255, 255), + width=3, ) draw.line( [(pixel_x, pixel_y - marker_size), (pixel_x, pixel_y + marker_size)], - fill=(0, 255, 255), width=3 + fill=(0, 255, 255), + width=3, ) circle_radius = 40 draw.ellipse( - [pixel_x - circle_radius, pixel_y - circle_radius, - pixel_x + circle_radius, pixel_y + circle_radius], - outline=(0, 255, 255), width=2 + [ + pixel_x - circle_radius, + pixel_y - circle_radius, + pixel_x + circle_radius, + pixel_y + circle_radius, + ], + outline=(0, 255, 255), + width=2, ) try: @@ -130,8 +137,12 @@ def _save_marked_image(image: np.ndarray, embryos: List[Dict], output_path: Path except Exception: font = ImageFont.load_default() - draw.text((pixel_x - 10, pixel_y + circle_radius + 5), - str(embryo_num), fill=(0, 255, 255), font=font) + draw.text( + (pixel_x - 10, pixel_y + circle_radius + 5), + str(embryo_num), + fill=(0, 255, 255), + font=font, + ) pil_image.save(output_path) logger.info("Saved marked image: %s", output_path) diff --git a/gently/ui/web/image_store.py b/gently/ui/web/image_store.py index d27b236c..09fcd49f 100644 --- a/gently/ui/web/image_store.py +++ b/gently/ui/web/image_store.py @@ -5,11 +5,13 @@ Organized storage for images by type and embryo. """ -from typing import Dict, List, Optional - from .models import ( - ImageData, EmbryoImageCache, Volume3DData, - CALIBRATION_TYPES, VOLUME_TYPES, ANALYSIS_TYPES, + ANALYSIS_TYPES, + CALIBRATION_TYPES, + VOLUME_TYPES, + EmbryoImageCache, + ImageData, + Volume3DData, ) @@ -17,11 +19,11 @@ class ImageStore: """Organized storage for images by type and embryo (unlimited)""" def __init__(self): - self._embryo_caches: Dict[str, EmbryoImageCache] = {} - self._global_images: List[ImageData] = [] # Images without embryo_id - self._calibration_images: List[ImageData] = [] # Global calibration - self._volume_images: List[ImageData] = [] # Global volumes - self._volumes_3d: Dict[str, Volume3DData] = {} # 3D volumes by UID + self._embryo_caches: dict[str, EmbryoImageCache] = {} + self._global_images: list[ImageData] = [] # Images without embryo_id + self._calibration_images: list[ImageData] = [] # Global calibration + self._volume_images: list[ImageData] = [] # Global volumes + self._volumes_3d: dict[str, Volume3DData] = {} # 3D volumes by UID def _get_embryo_cache(self, embryo_id: str) -> EmbryoImageCache: if embryo_id not in self._embryo_caches: @@ -30,7 +32,7 @@ def _get_embryo_cache(self, embryo_id: str) -> EmbryoImageCache: def add_image(self, image: ImageData): """Add image to appropriate storage based on type and embryo""" - embryo_id = image.metadata.get('embryo_id') + embryo_id = image.metadata.get("embryo_id") data_type = image.data_type if data_type in CALIBRATION_TYPES or data_type in ANALYSIS_TYPES: @@ -55,7 +57,7 @@ def add_image(self, image: ImageData): else: self._global_images.append(image) - def get_all_calibration(self, embryo_id: Optional[str] = None) -> List[ImageData]: + def get_all_calibration(self, embryo_id: str | None = None) -> list[ImageData]: """Get calibration images, optionally filtered by embryo""" if embryo_id: cache = self._embryo_caches.get(embryo_id) @@ -66,7 +68,7 @@ def get_all_calibration(self, embryo_id: Optional[str] = None) -> List[ImageData all_cal.extend(cache.calibration) return sorted(all_cal, key=lambda x: x.timestamp) - def get_all_volumes(self, embryo_id: Optional[str] = None) -> List[ImageData]: + def get_all_volumes(self, embryo_id: str | None = None) -> list[ImageData]: """Get volume images, optionally filtered by embryo""" if embryo_id: cache = self._embryo_caches.get(embryo_id) @@ -76,7 +78,7 @@ def get_all_volumes(self, embryo_id: Optional[str] = None) -> List[ImageData]: all_vol.extend(cache.volumes) return sorted(all_vol, key=lambda x: x.timestamp) - def get_all_snapshots(self, embryo_id: Optional[str] = None) -> List[ImageData]: + def get_all_snapshots(self, embryo_id: str | None = None) -> list[ImageData]: """Get snapshot images (including volume projections), optionally filtered by embryo""" if embryo_id: cache = self._embryo_caches.get(embryo_id) @@ -91,11 +93,11 @@ def get_all_snapshots(self, embryo_id: Optional[str] = None) -> List[ImageData]: all_snap.extend(cache.volumes) return sorted(all_snap, key=lambda x: x.timestamp) - def get_embryo_ids(self) -> List[str]: + def get_embryo_ids(self) -> list[str]: """Get list of all embryo IDs with images""" return list(self._embryo_caches.keys()) - def get_image_by_uid(self, uid: str) -> Optional[ImageData]: + def get_image_by_uid(self, uid: str) -> ImageData | None: """Find image by UID across all storage""" for img in self._global_images: if img.uid == uid: @@ -120,11 +122,11 @@ def add_volume_3d(self, volume_data: Volume3DData): oldest_uid = next(iter(self._volumes_3d)) del self._volumes_3d[oldest_uid] - def get_volume_3d(self, uid: str) -> Optional[Volume3DData]: + def get_volume_3d(self, uid: str) -> Volume3DData | None: """Get a 3D volume by UID""" return self._volumes_3d.get(uid) - def get_all_volumes_3d(self) -> List[Dict]: + def get_all_volumes_3d(self) -> list[dict]: """Get info for all 3D volumes (without heavy data)""" return [v.to_info_dict() for v in self._volumes_3d.values()] @@ -132,9 +134,9 @@ def get_sequence( self, embryo_id: str, start: int = 0, - end: Optional[int] = None, - data_type: Optional[str] = None - ) -> List[ImageData]: + end: int | None = None, + data_type: str | None = None, + ) -> list[ImageData]: """Get ordered sequence of images for an embryo within a timepoint range. Args: @@ -158,8 +160,8 @@ def get_sequence( all_images = [img for img in all_images if img.data_type == data_type] # Filter by timepoint range - def get_timepoint(img: ImageData) -> Optional[int]: - tp = img.metadata.get('timepoint') + def get_timepoint(img: ImageData) -> int | None: + tp = img.metadata.get("timepoint") if tp is not None: return int(tp) return None @@ -179,7 +181,7 @@ def get_timepoint(img: ImageData) -> Optional[int]: filtered.sort(key=lambda x: get_timepoint(x) or 0) return filtered - def get_stats(self) -> Dict: + def get_stats(self) -> dict: """Get storage statistics""" total_cal = len(self._calibration_images) total_vol = len(self._volume_images) @@ -191,10 +193,10 @@ def get_stats(self) -> Dict: total_snap += len(cache.snapshots) return { - 'embryo_count': len(self._embryo_caches), - 'calibration_count': total_cal, - 'volume_count': total_vol, - 'snapshot_count': total_snap, - 'volumes_3d_count': len(self._volumes_3d), - 'embryo_ids': list(self._embryo_caches.keys()), + "embryo_count": len(self._embryo_caches), + "calibration_count": total_cal, + "volume_count": total_vol, + "snapshot_count": total_snap, + "volumes_3d_count": len(self._volumes_3d), + "embryo_ids": list(self._embryo_caches.keys()), } diff --git a/gently/ui/web/models.py b/gently/ui/web/models.py index b61147fd..43c05498 100644 --- a/gently/ui/web/models.py +++ b/gently/ui/web/models.py @@ -5,38 +5,47 @@ Dataclasses and type constants used across the visualization package. """ -from dataclasses import dataclass, field, asdict -from typing import Any, Dict, List, Optional +from dataclasses import asdict, dataclass, field +from typing import Any import numpy as np - # Data types for routing to tabs CALIBRATION_TYPES = { - 'focus_sweep', 'focus_plot', 'edge_detection', 'calibration_summary', - 'focus_snap', 'focus_coarse', 'focus_curve', 'focus_assess' + "focus_sweep", + "focus_plot", + "edge_detection", + "calibration_summary", + "focus_snap", + "focus_coarse", + "focus_curve", + "focus_assess", } -VOLUME_TYPES = { - 'volume', 'volume_projection', 'z_stack', 'timelapse' -} +VOLUME_TYPES = {"volume", "volume_projection", "z_stack", "timelapse"} # CV/Analysis types - shown in a separate "Analysis" category within Calibration ANALYSIS_TYPES = { - 'segmentation', 'detection', 'classification', 'tracking', + "segmentation", + "detection", + "classification", + "tracking", # CV agent visualization types - 'roi_detection', 'cropped_roi', 'vision_prepared', 'timeline', 'cv_visualization' + "roi_detection", + "cropped_roi", + "vision_prepared", + "timeline", + "cv_visualization", } # 3D types that support Z-slider browsing -VOLUME_3D_TYPES = { - 'segmentation_3d' -} +VOLUME_3D_TYPES = {"segmentation_3d"} @dataclass class ClientInfo: """Information about a connected WebSocket client for presence tracking""" + client_id: str name: str color: str # Hex color for avatar background @@ -46,13 +55,14 @@ class ClientInfo: @dataclass class Volume3DData: """Container for 3D volume data with segmentation overlay""" + uid: str data_type: str timestamp: str volume: np.ndarray # Original volume (Z, H, W) - masks: np.ndarray # Segmentation masks (Z, H, W) + masks: np.ndarray # Segmentation masks (Z, H, W) colors: np.ndarray # Cell colors (num_labels, 3) - metadata: Dict[str, Any] = field(default_factory=dict) + metadata: dict[str, Any] = field(default_factory=dict) @property def num_slices(self) -> int: @@ -89,37 +99,39 @@ def get_slice_overlay(self, z: int, alpha: float = 0.4) -> np.ndarray: return rgb - def to_info_dict(self) -> Dict: + def to_info_dict(self) -> dict: """Return metadata without the heavy arrays""" return { - 'uid': self.uid, - 'data_type': self.data_type, - 'timestamp': self.timestamp, - 'shape': list(self.shape), - 'num_slices': self.num_slices, - 'num_cells': int(self.masks.max()), - 'metadata': self.metadata + "uid": self.uid, + "data_type": self.data_type, + "timestamp": self.timestamp, + "shape": list(self.shape), + "num_slices": self.num_slices, + "num_cells": int(self.masks.max()), + "metadata": self.metadata, } @dataclass class ImageData: """Container for image data sent to clients""" + uid: str data_type: str # 'volume', 'projection', 'snapshot', 'detection', 'focus_sweep', etc. timestamp: str - metadata: Dict[str, Any] = field(default_factory=dict) - base64_png: Optional[str] = None - shape: Optional[tuple] = None + metadata: dict[str, Any] = field(default_factory=dict) + base64_png: str | None = None + shape: tuple | None = None - def to_dict(self) -> Dict: + def to_dict(self) -> dict: return asdict(self) @dataclass class EmbryoImageCache: """Per-embryo image organization""" + embryo_id: str - volumes: List[ImageData] = field(default_factory=list) - calibration: List[ImageData] = field(default_factory=list) - snapshots: List[ImageData] = field(default_factory=list) + volumes: list[ImageData] = field(default_factory=list) + calibration: list[ImageData] = field(default_factory=list) + snapshots: list[ImageData] = field(default_factory=list) diff --git a/gently/ui/web/plots.py b/gently/ui/web/plots.py index 437f792d..45575f6d 100644 --- a/gently/ui/web/plots.py +++ b/gently/ui/web/plots.py @@ -5,11 +5,10 @@ Uses matplotlib with Agg backend for thread safety. """ +import matplotlib import numpy as np -from typing import Optional, Tuple, List -import matplotlib -matplotlib.use('Agg') # Non-interactive backend for thread safety +matplotlib.use("Agg") # Non-interactive backend for thread safety import matplotlib.pyplot as plt @@ -17,10 +16,10 @@ def generate_focus_curve_plot( positions: np.ndarray, scores: np.ndarray, best_position: float, - fit_params: Optional[np.ndarray] = None, + fit_params: np.ndarray | None = None, r_squared: float = 0.0, title: str = "Focus Curve", - figsize: Tuple[int, int] = (6, 4), + figsize: tuple[int, int] = (6, 4), dpi: int = 100, ) -> np.ndarray: """ @@ -53,24 +52,34 @@ def generate_focus_curve_plot( fig, ax = plt.subplots(figsize=figsize, dpi=dpi) # Data points - ax.scatter(positions, scores, c='#2196F3', s=50, zorder=3, label='Measurements') + ax.scatter(positions, scores, c="#2196F3", s=50, zorder=3, label="Measurements") # Gaussian fit curve if fit_params is not None and len(fit_params) >= 4: a, mu, sigma, c = fit_params[:4] x_fit = np.linspace(positions.min(), positions.max(), 200) - y_fit = a * np.exp(-((x_fit - mu) ** 2) / (2 * sigma ** 2)) + c - ax.plot(x_fit, y_fit, color='#F44336', linewidth=2, - label=f'Gaussian fit (R²={r_squared:.3f})') + y_fit = a * np.exp(-((x_fit - mu) ** 2) / (2 * sigma**2)) + c + ax.plot( + x_fit, + y_fit, + color="#F44336", + linewidth=2, + label=f"Gaussian fit (R²={r_squared:.3f})", + ) # Best position marker - ax.axvline(best_position, color='#4CAF50', linestyle='--', linewidth=2, - label=f'Best: {best_position:.2f} µm') + ax.axvline( + best_position, + color="#4CAF50", + linestyle="--", + linewidth=2, + label=f"Best: {best_position:.2f} µm", + ) - ax.set_xlabel('Piezo Position (µm)', fontsize=11) - ax.set_ylabel('Focus Score', fontsize=11) - ax.set_title(title, fontsize=12, fontweight='bold') - ax.legend(loc='upper right', framealpha=0.9) + ax.set_xlabel("Piezo Position (µm)", fontsize=11) + ax.set_ylabel("Focus Score", fontsize=11) + ax.set_title(title, fontsize=12, fontweight="bold") + ax.legend(loc="upper right", framealpha=0.9) ax.grid(True, alpha=0.3) # Tight layout @@ -94,7 +103,7 @@ def generate_calibration_summary_plot( offset: float, r_squared_top: float = 0.0, r_squared_bottom: float = 0.0, - figsize: Tuple[int, int] = (7, 5), + figsize: tuple[int, int] = (7, 5), dpi: int = 100, ) -> np.ndarray: """ @@ -135,32 +144,46 @@ def generate_calibration_summary_plot( # Calibration points galvos = [galvo_top, galvo_bottom] piezos = [piezo_top, piezo_bottom] - ax.scatter(galvos, piezos, c='#2196F3', s=100, zorder=3, - label='Calibration points') + ax.scatter(galvos, piezos, c="#2196F3", s=100, zorder=3, label="Calibration points") # Linear fit line margin = 0.05 galvo_range = np.linspace( min(galvo_top, galvo_bottom) - margin, max(galvo_top, galvo_bottom) + margin, - 100 + 100, ) piezo_fit = slope * galvo_range + offset - ax.plot(galvo_range, piezo_fit, color='#F44336', linewidth=2, - label=f'Linear fit: piezo = {slope:.1f}·galvo + {offset:.1f}') + ax.plot( + galvo_range, + piezo_fit, + color="#F44336", + linewidth=2, + label=f"Linear fit: piezo = {slope:.1f}·galvo + {offset:.1f}", + ) # Annotations - ax.annotate(f'Top\nR²={r_squared_top:.3f}', - (galvo_top, piezo_top), textcoords="offset points", - xytext=(10, 10), fontsize=9, color='#666') - ax.annotate(f'Bottom\nR²={r_squared_bottom:.3f}', - (galvo_bottom, piezo_bottom), textcoords="offset points", - xytext=(10, -20), fontsize=9, color='#666') - - ax.set_xlabel('Galvo Position (degrees)', fontsize=11) - ax.set_ylabel('Piezo Position (µm)', fontsize=11) - ax.set_title(f'{embryo_id} - Piezo-Galvo Calibration', fontsize=12, fontweight='bold') - ax.legend(loc='upper left', framealpha=0.9) + ax.annotate( + f"Top\nR²={r_squared_top:.3f}", + (galvo_top, piezo_top), + textcoords="offset points", + xytext=(10, 10), + fontsize=9, + color="#666", + ) + ax.annotate( + f"Bottom\nR²={r_squared_bottom:.3f}", + (galvo_bottom, piezo_bottom), + textcoords="offset points", + xytext=(10, -20), + fontsize=9, + color="#666", + ) + + ax.set_xlabel("Galvo Position (degrees)", fontsize=11) + ax.set_ylabel("Piezo Position (µm)", fontsize=11) + ax.set_title(f"{embryo_id} - Piezo-Galvo Calibration", fontsize=12, fontweight="bold") + ax.legend(loc="upper left", framealpha=0.9) ax.grid(True, alpha=0.3) fig.tight_layout() @@ -172,12 +195,12 @@ def generate_calibration_summary_plot( def generate_edge_detection_plot( - galvo_positions: List[float], - visibility: List[bool], - edge_top: Optional[float] = None, - edge_bottom: Optional[float] = None, + galvo_positions: list[float], + visibility: list[bool], + edge_top: float | None = None, + edge_bottom: float | None = None, embryo_id: str = "embryo", - figsize: Tuple[int, int] = (6, 4), + figsize: tuple[int, int] = (6, 4), dpi: int = 100, ) -> np.ndarray: """ @@ -211,31 +234,47 @@ def generate_edge_detection_plot( vis_numeric = [1 if v else 0 for v in visibility] # Plot visibility as step function - colors = ['#4CAF50' if v else '#F44336' for v in visibility] + colors = ["#4CAF50" if v else "#F44336" for v in visibility] ax.scatter(galvo_positions, vis_numeric, c=colors, s=80, zorder=3) # Draw step-like connecting lines for i in range(len(galvo_positions) - 1): - color = '#4CAF50' if visibility[i] else '#F44336' - ax.hlines(vis_numeric[i], galvo_positions[i], galvo_positions[i+1], - color=color, alpha=0.3, linewidth=2) + color = "#4CAF50" if visibility[i] else "#F44336" + ax.hlines( + vis_numeric[i], + galvo_positions[i], + galvo_positions[i + 1], + color=color, + alpha=0.3, + linewidth=2, + ) # Mark edges if provided if edge_top is not None: - ax.axvline(edge_top, color='#2196F3', linestyle='--', linewidth=2, - label=f'Top edge: {edge_top:.3f}°') + ax.axvline( + edge_top, + color="#2196F3", + linestyle="--", + linewidth=2, + label=f"Top edge: {edge_top:.3f}°", + ) if edge_bottom is not None: - ax.axvline(edge_bottom, color='#FF9800', linestyle='--', linewidth=2, - label=f'Bottom edge: {edge_bottom:.3f}°') - - ax.set_xlabel('Galvo Position (degrees)', fontsize=11) - ax.set_ylabel('Embryo Visible', fontsize=11) + ax.axvline( + edge_bottom, + color="#FF9800", + linestyle="--", + linewidth=2, + label=f"Bottom edge: {edge_bottom:.3f}°", + ) + + ax.set_xlabel("Galvo Position (degrees)", fontsize=11) + ax.set_ylabel("Embryo Visible", fontsize=11) ax.set_yticks([0, 1]) - ax.set_yticklabels(['No', 'Yes']) - ax.set_title(f'{embryo_id} - Edge Detection', fontsize=12, fontweight='bold') + ax.set_yticklabels(["No", "Yes"]) + ax.set_title(f"{embryo_id} - Edge Detection", fontsize=12, fontweight="bold") if edge_top is not None or edge_bottom is not None: - ax.legend(loc='best', framealpha=0.9) - ax.grid(True, alpha=0.3, axis='x') + ax.legend(loc="best", framealpha=0.9) + ax.grid(True, alpha=0.3, axis="x") fig.tight_layout() fig.canvas.draw() diff --git a/gently/ui/web/routes/__init__.py b/gently/ui/web/routes/__init__.py index 13f72b53..ebd90770 100644 --- a/gently/ui/web/routes/__init__.py +++ b/gently/ui/web/routes/__init__.py @@ -6,22 +6,24 @@ a FastAPI ``APIRouter`` bound to the server instance. """ -from .pages import create_router as create_pages_router -from .sessions import create_router as create_sessions_router -from .images import create_router as create_images_router -from .volumes import create_router as create_volumes_router -from .data import create_router as create_data_router -from .websocket import create_router as create_websocket_router from .agent_ws import create_router as create_agent_ws_router +from .auth_routes import create_router as create_auth_router from .campaigns import create_router as create_campaigns_router from .chat import create_router as create_chat_router +from .data import create_router as create_data_router from .experiments import create_router as create_experiments_router +from .images import create_router as create_images_router +from .pages import create_router as create_pages_router +from .sessions import create_router as create_sessions_router +from .volumes import create_router as create_volumes_router +from .websocket import create_router as create_websocket_router def register_all_routes(server): """Register all route groups on the server's FastAPI app.""" for factory in ( create_pages_router, + create_auth_router, create_sessions_router, create_campaigns_router, create_experiments_router, diff --git a/gently/ui/web/routes/agent_ws.py b/gently/ui/web/routes/agent_ws.py index df59aaeb..fdf2fc5f 100644 --- a/gently/ui/web/routes/agent_ws.py +++ b/gently/ui/web/routes/agent_ws.py @@ -9,8 +9,8 @@ import asyncio import json import logging +from collections.abc import Callable from datetime import datetime -from typing import Dict, Optional from fastapi import APIRouter, WebSocket, WebSocketDisconnect @@ -29,15 +29,180 @@ def create_router(server) -> APIRouter: router = APIRouter() # Pending choice futures keyed by request_id - _choice_futures: Dict[str, asyncio.Future] = {} + _choice_futures: dict[str, asyncio.Future] = {} + + # ── Single-driver control arbitration ───────────────────── + # Shared across all /ws/agent clients (the router is created once). + # Only the control holder may drive the agent (chat/command/cancel); + # everyone else is an observer until they take control. This is the + # seed of the multi-user control lock and also prevents the shared + # agent conversation from being corrupted when >1 client connects. + _control: dict[str, str | None] = {"holder": None} + _clients: dict[str, Callable] = {} + _client_labels: dict[str, str] = {} + _client_counter = {"n": 0} + _raw_clients: dict[str, WebSocket] = {} # client_id -> websocket (broadcast) + + # ── Uniform display transcript ──────────────────────────── + # A single conversation history shared by every client of this session. + # Persisted to /chat_display.json so it survives reconnects and + # restarts; broadcast live so all instances stay in sync. + _history: list = [] + _history_state = {"sid": None, "path": None, "agent_buf": None, "autonomous": False} + + async def _broadcast_control_status(): + """Tell every connected agent client who currently holds control.""" + holder = _control["holder"] + holder_label = _client_labels.get(holder) if holder else None + for cid, fn in list(_clients.items()): + try: + await fn( + { + "type": "control_status", + "holder": holder, + "holder_label": holder_label, + "you_have_control": (cid == holder), + } + ) + except Exception: + pass + + def _load_history_for_session(bridge): + """Load the current session's display history, reloading if the + session changed (e.g. after a resume from the Sessions tab).""" + try: + agent = bridge.agent + store = getattr(agent, "store", None) + sid = getattr(agent, "session_id", None) + except Exception: + return + if sid == _history_state["sid"]: + return # already loaded for this session + # Session changed (or first load): reset and reload from disk. + _history.clear() + _history_state["sid"] = sid + _history_state["path"] = None + _history_state["agent_buf"] = None + _history_state["autonomous"] = False + try: + if store and sid: + sdir = store._session_dir(sid) + if sdir: + p = sdir / "chat_display.json" + _history_state["path"] = p + if p.exists(): + loaded = json.loads(p.read_text(encoding="utf-8")) or [] + if isinstance(loaded, list): + _history.extend(loaded) + except Exception: + logger.debug("Could not load chat history", exc_info=True) + + # Fallback: sessions created before chat_display.json existed (or any + # session resumed for the first time) — derive a best-effort transcript + # from the saved Claude conversation so the chat still shows history. + if not _history and store and sid: + try: + snap = store.load_session_snapshot(sid) or {} + for m in snap.get("conversation_history") or []: + role = m.get("role") + content = m.get("content") + if isinstance(content, list): + text = "".join( + b.get("text", "") + for b in content + if isinstance(b, dict) and b.get("type") == "text" + ) + else: + text = content if isinstance(content, str) else "" + text = (text or "").strip() + if not text: + continue + if role == "user": + _history.append({"role": "user", "text": text}) + elif role == "assistant": + _history.append({"role": "agent", "text": text}) + except Exception: + logger.debug("Could not derive history from conversation", exc_info=True) + + def _save_history(): + p = _history_state["path"] + if not p: + return + try: + tmp = p.with_suffix(".json.tmp") + tmp.write_text(json.dumps(_history[-500:]), encoding="utf-8") + tmp.replace(p) + except Exception: + pass - async def _run_wizard(wizard, websocket, send_fn, _choice_futures, bridge=None, log_transcript=None): + def _record(item): + _history.append(item) + if len(_history) > 500: + del _history[: len(_history) - 500] + _save_history() + + def _flush_agent_buf(): + buf = _history_state["agent_buf"] + if buf: + # An autonomous (wake) turn's text is recorded distinctly so replay + # shows it as "Gently · autonomous", not an ordinary agent reply. + role = "autonomous" if _history_state.get("autonomous") else "agent" + _record({"role": role, "text": buf}) + _history_state["agent_buf"] = None + + def _record_display(msg): + """Fold a streamed chunk into the persistent display history.""" + t = msg.get("type") + if t == "user_message": + _flush_agent_buf() + _history_state["autonomous"] = False + _record( + { + "role": "user", + "text": msg.get("text", ""), + "author": msg.get("author"), + } + ) + elif t == "autonomous_start": + # An autonomous wake turn is beginning — record the trigger banner + # and mark following text as autonomous until stream_end. + _flush_agent_buf() + _history_state["autonomous"] = True + _record({"role": "autonomous_start", "trigger": msg.get("trigger", "")}) + elif t == "text": + _history_state["agent_buf"] = (_history_state["agent_buf"] or "") + msg.get("text", "") + elif t == "tool_call": + _flush_agent_buf() + _record( + { + "role": "tool", + "name": msg.get("tool_name"), + "duration": msg.get("duration"), + "summary": msg.get("result_summary"), + } + ) + elif t == "stream_end": + _flush_agent_buf() + _history_state["autonomous"] = False + + async def _broadcast(msg): + """Record to history + send a display message to ALL clients.""" + _record_display(msg) + for _cid, ws in list(_raw_clients.items()): + try: + await ws.send_json(msg) + except Exception: + pass + + async def _run_wizard( + wizard, websocket, send_fn, _choice_futures, bridge=None, log_transcript=None + ): """Run the wizard's interactive loop. Returns the wizard task so callers can check for exceptions. Used both at startup and for the /wizard command. """ - _wizard_input_future: Optional[asyncio.Future] = None + _wizard_input_future: asyncio.Future | None = None async def _wizard_wait_for_input() -> str: nonlocal _wizard_input_future @@ -48,11 +213,13 @@ async def _wizard_wait_for_input() -> str: async def _wizard_wait_for_choice(choice_data: dict) -> str: request_id = _make_request_id() choice_data["request_id"] = request_id - await send_fn({ - "type": "choice_request", - "choice_data": choice_data, - "request_id": request_id, - }) + await send_fn( + { + "type": "choice_request", + "choice_data": choice_data, + "request_id": request_id, + } + ) loop = asyncio.get_event_loop() future = loop.create_future() _choice_futures[request_id] = future @@ -65,7 +232,8 @@ async def _wizard_wait_for_choice(choice_data: dict) -> str: while not wizard_task.done(): try: raw = await asyncio.wait_for( - websocket.receive_text(), timeout=60.0, + websocket.receive_text(), + timeout=60.0, ) except asyncio.TimeoutError: await websocket.send_json({"type": "ping"}) @@ -101,12 +269,18 @@ async def _wizard_wait_for_choice(choice_data: dict) -> str: # /reset-context kills the wizard — context is gone if command.strip().lower() == "/reset-context": wizard_task.cancel() - await send_fn({ - "type": "stream_end", - "tokens": {"input_tokens": 0, "output_tokens": 0, - "total_tokens": 0, "api_calls": 0}, - "wizard_complete": True, - }) + await send_fn( + { + "type": "stream_end", + "tokens": { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + "api_calls": 0, + }, + "wizard_complete": True, + } + ) return wizard_task elif msg_type == "ping": @@ -129,13 +303,46 @@ async def agent_websocket(websocket: WebSocket): bridge = getattr(server, "agent_bridge", None) if bridge is None: - await websocket.send_json({ - "type": "error", - "error": "Agent bridge not initialized", - }) + await websocket.send_json( + { + "type": "error", + "error": "Agent bridge not initialized", + } + ) await websocket.close() return + # Route autonomous (wake-router) turns through this router's _broadcast so + # they stream to all chat clients + persist to the display transcript. + # Idempotent; _broadcast is router-scoped and fans out to whoever is live. + bridge.register_display_broadcaster(_broadcast) + + # ── Authenticate the connection (account mode) ──────────── + # When user accounts are configured, identity comes from the signed + # session cookie (set at login). Viewers may watch but not drive; + # operators/admins may take the control lock. With no accounts + # configured we fall back to the legacy "anyone connected can drive". + from gently.ui.web.accounts import CONTROL_ROLES, get_account_store + from gently.ui.web.auth import SESSION_COOKIE + + _acct = get_account_store() + username = None + can_control = True # legacy default when no accounts are configured + if _acct is not None and _acct.has_users(): + # Viewing is open: anonymous clients may connect and *watch* the + # conversation. Only authenticated operators/admins can hold or + # take the control lock (enforced on the drive actions below). + _token = websocket.cookies.get(SESSION_COOKIE) + username = _acct.verify_session(_token) if _token else None + role = _acct.get_role(username) if username else None + can_control = role in CONTROL_ROLES + + # Assign a stable id for control arbitration. The label shown to other + # clients is the username when authenticated, else a generic window id. + _client_counter["n"] += 1 + client_id = f"agent_client_{_client_counter['n']}" + client_label = username or f"window {_client_counter['n']}" + # Send connection metadata (version, tokens, embryo count, commands) meta = bridge.get_connect_metadata() _connected_msg = { @@ -189,10 +396,16 @@ async def _push_peer_discovered(event): "_type": "single", "question": f"New peer discovered: {hostname}", "options": [ - {"id": "pair", "label": "Pair", - "description": f"Start pairing with {hostname}"}, - {"id": "ignore", "label": "Ignore", - "description": "Dismiss (you can pair later via /pair)"}, + { + "id": "pair", + "label": "Pair", + "description": f"Start pairing with {hostname}", + }, + { + "id": "ignore", + "label": "Ignore", + "description": "Dismiss (you can pair later via /pair)", + }, ], "allow_multiple": False, }, @@ -241,12 +454,20 @@ async def _push_pairing_requested(event): "type": "choice_request", "choice_data": { "_type": "single", - "question": f"{hostname} wants to pair\nVerify this code matches: {pin}", + "question": ( + f"{hostname} wants to pair\nVerify this code matches: {pin}" + ), "options": [ - {"id": "accept", "label": "Accept pairing", - "description": f"Trust {hostname} and allow mesh communication"}, - {"id": "reject", "label": "Reject", - "description": "Decline this pairing request"}, + { + "id": "accept", + "label": "Accept pairing", + "description": f"Trust {hostname} and allow mesh communication", + }, + { + "id": "reject", + "label": "Reject", + "description": "Decline this pairing request", + }, ], "allow_multiple": False, }, @@ -327,7 +548,9 @@ async def _push_scope_denied(event): pass # Peer discovery - unsub = server.event_bus.subscribe_async(_ET.MESH_PEER_DISCOVERED, _push_peer_discovered) + unsub = server.event_bus.subscribe_async( + _ET.MESH_PEER_DISCOVERED, _push_peer_discovered + ) _mesh_unsubs.append(unsub) unsub = server.event_bus.subscribe_async(_ET.MESH_PEER_LOST, _push_peer_lost) _mesh_unsubs.append(unsub) @@ -335,23 +558,29 @@ async def _push_scope_denied(event): _mesh_unsubs.append(unsub) # Pairing events - unsub = server.event_bus.subscribe_async(_ET.MESH_PAIRING_REQUESTED, _push_pairing_requested) + unsub = server.event_bus.subscribe_async( + _ET.MESH_PAIRING_REQUESTED, _push_pairing_requested + ) _mesh_unsubs.append(unsub) - unsub = server.event_bus.subscribe_async(_ET.MESH_PAIRING_COMPLETED, _push_pairing_completed) + unsub = server.event_bus.subscribe_async( + _ET.MESH_PAIRING_COMPLETED, _push_pairing_completed + ) _mesh_unsubs.append(unsub) # Security events unsub = server.event_bus.subscribe_async(_ET.MESH_AUTH_FAILURE, _push_auth_failure) _mesh_unsubs.append(unsub) - unsub = server.event_bus.subscribe_async(_ET.MESH_CERT_PIN_FAILURE, _push_cert_pin_failure) + unsub = server.event_bus.subscribe_async( + _ET.MESH_CERT_PIN_FAILURE, _push_cert_pin_failure + ) _mesh_unsubs.append(unsub) unsub = server.event_bus.subscribe_async(_ET.MESH_SCOPE_DENIED, _push_scope_denied) _mesh_unsubs.append(unsub) # Active streaming task (so we can cancel on disconnect) - active_task: Optional[asyncio.Task] = None + active_task: asyncio.Task | None = None wizard_task = None - bootstrap_task: Optional[asyncio.Task] = None + bootstrap_task: asyncio.Task | None = None # ── Session transcript ──────────────────────────────── # Log every WebSocket message (both directions) to a JSONL @@ -367,7 +596,9 @@ async def _push_scope_denied(event): sdir = store._session_dir(sid) if sdir and sdir.exists(): _transcript_file = open( - sdir / "transcript.jsonl", "a", encoding="utf-8", + sdir / "transcript.jsonl", + "a", + encoding="utf-8", ) logger.info("Transcript logging to %s", sdir / "transcript.jsonl") except Exception as e: @@ -409,24 +640,65 @@ def choice_future_factory(choice_data: dict) -> asyncio.Future: _choice_futures[request_id] = future return future + def _discard_choice(request_id: str) -> None: + _choice_futures.pop(request_id, None) + + # Give the bridge the choice-factory + discard too, so ASK-mode autonomous + # turns can round-trip an approval picker through this connection's channel + # and clean up the future on timeout/cancel. + bridge.register_display_broadcaster(_broadcast, choice_future_factory, _discard_choice) + + # Register this client for control arbitration; grant control if free + # (only to clients allowed to drive — viewers never auto-hold). + _clients[client_id] = send_fn + _client_labels[client_id] = client_label + _raw_clients[client_id] = websocket + if _control["holder"] is None and can_control: + _control["holder"] = client_id + await _broadcast_control_status() + + # Replay the uniform session transcript so every client (and every + # reconnect/refresh) shows the same conversation. + _load_history_for_session(bridge) + if _history: + try: + await websocket.send_json({"type": "history", "items": list(_history)}) + except Exception: + pass + try: # ── Wizard phase ────────────────────────────────────── - # Run startup wizard (if needed) before entering the REPL. + # The startup wizard no longer auto-pops in the chat — setup is now + # launched on demand from the Home page (which sends /wizard) or via + # the /wizard command. Re-enable auto-run by setting + # server.wizard_autorun = True. NOTE: wizard_ran below is still + # derived from wizard.needed, so the briefing/resolution path is + # unaffected by this gate. wizard = getattr(bridge, "_wizard", None) - if wizard is not None and wizard.needed: + if wizard is not None and wizard.needed and getattr(server, "wizard_autorun", False): wizard_task = await _run_wizard( - wizard, websocket, send_fn, _choice_futures, bridge, + wizard, + websocket, + send_fn, + _choice_futures, + bridge, log_transcript=_log_transcript, ) exc = _handle_wizard_result(wizard_task) if exc: logger.error(f"Wizard error: {exc}", exc_info=exc) - await send_fn({ - "type": "stream_end", - "tokens": {"input_tokens": 0, "output_tokens": 0, - "total_tokens": 0, "api_calls": 0}, - "wizard_complete": True, - }) + await send_fn( + { + "type": "stream_end", + "tokens": { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + "api_calls": 0, + }, + "wizard_complete": True, + } + ) # ── Auto-briefing or resolution picker ──────────────── # New sessions with multiple unblocked imaging candidates @@ -441,26 +713,33 @@ def choice_future_factory(choice_data: dict) -> asyncio.Future: async def _run_resolution_bootstrap(): try: await bridge.bootstrap_resolution_picker( - send_fn, choice_future_factory, + send_fn, + choice_future_factory, ) except asyncio.CancelledError: raise except Exception as exc: logger.error( - "Resolution picker failed; falling back to " - "static briefing: %s", - exc, exc_info=exc, + "Resolution picker failed; falling back to static briefing: %s", + exc, + exc_info=exc, ) try: briefing = bridge.get_session_briefing() if briefing: await send_fn({"type": "stream_start"}) await send_fn({"type": "text", "text": briefing}) - await send_fn({ - "type": "stream_end", - "tokens": {"input_tokens": 0, "output_tokens": 0, - "total_tokens": 0, "api_calls": 0}, - }) + await send_fn( + { + "type": "stream_end", + "tokens": { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + "api_calls": 0, + }, + } + ) except Exception: pass @@ -472,17 +751,24 @@ async def _run_resolution_bootstrap(): if briefing: await send_fn({"type": "stream_start"}) await send_fn({"type": "text", "text": briefing}) - await send_fn({ - "type": "stream_end", - "tokens": {"input_tokens": 0, "output_tokens": 0, - "total_tokens": 0, "api_calls": 0}, - }) + await send_fn( + { + "type": "stream_end", + "tokens": { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + "api_calls": 0, + }, + } + ) # ── Main REPL loop ──────────────────────────────────── while True: try: raw = await asyncio.wait_for( - websocket.receive_text(), timeout=60.0, + websocket.receive_text(), + timeout=60.0, ) except asyncio.TimeoutError: await websocket.send_json({"type": "ping"}) @@ -497,6 +783,53 @@ async def _run_resolution_bootstrap(): _log_transcript("in", data) msg_type = data.get("type") + # ── Control arbitration ─────────────────────────── + # A client requesting the wheel. + if msg_type == "take_control": + if not can_control: + await send_fn( + { + "type": "notification", + "level": "warning", + "title": "View-only role", + "body": "Your account can watch but not control the microscope.", + } + ) + await _broadcast_control_status() + continue + prev = _control["holder"] + _control["holder"] = client_id + if prev and prev != client_id and prev in _clients: + try: + await _clients[prev]( + { + "type": "notification", + "level": "warning", + "title": f"Control taken by {client_label}", + "body": "You are now viewing.", + } + ) + except Exception: + pass + await _broadcast_control_status() + continue + + # Only the holder may drive the agent. Observers are told + # to take control rather than silently corrupting the + # single shared conversation. + if msg_type in ("chat", "command", "cancel") and client_id != _control["holder"]: + holder_label = _client_labels.get(_control["holder"]) or "another client" + await send_fn( + { + "type": "notification", + "level": "info", + "title": f"Viewing only — control is held by {holder_label}", + "body": "Take control to drive the microscope.", + } + ) + await _broadcast_control_status() + continue + if msg_type == "chat": text = data.get("text", "").strip() if not text: @@ -506,11 +839,18 @@ async def _run_resolution_bootstrap(): if active_task and not active_task.done(): active_task.cancel() + # Echo the user's message to ALL clients (so observers see + # what was asked), then stream the reply to everyone. + await _broadcast({"type": "user_message", "text": text, "author": client_label}) active_task = asyncio.create_task( - bridge.stream_response(text, send_fn, choice_future_factory) + bridge.stream_response(text, _broadcast, choice_future_factory) ) elif msg_type == "choice_response": + # Only the control holder answers pickers (observers see + # them read-only). + if _control["holder"] != client_id: + continue request_id = data.get("request_id", "") selected = data.get("selected", "") # Check if bridge owns this choice (e.g. /import-embryos picker) @@ -535,11 +875,13 @@ async def _run_resolution_bootstrap(): if command.lower() in ("/wizard",): w = getattr(bridge, "_wizard", None) if w is None: - await send_fn({ - "type": "command_result", - "command": "/wizard", - "error": "Wizard not available", - }) + await send_fn( + { + "type": "command_result", + "command": "/wizard", + "error": "Wizard not available", + } + ) else: # Re-create wizard so it re-assesses gaps cs = getattr(bridge, "_context_store", None) @@ -548,40 +890,60 @@ async def _run_resolution_bootstrap(): w = bridge._wizard # Tell TUI we're entering wizard mode - await send_fn({ - "type": "command_result", - "command": "/wizard", - "content": {"wizard_active": True}, - }) + await send_fn( + { + "type": "command_result", + "command": "/wizard", + "content": {"wizard_active": True}, + } + ) wizard_task = await _run_wizard( - w, websocket, send_fn, _choice_futures, bridge, + w, + websocket, + send_fn, + _choice_futures, + bridge, log_transcript=_log_transcript, ) exc = _handle_wizard_result(wizard_task) if exc: logger.error(f"Wizard error: {exc}", exc_info=exc) - await send_fn({ - "type": "stream_end", - "tokens": {"input_tokens": 0, "output_tokens": 0, - "total_tokens": 0, "api_calls": 0}, - "wizard_complete": True, - }) + await send_fn( + { + "type": "stream_end", + "tokens": { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + "api_calls": 0, + }, + "wizard_complete": True, + } + ) else: try: - await bridge.handle_command(command, send_fn, choice_futures=_choice_futures) + await bridge.handle_command( + command, send_fn, choice_futures=_choice_futures + ) except Exception as e: logger.error("Command '%s' failed: %s", command, e, exc_info=True) - await send_fn({ - "type": "command_result", - "command": command, - "error": str(e), - }) + await send_fn( + { + "type": "command_result", + "command": command, + "error": str(e), + } + ) elif msg_type == "browse": target = data.get("target", "") await _handle_browse( - target, data, server, bridge, send_fn, + target, + data, + server, + bridge, + send_fn, ) elif msg_type == "ping": @@ -615,11 +977,24 @@ async def _run_resolution_bootstrap(): active_task.cancel() if bootstrap_task is not None and not bootstrap_task.done(): bootstrap_task.cancel() - # Clean up pending futures - for future in _choice_futures.values(): - if not future.done(): - future.cancel() - _choice_futures.clear() + # Release control arbitration for this client; hand the wheel + # to any remaining client (or free it) and resync everyone. + _clients.pop(client_id, None) + _client_labels.pop(client_id, None) + _raw_clients.pop(client_id, None) + if _control["holder"] == client_id: + _control["holder"] = next(iter(_clients), None) + try: + await _broadcast_control_status() + except Exception: + pass + # Clean up pending futures only when the last client leaves — + # otherwise we'd cancel another connected client's pending choices. + if not _clients: + for future in _choice_futures.values(): + if not future.done(): + future.cancel() + _choice_futures.clear() return router @@ -643,13 +1018,20 @@ def serialize_campaign(c): else: children = [] items_raw = cs.get_plan_items(campaign_id=c.id) - items = [{ - "id": item.id, - "title": item.title, - "status": item.status.value if hasattr(item.status, "value") else str(item.status), - "type": item.type.value if hasattr(item.type, "value") else str(item.type), - "claimed_by_hostname": getattr(item, "claimed_by_hostname", None), - } for item in items_raw] + items = [ + { + "id": item.id, + "title": item.title, + "status": item.status.value + if hasattr(item.status, "value") + else str(item.status), + "type": item.type.value + if hasattr(item.type, "value") + else str(item.type), + "claimed_by_hostname": getattr(item, "claimed_by_hostname", None), + } + for item in items_raw + ] return { "id": c.id, "shorthand": c.shorthand or "", @@ -676,17 +1058,19 @@ def serialize_campaign(c): peers = mesh_svc.get_peers() result = [] for p in peers: - result.append({ - "instance_id": p.instance_id, - "hostname": p.hostname, - "ip_address": p.ip_address, - "viz_port": p.viz_port, - "mode": p.status.agent_mode if p.status else "unknown", - "embryo_count": p.status.embryo_count if p.status else 0, - "is_trusted": p.is_trusted, - "tls_enabled": p.tls_enabled, - "shared_campaigns": [], - }) + result.append( + { + "instance_id": p.instance_id, + "hostname": p.hostname, + "ip_address": p.ip_address, + "viz_port": p.viz_port, + "mode": p.status.agent_mode if p.status else "unknown", + "embryo_count": p.status.embryo_count if p.status else 0, + "is_trusted": p.is_trusted, + "tls_enabled": p.tls_enabled, + "shared_campaigns": [], + } + ) await send_fn({"type": "browse_result", "target": "peers", "data": result}) elif target == "peer_campaigns": @@ -708,25 +1092,29 @@ def serialize_campaign(c): campaigns = [] if p.instance_id == peer.instance_id: for c in shared: - campaigns.append({ - "id": c.get("id", ""), - "shorthand": c.get("shorthand", ""), - "description": c.get("description", ""), - "total": c.get("item_count", 0), - "completed": c.get("completed_count", 0), - "items": [], - }) - result.append({ - "instance_id": p.instance_id, - "hostname": p.hostname, - "ip_address": p.ip_address, - "viz_port": p.viz_port, - "mode": p.status.agent_mode if p.status else "unknown", - "embryo_count": p.status.embryo_count if p.status else 0, - "is_trusted": p.is_trusted, - "tls_enabled": p.tls_enabled, - "shared_campaigns": campaigns, - }) + campaigns.append( + { + "id": c.get("id", ""), + "shorthand": c.get("shorthand", ""), + "description": c.get("description", ""), + "total": c.get("item_count", 0), + "completed": c.get("completed_count", 0), + "items": [], + } + ) + result.append( + { + "instance_id": p.instance_id, + "hostname": p.hostname, + "ip_address": p.ip_address, + "viz_port": p.viz_port, + "mode": p.status.agent_mode if p.status else "unknown", + "embryo_count": p.status.embryo_count if p.status else 0, + "is_trusted": p.is_trusted, + "tls_enabled": p.tls_enabled, + "shared_campaigns": campaigns, + } + ) await send_fn({"type": "browse_result", "target": "peer_campaigns", "data": result}) elif target == "peer_campaign_items": @@ -734,31 +1122,53 @@ def serialize_campaign(c): campaign_id = data.get("campaign_id", "") mesh_svc = getattr(server, "mesh_service", None) if not mesh_svc or not hostname or not campaign_id: - await send_fn({"type": "browse_result", "target": "peer_campaign_items", "data": []}) + await send_fn( + { + "type": "browse_result", + "target": "peer_campaign_items", + "data": [], + } + ) return peer = mesh_svc.find_peer_by_hostname(hostname) if not peer or not mesh_svc.peer_client: - await send_fn({"type": "browse_result", "target": "peer_campaign_items", "data": []}) + await send_fn( + { + "type": "browse_result", + "target": "peer_campaign_items", + "data": [], + } + ) return export = await mesh_svc.peer_client.fetch_campaign_export(peer, campaign_id) if not export: - await send_fn({"type": "browse_result", "target": "peer_campaign_items", "data": []}) + await send_fn( + { + "type": "browse_result", + "target": "peer_campaign_items", + "data": [], + } + ) return items = [] for item in export.get("items", []): - items.append({ - "id": item.get("id", ""), - "title": item.get("title", ""), - "status": item.get("status", "planned"), - "claimed_by_hostname": item.get("claimed_by_hostname"), - }) - await send_fn({ - "type": "browse_result", - "target": "peer_campaign_items", - "data": items, - "campaign_id": campaign_id, - "hostname": hostname, - }) + items.append( + { + "id": item.get("id", ""), + "title": item.get("title", ""), + "status": item.get("status", "planned"), + "claimed_by_hostname": item.get("claimed_by_hostname"), + } + ) + await send_fn( + { + "type": "browse_result", + "target": "peer_campaign_items", + "data": items, + "campaign_id": campaign_id, + "hostname": hostname, + } + ) except Exception as e: logger.debug(f"Browse error ({target}): {e}") diff --git a/gently/ui/web/routes/auth_routes.py b/gently/ui/web/routes/auth_routes.py new file mode 100644 index 00000000..07e338f1 --- /dev/null +++ b/gently/ui/web/routes/auth_routes.py @@ -0,0 +1,123 @@ +"""Auth routes — login / logout / me, plus the login page. + +Self-managed accounts (see gently/ui/web/accounts.py). Login issues a signed +session cookie; roles (viewer/operator/admin) gate control elsewhere via +gently.ui.web.auth.resolve_role and the /ws/agent control lock. +""" + +import logging + +from fastapi import APIRouter, Request +from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse + +from gently.ui.web.accounts import ( + _SESSION_TTL_SECONDS, + CONTROL_ROLES, + ROLES, + get_account_store, +) +from gently.ui.web.auth import SESSION_COOKIE, current_username + +logger = logging.getLogger(__name__) + + +def create_router(server) -> APIRouter: + router = APIRouter() + + def _secure(request: Request) -> bool: + # Only mark the cookie Secure over HTTPS, else the browser drops it on + # plain-HTTP LAN deployments. + return request.url.scheme == "https" + + @router.get("/login", response_class=HTMLResponse) + async def login_page(request: Request): + store = get_account_store() + if store is None or not store.has_users(): + return RedirectResponse("/", status_code=302) + if current_username(request): + return RedirectResponse("/", status_code=302) + return server.templates.TemplateResponse(request, "login.html") + + @router.post("/api/auth/login") + async def login(request: Request): + store = get_account_store() + if store is None or not store.has_users(): + return JSONResponse({"error": "accounts not configured"}, status_code=400) + try: + body = await request.json() + except Exception: + body = {} + username = (body.get("username") or "").strip() + password = body.get("password") or "" + role = store.verify_password(username, password) + if not role: + host = request.client.host if request.client else "?" + logger.warning("login failed for %r from %s", username, host) + return JSONResponse({"error": "Invalid username or password"}, status_code=401) + token = store.issue_session(username) + resp = JSONResponse({"ok": True, "username": username, "role": role}) + resp.set_cookie( + SESSION_COOKIE, + token, + httponly=True, + samesite="lax", + secure=_secure(request), + max_age=_SESSION_TTL_SECONDS, + path="/", + ) + logger.info("login ok: %s (%s)", username, role) + return resp + + @router.post("/api/auth/logout") + async def logout(request: Request): + resp = JSONResponse({"ok": True}) + resp.delete_cookie(SESSION_COOKIE, path="/") + return resp + + @router.get("/api/auth/me") + async def me(request: Request): + store = get_account_store() + if store is None or not store.has_users(): + return JSONResponse({"accounts": False, "authenticated": False}) + username = current_username(request) + if not username: + return JSONResponse({"accounts": True, "authenticated": False}) + role = store.get_role(username) + return JSONResponse( + { + "accounts": True, + "authenticated": True, + "username": username, + "role": role, + "can_control": role in CONTROL_ROLES, + } + ) + + @router.post("/api/auth/users") + async def create_user(request: Request): + """Admin-only: provision a new account.""" + store = get_account_store() + if store is None: + return JSONResponse({"error": "accounts not configured"}, status_code=400) + requester = current_username(request) + if not requester or store.get_role(requester) != "admin": + return JSONResponse({"error": "admin role required"}, status_code=403) + try: + body = await request.json() + except Exception: + body = {} + new_user = (body.get("username") or "").strip() + password = body.get("password") or "" + role = body.get("role") or "viewer" + if not new_user or not password: + return JSONResponse({"error": "username and password required"}, status_code=400) + if role not in ROLES: + return JSONResponse({"error": f"role must be one of {list(ROLES)}"}, status_code=400) + try: + store.create_user(new_user, password, role) + except ValueError as e: + return JSONResponse({"error": str(e)}, status_code=400) + logger.info("admin %s created user %s (%s)", requester, new_user, role) + return JSONResponse({"ok": True, "username": new_user, "role": role}) + + return router diff --git a/gently/ui/web/routes/campaigns.py b/gently/ui/web/routes/campaigns.py index e5e59611..a3fea08c 100644 --- a/gently/ui/web/routes/campaigns.py +++ b/gently/ui/web/routes/campaigns.py @@ -4,7 +4,7 @@ import logging from dataclasses import asdict from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any from fastapi import APIRouter, Depends, HTTPException, Request @@ -121,13 +121,13 @@ async def get_campaign_document(campaign_id: str): # Pre-index every item in the tree once. The naive enrichment used to call # cs.get_plan_item(...) per dep + per dependent, each one walking the on-disk # campaign index — O(items × deps × campaigns) YAML reads per request. - items_by_id: Dict[str, Dict] = {} - dependents_map: Dict[str, List[str]] = {} + items_by_id: dict[str, dict] = {} + dependents_map: dict[str, list[str]] = {} def _index(node): for it in node.get("items", []): items_by_id[it["id"]] = it - for dep_id in (it.get("depends_on") or []): + for dep_id in it.get("depends_on") or []: dependents_map.setdefault(dep_id, []).append(it["id"]) for child in node.get("children", []): _index(child) @@ -147,16 +147,12 @@ def _enrich_tree(node): for item in node.get("items", []): item_id = item["id"] dep_ids = list(item.get("depends_on") or []) - item["dependencies"] = [ - {"id": d, "title": _resolve_title(d)} for d in dep_ids - ] + item["dependencies"] = [{"id": d, "title": _resolve_title(d)} for d in dep_ids] dnt_ids = dependents_map.get(item_id, []) - item["dependents"] = [ - {"id": d, "title": _resolve_title(d)} for d in dnt_ids - ] + item["dependents"] = [{"id": d, "title": _resolve_title(d)} for d in dnt_ids] # Collect references into bibliography - for ref in (item.get("references") or []): + for ref in item.get("references") or []: source = ref.get("source", "") key = ref.get("key", ref.get("id", ref.get("title", ""))) dedup_key = (source, key) @@ -227,16 +223,26 @@ async def get_item_detail(campaign_id: str, item_id: str): dependencies = [] for did in dep_ids: dep = cs.get_plan_item(did) - dependencies.append({"id": did, "title": dep.title if dep else did[:8], - "status": dep.status.value if dep else None}) + dependencies.append( + { + "id": did, + "title": dep.title if dep else did[:8], + "status": dep.status.value if dep else None, + } + ) # Dependents with titles dnt_ids = cs.get_plan_item_dependents(item_id) dependents = [] for did in dnt_ids: dnt = cs.get_plan_item(did) - dependents.append({"id": did, "title": dnt.title if dnt else did[:8], - "status": dnt.status.value if dnt else None}) + dependents.append( + { + "id": did, + "title": dnt.title if dnt else did[:8], + "status": dnt.status.value if dnt else None, + } + ) # Sessions linked to this campaign sessions = cs.get_sessions_for_campaign(item.campaign_id) @@ -280,9 +286,12 @@ async def _require(request: Request): if required_scope not in scopes: if _audit: from gently.mesh.audit import AuditEvent + _audit.log( - AuditEvent.SCOPE_DENIED, outcome="deny", - peer_id=peer_id, ip=host, + AuditEvent.SCOPE_DENIED, + outcome="deny", + peer_id=peer_id, + ip=host, detail=f"scope={required_scope} path={request.url.path}", ) raise HTTPException( @@ -291,36 +300,51 @@ async def _require(request: Request): ) if _audit: from gently.mesh.audit import AuditEvent + _audit.log( - AuditEvent.AUTH_SUCCESS, outcome="allow", - peer_id=peer_id, ip=host, + AuditEvent.AUTH_SUCCESS, + outcome="allow", + peer_id=peer_id, + ip=host, ) return if _audit: from gently.mesh.audit import AuditEvent + _audit.log( - AuditEvent.AUTH_FAILURE, outcome="deny", - ip=host, detail=f"path={request.url.path}", + AuditEvent.AUTH_FAILURE, + outcome="deny", + ip=host, + detail=f"path={request.url.path}", ) raise HTTPException(status_code=403, detail="Mesh authentication required") return _require - @router.post("/api/campaigns/{campaign_id}/share", dependencies=[Depends(_make_campaign_auth("campaigns:admin"))]) + @router.post( + "/api/campaigns/{campaign_id}/share", + dependencies=[Depends(_make_campaign_auth("campaigns:admin"))], + ) async def share_campaign(campaign_id: str): cs = _get_store() campaign = _resolve(cs, campaign_id) cs.share_campaign(campaign.id) return {"ok": True} - @router.post("/api/campaigns/{campaign_id}/unshare", dependencies=[Depends(_make_campaign_auth("campaigns:admin"))]) + @router.post( + "/api/campaigns/{campaign_id}/unshare", + dependencies=[Depends(_make_campaign_auth("campaigns:admin"))], + ) async def unshare_campaign(campaign_id: str): cs = _get_store() campaign = _resolve(cs, campaign_id) cs.unshare_campaign(campaign.id) return {"ok": True} - @router.get("/api/campaigns/{campaign_id}/export", dependencies=[Depends(_make_campaign_auth("campaigns"))]) + @router.get( + "/api/campaigns/{campaign_id}/export", + dependencies=[Depends(_make_campaign_auth("campaigns"))], + ) async def export_campaign(campaign_id: str): cs = _get_store() campaign = _resolve(cs, campaign_id) @@ -328,7 +352,10 @@ async def export_campaign(campaign_id: str): _enrich_export_with_claims(tree, cs, campaign.id) return tree - @router.post("/api/campaigns/{campaign_id}/join", dependencies=[Depends(_make_campaign_auth("campaigns"))]) + @router.post( + "/api/campaigns/{campaign_id}/join", + dependencies=[Depends(_make_campaign_auth("campaigns"))], + ) async def join_campaign(campaign_id: str, request: Request): cs = _get_store() campaign = _resolve(cs, campaign_id) @@ -340,14 +367,20 @@ async def join_campaign(campaign_id: str, request: Request): cs.add_campaign_participant(campaign.id, instance_id, hostname) return {"ok": True} - @router.get("/api/campaigns/{campaign_id}/participants", dependencies=[Depends(_make_campaign_auth("campaigns"))]) + @router.get( + "/api/campaigns/{campaign_id}/participants", + dependencies=[Depends(_make_campaign_auth("campaigns"))], + ) async def get_participants(campaign_id: str): cs = _get_store() campaign = _resolve(cs, campaign_id) participants = cs.get_campaign_participants(campaign.id) return {"participants": participants} - @router.post("/api/campaigns/{campaign_id}/items/{item_id}/claim", dependencies=[Depends(_make_campaign_auth("campaigns"))]) + @router.post( + "/api/campaigns/{campaign_id}/items/{item_id}/claim", + dependencies=[Depends(_make_campaign_auth("campaigns"))], + ) async def claim_item(campaign_id: str, item_id: str, request: Request): cs = _get_store() _resolve(cs, campaign_id) @@ -361,14 +394,20 @@ async def claim_item(campaign_id: str, item_id: str, request: Request): raise HTTPException(status_code=409, detail="Item already claimed by another node") return {"ok": True} - @router.post("/api/campaigns/{campaign_id}/items/{item_id}/unclaim", dependencies=[Depends(_make_campaign_auth("campaigns"))]) + @router.post( + "/api/campaigns/{campaign_id}/items/{item_id}/unclaim", + dependencies=[Depends(_make_campaign_auth("campaigns"))], + ) async def unclaim_item(campaign_id: str, item_id: str): cs = _get_store() _resolve(cs, campaign_id) cs.unclaim_plan_item(item_id) return {"ok": True} - @router.post("/api/campaigns/{campaign_id}/items/{item_id}/status", dependencies=[Depends(_make_campaign_auth("campaigns"))]) + @router.post( + "/api/campaigns/{campaign_id}/items/{item_id}/status", + dependencies=[Depends(_make_campaign_auth("campaigns"))], + ) async def update_item_status(campaign_id: str, item_id: str, request: Request): cs = _get_store() _resolve(cs, campaign_id) @@ -380,7 +419,7 @@ async def update_item_status(campaign_id: str, item_id: str, request: Request): try: item_status = PlanItemStatus(status_str) except ValueError: - raise HTTPException(status_code=400, detail=f"Invalid status: {status_str}") + raise HTTPException(status_code=400, detail=f"Invalid status: {status_str}") from None cs.update_plan_item(item_id, status=item_status, outcome=outcome) return {"ok": True} @@ -388,7 +427,7 @@ async def update_item_status(campaign_id: str, item_id: str, request: Request): # Helpers # ------------------------------------------------------------------ - def _build_campaign_tree(cs, campaign_id: str) -> Optional[Dict]: + def _build_campaign_tree(cs, campaign_id: str) -> dict | None: """Recursively build campaign tree with plan items and status.""" campaign = cs.get_campaign(campaign_id) if not campaign: @@ -410,13 +449,10 @@ def _build_campaign_tree(cs, campaign_id: str) -> Optional[Dict]: "in_progress": status["in_progress"], "planned": status["planned"], }, - "children": [ - _build_campaign_tree(cs, child.id) - for child in children - ], + "children": [_build_campaign_tree(cs, child.id) for child in children], } - def _enrich_export_with_claims(tree: Dict, cs, campaign_id: str): + def _enrich_export_with_claims(tree: dict, cs, campaign_id: str): """Walk a serialized campaign tree and annotate items with IDs and claim info.""" items = cs.get_plan_items(campaign_id=campaign_id) items.sort(key=lambda x: x.phase_order) diff --git a/gently/ui/web/routes/chat.py b/gently/ui/web/routes/chat.py index 12833b13..8c66dd45 100644 --- a/gently/ui/web/routes/chat.py +++ b/gently/ui/web/routes/chat.py @@ -14,12 +14,13 @@ import logging from datetime import datetime from pathlib import Path -from typing import Optional -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel +from gently.ui.web.auth import require_control + logger = logging.getLogger(__name__) CHAT_MODEL = "claude-opus-4-7" @@ -43,21 +44,21 @@ class ChatRequest(BaseModel): message: str -def _resolve_session_dir(server, sid: str) -> Optional[Path]: +def _resolve_session_dir(server, sid: str) -> Path | None: store = getattr(server, "gently_store", None) if store is None: return None return store._session_dir(sid) -def _trace_path(server, sid: str, eid: str, tp: int) -> Optional[Path]: +def _trace_path(server, sid: str, eid: str, tp: int) -> Path | None: sd = _resolve_session_dir(server, sid) if sd is None: return None return sd / "embryos" / eid / "traces" / f"t{tp:04d}.json" -def _chat_path(server, sid: str, eid: str, tp: int) -> Optional[Path]: +def _chat_path(server, sid: str, eid: str, tp: int) -> Path | None: sd = _resolve_session_dir(server, sid) if sd is None: return None @@ -68,7 +69,7 @@ def _load_history(path: Path) -> list[dict]: if not path.exists(): return [] turns: list[dict] = [] - with open(path, "r", encoding="utf-8") as f: + with open(path, encoding="utf-8") as f: for line in f: line = line.strip() if not line: @@ -106,7 +107,13 @@ async def get_chat(sid: str, eid: str, tp: int): return {"turns": _load_history(path)} @router.post("/api/perception/chat/{sid}/{eid}/{tp}") - async def post_chat(sid: str, eid: str, tp: int, body: ChatRequest): + async def post_chat( + sid: str, + eid: str, + tp: int, + body: ChatRequest, + _control=Depends(require_control), # noqa: B008 + ): """Append a user message and stream the assistant reply as SSE. Each SSE event is JSON: ``{"type": "delta", "text": "..."}`` for @@ -126,7 +133,7 @@ async def post_chat(sid: str, eid: str, tp: int, body: ChatRequest): detail=f"No perception trace for T{tp}", ) - with open(trace_path, "r", encoding="utf-8") as f: + with open(trace_path, encoding="utf-8") as f: trace = json.load(f) stage = trace.get("predicted_stage", "unknown") reasoning = trace.get("reasoning", "") @@ -164,9 +171,7 @@ async def post_chat(sid: str, eid: str, tp: int, body: ChatRequest): } seed_assistant = { "role": "assistant", - "content": [ - {"type": "text", "text": f"Stage: {stage}\n\n{reasoning}"} - ], + "content": [{"type": "text", "text": f"Stage: {stage}\n\n{reasoning}"}], } messages: list[dict] = [seed_user, seed_assistant] @@ -174,9 +179,7 @@ async def post_chat(sid: str, eid: str, tp: int, body: ChatRequest): role = turn.get("role") content = turn.get("content", "") if role in ("user", "assistant") and content: - messages.append( - {"role": role, "content": [{"type": "text", "text": content}]} - ) + messages.append({"role": role, "content": [{"type": "text", "text": content}]}) messages.append( { "role": "user", diff --git a/gently/ui/web/routes/data.py b/gently/ui/web/routes/data.py index 3e66763f..93d49451 100644 --- a/gently/ui/web/routes/data.py +++ b/gently/ui/web/routes/data.py @@ -3,10 +3,11 @@ import logging from datetime import datetime from pathlib import Path -from typing import Optional import yaml -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, Body, Depends, HTTPException + +from gently.ui.web.auth import require_control logger = logging.getLogger(__name__) @@ -26,7 +27,7 @@ async def get_status(): "status": "running", "connections": len(server.manager.active_connections), **stats, - "timestamp": datetime.now().isoformat() + "timestamp": datetime.now().isoformat(), } @router.get("/api/device-status") @@ -63,6 +64,131 @@ async def get_device_status(): "microscope": microscope_up, } + def _require_agent_with_experiment(): + """Resolve the live agent from the server bridge, or 503. + + Edit endpoints write through ExperimentState so the notify hook fires + EMBRYOS_UPDATE and the Map re-renders without a follow-up fetch. + """ + bridge = getattr(server, "agent_bridge", None) + agent = bridge.agent if bridge is not None else None + if agent is None or not hasattr(agent, "experiment"): + raise HTTPException(status_code=503, detail="Agent not ready") + return agent + + @router.put("/api/embryos/{embryo_id}/position", dependencies=[Depends(require_control)]) + async def update_embryo_position( + embryo_id: str, + body: dict = Body(...), # noqa: B008 + ): + """Update an embryo's coarse XY position. + + Map-side edits write to the coarse stage and CLEAR any prior fine + position — the operator is overriding the sighting, so any + SPIM-objective fine alignment derived from the old coarse is no + longer trustworthy and must be re-run. + + Publishes OPERATOR_EDITED_EMBRYO with both the old and new + positions so candidates can reason about the magnitude of the + correction and trigger re-calibration suggestions. + """ + agent = _require_agent_with_experiment() + emb = agent.experiment.embryos.get(embryo_id) + if emb is None: + raise HTTPException(status_code=404, detail=f"Embryo {embryo_id} not found") + try: + x = float(body.get("x")) + y = float(body.get("y")) + except (TypeError, ValueError): + raise HTTPException(status_code=400, detail="Body needs numeric x and y") from None + old_coarse = dict(emb.position_coarse) if emb.position_coarse else None + had_fine = bool(emb.position_fine) + emb.position_coarse = {"x": x, "y": y} + emb.position_fine = {} + agent.experiment.notify_embryos_changed() + + bus = getattr(agent, "_event_bus", None) + if bus is not None: + from gently.core.event_bus import EventType + + try: + bus.publish( + event_type=EventType.OPERATOR_EDITED_EMBRYO, + data={ + "embryo_id": embryo_id, + "old_position_coarse": old_coarse, + "new_position_coarse": {"x": x, "y": y}, + "fine_position_invalidated": had_fine, + }, + source="web:map-edit", + ) + except Exception: + logger.exception("Failed to publish OPERATOR_EDITED_EMBRYO") + return emb.to_dict() + + @router.delete("/api/embryos/{embryo_id}", dependencies=[Depends(require_control)]) + async def delete_embryo(embryo_id: str): + """Remove an embryo from the experiment. + + Goes through ExperimentState.remove_embryo so the observer hook + fires EMBRYOS_UPDATE automatically. Also publishes + OPERATOR_REMOVED_EMBRYO carrying the embryo's last known position + — candidates can use that to e.g. clean up associated cache or + log the deletion in their own world model. + """ + agent = _require_agent_with_experiment() + emb = agent.experiment.embryos.get(embryo_id) + last_position = None + if emb is not None: + last_position = { + "coarse": dict(emb.position_coarse) if emb.position_coarse else None, + "fine": dict(emb.position_fine) if emb.position_fine else None, + } + if not agent.experiment.remove_embryo(embryo_id): + raise HTTPException(status_code=404, detail=f"Embryo {embryo_id} not found") + + bus = getattr(agent, "_event_bus", None) + if bus is not None: + from gently.core.event_bus import EventType + + try: + bus.publish( + event_type=EventType.OPERATOR_REMOVED_EMBRYO, + data={ + "embryo_id": embryo_id, + "last_position": last_position, + }, + source="web:map-delete", + ) + except Exception: + logger.exception("Failed to publish OPERATOR_REMOVED_EMBRYO") + return {"ok": True, "embryo_id": embryo_id} + + @router.get("/api/embryos/current") + async def get_current_embryos(): + """Return the agent's current embryo list as an EMBRYOS_UPDATE payload. + + EMBRYOS_UPDATE is published only on mutation, so a Map page opened + mid-session would otherwise see an empty embryo layer until the next + add/remove/edit. This endpoint serves the same payload shape as the + event so clients can bootstrap and then switch to the live stream. + """ + empty = {"embryos": [], "count": 0, "session_id": None} + bridge = getattr(server, "agent_bridge", None) + agent = bridge.agent if bridge is not None else None + if agent is None or not hasattr(agent, "experiment"): + return empty + try: + embryos = [e.to_dict() for e in agent.experiment.embryos.values()] + except Exception: + logger.exception("Failed to serialise embryos for snapshot") + return empty + return { + "embryos": embryos, + "count": len(embryos), + "session_id": getattr(agent, "session_id", None), + } + @router.get("/api/devices/coverslip") async def get_coverslip(): """Return the coverslip outline metadata for the Map view. @@ -74,17 +200,19 @@ async def get_coverslip(): block in this config and no zone endpoint here. """ try: - with open(_HARDWARE_CONFIG_PATH, "r") as f: + with open(_HARDWARE_CONFIG_PATH) as f: cfg = yaml.safe_load(f) or {} except FileNotFoundError: return {"coverslip": None} cs = cfg.get("coverslip") if not isinstance(cs, dict): return {"coverslip": None} - return {"coverslip": { - "center_um": list(cs.get("center_um") or [0.0, 0.0]), - "size_mm": list(cs.get("size_mm") or [50.0, 24.0]), - }} + return { + "coverslip": { + "center_um": list(cs.get("center_um") or [0.0, 0.0]), + "size_mm": list(cs.get("size_mm") or [50.0, 24.0]), + } + } @router.get("/api/devices/bottom_camera/status") async def get_bottom_camera_status(): @@ -98,7 +226,10 @@ async def get_bottom_camera_status(): "last_frame_ts": getattr(monitor, "_last_frame_ts", None) if monitor else None, } - @router.post("/api/devices/bottom_camera/stream/start") + @router.post( + "/api/devices/bottom_camera/stream/start", + dependencies=[Depends(require_control)], + ) async def start_bottom_camera_stream(): """Start the bottom-camera stream bridge. @@ -116,10 +247,13 @@ async def start_bottom_camera_stream(): await monitor.start() except Exception as exc: logger.exception("Failed to start bottom-camera monitor") - raise HTTPException(status_code=500, detail=f"start failed: {exc}") + raise HTTPException(status_code=500, detail=f"start failed: {exc}") from exc return {"streaming": monitor.running} - @router.post("/api/devices/bottom_camera/stream/stop") + @router.post( + "/api/devices/bottom_camera/stream/stop", + dependencies=[Depends(require_control)], + ) async def stop_bottom_camera_stream(): """Stop the bottom-camera stream bridge. Idempotent.""" bridge = getattr(server, "agent_bridge", None) @@ -131,42 +265,133 @@ async def stop_bottom_camera_stream(): await monitor.stop() except Exception as exc: logger.exception("Failed to stop bottom-camera monitor") - raise HTTPException(status_code=500, detail=f"stop failed: {exc}") + raise HTTPException(status_code=500, detail=f"stop failed: {exc}") from exc return {"streaming": False} + def _resolve_client(): + """Resolve the live microscope client from the agent bridge, or None.""" + bridge = getattr(server, "agent_bridge", None) + agent = bridge.agent if bridge is not None else None + return getattr(agent, "client", None) if agent else None + + @router.get("/api/devices/room_light/status") + async def get_room_light_status(): + """Cached on/off state of the room-light SwitchBot (cheap to poll).""" + client = _resolve_client() + if client is None: + return {"available": False, "state": "unknown"} + try: + res = await client.get_room_light_status() + except Exception as exc: + logger.debug("room light status fetch failed: %s", exc) + return {"available": False, "state": "unknown"} + return { + "available": bool(res.get("available", res.get("success", False))), + "state": res.get("state", "unknown"), + } + + @router.post("/api/devices/room_light/set", dependencies=[Depends(require_control)]) + async def set_room_light(payload: dict = Body(...)): # noqa: B008 + """Switch the room light on/off. Body: {"state": "on"|"off"|"press"}.""" + state = str(payload.get("state", "")).lower() + if state not in ("on", "off", "press"): + raise HTTPException(status_code=400, detail="state must be on, off, or press") + client = _resolve_client() + if client is None: + raise HTTPException(status_code=503, detail="Microscope not connected") + try: + res = await client.set_room_light(state) + except Exception as exc: + logger.exception("Room light command failed") + raise HTTPException( + status_code=502, detail=f"room light command failed: {exc}" + ) from exc + if not res.get("success"): + raise HTTPException( + status_code=502, detail=res.get("error", "room light command failed") + ) + return {"state": res.get("state", state)} + + @router.get("/api/devices/temperature/status") + async def get_temperature_status(): + """Live water temperature, setpoint, and lock state (cheap to poll). + + Cached at the device layer (no per-call hardware round trip), so the + Devices header can poll it like the room light. ``available`` is false + when no controller is configured/connected, which hides the control. + """ + client = _resolve_client() + if client is None: + return {"available": False, "state": "unknown"} + try: + res = await client.get_temperature() + except Exception as exc: + logger.debug("temperature status fetch failed: %s", exc) + return {"available": False, "state": "unknown"} + return { + "available": bool(res.get("success", False)), + "temperature_c": res.get("temperature_c"), + "setpoint_c": res.get("setpoint_c"), + "state": res.get("state", "unknown"), + "peltier_c": res.get("peltier_c"), + } + + @router.post("/api/devices/temperature/set", dependencies=[Depends(require_control)]) + async def set_temperature(payload: dict = Body(...)): # noqa: B008 + """Command the temperature setpoint. Body: {"target_c": float}. + + Non-blocking: the controller ramps and the status poll reflects progress + (and the SYSTEM LOCKED state once it stabilizes). + """ + try: + target = float(payload.get("target_c")) + except (TypeError, ValueError): + raise HTTPException(status_code=400, detail="target_c must be a number") from None + if not (0.0 <= target <= 99.9): + raise HTTPException(status_code=400, detail="target_c must be between 0.0 and 99.9 C") + client = _resolve_client() + if client is None: + raise HTTPException(status_code=503, detail="Microscope not connected") + try: + res = await client.set_temperature(target) + except Exception as exc: + logger.exception("Temperature command failed") + raise HTTPException( + status_code=502, detail=f"temperature command failed: {exc}" + ) from exc + if not res.get("success"): + raise HTTPException( + status_code=502, detail=res.get("error", "temperature command failed") + ) + return { + "target_c": res.get("target_c", target), + "temperature_c": res.get("temperature_c"), + "state": res.get("state", "unknown"), + "waited": res.get("waited", False), + } + @router.get("/api/calibration") - async def list_calibration(embryo_id: Optional[str] = None): + async def list_calibration(embryo_id: str | None = None): """Get calibration images""" images = server.store.get_all_calibration(embryo_id) - return { - "calibration": [img.to_dict() for img in images], - "count": len(images) - } + return {"calibration": [img.to_dict() for img in images], "count": len(images)} @router.get("/api/volumes") - async def list_volumes(embryo_id: Optional[str] = None): + async def list_volumes(embryo_id: str | None = None): """Get volume images""" images = server.store.get_all_volumes(embryo_id) - return { - "volumes": [img.to_dict() for img in images], - "count": len(images) - } + return {"volumes": [img.to_dict() for img in images], "count": len(images)} @router.get("/api/snapshots") - async def list_snapshots(embryo_id: Optional[str] = None): + async def list_snapshots(embryo_id: str | None = None): """Get snapshot images""" images = server.store.get_all_snapshots(embryo_id) - return { - "snapshots": [img.to_dict() for img in images], - "count": len(images) - } + return {"snapshots": [img.to_dict() for img in images], "count": len(images)} @router.get("/api/embryos") async def list_embryos(): """Get list of embryos with images""" - return { - "embryos": server.store.get_embryo_ids() - } + return {"embryos": server.store.get_embryo_ids()} @router.get("/api/embryos/positions") async def embryo_positions(): @@ -190,26 +415,28 @@ async def embryo_positions(): # Embryo registered but no position yet (e.g. only the # ID arrived from another path). Skip — nothing to render. continue - points.append({ - "embryo_id": eid, - "uid": emb.get("uid"), - "x": float(x), - "y": float(y), - "role": emb.get("role", "test"), - "user_label": emb.get("user_label"), - "confidence": emb.get("confidence"), - "cadence_phase": emb.get("cadence_phase"), - "is_complete": bool(emb.get("is_complete")), - }) + points.append( + { + "embryo_id": eid, + "uid": emb.get("uid"), + "x": float(x), + "y": float(y), + "role": emb.get("role", "test"), + "user_label": emb.get("user_label"), + "confidence": emb.get("confidence"), + "cadence_phase": emb.get("cadence_phase"), + "is_complete": bool(emb.get("is_complete")), + } + ) return {"embryos": points} @router.get("/api/sequence/{embryo_id}") async def get_image_sequence( embryo_id: str, start: int = 0, - end: Optional[int] = None, + end: int | None = None, data_type: str = "volume_projection", - buffer_percent: float = 0.15 + buffer_percent: float = 0.15, ): """Get ordered sequence of images for timepoint range. @@ -230,7 +457,7 @@ async def get_image_sequence( embryo_id=embryo_id, start=buffered_start, end=buffered_end, - data_type=data_type + data_type=data_type, ) # Return lightweight metadata (no base64 data) @@ -238,26 +465,25 @@ async def get_image_sequence( seen_uids = set() for img in images: seen_uids.add(img.uid) - sequence.append({ - "uid": img.uid, - "timepoint": img.metadata.get("timepoint"), - "timestamp": img.timestamp, - "data_type": img.data_type, - "shape": img.shape, - "embryo_id": img.metadata.get("embryo_id") - }) + sequence.append( + { + "uid": img.uid, + "timepoint": img.metadata.get("timepoint"), + "timestamp": img.timestamp, + "data_type": img.data_type, + "shape": img.shape, + "embryo_id": img.metadata.get("embryo_id"), + } + ) # Fallback to persistent DataStore for missing timepoints if server.data_store and (len(sequence) == 0 or buffered_end is not None): try: - refs = server.data_store.query( - data_type=data_type, - embryo_id=embryo_id - ) + refs = server.data_store.query(data_type=data_type, embryo_id=embryo_id) for ref in refs: if ref.uid in seen_uids: continue - tp = ref.metadata.get('timepoint') + tp = ref.metadata.get("timepoint") if tp is None: continue tp = int(tp) @@ -266,16 +492,18 @@ async def get_image_sequence( if buffered_end is not None and tp > buffered_end: continue seen_uids.add(ref.uid) - sequence.append({ - "uid": ref.uid, - "timepoint": tp, - "timestamp": ref.metadata.get('timestamp', ''), - "data_type": ref.data_type, - "shape": ref.metadata.get('shape'), - "embryo_id": embryo_id - }) + sequence.append( + { + "uid": ref.uid, + "timepoint": tp, + "timestamp": ref.metadata.get("timestamp", ""), + "data_type": ref.data_type, + "shape": ref.metadata.get("shape"), + "embryo_id": embryo_id, + } + ) # Re-sort by timepoint - sequence.sort(key=lambda x: x.get('timepoint') or 0) + sequence.sort(key=lambda x: x.get("timepoint") or 0) except Exception as e: logger.warning(f"DataStore fallback failed: {e}") @@ -284,14 +512,12 @@ async def get_image_sequence( "requested_range": {"start": start, "end": end}, "buffered_range": {"start": buffered_start, "end": buffered_end}, "sequence": sequence, - "count": len(sequence) + "count": len(sequence), } @router.get("/api/events") async def list_events( - event_type: Optional[str] = None, - source: Optional[str] = None, - limit: int = 100 + event_type: str | None = None, source: str | None = None, limit: int = 100 ): """Get event history from EventBus""" if not server.event_bus: @@ -299,6 +525,7 @@ async def list_events( # Get history from event bus from gently.core import EventType + et = None if event_type: try: @@ -306,24 +533,22 @@ async def list_events( except KeyError: pass - events = server.event_bus.get_history( - event_type=et, - source=source, - limit=limit - ) + events = server.event_bus.get_history(event_type=et, source=source, limit=limit) return { "events": [ { - "event_type": e.event_type.name if hasattr(e.event_type, 'name') else str(e.event_type), + "event_type": e.event_type.name + if hasattr(e.event_type, "name") + else str(e.event_type), "data": e.data, "source": e.source, "timestamp": e.timestamp.isoformat(), - "event_id": e.event_id + "event_id": e.event_id, } for e in events ], - "total": len(events) + "total": len(events), } return router diff --git a/gently/ui/web/routes/experiments.py b/gently/ui/web/routes/experiments.py index 3b8c9506..c2c5de77 100644 --- a/gently/ui/web/routes/experiments.py +++ b/gently/ui/web/routes/experiments.py @@ -61,6 +61,6 @@ async def get_strategy(session_id: str): raise HTTPException( status_code=500, detail=f"Failed to build strategy: {e}", - ) + ) from e return router diff --git a/gently/ui/web/routes/images.py b/gently/ui/web/routes/images.py index 2ecc805a..ce690493 100644 --- a/gently/ui/web/routes/images.py +++ b/gently/ui/web/routes/images.py @@ -2,11 +2,10 @@ import base64 import logging -from typing import Optional import numpy as np from fastapi import APIRouter, HTTPException, Request -from fastapi.responses import Response, FileResponse +from fastapi.responses import FileResponse, Response from ..volume_helpers import parse_volume_uid @@ -66,17 +65,22 @@ async def get_image_png(uid: str): if data is None and parsed: embryo_id, timepoint = parsed if embryo_id in server.timelapse_tracker.projection_uids: - real_uid = server.timelapse_tracker.projection_uids[embryo_id].get(timepoint) + real_uid = server.timelapse_tracker.projection_uids[embryo_id].get( + timepoint + ) if real_uid: data = server.data_store.retrieve(real_uid) if data is not None: from io import BytesIO + from PIL import Image + from gently.core.imaging import ( - projection_three_view, - compute_crop_bounds, apply_crop_bounds, + compute_crop_bounds, + projection_three_view, ) + # Handle numpy array if isinstance(data, np.ndarray): # Handle 4D volumes (Views, Z, Y, X) - take View A @@ -87,30 +91,45 @@ async def get_image_png(uid: str): z_depth, height, width = data.shape # Handle dual-view format if width > height * 2: - data = data[:, :, :width // 2] + data = data[:, :, : width // 2] # Auto-crop and project bounds = compute_crop_bounds(data) data = apply_crop_bounds(data, bounds) data, _ = projection_three_view(data) # Normalize to uint8 if needed if data.dtype != np.uint8: - data = ((data - data.min()) / (data.max() - data.min() + 1e-8) * 255).astype(np.uint8) + data = ( + (data - data.min()) / (data.max() - data.min() + 1e-8) * 255 + ).astype(np.uint8) img = Image.fromarray(data) buf = BytesIO() - img.save(buf, format='PNG') - return Response(content=buf.getvalue(), media_type="image/png", headers=cache_headers) + img.save(buf, format="PNG") + return Response( + content=buf.getvalue(), + media_type="image/png", + headers=cache_headers, + ) except Exception as e: logger.warning(f"Failed to load image {uid} from DataStore: {e}") - # Fallback to FileStore JPEG projections (persistent on-disk) + # Fallback to FileStore JPEG projections (persistent on-disk). + # Unlike the in-memory base64 images, an on-disk projection CAN change + # (e.g. regenerated after a projection-format fix), so we must NOT mark + # it immutable with a content-independent (uid) ETag — that pins the + # browser to the stale image. Use a content-aware ETag (mtime+size) + # and a short max-age so a regeneration is picked up. if server.gently_store and parsed: embryo_id, timepoint = parsed proj_path = server._resolve_projection_path(embryo_id, timepoint) if proj_path: + st = proj_path.stat() return FileResponse( str(proj_path), media_type="image/jpeg", - headers=cache_headers, + headers={ + "Cache-Control": "public, max-age=3600", + "ETag": f'"{uid}-{int(st.st_mtime)}-{st.st_size}"', + }, ) raise HTTPException(status_code=404, detail=f"Image {uid} not found") @@ -122,21 +141,18 @@ async def push_image_http(request: Request): data = await request.json() # Decode the image from base64 - image_b64 = data.get('image_b64') - uid = data.get('uid') - shape = data.get('shape') - dtype = data.get('dtype', 'uint8') - data_type = data.get('data_type', 'cv_visualization') - metadata = data.get('metadata', {}) + image_b64 = data.get("image_b64") + uid = data.get("uid") + shape = data.get("shape") + dtype = data.get("dtype", "uint8") + data_type = data.get("data_type", "cv_visualization") + metadata = data.get("metadata", {}) if not all([image_b64, uid, shape]): raise HTTPException(status_code=400, detail="Missing required fields") # Decode array - array = np.frombuffer( - base64.b64decode(image_b64), - dtype=np.dtype(dtype) - ).reshape(shape) + array = np.frombuffer(base64.b64decode(image_b64), dtype=np.dtype(dtype)).reshape(shape) # Push using the existing method await server.push_image(array, uid, data_type, metadata) @@ -145,6 +161,6 @@ async def push_image_http(request: Request): except Exception as e: logger.error(f"Failed to push image via HTTP: {e}") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException(status_code=500, detail=str(e)) from e return router diff --git a/gently/ui/web/routes/pages.py b/gently/ui/web/routes/pages.py index 1412e014..3858f4e7 100644 --- a/gently/ui/web/routes/pages.py +++ b/gently/ui/web/routes/pages.py @@ -9,10 +9,14 @@ def create_router(server) -> APIRouter: @router.get("/", response_class=HTMLResponse) async def index(request: Request): - """Serve the main SPA page""" + """Serve the main SPA page. + + Viewing is open to everyone — the dashboard loads in view mode with no + login. Signing in is an *elevation* to control (handled in-app via the + chat window's "Sign in" affordance), not a gate on the page itself. + """ return server.templates.TemplateResponse( - "index.html", - {"request": request, "active_section": "embryos", "is_live": True} + request, "index.html", {"active_section": "embryos", "is_live": True} ) # Standalone URLs redirect to SPA with hash fragment for tab routing @@ -32,8 +36,8 @@ async def plan_review_page(campaign_id: str): async def settings_page(request: Request): """Serve the dashboard settings page""" return server.templates.TemplateResponse( + request, "settings.html", - {"request": request} ) return router diff --git a/gently/ui/web/routes/sessions.py b/gently/ui/web/routes/sessions.py index 69d70d47..ccbc3a93 100644 --- a/gently/ui/web/routes/sessions.py +++ b/gently/ui/web/routes/sessions.py @@ -1,9 +1,12 @@ -"""Session routes - list and retrieve saved sessions.""" +"""Session routes - list, retrieve, and resume saved sessions.""" -import json import logging +from pathlib import Path -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import FileResponse + +from gently.ui.web.auth import require_control logger = logging.getLogger(__name__) @@ -11,39 +14,221 @@ def create_router(server) -> APIRouter: router = APIRouter() + def _file_store(): + """The live FileStore (current Gently3 layout), via the agent.""" + bridge = getattr(server, "agent_bridge", None) + if bridge is not None and getattr(bridge, "agent", None) is not None: + st = getattr(bridge.agent, "store", None) + if st is not None: + return st + return getattr(server, "gently_store", None) + + def _active_session_id(): + bridge = getattr(server, "agent_bridge", None) + agent = bridge.agent if bridge is not None else None + return getattr(agent, "session_id", None) if agent is not None else None + @router.get("/api/sessions") async def list_sessions(): - """List available sessions with metadata""" + """List available sessions (from the live FileStore).""" + store = _file_store() + if store is None: + return {"sessions": []} + active_id = _active_session_id() sessions = [] - if server.sessions_dir.exists(): - for path in server.sessions_dir.glob("*.json"): + try: + for s in store.list_sessions(): + sid = s.get("session_id") try: - with open(path, encoding='utf-8') as f: - data = json.load(f) - sessions.append({ - 'session_id': data.get('session_id', path.stem), - 'name': data.get('name', path.stem), - 'created_at': data.get('created_at', ''), - 'last_active': data.get('last_active', ''), - 'embryo_count': len(data.get('embryo_states', {})), - 'description': data.get('description', '') - }) - except Exception as e: - logger.warning(f"Failed to read session {path}: {e}") - # Sort by created_at descending (newest first) - sessions.sort(key=lambda x: x.get('created_at', ''), reverse=True) - return {'sessions': sessions} + count = len(store.list_embryos(sid) or []) + except Exception: + count = 0 + sessions.append( + { + "session_id": sid, + "name": s.get("name") or sid, + "created_at": s.get("created_at", ""), + "last_active": s.get("last_active", ""), + "embryo_count": count, + "description": s.get("description", ""), + "active": sid == active_id, + } + ) + except Exception as e: + logger.warning("Failed to list sessions from FileStore: %s", e) + return {"sessions": sessions} - @router.get("/api/sessions/{session_id}") - async def get_session(session_id: str): - """Get full session state for review""" - path = server.sessions_dir / f"{session_id}.json" - if not path.exists(): + @router.get("/api/home/recent-images") + async def recent_images(limit: int = 8, scan: int = 200): + """Latest projection per embryo, aggregated across recent sessions. + + Unlike /api/snapshots (in-memory, current session only), this walks the + FileStore on disk so the home page can show imagery from *previous* + sessions. Cheap by construction: recent session IDs come from folder + names (no session.yaml parse), embryo IDs from directory names (no + embryo.yaml parse), timepoints from a filename glob (no pixel decode), + and the walk stops as soon as `limit` images are collected. + + `scan` is the *budget* of most-recent sessions to walk while hunting for + images, NOT a hard window — empty/aborted sessions (common at the head: + a rig accrues many no-capture sessions) are skipped nearly for free + (one iterdir each), so the default is generous enough to reach older + sessions that actually hold projections. Both bounds are clamped so a + crafted ?scan=/?limit= can't turn this unauthenticated read into an + unbounded scan. Returns components; the client builds the (encoded) URL. + """ + store = _file_store() + if store is None: + return {"images": []} + limit = max(1, min(int(limit), 48)) + scan = max(1, min(int(scan), 500)) + out = [] + try: + for sid in store.recent_session_ids(scan) or []: + try: + eids = store.list_embryo_ids(sid) + except Exception: + eids = [] + sname = None # parsed lazily, only if this session contributes + for eid in eids: + try: + tps = store.list_projection_timepoints(sid, eid) or [] + except Exception: + tps = [] + if not tps: + continue + if sname is None: + try: + info = store.get_session(sid) + except Exception: + info = None + sname = (info.get("name") if info else None) or sid + out.append( + { + "session_id": sid, + "session_name": sname, + "embryo_id": eid, + "timepoint": int(max(tps)), + } + ) + if len(out) >= limit: + break + if len(out) >= limit: + break + except Exception as e: + logger.warning("recent_images failed: %s", e) + return {"images": out[:limit]} + + @router.get("/api/sessions/{session_id}/projection") + async def get_session_projection(session_id: str, embryo: str, t: int): + """Serve a saved JPEG projection from any session on disk. + + Path-traversal safe: the resolved file must live inside the session's + own directory, so a crafted `embryo` (e.g. '../..') can't escape. + """ + store = _file_store() + if store is None: + raise HTTPException(status_code=503, detail="Store not available") + path = store.get_projection_path(session_id, embryo, t) + if path is None: + raise HTTPException(status_code=404, detail="Projection not found") + try: + sd = store._session_dir(session_id) + resolved = Path(path).resolve() + # Component-wise ancestor check (not str.startswith, which would + # let a sibling like `_evil` slip through the prefix match). + sd_resolved = Path(sd).resolve() if sd is not None else None + if sd_resolved is None or sd_resolved not in resolved.parents: + raise HTTPException(status_code=404, detail="Not found") + except HTTPException: + raise + except Exception: + raise HTTPException(status_code=404, detail="Not found") from None + try: + st = resolved.stat() + etag = f'"{int(st.st_mtime)}-{st.st_size}"' + except OSError: + etag = None + headers = {"Cache-Control": "private, max-age=60"} + if etag: + headers["ETag"] = etag + return FileResponse(str(resolved), media_type="image/jpeg", headers=headers) + + @router.post("/api/sessions/{session_id}/resume", dependencies=[Depends(require_control)]) + async def resume_session(session_id: str): + """Switch the live agent to a different saved session. + + Reuses the same machinery as CLI resume (saves the current session, + loads the target's embryos + conversation). Then nudges all browser + clients to reload so they pick up the new session's state and + transcript. + """ + bridge = getattr(server, "agent_bridge", None) + agent = bridge.agent if bridge is not None else None + if agent is None: + raise HTTPException(status_code=503, detail="Agent not ready") + store = getattr(agent, "store", None) + if store is None or store.get_session(session_id) is None: raise HTTPException(status_code=404, detail="Session not found") + if session_id == getattr(agent, "session_id", None): + return { + "ok": True, + "session_id": session_id, + "active": True, + "note": "already active", + } try: - with open(path, encoding='utf-8') as f: - return json.load(f) + ok = agent.resume_session(session_id) except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to load session: {e}") + logger.exception("Session resume failed") + raise HTTPException(status_code=500, detail=f"resume failed: {e}") from e + if not ok: + raise HTTPException(status_code=500, detail="resume returned false") + # Rehydrate the viz image store from disk so the resumed session's + # projections/filmstrips show (pixels load lazily from the FileStore). + rehydrated = 0 + try: + rehydrated = server.rehydrate_session(session_id) + except Exception: + logger.exception("rehydrate_session failed") + # Tell every connected browser to reload — they'll reconnect to the + # new session's state (embryos, transcript, rehydrated imagery). + try: + await server.manager.broadcast({"type": "session_changed", "session_id": session_id}) + except Exception: + pass + return { + "ok": True, + "session_id": session_id, + "active": True, + "rehydrated_projections": rehydrated, + } + + @router.get("/api/sessions/{session_id}") + async def get_session(session_id: str): + """Get session state for review, from the live FileStore. + + Maps the FileStore session snapshot onto the shape the Sessions review + view expects (embryo_states / conversation). detection_history isn't + reconstructed here (per-timepoint predictions live elsewhere). + """ + store = _file_store() + if store is None: + raise HTTPException(status_code=503, detail="Store not available") + info = store.get_session(session_id) + if info is None: + raise HTTPException(status_code=404, detail="Session not found") + snapshot = store.load_session_snapshot(session_id) or {} + experiment = snapshot.get("experiment_data", {}) or {} + return { + "session_id": session_id, + "name": info.get("name") or session_id, + "description": info.get("description", ""), + "created_at": info.get("created_at", ""), + "last_active": info.get("last_active", ""), + "embryo_states": experiment.get("embryos", {}) or {}, + "conversation": snapshot.get("conversation_history", []) or [], + "detection_history": {}, + } return router diff --git a/gently/ui/web/routes/volumes.py b/gently/ui/web/routes/volumes.py index e48475c3..d2c66dac 100644 --- a/gently/ui/web/routes/volumes.py +++ b/gently/ui/web/routes/volumes.py @@ -3,18 +3,18 @@ import base64 import io import logging -from typing import Optional import numpy as np from fastapi import APIRouter, HTTPException, Request from fastapi.responses import Response -from ..volume_helpers import load_volume_from_disk, image_to_base64_png +from ..volume_helpers import image_to_base64_png, load_volume_from_disk logger = logging.getLogger(__name__) try: from PIL import Image + PIL_AVAILABLE = True except ImportError: PIL_AVAILABLE = False @@ -31,7 +31,8 @@ async def get_projections(embryo_id: str, timepoint: int, method: str = "all"): Args: embryo_id: Embryo identifier timepoint: Timepoint number (1-indexed) - method: Projection method - 'all', 'three_view', 'dual_view', 'depth_colored', 'multi_slice' + method: Projection method - 'all', 'three_view', 'dual_view', + 'depth_colored', 'multi_slice' Returns: List of projections with method name, description, and base64 PNG data @@ -41,7 +42,10 @@ async def get_projections(embryo_id: str, timepoint: int, method: str = "all"): # Look up volume path (timelapse tracker + FileStore fallback) volume_path = server._resolve_volume_path(embryo_id, timepoint) if not volume_path: - raise HTTPException(status_code=404, detail=f"No volume for {embryo_id} at timepoint {timepoint}") + raise HTTPException( + status_code=404, + detail=f"No volume for {embryo_id} at timepoint {timepoint}", + ) # Load volume from disk try: @@ -52,29 +56,32 @@ async def get_projections(embryo_id: str, timepoint: int, method: str = "all"): vol = (vol - vol.min()) / (vol.max() - vol.min() + 1e-8) except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) + raise HTTPException(status_code=404, detail=str(e)) from e except Exception as e: logger.error(f"Failed to load volume: {e}") - raise HTTPException(status_code=500, detail=f"Failed to load volume: {e}") + raise HTTPException(status_code=500, detail=f"Failed to load volume: {e}") from e PROJECTION_METHODS = { - 'three_view': projection_three_view, + "three_view": projection_three_view, } # Try to import additional projection methods from explorer try: from gently.dataset.explorer_server import ( - projection_dual_view, projection_depth_colored, + projection_dual_view, projection_multi_slice, projection_spin_3d, ) - PROJECTION_METHODS.update({ - 'dual_view': projection_dual_view, - 'depth_colored': projection_depth_colored, - 'multi_slice': projection_multi_slice, - 'spin_3d': projection_spin_3d, - }) + + PROJECTION_METHODS.update( + { + "dual_view": projection_dual_view, + "depth_colored": projection_depth_colored, + "multi_slice": projection_multi_slice, + "spin_3d": projection_spin_3d, + } + ) except ImportError: pass # Explorer projections not available @@ -84,22 +91,31 @@ async def get_projections(embryo_id: str, timepoint: int, method: str = "all"): for method_name, method_func in PROJECTION_METHODS.items(): try: proj_img, desc = method_func(vol) - projections.append({ - "method": method_name, - "description": desc, - "data": image_to_base64_png(proj_img), - }) + projections.append( + { + "method": method_name, + "description": desc, + "data": image_to_base64_png(proj_img), + } + ) except Exception as e: logger.warning(f"Projection {method_name} failed: {e}") else: if method not in PROJECTION_METHODS: - raise HTTPException(status_code=400, detail=f"Unknown method: {method}. Available: {list(PROJECTION_METHODS.keys())}") + raise HTTPException( + status_code=400, + detail=( + f"Unknown method: {method}. Available: {list(PROJECTION_METHODS.keys())}" + ), + ) proj_img, desc = PROJECTION_METHODS[method](vol) - projections.append({ - "method": method, - "description": desc, - "data": image_to_base64_png(proj_img), - }) + projections.append( + { + "method": method, + "description": desc, + "data": image_to_base64_png(proj_img), + } + ) return { "embryo_id": embryo_id, @@ -120,7 +136,10 @@ async def get_volume_raw(embryo_id: str, timepoint: int): # Look up volume path (timelapse tracker + FileStore fallback) volume_path = server._resolve_volume_path(embryo_id, timepoint) if not volume_path: - raise HTTPException(status_code=404, detail=f"No volume for {embryo_id} at timepoint {timepoint}") + raise HTTPException( + status_code=404, + detail=f"No volume for {embryo_id} at timepoint {timepoint}", + ) try: vol = load_volume_from_disk(volume_path) @@ -137,7 +156,7 @@ async def get_volume_raw(embryo_id: str, timepoint: int): # Encode as base64 vol_bytes = vol_uint8.tobytes() - vol_b64 = base64.b64encode(vol_bytes).decode('utf-8') + vol_b64 = base64.b64encode(vol_bytes).decode("utf-8") # Physical voxel size for isometric 3D rendering. # Matches the default in gently.core.imaging.projection_three_view: @@ -155,17 +174,17 @@ async def get_volume_raw(embryo_id: str, timepoint: int): } except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) + raise HTTPException(status_code=404, detail=str(e)) from e except Exception as e: logger.error(f"Failed to load volume: {e}") - raise HTTPException(status_code=500, detail=f"Failed to load volume: {e}") + raise HTTPException(status_code=500, detail=f"Failed to load volume: {e}") from e @router.get("/api/volumes3d") async def list_volumes_3d(): """Get list of 3D volumes (without heavy data)""" return { "volumes_3d": server.store.get_all_volumes_3d(), - "count": len(server.store._volumes_3d) + "count": len(server.store._volumes_3d), } @router.get("/api/volumes3d/{uid}") @@ -188,7 +207,7 @@ async def get_volume_3d_slice(uid: str, z: int): if PIL_AVAILABLE: img = Image.fromarray(rgb) buffer = io.BytesIO() - img.save(buffer, format='PNG') + img.save(buffer, format="PNG") return Response(content=buffer.getvalue(), media_type="image/png") raise HTTPException(status_code=500, detail="PIL not available") @@ -209,14 +228,17 @@ async def get_volume_data_for_3d_viewer(uid: str): volume = np.zeros(volume.shape, dtype=np.uint8) return { "shape": list(volume.shape), - "data": base64.b64encode(volume.tobytes()).decode('utf-8'), - "uid": uid + "data": base64.b64encode(volume.tobytes()).decode("utf-8"), + "uid": uid, } # Check if it's a regular image with stored volume data image = server.store.get_image_by_uid(uid) if image and image.shape and len(image.shape) == 3: - raise HTTPException(status_code=404, detail=f"Volume data for {uid} not available - only segmented volumes supported") + raise HTTPException( + status_code=404, + detail=f"Volume data for {uid} not available - only segmented volumes supported", + ) raise HTTPException(status_code=404, detail=f"Volume {uid} not found") @@ -227,27 +249,25 @@ async def push_volume_3d_http(request: Request): data = await request.json() # Decode the volume and masks from base64 - volume_b64 = data.get('volume_b64') - masks_b64 = data.get('masks_b64') - uid = data.get('uid') - shape = data.get('shape') - dtype_vol = data.get('dtype_vol', 'uint16') - dtype_mask = data.get('dtype_mask', 'uint16') - metadata = data.get('metadata', {}) + volume_b64 = data.get("volume_b64") + masks_b64 = data.get("masks_b64") + uid = data.get("uid") + shape = data.get("shape") + dtype_vol = data.get("dtype_vol", "uint16") + dtype_mask = data.get("dtype_mask", "uint16") + metadata = data.get("metadata", {}) if not all([volume_b64, masks_b64, uid, shape]): raise HTTPException(status_code=400, detail="Missing required fields") # Decode arrays - volume = np.frombuffer( - base64.b64decode(volume_b64), - dtype=np.dtype(dtype_vol) - ).reshape(shape) + volume = np.frombuffer(base64.b64decode(volume_b64), dtype=np.dtype(dtype_vol)).reshape( + shape + ) - masks = np.frombuffer( - base64.b64decode(masks_b64), - dtype=np.dtype(dtype_mask) - ).reshape(shape) + masks = np.frombuffer(base64.b64decode(masks_b64), dtype=np.dtype(dtype_mask)).reshape( + shape + ) # Push using the existing method await server.push_volume_3d(volume, masks, uid, metadata) @@ -256,6 +276,6 @@ async def push_volume_3d_http(request: Request): except Exception as e: logger.error(f"Failed to push 3D volume via HTTP: {e}") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException(status_code=500, detail=str(e)) from e return router diff --git a/gently/ui/web/routes/websocket.py b/gently/ui/web/routes/websocket.py index b49518e0..de66552c 100644 --- a/gently/ui/web/routes/websocket.py +++ b/gently/ui/web/routes/websocket.py @@ -11,6 +11,36 @@ logger = logging.getLogger(__name__) +# /ws message types that mutate experiment state (define what gets imaged). +# These are control actions and are gated by role; pure read/presence +# messages stay open so anyone can watch. +_MARKING_TYPES = frozenset( + { + "embryo_marked", + "marking_update", + "marking_done", + "marking_redetect", + } +) + + +def _ws_can_control(websocket: WebSocket) -> bool: + """Whether this /ws client may perform control actions (marking). + + Account mode: operators/admins (by session cookie) only. Legacy mode + (no accounts configured): open, preserving prior behavior. + """ + from gently.ui.web.accounts import CONTROL_ROLES, get_account_store + from gently.ui.web.auth import SESSION_COOKIE + + store = get_account_store() + if store is None or not store.has_users(): + return True + token = websocket.cookies.get(SESSION_COOKIE) + user = store.verify_session(token) if token else None + role = store.get_role(user) if user else None + return role in CONTROL_ROLES + def create_router(server) -> APIRouter: router = APIRouter() @@ -23,27 +53,30 @@ async def websocket_endpoint(websocket: WebSocket): try: # Send current status on connect stats = server.store.get_stats() - await websocket.send_json({ - "type": "connected", - **stats, - "timestamp": datetime.now().isoformat() - }) + await websocket.send_json( + {"type": "connected", **stats, "timestamp": datetime.now().isoformat()} + ) # Always send timelapse state on connect so client can reconcile # (if IDLE with no session_id, client will clear stale cached state) timelapse_state = server.timelapse_tracker.to_dict() - await websocket.send_json({ - "type": "timelapse_state", - "data": timelapse_state - }) + # The header's session id is driven by this payload; the tracker's + # session_id goes stale after a resume with no active timelapse, so + # override it with the live agent session (the source of truth). + try: + bridge = getattr(server, "agent_bridge", None) + if bridge is not None and getattr(bridge, "agent", None) is not None: + live_sid = bridge.agent.session_id + if live_sid: + timelapse_state["session_id"] = live_sid + except Exception: + pass + await websocket.send_json({"type": "timelapse_state", "data": timelapse_state}) # Keep connection alive and handle incoming messages while True: try: - data = await asyncio.wait_for( - websocket.receive_text(), - timeout=30.0 - ) + data = await asyncio.wait_for(websocket.receive_text(), timeout=30.0) # Handle client messages (e.g., requests) await _handle_ws_message(server, websocket, data) except asyncio.TimeoutError: @@ -77,41 +110,37 @@ async def _handle_ws_message(server, websocket: WebSocket, message: str): msg_type = data.get("type") embryo_id = data.get("embryo_id") + # Gate control actions (marking) by role; viewing/presence stays open. + if msg_type in _MARKING_TYPES and not _ws_can_control(websocket): + logger.warning("Ignored %s from a view-only /ws client", msg_type) + return + if msg_type == "get_calibration": images = server.store.get_all_calibration(embryo_id) - await websocket.send_json({ - "type": "calibration", - "data": [img.to_dict() for img in images] - }) + await websocket.send_json( + {"type": "calibration", "data": [img.to_dict() for img in images]} + ) elif msg_type == "get_volumes": images = server.store.get_all_volumes(embryo_id) - await websocket.send_json({ - "type": "volumes", - "data": [img.to_dict() for img in images] - }) + await websocket.send_json( + {"type": "volumes", "data": [img.to_dict() for img in images]} + ) elif msg_type == "get_snapshots": images = server.store.get_all_snapshots(embryo_id) - await websocket.send_json({ - "type": "snapshots", - "data": [img.to_dict() for img in images] - }) + await websocket.send_json( + {"type": "snapshots", "data": [img.to_dict() for img in images]} + ) elif msg_type == "get_embryos": - await websocket.send_json({ - "type": "embryos", - "data": server.store.get_embryo_ids() - }) + await websocket.send_json({"type": "embryos", "data": server.store.get_embryo_ids()}) elif msg_type == "get_image": uid = data.get("uid") image = server.store.get_image_by_uid(uid) if image: - await websocket.send_json({ - "type": "image", - "data": image.to_dict() - }) + await websocket.send_json({"type": "image", "data": image.to_dict()}) elif msg_type == "pong": pass # Client responding to ping @@ -125,7 +154,8 @@ async def _handle_ws_message(server, websocket: WebSocket, message: str): # Sanitize name: strip HTML tags, limit length if name: import re - name = re.sub(r'<[^>]+>', '', name)[:50] + + name = re.sub(r"<[^>]+>", "", name)[:50] # Update the client's info async with server.manager._lock: if websocket in server.manager.active_connections: @@ -134,7 +164,7 @@ async def _handle_ws_message(server, websocket: WebSocket, message: str): client_id=client_id, name=name or old_info.name, color=server.manager._generate_color(client_id), - connected_at=old_info.connected_at + connected_at=old_info.connected_at, ) await server.manager.broadcast_presence() @@ -144,7 +174,8 @@ async def _handle_ws_message(server, websocket: WebSocket, message: str): if name: # Sanitize name: strip HTML tags, limit length import re - name = re.sub(r'<[^>]+>', '', name)[:50] + + name = re.sub(r"<[^>]+>", "", name)[:50] await server.manager.update_client_name(websocket, name) elif msg_type == "get_presence": @@ -155,16 +186,19 @@ async def _handle_ws_message(server, websocket: WebSocket, message: str): elif msg_type == "embryo_marked": session_id = data.get("session_id") marker = data.get("marker") - if session_id and marker and hasattr(server, '_marking_sessions'): + if session_id and marker and hasattr(server, "_marking_sessions"): session = server._marking_sessions.get(session_id) if session: session["markers"].append(marker) - logger.info(f"Embryo marked: #{marker['number']} at ({marker['pixelX']}, {marker['pixelY']})") + logger.info( + f"Embryo marked: #{marker['number']}" + f" at ({marker['pixelX']}, {marker['pixelY']})" + ) elif msg_type == "marking_update": session_id = data.get("session_id") markers = data.get("markers", []) - if session_id and hasattr(server, '_marking_sessions'): + if session_id and hasattr(server, "_marking_sessions"): session = server._marking_sessions.get(session_id) if session: session["markers"] = markers @@ -173,7 +207,7 @@ async def _handle_ws_message(server, websocket: WebSocket, message: str): elif msg_type == "marking_done": session_id = data.get("session_id") markers = data.get("markers", []) - if session_id and hasattr(server, '_marking_sessions'): + if session_id and hasattr(server, "_marking_sessions"): session = server._marking_sessions.get(session_id) if session: session["markers"] = markers @@ -183,8 +217,7 @@ async def _handle_ws_message(server, websocket: WebSocket, message: str): r = m.get("role", "test") role_summary[r] = role_summary.get(r, 0) + 1 logger.info( - f"Marking complete: {len(markers)} embryo(s) " - f"(roles: {role_summary})" + f"Marking complete: {len(markers)} embryo(s) (roles: {role_summary})" ) elif msg_type == "marking_redetect": @@ -194,13 +227,14 @@ async def _handle_ws_message(server, websocket: WebSocket, message: str): # listen for. Once recapture lands, the agent calls # start_marking_session again with the new image + markers. session_id = data.get("session_id") - if session_id and hasattr(server, '_marking_sessions'): + if session_id and hasattr(server, "_marking_sessions"): session = server._marking_sessions.get(session_id) if session is not None: session["redetect_requested"] = True logger.info(f"Marking redetect requested for session {session_id}") try: from gently.core import EventType, get_event_bus + get_event_bus().publish( event_type=EventType.STATUS_CHANGED, data={ diff --git a/gently/ui/web/server.py b/gently/ui/web/server.py index ff14f9bb..c16c548d 100644 --- a/gently/ui/web/server.py +++ b/gently/ui/web/server.py @@ -24,7 +24,6 @@ import sys from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Optional import numpy as np @@ -35,11 +34,12 @@ # Optional imports try: + import uvicorn from fastapi import FastAPI + from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates - from fastapi.middleware.cors import CORSMiddleware - import uvicorn + FASTAPI_AVAILABLE = True except ImportError: FASTAPI_AVAILABLE = False @@ -60,8 +60,7 @@ class _InvalidHttpFilter(logging.Filter): def filter(self, record: logging.LogRecord) -> bool: msg = record.getMessage() if "Invalid HTTP request received" in msg: - logger.debug("uvicorn dropped non-HTTP bytes on viz port " - "(probable TLS/peer mismatch)") + logger.debug("uvicorn dropped non-HTTP bytes on viz port (probable TLS/peer mismatch)") return False return True @@ -76,25 +75,39 @@ def filter(self, record: logging.LogRecord) -> bool: try: from PIL import Image + PIL_AVAILABLE = True except ImportError: PIL_AVAILABLE = False # Import data models and components -from .models import ( - ClientInfo, Volume3DData, ImageData, EmbryoImageCache, - CALIBRATION_TYPES, VOLUME_TYPES, ANALYSIS_TYPES, VOLUME_3D_TYPES, +from .connection_manager import ConnectionManager # noqa: E402 +from .image_store import ImageStore # noqa: E402 +from .models import ( # noqa: E402 + ANALYSIS_TYPES, + CALIBRATION_TYPES, + VOLUME_3D_TYPES, + VOLUME_TYPES, + ClientInfo, + EmbryoImageCache, + ImageData, + Volume3DData, ) -from .image_store import ImageStore -from .timelapse_tracker import TimelapseStateTracker -from .connection_manager import ConnectionManager +from .timelapse_tracker import TimelapseStateTracker # noqa: E402 # Re-export for backward compatibility __all__ = [ - 'VisualizationServer', 'create_visualization_server', - 'ClientInfo', 'Volume3DData', 'ImageData', 'EmbryoImageCache', - 'CALIBRATION_TYPES', 'VOLUME_TYPES', 'ANALYSIS_TYPES', 'VOLUME_3D_TYPES', - 'ImageStore', + "VisualizationServer", + "create_visualization_server", + "ClientInfo", + "Volume3DData", + "ImageData", + "EmbryoImageCache", + "CALIBRATION_TYPES", + "VOLUME_TYPES", + "ANALYSIS_TYPES", + "VOLUME_3D_TYPES", + "ImageStore", ] @@ -130,8 +143,8 @@ def __init__( event_bus=None, sessions_dir: str = str(settings.storage.sessions_dir), gently_store=None, - ssl_certfile: str = None, - ssl_keyfile: str = None, + ssl_certfile: str | None = None, + ssl_keyfile: str | None = None, ): super().__init__(name="visualization", service_type="http", host=host, port=port) if not FASTAPI_AVAILABLE: @@ -161,7 +174,7 @@ def __init__( self.app = FastAPI( title="Gently Visualization Server", description="Real-time microscopy visualization", - version="2.0.0" + version="2.0.0", ) # Setup templates and static files @@ -179,6 +192,7 @@ def __init__( # Register route groups from .routes import register_all_routes + register_all_routes(self) # Subscribe to events if event bus provided @@ -193,7 +207,7 @@ def set_context_store(self, context_store) -> None: """Set the FileContextStore for campaign/plan data access.""" self.context_store = context_store - def _resolve_volume_path(self, embryo_id: str, timepoint: int) -> Optional[str]: + def _resolve_volume_path(self, embryo_id: str, timepoint: int) -> str | None: """Resolve volume file path from timelapse tracker or FileStore.""" # 1. Try timelapse tracker (in-memory, fastest) if embryo_id in self.timelapse_tracker.volume_paths: @@ -201,11 +215,17 @@ def _resolve_volume_path(self, embryo_id: str, timepoint: int) -> Optional[str]: if path: return path - # 2. Try FileStore (file-based, persistent) - if self.gently_store and self.timelapse_tracker.session_id: + # 2. Try FileStore (file-based, persistent). Key on the LIVE agent + # session, not the tracker's (which goes stale after a resume with no + # active timelapse) — mirrors _resolve_projection_path so an agent-driven + # open_volume hand-off doesn't 404 after a /resume. + sid = self._current_session_id() + if self.gently_store and sid: try: vol_path = self.gently_store.get_volume_path( - self.timelapse_tracker.session_id, embryo_id, timepoint, + sid, + embryo_id, + timepoint, ) if vol_path and vol_path.exists(): return str(vol_path) @@ -214,12 +234,26 @@ def _resolve_volume_path(self, embryo_id: str, timepoint: int) -> Optional[str]: return None - def _resolve_projection_path(self, embryo_id: str, timepoint: int) -> Optional[Path]: - """Resolve projection file path from FileStore.""" - if self.gently_store and self.timelapse_tracker.session_id: + def _current_session_id(self) -> str | None: + """The live agent session (source of truth), falling back to the + timelapse tracker. The tracker's session_id goes stale after a resume + with no active timelapse, so the live agent session is preferred.""" + bridge = getattr(self, "agent_bridge", None) + if bridge is not None and getattr(bridge, "agent", None) is not None: + sid = getattr(bridge.agent, "session_id", None) + if sid: + return sid + return self.timelapse_tracker.session_id + + def _resolve_projection_path(self, embryo_id: str, timepoint: int) -> Path | None: + """Resolve projection file path from FileStore (current session).""" + sid = self._current_session_id() + if self.gently_store and sid: try: proj_path = self.gently_store.get_projection_path( - self.timelapse_tracker.session_id, embryo_id, timepoint, + sid, + embryo_id, + timepoint, ) if proj_path and proj_path.exists(): return proj_path @@ -227,6 +261,109 @@ def _resolve_projection_path(self, embryo_id: str, timepoint: int) -> Optional[P logger.debug(f"FileStore projection path lookup failed: {e}") return None + def rehydrate_session(self, session_id: str) -> int: + """Repopulate the in-memory image store with the FileStore's persisted + projections for a (resumed) session, so galleries and filmstrips show + its historical data. + + Lightweight: only metadata-bearing ImageData entries are created (uid + ``volume_{embryo}_t{NNNN}``); the JPEG pixels load lazily on demand via + /api/images/{uid}/png (which falls back to the FileStore projection). + Resets the store first so the previous session's images don't linger. + Returns the number of projection entries added. + """ + if self.gently_store is None or not session_id: + return 0 + self.store = ImageStore() # drop the previous session's images + added = 0 + try: + embryos = self.gently_store.list_embryos(session_id) or [] + except Exception: + embryos = [] + for emb in embryos: + eid = emb.get("embryo_id") if isinstance(emb, dict) else getattr(emb, "embryo_id", None) + if not eid: + continue + try: + tps = self.gently_store.list_projection_timepoints(session_id, eid) + except Exception: + tps = [] + for tp in tps: + self.store.add_image( + ImageData( + uid=f"volume_{eid}_t{tp:04d}", + data_type="volume_projection", + timestamp=f"{tp:06d}", # monotonic with timepoint for ordering + metadata={"embryo_id": eid, "timepoint": tp}, + ) + ) + added += 1 + + # Rehydrate the timelapse tracker's per-embryo perception state from + # predictions.jsonl so the Default / Film / reasoning views populate + # (those are driven by detection_reasoning, not the raw image store). + # Thumbnails resolve via the projection uids added above. + tracker = self.timelapse_tracker + try: + tracker.session_id = session_id + tracker.detection_reasoning = {} + tracker.projection_uids = {} + for emb in embryos: + eid = ( + emb.get("embryo_id") + if isinstance(emb, dict) + else getattr(emb, "embryo_id", None) + ) + if not eid: + continue + try: + preds = self.gently_store.get_predictions(session_id, eid) or [] + except Exception: + preds = [] + if not preds: + continue + items, puids, last_stage = [], {}, None + for p in preds: + tp = p.get("timepoint") + if tp is None: + continue + uid = f"volume_{eid}_t{tp:04d}" + puids[tp] = uid + stage = p.get("predicted_stage") + last_stage = stage or last_stage + items.append( + { + "timepoint": tp, + "stage": stage, + "detected_stage": stage, + "reasoning": p.get("reasoning"), + "confidence": p.get("confidence"), + "projection_uid": uid, + "image_uid": uid, + "detector_name": "perception", + } + ) + tracker.detection_reasoning[eid] = items + tracker.projection_uids[eid] = puids + entry = tracker.embryos.setdefault( + eid, + { + "embryo_id": eid, + "timepoints": 0, + "is_complete": False, + "detections": {}, + "current_stage": None, + }, + ) + entry["timepoints"] = max((it["timepoint"] for it in items), default=0) + entry["current_stage"] = last_stage + tracker.total_timepoints = sum(len(v) for v in tracker.detection_reasoning.values()) + except Exception: + logger.exception("Tracker perception rehydration failed") + + logger.info("Rehydrated %d projections for session %s", added, session_id) + return added + def _subscribe_to_events(self): """Subscribe to EventBus for automatic updates - broadcasts ALL events""" @@ -236,7 +373,11 @@ def _subscribe_to_events(self): async def on_event_async(event): """Async handler for all events - broadcasts to WebSocket clients""" - event_type_str = event.event_type.name if hasattr(event.event_type, 'name') else str(event.event_type) + event_type_str = ( + event.event_type.name + if hasattr(event.event_type, "name") + else str(event.event_type) + ) # Update timelapse state tracker self.timelapse_tracker.handle_event(event_type_str, event.data) @@ -246,15 +387,17 @@ async def on_event_async(event): event_type=event_type_str, data=event.data, source=event.source, - event_id=event.event_id + event_id=event.event_id, ) # For session events, also broadcast updated timelapse_state so clients can sync if event_type_str in ("SESSION_STARTED", "SESSION_RESTORED"): - await self.manager.broadcast({ - "type": "timelapse_state", - "data": self.timelapse_tracker.to_dict() - }) + await self.manager.broadcast( + { + "type": "timelapse_state", + "data": self.timelapse_tracker.to_dict(), + } + ) # Subscribe to ALL events using wildcard with async handler self.event_bus.subscribe_async("*", on_event_async) @@ -272,11 +415,19 @@ def _init_from_event_history(self): # Process events in chronological order (history is newest-first) for event in reversed(history): - event_type_str = event.event_type.name if hasattr(event.event_type, 'name') else str(event.event_type) + event_type_str = ( + event.event_type.name + if hasattr(event.event_type, "name") + else str(event.event_type) + ) self.timelapse_tracker.handle_event(event_type_str, event.data) if self.timelapse_tracker.session_id: - logger.info(f"Initialized timelapse state from history: session={self.timelapse_tracker.session_id}, status={self.timelapse_tracker.status}") + logger.info( + f"Initialized timelapse state from history:" + f" session={self.timelapse_tracker.session_id}," + f" status={self.timelapse_tracker.status}" + ) except Exception as e: logger.warning(f"Failed to initialize from event history: {e}") @@ -285,13 +436,13 @@ def _array_to_image_data( array: np.ndarray, uid: str, data_type: str, - metadata: Optional[Dict] = None + metadata: dict | None = None, ) -> ImageData: """Convert numpy array to ImageData with base64 PNG""" from gently.core.imaging import ( - projection_three_view, - compute_crop_bounds, apply_crop_bounds, + compute_crop_bounds, + projection_three_view, ) # Handle 4D arrays (Views, Z, Y, X) - select View A only @@ -308,7 +459,7 @@ def _array_to_image_data( z_depth, height, width = array.shape # Handle dual-view format (width > 2*height) if width > height * 2: - array = array[:, :, :width // 2] + array = array[:, :, : width // 2] # Auto-crop to embryo region bounds = compute_crop_bounds(array) array = apply_crop_bounds(array, bounds) @@ -328,8 +479,8 @@ def _array_to_image_data( if PIL_AVAILABLE: img = Image.fromarray(array) buffer = io.BytesIO() - img.save(buffer, format='PNG') - base64_png = base64.b64encode(buffer.getvalue()).decode('utf-8') + img.save(buffer, format="PNG") + base64_png = base64.b64encode(buffer.getvalue()).decode("utf-8") return ImageData( uid=uid, @@ -337,7 +488,7 @@ def _array_to_image_data( timestamp=datetime.now().isoformat(), metadata=metadata or {}, base64_png=base64_png, - shape=array.shape + shape=array.shape, ) async def push_image( @@ -345,7 +496,7 @@ async def push_image( array: np.ndarray, uid: str, data_type: str = "image", - metadata: Optional[Dict] = None, + metadata: dict | None = None, ): """ Push an image to connected clients @@ -369,14 +520,16 @@ async def push_image( # Broadcast to clients await self.manager.send_image(image_data) - logger.debug(f"Pushed image {uid} ({data_type}) to {len(self.manager.active_connections)} clients") + logger.debug( + f"Pushed image {uid} ({data_type}) to {len(self.manager.active_connections)} clients" + ) async def start_marking_session( self, image: np.ndarray, initial_stage_position: tuple = (0.0, 0.0), pixel_size_um: float = 0.65, - initial_markers: Optional[list] = None, + initial_markers: list | None = None, default_role: str = "test", ) -> str: """ @@ -412,7 +565,7 @@ async def start_marking_session( """ import uuid - if not hasattr(self, '_marking_sessions'): + if not hasattr(self, "_marking_sessions"): self._marking_sessions = {} session_id = str(uuid.uuid4())[:8] @@ -424,15 +577,17 @@ async def start_marking_session( py = m.get("pixel_y", m.get("pixelY")) if px is None or py is None: continue - normalized.append({ - "number": i + 1, - "pixelX": round(float(px), 1), - "pixelY": round(float(py), 1), - "role": m.get("role", default_role), - "source": m.get("source", "sam"), - "embryo_id": m.get("embryo_id"), - "confidence": m.get("confidence"), - }) + normalized.append( + { + "number": i + 1, + "pixelX": round(float(px), 1), + "pixelY": round(float(py), 1), + "role": m.get("role", default_role), + "source": m.get("source", "sam"), + "embryo_id": m.get("embryo_id"), + "confidence": m.get("confidence"), + } + ) self._marking_sessions[session_id] = { "markers": list(normalized), @@ -445,31 +600,34 @@ async def start_marking_session( # Encode image as base64 PNG from PIL import Image as PILImage + img = image if img.dtype != np.uint8: img = ((img - img.min()) / max(img.max() - img.min(), 1) * 255).astype(np.uint8) pil_img = PILImage.fromarray(img) buf = io.BytesIO() - pil_img.save(buf, format='PNG') - b64 = base64.b64encode(buf.getvalue()).decode('ascii') + pil_img.save(buf, format="PNG") + b64 = base64.b64encode(buf.getvalue()).decode("ascii") h, w = image.shape[:2] # Broadcast to all clients - await self.manager.broadcast({ - "type": "marking_image", - "data": { - "session_id": session_id, - "image_b64": b64, - "width": w, - "height": h, - "initial_markers": normalized, - "default_role": default_role, - "stage_x_um": float(initial_stage_position[0]), - "stage_y_um": float(initial_stage_position[1]), - "pixel_size_um": pixel_size_um, + await self.manager.broadcast( + { + "type": "marking_image", + "data": { + "session_id": session_id, + "image_b64": b64, + "width": w, + "height": h, + "initial_markers": normalized, + "default_role": default_role, + "stage_x_um": float(initial_stage_position[0]), + "stage_y_um": float(initial_stage_position[1]), + "pixel_size_um": pixel_size_um, + }, } - }) + ) logger.info( f"Marking session {session_id} started, image {w}x{h}, " @@ -478,7 +636,7 @@ async def start_marking_session( ) return session_id - async def wait_for_marking(self, session_id: str, timeout: float = None) -> list: + async def wait_for_marking(self, session_id: str, timeout: float | None = None) -> list: """ Wait for a marking session to complete. @@ -505,9 +663,9 @@ async def wait_for_marking(self, session_id: str, timeout: float = None) -> list markers = session["markers"] initial_pos = session["initial_stage_position"] - pixel_size = session["pixel_size_um"] + session["pixel_size_um"] h, w = session["image_shape"][:2] - center_x, center_y = w / 2, h / 2 + _center_x, _center_y = w / 2, h / 2 # Convert to embryo entries. Carries role + source so callers can # register each embryo with the right experimental classification. @@ -515,18 +673,20 @@ async def wait_for_marking(self, session_id: str, timeout: float = None) -> list embryos = [] for m in markers: px, py = m["pixelX"], m["pixelY"] - embryos.append({ - "embryo_number": m["number"], - "embryo_id": m.get("embryo_id") or f"embryo_{m['number']:03d}", - "pixel_position": (px, py), - "pixel_x": px, - "pixel_y": py, - "initial_stage_position": initial_pos, - "role": m.get("role", default_role), - "source": m.get("source", "manual"), - "confidence": m.get("confidence"), - "marking_timestamp": m.get("timestamp", datetime.now().isoformat()), - }) + embryos.append( + { + "embryo_number": m["number"], + "embryo_id": m.get("embryo_id") or f"embryo_{m['number']:03d}", + "pixel_position": (px, py), + "pixel_x": px, + "pixel_y": py, + "initial_stage_position": initial_pos, + "role": m.get("role", default_role), + "source": m.get("source", "manual"), + "confidence": m.get("confidence"), + "marking_timestamp": m.get("timestamp", datetime.now().isoformat()), + } + ) # Clean up del self._marking_sessions[session_id] @@ -538,7 +698,7 @@ async def push_volume_3d( volume: np.ndarray, masks: np.ndarray, uid: str, - metadata: Optional[Dict] = None, + metadata: dict | None = None, ): """ Push a 3D segmentation volume to connected clients @@ -562,24 +722,54 @@ async def push_volume_3d( volume_data = Volume3DData( uid=uid, - data_type='segmentation_3d', + data_type="segmentation_3d", timestamp=datetime.now().isoformat(), volume=volume, masks=masks, colors=colors, - metadata=metadata or {} + metadata=metadata or {}, ) # Store the 3D volume self.store.add_volume_3d(volume_data) # Broadcast notification to clients (without the heavy data) - await self.manager.broadcast({ - 'type': 'volume_3d', - 'data': volume_data.to_info_dict() - }) + await self.manager.broadcast({"type": "volume_3d", "data": volume_data.to_info_dict()}) - logger.info(f"Pushed 3D volume {uid} ({volume.shape}) to {len(self.manager.active_connections)} clients") + logger.info( + f"Pushed 3D volume {uid} ({volume.shape}) to" + f" {len(self.manager.active_connections)} clients" + ) + + async def open_volume_in_browser( + self, + embryo_id: str, + timepoint: int, + view: str = "3d_viewer", + ) -> int: + """Ask every connected browser to open the in-browser volume viewer. + + This is the web-native replacement for the old napari ``view_volume``: + the agent triggers the existing ProjectionViewer (WebGL raymarcher + + projections) instead of launching a desktop Qt window that would block + the shared agent/web event loop. Returns the number of clients notified. + """ + await self.manager.broadcast( + { + "type": "open_volume", + "embryo_id": embryo_id, + "timepoint": timepoint, + "view": view, + } + ) + n = len(self.manager.active_connections) + logger.info( + "Requested browser open_volume for %s t%s (%d client(s))", + embryo_id, + timepoint, + n, + ) + return n async def on_start(self): """Start the visualization server""" @@ -592,6 +782,7 @@ async def on_start(self): # off to uvicorn (whose bind error surfaces inside a background # task and produces an unhelpful log line). import socket + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: sock.bind((self.host, self.port)) @@ -600,7 +791,7 @@ async def on_start(self): f"Port {self.port} is already in use. " "Is another instance of the agent running? " "Close it first and try again." - ) + ) from None finally: sock.close() @@ -662,10 +853,10 @@ async def on_stop(self): self._server_task = None logger.info("Visualization server stopped") - async def health_check(self) -> Dict: + async def health_check(self) -> dict: """Return health status with connected client count.""" base = await super().health_check() - base['connected_clients'] = len(self.manager.active_connections) + base["connected_clients"] = len(self.manager.active_connections) return base async def run_forever(self): @@ -682,20 +873,20 @@ def signal_handler(*args): loop = asyncio.get_running_loop() signals_installed = False - if hasattr(signal, 'SIGINT'): + if hasattr(signal, "SIGINT"): try: loop.add_signal_handler(signal.SIGINT, signal_handler) signals_installed = True except NotImplementedError: pass - if hasattr(signal, 'SIGTERM'): + if hasattr(signal, "SIGTERM"): try: loop.add_signal_handler(signal.SIGTERM, signal_handler) except NotImplementedError: pass - if sys.platform == 'win32' and not signals_installed: + if sys.platform == "win32" and not signals_installed: signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) @@ -704,21 +895,17 @@ def signal_handler(*args): logger.info(f"Server running at http://{self.host}:{self.port} - Press Ctrl+C to stop") try: - if sys.platform == 'win32': + if sys.platform == "win32": while not stop_event.is_set(): try: - await asyncio.wait_for( - asyncio.shield(self._server_task), - timeout=0.5 - ) + await asyncio.wait_for(asyncio.shield(self._server_task), timeout=0.5) break except asyncio.TimeoutError: continue else: stop_task = asyncio.create_task(stop_event.wait()) done, pending = await asyncio.wait( - [self._server_task, stop_task], - return_when=asyncio.FIRST_COMPLETED + [self._server_task, stop_task], return_when=asyncio.FIRST_COMPLETED ) for task in pending: task.cancel() diff --git a/gently/ui/web/static/css/agent-chat.css b/gently/ui/web/static/css/agent-chat.css new file mode 100644 index 00000000..fdaaa8e3 --- /dev/null +++ b/gently/ui/web/static/css/agent-chat.css @@ -0,0 +1,443 @@ +/* Floating agent-chat window — the web-side control surface. + Restrained, professional styling for a lab instrument. */ + +/* ── Header toggle (replaces the floating FAB) ─────────────── */ +.header-agent-toggle { + display: inline-flex; align-items: center; gap: 7px; + padding: 5px 10px; border-radius: 8px; + border: 1px solid var(--border); + background: var(--bg-hover); color: var(--text); + font: 500 12.5px/1 'Inter Tight', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; + cursor: pointer; position: relative; + transition: border-color 0.15s ease, background 0.15s ease, color 0.15s ease; +} +.header-agent-toggle:hover { border-color: var(--accent); } +.header-agent-toggle[aria-pressed="true"] { + border-color: var(--accent); color: var(--accent); + background: rgba(96, 165, 250, 0.12); +} +.header-agent-toggle svg { display: block; } +.header-agent-label { letter-spacing: 0.01em; } +.header-agent-dot { + width: 7px; height: 7px; border-radius: 50%; + background: var(--text-muted); flex: 0 0 auto; +} +.header-agent-dot.ok { background: var(--accent-green); } +.header-agent-badge { + position: absolute; top: -6px; right: -6px; + min-width: 16px; height: 16px; padding: 0 4px; + border-radius: 999px; background: var(--accent-purple); color: #fff; + font-size: 10px; font-weight: 700; line-height: 16px; text-align: center; +} +.header-agent-badge.hidden { display: none; } + +/* ── Docked agent panel ──────────────────────────────────── + Default = overlay slide-over, absolutely positioned inside .app-shell (which + sits below the global header/navbar). Pin (body.chat-docked) turns it into a + real column that pushes .app-main. */ +.agent-chat { + position: absolute; + top: 0; right: 0; bottom: 0; + width: var(--chat-w, 460px); + max-width: 92vw; + display: flex; + flex-direction: column; + background: var(--bg-card); + border-left: 1px solid var(--border); + box-shadow: -16px 0 40px -16px var(--panel-edge-shadow); + z-index: 50; + overflow: hidden; + transform: translateX(100%); + transition: transform 0.22s cubic-bezier(0.22, 1, 0.36, 1); + will-change: transform; + font-family: 'Inter Tight', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; +} +.agent-chat.open { transform: translateX(0); } + +/* Pinned: a real pushing column — no float shadow, just a seam. */ +body.chat-docked .agent-chat { + position: relative; + transform: none; + box-shadow: none; + border-left: 1px solid var(--border-strong); + flex: 0 0 auto; + transition: none; + z-index: auto; +} +body.chat-docked .agent-chat:not(.open) { + width: 0; border-left: none; overflow: hidden; +} + +@media (prefers-reduced-motion: reduce) { + .agent-chat { transition: opacity 0.12s ease; } +} + +/* Left-edge resize handle (thin seam, generous hit area). */ +.agent-chat-resize { + position: absolute; left: -2px; top: 0; bottom: 0; width: 6px; + cursor: ew-resize; z-index: 3; +} +.agent-chat-resize::after { + content: ''; position: absolute; left: 2px; top: 0; bottom: 0; width: 1px; + background: transparent; transition: background 0.12s ease; +} +.agent-chat-resize:hover::after, .agent-chat-resize.dragging::after { background: var(--accent); } + +.agent-control-banner.hidden { display: none; } + +/* Pin button in the panel header. */ +.agent-chat-pin { + background: none; border: none; color: var(--text-muted); + cursor: pointer; padding: 2px; display: flex; align-items: center; + border-radius: 5px; +} +.agent-chat-pin:hover { color: var(--text); background: var(--bg-hover); } +.agent-chat-pin[aria-pressed="true"] { color: var(--accent); } + +/* ── Header ─────────────────────────────────────────────── */ +.agent-chat-header { + display: flex; + align-items: center; + gap: 10px; + padding: 12px 14px; + border-bottom: 1px solid var(--border); +} +.agent-chat-id { display: flex; align-items: center; gap: 9px; } +.agent-chat-mark { display: block; } +.agent-chat-title { font-weight: 600; font-size: 14px; color: var(--text); letter-spacing: 0.01em; } +.agent-chat-user { + font-size: 11px; color: var(--text-muted); + padding-left: 7px; margin-left: 1px; + border-left: 1px solid var(--border); +} +.agent-chat-user:empty { display: none; } +.agent-chat-signout { + background: none; border: none; + color: var(--text-muted); font-size: 11px; cursor: pointer; + padding: 0 4px; font-family: inherit; +} +.agent-chat-signout:hover { color: var(--text); text-decoration: underline; } + +.agent-chat-conn { + margin-left: auto; + font-size: 11px; + font-weight: 500; + padding: 3px 9px; + border-radius: 999px; + border: 1px solid var(--border); + color: var(--text-muted); + white-space: nowrap; +} +.agent-chat-conn.ac-conn-ok { color: var(--accent-green); border-color: rgba(74, 222, 128, 0.35); } +.agent-chat-conn.ac-conn-bad { color: var(--accent-orange, #fb923c); border-color: rgba(251, 146, 60, 0.35); } +.agent-chat-close { + background: none; border: none; + color: var(--text-muted); font-size: 20px; line-height: 1; + cursor: pointer; padding: 0 2px; +} +.agent-chat-close:hover { color: var(--text); } + +/* ── Control banner ─────────────────────────────────────── */ +.agent-control-banner { + display: flex; align-items: center; gap: 10px; + padding: 9px 14px; + background: rgba(251, 146, 60, 0.10); + border-bottom: 1px solid var(--border); + color: var(--accent-orange, #fb923c); + font-size: 12.5px; +} +.ac-take-control { + margin-left: auto; + padding: 4px 12px; border-radius: 7px; + border: 1px solid var(--accent); + background: var(--accent); color: #fff; + cursor: pointer; font-size: 12px; font-weight: 600; +} +.ac-take-control:hover { background: var(--accent-hover); } + +/* ── Transcript ─────────────────────────────────────────── */ +.agent-chat-log { + flex: 1 1 auto; + overflow-y: auto; + padding: 16px; + display: flex; flex-direction: column; gap: 14px; + font-size: 13.5px; line-height: 1.6; + color: var(--text); +} + +.ac-turn { display: flex; flex-direction: column; } +.ac-role { + font-size: 10.5px; font-weight: 600; + letter-spacing: 0.06em; text-transform: uppercase; + color: var(--accent-purple); + margin-bottom: 4px; +} +.ac-turn-agent .ac-content { color: var(--text); } +.ac-turn-agent .ac-content code { + font-family: 'JetBrains Mono', ui-monospace, monospace; + font-size: 12px; + background: rgba(127, 127, 127, 0.14); + padding: 1px 5px; border-radius: 4px; +} + +/* User: right-aligned, subtle accent block (not a loud bubble) */ +.ac-turn-user { align-items: flex-end; } +.ac-turn-user .ac-content { + background: rgba(96, 165, 250, 0.12); + border: 1px solid rgba(96, 165, 250, 0.22); + color: var(--text); + padding: 7px 11px; + border-radius: 10px 10px 2px 10px; + max-width: 88%; + white-space: pre-wrap; word-wrap: break-word; +} + +/* ── Autonomous (wake) turns ────────────────────────────── */ +.ac-autonomous-banner { + display: flex; align-items: center; gap: 8px; + align-self: stretch; + margin: 2px 0; + padding: 6px 10px; + font-size: 11.5px; font-weight: 500; + color: var(--accent-purple); + background: rgba(167, 139, 250, 0.10); + border: 1px solid rgba(167, 139, 250, 0.28); + border-radius: 8px; +} +.ac-autonomous-dot { + width: 7px; height: 7px; border-radius: 50%; + background: var(--accent-purple); + box-shadow: 0 0 0 3px rgba(167, 139, 250, 0.20); + flex: 0 0 auto; +} +/* Autonomous agent bubbles get an accent rail + a distinct role label. */ +.ac-turn-autonomous { border-left: 2px solid rgba(167, 139, 250, 0.45); padding-left: 8px; } +.ac-turn-autonomous .ac-role { color: var(--accent-purple); } + +/* ── Activity indicator ─────────────────────────────────── */ +.ac-activity { + display: flex; align-items: center; gap: 9px; + color: var(--text-muted); font-size: 12.5px; +} +.ac-dots { display: inline-flex; gap: 4px; } +.ac-dots i { + width: 5px; height: 5px; border-radius: 50%; + background: var(--accent); + display: inline-block; + animation: ac-blink 1.2s infinite both; +} +.ac-dots i:nth-child(2) { animation-delay: 0.18s; } +.ac-dots i:nth-child(3) { animation-delay: 0.36s; } +@keyframes ac-blink { 0%, 80%, 100% { opacity: 0.22; } 40% { opacity: 1; } } + +/* ── Tool calls ─────────────────────────────────────────── */ +.ac-tool { + display: flex; align-items: center; gap: 8px; + font-family: 'JetBrains Mono', ui-monospace, monospace; + font-size: 11.5px; + color: var(--text-muted); + padding: 6px 10px; + border: 1px solid var(--border); + border-radius: 8px; + background: rgba(127, 127, 127, 0.05); +} +.ac-tool-name { color: var(--text); } +.ac-tool-meta { color: var(--text-muted); } +.ac-tool-check { color: var(--accent-green); } + +/* Multi-line tool rows: head (icon + name + meta) over args / summary. */ +.ac-tool { flex-direction: column; align-items: stretch; gap: 4px; } +.ac-tool-head { display: flex; align-items: center; gap: 8px; } +.ac-tool-args { color: var(--text-muted); padding-left: 19px; word-break: break-word; } +.ac-tool-summary { color: var(--text-muted); padding-left: 19px; word-break: break-word; } +.ac-tool-summary-err, .ac-tool-warn { color: var(--accent-orange, #fb923c); } +.ac-tool-err { border-color: rgba(251, 146, 60, 0.35); } +.ac-tool-spin { + width: 11px; height: 11px; border-radius: 50%; + border: 1.6px solid var(--border); + border-top-color: var(--accent); + display: inline-block; + animation: ac-spin 0.7s linear infinite; +} +@keyframes ac-spin { to { transform: rotate(360deg); } } + +/* ── System lines / notifications ───────────────────────── */ +.ac-system { + align-self: center; + font-size: 11.5px; color: var(--text-muted); + text-align: center; max-width: 95%; +} +.ac-level-error { color: var(--color-danger, #f87171); } +.ac-level-warning { color: var(--accent-orange, #fb923c); } +.ac-level-success { color: var(--accent-green); } + +/* ── Choice picker ──────────────────────────────────────── */ +.ac-choice { + display: flex; flex-direction: column; gap: 7px; + padding: 12px; + border: 1px solid var(--border); + border-radius: 10px; + background: rgba(127, 127, 127, 0.04); +} +.ac-choice-q { color: var(--text); font-weight: 500; } +.ac-choice-opt { + text-align: left; + padding: 9px 12px; border-radius: 8px; + border: 1px solid var(--border); + background: var(--bg-card); color: var(--text); + cursor: pointer; + display: flex; flex-direction: column; gap: 2px; + transition: border-color 0.12s ease, background 0.12s ease; +} +.ac-choice-opt:hover:not(:disabled) { border-color: var(--accent); background: var(--bg-hover); } +.ac-choice-opt:disabled { opacity: 0.5; cursor: default; } +.ac-choice-label { font-weight: 600; font-size: 13px; } +.ac-choice-desc { font-size: 12px; color: var(--text-muted); } +.ac-choice-picked { border-color: var(--accent-green); background: rgba(74, 222, 128, 0.08); } +.ac-choice-wake { + border-color: rgba(167, 139, 250, 0.45); + border-left: 3px solid var(--accent-purple); + background: rgba(167, 139, 250, 0.06); +} +.ac-choice-origin { + font-size: 10.5px; font-weight: 600; letter-spacing: 0.04em; text-transform: uppercase; + color: var(--accent-purple); margin-bottom: 2px; +} + +/* Sticky ASK-approval slot: pinned above the composer so it never scrolls away. */ +.ac-pending { + flex: 0 0 auto; + border-top: 1px solid var(--border); + background: var(--bg-card); + padding: 8px 12px 0; +} +.ac-pending.hidden { display: none; } +.ac-pending .ac-choice { margin-bottom: 8px; } + +/* "↓ N new" jump-to-bottom pill (shown when scrolled up during streaming). */ +.ac-jump { + position: absolute; + left: 50%; transform: translateX(-50%); + bottom: 74px; + padding: 4px 12px; border-radius: 999px; + border: 1px solid var(--accent); + background: var(--bg-card); color: var(--accent); + font: 600 11.5px/1.4 'Inter Tight', -apple-system, sans-serif; + cursor: pointer; z-index: 4; + box-shadow: 0 4px 14px rgba(0, 0, 0, 0.35); +} +.ac-jump.hidden { display: none; } + +/* ── Applied-spec card ──────────────────────────────────── */ +.ac-spec { + border: 1px solid var(--border); + border-radius: 10px; + padding: 11px 13px; + background: rgba(127, 127, 127, 0.04); + font-size: 12.5px; +} +.ac-spec-title { + font-weight: 600; color: var(--accent-purple); + font-size: 11px; letter-spacing: 0.04em; text-transform: uppercase; + margin-bottom: 6px; +} +.ac-spec-row { display: flex; justify-content: space-between; gap: 16px; padding: 1px 0; color: var(--text-muted); } +.ac-spec-row span:last-child { color: var(--text); font-family: 'JetBrains Mono', ui-monospace, monospace; } + +/* ── Composer ───────────────────────────────────────────── */ +.agent-chat-input { + display: flex; gap: 8px; + padding: 12px; + border-top: 1px solid var(--border); + position: relative; /* anchor for the autocomplete dropdown */ +} + +/* ── Autocomplete dropdown ──────────────────────────────── */ +.ac-complete { + position: absolute; + left: 12px; right: 12px; bottom: calc(100% + 4px); + max-height: 240px; overflow-y: auto; + background: var(--bg-card); + border: 1px solid var(--border); + border-radius: 9px; + box-shadow: 0 -8px 28px rgba(0, 0, 0, 0.45); + padding: 4px; + z-index: 5; +} +.ac-complete.hidden { display: none; } +.ac-complete-item { + display: flex; flex-direction: column; gap: 1px; + padding: 6px 9px; border-radius: 6px; + cursor: pointer; +} +.ac-complete-item.active, +.ac-complete-item:hover { background: var(--bg-hover, rgba(127, 127, 127, 0.12)); } +.ac-complete-name { + font-family: 'JetBrains Mono', ui-monospace, monospace; + font-size: 12.5px; color: var(--accent); +} +.ac-complete-desc { + font-size: 11.5px; color: var(--text-muted); + white-space: nowrap; overflow: hidden; text-overflow: ellipsis; +} +.agent-chat-input textarea { + flex: 1 1 auto; resize: none; + border: 1px solid var(--border); border-radius: 9px; + background: var(--bg-dark); color: var(--text); + padding: 9px 11px; + font-family: inherit; font-size: 13.5px; line-height: 1.45; + max-height: 140px; +} +.agent-chat-input textarea::placeholder { color: var(--text-muted); } +.agent-chat-input textarea:focus { outline: none; border-color: var(--accent); } +.agent-chat-input textarea:disabled { opacity: 0.55; } +.agent-chat-send { + flex: 0 0 auto; align-self: flex-end; + padding: 9px 16px; border-radius: 9px; + border: none; background: var(--accent); color: #fff; + font-weight: 600; font-size: 13px; cursor: pointer; + transition: background 0.12s ease; +} +.agent-chat-send:hover:not(:disabled) { background: var(--accent-hover); } +.agent-chat-send:disabled { opacity: 0.5; cursor: default; } +/* Send now queues while busy (it no longer doubles as Stop), so just dim it. */ +.agent-chat-send.ac-busy { opacity: 0.6; } + +/* Explicit Stop (separate from Send), shown only during a cancellable turn. */ +.ac-stop { + flex: 0 0 auto; align-self: flex-end; + padding: 9px 12px; border-radius: 9px; + border: 1px solid var(--color-danger, #f87171); + background: transparent; color: var(--color-danger, #f87171); + font-weight: 600; font-size: 13px; cursor: pointer; +} +.ac-stop:hover { background: rgba(248, 113, 113, 0.12); } +.ac-stop.hidden { display: none; } + +/* ── Queued-message panel (type-while-busy) ─────────────── */ +.ac-queue { + margin: 0 12px 6px; + border: 1px solid var(--border); border-radius: 9px; + background: rgba(127, 127, 127, 0.06); + padding: 6px; font-size: 12px; +} +.ac-queue.hidden { display: none; } +.ac-queue-head { + display: flex; align-items: center; justify-content: space-between; + padding: 2px 4px 6px; color: var(--text-muted); +} +.ac-queue-clear { + background: none; border: none; color: var(--accent); + cursor: pointer; font-size: 11.5px; font-family: inherit; +} +.ac-queue-clear:hover { text-decoration: underline; } +.ac-queue-item { display: flex; align-items: center; gap: 8px; padding: 4px; } +.ac-queue-text { + flex: 1 1 auto; color: var(--text); + white-space: nowrap; overflow: hidden; text-overflow: ellipsis; +} +.ac-queue-remove { + flex: 0 0 auto; background: none; border: none; + color: var(--text-muted); cursor: pointer; font-size: 12px; line-height: 1; +} +.ac-queue-remove:hover { color: var(--color-danger, #f87171); } diff --git a/gently/ui/web/static/css/main.css b/gently/ui/web/static/css/main.css index 4d19c79d..10bc6f62 100644 --- a/gently/ui/web/static/css/main.css +++ b/gently/ui/web/static/css/main.css @@ -34,6 +34,10 @@ /* Image backgrounds */ --img-bg: #000; + + /* Docked agent panel */ + --panel-edge-shadow: rgba(0, 0, 0, 0.55); + --border-strong: #444c56; } /* ======================================== @@ -70,6 +74,10 @@ /* Image backgrounds */ --img-bg: #1e293b; + + /* Docked agent panel — softer shadow + stronger seam for light mode */ + --panel-edge-shadow: rgba(0, 0, 0, 0.18); + --border-strong: #cbd5e1; } * { margin: 0; padding: 0; box-sizing: border-box; } @@ -90,6 +98,25 @@ body { transition: background-color 0.3s ease, color 0.3s ease; } +/* App shell: main column + docked agent panel side by side. The flex row lets + the panel become a real column (pushing content) when pinned to dock; in the + default overlay mode the panel is position:fixed and sits out of this flow. */ +.app-shell { + flex: 1 1 auto; + min-height: 0; + display: flex; + flex-direction: row; + position: relative; /* anchor for the overlay-mode agent panel */ +} +.app-main { + flex: 1 1 auto; + min-width: 0; /* allow canvases to shrink (not overflow) when docked */ + min-height: 0; + display: flex; + flex-direction: column; + overflow: hidden; +} + /* Smooth theme transitions for key elements */ .header, .tabs, .tab-content, .panel, .gallery-item, .events-container, .lightbox-container, .shortcuts-content { @@ -625,6 +652,114 @@ a.tab-link.active { flex-direction: column; } +/* ── Home (landing) tab ───────────────────────────────────── + #home-content is a flex column with overflow:hidden, so the scroll lives on + .home-scroll. */ +.home-scroll { + flex: 1 1 auto; + min-height: 0; + overflow-y: auto; + padding: 24px; + display: flex; + flex-direction: column; + gap: 20px; +} +.home-hero { + display: flex; + align-items: center; + justify-content: space-between; + gap: 16px; + padding: 20px 22px; + border: 1px solid var(--border); + border-radius: 14px; + background: var(--bg-card); +} +.home-hero-title { font-size: 1.35rem; font-weight: 700; color: var(--text); margin: 0; } +.home-hero-status { font-size: 12.5px; color: var(--text-muted); margin-top: 4px; } +.home-start-btn { + flex: 0 0 auto; + padding: 10px 18px; + border: none; border-radius: 10px; + background: var(--gradient-primary, var(--accent)); + color: #fff; font-weight: 600; font-size: 13.5px; cursor: pointer; + box-shadow: var(--shadow-glow); + transition: transform 0.12s ease, box-shadow 0.12s ease; +} +.home-start-btn:hover { transform: translateY(-1px); box-shadow: var(--shadow-glow-strong); } + +.home-grid { + display: grid; + grid-template-columns: repeat(2, minmax(0, 1fr)); + gap: 16px; +} +.home-card-wide { grid-column: 1 / -1; } +@media (max-width: 820px) { + .home-grid { grid-template-columns: 1fr; } + .home-card-wide { grid-column: auto; } +} + +.home-card { + display: flex; flex-direction: column; + padding: 14px 16px; + border: 1px solid var(--border); + border-radius: 12px; + background: var(--bg-card); + min-height: 120px; +} +.home-card-head { + display: flex; align-items: center; justify-content: space-between; + margin-bottom: 10px; +} +.home-card-title { + font-size: 11px; font-weight: 600; letter-spacing: 0.06em; + text-transform: uppercase; color: var(--text-muted); +} +.home-card-link { font-size: 11.5px; color: var(--accent); text-decoration: none; } +.home-card-link:hover { text-decoration: underline; } +.home-card-body { display: flex; flex-direction: column; gap: 6px; } + +.home-item { + display: flex; align-items: center; justify-content: space-between; gap: 10px; + padding: 8px 10px; border-radius: 8px; + background: rgba(127, 127, 127, 0.05); + border: 1px solid transparent; +} +.home-item-clickable { cursor: pointer; } +.home-item-clickable:hover { border-color: var(--accent); background: var(--bg-hover); } +.home-item-main { display: flex; flex-direction: column; gap: 2px; min-width: 0; } +.home-item-row { display: flex; align-items: center; gap: 7px; } +.home-item-name { + font-size: 13px; color: var(--text); font-weight: 500; + white-space: nowrap; overflow: hidden; text-overflow: ellipsis; +} +.home-item-meta { font-size: 11.5px; color: var(--text-muted); } +.home-tag { + font-size: 9.5px; font-weight: 700; letter-spacing: 0.04em; text-transform: uppercase; + padding: 1px 6px; border-radius: 999px; +} +.home-tag-live { color: var(--accent-green); border: 1px solid rgba(74, 222, 128, 0.4); } +.home-resume { + flex: 0 0 auto; + padding: 4px 11px; border-radius: 7px; + border: 1px solid var(--accent); background: transparent; color: var(--accent); + font-size: 12px; font-weight: 600; cursor: pointer; +} +.home-resume:hover { background: var(--accent); color: #fff; } +.home-resume:disabled { opacity: 0.6; cursor: default; } +.home-chip { + flex: 0 0 auto; + font-size: 11px; font-weight: 600; + padding: 2px 8px; border-radius: 999px; + background: var(--bg-hover); color: var(--text-muted); +} + +.home-image-strip { display: flex; gap: 8px; flex-wrap: wrap; } +.home-image { + width: 84px; height: 84px; border-radius: 8px; overflow: hidden; + border: 1px solid var(--border); background: var(--img-bg); flex: 0 0 auto; +} +.home-image img { width: 100%; height: 100%; object-fit: cover; display: block; } + /* Live View - Clean full-width layout */ .live-view { display: flex; @@ -1482,12 +1617,200 @@ a.tab-link.active { background: #000; border-radius: 3px; border: 1px solid var(--border); - display: none; /* shown when frame arrives via .has-frame */ - opacity: 1; + display: inline-block; + opacity: 0.35; /* dim until a real frame arrives (.has-frame) */ } .cal-spim-thumb.has-frame { - display: inline-block; + opacity: 1; +} + +/* Thumb wrapped in a button so click pops out a larger live view. + Sized to match the thumb so it remains clickable even before the + first frame arrives. */ +.cal-spim-thumb-btn { + position: relative; + padding: 0; + background: none; + border: 0; + cursor: pointer; + display: inline-flex; + align-items: center; + line-height: 0; + color: inherit; + width: 96px; + height: 72px; +} + +.cal-spim-thumb-btn:focus-visible { + outline: 2px solid var(--accent, #4f8cff); + outline-offset: 2px; + border-radius: 4px; +} + +.cal-spim-expand-icon { + position: absolute; + top: 2px; + right: 2px; + background: rgba(0, 0, 0, 0.55); + color: #fff; + font-size: 11px; + line-height: 1; + padding: 2px 4px; + border-radius: 3px; + opacity: 0; + transition: opacity 0.12s ease; + pointer-events: none; +} + +.cal-spim-thumb-btn:hover .cal-spim-expand-icon, +.cal-spim-thumb-btn:focus-visible .cal-spim-expand-icon { + opacity: 1; +} + +/* Hide the expand chip when the thumb has no frame yet — nothing to expand. */ +.cal-spim-thumb-btn:has(.cal-spim-thumb:not(.has-frame)) .cal-spim-expand-icon { + display: none; +} + +/* ---------- Floating SPIM popout ---------- */ +.cal-spim-popout { + position: fixed; + top: 80px; + right: 24px; + width: 560px; + height: 480px; + min-width: 320px; + min-height: 260px; + z-index: 9000; + background: var(--bg-card); + border: 1px solid var(--border); + border-radius: 10px; + box-shadow: 0 16px 40px rgba(0, 0, 0, 0.45), + 0 2px 8px rgba(0, 0, 0, 0.25); + display: flex; + flex-direction: column; + overflow: hidden; + resize: both; +} + +.cal-spim-popout[hidden] { + display: none; +} + +.cal-spim-popout.dragging { + user-select: none; + cursor: grabbing; +} + +.cal-spim-popout-header { + flex: 0 0 auto; + display: flex; + align-items: center; + gap: 8px; + padding: 8px 12px; + background: var(--bg-elevated, var(--bg-card)); + border-bottom: 1px solid var(--border); + cursor: grab; + touch-action: none; +} + +.cal-spim-popout.dragging .cal-spim-popout-header { + cursor: grabbing; +} + +.cal-spim-popout-led { + width: 8px; + height: 8px; + border-radius: 50%; + background: #4ade80; + box-shadow: 0 0 6px rgba(74, 222, 128, 0.7); + animation: cal-spim-led-blink 1.6s ease-in-out infinite; +} + +.cal-spim-popout-led.idle { + background: var(--text-muted, #666); + box-shadow: none; + animation: none; +} + +.cal-spim-popout-title { + font-size: 11px; + font-weight: 700; + text-transform: uppercase; + letter-spacing: 0.6px; + color: var(--text); +} + +.cal-spim-popout-embryo { + font-family: 'JetBrains Mono', ui-monospace, monospace; + font-size: 11px; + color: var(--text-muted); +} + +.cal-spim-popout-spacer { + flex: 1; +} + +.cal-spim-popout-close { + background: transparent; + border: 0; + color: var(--text-muted); + font-size: 20px; + line-height: 1; + padding: 0 6px; + cursor: pointer; + border-radius: 4px; +} + +.cal-spim-popout-close:hover { + color: var(--text); + background: var(--border); +} + +.cal-spim-popout-body { + flex: 1; + min-height: 0; + display: flex; + align-items: center; + justify-content: center; + background: #000; + overflow: hidden; + padding: 4px; +} + +.cal-spim-popout-img { + max-width: 100%; + max-height: 100%; + object-fit: contain; + display: none; +} + +.cal-spim-popout-img.has-frame { + display: block; +} + +.cal-spim-popout-placeholder { + color: var(--text-muted); + font-size: 12px; + letter-spacing: 0.4px; +} + +.cal-spim-popout-placeholder[hidden] { + display: none; +} + +.cal-spim-popout-footer { + flex: 0 0 auto; + padding: 6px 12px; + border-top: 1px solid var(--border); + background: var(--bg-elevated, var(--bg-card)); + font-family: 'JetBrains Mono', ui-monospace, monospace; + font-size: 11px; + color: var(--text-muted); + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; } /* When a live frame is active, let the SPIM cell breathe a bit so the @@ -2036,6 +2359,20 @@ a.tab-link.active { .event-type-badge.error { background: rgba(248, 81, 73, 0.2); color: #f85149; } .event-type-badge.default { background: var(--bg-hover); color: var(--text-muted); } +/* Log-record badges per level. The level is the badge text (DEBUG / INFO / + WARN / ERROR) for log rows; the LOG_RECORD type itself is collapsed into + the level so the column doesn't read the same string for every line. */ +.event-type-badge.log-debug { background: rgba(125, 134, 145, 0.18); color: #9ba3b0; } +.event-type-badge.log-info { background: rgba(88, 166, 255, 0.16); color: var(--accent); } +.event-type-badge.log-warn { background: rgba(210, 153, 34, 0.22); color: var(--accent-orange); } +.event-type-badge.log-error { background: rgba(248, 81, 73, 0.22); color: #f85149; } + +/* Log line message: monospace, faint logger prefix, expandable trace. */ +.log-row .event-data { font-family: 'JetBrains Mono', ui-monospace, monospace; } +.log-logger { color: var(--text-muted); opacity: 0.85; margin-right: 0.5rem; } +.log-message { color: var(--text); } +.log-exc { color: #f85149; opacity: 0.85; } + .event-source { color: var(--text-muted); font-size: 0.75rem; @@ -3330,6 +3667,15 @@ kbd { padding: 0; } +/* Filmstrip: rows on the left, reasoning/detail panel pinned on the right. + (Recovered from the lost WIP commit 0269e18d.) */ +.view-filmstrip { + display: flex; + flex-direction: row; + align-items: stretch; + overflow: hidden; +} + /* ======================================== AMBIENT HEALTH PULSE ======================================== */ @@ -3405,12 +3751,24 @@ kbd { .board-col { padding: 0 0.5rem; } .board-col-embryo { width: 100px; flex-shrink: 0; } .board-col-stage { width: 130px; flex-shrink: 0; } -.board-col-conf { width: 60px; flex-shrink: 0; text-align: center; } -.board-col-rate { width: 70px; flex-shrink: 0; text-align: center; } -.board-col-eta { width: 70px; flex-shrink: 0; text-align: center; } +.board-col-clock { width: 72px; flex-shrink: 0; text-align: right; font-variant-numeric: tabular-nums; } +.board-col-stereo { width: 140px; flex-shrink: 0; font-variant-numeric: tabular-nums; } +.board-col-pace { width: 90px; flex-shrink: 0; text-align: center; font-variant-numeric: tabular-nums; } +.board-col-eta { width: 70px; flex-shrink: 0; text-align: right; font-variant-numeric: tabular-nums; } .board-col-spark { flex: 1; min-width: 100px; } .board-col-alert { width: 110px; flex-shrink: 0; text-align: right; } +/* Pace cell coloring — green when on reference, orange when slow, + red when seriously slow. Class names mirror _formatPace(). */ +.board-col-pace.pace-unknown { color: var(--text-muted); } +.board-col-pace.pace-normal { color: var(--accent-green, #4ade80); } +.board-col-pace.pace-slow { color: #fb923c; } +.board-col-pace.pace-slow-bad { color: #f87171; font-weight: 600; } + +/* Subtle overdue mark in the stereo cell when clock has run past the + expected stage duration. */ +.stereo-overdue { color: #fb923c; margin-left: 4px; } + .board-rows { flex: 1; } .board-row { @@ -3500,8 +3858,10 @@ kbd { scrollbar shared by all rows. Labels pin to the left via position:sticky inside each row. */ display: block; + flex: 1 1 0; /* flex-1 child: shrinks/grows as the panel opens */ + min-width: 0; overflow-x: auto; - overflow-y: hidden; + overflow-y: auto; scrollbar-width: thin; scrollbar-gutter: stable; position: relative; @@ -3634,6 +3994,10 @@ kbd { border-radius: 4px; border: 2px solid; object-fit: cover; + /* The stored projection is a three-view ([XY|YZ] over [XZ]); the embryo is + in the LEFT column (XY/XZ), and the centre is the black XY|YZ divider. + Crop to the left so the square thumbnail shows the embryo, not the gap. */ + object-position: left center; background: var(--bg-dark); } @@ -3658,15 +4022,25 @@ kbd { .filmstrip-detail { background: var(--bg-card); - /* Cap the detail panel so it always leaves the rows visible AND - has a scrollable body of its own. Without this, when an item is - expanded the detail panel can grow past the viewport bottom and - the parent's scroll is unintuitive (mouse wheel over the rows - converts to horizontal). max-height keeps it bounded; overflow-y - lets long VLM summaries scroll on their own. */ - max-height: 60vh; + /* Right-side panel: fixed-ish width that shrinks gracefully on narrow + viewports. Scrolls vertically inside itself so long VLM summaries + don't push the layout. When empty (no frame selected) it collapses + entirely so the rows get full width. (Recovered from WIP 0269e18d.) */ + flex: 0 0 auto; + width: clamp(360px, 32vw, 520px); + border-left: 1px solid var(--border); overflow-y: auto; overscroll-behavior: contain; + animation: filmstripDetailIn 0.18s ease-out; +} +.filmstrip-detail:empty { display: none; } +@keyframes filmstripDetailIn { + from { transform: translateX(8px); opacity: 0; } + to { transform: translateX(0); opacity: 1; } +} +/* In the narrow side panel, stack the image | VLM summary split vertically. */ +#filmstrip-detail .detail-split { + grid-template-columns: 1fr; } /* ======================================== @@ -9172,6 +9546,7 @@ body.modal-open { --map-zone-green: 90, 168, 122; /* RGB triples for compositing */ --map-zone-orange: 215, 152, 84; --map-zone-red: 220, 96, 88; + --map-embryo: 156, 120, 220; /* lavender — distinct from zones and marker */ --map-overlay-bg: rgba(11, 14, 19, 0.78); --map-overlay-bg-2: rgba(11, 14, 19, 0.92); --map-overlay-edge: rgba(212, 221, 232, 0.18); @@ -9189,6 +9564,7 @@ body.modal-open { --map-accent: #0e7490; --map-accent-2: #155e75; --map-warm: #a16207; + --map-embryo: 100, 60, 180; /* deeper purple for cream paper */ --map-overlay-bg: rgba(246, 243, 236, 0.82); --map-overlay-bg-2: rgba(246, 243, 236, 0.96); --map-overlay-edge: rgba(29, 43, 58, 0.18); @@ -9253,6 +9629,113 @@ body.modal-open { .devices-status-led.stale::before { background: var(--map-warm); } .devices-status-led.error::before { background: #f87171; } +/* --- Room-light toggle (header) -------------------------------------- */ +.devices-room-light { + display: inline-flex; + align-items: center; + gap: 0.4rem; + padding: 0.18rem 0.6rem 0.18rem 0.45rem; + border: 1px solid var(--map-overlay-edge); + background: var(--map-overlay-bg); + border-radius: 999px; + color: var(--map-ink-mute); + font-family: inherit; + font-size: 0.65rem; + font-weight: 600; + letter-spacing: 0.06em; + text-transform: uppercase; + cursor: pointer; + transition: color 0.15s, border-color 0.15s, background 0.15s; +} +.devices-room-light[hidden] { display: none; } +.devices-room-light:hover:not(:disabled) { + border-color: var(--map-accent); + color: var(--map-ink); +} +.devices-room-light:disabled { opacity: 0.5; cursor: default; } +.devices-room-light-bulb { + display: inline-flex; + align-items: center; + color: var(--map-ink-mute); + transition: color 0.15s, filter 0.15s; +} +/* "on" — warm glow on the bulb to read like a lit lamp */ +.devices-room-light.is-on { + border-color: rgba(255, 210, 74, 0.7); + color: #ffd24a; + background: rgba(255, 210, 74, 0.12); +} +.devices-room-light.is-on .devices-room-light-bulb { + color: #ffd24a; + filter: drop-shadow(0 0 5px rgba(255, 210, 74, 0.7)); +} +.devices-room-light.is-busy { opacity: 0.65; cursor: progress; } + +/* --- Temperature controller ------------------------------------------- */ +/* Readout + setpoint input + Set, styled as a pill to sit beside the + room-light toggle in the Devices header. Hidden until a controller is + available (mirrors the room light). */ +.devices-temp { + display: inline-flex; + align-items: center; + gap: 0.35rem; + padding: 0.12rem 0.28rem 0.12rem 0.5rem; + border: 1px solid var(--map-overlay-edge); + background: var(--map-overlay-bg); + border-radius: 999px; + color: var(--map-ink-mute); + font-family: inherit; + font-size: 0.65rem; + font-weight: 600; + letter-spacing: 0.06em; + text-transform: uppercase; +} +.devices-temp[hidden] { display: none; } +.devices-temp-icon { display: inline-flex; align-items: center; color: var(--map-ink-mute); transition: color 0.15s, filter 0.15s; } +/* "locked" — cool glow once the controller reports SYSTEM LOCKED */ +.devices-temp.is-locked { border-color: rgba(90, 200, 250, 0.6); color: #5ac8fa; background: rgba(90, 200, 250, 0.1); } +.devices-temp.is-locked .devices-temp-icon { color: #5ac8fa; filter: drop-shadow(0 0 5px rgba(90, 200, 250, 0.6)); } +.devices-temp.is-busy { opacity: 0.7; cursor: progress; } +.devices-temp-readout { + font-variant-numeric: tabular-nums; + letter-spacing: 0.02em; + min-width: 3.4em; + text-align: right; + white-space: nowrap; +} +.devices-temp-input { + width: 3.4em; + padding: 0.1rem 0.3rem; + border: 1px solid var(--map-overlay-edge); + background: var(--map-paper, rgba(0, 0, 0, 0.2)); + border-radius: 6px; + color: var(--map-ink); + font-family: inherit; + font-size: 0.72rem; + font-variant-numeric: tabular-nums; + text-align: right; + -moz-appearance: textfield; +} +.devices-temp-input:focus { outline: none; border-color: var(--map-accent); } +.devices-temp-input::-webkit-outer-spin-button, +.devices-temp-input::-webkit-inner-spin-button { -webkit-appearance: none; margin: 0; } +.devices-temp-set { + padding: 0.16rem 0.5rem; + border: 1px solid var(--map-overlay-edge); + background: var(--map-overlay-bg); + border-radius: 999px; + color: var(--map-ink-mute); + font-family: inherit; + font-size: 0.62rem; + font-weight: 700; + letter-spacing: 0.06em; + text-transform: uppercase; + cursor: pointer; + transition: color 0.15s, border-color 0.15s, background 0.15s; +} +.devices-temp-set:hover:not(:disabled) { border-color: var(--map-accent); color: var(--map-ink); } +.devices-temp-set:disabled { opacity: 0.5; cursor: default; } + /* --- Containers ------------------------------------------------------- */ .devices-view { display: flex; flex-direction: column; flex: 1; min-height: 0; } .devices-view-details { gap: 1rem; } @@ -9429,6 +9912,53 @@ body.modal-open { 100% { opacity: 0; r: 28; } } +/* --- Embryo waypoints ------------------------------------------------ */ +/* Coarse = bottom-camera / manual placement; fine = SPIM-objective + alignment. Coarse reads as an outlined ring (provisional), fine as a + filled disc (committed). Same hue so the row of embryos still reads as + one cohort, but visual weight signals calibration state at a glance. */ +.devices-embryo-group { + cursor: pointer; +} +.devices-embryo-ring { + fill: rgba(var(--map-embryo), 0.08); + stroke: rgba(var(--map-embryo), 0.85); + stroke-width: 1.4; + vector-effect: non-scaling-stroke; +} +.devices-embryo-disc { + fill: rgba(var(--map-embryo), 0.65); + stroke: rgba(var(--map-embryo), 0.95); + stroke-width: 1.4; + vector-effect: non-scaling-stroke; +} +.devices-embryo-label { + fill: var(--map-ink); + font-family: 'JetBrains Mono', ui-monospace, monospace; + font-weight: 600; + text-anchor: middle; + dominant-baseline: central; + pointer-events: none; + paint-order: stroke; + stroke: var(--map-paper); + stroke-width: 2; + stroke-linejoin: round; +} + +/* Selected = "picked up" — outlined dashed, hollow fill, brighter label. + Click on empty map drops the picked-up embryo at that XY; Delete / + Backspace removes it; Escape deselects. */ +.devices-embryo-group.devices-embryo-selected .devices-embryo-ring, +.devices-embryo-group.devices-embryo-selected .devices-embryo-disc { + fill: rgba(var(--map-embryo), 0.12); + stroke: rgba(var(--map-embryo), 1); + stroke-width: 2; + stroke-dasharray: 4 3; +} +.devices-embryo-group.devices-embryo-selected .devices-embryo-label { + fill: rgba(var(--map-embryo), 1); +} + /* --- Overlay panels (compass, readout, scalebar, legend) ------------- */ .devices-compass, .devices-map-readout, @@ -9635,9 +10165,44 @@ body.modal-open { display: block; opacity: 0; transition: opacity 0.25s; + /* Zoom anchored at frame centre; scroll-wheel + cursor adjust translate + so the point under the cursor stays under the cursor. */ + transform-origin: center center; + will-change: transform; } .devices-camera-img.has-frame { opacity: 1; } +/* Cursor hints for zoom/pan mode. Default cursor stays untouched at zoom 1 + so the operator can still interact with overlays under the camera. */ +.devices-camera-stage.camera-zoomed { cursor: grab; } +.devices-camera-stage.camera-panning { cursor: grabbing; } + +/* Centre reticle — full-span horizontal + vertical hairline marking the + FOV centre IN the image. SVG is a sibling of ; the inner + receives the same translate/scale (in viewBox units) so the lines + track the camera image through zoom/pan instead of staying pinned to + the viewer rect. Transform lives on the , not the SVG element, so + the renderer re-rasterises at each zoom step — otherwise the strokes + get bitmap-scaled and go blurry. */ +.devices-camera-crosshair { + position: absolute; + inset: 0; + width: 100%; + height: 100%; + pointer-events: none; + opacity: 0; + transition: opacity 0.25s; +} +.devices-camera-stage:has(.devices-camera-img.has-frame) .devices-camera-crosshair { + opacity: 1; +} +.devices-camera-crosshair line { + stroke: var(--map-warm); + stroke-width: 1; + vector-effect: non-scaling-stroke; + stroke-opacity: 0.85; +} + .devices-camera-placeholder { position: absolute; inset: 0; diff --git a/gently/ui/web/static/css/review.css b/gently/ui/web/static/css/review.css index de2bd66d..735d42a5 100644 --- a/gently/ui/web/static/css/review.css +++ b/gently/ui/web/static/css/review.css @@ -443,3 +443,30 @@ color: var(--text-muted); } + +/* Resume-in-agent action on session list items */ +.session-resume-btn { + margin-top: 8px; + padding: 5px 10px; + border-radius: 7px; + border: 1px solid var(--accent, #60a5fa); + background: transparent; + color: var(--accent, #60a5fa); + font-size: 12px; + font-weight: 600; + cursor: pointer; +} +.session-resume-btn:hover { background: var(--accent, #60a5fa); color: #fff; } +.session-active-badge { + font-size: 10px; + font-weight: 600; + text-transform: uppercase; + letter-spacing: 0.05em; + color: var(--accent-green, #4ade80); + border: 1px solid var(--accent-green, #4ade80); + border-radius: 999px; + padding: 1px 7px; + margin-left: 6px; + vertical-align: middle; +} +.session-item.active-session { border-left: 2px solid var(--accent-green, #4ade80); } diff --git a/gently/ui/web/static/js/agent-chat.js b/gently/ui/web/static/js/agent-chat.js new file mode 100644 index 00000000..63339b92 --- /dev/null +++ b/gently/ui/web/static/js/agent-chat.js @@ -0,0 +1,968 @@ +/** + * Floating agent-chat window — the web-side control surface. + * + * Connects to the same /ws/agent bridge the TUI uses, streams the agent's + * responses, and renders interactive choice pickers. A single-driver control + * lock on the server arbitrates who may drive the microscope; this client + * shows a banner and offers "Take control" when another client holds it. + * + * Self-contained IIFE (no build step). All untrusted text is escaped before + * insertion — never assign agent/user/tool strings to innerHTML directly. + */ +const AgentChat = (() => { + let ws = null; + let reconnectDelay = 1000; + const MAX_DELAY = 30000; + + let panelOpen = false; + let hasControl = true; // optimistic until the server says otherwise + let holderLabel = null; + let streaming = false; + let currentAgentEl = null; // the agent content element being streamed into + let activityEl = null; // the persistent "working…" indicator (reused) + let me = null; // { authenticated, username, role, can_control } + + // Autocomplete: slash-command + @tool registries (pushed by the server on + // connect) and the live dropdown state. + let commands = []; // [{name, description, aliases, ...}] + let tools = []; // [{name, description, params, ...}] + let acItems = []; // current completion items shown in the dropdown + let acIdx = -1; // highlighted item index + let autonomousTurn = false; // true while rendering an autonomous (wake) turn + let agentBusy = false; // a turn (user or autonomous) is currently running + let busySource = null; // 'user' | 'wake' while busy + let msgQueue = []; // messages typed while busy, sent on idle + let queuePanel = null; // the "⏳ Queued (N)" panel element + let stopBtn = null; // explicit Stop button (separate from Send) + + // DOM refs (resolved in init) + let panel, log, input, sendBtn, conn, banner, closeBtn, userEl, signoutBtn; + let toggleBtn, pinBtn, resizeEl, toggleDot, toggleBadge; // docked-panel chrome + let pendingSlot = null; // sticky slot for ASK approval proposals + let acComplete = null; // the autocomplete dropdown element + + // ── Safe rendering ──────────────────────────────────────── + function escapeHtml(s) { + const d = document.createElement('div'); + d.textContent = String(s == null ? '' : s); + return d.innerHTML; + } + + /** Minimal, safe markdown: escape first, then a few inline transforms. */ + function mdToHtml(text) { + let html = escapeHtml(text); + html = html.replace(/`([^`]+)`/g, '$1'); + html = html.replace(/\*\*([^*]+)\*\*/g, '$1'); + html = html.replace(/\*([^*]+)\*/g, '$1'); + html = html.replace(/\n/g, '
'); + return html; + } + + // Pin-to-bottom autoscroll: only follow new content if the user is already + // near the bottom; otherwise count unseen items and show a "↓ N new" pill so + // a streaming agent never yanks the operator away from something they're reading. + let stickBottom = true; + let newCount = 0; + let jumpPill = null; + function nearBottom() { return (log.scrollHeight - log.scrollTop - log.clientHeight) < 60; } + function renderJumpPill() { + if (!jumpPill) return; + if (!stickBottom && newCount > 0) { + jumpPill.textContent = `↓ ${newCount} new`; + jumpPill.classList.remove('hidden'); + } else { + jumpPill.classList.add('hidden'); + } + } + function scrollToBottom(isNewItem = true) { + if (stickBottom) { log.scrollTop = log.scrollHeight; } + // Only count genuinely new items (bubbles/rows), not in-place streaming + // text edits — otherwise the "N new" pill inflates per chunk. + else { if (isNewItem) newCount += 1; renderJumpPill(); } + } + function jumpToBottom() { + stickBottom = true; newCount = 0; + log.scrollTop = log.scrollHeight; + renderJumpPill(); + } + + // ── Activity indicator ──────────────────────────────────── + // A single reusable "the agent is working" row, always pinned to the + // bottom of the log. This is the trust signal — something is happening. + function setActivity(label) { + if (!activityEl) { + activityEl = document.createElement('div'); + activityEl.className = 'ac-activity'; + activityEl.innerHTML = + '' + + ''; + } + activityEl.querySelector('.ac-activity-label').textContent = label; + log.appendChild(activityEl); // (re)pin to bottom + scrollToBottom(); + } + function hideActivity() { + if (activityEl && activityEl.parentNode) activityEl.parentNode.removeChild(activityEl); + } + + // ── Message elements ────────────────────────────────────── + function addTurn(role) { + const wrap = document.createElement('div'); + wrap.className = `ac-turn ac-turn-${role}`; + if (role === 'agent' && autonomousTurn) wrap.classList.add('ac-turn-autonomous'); + if (role === 'agent') { + const label = document.createElement('div'); + label.className = 'ac-role'; + label.textContent = autonomousTurn ? 'Gently · autonomous' : 'Gently'; + wrap.appendChild(label); + } + const content = document.createElement('div'); + content.className = 'ac-content'; + wrap.appendChild(content); + log.appendChild(wrap); + scrollToBottom(); + return content; + } + + function addUserMessage(text, author) { + const wrap = document.createElement('div'); + wrap.className = 'ac-turn ac-turn-user'; + if (author) { + const label = document.createElement('div'); + label.className = 'ac-role ac-role-user'; + label.textContent = author; + wrap.appendChild(label); + } + const content = document.createElement('div'); + content.className = 'ac-content'; + content.textContent = text; + wrap.appendChild(content); + log.appendChild(wrap); + scrollToBottom(); + } + + /** Rebuild the transcript from a persisted/replayed history list. */ + function renderHistory(items) { + log.innerHTML = ''; + currentAgentEl = null; + activityEl = null; + stickBottom = true; newCount = 0; // a full rebuild jumps to latest + (items || []).forEach(it => { + if (it.role === 'user') { + addUserMessage(it.text, it.author); + } else if (it.role === 'agent') { + const c = addTurn('agent'); + c._raw = it.text || ''; + c.innerHTML = mdToHtml(c._raw); + } else if (it.role === 'autonomous_start') { + addAutonomousBanner(it.trigger || ''); + } else if (it.role === 'autonomous') { + autonomousTurn = true; + const c = addTurn('agent'); + c._raw = it.text || ''; + c.innerHTML = mdToHtml(c._raw); + autonomousTurn = false; + } else if (it.role === 'tool') { + const el = document.createElement('div'); + el.className = 'ac-tool ac-tool-done'; + const dur = it.duration ? ` · ${(it.duration.toFixed ? it.duration.toFixed(1) : it.duration)}s` : ''; + const summary = it.summary ? ` — ${escapeHtml(it.summary)}` : ''; + el.innerHTML = `${escapeHtml(it.name || 'tool')}${dur}${summary}`; + log.appendChild(el); + } else if (it.role === 'system') { + addSystemLine(it.text, it.level || 'info'); + } + }); + scrollToBottom(); + } + + /** A divider announcing the agent woke itself, with the trigger reason. */ + function addAutonomousBanner(trigger) { + const el = document.createElement('div'); + el.className = 'ac-autonomous-banner'; + const t = trigger ? `Gently woke up — ${trigger}` : 'Gently woke up'; + el.innerHTML = `${escapeHtml(t)}`; + log.appendChild(el); + scrollToBottom(); + } + + function addSystemLine(text, level = 'info') { + const el = document.createElement('div'); + el.className = `ac-system ac-level-${level}`; + el.textContent = text; + log.appendChild(el); + scrollToBottom(); + } + + // ── Protocol handlers ───────────────────────────────────── + function handle(msg) { + switch (msg.type) { + case 'connected': + reconnectDelay = 1000; + setConn(true, msg.version ? `Connected · v${msg.version}` : 'Connected'); + // The bridge ships the command + tool registries on connect. + // Capture them so the composer can offer autocomplete — the + // data was always on the wire; we just never used it. + commands = Array.isArray(msg.commands) ? msg.commands : []; + tools = Array.isArray(msg.tools) ? msg.tools : []; + break; + + case 'control_status': + hasControl = !!msg.you_have_control; + holderLabel = msg.holder_label || null; + renderControl(); + break; + + case 'history': + renderHistory(msg.items || []); + break; + + case 'user_message': + hideActivity(); + addUserMessage(msg.text, msg.author); + break; + + case 'stream_start': + streaming = true; + currentAgentEl = null; // created lazily on first text + setBusy(true, 'user'); + setActivity('Working…'); + break; + + case 'autonomous_start': + // The agent woke itself — render a distinct banner + label the + // following text as autonomous (no stream_start precedes this). + hideActivity(); + autonomousTurn = true; + currentAgentEl = null; + setBusy(true, 'wake'); + addAutonomousBanner(msg.trigger || ''); + bumpBadge(); + break; + + case 'thinking': + if (streaming) setActivity('Thinking…'); + break; + + case 'text': { + if (!currentAgentEl) { + hideActivity(); + currentAgentEl = addTurn('agent'); + currentAgentEl._raw = ''; + } + currentAgentEl._raw += (msg.text || ''); + currentAgentEl.innerHTML = mdToHtml(currentAgentEl._raw); + scrollToBottom(false); // in-place edit, not a new item + break; + } + + case 'tool_start': { + hideActivity(); // the running tool row is the signal now + currentAgentEl = null; // text after a tool starts a fresh bubble + const label = msg.tool_label || msg.tool_name || 'tool'; + const args = fmtArgs(msg.tool_input); + const el = document.createElement('div'); + el.className = 'ac-tool ac-tool-running'; + el.dataset.tool = msg.tool_name || ''; + el.innerHTML = + `
` + + `${escapeHtml(label)}
` + + (args ? `
${escapeHtml(args)}
` : ''); + log.appendChild(el); + scrollToBottom(); + break; + } + + case 'tool_call': { + const running = [...log.querySelectorAll('.ac-tool-running')] + .filter(e => e.dataset.tool === (msg.tool_name || '')); + const el = running[running.length - 1]; + const label = msg.tool_name || 'tool'; + const dur = msg.duration + ? ` · ${(msg.duration.toFixed ? msg.duration.toFixed(1) : msg.duration)}s` : ''; + const args = fmtArgs(msg.tool_input); + const summary = msg.result_summary || ''; + // Show ⚠ instead of ✓ when the tool errored or its result reads + // like a failure — so the operator can tell when a tool did nothing. + const isErr = !!msg.is_error || looksLikeError(summary); + const icon = isErr + ? `` + : ``; + const html = + `
${icon}` + + `${escapeHtml(label)}` + + `${dur}
` + + (args ? `
${escapeHtml(args)}
` : '') + + (summary ? `
${escapeHtml(summary)}
` : ''); + if (el) { + el.className = 'ac-tool ac-tool-done' + (isErr ? ' ac-tool-err' : ''); + el.innerHTML = html; + } else { + // No matching running row (e.g. after a reconnect) — append fresh. + const fresh = document.createElement('div'); + fresh.className = 'ac-tool ac-tool-done' + (isErr ? ' ac-tool-err' : ''); + fresh.innerHTML = html; + log.appendChild(fresh); + } + if (streaming) setActivity('Working…'); // agent continues after the tool + scrollToBottom(); + break; + } + + case 'choice_request': + hideActivity(); + renderChoice(msg); + bumpBadge(); + break; + + case 'applied_spec': + renderSpec(msg.spec || {}); + break; + + case 'stream_end': + streaming = false; + currentAgentEl = null; + autonomousTurn = false; + hideActivity(); + setBusy(false); + break; + + case 'command_result': + if (msg.error) addSystemLine(`${msg.command}: ${msg.error}`, 'error'); + else if (msg.content) addSystemLine(`${msg.command} ✓`, 'info'); + break; + + case 'notification': + addSystemLine(msg.body ? `${msg.title} — ${msg.body}` : msg.title, msg.level || 'info'); + bumpBadge(); + break; + + case 'error': + streaming = false; + hideActivity(); + setBusy(false); + addSystemLine(msg.error || 'Unknown error', 'error'); + break; + + case 'ping': + send({ type: 'pong' }); + break; + + default: + break; // pong / state_update / browse_result / unknown — ignored + } + } + + function renderChoice(msg) { + const data = msg.choice_data || {}; + const reqId = msg.request_id || data.request_id || ''; + const isWake = msg.origin === 'wake'; + const wrap = document.createElement('div'); + wrap.className = 'ac-choice' + (isWake ? ' ac-choice-wake' : ''); + if (isWake) { + const tag = document.createElement('div'); + tag.className = 'ac-choice-origin'; + tag.textContent = 'Autonomy proposal — your approval needed'; + wrap.appendChild(tag); + } + const q = document.createElement('div'); + q.className = 'ac-choice-q'; + q.innerHTML = mdToHtml(data.question || 'Choose:'); + wrap.appendChild(q); + + (data.options || []).forEach(opt => { + const btn = document.createElement('button'); + btn.className = 'ac-choice-opt'; + btn.disabled = !!opt.disabled || !hasControl; // observers see it read-only + const desc = opt.description ? `${escapeHtml(opt.description)}` : ''; + btn.innerHTML = `${escapeHtml(opt.label)}${desc}`; + btn.addEventListener('click', () => { + send({ type: 'choice_response', request_id: reqId, selected: opt.id }); + [...wrap.querySelectorAll('button')].forEach(b => b.disabled = true); + wrap.classList.add('ac-choice-answered'); + btn.classList.add('ac-choice-picked'); + if (streaming) setActivity('Working…'); + if (isWake && pendingSlot) { + setTimeout(() => { pendingSlot.classList.add('hidden'); pendingSlot.innerHTML = ''; }, 700); + } + }); + wrap.appendChild(btn); + }); + // ASK approvals pin to the sticky slot above the composer so they can't + // scroll out of reach; ordinary choices stay inline in the transcript. + if (isWake && pendingSlot) { + pendingSlot.innerHTML = ''; + pendingSlot.appendChild(wrap); + pendingSlot.classList.remove('hidden'); + return; + } + log.appendChild(wrap); + scrollToBottom(); + } + + function renderSpec(spec) { + const rows = []; + const add = (k, v) => { if (v !== undefined && v !== null && v !== '') rows.push([k, v]); }; + add('Strain', spec.strain); + add('Temperature', spec.temperature_c != null ? `${spec.temperature_c} °C` : null); + add('Slices', spec.num_slices); + add('Exposure', spec.exposure_ms != null ? `${spec.exposure_ms} ms` : null); + add('Interval', spec.interval_s != null ? `${spec.interval_s} s` : null); + add('Stop at', spec.stop_condition); + if (!rows.length) return; + const el = document.createElement('div'); + el.className = 'ac-spec'; + el.innerHTML = '
Imaging spec applied
' + + rows.map(([k, v]) => `
${escapeHtml(k)}${escapeHtml(v)}
`).join(''); + log.appendChild(el); + scrollToBottom(); + } + + // ── Tool argument formatting ────────────────────────────── + /** Compact, escaped "key=value" rendering of a tool's input for the chat. */ + function fmtArgs(input) { + if (!input || typeof input !== 'object') return ''; + const parts = []; + for (const [k, v] of Object.entries(input)) { + if (k === 'context' || v === null || v === undefined || v === '') continue; + let val = (typeof v === 'object') ? JSON.stringify(v) : String(v); + if (val.length > 48) val = val.slice(0, 47) + '…'; + parts.push(`${k}=${val}`); + } + return parts.join(' '); + } + + /** Heuristic: does a tool's result summary read like a failure? + * Used to show ⚠ for tools that return an error STRING (the agent only + * flags raised exceptions). Avoids false alarms like "No errors found". */ + function looksLikeError(s) { + if (!s) return false; + const t = s.trim(); + if (/^no\s+(errors?|issues?|problems?|anomal|changes?|warnings?)\b/i.test(t)) return false; + if (/^(error|failed|failure|unable|cannot|can'?t|could\s?n'?t|could not|denied|invalid|no |not )/i.test(t)) return true; + // mid-string failure markers, e.g. "Timepoint 7 not found for embryo_2". + return /\bnot (found|available|connected|recognized|valid|supported)\b/i.test(t); + } + + // ── Autocomplete ────────────────────────────────────────── + /** The whitespace-delimited token immediately left of the caret. */ + function currentToken() { + const v = input.value; + const pos = (input.selectionStart != null) ? input.selectionStart : v.length; + const before = v.slice(0, pos); + const m = before.match(/(\S+)$/); + return { token: m ? m[1] : '', start: m ? pos - m[1].length : pos, pos }; + } + + /** Compute completion items for the current input/caret, or []. */ + function computeCompletions() { + const trimmed = input.value.trimStart().toLowerCase(); + // Slash commands: whole-input prefix (mirrors the TUI). A trailing space + // (i.e. typing args) naturally yields no matches and hides the menu. + if (trimmed.startsWith('/')) { + return commands.filter(c => + (c.name && c.name.toLowerCase().startsWith(trimmed)) || + (c.aliases || []).some(a => String(a).toLowerCase().startsWith(trimmed)) + ).slice(0, 8).map(c => ({ kind: 'command', name: c.name, desc: c.description || '' })); + } + // @tool mention: complete the token under the caret against tool names. + const tok = currentToken(); + if (tok.token.startsWith('@') && tools.length) { + const q = tok.token.slice(1).toLowerCase(); + return tools.filter(t => t.name.toLowerCase().includes(q)) + .slice(0, 8) + .map(t => ({ kind: 'tool', name: t.name, desc: t.description || '', token: tok })); + } + return []; + } + + function renderCompletions(items) { + acItems = items || []; + acIdx = acItems.length ? 0 : -1; + if (!acComplete) return; + if (!acItems.length) { hideCompletions(); return; } + acComplete.innerHTML = ''; + acItems.forEach((it, i) => { + const row = document.createElement('div'); + row.className = 'ac-complete-item' + (i === acIdx ? ' active' : ''); + row.innerHTML = + `${escapeHtml(it.name)}` + + (it.desc ? `${escapeHtml(it.desc)}` : ''); + // mousedown (not click) so it fires before the textarea blurs. + row.addEventListener('mousedown', (e) => { e.preventDefault(); acceptCompletion(it); }); + acComplete.appendChild(row); + }); + acComplete.classList.remove('hidden'); + } + + function hideCompletions() { + acItems = []; + acIdx = -1; + if (acComplete) { acComplete.classList.add('hidden'); acComplete.innerHTML = ''; } + } + + function updateCompletions() { + renderCompletions(computeCompletions()); + } + + function moveCompletion(delta) { + if (!acItems.length || !acComplete) return; + acIdx = (acIdx + delta + acItems.length) % acItems.length; + [...acComplete.children].forEach((c, i) => c.classList.toggle('active', i === acIdx)); + } + + function acceptCompletion(item) { + if (!item) return; + if (item.kind === 'command') { + input.value = item.name + ' '; + const p = input.value.length; + try { input.setSelectionRange(p, p); } catch (_) {} + } else if (item.kind === 'tool') { + const tok = item.token || currentToken(); + const v = input.value; + const insert = '@' + item.name + ' '; + input.value = v.slice(0, tok.start) + insert + v.slice(tok.pos); + const p = tok.start + insert.length; + try { input.setSelectionRange(p, p); } catch (_) {} + } + hideCompletions(); + input.focus(); + autosize(); + } + + // ── Control / UI state ──────────────────────────────────── + function renderControl() { + if (hasControl) { + banner.classList.add('hidden'); + banner.innerHTML = ''; + input.disabled = false; + sendBtn.disabled = false; + input.placeholder = 'Message Gently… ( / commands · @ tools )'; + } else { + banner.classList.remove('hidden'); + const who = holderLabel || 'another session'; + input.disabled = true; + sendBtn.disabled = true; + if (me && me.accounts && !me.authenticated) { + // Anonymous — viewing is open; sign in to control. + banner.innerHTML = `Viewing — sign in to control.`; + const btn = document.createElement('button'); + btn.className = 'ac-take-control'; + btn.textContent = 'Sign in'; + btn.addEventListener('click', () => { window.location.href = '/login'; }); + banner.appendChild(btn); + input.placeholder = 'Viewing — sign in to control…'; + } else if (me && me.authenticated && me.can_control === false) { + // Viewer-role account — watching is all this account can do. + banner.innerHTML = `View-only access — you can watch but not control.`; + input.placeholder = 'View-only access'; + } else { + banner.innerHTML = `Control held by ${escapeHtml(who)}`; + const btn = document.createElement('button'); + btn.className = 'ac-take-control'; + btn.textContent = 'Take control'; + btn.addEventListener('click', () => send({ type: 'take_control' })); + banner.appendChild(btn); + input.placeholder = 'Viewing only — take control to drive…'; + } + } + } + + function setBusy(busy, source) { + agentBusy = !!busy; + busySource = agentBusy ? (source || 'user') : null; + // Send no longer doubles as Stop — it queues while busy. A separate Stop + // (shown only for a cancellable user turn) aborts the current turn. + if (stopBtn) stopBtn.classList.toggle('hidden', !(agentBusy && busySource === 'user')); + sendBtn.classList.toggle('ac-busy', agentBusy); + if (agentBusy) { + input.placeholder = (busySource === 'wake') + ? 'Gently is acting autonomously — your message will queue' + : 'Gently is working — your message will queue'; + } else { + if (hasControl) input.placeholder = 'Message Gently… ( / commands · @ tools )'; + drainQueue(); // a turn just ended — send the next queued message + } + } + + // ── Message queue (type-while-busy) ─────────────────────── + function enqueue(text) { msgQueue.push(text); renderQueue(); } + function removeQueued(i) { + if (i >= 0 && i < msgQueue.length) { msgQueue.splice(i, 1); renderQueue(); } + } + function clearQueue() { msgQueue = []; renderQueue(); } + function drainQueue() { + if (agentBusy || !msgQueue.length) return; + if (!ws || ws.readyState !== WebSocket.OPEN) return; // keep queued until reconnect + const next = msgQueue.shift(); + renderQueue(); + actuallySend(next); + } + function renderQueue() { + if (!queuePanel) return; + if (!msgQueue.length) { queuePanel.classList.add('hidden'); queuePanel.innerHTML = ''; return; } + queuePanel.classList.remove('hidden'); + queuePanel.innerHTML = ''; + const head = document.createElement('div'); + head.className = 'ac-queue-head'; + const lbl = document.createElement('span'); + lbl.textContent = `⏳ Queued (${msgQueue.length})`; + const clear = document.createElement('button'); + clear.className = 'ac-queue-clear'; + clear.textContent = 'Clear all'; + clear.addEventListener('click', clearQueue); + head.appendChild(lbl); + head.appendChild(clear); + queuePanel.appendChild(head); + msgQueue.forEach((m, i) => { + const row = document.createElement('div'); + row.className = 'ac-queue-item'; + const span = document.createElement('span'); + span.className = 'ac-queue-text'; + span.textContent = m; + const x = document.createElement('button'); + x.className = 'ac-queue-remove'; + x.textContent = '✕'; + x.title = 'Remove from queue'; + x.addEventListener('click', () => removeQueued(i)); + row.appendChild(span); + row.appendChild(x); + queuePanel.appendChild(row); + }); + } + + function setConn(ok, label) { + conn.classList.toggle('ac-conn-ok', ok); + conn.classList.toggle('ac-conn-bad', !ok); + conn.textContent = label || (ok ? 'Connected' : 'Reconnecting…'); + if (toggleDot) toggleDot.classList.toggle('ok', ok); + } + + // ── Transport ───────────────────────────────────────────── + function send(obj) { + if (ws && ws.readyState === WebSocket.OPEN) ws.send(JSON.stringify(obj)); + } + + function connect() { + const proto = location.protocol === 'https:' ? 'wss:' : 'ws:'; + setConn(false, 'Connecting…'); + ws = new WebSocket(`${proto}//${location.host}/ws/agent`); + ws.onopen = () => { reconnectDelay = 1000; setConn(true); }; + ws.onclose = () => { + setConn(false, 'Reconnecting…'); + setBusy(false); + streaming = false; + hideActivity(); + setTimeout(connect, reconnectDelay); + reconnectDelay = Math.min(reconnectDelay * 2, MAX_DELAY); + }; + ws.onerror = () => {}; + ws.onmessage = (e) => { + let msg; + try { msg = JSON.parse(e.data); } catch { return; } + handle(msg); + }; + } + + // ── Input handling ──────────────────────────────────────── + function actuallySend(text) { + if (text.startsWith('/')) { + addUserMessage(text); // commands aren't broadcast; echo locally + send({ type: 'command', command: text }); // slash commands (e.g. /status) + // Most commands reply with a single 'command_result' and no stream — + // do NOT mark the composer busy, or the queue would stick forever. + // Commands that DO stream (e.g. /wizard) set busy via stream_start. + return; + } + send({ type: 'chat', text }); // echoed to all via 'user_message' + // Instant feedback before the first chunk arrives. + setBusy(true, 'user'); + setActivity('Working…'); + } + + function submit() { + hideCompletions(); + const text = input.value.trim(); + if (!text) return; + if (!hasControl) { renderControl(); return; } + input.value = ''; + autosize(); + // While the agent is busy (a user OR autonomous turn), queue instead of + // cancelling — Send no longer doubles as Stop. + if (agentBusy) { enqueue(text); return; } + actuallySend(text); + } + + function autosize() { + input.style.height = 'auto'; + input.style.height = Math.min(input.scrollHeight, 140) + 'px'; + } + + function togglePanel(open) { + panelOpen = (open === undefined) ? !panelOpen : open; + panel.classList.toggle('open', panelOpen); + if (toggleBtn) { + toggleBtn.setAttribute('aria-pressed', panelOpen ? 'true' : 'false'); + toggleBtn.setAttribute('aria-expanded', panelOpen ? 'true' : 'false'); + } + if (panelOpen) { + clearBadge(); + if (!ws) connect(); + // Re-pin to the latest content (it may have streamed while closed, + // where scroll events don't fire to keep stickBottom current). + setTimeout(() => { input.focus(); jumpToBottom(); }, 50); + } + // Opening/closing while docked reflows .app-main — tell viewers to resize. + if (document.body.classList.contains('chat-docked')) emitLayoutChanged(); + } + + // ── Layout: dock, resize, persistence ───────────────────── + const CHAT_MIN_W = 320; + const CHAT_DEFAULT_W = 460; + // Roomy ceiling: the panel shows agent reasoning, tool calls, approvals and + // pickers — content that wraps badly in a narrow column — so allow up to + // ~half the viewport (was min(560, 45vw), which capped power users too low). + function chatMaxW() { return Math.min(760, Math.round(window.innerWidth * 0.60)); } + + function emitLayoutChanged() { + // Let the CSS settle, then notify viewers (e.g. the 3D canvas) to resize. + requestAnimationFrame(() => window.dispatchEvent(new CustomEvent('gently:layout-changed'))); + } + + function curChatWidth() { + return parseInt(getComputedStyle(document.documentElement).getPropertyValue('--chat-w')) || CHAT_DEFAULT_W; + } + + function setChatWidth(px, persist) { + const w = Math.max(CHAT_MIN_W, Math.min(chatMaxW(), Math.round(px))); + document.documentElement.style.setProperty('--chat-w', w + 'px'); + if (persist) { try { localStorage.setItem('gently-chat-w', String(w)); } catch (_) {} } + return w; + } + + function applyDock(docked, persist) { + document.body.classList.toggle('chat-docked', docked); + if (pinBtn) { + pinBtn.setAttribute('aria-pressed', docked ? 'true' : 'false'); + pinBtn.title = docked ? 'Unpin (float over content)' : 'Pin to dock'; + } + if (persist) { try { localStorage.setItem('gently-chat-docked', docked ? '1' : '0'); } catch (_) {} } + // Suppress the slide animation across the mode flip, then notify viewers. + panel.style.transition = 'none'; + requestAnimationFrame(() => { panel.style.transition = ''; emitLayoutChanged(); }); + } + + function togglePin() { + const docked = !document.body.classList.contains('chat-docked'); + if (docked && !panelOpen) togglePanel(true); // pinning implies showing + applyDock(docked, true); + } + + function setupResize() { + if (!resizeEl) return; + let startX = 0, startW = 0, dragging = false, rafId = 0, pid = null; + const onMove = (e) => { + if (!dragging) return; + setChatWidth(startW + (startX - e.clientX), false); // right panel: drag left = wider + if (document.body.classList.contains('chat-docked')) { + if (rafId) cancelAnimationFrame(rafId); + rafId = requestAnimationFrame(emitLayoutChanged); // coalesce dock reflow + } + }; + const onUp = () => { + if (!dragging) return; + dragging = false; + resizeEl.classList.remove('dragging'); + resizeEl.removeEventListener('pointermove', onMove); + resizeEl.removeEventListener('pointerup', onUp); + resizeEl.removeEventListener('pointercancel', onUp); + if (pid !== null && resizeEl.hasPointerCapture && resizeEl.hasPointerCapture(pid)) { + try { resizeEl.releasePointerCapture(pid); } catch (_) {} + } + pid = null; + document.body.style.userSelect = ''; + setChatWidth(curChatWidth(), true); + emitLayoutChanged(); + }; + resizeEl.addEventListener('pointerdown', (e) => { + if (e.button !== 0) return; // primary button only + e.preventDefault(); + dragging = true; + startX = e.clientX; + startW = curChatWidth(); + pid = e.pointerId; + // Capture so move/up/cancel always reach the handle (touch/pen-safe). + try { resizeEl.setPointerCapture(pid); } catch (_) {} + resizeEl.classList.add('dragging'); + document.body.style.userSelect = 'none'; + resizeEl.addEventListener('pointermove', onMove); + resizeEl.addEventListener('pointerup', onUp); + resizeEl.addEventListener('pointercancel', onUp); + }); + resizeEl.addEventListener('dblclick', () => { setChatWidth(CHAT_DEFAULT_W, true); emitLayoutChanged(); }); + } + + function restorePrefs() { + try { + const w = parseInt(localStorage.getItem('gently-chat-w')); + if (w) setChatWidth(w, false); + if (localStorage.getItem('gently-chat-docked') === '1') applyDock(true, false); + } catch (_) {} + } + + // Unseen-activity badge on the header toggle — so a closed panel still tells + // the operator the agent did something (woke, proposed an approval, notified). + let badgeCount = 0; + function bumpBadge() { + if (panelOpen) return; // they're watching; no badge needed + badgeCount += 1; + if (toggleBadge) { + toggleBadge.textContent = badgeCount > 9 ? '9+' : String(badgeCount); + toggleBadge.classList.remove('hidden'); + } + } + function clearBadge() { + badgeCount = 0; + if (toggleBadge) { toggleBadge.classList.add('hidden'); toggleBadge.textContent = ''; } + } + + // ── Identity ────────────────────────────────────────────── + function fetchMe() { + fetch('/api/auth/me').then(r => r.json()).then(m => { + me = m; + if (m && m.authenticated) { + userEl.textContent = m.username; + userEl.title = `Signed in as ${m.username} (${m.role})`; + signoutBtn.textContent = 'Sign out'; + signoutBtn.dataset.action = 'logout'; + signoutBtn.style.display = ''; + } else if (m && m.accounts) { + // Anonymous — viewing is open; sign in to gain control. + userEl.textContent = 'viewing'; + userEl.title = 'Not signed in — view-only'; + signoutBtn.textContent = 'Sign in'; + signoutBtn.dataset.action = 'login'; + signoutBtn.style.display = ''; + } else { + // No accounts configured (legacy mode). + userEl.textContent = ''; + signoutBtn.style.display = 'none'; + } + renderControl(); + }).catch(() => {}); + } + + // ── Init ────────────────────────────────────────────────── + function init() { + panel = document.getElementById('agent-chat'); + log = document.getElementById('agent-chat-log'); + input = document.getElementById('agent-chat-text'); + sendBtn = document.getElementById('agent-chat-send'); + conn = document.getElementById('agent-chat-conn'); + banner = document.getElementById('agent-control-banner'); + closeBtn = document.getElementById('agent-chat-close'); + userEl = document.getElementById('agent-chat-user'); + signoutBtn = document.getElementById('agent-chat-signout'); + toggleBtn = document.getElementById('agent-chat-toggle'); + pinBtn = document.getElementById('agent-chat-pin'); + resizeEl = document.getElementById('agent-chat-resize'); + toggleDot = document.getElementById('agent-chat-toggle-dot'); + toggleBadge = document.getElementById('agent-chat-toggle-badge'); + if (!panel) return; // markup not present + + restorePrefs(); + if (toggleBtn) toggleBtn.addEventListener('click', () => togglePanel()); + closeBtn.addEventListener('click', () => togglePanel(false)); + if (pinBtn) pinBtn.addEventListener('click', togglePin); + setupResize(); + // Ctrl/Cmd+J toggles the panel from anywhere. + document.addEventListener('keydown', (e) => { + if ((e.ctrlKey || e.metaKey) && (e.key === 'j' || e.key === 'J')) { + e.preventDefault(); // suppress browser default (downloads) always + if (e.repeat) return; // ignore held-key auto-repeat + if (document.activeElement === input) return; // don't toggle while composing + togglePanel(); + } + }); + signoutBtn.addEventListener('click', async () => { + if (signoutBtn.dataset.action === 'login') { + window.location.href = '/login'; + return; + } + try { await fetch('/api/auth/logout', { method: 'POST' }); } catch (_) {} + window.location.reload(); + }); + fetchMe(); + + // Build the autocomplete dropdown inside the composer (positioned above + // the textarea via CSS). + const inputWrap = input.parentNode; + if (inputWrap) { + acComplete = document.createElement('div'); + acComplete.className = 'ac-complete hidden'; + inputWrap.insertBefore(acComplete, inputWrap.firstChild); + + // Queued-message panel (above the composer) for type-while-busy. + queuePanel = document.createElement('div'); + queuePanel.className = 'ac-queue hidden'; + if (inputWrap.parentNode) inputWrap.parentNode.insertBefore(queuePanel, inputWrap); + + // Explicit Stop button — shown only during a cancellable user turn. + stopBtn = document.createElement('button'); + stopBtn.className = 'ac-stop hidden'; + stopBtn.textContent = 'Stop'; + stopBtn.title = 'Stop the current turn'; + stopBtn.addEventListener('click', () => { send({ type: 'cancel' }); setBusy(false); }); + inputWrap.appendChild(stopBtn); + + // Sticky ASK-approval slot — above the queue + composer, never scrolls away. + pendingSlot = document.createElement('div'); + pendingSlot.className = 'ac-pending hidden'; + if (inputWrap.parentNode) inputWrap.parentNode.insertBefore(pendingSlot, queuePanel); + } + + // "↓ N new" jump pill + pin-to-bottom scroll tracking. + jumpPill = document.createElement('button'); + jumpPill.className = 'ac-jump hidden'; + jumpPill.addEventListener('click', jumpToBottom); + panel.appendChild(jumpPill); + log.addEventListener('scroll', () => { + stickBottom = nearBottom(); + if (stickBottom) newCount = 0; + renderJumpPill(); + }); + + sendBtn.addEventListener('click', submit); + input.addEventListener('input', () => { autosize(); updateCompletions(); }); + // Close the menu shortly after blur (delay lets a mousedown selection land). + input.addEventListener('blur', () => setTimeout(hideCompletions, 120)); + input.addEventListener('keydown', (e) => { + // While the completion menu is open it owns the navigation keys. + if (acItems.length) { + if (e.key === 'ArrowDown') { e.preventDefault(); moveCompletion(1); return; } + if (e.key === 'ArrowUp') { e.preventDefault(); moveCompletion(-1); return; } + if (e.key === 'Tab') { e.preventDefault(); acceptCompletion(acItems[acIdx]); return; } + if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); acceptCompletion(acItems[acIdx]); return; } + if (e.key === 'Escape') { e.preventDefault(); hideCompletions(); return; } + } + if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); submit(); } + // Escape mirrors Stop: cancel a cancellable (user) turn and clear busy + // (a cancelled turn emits no stream_end, so clear optimistically). + if (e.key === 'Escape' && agentBusy && busySource === 'user') { + e.preventDefault(); send({ type: 'cancel' }); setBusy(false); + } + }); + } + + document.addEventListener('DOMContentLoaded', init); + + // Public: programmatically send a message/command (e.g. the Home page's + // "Start / continue an experiment" button sends '/wizard'). + function runCommand(text) { + if (!text) return; + if (!hasControl) { renderControl(); return; } + actuallySend(text); + } + + return { togglePanel, runCommand }; +})(); diff --git a/gently/ui/web/static/js/app.js b/gently/ui/web/static/js/app.js index 203cea04..d1e75a72 100644 --- a/gently/ui/web/static/js/app.js +++ b/gently/ui/web/static/js/app.js @@ -6,7 +6,7 @@ const state = { ws: null, connected: false, - tab: TABS.EMBRYOS, // Default to Embryos tab + tab: TABS.HOME, // Default to the Home landing tab snapshots: [], calibration: [], embryos: [], @@ -71,6 +71,9 @@ function switchTab(tabName) { const content = document.getElementById(`${tabName}-content`); if (content) content.classList.add('active'); + // Lazy-init Home landing tab + if (tabName === TABS.HOME && typeof HomeApp !== 'undefined') HomeApp.init(); + // Render galleries if (tabName === TABS.CALIBRATION) renderCalibrationGallery(); if (tabName === TABS.EVENTS) renderEventsTable(); @@ -631,7 +634,7 @@ document.addEventListener('DOMContentLoaded', () => { const hash = window.location.hash.slice(1); // remove # if (hash) { const [tab, param] = hash.split(':'); - if (tab === TABS.PLANS || tab === TABS.SESSIONS || tab === TABS.EMBRYOS || tab === TABS.CALIBRATION || tab === TABS.EVENTS || tab === TABS.EXPERIMENT) { + if (tab === TABS.HOME || tab === TABS.PLANS || tab === TABS.SESSIONS || tab === TABS.EMBRYOS || tab === TABS.CALIBRATION || tab === TABS.EVENTS || tab === TABS.EXPERIMENT) { switchTab(tab); if (tab === TABS.PLANS && param && typeof openCampaign === 'function') { setTimeout(() => openCampaign(param), 200); diff --git a/gently/ui/web/static/js/devices.js b/gently/ui/web/static/js/devices.js index 0f2f5316..caa1bf48 100644 --- a/gently/ui/web/static/js/devices.js +++ b/gently/ui/web/static/js/devices.js @@ -32,11 +32,14 @@ const DevicesManager = (function () { let _mapWrap; let _scalebarLabel; - // Embryos overlay state: list of {embryo_id, x, y, role, ...}. - // Populated by /api/embryos/positions on init + EMBRYO_DETECTED / - // STATUS_CHANGED WS pushes thereafter. Roles drive the marker color + // Embryo waypoints — driven by EMBRYOS_UPDATE events (the canonical bulk + // mutation broadcast added by the embryos-broadcast commit) and the + // initial /api/embryos/current snapshot. Each entry mirrors + // EmbryoState.to_dict() (id, position_coarse, position_fine, + // has_fine_position, nickname, role, ...). Role drives marker color // (mirrors the marking-window legend: magenta=test, cyan=calibration, - // grey=unassigned). + // grey=unassigned). EMBRYO_DETECTED / STATUS_CHANGED listeners stay + // hooked as a belt-and-braces refresh path. let _embryos = []; const _ROLE_COLOR = { test: '#ff66cc', @@ -44,8 +47,14 @@ const DevicesManager = (function () { unassigned: '#888888', }; + // Map-side edit state. _selectedEmbryoId means "picked up": the next + // click on empty map space drops it there (with a confirm), Delete / + // Backspace removes it (with a confirm), Escape clears the selection. + let _selectedEmbryoId = null; + // Bottom-camera panel DOM + state let _camPanel, _camToggle, _camImg, _camPlaceholder, _camLed, _camMeta; + let _camStage, _camCrosshair, _camCrosshairGroup; let _camStreaming = false; let _camLastFrameTs = 0; let _camHasFrame = false; @@ -53,6 +62,32 @@ const DevicesManager = (function () { const _CAM_FPS_WINDOW = 12; let _camFrameTimes = []; + // Camera zoom / pan. Identity transform = (zoom 1, tx 0, ty 0); pan only + // engages once zoom > 1. Reset on double-click and on stream-off. + let _camZoom = 1; + let _camTx = 0; + let _camTy = 0; + let _camPanLast = null; // {x, y} clientX/Y of last pointermove during pan + const _CAM_ZOOM_MIN = 1; + const _CAM_ZOOM_MAX = 8; + const _CAM_ZOOM_STEP = 1.15; // multiplicative per wheel notch + + // Room-light toggle (header). Drives the SwitchBot Bot that switches the + // diSPIM room light. State is the bot's cached on/off; hidden until the + // device layer reports the accessory is configured. + let _roomLightToggle, _roomLightLabel; + let _roomLightState = 'unknown'; + let _roomLightAvailable = false; + let _roomLightBusy = false; + let _roomLightTimer = null; + + // Temperature-controller panel DOM + state + let _tempEl, _tempReadout, _tempInput, _tempSet; + let _tempState = 'unknown'; + let _tempAvailable = false; + let _tempBusy = false; + let _tempTimer = null; + let _lastTs = 0; let _previousTs = 0; let _lastWallTs = 0; @@ -110,9 +145,20 @@ const DevicesManager = (function () { _camToggle = document.getElementById('devices-camera-toggle'); _camImg = document.getElementById('devices-camera-img'); _camPlaceholder = document.getElementById('devices-camera-placeholder'); + _camStage = _camPanel ? _camPanel.querySelector('.devices-camera-stage') : null; + _camCrosshair = document.getElementById('devices-camera-crosshair'); + _camCrosshairGroup = document.getElementById('devices-camera-crosshair-group'); _camLed = document.getElementById('devices-camera-led'); _camMeta = document.getElementById('devices-camera-meta'); + _roomLightToggle = document.getElementById('devices-room-light-toggle'); + _roomLightLabel = document.getElementById('devices-room-light-label'); + + _tempEl = document.getElementById('devices-temp'); + _tempReadout = document.getElementById('devices-temp-readout'); + _tempInput = document.getElementById('devices-temp-input'); + _tempSet = document.getElementById('devices-temp-set'); + // Recompute the scale bar caption whenever the canvas resizes. if (_mapSvg && window.ResizeObserver) { new ResizeObserver(() => updateScalebar()).observe(_mapSvg); @@ -257,6 +303,30 @@ const DevicesManager = (function () { } } + // Initial embryo snapshot — closes the gap for clients that connect + // mid-session, after the last EMBRYOS_UPDATE has already been broadcast + // and aged out of history. Subsequent updates arrive over the event bus. + async function loadEmbryosSnapshot() { + try { + const res = await fetch('/api/embryos/current'); + if (!res.ok) return; + const data = await res.json(); + handleEmbryosUpdate(data); + } catch (err) { + console.debug('embryos snapshot fetch failed:', err); + } + } + + function handleEmbryosUpdate(payload) { + _embryos = (payload && Array.isArray(payload.embryos)) ? payload.embryos : []; + if (!_viewBox) { + computeViewBox(); + renderMap(); + } else { + renderEmbryos(); + } + } + // ===================================================================== // Properties table (Details view) // ===================================================================== @@ -744,6 +814,215 @@ const DevicesManager = (function () { return Math.round(v).toString(); } + // ===================================================================== + // Embryo waypoints + // ===================================================================== + + // "embryo_007" / "embryo_7" -> 7. Falls back to a 1-based index from the + // caller so the label always shows *something*, even for stray ids. + function embryoLabelText(id, fallbackIndex) { + const m = id && String(id).match(/(\d+)/); + if (m) { + const n = parseInt(m[1], 10); + if (Number.isFinite(n)) return String(n); + } + return String(fallbackIndex + 1); + } + + // Resolve XY for rendering — fine if SPIM-aligned, else coarse. Returns + // null when neither stage carries usable values so the entry is skipped + // (e.g. an embryo whose detection record came in malformed). + function embryoResolvedXY(emb) { + const f = emb && emb.position_fine; + if (f && Number.isFinite(f.x) && Number.isFinite(f.y)) return { x: f.x, y: f.y }; + const c = emb && emb.position_coarse; + if (c && Number.isFinite(c.x) && Number.isFinite(c.y)) return { x: c.x, y: c.y }; + return null; + } + + function renderEmbryos() { + if (!_mapEmbryos || !_viewBox) return; + _mapEmbryos.innerHTML = ''; + if (!_embryos || !_embryos.length) return; + const span = Math.max(_viewBox.xMax - _viewBox.xMin, + _viewBox.yMax - _viewBox.yMin); + const radius = span * 0.012; + const fontSize = span * 0.015; + + _embryos.forEach((emb, i) => { + const xy = embryoResolvedXY(emb); + if (!xy) return; + + const isFine = !!emb.has_fine_position; + const isSelected = _selectedEmbryoId !== null + && emb.id === _selectedEmbryoId; + + // Wrap circle + label in a group so a single closest() lookup + // finds the embryo regardless of which child the click hit. + const group = document.createElementNS(SVG_NS, 'g'); + group.setAttribute('class', + 'devices-embryo-group' + (isSelected ? ' devices-embryo-selected' : '')); + group.setAttribute('data-embryo-id', emb.id || ''); + group.setAttribute('data-embryo-stage', isFine ? 'fine' : 'coarse'); + + const circle = document.createElementNS(SVG_NS, 'circle'); + circle.setAttribute('cx', xy.x); + circle.setAttribute('cy', svgY(xy.y)); + circle.setAttribute('r', radius); + circle.setAttribute('class', + isFine ? 'devices-embryo-disc' : 'devices-embryo-ring'); + group.appendChild(circle); + + const label = document.createElementNS(SVG_NS, 'text'); + label.setAttribute('x', xy.x); + label.setAttribute('y', svgY(xy.y)); + label.setAttribute('class', 'devices-embryo-label'); + label.setAttribute('font-size', fontSize); + label.textContent = embryoLabelText(emb.id, i); + group.appendChild(label); + + _mapEmbryos.appendChild(group); + }); + } + + // ---- Map-side edit interactions ------------------------------------ + // Convert a pointer event's client coords into stage µm. SVG y axis is + // positive-down and stage y is positive-up, so the y component is + // negated to match the convention used elsewhere in this module. + function eventToStageXY(event) { + if (!_mapSvg || !_mapSvg.getScreenCTM) return null; + const ctm = _mapSvg.getScreenCTM(); + if (!ctm) return null; + const pt = _mapSvg.createSVGPoint(); + pt.x = event.clientX; + pt.y = event.clientY; + const local = pt.matrixTransform(ctm.inverse()); + return { x: local.x, y: -local.y }; + } + + function findEmbryoIdAt(target) { + if (!target) return null; + const node = target.closest && target.closest('[data-embryo-id]'); + return node ? node.getAttribute('data-embryo-id') : null; + } + + function embryoById(id) { + return _embryos.find(e => e.id === id) || null; + } + + function embryoNumberFor(emb) { + return embryoLabelText(emb.id, _embryos.indexOf(emb)); + } + + function setSelectedEmbryo(id) { + if (_selectedEmbryoId === id) return; + _selectedEmbryoId = id; + renderEmbryos(); + } + + function clearSelection() { + if (_selectedEmbryoId === null) return; + _selectedEmbryoId = null; + renderEmbryos(); + } + + async function attemptMoveSelected(targetStage) { + const id = _selectedEmbryoId; + if (!id) return; + const emb = embryoById(id); + if (!emb) { clearSelection(); return; } + const cur = embryoResolvedXY(emb); + const num = embryoNumberFor(emb); + const oldStr = cur ? `(${cur.x.toFixed(1)}, ${cur.y.toFixed(1)})` : '(unknown)'; + const newStr = `(${targetStage.x.toFixed(1)}, ${targetStage.y.toFixed(1)})`; + if (!window.confirm(`Move embryo ${num} from ${oldStr} to ${newStr}?`)) { + return; // keep the embryo picked up so they can try again + } + try { + const res = await fetch(`/api/embryos/${encodeURIComponent(id)}/position`, { + method: 'PUT', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ x: targetStage.x, y: targetStage.y }), + }); + if (!res.ok) { + window.alert(`Move failed (${res.status}): ${await res.text()}`); + return; + } + // EMBRYOS_UPDATE will arrive over the bus and refresh the layer; + // dropping clears the picked-up state regardless. + clearSelection(); + } catch (err) { + console.error('move embryo:', err); + window.alert(`Move failed: ${err.message}`); + } + } + + async function attemptDeleteSelected() { + const id = _selectedEmbryoId; + if (!id) return; + const emb = embryoById(id); + const num = emb ? embryoNumberFor(emb) : id; + if (!window.confirm(`Remove embryo ${num}?`)) return; + try { + const res = await fetch(`/api/embryos/${encodeURIComponent(id)}`, { + method: 'DELETE', + }); + if (!res.ok) { + window.alert(`Delete failed (${res.status}): ${await res.text()}`); + return; + } + // The embryo is gone from the server snapshot; EMBRYOS_UPDATE + // will arrive and drop it from _embryos. Clear locally too. + _selectedEmbryoId = null; + } catch (err) { + console.error('delete embryo:', err); + window.alert(`Delete failed: ${err.message}`); + } + } + + function onMapPointerDown(event) { + // Ignore non-primary buttons so right-clicks etc. don't trigger UI. + if (event.button !== undefined && event.button !== 0) return; + const id = findEmbryoIdAt(event.target); + if (id) { + setSelectedEmbryo(id); + return; + } + // Empty-space click: drop the picked-up embryo here. + if (_selectedEmbryoId !== null) { + const stage = eventToStageXY(event); + if (stage) attemptMoveSelected(stage); + } + } + + function onMapKeyDown(event) { + // Only honour keys when the operator is actually looking at the Map: + // not on another top-level tab, not on the Details subview, and not + // typing into an input / textarea / select / contenteditable. + if (typeof state !== 'undefined' && typeof TABS !== 'undefined' + && state.tab !== TABS.DEVICES) { + return; + } + if (_currentView !== 'map') return; + const a = document.activeElement; + if (a && (a.tagName === 'INPUT' || a.tagName === 'TEXTAREA' || + a.tagName === 'SELECT' || a.isContentEditable)) { + return; + } + if (event.key === 'Escape') { + if (_selectedEmbryoId !== null) { + clearSelection(); + event.preventDefault(); + } + return; + } + if (_selectedEmbryoId === null) return; + if (event.key === 'Delete' || event.key === 'Backspace') { + event.preventDefault(); // Backspace would otherwise navigate back + attemptDeleteSelected(); + } + } + function updateMapMarker() { if (!_mapMarker || !_lastXY) return; const sx = _lastXY.X; @@ -820,6 +1099,9 @@ const DevicesManager = (function () { if (_camPlaceholder) _camPlaceholder.hidden = false; if (_camMeta) _camMeta.textContent = 'stream off'; if (_camStaleTimer) { clearTimeout(_camStaleTimer); _camStaleTimer = null; } + // Operator may have zoomed in; reset so the next stream session + // starts at 1× rather than inheriting a stale view. + resetCameraZoom(); } else { _camFrameTimes = []; if (_camMeta) _camMeta.textContent = 'waiting…'; @@ -881,6 +1163,116 @@ const DevicesManager = (function () { } } + // ---- Camera zoom / pan --------------------------------------------- + function applyCameraTransform() { + if (!_camImg) return; + _camImg.style.transform = + `translate(${_camTx}px, ${_camTy}px) scale(${_camZoom})`; + // Reticle uses an SVG transform attribute on the inner instead + // of a CSS transform on the SVG element — same geometric effect, + // but the SVG renderer re-rasterises at the new zoom so the 1px + // strokes stay crisp instead of getting bitmap-scaled. + if (_camCrosshairGroup && _camStage) { + const rect = _camStage.getBoundingClientRect(); + // Convert pixel-space translation to viewBox units (viewBox is + // 0..100 in both axes, preserveAspectRatio=none). + const txV = rect.width > 0 ? (_camTx * 100) / rect.width : 0; + const tyV = rect.height > 0 ? (_camTy * 100) / rect.height : 0; + // translate(50+tx, 50+ty) scale(zoom) translate(-50, -50) keeps + // the viewBox centre (50, 50) as the zoom anchor and offsets by + // the converted pixel translation. + _camCrosshairGroup.setAttribute( + 'transform', + `translate(${50 + txV} ${50 + tyV}) ` + + `scale(${_camZoom}) ` + + `translate(-50 -50)` + ); + } + } + + function resetCameraZoom() { + _camZoom = 1; + _camTx = 0; + _camTy = 0; + applyCameraTransform(); + if (_camStage) _camStage.classList.remove('camera-zoomed', 'camera-panning'); + } + + // Keep at least the image centre within the visible window so the + // operator can't accidentally pan the entire frame off-screen. At + // zoom 1 this collapses to (0, 0). + function clampCameraPan() { + if (!_camStage) return; + const rect = _camStage.getBoundingClientRect(); + const maxX = (rect.width * (_camZoom - 1)) / 2; + const maxY = (rect.height * (_camZoom - 1)) / 2; + _camTx = Math.max(-maxX, Math.min(maxX, _camTx)); + _camTy = Math.max(-maxY, Math.min(maxY, _camTy)); + } + + function onCameraWheel(event) { + if (!_camStage) return; + // Always preventDefault so the page doesn't scroll under the + // operator while they're framing a sample. + event.preventDefault(); + const rect = _camStage.getBoundingClientRect(); + const cx = event.clientX - rect.left - rect.width / 2; + const cy = event.clientY - rect.top - rect.height / 2; + const oldZoom = _camZoom; + const factor = event.deltaY < 0 ? _CAM_ZOOM_STEP : 1 / _CAM_ZOOM_STEP; + const newZoom = Math.max(_CAM_ZOOM_MIN, + Math.min(_CAM_ZOOM_MAX, oldZoom * factor)); + if (newZoom === oldZoom) return; + + // Keep the image point under the cursor anchored under the cursor + // across the zoom: cursor_new = cursor_old after the transform + // change, which means newT = cursor - (cursor - oldT) * (new/old). + const ratio = newZoom / oldZoom; + _camTx = cx - (cx - _camTx) * ratio; + _camTy = cy - (cy - _camTy) * ratio; + _camZoom = newZoom; + + if (Math.abs(_camZoom - 1) < 0.001) { + resetCameraZoom(); + return; + } + clampCameraPan(); + applyCameraTransform(); + _camStage.classList.add('camera-zoomed'); + } + + function onCameraPointerDown(event) { + if (event.button !== 0) return; + if (_camZoom <= 1) return; + _camPanLast = { x: event.clientX, y: event.clientY }; + try { _camStage.setPointerCapture(event.pointerId); } catch (_) {} + _camStage.classList.add('camera-panning'); + event.preventDefault(); + } + + function onCameraPointerMove(event) { + if (!_camPanLast) return; + _camTx += event.clientX - _camPanLast.x; + _camTy += event.clientY - _camPanLast.y; + _camPanLast = { x: event.clientX, y: event.clientY }; + clampCameraPan(); + applyCameraTransform(); + } + + function onCameraPointerEnd(event) { + if (!_camPanLast) return; + _camPanLast = null; + try { _camStage.releasePointerCapture(event.pointerId); } catch (_) {} + if (_camStage) _camStage.classList.remove('camera-panning'); + } + + function onCameraDoubleClick(event) { + if (_camZoom !== 1 || _camTx !== 0 || _camTy !== 0) { + event.preventDefault(); + resetCameraZoom(); + } + } + function setupCameraWiring() { if (!_camToggle) return; _camToggle.addEventListener('click', toggleCameraStream); @@ -888,6 +1280,210 @@ const DevicesManager = (function () { if (typeof ClientEventBus !== 'undefined') { ClientEventBus.on('BOTTOM_CAMERA_FRAME', handleCameraFrame); } + // Camera zoom/pan. wheel needs passive:false so we can preventDefault + // and stop the page from scrolling beneath the FOV. + if (_camStage) { + _camStage.addEventListener('wheel', onCameraWheel, { passive: false }); + _camStage.addEventListener('pointerdown', onCameraPointerDown); + _camStage.addEventListener('pointermove', onCameraPointerMove); + _camStage.addEventListener('pointerup', onCameraPointerEnd); + _camStage.addEventListener('pointercancel', onCameraPointerEnd); + _camStage.addEventListener('dblclick', onCameraDoubleClick); + } + } + + // ===================================================================== + // Room-light toggle + // ===================================================================== + + function applyRoomLight(state, available) { + _roomLightState = state || 'unknown'; + _roomLightAvailable = !!available; + if (!_roomLightToggle) return; + _roomLightToggle.hidden = !_roomLightAvailable; + _roomLightToggle.disabled = !_roomLightAvailable || _roomLightBusy; + const on = _roomLightState === 'on'; + _roomLightToggle.classList.toggle('is-on', on); + _roomLightToggle.setAttribute('aria-pressed', on ? 'true' : 'false'); + if (_roomLightLabel && !_roomLightBusy) { + _roomLightLabel.textContent = on ? 'Room light: on' + : (_roomLightState === 'off' ? 'Room light: off' : 'Room light'); + } + } + + async function loadRoomLightStatus() { + if (!_roomLightToggle || _roomLightBusy) return; + try { + const res = await fetch('/api/devices/room_light/status'); + if (!res.ok) { applyRoomLight('unknown', false); return; } + const data = await res.json(); + applyRoomLight(data.state, data.available); + } catch (err) { + console.debug('room light status fetch failed:', err); + applyRoomLight('unknown', false); + } + } + + async function toggleRoomLight() { + if (!_roomLightToggle || _roomLightBusy || !_roomLightAvailable) return; + const next = _roomLightState === 'on' ? 'off' : 'on'; + _roomLightBusy = true; + _roomLightToggle.classList.add('is-busy'); + _roomLightToggle.disabled = true; + if (_roomLightLabel) { + _roomLightLabel.textContent = next === 'on' ? 'Turning on…' : 'Turning off…'; + } + + // Settle back to the resolved state, or surface a transient message + // (insufficient control / error) for 2 s before reverting. + const finish = (msg) => { + _roomLightBusy = false; + _roomLightToggle.classList.remove('is-busy'); + if (msg) { + if (_roomLightLabel) _roomLightLabel.textContent = msg; + _roomLightToggle.disabled = false; + setTimeout(() => applyRoomLight(_roomLightState, _roomLightAvailable), 2000); + } else { + applyRoomLight(_roomLightState, _roomLightAvailable); + } + }; + + try { + const res = await fetch('/api/devices/room_light/set', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ state: next }), + }); + if (res.status === 401 || res.status === 403) { finish('Need control'); return; } + if (!res.ok) { + console.error('room light set failed:', await res.text()); + finish('Error'); + return; + } + const data = await res.json(); + _roomLightState = data.state || next; + finish(null); + } catch (err) { + console.error('room light toggle failed:', err); + finish('Error'); + } + } + + function setupRoomLight() { + if (!_roomLightToggle) return; + _roomLightToggle.addEventListener('click', toggleRoomLight); + loadRoomLightStatus(); + // Light periodic refresh: state can also change from agent plans + // (e.g. brightfield imaging turns it on). Status read is cached at the + // device layer (no BLE), so polling is cheap; it also makes the toggle + // appear automatically once the device layer connects. + if (_roomLightTimer) clearInterval(_roomLightTimer); + _roomLightTimer = setInterval(loadRoomLightStatus, 15000); + } + + // ===================================================================== + // Temperature controller (ACUITYnano) — readout + setpoint + // ===================================================================== + + function fmtTemp(v) { + return (v === null || v === undefined || isNaN(v)) ? '—' : Number(v).toFixed(1) + '°'; + } + + function applyTemperature(data) { + _tempAvailable = !!(data && data.available); + if (!_tempEl) return; + _tempEl.hidden = !_tempAvailable; + if (!_tempAvailable) return; + _tempState = (data && data.state) || 'unknown'; + const locked = /LOCK/i.test(_tempState); + _tempEl.classList.toggle('is-locked', locked); + if (_tempBusy) return; // a set() is in flight; leave its transient label + const cur = fmtTemp(data.temperature_c); + const hasSp = data.setpoint_c !== null && data.setpoint_c !== undefined; + const sp = hasSp ? fmtTemp(data.setpoint_c) : null; + _tempReadout.textContent = sp ? (cur + ' → ' + sp) : cur; + _tempReadout.title = 'Water ' + cur + (sp ? (', setpoint ' + sp) : '') + + (locked ? ' (locked)' : ''); + // Seed the input with the current setpoint once, while untouched, so the + // operator sees where it is before nudging it. + if (_tempInput && document.activeElement !== _tempInput && _tempInput.value === '' && hasSp) { + _tempInput.value = Number(data.setpoint_c).toFixed(1); + } + } + + async function loadTemperatureStatus() { + if (!_tempEl || _tempBusy) return; + try { + const res = await fetch('/api/devices/temperature/status'); + if (!res.ok) { applyTemperature({ available: false }); return; } + applyTemperature(await res.json()); + } catch (err) { + console.debug('temperature status fetch failed:', err); + applyTemperature({ available: false }); + } + } + + async function setTemperature() { + if (!_tempEl || _tempBusy || !_tempAvailable) return; + const target = parseFloat(_tempInput && _tempInput.value); + if (isNaN(target) || target < 0 || target > 99.9) { + _tempReadout.textContent = '0–99.9 °C'; + setTimeout(loadTemperatureStatus, 1500); + return; + } + _tempBusy = true; + _tempEl.classList.add('is-busy'); + if (_tempSet) _tempSet.disabled = true; + _tempReadout.textContent = 'Set ' + target.toFixed(1) + '°…'; + + // Settle back to the resolved state, or surface a transient message + // (insufficient control / error) for 2 s before reverting. + const finish = (msg) => { + _tempBusy = false; + _tempEl.classList.remove('is-busy'); + if (_tempSet) _tempSet.disabled = false; + if (msg) { + _tempReadout.textContent = msg; + setTimeout(loadTemperatureStatus, 2000); + } else { + loadTemperatureStatus(); // controller ramps; poll shows progress + } + }; + + try { + const res = await fetch('/api/devices/temperature/set', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ target_c: target }), + }); + if (res.status === 401 || res.status === 403) { finish('Need control'); return; } + if (!res.ok) { + console.error('temperature set failed:', await res.text()); + finish('Error'); + return; + } + await res.json(); + finish(null); + } catch (err) { + console.error('temperature set failed:', err); + finish('Error'); + } + } + + function setupTemperature() { + if (!_tempEl) return; + if (_tempSet) _tempSet.addEventListener('click', setTemperature); + if (_tempInput) { + _tempInput.addEventListener('keydown', (e) => { + if (e.key === 'Enter') { e.preventDefault(); setTemperature(); } + }); + } + loadTemperatureStatus(); + // Periodic refresh: the setpoint can also change from agent plans, and a + // commanded ramp settles over time. Status is cached at the device layer, + // so polling is cheap; it also reveals the control once the layer connects. + if (_tempTimer) clearInterval(_tempTimer); + _tempTimer = setInterval(loadTemperatureStatus, 15000); } // ===================================================================== @@ -949,18 +1545,27 @@ const DevicesManager = (function () { cacheDom(); setupViewSwitcher(); setupCameraWiring(); + setupRoomLight(); + setupTemperature(); loadCoverslip(); - loadEmbryos(); + loadEmbryosSnapshot(); switchView(_currentView); if (typeof ClientEventBus !== 'undefined') { ClientEventBus.on('DEVICE_STATE_UPDATE', handlePayload); - // Embryo events: a fresh marking session emits one - // EMBRYO_DETECTED per registered embryo (via - // ExperimentState.add_embryo). assign_embryo_roles emits - // STATUS_CHANGED with change=role_assigned per change. + ClientEventBus.on('EMBRYOS_UPDATE', handleEmbryosUpdate); + // Belt-and-braces: also listen for the fine-grained events that + // existed before EMBRYOS_UPDATE so direct emitters still refresh. ClientEventBus.on('EMBRYO_DETECTED', handleEmbryoDetected); ClientEventBus.on('STATUS_CHANGED', handleStatusChanged); } + // Map-side edit handlers. Pointer events on the SVG cover both + // "click an embryo" (selects it) and "click empty map" (drops the + // selected embryo). Keyboard listener is document-wide but guards + // against firing while an input is focused. + if (_mapSvg) { + _mapSvg.addEventListener('pointerdown', onMapPointerDown); + } + document.addEventListener('keydown', onMapKeyDown); setStatus('stale', 'waiting', 'no payload yet'); syncInitialCameraState(); // Stop the camera stream if the tab is closed while it's running, diff --git a/gently/ui/web/static/js/embryos.js b/gently/ui/web/static/js/embryos.js index f8b7cea3..29340cf5 100644 --- a/gently/ui/web/static/js/embryos.js +++ b/gently/ui/web/static/js/embryos.js @@ -59,7 +59,7 @@ const EmbryosManager = { dashboardConfig: { defaultView: 'default', board: { - columns: ['stage', 'confidence', 'rate', 'eta', 'sparkline', 'alert'], + columns: ['stage', 'clock', 'stereo', 'pace', 'eta', 'sparkline', 'alert'], sparklineLength: 20, warnOvertimeRatio: 1.5, criticalOvertimeRatio: 2.5 @@ -266,6 +266,23 @@ const EmbryosManager = { // Deep merge with defaults this.dashboardConfig = this._deepMerge(this.dashboardConfig, parsed); } + // Migrate legacy board columns: drop the never-populated + // 'confidence' column and the misleading 'rate' column in + // favour of clock/stereo/pace. Idempotent — runs on every load. + const cols = this.dashboardConfig.board?.columns; + if (Array.isArray(cols)) { + const filtered = cols.filter(c => c !== 'confidence' && c !== 'rate'); + const ensure = (key, after) => { + if (filtered.includes(key)) return; + const idx = filtered.indexOf(after); + if (idx === -1) filtered.push(key); + else filtered.splice(idx + 1, 0, key); + }; + ensure('clock', 'stage'); + ensure('stereo', 'clock'); + ensure('pace', 'stereo'); + this.dashboardConfig.board.columns = filtered; + } } catch (e) { console.warn('Failed to load dashboard config:', e); } @@ -370,9 +387,10 @@ const EmbryosManager = {
Embryo ${cols.includes('stage') ? 'Stage' : ''} - ${cols.includes('confidence') ? 'Conf' : ''} - ${cols.includes('rate') ? 'Rate' : ''} - ${cols.includes('eta') ? 'ETA' : ''} + ${cols.includes('clock') ? 'Clock' : ''} + ${cols.includes('stereo') ? 'Stereo' : ''} + ${cols.includes('pace') ? 'Pace' : ''} + ${cols.includes('eta') ? 'ETA' : ''} ${cols.includes('sparkline') ? 'Progression' : ''} ${cols.includes('alert') ? 'Alert' : ''}
@@ -412,54 +430,27 @@ const EmbryosManager = { const latest = reasoning.length > 0 ? reasoning[reasoning.length - 1] : null; const cols = this.dashboardConfig.board.columns; - // Stage const stage = latest?.stage || embryo.current_stage || '—'; const stageIcon = this.getStageIcon(stage); const stageName = this.formatStageName(stage); - // Confidence - const conf = latest ? this.normalizeConfidence(latest.confidence) : 'unknown'; - const confDots = conf === 'high' ? '●●●' : conf === 'medium' ? '●●○' : conf === 'low' ? '●○○' : '○○○'; - const confClass = conf === 'high' ? 'conf-high' : conf === 'medium' ? 'conf-med' : 'conf-low'; + const align = this._computeAlignment(latest); + const overtime = align?.overtime; - // Rate - const overtime = latest?.temporal_analysis?.overtime_ratio; - let rateText = '—'; - let rateClass = ''; - if (overtime != null) { - const rate = (1 / overtime).toFixed(1); - rateText = overtime < 0.9 ? `${rate}x↑` : overtime > 1.1 ? `${rate}x↓` : `${rate}x→`; - rateClass = overtime < 0.9 ? 'rate-fast' : overtime > 1.5 ? 'rate-slow' : 'rate-normal'; - } + const clockText = align ? this._formatMinutes(align.inStageClockMin) : '—'; + const stereoText = align ? this._formatStereoLabel(align) : '—'; + const pace = align ? this._formatPace(align) : { text: '—', className: '' }; + const eta = align ? this._formatEta(align) : '—'; - // ETA - let eta = '—'; - if (stage && this.STAGE_TIMING[stage] != null) { - const stageMinutes = this.STAGE_TIMING[stage]; - const hatchMinutes = this.STAGE_TIMING['hatched'] || 570; - const remaining = hatchMinutes - stageMinutes; - if (remaining > 0) { - const hours = (remaining / 60).toFixed(1); - eta = `~${hours}h`; - } else { - eta = 'done'; - } - } - - // Sparkline const sparklineSvg = cols.includes('sparkline') ? this._renderBoardSparkline(reasoning) : ''; - // Alert const arrested = latest?.temporal_analysis?.is_potentially_arrested; const slow = overtime && overtime > (this.dashboardConfig.board.warnOvertimeRatio || 1.5); - const lowConf = conf === 'low'; let alertHtml = ''; if (arrested) { alertHtml = '⚠ arrested'; } else if (slow) { - alertHtml = `⚠ slow ${overtime.toFixed(1)}x`; - } else if (lowConf) { - alertHtml = '⚠ low conf'; + alertHtml = `⚠ slow ${overtime.toFixed(1)}×`; } const status = embryo.isComplete ? 'complete' : embryo.lastError ? 'error' : 'running'; @@ -472,8 +463,9 @@ const EmbryosManager = { ${embryo.embryoId.replace(/embryo_?/i, 'E')} ${cols.includes('stage') ? `${stageIcon} ${stageName}` : ''} - ${cols.includes('confidence') ? `${confDots}` : ''} - ${cols.includes('rate') ? `${rateText}` : ''} + ${cols.includes('clock') ? `${clockText}` : ''} + ${cols.includes('stereo') ? `${stereoText}` : ''} + ${cols.includes('pace') ? `${pace.text}` : ''} ${cols.includes('eta') ? `${eta}` : ''} ${cols.includes('sparkline') ? `${sparklineSvg}` : ''} ${cols.includes('alert') ? `${alertHtml}` : ''} @@ -481,6 +473,99 @@ const EmbryosManager = { `; }, + /** Compute clock↔stereotypic alignment from perception temporal_analysis. + * + * Definitions: + * inStageClockMin — wall-clock minutes elapsed in current stage + * inStageStereoMin — stereotypic minutes "used" within the stage, + * capped at the stage's expected duration. An + * overdue embryo is stuck at the stage end in + * stereo time while clock keeps ticking. + * overtime — ratio inStageClockMin / expected_duration. + * >1 means the embryo has spent more clock time + * in the stage than the reference 20°C textbook + * duration. <1 just means "still within stage" — + * no slow/fast signal yet. + * stereoAgeMin — total stereotypic age, anchored at the start + * minute of the current stage in the reference + * table plus the (capped) in-stage stereo offset. + */ + _computeAlignment(latest) { + const ta = latest?.temporal_analysis; + if (!ta || !ta.current_stage) return null; + const stage = ta.current_stage; + const stageStart = this.STAGE_TIMING[stage]; + if (stageStart == null) return null; + + const expDur = Number(ta.expected_duration_min) || 0; + const inClock = Number(ta.time_in_stage_min) || 0; + const overtime = Number(ta.overtime_ratio) || 0; + + const inStereo = expDur > 0 ? Math.min(inClock, expDur) : inClock; + const stereoAge = stageStart + inStereo; + + return { + stage, + stageStart, + expDur, + inStageClockMin: inClock, + inStageStereoMin: inStereo, + stereoAgeMin: stereoAge, + overtime, + }; + }, + + /** Render the stereo cell: "≈early", "≈bean +12m", or "≈comma +88m ⚠" + * when overdue (stereo capped at stage end while clock keeps running). */ + _formatStereoLabel(align) { + const stageName = this.formatStageName(align.stage); + const offsetMin = Math.round(align.inStageStereoMin); + const overdue = align.expDur > 0 && align.inStageClockMin > align.expDur + 1; + const offsetStr = offsetMin > 0 ? ` +${offsetMin}m` : ''; + const overdueMark = overdue ? ' ' : ''; + return `≈${stageName}${offsetStr}${overdueMark}`; + }, + + _formatPace(align) { + // Only emit a pace signal once we have meaningful clock data. + // Within the first few minutes the ratio is tiny and noisy — show + // a dashed placeholder so the column doesn't lie about precision. + const NORMAL_BAND = 1.05; + const SLOW_BAND = 1.5; + if (align.inStageClockMin < 1 || align.expDur <= 0) { + return { text: '—', className: 'pace-unknown' }; + } + const r = align.overtime; + if (r <= NORMAL_BAND) { + return { text: '1.0×', className: 'pace-normal' }; + } + if (r <= SLOW_BAND) { + return { text: `${r.toFixed(1)}× slow`, className: 'pace-slow' }; + } + return { text: `⚠ ${r.toFixed(1)}×`, className: 'pace-slow-bad' }; + }, + + /** ETA in hours from current stereotypic position to hatched, scaled + * by observed pace when the embryo is demonstrably slow. */ + _formatEta(align) { + const hatchStereo = this.STAGE_TIMING['hatched'] || 570; + const remainStereo = hatchStereo - align.stereoAgeMin; + if (remainStereo <= 0) return 'done'; + const paceFactor = align.overtime > 1.05 ? align.overtime : 1.0; + const remainClockMin = remainStereo * paceFactor; + return `~${(remainClockMin / 60).toFixed(1)}h`; + }, + + /** Compact minute formatter: "45s" / "10m" / "1h 22m" / "3h". */ + _formatMinutes(min) { + if (min == null || !isFinite(min)) return '—'; + if (min < 1) return `${Math.round(min * 60)}s`; + if (min < 60) return `${Math.round(min)}m`; + const h = Math.floor(min / 60); + const m = Math.round(min - h * 60); + return m > 0 ? `${h}h ${m}m` : `${h}h`; + }, + _renderBoardSparkline(reasoning) { if (!reasoning.length) return ''; const sorted = [...reasoning].sort((a, b) => (a.timepoint ?? 0) - (b.timepoint ?? 0)); @@ -565,12 +650,24 @@ const EmbryosManager = { const shortName = embryo.embryoId.replace(/embryo_?/i, 'E'); const latestStage = sorted.length > 0 ? this.formatStageName(sorted[sorted.length - 1].stage) : '—'; + const isTerminated = !!embryo.isComplete; + const termReason = embryo.completionReason || ''; + // Short label for the badge — humanise the no_object terminal + // reason, otherwise keep the first clause of whatever the + // backend sent so the user still gets a hint. + const termBadge = isTerminated + ? (termReason.includes('no_object') ? 'HATCHED?' : 'STOPPED') + : ''; + const termTooltip = isTerminated + ? `Terminated — ${termReason || 'no reason given'}` + : ''; - html += `
`; + html += `
`; html += `
${shortName} ${latestStage} ${reasoning.length} eval + ${isTerminated ? `${termBadge}` : ''}
`; html += `
`; @@ -1000,7 +1097,7 @@ const EmbryosManager = { intervalSeconds: embryoData.interval_seconds || this.state.baseInterval, timepoints: embryoData.timepoints || 0, isComplete: embryoData.is_complete || false, - completionReason: null, + completionReason: embryoData.completion_reason || null, firstAcquired: embryoData.first_acquired ? new Date(embryoData.first_acquired) : null, lastAcquired: embryoData.last_acquired ? new Date(embryoData.last_acquired) : null, detections: embryoData.detections || {}, @@ -2726,10 +2823,17 @@ const EmbryosManager = { `; } - // Format confidence display - const confDisplay = typeof item.confidence === 'number' - ? `${Math.round(item.confidence * 100)}%` - : (item.confidence || 'Unknown'); + // Format confidence display. Hide entirely when the detector + // doesn't emit a probabilistic confidence (e.g. dopaminergic_signal + // returns structured intensity/structure findings instead) — the + // string "Unknown confidence" was actively confusing. + const hasNumericConf = typeof item.confidence === 'number'; + const hasTextConf = typeof item.confidence === 'string' && item.confidence.trim() !== ''; + const confHtml = hasNumericConf + ? `${Math.round(item.confidence * 100)}% confidence` + : hasTextConf + ? `${item.confidence}` + : ''; return `
@@ -2748,7 +2852,7 @@ const EmbryosManager = {
${item.stage ? this.formatStageName(item.stage) : (item.detected ? 'DETECTED' : 'Not detected')} - ${confDisplay} confidence + ${confHtml} ${transitionalHtml}
${detectorFindingsHtml} @@ -2979,13 +3083,22 @@ const EmbryosManager = { container.classList.remove('visible'); container.innerHTML = ''; } + // Filmstrip side panel — clearing innerHTML lets the :empty CSS + // rule collapse the panel and let the rows reclaim full width. + const filmstripDetail = document.getElementById('filmstrip-detail'); + if (filmstripDetail) { + filmstripDetail.innerHTML = ''; + } this.detailPanelVisible = false; this.currentDetailItem = null; - // Clear eval dot highlight + // Clear eval dot + filmstrip cell highlight document.querySelectorAll('.eval-dot.active').forEach(dot => { dot.classList.remove('active'); }); + document.querySelectorAll('.filmstrip-cell.active').forEach(cell => { + cell.classList.remove('active'); + }); }, // Navigate to previous/next item in detail panel diff --git a/gently/ui/web/static/js/events.js b/gently/ui/web/static/js/events.js index 02365e40..998431a1 100644 --- a/gently/ui/web/static/js/events.js +++ b/gently/ui/web/static/js/events.js @@ -41,6 +41,40 @@ function getEventBadgeClass(eventType) { return 'default'; } +// Log-record helpers -------------------------------------------------- +// LOG_RECORD events come from the Python logging bridge. We collapse the +// generic "LOG_RECORD" type into the actual level (DEBUG / INFO / WARN / +// ERROR) so the table is readable -- otherwise every row in a busy +// session reads the same string in the Type column. +function isLogEvent(event) { + return event && event.event_type === 'LOG_RECORD'; +} + +function logLevelLabel(d) { + // levelname is fastest path; fall back to numeric mapping if missing. + const lvl = (d && (d.level_name || '')).toString().toUpperCase(); + if (lvl) { + if (lvl === 'WARNING') return 'WARN'; + if (lvl === 'CRITICAL') return 'CRIT'; + return lvl; + } + const n = d && Number(d.level); + if (!isFinite(n)) return 'LOG'; + if (n >= 50) return 'CRIT'; + if (n >= 40) return 'ERROR'; + if (n >= 30) return 'WARN'; + if (n >= 20) return 'INFO'; + return 'DEBUG'; +} + +function logBadgeClass(d) { + const label = logLevelLabel(d); + if (label === 'DEBUG') return 'log-debug'; + if (label === 'INFO') return 'log-info'; + if (label === 'WARN') return 'log-warn'; + return 'log-error'; // ERROR / CRIT collapse together +} + // Search helper functions function escapeRegex(str) { return str.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); @@ -67,12 +101,17 @@ function eventMatchesSearch(event) { } function highlightSearchTerms(text) { - if (!searchQuery || !text) return text; + // Escape first — event keys/values/messages are arbitrary text (perception + // prose, file paths, agent output) and are inserted via innerHTML by the + // callers. Escaping here closes the XSS hole at every call site; the + // injected tags are the only markup we add. + const safe = escapeHtml(text == null ? '' : String(text)); + if (!searchQuery) return safe; try { const regex = new RegExp(`(${escapeRegex(searchQuery)})`, 'gi'); - return String(text).replace(regex, '$1'); + return safe.replace(regex, '$1'); } catch (e) { - return text; + return safe; } } @@ -212,45 +251,81 @@ function addEventToTable(event, prepend = true) { if (hasImage) tr.classList.add('has-image'); tr.dataset.eventId = event.event_id || ''; - const badgeClass = getEventBadgeClass(event.event_type); - - // Image indicator icon - const imageIndicator = hasImage - ? ` - - - - - - ` - : ''; - - // Thumbnail preview - const thumbnailHtml = hasImage - ? `Event image` - : ''; - - tr.innerHTML = ` - ${formatEventTime(event.timestamp)} - ${imageIndicator}${event.event_type} - ${event.source || '-'} - ${thumbnailHtml}
${formatEventData(event.data)}
- `; - - // Click to expand data - tr.addEventListener('click', () => { - const dataDiv = tr.querySelector('.event-data'); - dataDiv.classList.toggle('expanded'); - if (dataDiv.classList.contains('expanded')) { - dataDiv.innerHTML = `
${JSON.stringify(event.data, null, 2)}
`; - } else { - dataDiv.innerHTML = formatEventData(event.data); - } - }); + if (isLogEvent(event)) { + // Log rows have a compact, distinctive shape: level badge in the + // Type column, logger name + message in the Data column. Click to + // toggle a pre with the full payload (incl. exception trace). + tr.classList.add('log-row'); + const d = event.data || {}; + const badgeCls = logBadgeClass(d); + const label = logLevelLabel(d); + const message = highlightSearchTerms(d.message || ''); + const loggerName = highlightSearchTerms(d.logger || '-'); + const excTag = d.exc_text ? ' ⏎ trace…' : ''; + tr.innerHTML = ` + ${formatEventTime(event.timestamp)} + ${label} + ${event.source || '-'} +
+ ${loggerName}${message}${excTag} +
+ `; + tr.addEventListener('click', () => { + const dataDiv = tr.querySelector('.event-data'); + dataDiv.classList.toggle('expanded'); + if (dataDiv.classList.contains('expanded')) { + const tracePart = d.exc_text + ? `\n\n${d.exc_text}` : ''; + dataDiv.innerHTML = + `
${d.logger || ''}  ${d.func || ''}:${d.line || ''}\n` +
+                    `${(d.message || '')}${tracePart}
`; + } else { + dataDiv.innerHTML = + `${loggerName}` + + `${message}${excTag}`; + } + }); + } else { + const badgeClass = getEventBadgeClass(event.event_type); + + // Image indicator icon + const imageIndicator = hasImage + ? ` + + + + + + ` + : ''; + + // Thumbnail preview + const thumbnailHtml = hasImage + ? `Event image` + : ''; + + tr.innerHTML = ` + ${formatEventTime(event.timestamp)} + ${imageIndicator}${event.event_type} + ${event.source || '-'} + ${thumbnailHtml}
${formatEventData(event.data)}
+ `; + + // Click to expand data + tr.addEventListener('click', () => { + const dataDiv = tr.querySelector('.event-data'); + dataDiv.classList.toggle('expanded'); + if (dataDiv.classList.contains('expanded')) { + dataDiv.innerHTML = `
${JSON.stringify(event.data, null, 2)}
`; + } else { + dataDiv.innerHTML = formatEventData(event.data); + } + }); + } if (prepend) { tbody.insertBefore(tr, tbody.firstChild); diff --git a/gently/ui/web/static/js/gallery.js b/gently/ui/web/static/js/gallery.js index 51ee4207..42159d18 100644 --- a/gently/ui/web/static/js/gallery.js +++ b/gently/ui/web/static/js/gallery.js @@ -422,13 +422,19 @@ const CalibrationProfileView = { /** Compact SPIM live indicator used inside the metrics strip. * Carries the same IDs as the old big preview so SpimLivePreview's - * apply-on-render logic continues to work unchanged. */ + * apply-on-render logic continues to work unchanged. The thumb is a + * button — click to open the floating popout for a larger view. */ _renderSpimIndicator() { return `
SPIM - +
@@ -1329,21 +1335,33 @@ const SpimLivePreview = { const placeholder = document.getElementById('cal-spim-placeholder'); const metaEl = document.getElementById('cal-spim-meta'); const led = document.getElementById('cal-spim-led'); - if (!img) return; // not in profile view const latest = embryoId ? this._latestByEmbryo[embryoId] : null; - if (latest) { - img.src = `data:image/png;base64,${latest.base64_png}`; - img.classList.add('has-frame'); - if (placeholder) placeholder.hidden = true; - if (metaEl) metaEl.textContent = this._formatMeta(latest); - if (led) led.classList.remove('idle'); - } else { - img.removeAttribute('src'); - img.classList.remove('has-frame'); - if (placeholder) placeholder.hidden = false; - if (metaEl) metaEl.textContent = '—'; - if (led) led.classList.add('idle'); + + if (img) { + if (latest) { + img.src = `data:image/png;base64,${latest.base64_png}`; + img.classList.add('has-frame'); + if (placeholder) placeholder.hidden = true; + if (metaEl) metaEl.textContent = this._formatMeta(latest); + if (led) led.classList.remove('idle'); + } else { + img.removeAttribute('src'); + img.classList.remove('has-frame'); + if (placeholder) placeholder.hidden = false; + if (metaEl) metaEl.textContent = '—'; + if (led) led.classList.add('idle'); + } + } + + // Mirror into popout if it's open — the popout lives outside the + // calibration panel's innerHTML reset, so we paint it independently. + if (typeof SpimPopout !== 'undefined') { + SpimPopout.paint(latest ? { + base64_png: latest.base64_png, + meta: this._formatMeta(latest), + embryoId, + } : null); } }, @@ -1375,6 +1393,227 @@ const SpimLivePreview = { document.addEventListener('DOMContentLoaded', () => SpimLivePreview.init()); +// ========================================== +// SPIM live popout (floating draggable window) +// ========================================== +// Lazy-built floating window that mirrors SpimLivePreview at a larger +// size. Draggable via the header bar, resizable from the bottom-right +// corner. Position and size persist in localStorage so the window +// re-opens where the operator last left it. Closes on Escape. +const SpimPopout = { + _STORAGE_KEY: 'gently.spimPopout.v1', + _root: null, + _isOpen: false, + + _ensureBuilt() { + if (this._root) return this._root; + + const el = document.createElement('div'); + el.className = 'cal-spim-popout'; + el.id = 'cal-spim-popout'; + el.hidden = true; + el.innerHTML = ` +
+ + SPIM Live + + + +
+
+ +
+ Awaiting SPIM frame… +
+
+ + `; + document.body.appendChild(el); + this._root = el; + + // Restore persisted geometry + const saved = this._loadGeometry(); + if (saved) { + el.style.left = `${saved.left}px`; + el.style.top = `${saved.top}px`; + el.style.width = `${saved.width}px`; + el.style.height = `${saved.height}px`; + } + + el.querySelector('#cal-spim-popout-close').addEventListener('click', () => this.close()); + this._wireDrag(el); + this._wireResizeObserver(el); + + return el; + }, + + open() { + const el = this._ensureBuilt(); + if (this._isOpen) return; + el.hidden = false; + this._isOpen = true; + + // Clamp into viewport in case window was resized while popout was hidden + this._clampIntoViewport(el); + + // Paint current frame for the selected embryo + const selected = (typeof CalibrationManager !== 'undefined') + ? CalibrationManager.selectedEmbryoId : null; + if (selected && typeof SpimLivePreview !== 'undefined') { + const latest = SpimLivePreview._latestByEmbryo[selected]; + this.paint(latest ? { + base64_png: latest.base64_png, + meta: SpimLivePreview._formatMeta(latest), + embryoId: selected, + } : null); + } else { + this.paint(null); + } + + document.addEventListener('keydown', this._onKey); + }, + + close() { + if (!this._root || !this._isOpen) return; + this._root.hidden = true; + this._isOpen = false; + document.removeEventListener('keydown', this._onKey); + }, + + toggle() { + this._isOpen ? this.close() : this.open(); + }, + + /** Called by SpimLivePreview whenever the current embryo's latest + * frame changes. Frame is {base64_png, meta, embryoId} or null. */ + paint(frame) { + if (!this._root || !this._isOpen) return; + const img = this._root.querySelector('#cal-spim-popout-img'); + const placeholder = this._root.querySelector('#cal-spim-popout-placeholder'); + const meta = this._root.querySelector('#cal-spim-popout-meta'); + const embryoEl = this._root.querySelector('#cal-spim-popout-embryo'); + const led = this._root.querySelector('#cal-spim-popout-led'); + + if (frame) { + img.src = `data:image/png;base64,${frame.base64_png}`; + img.classList.add('has-frame'); + placeholder.hidden = true; + meta.textContent = frame.meta || '—'; + embryoEl.textContent = frame.embryoId || ''; + led.classList.remove('idle'); + } else { + img.removeAttribute('src'); + img.classList.remove('has-frame'); + placeholder.hidden = false; + meta.textContent = '—'; + embryoEl.textContent = ''; + led.classList.add('idle'); + } + }, + + _onKey: (e) => { + if (e.key === 'Escape') SpimPopout.close(); + }, + + _wireDrag(el) { + const header = el.querySelector('#cal-spim-popout-header'); + let dragging = false; + let startX = 0, startY = 0, startLeft = 0, startTop = 0; + + header.addEventListener('pointerdown', (e) => { + // Don't start drag on the close button + if (e.target.closest('.cal-spim-popout-close')) return; + dragging = true; + const rect = el.getBoundingClientRect(); + startX = e.clientX; + startY = e.clientY; + startLeft = rect.left; + startTop = rect.top; + // Switch to absolute positioning if currently default + el.style.left = `${startLeft}px`; + el.style.top = `${startTop}px`; + el.style.right = 'auto'; + el.style.bottom = 'auto'; + header.setPointerCapture(e.pointerId); + el.classList.add('dragging'); + }); + + header.addEventListener('pointermove', (e) => { + if (!dragging) return; + const dx = e.clientX - startX; + const dy = e.clientY - startY; + let nextLeft = startLeft + dx; + let nextTop = startTop + dy; + // Keep at least 40px of header on-screen + const w = el.offsetWidth; + const h = el.offsetHeight; + nextLeft = Math.max(-(w - 80), Math.min(window.innerWidth - 80, nextLeft)); + nextTop = Math.max(0, Math.min(window.innerHeight - 40, nextTop)); + el.style.left = `${nextLeft}px`; + el.style.top = `${nextTop}px`; + }); + + const endDrag = (e) => { + if (!dragging) return; + dragging = false; + el.classList.remove('dragging'); + try { header.releasePointerCapture(e.pointerId); } catch (_) {} + this._saveGeometry(el); + }; + header.addEventListener('pointerup', endDrag); + header.addEventListener('pointercancel', endDrag); + }, + + _wireResizeObserver(el) { + if (typeof ResizeObserver === 'undefined') return; + let saveTimer = null; + const ro = new ResizeObserver(() => { + if (!this._isOpen) return; + clearTimeout(saveTimer); + saveTimer = setTimeout(() => this._saveGeometry(el), 250); + }); + ro.observe(el); + }, + + _clampIntoViewport(el) { + const rect = el.getBoundingClientRect(); + if (rect.left + 80 > window.innerWidth || rect.top + 40 > window.innerHeight + || rect.left < -(rect.width - 80) || rect.top < 0) { + // Recenter + const w = Math.min(rect.width || 520, window.innerWidth - 40); + const h = Math.min(rect.height || 440, window.innerHeight - 40); + el.style.width = `${w}px`; + el.style.height = `${h}px`; + el.style.left = `${Math.max(20, (window.innerWidth - w) / 2)}px`; + el.style.top = `${Math.max(20, (window.innerHeight - h) / 2)}px`; + } + }, + + _saveGeometry(el) { + const rect = el.getBoundingClientRect(); + const data = { + left: Math.round(rect.left), + top: Math.round(rect.top), + width: Math.round(rect.width), + height: Math.round(rect.height), + }; + try { localStorage.setItem(this._STORAGE_KEY, JSON.stringify(data)); } catch (_) {} + }, + + _loadGeometry() { + try { + const raw = localStorage.getItem(this._STORAGE_KEY); + if (!raw) return null; + const data = JSON.parse(raw); + if (typeof data.left !== 'number') return null; + return data; + } catch (_) { return null; } + }, +}; + // Legacy wrappers kept for backward compatibility function renderCalibrationGallery() { CalibrationManager.render(); } diff --git a/gently/ui/web/static/js/home.js b/gently/ui/web/static/js/home.js new file mode 100644 index 00000000..089d7de3 --- /dev/null +++ b/gently/ui/web/static/js/home.js @@ -0,0 +1,177 @@ +/** + * HomeApp — the landing tab. + * + * A light at-a-glance landing surface: recent sessions, recent plans, recent + * images, a thin status line, and a "Start / continue an experiment" button + * that launches the setup flow (the wizard, which no longer auto-pops in chat). + * + * Read-only fetches against existing endpoints (/api/sessions, /api/campaigns, + * /api/home/recent-images); mirrors the ReviewApp/CampaignsApp module pattern. + */ +const HomeApp = (() => { + let _inited = false; + const SESSIONS_N = 5; + const CAMPAIGNS_N = 5; + const IMAGES_N = 8; + // Recent images are stable (latest projection per embryo). refresh() runs on + // every Home-tab entry, so guard against redundant disk-walking fetches: + // skip if one is in flight or the strip was loaded within IMAGES_TTL_MS. + const IMAGES_TTL_MS = 15000; + let _imgState = { at: 0, inflight: false }; + + function relTime(iso) { + if (!iso) return ''; + const t = Date.parse(iso); + if (isNaN(t)) return ''; + const s = Math.max(0, (Date.now() - t) / 1000); + if (s < 60) return 'just now'; + if (s < 3600) return `${Math.floor(s / 60)}m ago`; + if (s < 86400) return `${Math.floor(s / 3600)}h ago`; + const d = Math.floor(s / 86400); + return d < 30 ? `${d}d ago` : new Date(t).toLocaleDateString(); + } + + function empty(el, msg) { + el.innerHTML = `
${escapeHtml(msg)}
`; + } + + function wireGoTab(scope) { + (scope || document).querySelectorAll('[data-go-tab]').forEach(el => { + if (el._goWired) return; + el._goWired = true; + el.addEventListener('click', (e) => { + e.preventDefault(); + if (typeof switchTab === 'function') switchTab(el.dataset.goTab); + }); + }); + } + + async function loadSessions() { + const el = document.getElementById('home-recent-sessions'); + if (!el) return; + try { + const data = await (await fetch('/api/sessions')).json(); + const sessions = (data.sessions || []).slice(0, SESSIONS_N); + if (!sessions.length) { empty(el, 'No sessions yet.'); return; } + el.innerHTML = sessions.map(s => { + const live = s.active ? 'live' : ''; + const resume = s.active ? '' : + ``; + return `
+
+
${escapeHtml(s.name || s.session_id)}${live}
+ ${escapeHtml(relTime(s.last_active))} · ${s.embryo_count || 0} embryos +
${resume} +
`; + }).join(''); + el.querySelectorAll('[data-resume]').forEach(b => b.addEventListener('click', async () => { + b.disabled = true; + b.textContent = 'Resuming…'; + try { + await fetch(`/api/sessions/${encodeURIComponent(b.dataset.resume)}/resume`, { method: 'POST' }); + } catch (_) { b.disabled = false; b.textContent = 'Resume'; } + })); + } catch (e) { empty(el, 'Could not load sessions.'); } + } + + async function loadCampaigns() { + const el = document.getElementById('home-recent-campaigns'); + if (!el) return; + try { + const data = await (await fetch('/api/campaigns')).json(); + const items = (data.campaigns || []).slice(0, CAMPAIGNS_N); + if (!items.length) { empty(el, 'No plans yet.'); return; } + el.innerHTML = items.map(t => { + const c = t.campaign || {}; + const st = t.status || {}; + const name = c.shorthand || c.description || 'Untitled plan'; + const total = st.total || 0; + const chip = total ? `${st.completed || 0}/${total} done` : ''; + return `
+ ${escapeHtml(name)}${chip} +
`; + }).join(''); + wireGoTab(el); + } catch (e) { empty(el, 'Could not load plans.'); } + } + + async function loadImages(force) { + const el = document.getElementById('home-recent-images'); + if (!el) return; + if (_imgState.inflight) return; + // _imgState.at is set only after a completed fetch (images or empty), + // never after an error — so failures still retry on the next entry. + if (!force && _imgState.at && (Date.now() - _imgState.at) < IMAGES_TTL_MS) return; + _imgState.inflight = true; + try { + const data = await (await fetch(`/api/home/recent-images?limit=${IMAGES_N}`)).json(); + // Latest projection per embryo across recent sessions (server orders + // most-recent session first). + const recent = (data.images || []).slice(0, IMAGES_N); + if (!recent.length) { + empty(el, 'No images yet — they appear once a session has captured volumes.'); + _imgState.at = Date.now(); + return; + } + el.innerHTML = '
' + recent.map(s => { + const tp = (s.timepoint != null) ? ` · t${s.timepoint}` : ''; + const label = `${s.embryo_id || ''}${tp}`; + const sub = s.session_name && s.session_name !== s.session_id + ? ` (${s.session_name})` : ''; + const src = `/api/sessions/${encodeURIComponent(s.session_id)}` + + `/projection?embryo=${encodeURIComponent(s.embryo_id)}` + + `&t=${encodeURIComponent(s.timepoint)}`; + return `
+ ${escapeHtml(label)} +
`; + }).join('') + '
'; + _imgState.at = Date.now(); + } catch (e) { + empty(el, 'Could not load images.'); + } finally { + _imgState.inflight = false; + } + } + + function updateStatus() { + const el = document.getElementById('home-status'); + if (!el) return; + const connected = (typeof state !== 'undefined' && state.connected); + const n = (typeof state !== 'undefined' && Array.isArray(state.embryos)) ? state.embryos.length : 0; + el.textContent = connected + ? `Connected · ${n} embryo${n === 1 ? '' : 's'} in view` + : 'Offline — start the agent to connect.'; + } + + function refresh() { + updateStatus(); + loadSessions(); + loadCampaigns(); + loadImages(); + } + + function init() { + if (!_inited) { + _inited = true; + wireGoTab(document.getElementById('home-content')); + const start = document.getElementById('home-start-btn'); + if (start) start.addEventListener('click', () => { + if (typeof AgentChat !== 'undefined' && AgentChat.togglePanel) { + AgentChat.togglePanel(true); + // Let the panel's WS connect before sending the command. + if (AgentChat.runCommand) setTimeout(() => AgentChat.runCommand('/wizard'), 250); + } + }); + } + refresh(); // re-fetch on every entry to the tab + } + + // Self-initialise on load when Home is the default-active tab (switchTab's + // lazy-init hook only fires on a tab click / hash route, not initial paint). + document.addEventListener('DOMContentLoaded', () => { + const home = document.getElementById('home-content'); + if (home && home.classList.contains('active')) init(); + }); + + return { init, refresh }; +})(); diff --git a/gently/ui/web/static/js/projection-viewer.js b/gently/ui/web/static/js/projection-viewer.js index 1f5f530e..0e20b9fa 100644 --- a/gently/ui/web/static/js/projection-viewer.js +++ b/gently/ui/web/static/js/projection-viewer.js @@ -118,6 +118,10 @@ const ProjectionViewer = { this.projections = []; this.selectedMethod = null; this.isOpen = true; + // Clear any volume from a previous open so a failed /api/volume-raw fetch + // can't leave the prior embryo/timepoint's 3D data bound (stale-render). + this.volumeData = null; + this.volumeShape = null; const modal = document.getElementById('projection-viewer-modal'); const loading = document.getElementById('pv-loading'); @@ -280,6 +284,10 @@ const ProjectionViewer = { }, selectMethod(method) { + // If the 3D view is requested but no volume loaded (e.g. /api/volume-raw + // failed while projections succeeded), fall back to the projections grid + // rather than showing an empty, never-initialized 3D panel. + if (method === '3d_viewer' && !this.volumeData) method = null; this.selectedMethod = method; this.renderProjections(); this.renderTabs(); @@ -298,6 +306,19 @@ const ProjectionViewer = { this.updateViewerVisibility(); }, + // Resize the WebGL canvas + camera to the container's current width. + // (Height is fixed at 400px; only width tracks the layout.) The animation + // loop handles re-rendering. + _resize3D() { + const container = document.getElementById('pv-3d-container'); + if (!container || !this.renderer3d || !this.camera3d) return; + const w = container.clientWidth || 500; + const h = 400; + this.renderer3d.setSize(w, h); + this.camera3d.aspect = w / h; + this.camera3d.updateProjectionMatrix(); + }, + // 3D Viewer Methods init3DViewer() { const container = document.getElementById('pv-3d-container'); @@ -319,6 +340,22 @@ const ProjectionViewer = { container.innerHTML = ''; container.appendChild(this.renderer3d.domElement); + // Keep the WebGL canvas in sync with its container width — the chat + // panel can dock/resize and the window can resize. The animation loop + // re-renders every frame, so on a size change we only need to resize the + // renderer + camera (coalesced to one rAF). Also listen for the explicit + // layout-change event the chat dock fires on pin/unpin. + if (this._resizeObserver) this._resizeObserver.disconnect(); + this._resizeObserver = new ResizeObserver(() => { + if (this._resizeRaf) cancelAnimationFrame(this._resizeRaf); + this._resizeRaf = requestAnimationFrame(() => this._resize3D()); + }); + this._resizeObserver.observe(container); + if (!this._onLayoutChanged) { + this._onLayoutChanged = () => this._resize3D(); + window.addEventListener('gently:layout-changed', this._onLayoutChanged); + } + // Root group is the object the user rotates. Raymarched volume // mesh is added here. The group scale flips Y so the image // orientation matches 2D projections. @@ -609,6 +646,18 @@ const ProjectionViewer = { cancelAnimationFrame(this.animationId); this.animationId = null; } + if (this._resizeObserver) { + this._resizeObserver.disconnect(); + this._resizeObserver = null; + } + if (this._resizeRaf) { + cancelAnimationFrame(this._resizeRaf); + this._resizeRaf = null; + } + if (this._onLayoutChanged) { + window.removeEventListener('gently:layout-changed', this._onLayoutChanged); + this._onLayoutChanged = null; + } // Dispose the volume cube's geometry, material, and 3D texture. if (this.volumeMesh) { this.volumeMesh.geometry?.dispose(); diff --git a/gently/ui/web/static/js/review.js b/gently/ui/web/static/js/review.js index bd2c75fa..07dcf49e 100644 --- a/gently/ui/web/static/js/review.js +++ b/gently/ui/web/static/js/review.js @@ -86,17 +86,35 @@ const ReviewApp = { } list.innerHTML = filtered.map(s => ` -
-
${this.escapeHtml(s.name || s.session_id)}
+
+
${this.escapeHtml(s.name || s.session_id)}${s.active ? ' active' : ''}
${this.formatDate(s.created_at)} ${s.embryo_count ? `${s.embryo_count} embryo${s.embryo_count !== 1 ? 's' : ''}` : ''}
${s.description ? `
${this.escapeHtml(s.description)}
` : ''} + ${s.active ? '' : ``}
`).join(''); }, + async resumeSession(sessionId) { + if (!confirm('Switch the live agent to this session?\nThe current session is saved first.')) return; + try { + const resp = await fetch(`/api/sessions/${sessionId}/resume`, { method: 'POST' }); + if (resp.ok) { + // Server broadcasts session_changed to reload all clients; we + // navigate home as well so the operator lands on the new session. + window.location.href = '/'; + } else { + const d = await resp.json().catch(() => ({})); + alert('Resume failed: ' + (d.detail || ('HTTP ' + resp.status))); + } + } catch (e) { + alert('Resume failed: ' + e); + } + }, + renderSessionContent() { const content = document.getElementById('session-content'); const s = this.currentSession; diff --git a/gently/ui/web/static/js/utils.js b/gently/ui/web/static/js/utils.js index b0e6d2ac..4b8ff62b 100644 --- a/gently/ui/web/static/js/utils.js +++ b/gently/ui/web/static/js/utils.js @@ -3,7 +3,7 @@ // ══════════════════════════════════════════════════════════ // Tab and view name constants -const TABS = { EMBRYOS: 'embryos', CALIBRATION: 'calibration', EVENTS: 'events', PLANS: 'plans', SESSIONS: 'sessions', DEVICES: 'devices', EXPERIMENT: 'experiment' }; +const TABS = { HOME: 'home', EMBRYOS: 'embryos', CALIBRATION: 'calibration', EVENTS: 'events', PLANS: 'plans', SESSIONS: 'sessions', DEVICES: 'devices', EXPERIMENT: 'experiment' }; /** * HTML-escape a string (safe for insertion into innerHTML). diff --git a/gently/ui/web/static/js/websocket.js b/gently/ui/web/static/js/websocket.js index 38a2794c..069724a2 100644 --- a/gently/ui/web/static/js/websocket.js +++ b/gently/ui/web/static/js/websocket.js @@ -127,6 +127,24 @@ function handleMessage(msg) { // Switch to embryos tab if not already there if (state.tab !== 'embryos') switchTab('embryos'); } + } else if (msg.type === 'open_volume') { + // The agent asked us to open the in-browser volume viewer — the + // web-native replacement for the old desktop napari window. + if (typeof ProjectionViewer !== 'undefined' && msg.embryo_id != null) { + const view = msg.view || '3d_viewer'; + Promise.resolve(ProjectionViewer.open(msg.embryo_id, msg.timepoint)) + .then(() => { + // Default to the 3D viewer tab when the agent opens it. + if (view && typeof ProjectionViewer.selectMethod === 'function') { + ProjectionViewer.selectMethod(view); + } + }) + .catch((e) => console.warn('open_volume failed', e)); + } + } else if (msg.type === 'session_changed') { + // The live agent switched sessions (resume from the Sessions tab) — + // reload so every client picks up the new session's state + transcript. + window.location.href = '/'; } else if (msg.type === 'ping') { state.ws.send(JSON.stringify({type: 'pong'})); } else if (msg.type === 'presence') { diff --git a/gently/ui/web/strategy_snapshot.py b/gently/ui/web/strategy_snapshot.py index 46569758..91c560e7 100644 --- a/gently/ui/web/strategy_snapshot.py +++ b/gently/ui/web/strategy_snapshot.py @@ -22,7 +22,7 @@ from dataclasses import dataclass from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple +from typing import Any import yaml @@ -37,10 +37,10 @@ # ui_icon names but resolves them to actual unicode glyphs the swimlane SVG # can render directly. _ROLE_ICONS = { - "star": "★", # ★ - "diamond": "◆", # ◆ - "circle": "●", # ● - "triangle": "▲", # ▲ + "star": "★", # ★ + "diamond": "◆", # ◆ + "circle": "●", # ● + "triangle": "▲", # ▲ } # Default per-timepoint exposure when nothing on disk tells us otherwise. @@ -60,9 +60,9 @@ # --------------------------------------------------------------------------- -def _read_yaml(path: Path) -> Optional[dict]: +def _read_yaml(path: Path) -> dict | None: try: - with open(path, "r", encoding="utf-8") as f: + with open(path, encoding="utf-8") as f: return yaml.safe_load(f) or {} except FileNotFoundError: return None @@ -80,7 +80,7 @@ def _pick_timelapse_yaml(session_dir: Path, legacy_session_dir: Path) -> dict: orchestrator that hasn't yet been restarted (still writing to legacy) isn't shadowed by a stale new-path file. """ - candidates: List[Path] = [ + candidates: list[Path] = [ session_dir / "timelapse.yaml", legacy_session_dir / "timelapse.yaml", ] @@ -93,6 +93,7 @@ def _pick_timelapse_yaml(session_dir: Path, legacy_session_dir: Path) -> dict: return {} if len(docs) == 1: return docs[0][1] + # Pick by saved_at if present, falling back to file mtime. def _saved_at_key(item): path, doc = item @@ -105,11 +106,12 @@ def _saved_at_key(item): return path.stat().st_mtime except OSError: return 0.0 + docs.sort(key=_saved_at_key, reverse=True) return docs[0][1] -def _parse_iso(s: Optional[str]) -> Optional[datetime]: +def _parse_iso(s: str | None) -> datetime | None: if not s: return None try: @@ -118,7 +120,7 @@ def _parse_iso(s: Optional[str]) -> Optional[datetime]: return None -def _elapsed_s(t: Optional[datetime], started_at: datetime) -> Optional[float]: +def _elapsed_s(t: datetime | None, started_at: datetime) -> float | None: if t is None: return None return (t - started_at).total_seconds() @@ -132,19 +134,21 @@ def _elapsed_s(t: Optional[datetime], started_at: datetime) -> Optional[float]: @dataclass class _EmbryoAccum: """Mutable accumulator while replaying timeline events for one embryo.""" + eid: str - phases: List[dict] - trigger_events: List[dict] - power_history_488: List[dict] + phases: list[dict] + trigger_events: list[dict] + power_history_488: list[dict] - def open_phase(self, mode: str, start_s: float, cadence_s: Optional[float] = None, - **extra) -> None: + def open_phase( + self, mode: str, start_s: float, cadence_s: float | None = None, **extra + ) -> None: # If the last phase has no end yet, close it at start_s. if self.phases: last = self.phases[-1] if "end" not in last or last["end"] is None: last["end"] = start_s - ph: Dict[str, Any] = {"mode": mode, "start": start_s, "end": None} + ph: dict[str, Any] = {"mode": mode, "start": start_s, "end": None} if cadence_s is not None: ph["cadence_s"] = cadence_s ph.update(extra) @@ -167,7 +171,7 @@ def build_strategy_snapshot( session_id: str, *, horizon_padding_s: float = 1800.0, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Read the session folder and return the strategy dict the frontend wants. Parameters @@ -272,7 +276,9 @@ def build_strategy_snapshot( "now_offset_s": now_offset_s, "horizon_s": horizon_s, "base_interval_s": base_interval_s, - "dose_budget_base_ms": float(dose_budget_base_ms) if dose_budget_base_ms is not None else None, + "dose_budget_base_ms": float(dose_budget_base_ms) + if dose_budget_base_ms is not None + else None, "per_timepoint_ms": per_timepoint_ms, "monitoring_modes": monitoring_modes, "triggers": triggers, @@ -285,7 +291,7 @@ def build_strategy_snapshot( # --------------------------------------------------------------------------- -def _build_monitoring_modes(mode_names: List[str]) -> List[dict]: +def _build_monitoring_modes(mode_names: list[str]) -> list[dict]: """Resolve each active monitoring mode name into a serialized dict. The orchestrator persists only the names in ``timelapse.yaml``; we @@ -300,16 +306,18 @@ def _build_monitoring_modes(mode_names: List[str]) -> List[dict]: logger.debug("Could not import MONITORING_MODES; skipping mode resolution") return [] - out: List[dict] = [] + out: list[dict] = [] for name in mode_names: factory = MONITORING_MODES.get(name) if factory is None: - out.append({ - "name": name, - "description": "", - "applies_to_roles": [], - "params": {}, - }) + out.append( + { + "name": name, + "description": "", + "applies_to_roles": [], + "params": {}, + } + ) continue try: mode = factory() @@ -319,15 +327,16 @@ def _build_monitoring_modes(mode_names: List[str]) -> List[dict]: # Pull declarative knobs (fast_interval, rampdown_*) off the instance. excluded = {"name", "description", "applies_to_roles"} params = { - k: v for k, v in vars(mode).items() - if not k.startswith("_") and k not in excluded + k: v for k, v in vars(mode).items() if not k.startswith("_") and k not in excluded } - out.append({ - "name": mode.name, - "description": mode.description, - "applies_to_roles": list(mode.applies_to_roles), - "params": params, - }) + out.append( + { + "name": mode.name, + "description": mode.description, + "applies_to_roles": list(mode.applies_to_roles), + "params": params, + } + ) return out @@ -338,31 +347,35 @@ def _build_monitoring_modes(mode_names: List[str]) -> List[dict]: def _build_triggers( *, - interval_rules: List[dict], - power_rules: List[dict], - embryo_roles: Dict[str, str], -) -> List[dict]: - triggers: List[dict] = [] + interval_rules: list[dict], + power_rules: list[dict], + embryo_roles: dict[str, str], +) -> list[dict]: + triggers: list[dict] = [] for r in interval_rules: - triggers.append({ - "id": r["name"], - "kind": "interval_rule", - "label": _humanize_rule_name(r["name"]), - "when_text": _interval_when_text(r), - "then_text": _interval_then_text(r), - "applies_to": _resolve_applies_to_roles(r.get("applies_to"), embryo_roles), - "one_time": bool(r.get("one_time", True)), - }) + triggers.append( + { + "id": r["name"], + "kind": "interval_rule", + "label": _humanize_rule_name(r["name"]), + "when_text": _interval_when_text(r), + "then_text": _interval_then_text(r), + "applies_to": _resolve_applies_to_roles(r.get("applies_to"), embryo_roles), + "one_time": bool(r.get("one_time", True)), + } + ) for r in power_rules: - triggers.append({ - "id": r["name"], - "kind": "power_rule", - "label": _humanize_rule_name(r["name"]), - "when_text": _power_when_text(r), - "then_text": _power_then_text(r), - "applies_to": _resolve_applies_to_roles(r.get("applies_to"), embryo_roles), - "one_time": bool(r.get("one_time", False)), - }) + triggers.append( + { + "id": r["name"], + "kind": "power_rule", + "label": _humanize_rule_name(r["name"]), + "when_text": _power_when_text(r), + "then_text": _power_then_text(r), + "applies_to": _resolve_applies_to_roles(r.get("applies_to"), embryo_roles), + "one_time": bool(r.get("one_time", False)), + } + ) return triggers @@ -406,9 +419,9 @@ def _power_then_text(r: dict) -> str: def _resolve_applies_to_roles( - applies_to: Optional[List[str]], - embryo_roles: Dict[str, str], -) -> List[str]: + applies_to: list[str] | None, + embryo_roles: dict[str, str], +) -> list[str]: """``applies_to`` is a list of embryo ids; resolve to a deduplicated list of role names for the chips. ``None`` means "all roles in the timelapse". @@ -430,14 +443,14 @@ def _resolve_applies_to_roles( # --------------------------------------------------------------------------- -def _read_embryo_roles(session_dir: Path) -> Dict[str, str]: +def _read_embryo_roles(session_dir: Path) -> dict[str, str]: """Map embryo_id -> role by scanning ``embryos/*/embryo.yaml``. We read this from the durable per-embryo file rather than timelapse.yaml so the role is correct even when the embryo isn't in the active timelapse (yet). """ - out: Dict[str, str] = {} + out: dict[str, str] = {} embryos_dir = session_dir / "embryos" if not embryos_dir.is_dir(): return out @@ -451,7 +464,7 @@ def _read_embryo_roles(session_dir: Path) -> Dict[str, str]: return out -def _stop_condition_from_serialized(d: Any) -> Tuple[str, str]: +def _stop_condition_from_serialized(d: Any) -> tuple[str, str]: """Read the per-embryo stop_condition dict and return ``(spec, kind)``. ``kind`` is ``"bounded"`` when ANY component of the (possibly composite) @@ -463,10 +476,10 @@ def _stop_condition_from_serialized(d: Any) -> Tuple[str, str]: if not isinstance(d, dict): return "manual", "open_ended" spec = d.get("spec") or "manual" - types: List[str] = [] + types: list[str] = [] if d.get("condition_type"): types.append(d["condition_type"]) - for ad in (d.get("additional") or []): + for ad in d.get("additional") or []: if ad.get("condition_type"): types.append(ad["condition_type"]) auto_stop = any(t != "manual" for t in types) @@ -476,13 +489,13 @@ def _stop_condition_from_serialized(d: Any) -> Tuple[str, str]: def _build_embryos_static( *, session_dir: Path, - tl_embryos: Dict[str, dict], - embryo_roles: Dict[str, str], - dose_budget_base_ms: Optional[float], + tl_embryos: dict[str, dict], + embryo_roles: dict[str, str], + dose_budget_base_ms: float | None, base_interval_s: float, - started_at: Optional[datetime] = None, - now_offset_s: Optional[float] = None, -) -> List[dict]: + started_at: datetime | None = None, + now_offset_s: float | None = None, +) -> list[dict]: """Build the per-embryo static portion of the snapshot. Dynamic fields (phases, trigger_events, power_history_488) are seeded @@ -493,7 +506,7 @@ def _build_embryos_static( except Exception: ROLE_REGISTRY = {} - out: List[dict] = [] + out: list[dict] = [] # Sort embryo ids so the snapshot is deterministic. for eid in sorted(tl_embryos.keys()): ed = tl_embryos[eid] or {} @@ -503,9 +516,7 @@ def _build_embryos_static( icon = _ROLE_ICONS.get(role_def.ui_icon if role_def else "circle", "●") mult = role_def.photodose_budget_multiplier if role_def else 1.0 dose_budget_ms = ( - float(dose_budget_base_ms) * float(mult) - if dose_budget_base_ms is not None - else 0.0 + float(dose_budget_base_ms) * float(mult) if dose_budget_base_ms is not None else 0.0 ) laser_488 = ed.get("laser_power_488_pct") if laser_488 is None: @@ -519,49 +530,47 @@ def _build_embryos_static( stop_spec, stop_kind = _stop_condition_from_serialized(ed.get("stop_condition")) # Seed: one base-cadence phase from t=0 until now (replay will # split it as cadence_changed events come in). - out.append({ - "id": eid, - "role": role, - "color": color, - "icon": icon, - "dose_used_ms": float(ed.get("total_exposure_ms") or 0.0), - "dose_budget_ms": dose_budget_ms, - "tp_acquired": int(ed.get("timepoints_acquired") or 0), - "stop_condition": stop_spec, - "stop_kind": stop_kind, - "laser_488_pct_now": float(laser_488), - "phases": [ - { - "mode": "base", - "start": 0.0, - "end": None, - "cadence_s": initial_cadence, - } - ], - "trigger_events": [], - "power_history_488": [ - {"at": 0.0, "pct": float(laser_488)} - ], - # Filled in by _project_forward. - "projected_cadence_s": initial_cadence, - "projected_end_s": None, - # When the embryo was marked complete/terminated, as seconds - # from session start. Null while still acquiring. The - # frontend uses this to draw a TERMINATED cap and stop the - # projection bar; without it, a finished embryo's row would - # appear to still be acquiring forever. - "terminated_at_s": _terminated_at_offset( - ed, started_at, now_offset_s - ), - }) + out.append( + { + "id": eid, + "role": role, + "color": color, + "icon": icon, + "dose_used_ms": float(ed.get("total_exposure_ms") or 0.0), + "dose_budget_ms": dose_budget_ms, + "tp_acquired": int(ed.get("timepoints_acquired") or 0), + "stop_condition": stop_spec, + "stop_kind": stop_kind, + "laser_488_pct_now": float(laser_488), + "phases": [ + { + "mode": "base", + "start": 0.0, + "end": None, + "cadence_s": initial_cadence, + } + ], + "trigger_events": [], + "power_history_488": [{"at": 0.0, "pct": float(laser_488)}], + # Filled in by _project_forward. + "projected_cadence_s": initial_cadence, + "projected_end_s": None, + # When the embryo was marked complete/terminated, as seconds + # from session start. Null while still acquiring. The + # frontend uses this to draw a TERMINATED cap and stop the + # projection bar; without it, a finished embryo's row would + # appear to still be acquiring forever. + "terminated_at_s": _terminated_at_offset(ed, started_at, now_offset_s), + } + ) return out def _terminated_at_offset( ed: dict, - started_at: Optional[datetime], - now_offset_s: Optional[float], -) -> Optional[float]: + started_at: datetime | None, + now_offset_s: float | None, +) -> float | None: """Map an embryo's ``completed_at`` ISO timestamp into seconds-from- session-start. Returns ``None`` if the embryo isn't complete yet or we don't have the data to compute the offset. @@ -590,14 +599,14 @@ def _terminated_at_offset( def _resolve_timeline_paths( session_dir: Path, legacy_session_dir: Path, -) -> List[Tuple[Path, bool]]: +) -> list[tuple[Path, bool]]: """Return the timeline.jsonl paths to read, with a per-source flag. The flag indicates whether the file is the global legacy timeline (which mixes multiple sessions and must be filtered by session_id) or a per-session file (no filtering needed). """ - paths: List[Tuple[Path, bool]] = [] + paths: list[tuple[Path, bool]] = [] # Per-session (new) location. p = session_dir / "timeline.jsonl" if p.exists(): @@ -620,8 +629,8 @@ def _replay_timeline( session_dir: Path, legacy_session_dir: Path, session_id: str, - embryo_dicts: List[dict], - triggers: List[dict], + embryo_dicts: list[dict], + triggers: list[dict], started_at: datetime, now_offset_s: float, base_interval_s: float, @@ -649,20 +658,19 @@ def _replay_timeline( # We need to know each embryo's current cadence_s as we go (for the # phase records). Seed from each embryo's initial phase. - current_cadence: Dict[str, float] = { - e["id"]: e["phases"][0].get("cadence_s", base_interval_s) - for e in embryo_dicts + current_cadence: dict[str, float] = { + e["id"]: e["phases"][0].get("cadence_s", base_interval_s) for e in embryo_dicts } # Track last trigger_fired per (embryo, rule) so we can cluster # consecutive fires into one event with a count. - last_trigger: Dict[Tuple[str, str], dict] = {} + last_trigger: dict[tuple[str, str], dict] = {} # Collect events from all sources, filtering global file by session_id. - events: List[Tuple[datetime, dict]] = [] + events: list[tuple[datetime, dict]] = [] seen_ids: set = set() for path, is_global in paths: try: - with open(path, "r", encoding="utf-8") as f: + with open(path, encoding="utf-8") as f: for line in f: line = line.strip() if not line or not line.startswith("{"): @@ -711,12 +719,14 @@ def _replay_timeline( pass mode = _phase_mode_from_name(new_phase_name) _close_open_phase(emb, at_s) - emb["phases"].append({ - "mode": mode, - "start": at_s, - "end": None, - "cadence_s": current_cadence.get(embryo_id, base_interval_s), - }) + emb["phases"].append( + { + "mode": mode, + "start": at_s, + "end": None, + "cadence_s": current_cadence.get(embryo_id, base_interval_s), + } + ) elif subtype == "power_changed" and embryo_id in by_id: wavelength = data.get("wavelength") @@ -726,10 +736,12 @@ def _replay_timeline( if new_pct is None: continue emb = by_id[embryo_id] - emb["power_history_488"].append({ - "at": at_s, - "pct": float(new_pct), - }) + emb["power_history_488"].append( + { + "at": at_s, + "pct": float(new_pct), + } + ) emb["laser_488_pct_now"] = float(new_pct) elif subtype == "trigger_fired" and embryo_id in by_id: @@ -739,10 +751,7 @@ def _replay_timeline( emb = by_id[embryo_id] key = (embryo_id, rule_name) prev = last_trigger.get(key) - if ( - prev is not None - and at_s - prev["at"] <= _TRIGGER_CLUSTER_GAP_S - ): + if prev is not None and at_s - prev["at"] <= _TRIGGER_CLUSTER_GAP_S: prev["count"] = prev.get("count", 1) + 1 prev["at"] = at_s # extend cluster to last fire else: @@ -755,13 +764,15 @@ def _replay_timeline( mode = data.get("mode") or "1hz" hz = 1.0 if mode == "1hz" else 20.0 _close_open_phase(emb, at_s) - emb["phases"].append({ - "mode": "burst", - "start": at_s, - "end": None, - "frames": int(data.get("frames") or 0), - "hz": hz, - }) + emb["phases"].append( + { + "mode": "burst", + "start": at_s, + "end": None, + "frames": int(data.get("frames") or 0), + "hz": hz, + } + ) elif subtype == "burst_completed" and embryo_id in by_id: emb = by_id[embryo_id] @@ -795,7 +806,12 @@ def _ensure_tail_power(emb: dict, now_offset_s: float) -> None: extends the steady segment to the right edge.""" hist = emb.get("power_history_488") or [] if not hist: - emb["power_history_488"] = [{"at": now_offset_s, "pct": emb.get("laser_488_pct_now", _DEFAULT_INITIAL_POWER_PCT)}] + emb["power_history_488"] = [ + { + "at": now_offset_s, + "pct": emb.get("laser_488_pct_now", _DEFAULT_INITIAL_POWER_PCT), + } + ] return if hist[-1]["at"] < now_offset_s: hist.append({"at": now_offset_s, "pct": hist[-1]["pct"]}) @@ -820,7 +836,7 @@ def _phase_mode_from_name(name: str) -> str: def _project_forward( *, - embryo_dicts: List[dict], + embryo_dicts: list[dict], now_offset_s: float, per_timepoint_ms: float, ) -> None: @@ -851,7 +867,7 @@ def _project_forward( def _compute_horizon( now_offset_s: float, - embryo_dicts: List[dict], + embryo_dicts: list[dict], padding_s: float, ) -> float: """Pick a horizon that comfortably contains the past + projected future.""" diff --git a/gently/ui/web/templates/_header.html b/gently/ui/web/templates/_header.html index d8f692e0..3f113590 100644 --- a/gently/ui/web/templates/_header.html +++ b/gently/ui/web/templates/_header.html @@ -25,6 +25,17 @@
+ + {% endif %} diff --git a/gently/ui/web/templates/_navbar.html b/gently/ui/web/templates/_navbar.html index 821c8710..33d5f675 100644 --- a/gently/ui/web/templates/_navbar.html +++ b/gently/ui/web/templates/_navbar.html @@ -2,7 +2,8 @@
{% if is_live %} {# SPA tabs — JS-driven via switchTab() in app.js #} -
+
Home
+
Embryos 0
@@ -24,6 +25,7 @@
Sessions
{% else %} {# Standalone pages — all tabs link back to the SPA #} +
Home Embryos Calibration System diff --git a/gently/ui/web/templates/index.html b/gently/ui/web/templates/index.html index dd3bbfb0..85a60345 100644 --- a/gently/ui/web/templates/index.html +++ b/gently/ui/web/templates/index.html @@ -14,12 +14,59 @@ + {% include '_header.html' %} {% include '_navbar.html' %} + +
+
+ +
+
+
+
+

Welcome to Gently

+
Connecting…
+
+ +
+
+
+
+ Recent sessions + All +
+
+
Loading…
+
+
+
+
+ Recent plans + All +
+
+
Loading…
+
+
+
+
+ Recent images + All +
+
+
No images yet — they appear once a session is active.
+
+
+
+
+
+

Calibration

@@ -122,7 +169,7 @@

Calibration

-
+
Monitoring
@@ -257,6 +304,32 @@

Device

+ + disconnected no data
@@ -359,6 +432,20 @@

Device
bottom camera live frame + +
@@ -498,6 +585,34 @@

Properties

+
+ + + +
@@ -517,7 +632,9 @@

Properties

+ + diff --git a/gently/ui/web/templates/login.html b/gently/ui/web/templates/login.html new file mode 100644 index 00000000..62e5ec98 --- /dev/null +++ b/gently/ui/web/templates/login.html @@ -0,0 +1,114 @@ + + + + + + Sign in · Gently + + + + + + +
+
+ +

Gently

+
+

Sign in to control the microscope — or keep watching in view-only mode.

+ + + + + +
+
or
+ Continue without signing in → +

View-only: watch live sessions and imagery. You can sign in any time to take control.

+
+ + + diff --git a/gently/ui/web/timelapse_tracker.py b/gently/ui/web/timelapse_tracker.py index b57086f9..f980001d 100644 --- a/gently/ui/web/timelapse_tracker.py +++ b/gently/ui/web/timelapse_tracker.py @@ -6,7 +6,6 @@ """ from datetime import datetime -from typing import Dict, List, Optional class TimelapseStateTracker: @@ -21,15 +20,17 @@ class TimelapseStateTracker: """ def __init__(self): - self.session_id: Optional[str] = None # Unique ID per experiment + self.session_id: str | None = None # Unique ID per experiment self.status = "IDLE" # IDLE, RUNNING, PAUSED, COMPLETED - self.started_at: Optional[str] = None - self.embryos: Dict[str, dict] = {} # embryo_id -> state + self.started_at: str | None = None + self.embryos: dict[str, dict] = {} # embryo_id -> state self.total_timepoints = 0 self.base_interval = 120 - self.detection_reasoning: Dict[str, List[dict]] = {} # embryo_id -> list of detections - self.projection_uids: Dict[str, Dict[int, str]] = {} # embryo_id -> {timepoint -> projection_uid} - self.volume_paths: Dict[str, Dict[int, str]] = {} # embryo_id -> {timepoint -> volume_path} + self.detection_reasoning: dict[str, list[dict]] = {} # embryo_id -> list of detections + self.projection_uids: dict[ + str, dict[int, str] + ] = {} # embryo_id -> {timepoint -> projection_uid} + self.volume_paths: dict[str, dict[int, str]] = {} # embryo_id -> {timepoint -> volume_path} def handle_event(self, event_type: str, data: dict): """Update state based on incoming event""" @@ -128,6 +129,16 @@ def handle_event(self, event_type: str, data: dict): self.status = "STOPPED" # Don't mark embryos as complete - they were stopped, not finished + elif event_type == "EMBRYO_TERMINATED": + # A single embryo's imaging was halted by the orchestrator + # (no_object terminal, configured stop condition, errors, etc). + # Carry the completion_reason through so the UI can show why. + eid = data.get("embryo_id") + if eid and eid in self.embryos: + self.embryos[eid]["is_complete"] = True + self.embryos[eid]["completion_reason"] = data.get("completion_reason") + self.embryos[eid].setdefault("completed_at", datetime.now().isoformat()) + elif event_type == "DETECTOR_EVALUATED": # All detector/perception evaluations (with reasoning) - populates reasoning panel eid = data.get("embryo_id") @@ -148,7 +159,8 @@ def handle_event(self, event_type: str, data: dict): "description": data.get("description"), "timepoint": timepoint, "volume_uid": data.get("volume_uid"), - "projection_uid": data.get("projection_uid") or projection_uid, # Use stored UID as fallback + "projection_uid": data.get("projection_uid") + or projection_uid, # Use stored UID as fallback "timestamp": datetime.now().isoformat(), # Perception-specific fields "stage": data.get("stage"), @@ -186,15 +198,18 @@ def handle_event(self, event_type: str, data: dict): # before any acquisition has happened. eid = data.get("embryo_id") if eid: - emb = self.embryos.setdefault(eid, { - "embryo_id": eid, - "timepoints": 0, - "is_complete": False, - "first_acquired": None, - "last_acquired": None, - "detections": {}, - "current_stage": None, - }) + emb = self.embryos.setdefault( + eid, + { + "embryo_id": eid, + "timepoints": 0, + "is_complete": False, + "first_acquired": None, + "last_acquired": None, + "detections": {}, + "current_stage": None, + }, + ) if data.get("x") is not None: emb["stage_x_um"] = data["x"] if data.get("y") is not None: @@ -215,13 +230,11 @@ def handle_event(self, event_type: str, data: dict): detector_name = data.get("detector_name", "unknown") self.embryos[eid]["detections"][detector_name] = { "detected": True, - "confidence": data.get("confidence") + "confidence": data.get("confidence"), } if detector_name == "hatching": self.embryos[eid]["is_complete"] = True - self.embryos[eid].setdefault( - "completed_at", datetime.now().isoformat() - ) + self.embryos[eid].setdefault("completed_at", datetime.now().isoformat()) elif event_type == "VERIFICATION_STARTED": # Verification round started for embryo @@ -250,8 +263,12 @@ def handle_event(self, event_type: str, data: dict): # Progress update eid = data.get("embryo_id") if eid and eid in self.embryos and "verification" in self.embryos[eid]: - self.embryos[eid]["verification"]["strategies_complete"] = data.get("strategies_complete", 0) - self.embryos[eid]["verification"]["total_strategies"] = data.get("total_strategies", 5) + self.embryos[eid]["verification"]["strategies_complete"] = data.get( + "strategies_complete", 0 + ) + self.embryos[eid]["verification"]["total_strategies"] = data.get( + "total_strategies", 5 + ) elif event_type == "VERIFICATION_COMPLETED": # Final verification result @@ -292,10 +309,16 @@ def handle_event(self, event_type: str, data: dict): if data.get("change") == "role_assigned": eid = data.get("embryo_id") if eid: - emb = self.embryos.setdefault(eid, { - "embryo_id": eid, "timepoints": 0, "is_complete": False, - "detections": {}, "current_stage": None, - }) + emb = self.embryos.setdefault( + eid, + { + "embryo_id": eid, + "timepoints": 0, + "is_complete": False, + "detections": {}, + "current_stage": None, + }, + ) if data.get("new_role"): emb["role"] = data["new_role"] @@ -304,10 +327,16 @@ def handle_event(self, event_type: str, data: dict): elif event_type == "EMBRYO_CADENCE_CHANGED": eid = data.get("embryo_id") if eid: - emb = self.embryos.setdefault(eid, { - "embryo_id": eid, "timepoints": 0, "is_complete": False, - "detections": {}, "current_stage": None, - }) + emb = self.embryos.setdefault( + eid, + { + "embryo_id": eid, + "timepoints": 0, + "is_complete": False, + "detections": {}, + "current_stage": None, + }, + ) if data.get("new_phase") is not None: emb["cadence_phase"] = data["new_phase"] if data.get("new_interval_s") is not None: @@ -319,22 +348,30 @@ def handle_event(self, event_type: str, data: dict): elif event_type == "POWER_RAMP_STEP": eid = data.get("embryo_id") if eid: - emb = self.embryos.setdefault(eid, { - "embryo_id": eid, "timepoints": 0, "is_complete": False, - "detections": {}, "current_stage": None, - }) + emb = self.embryos.setdefault( + eid, + { + "embryo_id": eid, + "timepoints": 0, + "is_complete": False, + "detections": {}, + "current_stage": None, + }, + ) wavelength = data.get("wavelength", 488) if wavelength == 488: emb["laser_power_488_pct"] = data.get("new_pct") - emb.setdefault("power_history", []).append({ - "wavelength": wavelength, - "old_pct": data.get("old_pct"), - "new_pct": data.get("new_pct"), - "direction": data.get("direction"), - "rule": data.get("rule"), - "intensity_level": data.get("intensity_level"), - "timestamp": datetime.now().isoformat(), - }) + emb.setdefault("power_history", []).append( + { + "wavelength": wavelength, + "old_pct": data.get("old_pct"), + "new_pct": data.get("new_pct"), + "direction": data.get("direction"), + "rule": data.get("rule"), + "intensity_level": data.get("intensity_level"), + "timestamp": datetime.now().isoformat(), + } + ) # cap history per embryo if len(emb["power_history"]) > 200: emb["power_history"] = emb["power_history"][-200:] @@ -342,10 +379,16 @@ def handle_event(self, event_type: str, data: dict): elif event_type == "CLAUDE_DETECTOR_RESULT": eid = data.get("embryo_id") if eid: - emb = self.embryos.setdefault(eid, { - "embryo_id": eid, "timepoints": 0, "is_complete": False, - "detections": {}, "current_stage": None, - }) + emb = self.embryos.setdefault( + eid, + { + "embryo_id": eid, + "timepoints": 0, + "is_complete": False, + "detections": {}, + "current_stage": None, + }, + ) findings = data.get("findings") or {} emb["last_intensity_level"] = findings.get("intensity_level") emb["last_structure_quality"] = findings.get("structure_quality") @@ -353,13 +396,24 @@ def handle_event(self, event_type: str, data: dict): if findings.get("has_hatched"): emb["hatched"] = True - elif event_type in ("BURST_QUEUED", "BURST_START", "BURST_FRAME", "BURST_COMPLETE"): + elif event_type in ( + "BURST_QUEUED", + "BURST_START", + "BURST_FRAME", + "BURST_COMPLETE", + ): eid = data.get("embryo_id") if eid: - emb = self.embryos.setdefault(eid, { - "embryo_id": eid, "timepoints": 0, "is_complete": False, - "detections": {}, "current_stage": None, - }) + emb = self.embryos.setdefault( + eid, + { + "embryo_id": eid, + "timepoints": 0, + "is_complete": False, + "detections": {}, + "current_stage": None, + }, + ) emb.setdefault("burst", {}) burst_state = emb["burst"] if event_type == "BURST_QUEUED": @@ -397,7 +451,7 @@ def to_dict(self) -> dict: "embryos": self.embryos, "total_timepoints": self.total_timepoints, "base_interval": self.base_interval, - "detection_reasoning": self.detection_reasoning + "detection_reasoning": self.detection_reasoning, } def reset(self): @@ -424,14 +478,17 @@ def seed_from_experiment(self, experiment) -> int: y = pos.get("y") if isinstance(pos, dict) else None if x is None or y is None: continue - self.handle_event("EMBRYO_DETECTED", { - "embryo_id": eid, - "uid": getattr(emb, "uid", None), - "x": x, - "y": y, - "role": getattr(emb, "role", "test"), - "user_label": getattr(emb, "user_label", None), - "confidence": getattr(emb, "detection_confidence", None), - }) + self.handle_event( + "EMBRYO_DETECTED", + { + "embryo_id": eid, + "uid": getattr(emb, "uid", None), + "x": x, + "y": y, + "role": getattr(emb, "role", "test"), + "user_label": getattr(emb, "user_label", None), + "confidence": getattr(emb, "detection_confidence", None), + }, + ) seeded += 1 return seeded diff --git a/gently/ui/web/volume_helpers.py b/gently/ui/web/volume_helpers.py index e92954ae..2ff479d9 100644 --- a/gently/ui/web/volume_helpers.py +++ b/gently/ui/web/volume_helpers.py @@ -10,11 +10,10 @@ import re from io import BytesIO from pathlib import Path -from typing import Optional import numpy as np -from gently.core.imaging import normalize_to_uint8, image_to_base64 +from gently.core.imaging import image_to_base64, normalize_to_uint8 logger = logging.getLogger(__name__) @@ -22,7 +21,7 @@ VOLUME_UID_PATTERN = re.compile(r"volume_(.+)_t(\d+)$") -def parse_volume_uid(uid: str) -> Optional[tuple]: +def parse_volume_uid(uid: str) -> tuple | None: """Parse a volume UID into (embryo_id, timepoint) or return None.""" if not uid.startswith("volume_"): return None @@ -42,9 +41,9 @@ def load_volume_from_disk(volume_path: str) -> np.ndarray: Cropped 3D numpy array (Z, H, W) """ from gently.core.imaging import ( - load_volume, - compute_crop_bounds, apply_crop_bounds, + compute_crop_bounds, + load_volume, ) path = Path(volume_path) @@ -67,7 +66,7 @@ def array_to_png_bytes(img_array: np.ndarray) -> bytes: img_array = normalize_to_uint8(img_array, method="simple") img = Image.fromarray(img_array) buf = BytesIO() - img.save(buf, format='PNG') + img.save(buf, format="PNG") return buf.getvalue() diff --git a/launch_gently.py b/launch_gently.py index 1727dd2a..64efdc15 100644 --- a/launch_gently.py +++ b/launch_gently.py @@ -4,36 +4,58 @@ Conversational AI agent for diSPIM microscope control. +Starts the agent + web visualization server, then opens the browser UI. +The web UI is the control surface (the legacy Ink TUI is retired — its +source is kept in the tree but no longer launched). + Usage: - python launch_gently.py # Ink TUI (default) - python launch_gently.py --offline + python launch_gently.py # Start server + open browser + python launch_gently.py --no-browser # Start server, don't open a browser + python launch_gently.py --offline # Run without the device layer + python launch_gently.py --no-api # UI-only: boot the web UI without an API key python launch_gently.py --sessions # List sessions and exit - python launch_gently.py --resume # Interactive session picker + python launch_gently.py --resume # Resume most recent session python launch_gently.py --resume latest # Resume most recent session python launch_gently.py --resume # Resume specific session python launch_gently.py -v # Verbose (INFO) logging python launch_gently.py --debug # Debug logging """ +import argparse import asyncio -import json import logging import os -import sys import shutil import subprocess -import argparse -from pathlib import Path +import sys from datetime import datetime +from pathlib import Path import yaml -from gently.log_config import configure_logging +# Load a project-root .env (if present) so ANTHROPIC_API_KEY and other +# settings can live in a file instead of being exported every session. +# Existing environment variables take precedence. +try: + from dotenv import load_dotenv + + load_dotenv(Path(__file__).resolve().parent / ".env") +except ImportError: + pass + +# The gently imports below pull in heavy dependencies (anthropic, torch, scipy, +# perception) and take several seconds. Print immediate feedback first so the +# terminal isn't silent during that load. Skipped for --help/--version. +if not any(flag in sys.argv for flag in ("-h", "--help")): + print("Starting gently — loading modules (this can take a few seconds)...", flush=True) + from gently.app.agent import MicroscopyAgent +from gently.core.file_store import FileStore +from gently.core.log_bridge import configure_log_bridge +from gently.hardware import get_hardware, load_hardware +from gently.log_config import configure_logging from gently.organisms import load_organism -from gently.hardware import load_hardware, get_hardware from gently.settings import settings -from gently.core.file_store import FileStore logger = logging.getLogger(__name__) @@ -65,11 +87,13 @@ def _build_session_items(store: FileStore) -> list: session_id = session.get("session_id", "unknown") embryos = store.list_embryos(session_id) embryo_count = len(embryos) if embryos else 0 - items.append({ - "session_id": session_id, - "embryo_count": embryo_count, - "time": _format_elapsed(session.get("last_active", "")), - }) + items.append( + { + "session_id": session_id, + "embryo_count": embryo_count, + "time": _format_elapsed(session.get("last_active", "")), + } + ) return items @@ -89,10 +113,93 @@ def list_sessions(store: FileStore): print("Use: python launch_gently.py --resume ") +def _print_banner(viz_url, device_connected, offline, storage_dir, log_file, resumed, no_api=False): + """Print a human-readable launch banner to the terminal. + + This is the "what you see when you open it" surface now that the + server (not a TUI) is the long-running process. + """ + line = "─" * 56 + if offline: + dev = "○ offline (--offline)" + elif device_connected: + dev = "● connected" + else: + dev = "○ offline — run: python start_device_layer.py" + agent_status = "○ disabled — UI only (--no-api)" if no_api else "● enabled" + url = viz_url or "(viz server failed to start — check the log)" + tag = " [resumed session]" if resumed else "" + print() + print(f" ✦ Gently is running.{tag}") + print(f" {line}") + print(f" Open: {url}") + print(f" Agent: {agent_status}") + print(f" Device: {dev}") + print(f" Storage: {storage_dir}") + print(f" Logs: {log_file}") + print(" Stop: Ctrl-C") + print(f" {line}") + print() + + +def _open_browser(url: str) -> None: + """Open the web UI, preferring Google Chrome. + + Override with GENTLY_BROWSER (a webbrowser name like 'firefox', or a full + path to a browser executable). Falls back to the OS default browser if + Chrome can't be found, so this never blocks startup. + """ + import webbrowser + + override = os.environ.get("GENTLY_BROWSER", "").strip() + + # 1) Registered browser names (override first, then Chrome aliases). + for name in ([override] if override else []) + [ + "chrome", + "google-chrome", + "chromium", + ]: + try: + webbrowser.get(name).open(url) + return + except Exception: + pass + + # 2) Explicit executables (an override path, then known Chrome locations). + candidates = [override] if override else [] + candidates += [ + shutil.which("chrome"), + r"C:\Program Files\Google\Chrome\Application\chrome.exe", + r"C:\Program Files (x86)\Google\Chrome\Application\chrome.exe", + ] + for exe in candidates: + try: + if exe and Path(exe).exists(): + webbrowser.register( + "gently-browser", + None, + webbrowser.BackgroundBrowser(exe), + preferred=True, + ) + webbrowser.get("gently-browser").open(url) + return + except Exception: + pass + + # 3) Fall back to the OS default. + try: + webbrowser.open(url) + except Exception: + pass + + def run_ink_picker(tui_dist: Path, sessions_json: str) -> str | None: """ Spawn the Ink TUI in session-picker mode and capture the selection. + Retired: kept for reference / potential reuse by a future web session + picker. No longer called by the launcher. + Returns the selected session ID, or None for a new session. """ proc = subprocess.run( @@ -110,20 +217,36 @@ def run_ink_picker(tui_dist: Path, sessions_json: str) -> str | None: # Parse the SESSION: protocol line from stdout for line in (proc.stdout or "").splitlines(): if line.startswith("SESSION:"): - selected = line[len("SESSION:"):].strip() + selected = line[len("SESSION:") :].strip() return selected if selected else None return None -async def main(offline: bool = False, resume_session: str = None, show_sessions: bool = False, pick_session: bool = False, log_level: str = "WARNING"): +async def main( + offline: bool = False, + resume_session: str | None = None, + show_sessions: bool = False, + pick_session: bool = False, + log_level: str = "WARNING", + no_browser: bool = False, + no_api: bool = False, +): # Set up log file in storage directory - storage_base = Path(os.environ.get("GENTLY_STORAGE", "D:/Gently3")) + # Unified with FileStore: logs live under the same root as data + # (settings.storage.base_path reads GENTLY_STORAGE_PATH). Previously this + # read a separate GENTLY_STORAGE env var, so setting only one split logs + # from data. + storage_base = settings.storage.base_path log_dir = storage_base / "logs" log_dir.mkdir(parents=True, exist_ok=True) log_file = str(log_dir / f"gently_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log") # File always gets INFO+, console uses the requested level configure_logging(level=log_level, log_file=log_file) + # Mirror gently / gently_perception log lines onto the EventBus so the + # Events page in the viz server shows them too. Env vars control level + # and whether to include noisy third-party loggers (off by default). + configure_log_bridge() logger.info("Logging to %s (console level: %s)", log_file, log_level) # Load organism module from config @@ -142,39 +265,47 @@ async def main(offline: bool = False, resume_session: str = None, show_sessions: # Create unified store (FileStore) early for session queries from gently.core.gently_manifest import write_manifest + write_manifest(storage_dir) store = FileStore(storage_dir) + # ── Accounts / auth ─────────────────────────────────────────── + # Self-managed user accounts gate microscope control on the LAN. On first + # run we bootstrap an admin and print its one-time password in the banner. + # Set GENTLY_NO_AUTH=1 to disable accounts (legacy localhost-control mode). + admin_creds = None + if os.environ.get("GENTLY_NO_AUTH", "").strip().lower() not in ("1", "true", "yes"): + try: + from gently.ui.web.accounts import AccountStore, set_account_store + + account_store = AccountStore(storage_dir / "auth") + set_account_store(account_store) + admin_creds = account_store.bootstrap_admin_if_empty() + except Exception as e: + logger.error("Account store init failed (continuing without auth): %s", e) + # Handle --sessions (just list and exit) if show_sessions: list_sessions(store) store.close() return - # Ensure TUI is available - tui_dist = Path(__file__).parent / "gently" / "tui" / "dist" / "index.js" - if not tui_dist.exists() or not shutil.which("node"): - print("Error: TUI not available.") - if not tui_dist.exists(): - print(" Run: cd gently/tui && npm install && npm run build") - if not shutil.which("node"): - print(" Node.js not found in PATH") - store.close() - return + # Web-only: the TUI is retired. The browser is the control surface and + # the launcher just starts the server — no Node/dist requirement. - # Handle --resume (interactive picker, "latest", or specific session) + # Handle --resume. Interactive session picking has moved to the browser; + # without an explicit ID ("latest" or bare --resume) we resume the most + # recent session. session_to_resume = None - if pick_session: - # Two-phase launch: spawn Ink picker to select a session - items = _build_session_items(store) - if not items: - print("No saved sessions found. Starting new session.") - else: - session_to_resume = run_ink_picker(tui_dist, json.dumps(items)) - elif resume_session == "latest": + if pick_session or resume_session == "latest": sessions = store.list_sessions() if sessions: session_to_resume = sessions[0].get("session_id") + if pick_session: + print( + f"Resuming most recent session: {session_to_resume} " + "(interactive session picking is moving into the browser)" + ) else: print("No sessions found - starting fresh") elif resume_session: @@ -193,11 +324,12 @@ async def main(offline: bool = False, resume_session: str = None, show_sessions: if not offline: hw = get_hardware() http_url = f"http://{settings.network.device_host}:{settings.network.device_port}" - if hasattr(hw, 'create_client'): + if hasattr(hw, "create_client"): client = hw.create_client(http_url=http_url) else: # Fallback for hardware modules without create_client from gently.app.queue_server_client import QueueServerClient + client = QueueServerClient(http_url=http_url) connected = await client.connect() if not connected: @@ -208,14 +340,15 @@ async def main(offline: bool = False, resume_session: str = None, show_sessions: # makes Claude hallucinate XML tool calls as plain text. logger.debug( "Device layer not reachable at %s — microscope tools " - "available but will return errors until connected", http_url, + "available but will return errors until connected", + http_url, ) # Configure device session for zero-copy volume transfer if client and client.is_connected: try: incoming = str(store.incoming_dir) - resp = await client.configure_device_session(incoming) + await client.configure_device_session(incoming) logger.info("Device session configured: volume_dir=%s", incoming) except Exception as e: logger.error("Failed to configure device session (volumes will be slow): %s", e) @@ -224,6 +357,7 @@ async def main(offline: bool = False, resume_session: str = None, show_sessions: if client and client.is_connected: try: from gently.harness.microscope import register_microscope_tools + n = register_microscope_tools(client) if n: logger.info("Registered %d microscope tools from device layer", n) @@ -236,12 +370,14 @@ async def main(offline: bool = False, resume_session: str = None, show_sessions: storage_path=storage_dir, session_id=session_to_resume, store=store, + no_api=no_api, ) # Generate TLS certificate for mesh communication cert_path, key_path = None, None try: from gently.mesh.tls import ensure_tls_cert, get_cert_fingerprint + _config_dir = Path(__file__).parent / "config" cert_path, key_path = ensure_tls_cert(_config_dir) except Exception: @@ -251,15 +387,20 @@ async def main(offline: bool = False, resume_session: str = None, show_sessions: # self-signed certs trigger browser "unsafe" warnings for visitors). await agent.start_viz_server(port=settings.network.viz_port) scheme = "http" - viz_url = f"{scheme}://localhost:{settings.network.viz_port}" if agent.viz_server is not None else None + viz_url = ( + f"{scheme}://localhost:{settings.network.viz_port}" + if agent.viz_server is not None + else None + ) # ── Mesh discovery ────────────────────────────────────────────── mesh = None try: + import uuid as _uuid + from gently.mesh import MeshService, register_mesh_routes from gently.mesh.audit import MeshAuditLog from gently.mesh.pairing import PairingManager - import uuid as _uuid # Persistent instance ID instance_id_path = Path(__file__).parent / "config" / "mesh_instance_id" @@ -285,6 +426,7 @@ def _capability_provider(): # GPU detection — try torch first, fall back to nvidia-smi try: import torch + if torch.cuda.is_available(): caps["has_gpu"] = True caps["gpu_name"] = torch.cuda.get_device_name(0) @@ -294,10 +436,15 @@ def _capability_provider(): except ImportError: try: import subprocess as _sp + out = _sp.check_output( - ["nvidia-smi", "--query-gpu=name,memory.total", - "--format=csv,noheader,nounits"], - timeout=5, text=True, + [ + "nvidia-smi", + "--query-gpu=name,memory.total", + "--format=csv,noheader,nounits", + ], + timeout=5, + text=True, ).strip() if out: parts = out.split(",", 1) @@ -317,6 +464,7 @@ def _capability_provider(): def _status_provider(): import gently as _gently + return { "session_id": agent.session_id or "", "acquisition_status": "idle", @@ -329,6 +477,7 @@ def _status_provider(): } import socket as _socket + config_dir = Path(__file__).parent / "config" audit_log = MeshAuditLog(config_dir) pairing_mgr = PairingManager( @@ -361,27 +510,32 @@ def _status_provider(): await mesh.start() except Exception as e: import logging as _log + _log.getLogger(__name__).warning(f"Mesh discovery failed to start: {e}") mesh = None # ── End mesh ──────────────────────────────────────────────────── # Attach the agent bridge to the viz server from gently.harness.bridge import AgentBridge + bridge = AgentBridge(agent) - bridge.set_launch_info({ - "device_connected": client.is_connected if client else False, - "sam_available": client.has_sam if client else False, - "offline": offline or (client is None) or not client.is_connected, - "store_path": str(storage_dir), - "viz_url": viz_url, - "log_path": str(log_file), - "resumed": session_to_resume is not None, - "mesh_service": mesh, - }) + bridge.set_launch_info( + { + "device_connected": client.is_connected if client else False, + "sam_available": client.has_sam if client else False, + "offline": offline or (client is None) or not client.is_connected, + "store_path": str(storage_dir), + "viz_url": viz_url, + "log_path": str(log_file), + "resumed": session_to_resume is not None, + "mesh_service": mesh, + } + ) # Initialize startup wizard (gap-driven onboarding) from gently.harness.memory.file_store import FileContextStore + agent_dir = storage_dir / "agent" context_store = FileContextStore(agent_dir) agent.set_context_store(context_store) @@ -390,34 +544,71 @@ def _status_provider(): if agent.viz_server is not None: agent.viz_server.agent_bridge = bridge agent.viz_server.set_context_store(context_store) + # If launched into an existing session, rehydrate its persisted + # imagery so the galleries/filmstrips show data from the start. + if session_to_resume: + try: + agent.viz_server.rehydrate_session(session_to_resume) + except Exception: + logger.debug("Startup rehydrate failed", exc_info=True) + + # ── Banner + serve ────────────────────────────────────────────── + # The viz server runs in-process (uvicorn in a background task). With + # the TUI retired, the launcher's job is to keep that server alive and + # point the operator at the browser. + _print_banner( + viz_url=viz_url, + device_connected=bool(client and client.is_connected), + offline=offline, + storage_dir=storage_dir, + log_file=log_file, + resumed=session_to_resume is not None, + no_api=no_api, + ) - ws_url = f"ws://localhost:{settings.network.viz_port}/ws/agent" + if admin_creds: + _u, _p = admin_creds + print(" First-run admin account created — sign in at the URL above:") + print(f" username: {_u}") + print(f" password: {_p}") + print(" (Save this now. Add users via the admin API; GENTLY_NO_AUTH=1 disables auth.)\n") - # Spawn the Node.js TUI — it inherits stdin/stdout/stderr so Ink - # takes over the terminal. - tui_proc = subprocess.Popen( - ["node", str(tui_dist), "--ws-url", ws_url], - stdin=sys.stdin, - stdout=sys.stdout, - stderr=sys.stderr, - ) + if viz_url and not no_browser: + _open_browser(viz_url) + # Keep the event loop alive so the in-process viz server keeps serving. + # On Windows the Proactor loop won't surface Ctrl-C while blocked on a + # bare Event().wait(), so install signal handlers and poll on a short + # interval (which also lets a pending KeyboardInterrupt surface). + import signal as _signal + + _loop = asyncio.get_running_loop() + _stop = asyncio.Event() try: - # Wait for TUI to exit (blocks the event loop in a thread so - # the asyncio loop stays responsive for the viz server). - exit_code = await asyncio.get_event_loop().run_in_executor( - None, tui_proc.wait - ) - except (KeyboardInterrupt, asyncio.CancelledError): - tui_proc.terminate() + _loop.add_signal_handler(_signal.SIGINT, _stop.set) + _loop.add_signal_handler(_signal.SIGTERM, _stop.set) + except (NotImplementedError, AttributeError, RuntimeError, ValueError): + # Windows Proactor: add_signal_handler is unsupported — fall back to + # signal.signal, waking the loop via call_soon_threadsafe. + def _sig(*_a): + _loop.call_soon_threadsafe(_stop.set) + try: - tui_proc.wait(timeout=5) - except Exception: + _signal.signal(_signal.SIGINT, _sig) + _signal.signal(_signal.SIGTERM, _sig) + except (ValueError, OSError): pass + + try: + while not _stop.is_set(): + await asyncio.sleep(0.3) + except (KeyboardInterrupt, asyncio.CancelledError): + pass finally: # Suppress noisy CancelledError / overlapped IO errors from # uvicorn during shutdown on Windows. import logging as _logging + _logging.getLogger("uvicorn.error").setLevel(_logging.CRITICAL) _logging.getLogger("uvicorn").setLevel(_logging.CRITICAL) # Cleanup: stop mesh service @@ -436,37 +627,65 @@ def _status_provider(): def cli_main(): """Sync entry point for ``gently`` console script (pyproject.toml).""" - if not os.getenv("ANTHROPIC_API_KEY"): - print("Error: ANTHROPIC_API_KEY not set") - print("Set with: set ANTHROPIC_API_KEY=your-key") - exit(1) - parser = argparse.ArgumentParser(description="Launch Microscopy Agent") parser.add_argument("--offline", action="store_true", help="Run without server connections") + parser.add_argument( + "--no-api", + action="store_true", + help="UI-only mode: boot the web UI without any Anthropic API key. " + "Chat, perception, and plan generation are disabled.", + ) parser.add_argument("--sessions", action="store_true", help="List available sessions and exit") - parser.add_argument("--resume", nargs="?", const="__PICK__", metavar="ID", - help="Resume a session. Without ID: shows picker. With ID: resumes that session.") - parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose (INFO) logging") + parser.add_argument( + "--resume", + nargs="?", + const="__PICK__", + metavar="ID", + help="Resume a session. Without ID: shows picker. With ID: resumes that session.", + ) + parser.add_argument( + "-v", "--verbose", action="store_true", help="Enable verbose (INFO) logging" + ) parser.add_argument("--debug", action="store_true", help="Enable debug logging (most verbose)") + parser.add_argument( + "--no-browser", + action="store_true", + help="Do not auto-open the web UI in a browser", + ) args = parser.parse_args() + # An API key is required unless running in UI-only mode. + if not args.no_api and not os.getenv("ANTHROPIC_API_KEY"): + print("Error: ANTHROPIC_API_KEY not set") + if os.name == "nt": + print("Set with: set ANTHROPIC_API_KEY=your-key") + else: + print("Set with: export ANTHROPIC_API_KEY=your-key") + print("Or add it to a .env file in the project root: ANTHROPIC_API_KEY=your-key") + print("Or run UI-only without a key: python launch_gently.py --no-api") + exit(1) + log_level = "WARNING" if args.verbose: log_level = "INFO" if args.debug: log_level = "DEBUG" - pick_session = (args.resume == "__PICK__") + pick_session = args.resume == "__PICK__" resume_id = args.resume if args.resume and args.resume != "__PICK__" else None try: - asyncio.run(main( - offline=args.offline, - show_sessions=args.sessions, - resume_session=resume_id, - pick_session=pick_session, - log_level=log_level, - )) + asyncio.run( + main( + offline=args.offline, + show_sessions=args.sessions, + resume_session=resume_id, + pick_session=pick_session, + log_level=log_level, + no_browser=args.no_browser, + no_api=args.no_api, + ) + ) except (KeyboardInterrupt, RuntimeError, SystemExit): pass diff --git a/notes/biologist-readiness-plan.md b/notes/biologist-readiness-plan.md new file mode 100644 index 00000000..03c25844 --- /dev/null +++ b/notes/biologist-readiness-plan.md @@ -0,0 +1,342 @@ +# Gently — Biologist-Readiness Plan + +> Engineering plan to make Gently more robust, easier for a non-programmer biologist to operate, +> and to evolve it into a multi-user, web-first microscope control system. +> Compiled from a codebase audit (architecture map, complexity audit of all >200-line files in `gently/`, +> robustness + UX review, frontend audit, startup/topology trace, and auth/multi-user ground-truth). + +**Author:** engineering analysis · **Date:** 2026-05-28 · **Horizon:** 1 focused week + a multi-sprint convergence arc + +--- + +## 0. Strategic decisions (already made) + +These are settled and shape everything below: + +1. **Frontend → converge on web-only.** The browser becomes the single surface (a floating agent chat window + the existing rich visuals). The Ink TUI becomes **legacy / maintenance-only** and is retired once the web reaches control parity. → *Do not invest in TUI refactors.* +2. **Processes → keep the two-process split, improve feedback.** The device layer (`start_device_layer.py`) stays a separate process from the agent (`launch_gently.py`) — this isolation is a safety feature, not an accident. Fix the *visibility* of its state, not the topology. +3. **Multi-user → LAN deployment, pluggable auth (no IT dependency to start).** Auth is a thin pluggable layer. Start with **Gently-managed accounts** (or shared/role tokens as an MVP) — needs nothing from institute IT. **Institute SSO (e.g. Janelia/HHMI login via a reverse proxy) is an optional later upgrade** that slots into the same layer if/when IT provides an endpoint. Gently owns the **control arbitration + roles + audit**, regardless of which login backend is used. +4. **Roles → viewers vs operators.** Anyone authenticated can **watch** (today's read-only experience, unchanged). Only **operators** can take control and drive the microscope. **Admins** can force-release and manage roles. +5. **Permission model → an explicit observable-vs-inputable classification.** Every endpoint/WS-message is tagged `observable` (read-only) or `inputable` (control). One registry drives all gating: viewer = observable set; operator-with-lock = observable + inputable. Adding a new action forces a classification; the audit log falls out of the `inputable` tag. +6. **Plan shape → balanced.** Interleave robustness/UX hardening with safe, high-value refactors. Bold-but-safe: refactor where features *won't* break; add tests *before* touching anything that might. + +--- + +## 1. Executive summary + +Gently is in **good architectural shape**. The hard parts (async acquisition state machine, hardware-safety code, the LLM loop) are well-factored. The problems that matter are **not "too complex"** — they are a handful of **silent, high-consequence failure modes**, an **opt-in/jargon UX that assumes a programmer**, and the **operational friction** of starting and using a multi-process, dual-frontend system. The web-only + multi-user direction resolves much of the friction *by construction* (e.g. it dissolves the embryo-marking hand-off and removes the Node dependency). + +**Top priorities, in order:** + +1. **Fix the verified, provable bugs** (status-tool KeyError, non-atomic writes, the silent device-down, the env-var split). Low risk, immediate value. +2. **Wire crash/restart auto-resume** — the single biggest data-loss risk; the code already exists but is never called. +3. **Harden transient-failure handling** (device hiccups, perception/Claude outages) so a brief blip doesn't silently end a run or image a dead embryo. +4. **Make state visible** — live device heartbeat, connection banner, liveness line, acquisition-settings panel, armed-rules display. +5. **Begin the web-only + multi-user arc** — browser agent chat, then the auth + single-driver control lock (the control lock must land *with* browser control, not after). + +--- + +## 2. State of the codebase — legitimate vs. accidental complexity + +Most large files are **legitimately large** (broad-but-cohesive domain modules), not tangled. Accidental complexity is concentrated and well-localized. + +### Leave alone — legitimate complexity (high feature-break risk) +- `harness/state.py` (979L) — shared mutable `EmbryoState`/`ExperimentState`. Splitting *creates* the duplication the design avoids. **Riskiest refactor target in the repo.** +- `harness/conversation.py` (774L) — core LLM loop (asend-recursion, observed-failure guards). +- `hardware/dispim/devices/*` (stage/optical/scanner/acquisition/camera/piezo) — laser/stage safety constants + MMCore vocab. +- `hardware/dispim/plans/calibration.py` (958L) — irreducible multi-phase calibration state machine. +- `core/imaging.py`, `event_bus.py`, `service.py`; `app/device_state_monitor.py`; `organisms/celegans/stages.py`. + +### Top refactor targets — accidental complexity worth fixing + +| File | Verdict | Risk | Effort | The fix | +|---|---|---|---|---| +| `app/tools/timelapse_tools.py` (815L) | REFACTORABLE | low | ~4h | Contains the confirmed KeyError bug. `@timelapse_tool` decorator kills the 6-line preamble in 17 tools; stop reaching into `orchestrator._embryo_states`. | +| `app/tools/calibration_tools.py` (1504L) | REFACTORABLE | low | ~2h | Delete ~450 lines of **dead code** (`fast_calibrate_embryo`, `hybrid_focus_selection`, `binary_edge_search`, `_fine_focus_sweep`) — unregistered, uncalled, reference nonexistent agent attrs. | +| `harness/bridge.py` (2215L) | REFACTORABLE | med | ~10h | God-object: 720-line `handle_command` if/elif ladder + case-folding bug (lowercases session/embryo IDs). Dispatch table off `CommandRegistry`. **High value for web convergence** — the browser control surface leans on this. | +| `harness/detection/verifier.py` (1158L) | REFACTORABLE | med | ~6h | `verify()`/`verify_with_context()` + two `_evaluate_consensus*` are superset/subset dupes; 5 `_run_*` + 4 `_parse_*` copy-paste. ~250 lines. **Capture consensus truth-table fixtures first.** | +| `mesh/peer_client.py` (393L) | REFACTORABLE | low | ~4h | 11 near-identical authed methods → one `_authed_json` helper (~270→~80 lines). | +| `hardware/dispim/claude_client.py` (631L) | REFACTORABLE | low | ~3h | 4 vision methods copy-paste → one `_vision_call`. | +| `harness/memory/file_store.py` (2552L) | MIXED | med | ~10h | Mixin split + shared serde. Lower priority than deleting the SQLite twin. | + +### The dominant *reduction* opportunity — ~4000 lines of dead duplicate code +The **legacy SQLite store stack** is a complete duplicate of the live file stores (CLAUDE.md says "No SQLite databases"): +- `core/store.py` (1064L) twins `core/file_store.py` +- `harness/memory/{store,_intentions,_plans,_understanding,_ml_pipelines}.py` (~2960L) twin `harness/memory/file_store.py` + +Dead in production, pinned only by ~41 tests. Delete **after** migrating tests to the `file_context_store` fixture → ~4000 lines gone, zero runtime change. **Friday work** (gated on test migration). + +--- + +## 3. Verified bugs (confirmed in source, not just inferred) + +| # | Bug | Location | Impact | +|---|---|---|---| +| V1 | `get_timelapse_status` reads `next_embryo`/`next_acquisition_in_seconds` that `to_dict()` never emits → **KeyError every call**. Same dead keys in `detection_tools.py`. | `app/tools/timelapse_tools.py:145-146,154` | Biologist's primary "is it working?" tool is broken. | +| V2 | `load_state()` fully implemented, `save_state()` runs every acquisition — but `load_state()` has **zero callers**. | `app/orchestration/timelapse.py:1643` | **No crash/restart auto-resume.** Overnight crash = whole night lost. | +| V3 | `_write_yaml` does `unlink()` then `rename()`; `save_state()` writes with no temp file. | `core/file_store.py:123-125` | **Non-atomic on Windows** — a power blip corrupts the files `/resume` needs. | +| V4 | Launcher reads `GENTLY_STORAGE`; everything else uses `GENTLY_STORAGE_PATH`. | `launch_gently.py:121` | Logs and data silently split to different paths. | +| V5 | Device-layer-down is a `logger.debug` (invisible at default log level). | `launch_gently.py:209-212` | Biologist starts with scope off, gets a normal-looking startup, discovers it mid-conversation. | +| V6 | **XSS / HTML injection** — event key/value (perception prose, paths, agent text) assigned via `innerHTML` with no escaping. | `ui/web/static/js/events.js:69-77, 130-151, 237` | Real injection surface in the events table. `escapeHtml` exists and is used elsewhere. | +| V7 | `/ws/agent` has **no connection guard/lock**; conversation state is a single shared object. | `routes/agent_ws.py:128`, `bridge.py:565`, `agent.py:759` | Latent today (TUI is sole client); **becomes live corruption the moment a browser drives the agent.** Fixed by the control lock (§9). | +| V8 | `bridge.handle_command` does `command.strip().lower()` then branches on it. | `harness/bridge.py:647,696` | Case-sensitive args (session IDs, hostnames, embryo IDs) silently corrupted. | +| V9 | Embryo marking blocks forever; `wait_for_marking(timeout=None)`; TUI never shows the viz URL or signals a browser is needed. | `ui/web/embryo_marker.py:79`, `server.py:481`, `detection_tools.py` | **Worst operational friction** — hangs if no browser is open. Dissolved by web-only convergence. | +| V10 | Marking is global shared state broadcast to all `/ws` clients; any client's `marking_done` clobbers. | `server.py:459-472`, `websocket.py:164-188` | Two browsers marking simultaneously clobber each other. Fixed by driver-only gating (§9). | + +--- + +## 4. Robustness gaps (ranked, for unattended multi-hour sessions) + +1. **[CRITICAL] No crash/restart auto-resume** (V2). `_resume_session` (`manager.py:40-117`) restores embryos+conversation but never the orchestrator or runtime fields (stop_condition, cadence_phase, next_due_at, error_count). +2. **[CRITICAL] Device hiccup permanently drops embryos.** `_acquire_embryo` (`timelapse.py:712`) treats network/timeout as terminal; 3 strikes → `complete: errors`. No auto-reconnect in `client.py`. +3. **[CRITICAL] Silent perception/detector outage.** `_run_perception`/`_run_detector` are log-only, no retry, no event. A Claude outage silently freezes stage/hatching detection **while the laser keeps firing.** +4. **[HIGH] Non-atomic writes** (V3). +5. **[HIGH] No abort path for a hung device-layer plan** — one RunEngine, no abort endpoint; one stuck acquisition freezes the wheel each round. +6. **[HIGH] Disk-full silently stops persistence** — `save_state` failures are `logger.debug`. +7. **[HIGH] Orphaned volume TIFFs** — swallowed `register_volume` failure + 300s `cleanup_incoming` race deletes valid volumes. +8. **[HIGH] Unbounded fatal exception kills the whole session** — `_run_loop` top-level except → FAILED, no per-iteration recovery. +9. **[MEDIUM]** Perception task leak / no per-call timeout (`timelapse.py:2706, 2483`). +10. **[MEDIUM]** Startup picker / `wait_for_marking` block forever. +11. **[MEDIUM]** Advisory `session.lock`, no PID check. + +--- + +## 5. Biologist usability gaps (ranked) + +1. **[CRITICAL] "Microscope not connected" is silent** (V5). → persistent banner worded as consequence+fix; live heartbeat dot; `/reconnect`. +2. **[CRITICAL] Phototoxicity protection is opt-in, silent, expert-only** — only arms if Claude is passed `monitoring_mode='expression_monitoring'`. → make it **default** for reporter/hatching experiments; agent states plainly what it armed; show armed rules in plain English. +3. **[CRITICAL] No LLM-independent emergency stop** — pause/stop are only LLM tools. → `/stop` `/pause` that call the orchestrator directly (no API round-trip). +4. **[HIGH] Silent auto-complete / auto-pause** — biologist must inspect `completion_reason`. → push plain-language notice; distinguish hardware-error (offer retry) from biological endpoint. +5. **[HIGH] No liveness reassurance.** → "last volume 0:47 ago · next in 1:13" line, yellow/red when stalled. +6. **[HIGH] Marking blocks with no browser cue** (V9). +7. **[HIGH] Cryptic launch hard-stops** (`ANTHROPIC_API_KEY not set`, "TUI not available", Node/npm). +8. **[HIGH] First-run setup landmines** — stale model IDs (`settings.py:55-58`), env-var split, raw `ModuleNotFoundError` on bad organism, README version drift (v0.11.0 vs 0.20.0). → `--doctor` preflight. +9. **[MEDIUM]** Jargon mismatch (campaign/role=test/burst/SAM/photodose). → relabel human-facing strings. +10. **[MEDIUM]** Stop-condition vocabulary mismatch ("pretzel"/"2fold" shown but rejected as targets); casing drift. +11. **[MEDIUM]** Generic error strings (raw `str(e)`/tracebacks reach the biologist). + +--- + +## 6. Frontend audit + +### Web UI (`gently/ui/web`) — the future single surface +- **Stack:** vanilla JS, no build step, FastAPI + Jinja2, Three.js for 3D. ~15k JS / 21 `