Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 30 additions & 46 deletions scripts/eval/bond_geometry_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
# but static analysis can't always prove that.
import argparse
import itertools
import sys
from pathlib import Path

import numpy as np
import pandas as pd
Expand All @@ -14,7 +12,7 @@
from biotite.structure.io.pdbx import CIFFile, get_structure
from loguru import logger
from peppr.bounds import get_distance_bounds
from sampleworks.eval.grid_search_eval_utils import parse_args, scan_grid_search_results
from sampleworks.eval.grid_search_eval_utils import parse_eval_args, setup_evaluation_parameters
from scipy.special import comb
from tqdm import tqdm

Expand All @@ -40,19 +38,9 @@ def bond_length_violations(pose: AtomArray, tolerance: float = 0.1) -> tuple[flo
outlier_info : pd.DataFrame
DataFrame containing information about the outliers, including atom indices and distances.
"""
if pose.array_length() == 0:
return np.nan, pd.DataFrame()

try:
bounds = get_distance_bounds(pose) # this fetches values from RDKit
except BadStructureError:
return np.nan, pd.DataFrame()

if not pose.bonds:
logger.error(
"Models must have bonds, use "
"`biotite.structure.io.pdbx.get_structure(..., include_bonds=True)`"
)
bounds = check_pose_and_get_bounds(pose)
except (ValueError, BadStructureError) as e:
return np.nan, pd.DataFrame()

bond_indices = np.sort(pose.bonds.as_array()[:, :2], axis=1)
Expand Down Expand Up @@ -99,6 +87,23 @@ def bond_length_violations(pose: AtomArray, tolerance: float = 0.1) -> tuple[flo
return invalid_fraction, pd.DataFrame(outlier_info)


def check_pose_and_get_bounds(pose: AtomArray):
if pose.array_length() == 0:
raise ValueError("The structure is empty.")

if not pose.bonds:
logger.error(
"Models must have bonds, use "
"`biotite.structure.io.pdbx.get_structure(..., include_bonds=True)`"
)
raise ValueError("The structure does not have bonds.")

# this fetches values from RDKit, raises BadStructureError if the structure is bad
bounds = get_distance_bounds(pose)
return bounds



def bond_angle_violations(pose: AtomArray, tolerance: float = 0.1) -> tuple[float, pd.DataFrame]:
"""
Calculate the percentage of bonds that are outside acceptable ranges.
Expand All @@ -123,22 +128,10 @@ def bond_angle_violations(pose: AtomArray, tolerance: float = 0.1) -> tuple[floa
outlier_info : pd.DataFrame
DataFrame containing information about the outliers, including atom indices and distances.
"""
if pose.array_length() == 0:
return np.nan, pd.DataFrame()
bounds = check_pose_and_get_bounds(pose)

try:
bounds = get_distance_bounds(pose)
except BadStructureError:
return np.nan, pd.DataFrame()

if not pose.bonds:
logger.error(
"Models must have bonds, use "
"`biotite.structure.io.pdbx.get_structure(..., include_bonds=True)`"
)
return np.nan, pd.DataFrame()

# in the original, bonds were fetched from the reference structure, but we don't have one here.
# in bond_length_violations, bonds are fetched from the reference structure,
# but we don't have one here.
all_bonds, _ = pose.bonds.get_all_bonds()
# For a bond angle 'ABC', this list contains the atom indices for 'A' and 'C'
bond_indices = []
Expand Down Expand Up @@ -203,20 +196,8 @@ def bond_angle_violations(pose: AtomArray, tolerance: float = 0.1) -> tuple[floa


def main(args: argparse.Namespace):
workspace_root = Path(args.workspace_root)
# TODO make more general: https://github.com/diff-use/sampleworks/issues/93
grid_search_dir = workspace_root / "grid_search_results"
logger.info(f"Grid search directory: {grid_search_dir}")

# Scan for trials (look for refined.cif files)
all_trials = scan_grid_search_results(grid_search_dir, target_filename=args.target_filename)
logger.info(f"Found {len(all_trials)} trials with refined.cif files")

if all_trials:
all_trials.summarize() # Prints some summary stats, e.g. number of unique proteins
else:
logger.error("No trials found in grid search directory. Exiting with status 1.")
sys.exit(1)
# The unused variable is a list of ProteinConfigs, not used yet in this script
all_trials, _ = setup_evaluation_parameters(args)

all_bond_length_outliers = []
all_bond_angle_outliers = []
Expand Down Expand Up @@ -254,6 +235,7 @@ def main(args: argparse.Namespace):
}
)

grid_search_dir = args.grid_search_results_path
pd.concat(all_bond_length_outliers).to_csv(
grid_search_dir / "bond_length_outliers.csv", index=False
)
Expand All @@ -269,5 +251,7 @@ def main(args: argparse.Namespace):


if __name__ == "__main__":
args = parse_args("Evaluate bond angle and length outliers on grid search results.")
main(args)
eval_args = parse_eval_args(
description="Evaluate bond angle and length outliers on grid search results."
)
main(eval_args)
31 changes: 5 additions & 26 deletions scripts/eval/lddt_evaluation_script.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import argparse
import itertools
import re
import sys
import traceback
from pathlib import Path

import numpy as np
import pandas as pd
Expand All @@ -13,7 +11,7 @@
from joblib import delayed, Parallel
from loguru import logger
from sampleworks.eval.eval_dataclasses import ProteinConfig, Trial
from sampleworks.eval.grid_search_eval_utils import parse_args, scan_grid_search_results
from sampleworks.eval.grid_search_eval_utils import parse_eval_args, setup_evaluation_parameters
from sampleworks.eval.structure_utils import get_reference_atomarraystack
from sampleworks.metrics.lddt import AllAtomLDDT
from sampleworks.utils.atom_array_utils import filter_to_common_atoms, map_altlocs_to_stack
Expand Down Expand Up @@ -194,28 +192,9 @@ def translate_selection(selection: str) -> str:
return new_selection


# TODO make more general: https://github.com/diff-use/sampleworks/issues/93
def main(args: argparse.Namespace):
workspace_root = Path(args.workspace_root)

# TODO make more general: https://github.com/diff-use/sampleworks/issues/93
grid_search_dir = workspace_root / "grid_search_results"

# Protein configurations: base map paths, structure selections, and resolutions
protein_inputs_dir = args.grid_search_inputs_path or workspace_root
protein_configs = ProteinConfig.from_csv(protein_inputs_dir, args.protein_configs_csv)

logger.info(f"Grid search directory: {grid_search_dir}")
logger.info(f"Proteins configured: {list(protein_configs.keys())}")

# Scan for trials (look for refined.cif files)
all_trials = scan_grid_search_results(grid_search_dir, target_filename=args.target_filename)
logger.info(f"Found {len(all_trials)} trials with refined.cif files")

if all_trials:
all_trials.summarize() # Prints some summary stats, e.g. number of unique proteins
else:
logger.error("No trials found in grid search directory. Exiting with status 1.")
sys.exit(1)
all_trials, protein_configs = setup_evaluation_parameters(args)

logger.info("Pre-loading reference structures for each protein for coordinate extraction")
reference_atom_arrays = {}
Expand Down Expand Up @@ -313,7 +292,7 @@ def main(args: argparse.Namespace):
)

df = pd.DataFrame(null_results + all_results)
df.to_csv(grid_search_dir / "lddt_results.csv", index=False)
df.to_csv(args.grid_search_results_path / "lddt_results.csv", index=False)


def process_trial_with_selection(
Expand Down Expand Up @@ -378,5 +357,5 @@ def process_trial_with_selection(


if __name__ == "__main__":
args = parse_args("Evaluate LDDT on grid search results.")
args = parse_eval_args("Evaluate LDDT on grid search results.")
main(args)
35 changes: 4 additions & 31 deletions scripts/eval/rscc_grid_search_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import argparse
import copy
import traceback
from pathlib import Path

import numpy as np
import pandas as pd
Expand All @@ -29,8 +28,7 @@
from loguru import logger
from sampleworks.core.forward_models.xray.real_space_density_deps.qfit.volume import XMap
from sampleworks.eval.constants import DEFAULT_SELECTION_PADDING
from sampleworks.eval.eval_dataclasses import ProteinConfig
from sampleworks.eval.grid_search_eval_utils import parse_args, scan_grid_search_results
from sampleworks.eval.grid_search_eval_utils import parse_eval_args, setup_evaluation_parameters
from sampleworks.eval.metrics import rscc
from sampleworks.eval.structure_utils import (
get_asym_unit_from_structure,
Expand All @@ -50,31 +48,7 @@

# TODO consolidate eval script logic: https://github.com/diff-use/sampleworks/issues/93
def main(args: argparse.Namespace):
workspace_root = Path(args.workspace_root)
grid_search_dir = workspace_root / "grid_search_results"

# Protein configurations: base map paths, structure selections, and resolutions
protein_inputs_dir = args.grid_search_inputs_path or workspace_root
protein_configs = ProteinConfig.from_csv(protein_inputs_dir, args.protein_configs_csv)

logger.info(f"Grid search directory: {grid_search_dir}")
logger.info(f"Proteins configured: {list(protein_configs.keys())}")

# Test base map path resolution
logger.debug("Testing base map path resolution:")
for _, config in protein_configs.items():
for occ in args.occupancies:
altloc_occ = {"A": occ, "B": 1.0 - occ}
path = config.get_base_map_path_for_occupancy(altloc_occ) # will warn if not found
if path:
logger.debug(f" {config.protein} occupancies={altloc_occ}: {path}")

# Scan for trials (look for refined.cif files)
all_trials = scan_grid_search_results(grid_search_dir, target_filename=args.target_filename)
logger.info(f"Found {len(all_trials)} trials with refined.cif files")

if all_trials:
all_trials.summarize() # Prints some summary stats, e.g. number of unique proteins
all_trials, protein_configs = setup_evaluation_parameters(args)

logger.info("Pre-loading reference structures for each protein for coordinate extraction")
ref_coords = {}
Expand All @@ -87,7 +61,6 @@ def main(args: argparse.Namespace):
ref_coords[(protein_key, selection)] = protein_ref_coords[selection]

# Calculate RSCC for all trials
# (BIG) TODO: implement a sliding-window version (global can be achieved with diff't selections.
logger.info("Calculating RSCC values for all trials...")
logger.warning(
"Note: RSCC is computed on the region around altloc residues (defined by selection)"
Expand Down Expand Up @@ -296,7 +269,7 @@ def main(args: argparse.Namespace):

# Create DataFrame from results
df = pd.DataFrame(results)
df.to_csv(grid_search_dir / "rscc_results.csv", index=False)
df.to_csv(args.grid_search_results_path / "rscc_results.csv", index=False)

if not df.empty:
# Remove error column for display if present
Expand All @@ -321,5 +294,5 @@ def main(args: argparse.Namespace):


if __name__ == "__main__":
args = parse_args("Evaluate RSCC on grid search results.")
args = parse_eval_args("Evaluate RSCC on grid search results.")
main(args)
63 changes: 14 additions & 49 deletions scripts/eval/run_and_process_phenix_clashscore.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
import json
import subprocess
from pathlib import Path
Expand All @@ -7,41 +6,10 @@
import pandas as pd
from loguru import logger
from sampleworks.eval.eval_dataclasses import Trial
from sampleworks.eval.grid_search_eval_utils import scan_grid_search_results


# TODO unify this with the other parse_args most eval scripts use.
# https://github.com/diff-use/sampleworks/issues/110
def parse_args() -> argparse.Namespace:
"""
Arguments required for evaluating clashscores with phenix.
These are currently slightly different from the other eval scripts.
"""
parser = argparse.ArgumentParser(
description="Crawl the workspace root for CIF files matching --target-filename and"
" run phenix.clashscore on them."
)
parser.add_argument(
"--workspace-root",
type=Path,
required=True,
help="Path containing the grid search results directory, e.g. if results are "
"at $HOME/grid_search_results, $HOME should be what you pass",
)
parser.add_argument(
"--n-jobs",
type=int,
help="Number of parallel jobs to run. -1 uses all CPUs.",
default=16,
)
parser.add_argument(
"--target-filename",
default="refined.cif",
help="Target filename for the CIF files to process",
)
return parser.parse_args()
from sampleworks.eval.grid_search_eval_utils import parse_eval_args, setup_evaluation_parameters


# TODO make more general: https://github.com/diff-use/sampleworks/issues/93
def main(args) -> None:
# check that phenix is installed and available, bail early if not.
try:
Expand All @@ -51,13 +19,8 @@ def main(args) -> None:
"phenix.clashscore is not available, make sure phenix is installed "
" and that you have activated it, e.g. `source phenix-dir/phenix_env.sh`"
)

workspace_root = Path(args.workspace_root)

# TODO make more general: https://github.com/diff-use/sampleworks/issues/93
grid_search_dir = workspace_root / "grid_search_results"
all_trials = scan_grid_search_results(grid_search_dir, target_filename=args.target_filename)
logger.info(f"Found {len(all_trials)} trials with {args.target_filename} files")
# The dropped variable is a list of ProteinConfigs, not used yet in this script

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we reference #97 ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is separate. This doesn't use the ProteinConfigs because it analyzes the whole protein, and doesn't break out individual selections. If we think it would be useful to look at clashes for individual selections I believe that's possible, but should be a separate issue.

all_trials, _ = setup_evaluation_parameters(args)

# Now loop over trials with joblib and get back tuples of trial level metrics
clashscore_metrics = joblib.Parallel(n_jobs=args.n_jobs)(
Expand All @@ -75,7 +38,7 @@ def main(args) -> None:

clashscore_df = pd.concat(clashscore_metrics, ignore_index=True)
clashscore_df.to_csv(
workspace_root / "grid_search_results" / "clashscore_metrics.csv", index=False
args.grid_search_results_path / "clashscore_metrics.csv", index=False
)


Expand All @@ -93,11 +56,11 @@ def process_one_trial(trial: Trial) -> pd.DataFrame:
if retcode != 0:
raise RuntimeError(f"grep failed with code {retcode}, the command was {' '.join(grep_cmd)}")

# phenix needs to be installed and on path for this to work. Also sh won't work with
# phenix needs to be installed and on path for this to work. Also, sh won't work with
# phenix.clashscore because of that pesky period in the name.
with logfile.open("w") as fn:
# phenix.clashscore generates a JSON file with both per-model scores as well as per-model
# lists of clashes.
# phenix.clashscore generates a JSON file with both per-model scores,
# as well as per-model lists of clashes.
retcode = subprocess.call(
["phenix.clashscore", str(file_with_no_nans), "--json-filename", str(json_output)],
stderr=fn,
Expand All @@ -110,15 +73,15 @@ def process_one_trial(trial: Trial) -> pd.DataFrame:

def process_clashscore_json_output(json_output: Path) -> pd.DataFrame:
"""
Opens the json output file `json_output` and parses out the
Opens the JSON output file `json_output` and parses out the
"summary_results", flattening it into rows which include the "model_name" field

"""
with open(json_output) as f:
json_data = json.load(f)

model_name = json_data.get("model_name")
# For now we're only collecting model-level summary statistics, but
# For now, we're only collecting model-level summary statistics, but
# there are lists of specific clashes in each model too.
summary_results = json_data.get("summary_results", {})

Expand All @@ -136,5 +99,7 @@ def process_clashscore_json_output(json_output: Path) -> pd.DataFrame:


if __name__ == "__main__":
args = parse_args()
main(args)
argparse_description = "Crawl the workspace root for CIF files matching "
argparse_description += "--target-filename and run phenix.clashscore on them."
eval_args = parse_eval_args(description=argparse_description)
main(eval_args)
Loading
Loading