feat(rewards): add reciprocal-space reward (Fprotein from SFC)#272
feat(rewards): add reciprocal-space reward (Fprotein from SFC)#272DorisMai wants to merge 10 commits into
Conversation
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughIntroduces ChangesStructure Factor Reward Function
atomarray_to_gemmi shared utility
Synthetic SF generation multi-label MTZ refactor
Reward function contract test suite
Sequence Diagram(s)sequenceDiagram
participant Caller
participant StructureFactorRewardFunction
participant _detect_mtz_metadata
participant atomarray_to_gemmi
participant SFcalculator
rect rgba(100, 149, 237, 0.5)
note over Caller,StructureFactorRewardFunction: Phase 1 — Configuration
Caller->>StructureFactorRewardFunction: __init__(mtzfile, resolution, loss, ...)
StructureFactorRewardFunction->>_detect_mtz_metadata: read unit_cell, expcolumns from MTZ
_detect_mtz_metadata-->>StructureFactorRewardFunction: unit_cell, space_group, expcolumns
end
rect rgba(60, 179, 113, 0.5)
note over Caller,SFcalculator: Phase 2 — Preparation
Caller->>StructureFactorRewardFunction: prepare(atom_array)
StructureFactorRewardFunction->>atomarray_to_gemmi: convert AtomArray → gemmi.Structure
atomarray_to_gemmi-->>StructureFactorRewardFunction: gemmi.Structure
StructureFactorRewardFunction->>SFcalculator: __init__(pdbmodel=PDBParser(gemmi_structure))
StructureFactorRewardFunction->>SFcalculator: inspect_data()
SFcalculator-->>StructureFactorRewardFunction: Fo, Eo, outlier flags
StructureFactorRewardFunction-->>Caller: ready (reflection mask built)
end
rect rgba(210, 105, 30, 0.5)
note over Caller,SFcalculator: Phase 3 — Forward pass
Caller->>StructureFactorRewardFunction: __call__(coordinates, elements, b_factors, occupancies)
StructureFactorRewardFunction->>SFcalculator: write occupancy/B, calc_fprotein_batch
SFcalculator-->>StructureFactorRewardFunction: per-reflection complex SFs
StructureFactorRewardFunction->>StructureFactorRewardFunction: sum ensemble, compute amplitudes, apply loss mask
StructureFactorRewardFunction-->>Caller: scalar loss tensor
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 7
🧹 Nitpick comments (2)
src/sampleworks/eval/generate_synthetic_sf.py (1)
137-145: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winUse a NumPy-style docstring for this new helper.
The summary is useful, but this new function is missing the required
Parameters/Returnssections. As per coding guidelines, “Always include NumPy-style docstrings for every function and class.”Proposed docstring update
- """Build a one-amplitude rs.DataSet with labelled F / SIGF / PHIF columns. - - ``sfc.prepare_dataset`` returns an amplitude column and a phase column (degrees) - for the given ``structure_factor_column`` attribute. We auto-detect those by MTZ - dtype (rather than assuming the unexposed ``FMODEL`` / ``PHIFMODEL`` names), - rename them to ``F{label}`` / ``PHIF{label}``, and synthesize a ``SIGF{label}`` - column so several structure-factor sets (e.g. protein and total) can coexist in - one MTZ. - """ + """Build a one-amplitude dataset with labelled F / SIGF / PHIF columns. + + Parameters + ---------- + sfc + Structure-factor calculator containing the requested ASU amplitudes. + label + Output column label suffix, e.g. ``protein`` or ``total``. + structure_factor_column + SFcalculator attribute passed to ``prepare_dataset``. + miller_index_column + SFcalculator attribute containing Miller indices. + sigma_f_scale + Scale factor used to synthesize dummy SIGF values from amplitudes. + + Returns + ------- + rs.DataSet + Dataset containing ``F{label}``, ``SIGF{label}``, and ``PHIF{label}``. + """🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/sampleworks/eval/generate_synthetic_sf.py` around lines 137 - 145, The new helper in generate_synthetic_sf.py has a summary docstring but is missing the required NumPy-style structure. Update the docstring for the helper that builds the rs.DataSet to use NumPy format with explicit Parameters and Returns sections, documenting each input and the returned dataset/columns clearly; keep the existing behavior unchanged and ensure the docstring matches the function name and its role in auto-detecting and renaming the structure-factor columns.Source: Coding guidelines
tests/rewards/test_reward_function_contract.py (1)
47-66: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winFreeze
RewardCasebefore sharing it across tests.This bundle is passed around as shared state, so leaving the dataclass mutable makes accidental test-side mutation hard to spot.
@dataclass(frozen=True)matches the repo's immutable-state convention and still works withbatch(). As per coding guidelines, "Use frozen dataclasses with functional updates for immutable state management."Suggested change
-@dataclass +@dataclass(frozen=True) class RewardCase:🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/rewards/test_reward_function_contract.py` around lines 47 - 66, Make RewardCase immutable by marking the RewardCase dataclass as frozen so shared test state cannot be mutated accidentally. Update the RewardCase definition to use a frozen dataclass while keeping batch() unchanged, since it only reads fields and still works with immutable instances. Use the existing RewardCase symbol in tests/rewards/test_reward_function_contract.py to locate the class and apply the repo’s immutable-state convention.Source: Coding guidelines
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/sampleworks/core/rewards/structure_factor.py`:
- Around line 152-163: Validate the batch_partition argument eagerly in the
structure_factor constructor before assigning it on self, and raise a clear
ValueError when it is zero or negative. Add the check in the initializer where
device, mtzfile, and batch_partition are set so invalid OOM-tuning input is
rejected immediately instead of failing later in calc_fprotein_batch().
- Around line 69-78: The auto-detection in structure_factor.py is ambiguous
because it independently chooses the first amplitude and sigma columns, which
can silently pair the wrong MTZ labels in multi-label datasets. Update the
selection logic in the structure-factor loading path to either require explicit
expcolumns when more than one StructureFactorAmplitudeDtype or
StandardDeviationDtype candidate exists, or ensure the chosen sigma column is
matched to the selected amplitude label in the same ds.select_mtzdtype/return
path.
In `@src/sampleworks/eval/generate_synthetic_sf.py`:
- Around line 543-547: The help text in generate_synthetic_sf should not escape
the braces around label because it is a plain string, not an f-string. Update
the help text in the argument definition near the existing bulk solvent option
so users see F{label}/SIGF{label}/PHIF{label} in --help, and verify the same
wording is used consistently wherever that option text is defined.
In `@tests/eval/test_generate_synthetic_sf.py`:
- Around line 135-139: The round-trip test in test_generate_synthetic_sf is only
checking that each non-blank altloc label exists, so it can miss cases where
some atoms lose their altloc assignment. Update the assertion near
find_all_altloc_ids/loaded to verify multiplicity as well, either by comparing
per-label counts for altloc_id or by comparing the full altloc_id annotation
when order is stable. Keep the check black-box by asserting observable
annotation behavior rather than implementation details.
In `@tests/rewards/reward_input_helpers.py`:
- Around line 7-13: The shared helper build_scattering_indices() still has a
prose-only docstring, so update it to the repository’s NumPy-style format. Add
the standard sections for parameters, returns, and any relevant notes/details so
the contract is explicit for the reward tests that reuse it, while keeping the
behavior unchanged.
In `@tests/rewards/test_reward_function_contract.py`:
- Around line 248-270: The debug-only `test_gradient_descent_loss_trace` in
`test_reward_function_contract.py` should be removed from collected tests
because it only prints loss values and never asserts anything. Delete this
`test_...` method from `TestRewardFunctionContract`, or if you want to keep the
trace, rename it to a non-test helper so pytest won’t run it; make sure no
`print`-only optimization loop remains in the test suite.
- Around line 39-44: The module-level test marker in the reward contract suite
should include both GPU and slow tagging. Update the existing pytestmark
assignment in the test module so the suite remains GPU-marked via
pytest.mark.gpu while also adding pytest.mark.slow, keeping the change localized
to the module-level marker used by these reward contract tests.
---
Nitpick comments:
In `@src/sampleworks/eval/generate_synthetic_sf.py`:
- Around line 137-145: The new helper in generate_synthetic_sf.py has a summary
docstring but is missing the required NumPy-style structure. Update the
docstring for the helper that builds the rs.DataSet to use NumPy format with
explicit Parameters and Returns sections, documenting each input and the
returned dataset/columns clearly; keep the existing behavior unchanged and
ensure the docstring matches the function name and its role in auto-detecting
and renaming the structure-factor columns.
In `@tests/rewards/test_reward_function_contract.py`:
- Around line 47-66: Make RewardCase immutable by marking the RewardCase
dataclass as frozen so shared test state cannot be mutated accidentally. Update
the RewardCase definition to use a frozen dataclass while keeping batch()
unchanged, since it only reads fields and still works with immutable instances.
Use the existing RewardCase symbol in
tests/rewards/test_reward_function_contract.py to locate the class and apply the
repo’s immutable-state convention.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: b9ae1d00-b4d9-46d0-87da-383baa1d031f
📒 Files selected for processing (11)
src/sampleworks/core/rewards/structure_factor.pysrc/sampleworks/eval/generate_synthetic_sf.pysrc/sampleworks/eval/synthetic_utils.pytests/conftest.pytests/eval/test_generate_synthetic_sf.pytests/resources/1vme/1vme_final_crystalframe_0.5occA_0.5occB_1.80A.ciftests/resources/1vme/1vme_final_crystalframe_0.5occA_0.5occB_1.80A.mtztests/rewards/reward_input_helpers.pytests/rewards/test_real_space_density_reward.pytests/rewards/test_reward_function_contract.pytests/rewards/test_structure_factor_reward.py
…fixing in prepare()
b3f83d6 to
de306a5
Compare
What changed
StructureFactorRewardFunction(core/rewards/structure_factor.py) via SFC. Two-phase construction (__init__config +prepare(atom_array)) as SFC requires knowing topology.|Fprotein|normalize_amplitudefor testing andbatch_partitionin case of OOMeval/generate_synthetic_sf.pyto generate test data with both Fprotein and Ftotal in the same mtz for debugging/development and support the round trip of cif --> atomarray --> gemmi --> cif --> atomarray.test_real_space_density_reward.pytotest_reward_function_contract.py. Real space specific tests that are unused (e.g. vmap related) remain untouched.test_structure_factor_reward.pyand synthetic test data (1vme cif and mtz) generated fromeval/generate_synthetic_sf.py.Next steps before draft --> ready for review
|Ftotal|(should be trivial)Next steps
Summary by CodeRabbit
New Features
Bug Fixes
Tests