From 6eda26cb3cf598489ff76717a3036bcd7eeaeafb Mon Sep 17 00:00:00 2001 From: Ishaan Mathur Date: Tue, 16 Jun 2026 21:18:56 +0000 Subject: [PATCH] oss sync --- .pre-commit-config.yaml | 20 +- README.md | 6 +- cookbook/tutorials/binder_design.ipynb | 2 +- cookbook/tutorials/binder_design.py | 8 +- cookbook/tutorials/embed.ipynb | 54 ++---- .../esmc_sae_feature_interpretation.ipynb | 4 +- cookbook/tutorials/gfp_design.ipynb | 2 +- esm/layers/attention.py | 7 +- esm/layers/codebook.py | 25 +-- esm/layers/rotary.py | 22 +-- esm/models/esm3.py | 6 +- esm/models/esmc.py | 22 +-- esm/models/esmfold2/paired_msa.py | 15 +- esm/models/esmfold2/prepare_input.py | 12 +- esm/models/vqvae.py | 6 +- esm/sdk/api.py | 182 ++++++++++++++++-- .../experimental/constrained_generation.py | 2 +- esm/sdk/experimental/guided_generation.py | 8 +- esm/sdk/forge.py | 7 +- esm/tokenization/__init__.py | 8 +- esm/tokenization/residue_tokenizer.py | 10 +- esm/tokenization/sequence_tokenizer.py | 2 +- esm/utils/encoding.py | 6 +- esm/utils/forge_context_manager.py | 2 +- esm/utils/function/interpro.py | 6 +- esm/utils/function/lsh.py | 6 +- esm/utils/generation.py | 11 +- esm/utils/misc.py | 10 +- esm/utils/msa/msa.py | 131 +++++++++++-- esm/utils/msa/msa_test.py | 174 +++++++++++++++++ esm/utils/parsing.py | 2 +- esm/utils/residue_constants.py | 6 +- esm/utils/structure/affine3d.py | 33 ++-- esm/utils/structure/aligner.py | 4 +- esm/utils/structure/input_builder.py | 2 +- esm/utils/structure/metrics.py | 2 +- esm/utils/structure/mmcif_parsing.py | 31 ++- esm/utils/structure/molecular_complex.py | 134 ++++++++++++- esm/utils/structure/normalize_coordinates.py | 2 +- .../structure/predicted_aligned_error.py | 2 +- esm/utils/structure/protein_chain.py | 40 ++-- esm/utils/structure/protein_complex.py | 37 ++-- esm/utils/structure/protein_structure.py | 6 +- esm/utils/system.py | 2 +- esm/widgets/components/results_visualizer.py | 2 +- .../components/sasa_prompt_selector.py | 4 +- .../secondary_structure_prompt_selector.py | 4 +- .../components/sequence_prompt_selector.py | 4 +- .../components/structure_prompt_selector.py | 4 +- esm/widgets/utils/drawing/colors.py | 2 +- .../drawing/draw_function_annotations.py | 2 +- esm/widgets/utils/prompting.py | 4 +- esm/widgets/utils/protein_import.py | 2 +- esm/widgets/views/esm3_prompt_preview.py | 2 +- pixi.lock | 177 +++++++---------- pyproject.toml | 40 +++- 56 files changed, 964 insertions(+), 362 deletions(-) create mode 100644 esm/utils/msa/msa_test.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a127f34c..14f1cb70 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,17 +22,17 @@ repos: args: [ --fix ] - id: ruff-format # formatter types_or: [python, jupyter] -- repo: https://github.com/RobertCraigie/pyright-python - rev: v1.1.399 +- repo: local hooks: - - id: pyright - name: pyright - entry: pyright - language: system - types: [python] - pass_filenames: true # For speed, we only check the files that are changed - # Modal-app tutorial: deps (modal, abnumber) and dynamic decorators aren't resolvable in the lint env. - exclude: ^cookbook/tutorials/binder_design\.py$ + - id: ty + name: ty + entry: ty check + language: system # ty is a pixi dev dep (pyproject.toml [tool.pixi.feature.dev]); pre-commit runs in the pixi env + # pass_filenames: false — ty checks the whole project (it needs the full module graph and + # is fast enough), and per-file invocation would bypass [tool.ty.src].exclude (astral-sh/ty#269). + pass_filenames: false + always_run: true + require_serial: true - repo: https://github.com/gitleaks/gitleaks rev: v8.24.2 hooks: diff --git a/README.md b/README.md index 40932d2f..7b6ddb3e 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ # A world model of protein biology: ESMC, ESMFold2, & ESM Atlas -[ESMC & ESMFold2 Preprint](https://biohub.ai/papers/esm_protein.pdf) ⋅ [Atlas](https://biohub.ai/esm/protein/atlas) ⋅ [Tutorials](https://github.com/Biohub/esm/tree/main/cookbook/tutorials) ⋅ [Slack](https://bit.ly/esm-slack)
+[ESMC & ESMFold2 Preprint](https://www.biorxiv.org/content/10.64898/2026.06.03.729735) ⋅ [Atlas](https://biohub.ai/esm/protein/atlas) ⋅ [Tutorials](https://github.com/Biohub/esm/tree/main/cookbook/tutorials) ⋅ [Slack](https://bit.ly/esm-slack)
We are releasing a world model for protein biology: a scientific engine for prediction, design, and discovery. Built on the latest generation of Evolutionary Scale Modeling (ESM), this system learns from the protein sequences produced by evolution and uses that knowledge to represent, map, predict, and design proteins across scales — from atomic interactions to evolutionary relationships spanning billions of years. The system includes three artifacts: ESMC, ESMFold2, and ESM Atlas. @@ -25,7 +25,7 @@ We are releasing a world model for protein biology: a scientific engine for pred -ESMFold2 is validated in the lab across five therapeutic targets. Inversion of ESMFold2 enables generation of de novo minibinders and antibody-derived scFvs with high hit rates, nanomolar affinities, target specificity, and functional activity. We've released the full protocol from target sequence to ranked binder design in this [notebook](https://github.com/Biohub/esm/blob/main/cookbook/tutorials/binder_design.ipynb). For additional details, please refer to the [preprint](https://biohub.ai/papers/esm_protein.pdf). +ESMFold2 is validated in the lab across five therapeutic targets. Inversion of ESMFold2 enables generation of de novo minibinders and antibody-derived scFvs with high hit rates, nanomolar affinities, target specificity, and functional activity. We've released the full protocol from target sequence to ranked binder design in this [notebook](https://github.com/Biohub/esm/blob/main/cookbook/tutorials/binder_design.ipynb). For additional details, please refer to the [preprint](https://www.biorxiv.org/content/10.64898/2026.06.03.729735).
@@ -320,7 +320,7 @@ If you use ESM in your work, please cite one of the following: and Pannu, Jassi and Bachas, Sharrol and Liu, Daniel S. and Sercu, Tom and Rives, Alexander}, year = {2026}, - url = {https://biohub.ai/papers/esm_protein.pdf}, + url = {https://www.biorxiv.org/content/10.64898/2026.06.03.729735}, note = {Preprint} } ``` diff --git a/cookbook/tutorials/binder_design.ipynb b/cookbook/tutorials/binder_design.ipynb index 7ec1b867..be4ddd1c 100644 --- a/cookbook/tutorials/binder_design.ipynb +++ b/cookbook/tutorials/binder_design.ipynb @@ -7,7 +7,7 @@ "source": [ "## [Tutorial](https://github.com/biohub/esm/tree/main/cookbook/tutorials): How to run minibinder + scFv design fully end-to-end.\n", "\n", - "In this notebook we will use [Modal](https://modal.com/) to parallelize binder design and synthesize a selection, using the protocol described in the ESMC and ESMFold2 paper titled [\"Language Modeling Materializes a World Model of Protein Biology\"](https://biohub.ai/papers/esm_protein.pdf).\n", + "In this notebook we will use [Modal](https://modal.com/) to parallelize binder design and synthesize a selection, using the protocol described in the ESMC and ESMFold2 paper titled [\"Language Modeling Materializes a World Model of Protein Biology\"](https://www.biorxiv.org/content/10.64898/2026.06.03.729735).\n", "\n", "Biohub used this approach to design minibinders and scFvs against five therapeutically relevant targets — PDGFRB, EGFR, PD-L1, CD45, and CTLA4 — spanning receptor tyrosine kinases, immune checkpoints, and cell-surface phosphatases. Binders exhibit nanomolar affinity, target specificity, and functional activity in laboratory assays.\n", "\n", diff --git a/cookbook/tutorials/binder_design.py b/cookbook/tutorials/binder_design.py index 8040f164..3fd97049 100644 --- a/cookbook/tutorials/binder_design.py +++ b/cookbook/tutorials/binder_design.py @@ -9,7 +9,7 @@ """ Code for binder design with ESMFold2 and ESMC. -As described in [Language Modeling Materializes a World Model of Protein Biology](https://biohub.ai/papers/esm_protein.pdf). +As described in [Language Modeling Materializes a World Model of Protein Biology](https://www.biorxiv.org/content/10.64898/2026.06.03.729735). """ import logging @@ -1051,7 +1051,7 @@ def _apply_torch_compile(model: torch.nn.Module) -> None: def _maybe_compile_module(module: torch.nn.Module) -> None: if not isinstance(module, compile_targets): return - module.forward = torch.compile(module.forward) # pyright: ignore + module.forward = torch.compile(module.forward) # ty:ignore[invalid-assignment] model.apply(_maybe_compile_module) @@ -1206,7 +1206,9 @@ def main( app.load(use_scaling_critics) run_fn = app.design else: - app = ESMFold2DesignModal(use_scaling_critics=use_scaling_critics) + app = ESMFold2DesignModal( + use_scaling_critics=use_scaling_critics # ty:ignore[unknown-argument] + ) run_fn = app.design.remote seq, trajectory, results = run_fn( diff --git a/cookbook/tutorials/embed.ipynb b/cookbook/tutorials/embed.ipynb index 4b1596da..38678117 100644 --- a/cookbook/tutorials/embed.ipynb +++ b/cookbook/tutorials/embed.ipynb @@ -25,7 +25,8 @@ "# If you are working in colab, uncomment these lines to install dependencies\n", "#! pip install esm@git+https://github.com/Biohub/esm.git@main\n", "#! pip install matplotlib\n", - "#! pip install seaborn" + "#! pip install seaborn\n", + "#! pip install remotezip" ] }, { @@ -137,39 +138,24 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "--2026-05-28 16:44:15-- https://docs.google.com/uc?export=download&id=1SpOkL11MJxIgy99dqufvUNJuCiuhxuyg\n", - "Resolving docs.google.com (docs.google.com)... 142.251.210.78\n", - "Connecting to docs.google.com (docs.google.com)|142.251.210.78|:443... connected.\n", - "HTTP request sent, awaiting response... 303 See Other\n", - "Location: https://drive.usercontent.google.com/download?id=1SpOkL11MJxIgy99dqufvUNJuCiuhxuyg&export=download [following]\n", - "--2026-05-28 16:44:15-- https://drive.usercontent.google.com/download?id=1SpOkL11MJxIgy99dqufvUNJuCiuhxuyg&export=download\n", - "Resolving drive.usercontent.google.com (drive.usercontent.google.com)... 192.178.50.65\n", - "Connecting to drive.usercontent.google.com (drive.usercontent.google.com)|192.178.50.65|:443... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 43132 (42K) [application/octet-stream]\n", - "Saving to: 'adk.csv'\n", - "\n", - "adk.csv 100%[===================>] 42.12K --.-KB/s in 0.02s \n", - "\n", - "2026-05-28 16:44:16 (2.17 MB/s) - 'adk.csv' saved [43132/43132]\n", - "\n" - ] - } - ], + "outputs": [], "source": [ - "!wget --no-check-certificate \"https://docs.google.com/uc?export=download&id=1SpOkL11MJxIgy99dqufvUNJuCiuhxuyg\" -O adk.csv" + "# Stream just the one CSV we need out of the 200MB archive, no full download.\n", + "from remotezip import RemoteZip\n", + "\n", + "DATA_URL = \"https://zenodo.org/records/15022271/files/data.zip\"\n", + "MEMBER = \"data/adk_ml_dataset.csv\"\n", + "\n", + "with RemoteZip(DATA_URL) as zf:\n", + " with zf.open(MEMBER) as src, open(\"adk_ml_dataset.csv\", \"wb\") as dst:\n", + " dst.write(src.read())" ] }, { "cell_type": "code", - "execution_count": 28, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -177,7 +163,7 @@ "import pandas as pd\n", "import seaborn as sns\n", "\n", - "adk_path = \"adk.csv\"\n", + "adk_path = \"adk_ml_dataset.csv\"\n", "df = pd.read_csv(adk_path)\n", "df = df[[\"org_name\", \"sequence\", \"lid_type\", \"temperature\"]]\n", "df = df[df[\"lid_type\"] != \"other\"] # drop one structural class for simplicity" @@ -231,7 +217,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -244,7 +230,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -322,9 +308,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Pixi (ESM)", + "display_name": "default", "language": "python", - "name": "pixi-esm" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -336,7 +322,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.12" + "version": "3.12.9" } }, "nbformat": 4, diff --git a/cookbook/tutorials/esmc_sae_feature_interpretation.ipynb b/cookbook/tutorials/esmc_sae_feature_interpretation.ipynb index 7e0f8e35..5791f683 100644 --- a/cookbook/tutorials/esmc_sae_feature_interpretation.ipynb +++ b/cookbook/tutorials/esmc_sae_feature_interpretation.ipynb @@ -493,7 +493,7 @@ " for pos in top_positions:\n", " if activations[pos] > 0:\n", " ax.annotate(\n", - " f\"{sequence[pos]}{pos+1}\",\n", + " f\"{sequence[pos]}{pos + 1}\",\n", " (pos + 1, activations[pos]),\n", " textcoords=\"offset points\",\n", " xytext=(0, 6),\n", @@ -508,7 +508,7 @@ " top_positions = np.argsort(activations)[::-1][:5]\n", " for pos in top_positions:\n", " if activations[pos] > 0:\n", - " print(f\" Position {pos+1} ({sequence[pos]}): {activations[pos]:.3f}\")\n", + " print(f\" Position {pos + 1} ({sequence[pos]}): {activations[pos]:.3f}\")\n", " print()" ] }, diff --git a/cookbook/tutorials/gfp_design.ipynb b/cookbook/tutorials/gfp_design.ipynb index f5e4140b..550434b3 100644 --- a/cookbook/tutorials/gfp_design.ipynb +++ b/cookbook/tutorials/gfp_design.ipynb @@ -432,7 +432,7 @@ "alignment = alignments[0]\n", "\n", "identity = align.get_sequence_identity(alignment)\n", - "print(f\"Sequence identity: {100*identity:.2f}%\")\n", + "print(f\"Sequence identity: {100 * identity:.2f}%\")\n", "\n", "print(\"\\nSequence alignment:\")\n", "fig = pl.figure(figsize=(8.0, 4.0))\n", diff --git a/esm/layers/attention.py b/esm/layers/attention.py index 3a8a602d..718dea5f 100644 --- a/esm/layers/attention.py +++ b/esm/layers/attention.py @@ -8,9 +8,9 @@ from esm.layers.rotary import RotaryEmbedding, TritonRotaryEmbedding try: - from flash_attn import flash_attn_varlen_qkvpacked_func # type: ignore + from flash_attn import flash_attn_varlen_qkvpacked_func except (ImportError, RuntimeError): - flash_attn_varlen_qkvpacked_func = None + flash_attn_varlen_qkvpacked_func = None # ty:ignore[invalid-assignment] class MultiHeadAttention(nn.Module): @@ -135,7 +135,8 @@ def forward( ) qkv_N3HD = self.rotary(qkv_N3HD, cu_seqlens, max_seqlen) - context_NHD = flash_attn_varlen_qkvpacked_func( # type: ignore + assert flash_attn_varlen_qkvpacked_func is not None + context_NHD = flash_attn_varlen_qkvpacked_func( qkv_N3HD, cu_seqlens, max_seqlen, softmax_scale=self.d_head**-0.5 ) context_ND = einops.rearrange(context_NHD, "n h d -> n (h d)") diff --git a/esm/layers/codebook.py b/esm/layers/codebook.py index c2136bf2..14ce37a0 100644 --- a/esm/layers/codebook.py +++ b/esm/layers/codebook.py @@ -6,6 +6,10 @@ class EMACodebook(nn.Module): + embeddings: torch.Tensor + N: torch.Tensor + z_avg: torch.Tensor + def __init__( self, n_codes, @@ -17,7 +21,7 @@ def __init__( super().__init__() self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim)) self.register_buffer("N", torch.zeros(n_codes)) - self.register_buffer("z_avg", self.embeddings.data.clone()) # pyright: ignore[reportCallIssue] + self.register_buffer("z_avg", self.embeddings.data.clone()) self.n_codes = n_codes self.embedding_dim = embedding_dim @@ -50,9 +54,9 @@ def _init_embeddings(self, z): _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes] if dist.is_initialized(): dist.broadcast(_k_rand, 0) - self.embeddings.data.copy_(_k_rand) # pyright: ignore[reportCallIssue] - self.z_avg.data.copy_(_k_rand) # pyright: ignore[reportCallIssue] - self.N.data.copy_(torch.ones(self.n_codes)) # pyright: ignore[reportCallIssue] + self.embeddings.data.copy_(_k_rand) + self.z_avg.data.copy_(_k_rand) + self.N.data.copy_(torch.ones(self.n_codes)) def forward(self, z): # z: [b, t, c] @@ -62,17 +66,14 @@ def forward(self, z): flat_inputs = z.view(-1, self.embedding_dim) distances = ( (flat_inputs**2).sum(dim=1, keepdim=True) - - 2 * flat_inputs @ self.embeddings.t() # pyright: ignore[reportCallIssue] - + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) # pyright: ignore[reportCallIssue] + - 2 * flat_inputs @ self.embeddings.t() + + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) ) # [bt, c] encoding_indices = torch.argmin(distances, dim=1) encoding_indices = encoding_indices.view(*z.shape[:2]) # [b, t, ncode] - embeddings = F.embedding( - encoding_indices, - self.embeddings, # pyright: ignore[reportArgumentType] - ) # [b, t, c] # pyright: ignore[reportArgumentType] + embeddings = F.embedding(encoding_indices, self.embeddings) # [b, t, c] commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) @@ -84,8 +85,8 @@ def forward(self, z): return embeddings_st, encoding_indices, commitment_loss def dictionary_lookup(self, encodings): - embeddings = F.embedding(encodings, self.embeddings) # pyright: ignore[reportArgumentType] + embeddings = F.embedding(encodings, self.embeddings) return embeddings def soft_codebook_lookup(self, weights: torch.Tensor) -> torch.Tensor: - return weights @ self.embeddings # pyright: ignore[reportOperatorIssue] + return weights @ self.embeddings diff --git a/esm/layers/rotary.py b/esm/layers/rotary.py index 4ff9d387..08ca8b70 100644 --- a/esm/layers/rotary.py +++ b/esm/layers/rotary.py @@ -26,11 +26,9 @@ from einops import rearrange, repeat try: - from flash_attn.ops.triton.rotary import ( # type:ignore - apply_rotary as apply_triton_rotary, - ) + from flash_attn.ops.triton.rotary import apply_rotary as apply_triton_rotary except ImportError: - apply_triton_rotary = None + apply_triton_rotary = None # ty:ignore[invalid-assignment] def rotate_half(x, interleaved=False): @@ -169,23 +167,23 @@ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): else: inv_freq = self.inv_freq else: - t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) # pyright: ignore[reportArgumentType, reportCallIssue] + t = torch.arange( + seqlen, device=device, dtype=self.inv_freq.dtype + ) # ty:ignore[no-matching-overload] t /= self.scaling_factor inv_freq = self.inv_freq # Don't do einsum, it converts fp32 to fp16 under AMP # freqs = torch.einsum("i,j->ij", t, self.inv_freq) - freqs = torch.outer(t, inv_freq) # pyright: ignore[reportArgumentType] + freqs = torch.outer(t, inv_freq) # ty:ignore[invalid-argument-type] if self.scale is None: self._cos_cached = torch.cos(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype) else: power = ( - torch.arange( # pyright: ignore[reportCallIssue] - seqlen, - dtype=self.scale.dtype, # pyright: ignore[reportArgumentType] - device=self.scale.device, # pyright: ignore[reportArgumentType] - ) + torch.arange( + seqlen, dtype=self.scale.dtype, device=self.scale.device + ) # ty:ignore[no-matching-overload] - seqlen // 2 ) / self.scale_base scale = self.scale.to(device=power.device) ** power.unsqueeze(-1) @@ -225,7 +223,7 @@ def forward( self.interleaved, True, # inplace=True ), - ) # type: ignore + ) else: assert False diff --git a/esm/models/esm3.py b/esm/models/esm3.py index a9faa371..7a3a0d8d 100644 --- a/esm/models/esm3.py +++ b/esm/models/esm3.py @@ -421,8 +421,8 @@ def batch_generate( self, inputs, # type: ignore configs, - self.tokenizers, # type: ignore - ) + self.tokenizers, + ) # ty:ignore[invalid-return-type] else: raise ValueError("Input must be an ESMProtein or ESMProteinTensor") @@ -535,7 +535,7 @@ def logits( with ( torch.no_grad(), # Assume no gradients for now... - torch.autocast(enabled=True, device_type=device.type, dtype=torch.bfloat16) # type: ignore + torch.autocast(enabled=True, device_type=device.type, dtype=torch.bfloat16) if device.type == "cuda" else contextlib.nullcontext(), ): diff --git a/esm/models/esmc.py b/esm/models/esmc.py index f9a0f1cc..10a8ca7e 100644 --- a/esm/models/esmc.py +++ b/esm/models/esmc.py @@ -8,12 +8,12 @@ from attr import dataclass try: - from flash_attn.bert_padding import pad_input, unpad_input # type:ignore + from flash_attn.bert_padding import pad_input, unpad_input is_flash_attn_available = True except ImportError: - pad_input = None - unpad_input = None + pad_input = None # ty:ignore[invalid-assignment] + unpad_input = None # ty:ignore[invalid-assignment] is_flash_attn_available = False from esm.layers.regression_head import RegressionHead @@ -141,7 +141,9 @@ def forward( output_attentions = bool(output_attentions) if sequence_id is None: # For EMSC, a boolean mask is created in place of sequence_id if not specified. - sequence_id = sequence_tokens != self.tokenizer.pad_token_id + sequence_id = ( + sequence_tokens != self.tokenizer.pad_token_id + ) # ty:ignore[invalid-assignment] x = self.embed(sequence_tokens) @@ -154,13 +156,11 @@ def forward( "output_attentions is not supported with flash attention." ) assert ( - sequence_id.dtype == torch.bool + sequence_id.dtype == torch.bool # ty:ignore[unresolved-attribute] ), "sequence_id must be a boolean mask if Flash Attention is used" - assert sequence_id.shape == (B, L) + assert sequence_id.shape == (B, L) # ty:ignore[unresolved-attribute] assert unpad_input is not None - x, indices, *_ = unpad_input( # type: ignore - x, sequence_id - ) + x, indices, *_ = unpad_input(x, sequence_id) else: indices = None @@ -179,7 +179,7 @@ def forward( ] # Stack hidden states into a [n_layers, B, L, D] matrix. - hidden_states = torch.stack(hidden_states, dim=0) # type: ignore + hidden_states = torch.stack(hidden_states, dim=0) sequence_logits = self.sequence_head(x) output = ESMCOutput( @@ -225,7 +225,7 @@ def logits( with ( torch.no_grad(), - torch.autocast(enabled=True, device_type=device.type, dtype=torch.bfloat16) # type: ignore + torch.autocast(enabled=True, device_type=device.type, dtype=torch.bfloat16) if device.type == "cuda" else contextlib.nullcontext(), ): diff --git a/esm/models/esmfold2/paired_msa.py b/esm/models/esmfold2/paired_msa.py index 7a9d3027..57c2f8ea 100644 --- a/esm/models/esmfold2/paired_msa.py +++ b/esm/models/esmfold2/paired_msa.py @@ -16,7 +16,7 @@ PROTEIN_RESIDUE_TO_RES_TYPE, PROTEIN_UNK_RES_TYPE, ) -from esm.utils.msa.msa import MSA +from esm.utils.msa.msa import MSA, is_a3m_insertion _KEY_RE = re.compile(r"key=(-?\d+)") @@ -48,9 +48,13 @@ def msa_to_res_type_and_deletions( insertions and are not emitted; their count is accumulated into the next non-insertion position's deletion value. ``L`` is the query length after stripping insertions from row 0. + + If ``msa.deletions`` is set (e.g. by :meth:`MSA.from_a3m`) it is returned + directly: the stored sequences may already be insertion-stripped, which would + otherwise yield all-zero deletions. """ query = msa.entries[0].sequence - L = sum(1 for ch in query if not (ch.islower() or ch == ".")) + L = sum(1 for ch in query if not is_a3m_insertion(ch)) M = msa.depth res_type = np.full((M, L), MSA_GAP_TOKEN_ID, dtype=np.int64) @@ -60,7 +64,7 @@ def msa_to_res_type_and_deletions( col = 0 ins = 0 for ch in entry.sequence: - if ch == "." or (ch.islower() and ch != "-"): + if is_a3m_insertion(ch): ins += 1 continue if col >= L: @@ -75,6 +79,11 @@ def msa_to_res_type_and_deletions( deletions[r, col] = float(ins) ins = 0 col += 1 + + if msa.deletions is not None: + msg = f"stored deletions {msa.deletions.shape} != expected {(M, L)}" + assert msa.deletions.shape == (M, L), msg + deletions = msa.deletions.astype(np.float32) return res_type, deletions diff --git a/esm/models/esmfold2/prepare_input.py b/esm/models/esmfold2/prepare_input.py index fbbbe4e8..272ded59 100644 --- a/esm/models/esmfold2/prepare_input.py +++ b/esm/models/esmfold2/prepare_input.py @@ -582,7 +582,7 @@ def tokenize_ligand_smiles( mol = Chem.AddHs(mol) # Assign atom names using canonical ranking - canonical_order = AllChem.CanonicalRankAtoms(mol) # type: ignore[attr-defined] + canonical_order = AllChem.CanonicalRankAtoms(mol) # ty:ignore[unresolved-attribute] for atom, can_idx in zip(mol.GetAtoms(), canonical_order): atom_name = atom.GetSymbol().upper() + str(can_idx + 1) if len(atom_name) > 4: @@ -592,17 +592,19 @@ def tokenize_ligand_smiles( atom.SetProp("name", atom_name) # Generate 3D conformer - options = AllChem.ETKDGv3() # type: ignore[attr-defined] + options = AllChem.ETKDGv3() # ty:ignore[unresolved-attribute] options.clearConfs = False if seed is not None: options.randomSeed = seed - conf_id = AllChem.EmbedMolecule(mol, options) # type: ignore[attr-defined] + conf_id = AllChem.EmbedMolecule(mol, options) # ty:ignore[unresolved-attribute] if conf_id == -1: options.useRandomCoords = True - conf_id = AllChem.EmbedMolecule(mol, options) # type: ignore[attr-defined] + conf_id = AllChem.EmbedMolecule(mol, options) # ty:ignore[unresolved-attribute] if conf_id != -1: try: - AllChem.UFFOptimizeMolecule(mol, confId=conf_id, maxIters=1000) # type: ignore[attr-defined] + AllChem.UFFOptimizeMolecule( # ty:ignore[unresolved-attribute] + mol, confId=conf_id, maxIters=1000 + ) except (RuntimeError, ValueError): pass diff --git a/esm/models/vqvae.py b/esm/models/vqvae.py index 713dd716..8070286f 100644 --- a/esm/models/vqvae.py +++ b/esm/models/vqvae.py @@ -280,7 +280,7 @@ def find_knn_edges( (coords.shape[0], coords.shape[1]), device=coords.device ).long() - with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): # type: ignore + with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): ca = coords[..., 1, :] edges, edge_mask = knn_graph( ca, coord_mask, padding_mask, sequence_id, no_knn=knn @@ -419,8 +419,8 @@ def decode( # This might be broken for chainbreak tokens? We might align to the chainbreak ptm = compute_tm( pae_logits, # type: ignore - aa_mask=~special_tokens_mask, - max_bin=self.max_pae_bin, + aa_mask=~special_tokens_mask, # ty:ignore[unknown-argument] + max_bin=self.max_pae_bin, # ty:ignore[unknown-argument] ) plddt_logits = self.plddt_head(x) diff --git a/esm/sdk/api.py b/esm/sdk/api.py index d0a68e47..86777646 100644 --- a/esm/sdk/api.py +++ b/esm/sdk/api.py @@ -320,20 +320,38 @@ class ESMProteinError(Exception, ProteinType): ## High Level Endpoint Types @define class GenerationConfig: + """ + track (str): Track to generate: sequence, structure, secondary_structure, sasa, + or function. + invalid_ids (Sequence[int]): Token indices that should not be sampled. + schedule (str): Unmasking schedule for generation. Controls the number of tokens + to unmask during each round of iterative generation. + strategy (str): Unmasking strategy to use. Controls which tokens to unmask + during each round of iterative generation. 'random' will unmask a correct + number of tokens randomly. 'entropy' will unmask the tokens with the lowest + logit entropy first. Default was random. Updated on 02/14/2025. + num_steps (int): Number of steps for generation. There is diminishing return for + decoding steps more than 20. Note that this needs to be less than or equal + to the sequence length. Default was 8. Updated on 02/14/2025. + temperature (float): Temperature for sampling. Default was 1.0. Updated on + 02/14/2025. + temperature_annealing (bool): Whether temperature should be annealed during + generation. Default was False. Updated on 02/14/2025. + top_p (float): Top-p sampling. + condition_on_coordinates_only (bool): Use coordinates instead of structure + tokens as generation conditioning. + only_compute_backbone_rmsd (bool): Only compute the RMSD of the backbone atoms. + Affects the returned crmsd. + """ + track: str = "" invalid_ids: Sequence[int] = [] - # Controls the number of tokens to unmask during each round of iterative generation. schedule: str = attr.field( validator=attr.validators.in_(["cosine", "linear"]), default="cosine" ) - # Controls which tokens to unmask during each round of iterative generation. - # "random" will unmask a correct number of tokens randomly. - # "entropy" will unmask the tokens with the lowest logit entropy first. strategy: str = attr.field( validator=attr.validators.in_(["random", "entropy"]), default="random" ) - # Setting default to 20, as there is diminishing return for decoding steps more than 20. - # Note that this needs to be less than or equal to the sequence length. num_steps: int = 20 temperature: float = 1.0 temperature_annealing: bool = True @@ -356,21 +374,50 @@ def use_generative_unmasking_strategy(self): @define class InverseFoldingConfig: + """ + invalid_ids (Sequence[int]): Token indices that should not be sampled. + temperature (float): Temperature for sampling. For inverse folding models, we + recommend getting diverse predictions by changing the seed and not by + increasing the temperature. + """ + invalid_ids: Sequence[int] = [] temperature: float = 0.1 @define class FoldingConfig: + """ + include_distogram (bool): (ESMFold2) Whether to include distogram predictions in + the response. + include_pae (bool): (ESMFold2) Whether to include Predicted Aligned Error (PAE) + matrix in the response. + include_pair_chains_iptm (bool): (ESMFold2) Whether to include pair-chain IPTM + predictions in the response. + num_sampling_steps (int): (ESMFold2) Diffusion ODE solver steps. Lower for + speed, higher for quality. + num_loops (int): (ESMFold2) Number of trunk loops for iterative refinement. + lm_dropout (float): (ESMFold2) Dropout probability on LM pair embeddings. When > + 0, dropout is applied. + lm_mask_pct (float | None): (ESMFold2) Fraction of sequence residues randomly + masked before the PLM backbone. If not provided, defaults to 0.1 for + ESMFOLD2_FAST and 0.0 for ESMFOLD2 + msa_max_depth (int | None): (ESMFold2) Number of MSA rows randomly subsampled + each loop. Set to null to disable (sets msa_subsample_at_inference to + False). + msa_column_mask_rate (float): (ESMFold2) Fraction of MSA columns randomly masked + in non-query rows for inference-time diversity. + include_embeddings (bool): (ESMFold2) Whether to include sequence and pair + embeddings in the response. + """ + include_distogram: bool = False include_pae: bool = False include_pair_chains_iptm: bool = False num_sampling_steps: int = 100 num_loops: int = 20 lm_dropout: float = 0.3 - lm_mask_pct: float | None = ( - None # If not provided, defaults to 0.1 for ESMFOLD2_FAST and 0.0 for ESMFOLD2 - ) + lm_mask_pct: float | None = None msa_max_depth: int | None = 1024 msa_column_mask_rate: float = 0.1 include_embeddings: bool = False @@ -379,6 +426,14 @@ class FoldingConfig: ## Low Level Endpoint Types @define class SamplingTrackConfig: + """ + temperature (float): Temperature for sampling. + top_p (float): Sample from logits within the top-p probability. + only_sample_masked_tokens (bool): Only sample for masked tokens. + invalid_ids (Sequence[int]): Token indices that should not be sampled. + topk_logprobs (int): Number of top ranking prediction and logprobs to return. + """ + temperature: float = 1.0 top_p: float = 1.0 only_sample_masked_tokens: bool = True @@ -388,6 +443,21 @@ class SamplingTrackConfig: @define class SamplingConfig: + """ + sequence (SamplingTrackConfig | None): Sampling configuration for the sequence + track. + structure (SamplingTrackConfig | None): Sampling configuration for the structure + track. + secondary_structure (SamplingTrackConfig | None): Sampling configuration for the + secondary structure track. + sasa (SamplingTrackConfig | None): Sampling configuration for the SASA track. + function (SamplingTrackConfig | None): Sampling configuration for the function + annotation track. + return_per_residue_embeddings (bool): Whether to return per-residue embeddings. + return_mean_embedding (bool): Whether to return the embedding mean-pooled over + the sequence length. + """ + sequence: SamplingTrackConfig | None = attr.field( default=None, metadata={"max_topk": C.MAX_TOPK_SEQUENCE} ) @@ -410,6 +480,14 @@ class SamplingConfig: @define class ForwardTrackData: + """ + sequence (torch.Tensor | None): Sequence track logits. + structure (torch.Tensor | None): Structure track logits. + secondary_structure (torch.Tensor | None): Secondary structure track logits. + sasa (torch.Tensor | None): Solvent accessible surface area (SASA) track logits. + function (torch.Tensor | None): Function annotations logits. + """ + sequence: torch.Tensor | None = None structure: torch.Tensor | None = None secondary_structure: torch.Tensor | None = None @@ -419,6 +497,39 @@ class ForwardTrackData: @define class LogitsConfig: + """ + sequence (bool): Return sequence logits. + structure (bool): Return structure logits. + secondary_structure (bool): Return secondary structure logits. + sasa (bool): Return sasa logits. + function (bool): Return function logits. + residue_annotations (bool): Return residue annotations logits. + return_embeddings (bool): Whether embeddings should be returned. + return_hidden_states (bool): Whether to return per-residue hidden states. With + ith_hidden_layer=-1, returns all layers as a tensor of shape [n_layers + 1, + B, L, D]. With ith_hidden_layer!= -1, returns the selected layer as a tensor + of shape [1, B, L, D]. + return_mean_embedding (bool): Whether mean embeddings should be returned. + return_mean_hidden_states (bool): Whether hidden states mean-pooled along the + sequence length (L) dimension should be returned. Returns a tensor of shape + [B, n_layers + 1, D]. + ith_hidden_layer (int): Valid values for ith_hidden_layer are 0 to + max_ith_hidden_layer (inclusive), where index 0 is the embedding layer. -1 + returns all layers, but is not supported for ESMC 6B or any ESM3 model. Here + is the max_ith_hidden_layer for each ESMC and ESM3 model (except ESM3 + Large). + | Model Name | max_ith_hidden_layer | + |-------------------------------|--------------------------------| + | esmc-300-2024-12 | 30 | + | esmc-600-2024-12 | 36 | + | esmc-6b-2024-12 | 80 | + | esm3-small-2024-03 | 48 | + | esm3-small-2024-08 | 48 | + | esm3-medium-2024-03 | 96 | + | esm3-medium-2024-08 | 96 | + sae_config (SAEConfig | None): SAE config. Only applies to ESMC models. + """ + # Logits. sequence: bool = False @@ -439,15 +550,22 @@ class LogitsConfig: return_mean_hidden_states: bool = False ith_hidden_layer: int = -1 - # SAE config only applies to ESMC models sae_config: SAEConfig | None = None @define class SAEConfig: + """ + models (list[str]): List of SAE models with specific layer and codebook size. + normalize_features (bool): Normalize computed features before return. Default to + True. + model (str | None): Deprecated, use 'models' instead. SAE model with specific + layer and codebook size. + """ + models: list[str] = attr.Factory(list) normalize_features: bool = True - model: str | None = None # deprecated, use models + model: str | None = None def __attrs_post_init__(self): if self.model is not None: @@ -474,6 +592,24 @@ def __attrs_post_init__(self): @define class LogitsOutput: + """ + logits (ForwardTrackData | None): Per-track categorical logits, populated for each + track requested via LogitsConfig. + embeddings (torch.Tensor | None): Per-residue embeddings (final hidden state). + Returned when LogitsConfig.return_embeddings is set. + mean_embedding (torch.Tensor | None): Embedding mean-pooled over the sequence + length. Returned when LogitsConfig.return_mean_embedding is set. + residue_annotation_logits (torch.Tensor | None): Residue annotation logits. These + are multi-hot (bernoulli), so they are kept separate from `logits` (which holds + categorical per-track logits). + hidden_states (torch.Tensor | None): Hidden states for the requested layer(s). + Returned when LogitsConfig.return_hidden_states is set. + mean_hidden_state (torch.Tensor | None): Hidden states mean-pooled over the + sequence length. Returned when LogitsConfig.return_mean_hidden_states is set. + sae_outputs (dict[str, torch.Tensor] | None): SAE activations keyed by SAE model + name. Returned when LogitsConfig.sae_config is set. + """ + logits: ForwardTrackData | None = None embeddings: torch.Tensor | None = None mean_embedding: torch.Tensor | None = None @@ -484,22 +620,38 @@ class LogitsOutput: residue_annotation_logits: torch.Tensor | None = None hidden_states: torch.Tensor | None = None mean_hidden_state: torch.Tensor | None = None - # sae_outputs keys are sae model names and values are sparse representations of the sae activations sae_outputs: dict[str, torch.Tensor] | None = None @define class ForwardAndSampleOutput(LogitsOutput): + """Output of forward_and_sample. Extends LogitsOutput with the sampled tokens and + per-position sampling statistics (each ForwardTrackData holds one value per track). + + protein_tensor (ESMProteinTensor): The sampled tokens. + entropy (ForwardTrackData | None): Per-position entropy of the predicted + distribution, per track. + prob (ForwardTrackData | None): Probability of the sampled token at each position. + logprob (ForwardTrackData | None): Log-probability of the sampled token at each + position. + top_prob (ForwardTrackData | None): Highest token probability at each position. + topk_logprob (ForwardTrackData | None): Log-probabilities of the top-k tokens at + each position. Populated when PerTrackSamplingConfig.topk_logprobs is set. + topk_tokens (ForwardTrackData | None): Token ids of the top-k tokens at each + position. Populated when PerTrackSamplingConfig.topk_logprobs is set. + per_residue_embedding (torch.Tensor | None): Per-residue embeddings. Returned when + SamplingConfig.return_per_residue_embeddings is set. + mean_embedding (torch.Tensor | None): Embedding mean-pooled over the sequence + length. Returned when SamplingConfig.return_mean_embedding is set. + """ + protein_tensor: ESMProteinTensor = ESMProteinTensor() entropy: ForwardTrackData | None = None - # Probability of sampled token prob: ForwardTrackData | None = None logprob: ForwardTrackData | None = None - # Top probability at this position top_prob: ForwardTrackData | None = None topk_logprob: ForwardTrackData | None = None - # Which tokens correspond to top probability topk_tokens: ForwardTrackData | None = None per_residue_embedding: torch.Tensor | None = None mean_embedding: torch.Tensor | None = None diff --git a/esm/sdk/experimental/constrained_generation.py b/esm/sdk/experimental/constrained_generation.py index 967b7519..e3809b71 100644 --- a/esm/sdk/experimental/constrained_generation.py +++ b/esm/sdk/experimental/constrained_generation.py @@ -176,7 +176,7 @@ def _propose_and_eval(pt: ESMProteinTensor): if isinstance(results, Exception): raise results - samples, rewards, gh_lists, val_lists = zip(*results) # type: ignore + samples, rewards, gh_lists, val_lists = zip(*results) else: samples, rewards, gh_lists, val_lists = [], [], [], [] for _ in range(num_samples_per_step): diff --git a/esm/sdk/experimental/guided_generation.py b/esm/sdk/experimental/guided_generation.py index 9acd32e8..e11344e7 100644 --- a/esm/sdk/experimental/guided_generation.py +++ b/esm/sdk/experimental/guided_generation.py @@ -114,7 +114,7 @@ def _sample_and_score( ) # Separate samples and their scores returned by the executor - samples, scores = zip(*results) # type: ignore + samples, scores = zip(*results) else: # ---------------------------------------------- # Local client: sequential sampling (single thread) @@ -128,7 +128,7 @@ def _sample_and_score( # Select best scoring sample scores_list = list(scores) - best_sample = samples[scores_list.index(max(scores_list))] # type: ignore + best_sample = samples[scores_list.index(max(scores_list))] current_score = max(scores_list) protein_tensor = best_sample @@ -168,7 +168,7 @@ def get_number_of_masked_positions( track_tensor = getattr(protein_tensor, track) track_tokenizer = getattr(self.tokenizers, track) is_mask = track_tensor == track_tokenizer.mask_token_id - return is_mask.sum().item() # type: ignore + return is_mask.sum().item() def randomly_unmask_positions( self, @@ -189,7 +189,7 @@ def randomly_unmask_positions( num_masked_positions = is_mask.sum().item() if num_positions_to_unmask > num_masked_positions: - num_positions_to_unmask = num_masked_positions # type: ignore + num_positions_to_unmask = num_masked_positions mask_indices = is_mask.nonzero(as_tuple=False) mask_indices = mask_indices[torch.randperm(mask_indices.size(0))] diff --git a/esm/sdk/forge.py b/esm/sdk/forge.py index c1273dcb..dda3f10e 100644 --- a/esm/sdk/forge.py +++ b/esm/sdk/forge.py @@ -157,7 +157,7 @@ def _process_fold_request( UserWarning, stacklevel=4, ) - request["msa"] = {"sequences": msa.sequences} + request["msa"] = msa.state_dict(json_serializable=True) else: error_msg = f"MSA must be None or MSA. Got {msa} instead." raise AttributeError(error_msg) @@ -541,6 +541,7 @@ def _process_generate_protein_request( "condition_on_coordinates_only": config.condition_on_coordinates_only, "strategy": config.strategy, "temperature_annealing": config.temperature_annealing, + "only_compute_backbone_rmsd": config.only_compute_backbone_rmsd, } return request @@ -586,6 +587,10 @@ def _process_generate_protein_response(data: dict[str, Any]) -> ESMProtein: ), plddt=maybe_tensor(data["outputs"]["plddt"]), ptm=maybe_tensor(data["outputs"]["ptm"]), + pae=maybe_tensor(data["outputs"]["pae"]), + crmsd=maybe_tensor(data["outputs"]["crmsd"]), + globularity=maybe_tensor(data["outputs"]["globularity"]), + interface_ptm=maybe_tensor(data["outputs"]["interface_ptm"]), ) @staticmethod diff --git a/esm/tokenization/__init__.py b/esm/tokenization/__init__.py index 6db76554..1dc628a3 100644 --- a/esm/tokenization/__init__.py +++ b/esm/tokenization/__init__.py @@ -52,10 +52,10 @@ def get_esmc_model_tokenizers() -> EsmSequenceTokenizer: def get_invalid_tokenizer_ids(tokenizer: EsmTokenizerBase) -> list[int]: if isinstance(tokenizer, EsmSequenceTokenizer): return [ - tokenizer.mask_token_id, # type: ignore - tokenizer.pad_token_id, # type: ignore - tokenizer.cls_token_id, # type: ignore - tokenizer.eos_token_id, # type: ignore + tokenizer.mask_token_id, + tokenizer.pad_token_id, + tokenizer.cls_token_id, + tokenizer.eos_token_id, ] else: return [ diff --git a/esm/tokenization/residue_tokenizer.py b/esm/tokenization/residue_tokenizer.py index c64fdb38..05e90323 100644 --- a/esm/tokenization/residue_tokenizer.py +++ b/esm/tokenization/residue_tokenizer.py @@ -21,26 +21,26 @@ def __init__(self, csv_path: str | None = None, max_annotations: int = 16): @cached_property def _description2label(self) -> dict[str, str]: - with AnyPath(self.csv_path).open() as f: # type: ignore + with AnyPath(self.csv_path).open() as f: df = pd.read_csv(f) return dict(zip(df.label, df.label_clean)) @cached_property def _labels(self) -> list[str]: - with AnyPath(self.csv_path).open() as f: # type: ignore + with AnyPath(self.csv_path).open() as f: df = pd.read_csv(f) labels = ( df.groupby("label_clean")["count"] .sum() - .sort_values(ascending=False, kind="stable") # type: ignore + .sort_values(ascending=False, kind="stable") .index.tolist() ) assert isinstance(labels, list) - return labels # type: ignore + return labels def _description2id(self, description: str) -> int | None: label = self._description2label.get(description) - return self._label2id.get(label) # type: ignore + return self._label2id.get(label) @cached_property def _label2id(self) -> dict[str, int]: diff --git a/esm/tokenization/sequence_tokenizer.py b/esm/tokenization/sequence_tokenizer.py index 2df0d44c..5ed59345 100644 --- a/esm/tokenization/sequence_tokenizer.py +++ b/esm/tokenization/sequence_tokenizer.py @@ -45,7 +45,7 @@ def __init__( # This is where we configure the automatic addition of special tokens when we call # tokenizer(text, add_special_tokens=True). Note that you can also configure how two # sequences are merged if you want. - tokenizer.post_processor = TemplateProcessing( # type: ignore + tokenizer.post_processor = TemplateProcessing( single=" $A ", special_tokens=[ ("", tokenizer.token_to_id("")), diff --git a/esm/utils/encoding.py b/esm/utils/encoding.py index 8461709d..d078a85b 100644 --- a/esm/utils/encoding.py +++ b/esm/utils/encoding.py @@ -77,9 +77,9 @@ def tokenize_structure( _, structure_tokens = structure_encoder.encode( coordinates, residue_index=residue_index ) - coordinates = torch.squeeze(coordinates, dim=0) # (L, 37, 3) # type: ignore - plddt = torch.squeeze(plddt, dim=0) # (L,) # type: ignore - structure_tokens = torch.squeeze(structure_tokens, dim=0) # (L,) # type: ignore + coordinates = torch.squeeze(coordinates, dim=0) # (L, 37, 3) + plddt = torch.squeeze(plddt, dim=0) # (L,) + structure_tokens = torch.squeeze(structure_tokens, dim=0) # (L,) # Add space for BOS and EOS tokens if add_special_tokens: diff --git a/esm/utils/forge_context_manager.py b/esm/utils/forge_context_manager.py index 98de9ee4..f84ddde7 100644 --- a/esm/utils/forge_context_manager.py +++ b/esm/utils/forge_context_manager.py @@ -143,7 +143,7 @@ def execute_batch( retry_count += 1 pbar.update(0) else: - results[idx] = e # type: ignore + results[idx] = e fail_count += 1 pbar.update(1) diff --git a/esm/utils/function/interpro.py b/esm/utils/function/interpro.py index b9e0c742..4beda7ad 100644 --- a/esm/utils/function/interpro.py +++ b/esm/utils/function/interpro.py @@ -43,15 +43,15 @@ def _parse_interpro2go(path: PathLike) -> dict[str, list[str]]: df["interpro_id"] = df.line.apply(lambda line: re.findall(r"IPR\d+", line)) df["go_ids"] = df.line.apply(parse_go_terms) df = df[df.go_ids.apply(len).gt(0) & df.interpro_id.apply(len).eq(1)] - df["interpro_id"] = df["interpro_id"].apply(lambda xs: xs[0]) # type: ignore + df["interpro_id"] = df["interpro_id"].apply(lambda xs: xs[0]) # Group all mappints together into a single map. df = ( - df.groupby("interpro_id")["go_ids"] # type: ignore + df.groupby("interpro_id")["go_ids"] .apply(lambda group: list(itertools.chain.from_iterable(group))) .reset_index() ) - return dict(zip(df.interpro_id, df.go_ids)) # type: ignore + return dict(zip(df.interpro_id, df.go_ids)) class InterProEntryType(IntEnum): diff --git a/esm/utils/function/lsh.py b/esm/utils/function/lsh.py index 87f4c674..fd2c7e9b 100644 --- a/esm/utils/function/lsh.py +++ b/esm/utils/function/lsh.py @@ -44,7 +44,7 @@ def __init__( filepath = AnyPath(filepath) if not filepath.exists(): raise FileNotFoundError(filepath) - table_hyperplanes = np.load(filepath) # type: ignore + table_hyperplanes = np.load(filepath) for i in range(num_tables): assert str(i) in table_hyperplanes, f"Missing hyperplane for table {i}" elif not allow_create_hyperplanes: @@ -62,10 +62,10 @@ def __init__( ] def write_hyperplanes(self, filepath: PathLike): - hyperplanes: dict[str, np.ndarray] = { # type: ignore + hyperplanes: dict[str, np.ndarray] = { str(i): table.hyperplanes for i, table in enumerate(self.tables) } - np.savez(filepath, **hyperplanes) # type: ignore + np.savez(filepath, **hyperplanes) def __call__(self, array): tokens = np.stack([table(array) for table in self.tables], 1) diff --git a/esm/utils/generation.py b/esm/utils/generation.py index 4d35e7ee..68b690ac 100644 --- a/esm/utils/generation.py +++ b/esm/utils/generation.py @@ -209,8 +209,7 @@ def _stack_field(fn: str): o, fn, stack_variable_length_tensors( - sequences=tensors, # type: ignore - constant_value=mask_token_id, + sequences=tensors, constant_value=mask_token_id ), ) @@ -658,7 +657,9 @@ def maybe_clone(x: torch.Tensor | None) -> torch.Tensor | None: # Format output forward_and_sample_output_dir = {} - forward_and_sample_output_dir["protein_tensor"] = ESMProteinTensor(**tokens_dir) + forward_and_sample_output_dir["protein_tensor"] = ESMProteinTensor( + **tokens_dir # ty:ignore[invalid-argument-type] + ) for property in [ "entropy", "prob", @@ -682,7 +683,7 @@ def maybe_clone(x: torch.Tensor | None) -> torch.Tensor | None: forward_and_sample_output_dir[property] = None per_res_embed = ( - logits_output.embeddings # type: ignore + logits_output.embeddings if sampling_config.return_per_residue_embeddings else None ) @@ -696,7 +697,7 @@ def maybe_clone(x: torch.Tensor | None) -> torch.Tensor | None: return ForwardAndSampleOutput( per_residue_embedding=per_res_embed, mean_embedding=mean_embedding, - **forward_and_sample_output_dir, + **forward_and_sample_output_dir, # ty:ignore[invalid-argument-type] ) diff --git a/esm/utils/misc.py b/esm/utils/misc.py index 6fa670a9..cc874ff5 100644 --- a/esm/utils/misc.py +++ b/esm/utils/misc.py @@ -68,7 +68,7 @@ def slice_python_object_as_numpy( case _: sliced_obj = obj.__class__(sliced_obj) # type: ignore - return sliced_obj # type: ignore + return sliced_obj def slice_any_object( @@ -265,7 +265,7 @@ def unbinpack( return stack_variable_length_tensors(unpacked_tensors, pad_value) -def fp32_autocast_context(device_type: str) -> ContextManager[Any]: # type: ignore +def fp32_autocast_context(device_type: str) -> ContextManager[Any]: """ Returns an autocast context manager that disables downcasting by AMP. @@ -276,12 +276,12 @@ def fp32_autocast_context(device_type: str) -> ContextManager[Any]: # type: ign An autocast context manager with the specified behavior. """ if device_type == "cpu": - return torch.amp.autocast(device_type, enabled=False) # type: ignore + return torch.amp.autocast(device_type, enabled=False) elif device_type == "mps": # For MPS, just return a no-op context manager (nullcontext) since MPS does not support autocast. return nullcontext() elif device_type == "cuda": - return torch.amp.autocast(device_type, dtype=torch.float32) # type: ignore + return torch.amp.autocast(device_type, dtype=torch.float32) else: raise ValueError(f"Unsupported device type: {device_type}") @@ -475,7 +475,7 @@ def concat_objects(objs: Sequence[Any], separator: Any | None = None): """ match objs[0]: case Concatable(): - return objs[0].__class__.concat(objs) # type: ignore + return objs[0].__class__.concat(objs) case str(): assert isinstance( separator, str diff --git a/esm/utils/msa/msa.py b/esm/utils/msa/msa.py index 120f1333..ebd81dfa 100644 --- a/esm/utils/msa/msa.py +++ b/esm/utils/msa/msa.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from functools import cached_property from itertools import islice -from typing import Sequence +from typing import Any, Sequence import numpy as np from Bio import SeqIO @@ -24,6 +24,25 @@ def remove_insertions_from_sequence(seq: str) -> str: return seq.translate(REMOVE_LOWERCASE_TRANSLATION) +def is_a3m_insertion(ch: str) -> bool: + """True for an a3m insertion character (a lowercase letter or ``.``).""" + return ch == "." or ch.islower() + + +def a3m_deletion_counts(seq: str) -> np.ndarray: + """Per-match-column count of preceding a3m insertions (lowercase letters / ``.``). + + Each insertion run is accumulated into the next match column. Vectorized over the + sequence: encode to bytes, mask insertions, and difference their cumsum at the + match positions. Returns an int array of length = number of match columns. + """ + codes = np.frombuffer(seq.encode("ascii"), dtype=np.uint8) + is_insertion = ((codes >= ord("a")) & (codes <= ord("z"))) | (codes == ord(".")) + insertions_before = np.concatenate(([0], np.cumsum(is_insertion))) + match_positions = np.nonzero(~is_insertion)[0] + return np.diff(insertions_before[match_positions], prepend=0) + + @dataclass(frozen=True) class MSA(SequentialDataclass): """Object-oriented interface to an MSA. @@ -35,6 +54,7 @@ class MSA(SequentialDataclass): """ entries: list[FastaEntry] + deletions: np.ndarray | None = dataclasses.field(default=None, compare=False) @cached_property def sequences(self) -> list[str]: @@ -60,15 +80,19 @@ def from_a3m( max_sequences: int | None = None, ) -> MSA: entries = [] - for header, seq in islice(read_sequences(path), max_sequences): - if remove_insertions: - seq = remove_insertions_from_sequence(seq) + deletion_rows: list[np.ndarray] = [] + for header, raw in islice(read_sequences(path), max_sequences): + deletion_rows.append(a3m_deletion_counts(raw)) + seq = remove_insertions_from_sequence(raw) if remove_insertions else raw if entries: assert ( len(seq) == len(entries[0].sequence) ), f"Sequence length mismatch. Expected: {len(entries[0].sequence)}, Received: {len(seq)}" entries.append(FastaEntry(header, seq)) - return cls(entries) + deletions = ( + np.stack(deletion_rows).astype(np.float32) if deletion_rows else None + ) + return cls(entries, deletions=deletions) def to_a3m(self, path: PathOrBuffer) -> None: write_sequences(self.entries, path) @@ -149,6 +173,33 @@ def from_sequences( entries = [FastaEntry("", seq) for seq in sequences] return cls(entries) + def state_dict(self, json_serializable: bool = False) -> dict[str, Any]: + """Serialize for the Forge wire / storage (mirrors ``ProteinComplex``). + + ``deletions`` carries the per-(row, match-column) a3m deletion counts (set by + :meth:`from_a3m`) alongside the sequences, so the feature survives even when the + default ``remove_insertions`` strips the lowercase insertions out of the + sequences. With ``json_serializable=True`` the array is returned as a list. + Headers are not serialized. + """ + dct: dict[str, Any] = {"sequences": self.sequences} + if self.deletions is not None: + dct["deletions"] = ( + self.deletions.tolist() if json_serializable else self.deletions + ) + return dct + + @classmethod + def from_state_dict(cls, dct: dict[str, Any]) -> MSA: + """Inverse of :meth:`state_dict`; sequences are taken verbatim.""" + deletions = dct.get("deletions") + return cls( + entries=[FastaEntry("", seq) for seq in dct["sequences"]], + deletions=None + if deletions is None + else np.asarray(deletions, dtype=np.float32), + ) + def to_sequence_bytes(self) -> bytes: """Stores ONLY SEQUENCES in array format as bytes. Header information will be lost.""" seqlen_bytes = self.seqlen.to_bytes(4, "little") @@ -181,10 +232,30 @@ def array(self) -> np.ndarray: def query(self) -> str: return self.entries[0].sequence + def _aligned_deletions(self) -> np.ndarray | None: + """``deletions`` if it is row/column-aligned with the sequences, else None. + + Misalignment means the sequences still carry insertions (length != + match-column count), so the stored deletions no longer describe them.""" + if self.deletions is None or self.deletions.shape != (self.depth, self.seqlen): + return None + return self.deletions + + def _select_deletion_columns(self, indices) -> np.ndarray | None: + """Column-subselect ``deletions`` to match a position subselect, or None + when ``deletions`` is not column-aligned with the sequences (e.g. the + sequences still carry insertions, so length != match-column count).""" + if self.deletions is None or self.deletions.shape[1] != self.seqlen: + return None + return self.deletions[:, indices] + def select_sequences(self, indices: Sequence[int] | np.ndarray) -> MSA: """Subselect rows of the MSA.""" entries = [self.entries[idx] for idx in indices] - return dataclasses.replace(self, entries=entries) + deletions = ( + None if self.deletions is None else self.deletions[np.asarray(indices)] + ) + return dataclasses.replace(self, entries=entries, deletions=deletions) def select_positions(self, indices: Sequence[int] | np.ndarray) -> MSA: """Subselect columns of the MSA.""" @@ -192,7 +263,9 @@ def select_positions(self, indices: Sequence[int] | np.ndarray) -> MSA: FastaEntry(header, "".join(seq[idx] for idx in indices)) for header, seq in self.entries ] - return dataclasses.replace(self, entries=entries) + return dataclasses.replace( + self, entries=entries, deletions=self._select_deletion_columns(indices) + ) def __getitem__(self, indices: int | list[int] | slice | np.ndarray): if isinstance(indices, int): @@ -202,7 +275,9 @@ def __getitem__(self, indices: int | list[int] | slice | np.ndarray): FastaEntry(header, slice_any_object(seq, indices)) for header, seq in self.entries ] - return dataclasses.replace(self, entries=entries) + return dataclasses.replace( + self, entries=entries, deletions=self._select_deletion_columns(indices) + ) def __len__(self): return self.seqlen @@ -264,7 +339,7 @@ def select_random_sequences(self, num_seqs: int) -> MSA: 0, np.random.choice(self.depth - 1, num_seqs - 1, replace=False) + 1 ) ) - msa = self.select_sequences(indices) # type: ignore + msa = self.select_sequences(indices) return msa def select_diverse_sequences(self, num_seqs: int) -> MSA: @@ -287,7 +362,14 @@ def pad_to_depth(self, depth: int) -> MSA: num_to_add = depth - self.depth extra_entries = [FastaEntry("", "-" * self.seqlen) for _ in range(num_to_add)] - return dataclasses.replace(self, entries=self.entries + extra_entries) + # Padded rows are all-gap, so they contribute zero deletions. + deletions = self._aligned_deletions() + if deletions is not None: + pad = np.zeros((num_to_add, self.seqlen), dtype=deletions.dtype) + deletions = np.concatenate([deletions, pad], axis=0) + return dataclasses.replace( + self, entries=self.entries + extra_entries, deletions=deletions + ) @classmethod def stack( @@ -295,12 +377,26 @@ def stack( ) -> MSA: """Stack a series of MSAs. Optionally remove the query from msas after the first.""" all_entries = [] + deletion_rows: list[np.ndarray] = [] for i, msa in enumerate(msas): entries = msa.entries + dels = msa._aligned_deletions() if i > 0 and remove_query_from_later_msas: entries = entries[1:] + if dels is not None: + dels = dels[1:] all_entries.extend(entries) - return cls(entries=all_entries) + if dels is not None: + deletion_rows.append(dels) + # Carry deletions only if every input contributed a column-aligned array of + # matching width + deletions = None + if ( + len(deletion_rows) == len(msas) + and len({d.shape[1] for d in deletion_rows}) == 1 + ): + deletions = np.concatenate(deletion_rows, axis=0) + return cls(entries=all_entries, deletions=deletions) @cached_property def seqid(self) -> np.ndarray: @@ -335,7 +431,16 @@ def concat( seqs = [join_token.join(vals) for vals in zip(*(msa.sequences for msa in msas))] entries = [FastaEntry(header, seq) for header, seq in zip(headers, seqs)] - return cls(entries) + # A non-empty join token inserts columns with no deletion counterpart, so + # the column alignment only survives when chains are concatenated directly. + deletions = None + if join_token == "": + per_msa = [msa._aligned_deletions() for msa in msas] + if all(d is not None for d in per_msa): + deletions = np.concatenate( + per_msa, axis=1 + ) # ty: ignore[no-matching-overload] + return cls(entries, deletions=deletions) @dataclass(frozen=True) @@ -421,7 +526,7 @@ def select_random_sequences(self, num_seqs: int) -> FastMSA: 0, np.random.choice(self.depth - 1, num_seqs - 1, replace=False) + 1 ) ) - msa = self.select_sequences(indices) # type: ignore + msa = self.select_sequences(indices) return msa def pad_to_depth(self, depth: int) -> FastMSA: diff --git a/esm/utils/msa/msa_test.py b/esm/utils/msa/msa_test.py new file mode 100644 index 00000000..31816093 --- /dev/null +++ b/esm/utils/msa/msa_test.py @@ -0,0 +1,174 @@ +"""Tests for MSA.from_a3m deletion handling (a3m lowercase insertions).""" + +import gzip + +import numpy as np + +from esm.models.esmfold2.paired_msa import msa_to_res_type_and_deletions +from esm.utils.msa.msa import MSA, a3m_deletion_counts +from esm.utils.parsing import FastaEntry + +# query has no insertions (5 match columns); row1 has "aa" inserted before col2, +# row2 has "c" inserted before col4. +_A3M = ">query\nMKLNT\n>s1 key=101\nMKaaLNT\n>s2 key=102\nM-LNcT\n" +_EXPECTED_DELETIONS = np.array( + [[0, 0, 0, 0, 0], [0, 0, 2, 0, 0], [0, 0, 0, 0, 1]], dtype=np.float32 +) +_LETTER_TO_RES_TYPE = {c: i for i, c in enumerate("ACDEFGHIKLMNPQRSTVWY")} + + +def _write_a3m(path, gz: bool): + if gz: + with gzip.open(path, "wt") as f: + f.write(_A3M) + else: + path.write_text(_A3M) + + +def _naive_deletion_counts(seq: str) -> list[int]: + out, ins = [], 0 + for ch in seq: + if ch == "." or ch.islower(): + ins += 1 + else: + out.append(ins) + ins = 0 + return out + + +def _a3m_msa(tmp_path) -> MSA: + """Build the shared `_A3M` fixture as an MSA (insertion-stripped, deletions set).""" + p = tmp_path / "m.a3m" + _write_a3m(p, gz=False) + return MSA.from_a3m(str(p)) + + +def test_a3m_deletion_counts_vectorized(): + # leading "aa" before M, interior "xx" before L; trailing insertions are dropped + np.testing.assert_array_equal(a3m_deletion_counts("aaMKxxLN"), [2, 0, 2, 0]) + for seq in ["MKLNT", "MKaaLNcT", ".aMK..LN", "MKLNdd"]: + np.testing.assert_array_equal( + a3m_deletion_counts(seq), _naive_deletion_counts(seq) + ) + + +def test_from_a3m_records_deletions(tmp_path): + msa = _a3m_msa(tmp_path) # remove_insertions=True (default) + # stored sequences are insertion-stripped (equal length = query length) + assert msa.sequences == ["MKLNT", "MKLNT", "M-LNT"] + assert msa.deletions is not None + np.testing.assert_array_equal(msa.deletions, _EXPECTED_DELETIONS) + + +def test_from_a3m_gz(tmp_path): + p = tmp_path / "m.a3m.gz" + _write_a3m(p, gz=True) + msa = MSA.from_a3m(str(p)) + assert msa.deletions is not None + np.testing.assert_array_equal(msa.deletions, _EXPECTED_DELETIONS) + + +def test_deletions_flow_through_featurization(tmp_path): + """Stripped sequences would give zero deletions; stored ones must survive.""" + msa = _a3m_msa(tmp_path) + _, deletions = msa_to_res_type_and_deletions(msa, _LETTER_TO_RES_TYPE) + np.testing.assert_array_equal(deletions, _EXPECTED_DELETIONS) + + +def test_from_sequences_has_no_deletions_and_recomputes(): + """MSAs without stored deletions fall back to parsing the sequences.""" + no_del = MSA.from_sequences(["MKLNT", "MKLNT"]) + assert no_del.deletions is None + _, deletions = msa_to_res_type_and_deletions(no_del, _LETTER_TO_RES_TYPE) + np.testing.assert_array_equal(deletions, np.zeros((2, 5), dtype=np.float32)) + + +def test_slicing_carries_deletions(tmp_path): + """Row/column subselects carry the matching slice of deletions.""" + msa = _a3m_msa(tmp_path) + by_row = msa.select_sequences([0, 2]).deletions + by_col = msa.select_positions([2, 4]).deletions + by_slice = msa[1:].deletions + assert by_row is not None and by_col is not None and by_slice is not None + np.testing.assert_array_equal(by_row, _EXPECTED_DELETIONS[[0, 2]]) + np.testing.assert_array_equal(by_col, _EXPECTED_DELETIONS[:, [2, 4]]) + np.testing.assert_array_equal(by_slice, _EXPECTED_DELETIONS[:, 1:]) + + +def test_sliced_deletions_flow_through_featurization(tmp_path): + """A per-chain column subselect (as chainbreak splitting does) keeps deletions, + so featurizing the sliced MSA yields the sliced counts, not zeros.""" + sub = _a3m_msa(tmp_path).select_positions([2, 3, 4]) + _, deletions = msa_to_res_type_and_deletions(sub, _LETTER_TO_RES_TYPE) + np.testing.assert_array_equal(deletions, _EXPECTED_DELETIONS[:, [2, 3, 4]]) + assert deletions.sum() > 0 + + +def test_deletions_dropped_when_not_column_aligned(): + """If sequences keep insertions (length != match-column count), a column + subselect drops deletions rather than misaligning them.""" + unstripped = MSA( + entries=[FastaEntry("q", "MKaaLNT"), FastaEntry("s", "MKaaLNT")], + deletions=np.zeros((2, 5), dtype=np.float32), + ) + assert unstripped.select_positions([0, 1]).deletions is None + + +def test_pad_to_depth_carries_deletions(tmp_path): + """Padded (all-gap) rows contribute zero deletions and keep the array aligned.""" + msa = _a3m_msa(tmp_path) + padded = msa.pad_to_depth(5) + assert padded.deletions is not None + np.testing.assert_array_equal(padded.deletions[:3], _EXPECTED_DELETIONS) + np.testing.assert_array_equal(padded.deletions[3:], np.zeros((2, 5), np.float32)) + + +def test_stack_carries_deletions(tmp_path): + """Row-stacking concatenates deletion rows, dropping the duplicated query row.""" + msa = _a3m_msa(tmp_path) + stacked = MSA.stack([msa, msa], remove_query_from_later_msas=True) + assert stacked.deletions is not None + expected = np.concatenate([_EXPECTED_DELETIONS, _EXPECTED_DELETIONS[1:]], axis=0) + np.testing.assert_array_equal(stacked.deletions, expected) + + +def test_stack_drops_deletions_if_any_missing(tmp_path): + msa = _a3m_msa(tmp_path) + no_del = MSA.from_sequences(["MKLNT", "MKLNT"]) + assert MSA.stack([msa, no_del]).deletions is None + + +def test_concat_carries_deletions_without_join_token(tmp_path): + """Column-concatenation joins deletion columns when no join token is inserted.""" + msa = _a3m_msa(tmp_path) + concatenated = MSA.concat([msa, msa], join_token=None) + assert concatenated.deletions is not None + np.testing.assert_array_equal( + concatenated.deletions, + np.concatenate([_EXPECTED_DELETIONS, _EXPECTED_DELETIONS], axis=1), + ) + + +def test_concat_drops_deletions_with_join_token(tmp_path): + """A join token inserts columns with no deletion counterpart, so drop them.""" + msa = _a3m_msa(tmp_path) + assert MSA.concat([msa, msa], join_token="|").deletions is None + + +def test_state_dict_round_trip(tmp_path): + """state_dict/from_state_dict preserve sequences and deletions over the wire.""" + msa = _a3m_msa(tmp_path) + payload = msa.state_dict(json_serializable=True) + assert payload["sequences"] == msa.sequences + assert isinstance(payload["deletions"], list) + restored = MSA.from_state_dict(payload) + assert restored.sequences == msa.sequences + assert restored.deletions is not None + np.testing.assert_array_equal(restored.deletions, _EXPECTED_DELETIONS) + + +def test_state_dict_omits_deletions_when_absent(): + """from_sequences MSAs carry no deletions, so none are serialized.""" + payload = MSA.from_sequences(["MKLNT", "MKLNT"]).state_dict(json_serializable=True) + assert "deletions" not in payload + assert MSA.from_state_dict(payload).deletions is None diff --git a/esm/utils/parsing.py b/esm/utils/parsing.py index 3137af5b..eb5e2acb 100644 --- a/esm/utils/parsing.py +++ b/esm/utils/parsing.py @@ -41,7 +41,7 @@ def read_sequences(path: PathOrBuffer) -> Generator[FastaEntry, None, None]: # Doesn't use explicit isinstance check to support # inputs that are not explicitly str/Path/TextIOBase but # may support similar functionality - data = None # type: ignore + data = None try: if str(path).endswith(".gz"): import gzip diff --git a/esm/utils/residue_constants.py b/esm/utils/residue_constants.py index 81b379e8..da1d1b1e 100644 --- a/esm/utils/residue_constants.py +++ b/esm/utils/residue_constants.py @@ -1044,12 +1044,14 @@ def _make_rigid_group_constants(): for restype, restype_letter in enumerate(restypes_with_x): resname = restype_1to3[restype_letter] for atomname, group_idx, atom_position in rigid_group_atom_positions[resname]: - atomtype = atom_order[atomname] + atomtype = atom_order[atomname] # ty:ignore[invalid-argument-type] restype_atom37_to_rigid_group[restype, atomtype] = group_idx restype_atom37_mask[restype, atomtype] = 1 restype_atom37_rigid_group_positions[restype, atomtype, :] = atom_position - atom14idx = restype_name_to_atom14_names[resname].index(atomname) + atom14idx = restype_name_to_atom14_names[resname].index( + atomname # ty:ignore[invalid-argument-type] + ) restype_atom14_to_rigid_group[restype, atom14idx] = group_idx restype_atom14_mask[restype, atom14idx] = 1 restype_atom14_rigid_group_positions[restype, atom14idx, :] = atom_position diff --git a/esm/utils/structure/affine3d.py b/esm/utils/structure/affine3d.py index e2d9b27e..11b1307b 100644 --- a/esm/utils/structure/affine3d.py +++ b/esm/utils/structure/affine3d.py @@ -13,43 +13,49 @@ class Rotation(ABC): @classmethod - def identity(cls, shape: tuple[int, ...], **tensor_kwargs) -> Self: ... + def identity( + cls, shape: tuple[int, ...], **tensor_kwargs + ) -> Self: ... # ty:ignore[empty-body] @classmethod - def random(cls, shape: tuple[int, ...], **tensor_kwargs) -> Self: ... + def random( + cls, shape: tuple[int, ...], **tensor_kwargs + ) -> Self: ... # ty:ignore[empty-body] - def __getitem__(self, idx: T.Any) -> Self: ... + def __getitem__(self, idx: T.Any) -> Self: ... # ty:ignore[empty-body] @property - def tensor(self) -> torch.Tensor: + def tensor(self) -> torch.Tensor: # ty:ignore[empty-body] # We claim that this should be zero-cost abstraction that returns the raw tensor backing this # object. The raw tensor should always have exactly 1 more dim than self.shape, which should be # implemented using reshaping ... @property - def shape(self) -> torch.Size: + def shape(self) -> torch.Size: # ty:ignore[empty-body] # The "shape" of the rotation, as if it was a torch.tensor object # This means that 1x4 quaternions are treated as size (1,) for example ... - def as_matrix(self) -> RotationMatrix: ... + def as_matrix(self) -> RotationMatrix: ... # ty:ignore[empty-body] - def as_quat(self, normalize: bool = False) -> RotationQuat: ... + def as_quat( + self, normalize: bool = False + ) -> RotationQuat: ... # ty:ignore[empty-body] - def compose(self, other: Self) -> Self: + def compose(self, other: Self) -> Self: # ty:ignore[empty-body] # To be safe, we force users to explicitly convert between rotation types. ... - def convert_compose(self, other: Self) -> Self: + def convert_compose(self, other: Self) -> Self: # ty:ignore[empty-body] # This function will automatically convert between types of rotations ... - def apply(self, p: torch.Tensor) -> torch.Tensor: + def apply(self, p: torch.Tensor) -> torch.Tensor: # ty:ignore[empty-body] # rotates points by this rotation object ... - def invert(self) -> Self: ... + def invert(self) -> Self: ... # ty:ignore[empty-body] @property def dtype(self) -> torch.dtype: @@ -292,7 +298,8 @@ def identity( kwargs = tensor_kwargs shape = shape_or_affine return Affine3D( - torch.zeros((*shape, 3), **kwargs), rotation_type.identity(shape, **kwargs) + torch.zeros((*shape, 3), **kwargs), # ty:ignore[no-matching-overload] + rotation_type.identity(shape, **kwargs), ) @staticmethod @@ -406,7 +413,7 @@ def from_tensor(t: torch.Tensor) -> "Affine3D": rot = RotationMatrix(t[..., :-3].unflatten(-1, (3, 3))) case _: raise RuntimeError( - f"Cannot detect rotation fromat from {t.shape[-1] -3}-d flat vector" + f"Cannot detect rotation fromat from {t.shape[-1] - 3}-d flat vector" ) return Affine3D(trans, rot) diff --git a/esm/utils/structure/aligner.py b/esm/utils/structure/aligner.py index f25d9987..ab4eaa1f 100644 --- a/esm/utils/structure/aligner.py +++ b/esm/utils/structure/aligner.py @@ -14,11 +14,11 @@ class Alignable(Protocol): __dataclass_fields__: ClassVar[dict[str, Field[Any]]] @property - def atom37_positions(self) -> np.ndarray: # type: ignore + def atom37_positions(self) -> np.ndarray: pass @property - def atom37_mask(self) -> np.ndarray: # type: ignore + def atom37_mask(self) -> np.ndarray: pass def __len__(self) -> int: ... diff --git a/esm/utils/structure/input_builder.py b/esm/utils/structure/input_builder.py index 1a1c7344..0a7d7113 100644 --- a/esm/utils/structure/input_builder.py +++ b/esm/utils/structure/input_builder.py @@ -97,7 +97,7 @@ def create_chain_data(seq_input, chain_type: str) -> dict[str, Any]: elif seq_input.msa is None: chain_data["msa"] = None elif isinstance(seq_input.msa, MSA): - chain_data["msa"] = {"sequences": seq_input.msa.sequences} + chain_data["msa"] = seq_input.msa.state_dict(json_serializable=True) else: error_msg = f"MSA must be None or MSA. Got {seq_input.msa} instead." raise AttributeError(error_msg) diff --git a/esm/utils/structure/metrics.py b/esm/utils/structure/metrics.py index b2e590db..138ff9db 100644 --- a/esm/utils/structure/metrics.py +++ b/esm/utils/structure/metrics.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from einops import rearrange from torch import Tensor -from torch.amp import autocast # type: ignore +from torch.amp import autocast from esm.utils import residue_constants from esm.utils.misc import binpack, unbinpack diff --git a/esm/utils/structure/mmcif_parsing.py b/esm/utils/structure/mmcif_parsing.py index 3509ffb1..a3d4e760 100644 --- a/esm/utils/structure/mmcif_parsing.py +++ b/esm/utils/structure/mmcif_parsing.py @@ -9,12 +9,39 @@ import biotite.structure as bs import biotite.structure.io.pdbx as pdbx +import numpy as np +from biotite.structure.io.pdbx import CIFColumn, CIFData, CIFFile from esm.utils import residue_constants # Define PathOrBuffer for the opensource version PathOrBuffer = Union[str, os.PathLike, io.StringIO] +# pLDDT/confidence is stored internally on a 0-1 scale, but written to the mmCIF +# B-factor column on the conventional 0-100 scale +PLDDT_B_FACTOR_SCALE = 100.0 + +_MMCIF_COLUMN_DECIMALS = {"Cartn_x": 3, "Cartn_y": 3, "Cartn_z": 3, "B_iso_or_equiv": 2} + + +def round_mmcif_columns(cif_file: CIFFile) -> None: + """Round coordinate and B-factor columns in ``cif_file`` in-place.""" + if "atom_site" not in cif_file.block: + return + atom_site = cif_file.block["atom_site"] + for col_name, decimals in _MMCIF_COLUMN_DECIMALS.items(): + if col_name not in atom_site: + continue + column = atom_site[col_name] + values = column.as_array(np.float64) + atom_site[col_name] = CIFColumn( + data=CIFData( + array=np.array([f"{v:.{decimals}f}" for v in values], dtype=np.str_), + dtype=np.str_, + ), + mask=column.mask, + ) + class NoProteinError(Exception): pass @@ -246,10 +273,10 @@ def _parse_sequences(self): clean_ins_code = "" res_num = None - is_het = hetflag.upper() == "Y" # type: ignore + is_het = hetflag.upper() == "Y" chain_data[asym_id][seq_index] = Residue( residue_number=res_num, - insertion_code=clean_ins_code, # type: ignore + insertion_code=clean_ins_code, hetflag=is_het, ) diff --git a/esm/utils/structure/molecular_complex.py b/esm/utils/structure/molecular_complex.py index dc679d83..fe007954 100644 --- a/esm/utils/structure/molecular_complex.py +++ b/esm/utils/structure/molecular_complex.py @@ -25,6 +25,7 @@ from esm.utils import residue_constants from esm.utils.structure.metrics import compute_lddt, compute_rmsd +from esm.utils.structure.mmcif_parsing import PLDDT_B_FACTOR_SCALE, round_mmcif_columns from esm.utils.structure.protein_complex import ProteinComplex, ProteinComplexMetadata @@ -489,7 +490,7 @@ def from_mmcif(cls, inp: str, id: str | None = None) -> "MolecularComplex": _raw = ( _col.as_array(str) if hasattr(_col, "as_array") - else np.array(list(_col), dtype=str) # type: ignore[arg-type] + else np.array(list(_col), dtype=str) ) # biotite's get_structure(model=1) filters to model 1 AND # removes alternate conformations. We must apply the same @@ -500,7 +501,7 @@ def from_mmcif(cls, inp: str, id: str | None = None) -> "MolecularComplex": _models = ( _mc.as_array(str) if hasattr(_mc, "as_array") - else np.array(list(_mc), dtype=str) # type: ignore[arg-type] + else np.array(list(_mc), dtype=str) ) keep &= _models == "1" if "label_alt_id" in atom_site: @@ -508,7 +509,7 @@ def from_mmcif(cls, inp: str, id: str | None = None) -> "MolecularComplex": _alts = ( _ac.as_array(str) if hasattr(_ac, "as_array") - else np.array(list(_ac), dtype=str) # type: ignore[arg-type] + else np.array(list(_ac), dtype=str) ) keep &= np.isin(_alts, [".", "?", "", "A"]) filtered = _raw[keep] @@ -529,8 +530,8 @@ def from_mmcif(cls, inp: str, id: str | None = None) -> "MolecularComplex": entity_types, "__iter__" ): # Type annotation to help pyright understand these are iterable - entity_ids_list = list(entity_ids) # type: ignore - entity_types_list = list(entity_types) # type: ignore + entity_ids_list = list(entity_ids) + entity_types_list = list(entity_types) for eid, etype in zip(entity_ids_list, entity_types_list): entity_info[eid] = etype except Exception: @@ -634,7 +635,7 @@ def from_mmcif(cls, inp: str, id: str | None = None) -> "MolecularComplex": # Add confidence score (B-factor if available, otherwise 1.0) bfactor = getattr(atoms[0], "b_factor", 50.0) if atoms else 50.0 - confidence_scores.append(min(bfactor / 100.0, 1.0)) + confidence_scores.append(min(bfactor / PLDDT_B_FACTOR_SCALE, 1.0)) # Convert to numpy arrays if not flat_positions: @@ -783,6 +784,111 @@ def _add_entity_information( }, ) + # Add _entity_poly + _entity_poly_seq (SEQRES) for OST chain-grouping. + # Without these, OST's ligand scorer aborts with "Extracting chem grouping + # from mmCIF file requires all SEQRES information set" on every polymer + # chain — this is what cost us ~1234 runs_n_poses targets in the first + # OST pass. Iterate the polymer entities; one ``_entity_poly`` row per + # entity, one ``_entity_poly_seq`` row per residue. + entity_to_chains: dict[int, list[str]] = {} + for chain_id, eid in chain_to_entity.items(): + entity_to_chains.setdefault(eid, []).append(chain_id) + + ep_eids: list[str] = [] + ep_types: list[str] = [] + ep_strand_ids: list[str] = [] + ep_seq_codes: list[str] = [] # one-letter where possible, else (XXX) + eps_eids: list[str] = [] + eps_nums: list[str] = [] + eps_mons: list[str] = [] + eps_het: list[str] = [] + + for eid in sorted(entity_sequences.keys()): + seq = entity_sequences[eid] + has_protein = any(t in residue_constants.restype_3to1 for t in seq) + has_na = any( + t in ("A", "T", "G", "C", "U", "DA", "DT", "DG", "DC") for t in seq + ) + if not (has_protein or has_na): + continue # non-polymer entity, OST doesn't need SEQRES + + if has_protein: + # Detect L-peptide vs D — OF3 only outputs L for now. + ep_types.append("polypeptide(L)") + one_letter = "".join( + residue_constants.restype_3to1.get(t, "(X)") for t in seq + ) + else: + # Distinguish DNA vs RNA by presence of "U" or "T". + if any(t in ("U",) for t in seq): + ep_types.append("polyribonucleotide") + elif any(t in ("DA", "DT", "DG", "DC") for t in seq): + ep_types.append("polydeoxyribonucleotide") + else: + ep_types.append("polyribonucleotide") + one_letter = "".join( + "T" + if t == "DT" + else "A" + if t == "DA" + else "G" + if t == "DG" + else "C" + if t == "DC" + else t + for t in seq + ) + + ep_eids.append(str(eid)) + chains = sorted(entity_to_chains.get(eid, [])) + ep_strand_ids.append(",".join(chains) if chains else "?") + ep_seq_codes.append(one_letter) + + for num, t in enumerate(seq, start=1): + eps_eids.append(str(eid)) + eps_nums.append(str(num)) + # ``_entity_poly_seq.mon_id`` is the 3-letter (or CCD) code. + # Non-canonical residues come through as the raw token already. + eps_mons.append(t) + eps_het.append("n") + + if ep_eids: + cif_file.block["entity_poly"] = CIFCategory( + name="entity_poly", + columns={ + "entity_id": CIFColumn( + data=CIFData(array=np.array(ep_eids), dtype=np.str_) + ), + "type": CIFColumn( + data=CIFData(array=np.array(ep_types), dtype=np.str_) + ), + "pdbx_strand_id": CIFColumn( + data=CIFData(array=np.array(ep_strand_ids), dtype=np.str_) + ), + "pdbx_seq_one_letter_code_can": CIFColumn( + data=CIFData(array=np.array(ep_seq_codes), dtype=np.str_) + ), + }, + ) + if eps_eids: + cif_file.block["entity_poly_seq"] = CIFCategory( + name="entity_poly_seq", + columns={ + "entity_id": CIFColumn( + data=CIFData(array=np.array(eps_eids), dtype=np.str_) + ), + "num": CIFColumn( + data=CIFData(array=np.array(eps_nums), dtype=np.str_) + ), + "mon_id": CIFColumn( + data=CIFData(array=np.array(eps_mons), dtype=np.str_) + ), + "hetero": CIFColumn( + data=CIFData(array=np.array(eps_het), dtype=np.str_) + ), + }, + ) + def to_mmcif(self) -> str: """Write MolecularComplex to mmcif string using biotite. @@ -841,10 +947,10 @@ def to_mmcif(self) -> str: names = standard_names[: end - start] # Pad if needed while len(names) < (end - start): - names.append(f"X{len(names)+1}") + names.append(f"X{len(names) + 1}") else: # Fallback: generate names for ligands/nucleic acids - names = [f"C{i+1}" for i in range(end - start)] + names = [f"C{i + 1}" for i in range(end - start)] # Vectorized assignment for this token's atoms atom_res_ids[start:end] = res_id @@ -855,7 +961,7 @@ def to_mmcif(self) -> str: atom_hetero[start:end] = self.atom_hetero[start:end] else: atom_hetero[start:end] = not is_protein - atom_bfactors[start:end] = self.plddt[token_idx] * 100.0 + atom_bfactors[start:end] = self.plddt[token_idx] * PLDDT_B_FACTOR_SCALE atom_names[start:end] = names atom_entity_ids[start:end] = chain_to_entity.get(chain_id_str, 1) @@ -870,6 +976,10 @@ def to_mmcif(self) -> str: atom_array.atom_name = np.array(atom_names, dtype="U4") atom_array.add_annotation("b_factor", dtype=float) atom_array.b_factor = atom_bfactors + atom_array.add_annotation("occupancy", dtype=float) + atom_array.occupancy = np.ones( + n_atoms, dtype=np.float32 + ) # Necessary for BioPython MMCIFParser atom_array.add_annotation("entity_id", dtype=int) atom_array.entity_id = atom_entity_ids @@ -893,7 +1003,7 @@ def to_mmcif(self) -> str: if hasattr(label_asym_ids, "as_array"): chain_ids_list = label_asym_ids.as_array(str).tolist() elif hasattr(label_asym_ids, "__iter__"): - chain_ids_list = list(label_asym_ids) # type: ignore[arg-type] + chain_ids_list = list(label_asym_ids) else: chain_ids_list = [] updated_entity_ids = [ @@ -907,6 +1017,10 @@ def to_mmcif(self) -> str: # Add _entity category for OST compatibility self._add_entity_information(cif_file, entity_sequences) + # biotite echoes unmasked float columns at full precision + # so we round every float column to conventional mmCIF precision + round_mmcif_columns(cif_file) + # Convert to string output = io.StringIO() cif_file.write(output) diff --git a/esm/utils/structure/normalize_coordinates.py b/esm/utils/structure/normalize_coordinates.py index d26f2e0a..9ec2cf97 100644 --- a/esm/utils/structure/normalize_coordinates.py +++ b/esm/utils/structure/normalize_coordinates.py @@ -28,7 +28,7 @@ def index_by_atom_name( result = atom37[index] # type: ignore if squeeze: result = result.squeeze(dim) - return result + return result # ty:ignore[invalid-return-type] def get_protein_normalization_frame(coords: Tensor) -> Affine3D: diff --git a/esm/utils/structure/predicted_aligned_error.py b/esm/utils/structure/predicted_aligned_error.py index 1071baf4..8a97aa98 100644 --- a/esm/utils/structure/predicted_aligned_error.py +++ b/esm/utils/structure/predicted_aligned_error.py @@ -48,7 +48,7 @@ def compute_predicted_aligned_error( return (probs * bins).sum(dim=-1) -@torch.no_grad +@torch.no_grad # ty:ignore[too-many-positional-arguments] def compute_tm(logits: torch.Tensor, aa_mask: torch.Tensor, max_bin: float = 31.0): square_mask = _compute_pae_masks(aa_mask) seqlens = aa_mask.sum(-1, keepdim=True) diff --git a/esm/utils/structure/protein_chain.py b/esm/utils/structure/protein_chain.py index 08f39df2..4e433bd4 100644 --- a/esm/utils/structure/protein_chain.py +++ b/esm/utils/structure/protein_chain.py @@ -26,7 +26,12 @@ from esm.utils.structure.aligner import Aligner from esm.utils.structure.atom_indexer import AtomIndexer from esm.utils.structure.metrics import compute_gdt_ts, compute_lddt_ca -from esm.utils.structure.mmcif_parsing import MmcifWrapper, Residue +from esm.utils.structure.mmcif_parsing import ( + PLDDT_B_FACTOR_SCALE, + MmcifWrapper, + Residue, + round_mmcif_columns, +) from esm.utils.structure.normalize_coordinates import ( apply_frame_to_coords, get_protein_normalization_frame, @@ -42,7 +47,9 @@ def _str_key_to_int_key(dct: dict, ignore_keys: list[str] | None = None) -> dict new_dict = {} for k, v in dct.items(): v_new = v - if k not in ignore_keys and isinstance(v, dict): + if k not in ignore_keys and isinstance( # ty:ignore[unsupported-operator] + v, dict + ): v_new = _str_key_to_int_key(v, ignore_keys=ignore_keys) # Note assembly_composition is *supposed* to have string keys. if isinstance(k, str) and k.isdigit(): @@ -125,7 +132,7 @@ def chain_to_ndarray( ) atom_mask[res_index, residue_constants.atom_order[atom_name]] = True if is_predicted and atom_name == "CA": - confidence[res_index] = atom.b_factor + confidence[res_index] = atom.b_factor / PLDDT_B_FACTOR_SCALE assert all(sequence), "Some residue name was not specified correctly" return ( @@ -226,7 +233,8 @@ def atom_array(self) -> bs.AtomArray: hetero=False, atom_name=residue_constants.atom_types[i], element=residue_constants.atom_types[i][0], - b_factor=float(b_factor), + b_factor=float(b_factor) * PLDDT_B_FACTOR_SCALE, + occupancy=1.0, # Necessary for BioPython MMCIFParser ) atoms.append(atom) return bs.array(atoms) @@ -261,7 +269,8 @@ def atom_array_no_insertions(self) -> bs.AtomArray: hetero=False, atom_name=residue_constants.atom_types[i], element=residue_constants.atom_types[i][0], - b_factor=float(b_factor), + b_factor=float(b_factor) * PLDDT_B_FACTOR_SCALE, + occupancy=1.0, # Necessary for BioPython MMCIFParser ) atoms.append(atom) return bs.array(atoms) @@ -359,7 +368,9 @@ def to_mmcif(self, path: PathOrBuffer): data=CIFData(array=["2"] * len(self.residue_index), dtype=np.str_) ), "metric_value": CIFColumn( - data=CIFData(array=self.confidence, dtype=np.float32) + data=CIFData( + array=self.confidence * PLDDT_B_FACTOR_SCALE, dtype=np.float32 + ) ), # hard coded to show there are the initial version, there are no revisions "model_id": CIFColumn( @@ -369,6 +380,9 @@ def to_mmcif(self, path: PathOrBuffer): f.block["ma_qa_metric_local"] = CIFCategory( name="ma_qa_metric_local", columns=resid_pldd_table ) + # biotite echoes unmasked float columns at full precision + # so we round every float column to conventional mmCIF precision + round_mmcif_columns(f) f.write(path) def to_mmcif_string(self) -> str: @@ -453,7 +467,7 @@ def from_blob(cls, input: Path | str | io.BytesIO | bytes): def sasa(self, by_residue: bool = True): arr = self.atom_array_no_insertions - sasa_per_atom = bs.sasa(arr) # type: ignore + sasa_per_atom = bs.sasa(arr) if by_residue: # Sum per-atom SASA into residue "bins", with np.bincount. assert arr.res_id is not None @@ -545,7 +559,7 @@ def sap_score(self, aggregation: str = "atom") -> np.ndarray: assert len(sap_by_residue) == len(self) return sap_by_residue case "protein": - return sum(sap_by_atom[sap_by_atom > 0]) # pyright: ignore[reportReturnType] + return sum(sap_by_atom[sap_by_atom > 0]) case _: raise ValueError( f"Invalid aggregation method: {aggregation}. Must be one of 'atom', 'residue', or 'protein'" @@ -561,7 +575,7 @@ def globularity(self) -> float: # NOTE(@zeming): due to the approximation we make here, that atoms never overlap, you might get >1 globularity mask = self.atom37_mask.any(-1) points = self.atom37_positions[self.atom37_mask] - sequence = [aa for aa, m in zip(self.sequence, mask) if m] # type: ignore + sequence = [aa for aa, m in zip(self.sequence, mask) if m] A, _ = self._mvee(points, tol=1e-3) mvee_volume = (4 * np.pi) / (3 * np.sqrt(np.linalg.det(A))) volume = sum(residue_constants.amino_acid_volumes[x] for x in sequence) @@ -954,7 +968,7 @@ def from_atom37( return cls( id=id, - sequence=sequence, # type: ignore + sequence=sequence, chain_id=chain_id, entity_id=entity_id, atom37_positions=atom37_positions, @@ -1081,7 +1095,7 @@ def from_pdb( ) atom_mask[i, residue_constants.atom_order[atom_name]] = True if is_predicted and atom_name == "CA": - confidence[i] = atom.b_factor + confidence[i] = atom.b_factor / PLDDT_B_FACTOR_SCALE assert all(sequence), "Some residue name was not specified correctly" @@ -1121,7 +1135,7 @@ def from_rcsb( entity_id: int | None = None, keep_source: bool = False, ) -> ProteinChain: - f: io.StringIO = rcsb.fetch(pdb_id, "cif") # type: ignore + f: io.StringIO = rcsb.fetch(pdb_id, "cif") return cls.from_mmcif( f, id=pdb_id, @@ -1139,7 +1153,7 @@ def from_atomarray( Uses PDB file format as intermediate.""" atom_array = atom_array.copy() atom_array.box = None # remove surrounding box, from_pdb won't handle this - pdb_file = PDBFile() # pyright: ignore + pdb_file = PDBFile() pdb_file.set_structure(atom_array) buf = io.StringIO() diff --git a/esm/utils/structure/protein_complex.py b/esm/utils/structure/protein_complex.py index 4f74f697..eaea2911 100644 --- a/esm/utils/structure/protein_complex.py +++ b/esm/utils/structure/protein_complex.py @@ -33,7 +33,11 @@ from esm.utils.structure.aligner import Aligner from esm.utils.structure.atom_indexer import AtomIndexer from esm.utils.structure.metrics import compute_gdt_ts, compute_lddt_ca -from esm.utils.structure.mmcif_parsing import MmcifWrapper, NoProteinError +from esm.utils.structure.mmcif_parsing import ( + MmcifWrapper, + NoProteinError, + round_mmcif_columns, +) from esm.utils.structure.protein_chain import ( ProteinChain, _str_key_to_int_key, @@ -474,7 +478,7 @@ def from_blob(cls, input: Path | str | io.BytesIO | bytes): @classmethod def from_rcsb(cls, pdb_id: str, keep_source: bool = False) -> ProteinComplex: - f: io.StringIO = rcsb.fetch(pdb_id, "cif") # type: ignore + f: io.StringIO = rcsb.fetch(pdb_id, "cif") return cls.from_mmcif(f, id=pdb_id, keep_source=keep_source, is_predicted=False) @classmethod @@ -890,7 +894,7 @@ def sanity_check_chain_ids(pc: ProteinComplex): def parse_dict(d: dict[str, Any]) -> DockQSingleScore: return DockQSingleScore( - native_chains=tuple(d["Native chains"]), # type: ignore + native_chains=tuple(d["Native chains"]), DockQ=float(d["DockQ"]), interface_rms=float(d["irms"]), ligand_rms=float(d["Lrms"]), # Note the capitalization difference @@ -901,7 +905,10 @@ def parse_dict(d: dict[str, Any]) -> DockQSingleScore: DockQ_F1=float(d["DockQ_F1"]), ) - inv_mapping = {v: k for k, v in result["mapping"].items()} + inv_mapping = { + v: k + for k, v in result["mapping"].items() # ty:ignore[unresolved-attribute] + } self_chain_map = {c.chain_id: c for c in self.chain_iter()} realigned = [] @@ -913,9 +920,11 @@ def parse_dict(d: dict[str, Any]) -> DockQSingleScore: realigned = aligner.apply(realigned) result = DockQResult( - total_dockq=result["value"], - native_interfaces=result["native interfaces"], - chain_mapping=result["mapping"], + total_dockq=result["value"], # ty:ignore[invalid-argument-type] + native_interfaces=result[ + "native interfaces" + ], # ty:ignore[invalid-argument-type] + chain_mapping=result["mapping"], # ty:ignore[invalid-argument-type] interfaces={ (i["Model chains"][0], i["Model chains"][1]): parse_dict(i) for i in interfaces @@ -1011,6 +1020,10 @@ def to_mmcif_string(self) -> str: # Add entity information for proper mmCIF structure self._add_entity_information(f) + # biotite echoes unmasked float columns at full precision + # so we round every float column to conventional mmCIF precision + round_mmcif_columns(f) + # Write to string output = io.StringIO() f.write(output) @@ -1142,20 +1155,20 @@ def get_assembly_fast( ### Get structure according to additional parameters structure = get_structure( pdbx_file, model, data_block, altloc, ["label_asym_id"], use_author_fields - )[0] # type: ignore + )[0] # TODO(@zeming) This line will remove all non-protein structural elements, # we should remove this when we want to parse these too. structure: bs.AtomArray = structure[ - bs.filter_amino_acids(structure) & ~structure.hetero # type: ignore + bs.filter_amino_acids(structure) & ~structure.hetero ] if len(structure) == 0: raise NoProteinError - unique_asym_ids = np.unique(structure.label_asym_id) # type: ignore + unique_asym_ids = np.unique(structure.label_asym_id) asym2chain = {} asym2auth = {} for asym_id in unique_asym_ids: - sub_structure: bs.AtomArray = structure[structure.label_asym_id == asym_id] # type: ignore - chain_id: str = sub_structure[0].chain_id # type: ignore + sub_structure: bs.AtomArray = structure[structure.label_asym_id == asym_id] + chain_id: str = sub_structure[0].chain_id ( sequence, atom_positions, diff --git a/esm/utils/structure/protein_structure.py b/esm/utils/structure/protein_structure.py index cb91d110..d5ed7f83 100644 --- a/esm/utils/structure/protein_structure.py +++ b/esm/utils/structure/protein_structure.py @@ -6,7 +6,7 @@ import torch import torch.nn.functional as F from torch import Tensor -from torch.amp import autocast # type: ignore +from torch.amp import autocast from esm.utils import residue_constants from esm.utils.misc import unbinpack @@ -28,7 +28,7 @@ def index_by_atom_name( result = atom37[index] # type: ignore if squeeze: result = result.squeeze(dim) - return result + return result # ty:ignore[invalid-return-type] def infer_cbeta_from_atom37( @@ -52,7 +52,7 @@ def normalize(x: ArrayOrTensor): cross = np.cross else: - normalize = F.normalize # type: ignore + normalize = F.normalize cross = torch.cross with np.errstate(invalid="ignore"): # inf - inf = nan is ok here diff --git a/esm/utils/system.py b/esm/utils/system.py index c2800e57..3114b72c 100644 --- a/esm/utils/system.py +++ b/esm/utils/system.py @@ -40,6 +40,6 @@ def run_subprocess_with_errorcheck( ) except subprocess.CalledProcessError as e: raise RuntimeError( - f"Command failed with errorcode {e.returncode}." f"\n\n{e.stderr.decode()}" + f"Command failed with errorcode {e.returncode}.\n\n{e.stderr.decode()}" ) return p diff --git a/esm/widgets/components/results_visualizer.py b/esm/widgets/components/results_visualizer.py index 261c0a5d..95ff1f7d 100644 --- a/esm/widgets/components/results_visualizer.py +++ b/esm/widgets/components/results_visualizer.py @@ -38,7 +38,7 @@ def create_results_visualizer( # Sort structures by pTM samples = sorted( samples, - key=lambda item: (item.ptm.item() if item.ptm is not None else 0), + key=lambda item: item.ptm.item() if item.ptm is not None else 0, reverse=True, ) diff --git a/esm/widgets/components/sasa_prompt_selector.py b/esm/widgets/components/sasa_prompt_selector.py index 9c026500..7c23bf6f 100644 --- a/esm/widgets/components/sasa_prompt_selector.py +++ b/esm/widgets/components/sasa_prompt_selector.py @@ -18,8 +18,8 @@ def create_sasa_prompt_selector( with_title: bool = True, active_tag_callback: Callable[[], str] | None = None, ) -> widgets.Widget: - is_active_callback = ( - lambda: active_tag_callback() == tag if active_tag_callback else True + is_active_callback = lambda: ( + active_tag_callback() == tag if active_tag_callback else True ) if input_array is None: diff --git a/esm/widgets/components/secondary_structure_prompt_selector.py b/esm/widgets/components/secondary_structure_prompt_selector.py index b8007d69..0ca3d687 100644 --- a/esm/widgets/components/secondary_structure_prompt_selector.py +++ b/esm/widgets/components/secondary_structure_prompt_selector.py @@ -21,8 +21,8 @@ def create_secondary_structure_prompt_selector( ) -> widgets.Widget: ss3_categories = get_ss3_categories() - is_active_callback = ( - lambda: active_tag_callback() == tag if active_tag_callback else True + is_active_callback = lambda: ( + active_tag_callback() == tag if active_tag_callback else True ) if input_array is None: diff --git a/esm/widgets/components/sequence_prompt_selector.py b/esm/widgets/components/sequence_prompt_selector.py index c5e3526f..ccd6f303 100644 --- a/esm/widgets/components/sequence_prompt_selector.py +++ b/esm/widgets/components/sequence_prompt_selector.py @@ -20,8 +20,8 @@ def create_sequence_prompt_selector( ) -> widgets.Widget: sequence_length = len(full_sequence) - is_active_callback = ( - lambda: active_tag_callback() == tag if active_tag_callback else True + is_active_callback = lambda: ( + active_tag_callback() == tag if active_tag_callback else True ) range_slider = widgets.IntRangeSlider( diff --git a/esm/widgets/components/structure_prompt_selector.py b/esm/widgets/components/structure_prompt_selector.py index 13a9df78..28e4a3e1 100644 --- a/esm/widgets/components/structure_prompt_selector.py +++ b/esm/widgets/components/structure_prompt_selector.py @@ -28,8 +28,8 @@ def create_structure_prompt_selector( min_residue, max_residue = indexing.get_pdb_index_min_max(protein_chain) - is_active_callback = ( - lambda: active_tag_callback() == tag if active_tag_callback else True + is_active_callback = lambda: ( + active_tag_callback() == tag if active_tag_callback else True ) matrix_output = widgets.Output() diff --git a/esm/widgets/utils/drawing/colors.py b/esm/widgets/utils/drawing/colors.py index 373f5fe4..9043d459 100644 --- a/esm/widgets/utils/drawing/colors.py +++ b/esm/widgets/utils/drawing/colors.py @@ -22,7 +22,7 @@ def float_to_int(f): g = float_to_int(rgba[1]) b = float_to_int(rgba[2]) if len(rgba) > 3: - rgba = (r, g, b, rgba[3]) + rgba = (r, g, b, rgba[3]) # ty:ignore[index-out-of-bounds] else: rgba = (r, g, b) diff --git a/esm/widgets/utils/drawing/draw_function_annotations.py b/esm/widgets/utils/drawing/draw_function_annotations.py index 59e9f7cf..2ca55c9c 100644 --- a/esm/widgets/utils/drawing/draw_function_annotations.py +++ b/esm/widgets/utils/drawing/draw_function_annotations.py @@ -43,7 +43,7 @@ def draw_function_annotations( start=annotation.start - 1, # one index -> zero index end=annotation.end, label=label, - color=type_colors[entry_type], # type: ignore + color=type_colors[entry_type], strand=None, ) features.append(feature) diff --git a/esm/widgets/utils/prompting.py b/esm/widgets/utils/prompting.py index 1e89bb64..199e7ddc 100644 --- a/esm/widgets/utils/prompting.py +++ b/esm/widgets/utils/prompting.py @@ -280,14 +280,14 @@ def add_entry_to_ui(self, range_string): f"{range_string}" ) ) - entry_label.tag = range_string # type: ignore + entry_label.tag = range_string entry_container = widgets.HBox([entry_button, entry_label]) def delete_entry(b): self.entries_box.children = [ w for w in self.entries_box.children if w != entry_container ] - self.delete_prompt(entry_label.tag) # type: ignore + self.delete_prompt(entry_label.tag) self.redraw() for callback in self.delete_callbacks: callback() diff --git a/esm/widgets/utils/protein_import.py b/esm/widgets/utils/protein_import.py index c9d9dbb4..0d6f6632 100644 --- a/esm/widgets/utils/protein_import.py +++ b/esm/widgets/utils/protein_import.py @@ -114,7 +114,7 @@ def add_pdb_id(self, pdb_id: str, chain_id: str): def add_entry_to_ui(self, protein_id: str): entry_button = widgets.Button(description="Remove") entry_label = widgets.Label(value=protein_id) - entry_label.tag = protein_id # type: ignore + entry_label.tag = protein_id entry_container = widgets.HBox([entry_button, entry_label]) def delete_entry(b): diff --git a/esm/widgets/views/esm3_prompt_preview.py b/esm/widgets/views/esm3_prompt_preview.py index e9b9a5f2..60a57e1b 100644 --- a/esm/widgets/views/esm3_prompt_preview.py +++ b/esm/widgets/views/esm3_prompt_preview.py @@ -53,7 +53,7 @@ def text_to_sasa(sasa_text: str) -> list[int | float | None] | None: def function_annotations_to_text(annotations: list[FunctionAnnotation]) -> str: return "\n".join( [ - f"[{annotation.start-1}-{annotation.end-1}]: {annotation.label}" + f"[{annotation.start - 1}-{annotation.end - 1}]: {annotation.label}" for annotation in annotations ] ) diff --git a/pixi.lock b/pixi.lock index f4356818..baab98d9 100644 --- a/pixi.lock +++ b/pixi.lock @@ -74,7 +74,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/matplotlib-base-3.10.5-py312he3d6523_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.5-h2d0b736_3.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/nh3-0.2.21-py39h77e2912_1.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/numpy-2.3.2-py312h33ff503_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/numpy-1.26.4-py312heda63a1_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/openjpeg-2.5.3-h55fea9a_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.4.1-h7b32b05_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pandas-2.3.1-py312hf79963d_0.conda @@ -324,7 +324,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/matplotlib-base-3.10.5-py312h05635fa_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/ncurses-6.5-h5e97a16_3.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/nh3-0.3.0-py39h24c5d98_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/numpy-2.3.2-py312h2f38b44_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/numpy-1.26.4-py312h8442bc7_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/openjpeg-2.5.3-h889cd5d_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/openssl-3.5.2-he92f556_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pandas-2.3.1-py312h98f7732_0.conda @@ -507,8 +507,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/mysql-libs-9.0.1-he0572af_6.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.5-h2d0b736_3.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/nh3-0.2.21-py39h77e2912_1.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/nodejs-22.13.0-hf235a45_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/numpy-2.3.2-py312h33ff503_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/numpy-1.26.4-py312heda63a1_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/openjpeg-2.5.3-h55fea9a_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/openldap-2.6.9-he970967_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.4.1-h7b32b05_0.conda @@ -518,7 +517,6 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/pixman-0.46.4-h537e5f6_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pkg-config-0.29.2-h4bc722e_1009.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pthread-stubs-0.4-hb9d3cd8_1002.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/pyright-1.1.399-py312h66e93f0_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pyside6-6.9.0-py312h91f0f75_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.12.9-h9e4cc4f_1_cpython.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/python_abi-3.12-6_cp312.conda @@ -532,6 +530,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/statsmodels-0.14.5-py312h8b63200_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/tornado-6.5.1-py312h66e93f0_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/ty-0.0.49-h4e94fc0_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/ukkonen-1.0.1-py312h68727a3_5.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/unicodedata2-16.0.0-py312h66e93f0_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/wayland-1.24.0-h3e06ad9_0.conda @@ -558,7 +557,6 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-libxtst-1.2.5-hb9d3cd8_3.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-libxxf86vm-1.1.6-hb9d3cd8_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/yaml-0.2.5-h280c20c_3.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.1-hb9d3cd8_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstandard-0.23.0-py312h66e93f0_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb8e6e7a_2.conda - conda: https://conda.anaconda.org/conda-forge/noarch/_python_abi3_support-1.0-hd8ed1ab_1.conda @@ -836,8 +834,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/matplotlib-base-3.10.5-py312h05635fa_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/ncurses-6.5-h5e97a16_3.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/nh3-0.3.0-py39h24c5d98_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/nodejs-24.4.1-hab9d20b_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/numpy-2.3.2-py312h2f38b44_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/numpy-1.26.4-py312h8442bc7_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/openjpeg-2.5.3-h889cd5d_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/openssl-3.5.2-he92f556_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pandas-2.3.1-py312h98f7732_0.conda @@ -845,7 +842,6 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pillow-11.3.0-py312h50aef2c_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pkg-config-0.29.2-hde07d2e_1009.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pthread-stubs-0.4-hd74edd7_1002.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pyright-1.1.399-py312hea69d52_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/python-3.12.11-hc22306f_0_cpython.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pyyaml-6.0.2-py312h998013c_2.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/qhull-2020.2-h420ef59_5.conda @@ -855,6 +851,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/statsmodels-0.14.5-py312hcde60ef_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/tk-8.6.13-h892fb3f_2.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/tornado-6.5.2-py312h163523d_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/ty-0.0.49-hdfcc030_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/ukkonen-1.0.1-py312h6142ec9_5.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/unicodedata2-16.0.0-py312hea69d52_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/xorg-libxau-1.0.12-h5505292_0.conda @@ -2142,44 +2139,28 @@ packages: - pkg:pypi/nh3?source=hash-mapping size: 621078 timestamp: 1741652643562 -- conda: https://conda.anaconda.org/conda-forge/linux-64/nodejs-22.13.0-hf235a45_0.conda - sha256: 925ea8839d6f26d0eb4204675b98a862803a9a9657fd36a4a22c4c29a479a911 - md5: 1f9efd96347aa008bd2c735d7d88fc75 +- conda: https://conda.anaconda.org/conda-forge/linux-64/numpy-1.26.4-py312heda63a1_0.conda + sha256: fe3459c75cf84dcef6ef14efcc4adb0ade66038ddd27cadb894f34f4797687d8 + md5: d8285bea2a350f63fab23bf460221f3f depends: - - __glibc >=2.28,<3.0.a0 - - icu >=75.1,<76.0a0 - - libgcc >=13 - - libstdcxx >=13 - - libuv >=1.50.0,<2.0a0 - - libzlib >=1.3.1,<2.0a0 - - openssl >=3.4.1,<4.0a0 - - zlib - license: MIT - license_family: MIT - purls: [] - size: 21691794 - timestamp: 1741809786920 -- conda: https://conda.anaconda.org/conda-forge/linux-64/numpy-2.3.2-py312h33ff503_0.conda - sha256: d54e52df67e0be7e5faa9e6f0efccea3d72f635a3159cc151c4668e5159f6ef3 - md5: 3f6efbc40eb13f019c856c410fa921d2 - depends: - - python - - libgcc >=14 - - __glibc >=2.17,<3.0.a0 - - libstdcxx >=14 - - libgcc >=14 - libblas >=3.9.0,<4.0a0 - - python_abi 3.12.* *_cp312 - libcblas >=3.9.0,<4.0a0 + - libgcc-ng >=12 - liblapack >=3.9.0,<4.0a0 + - libstdcxx-ng >=12 + - python >=3.12,<3.13.0a0 + - python_abi 3.12.* *_cp312 constrains: - numpy-base <0a0 license: BSD-3-Clause license_family: BSD purls: - - pkg:pypi/numpy?source=compressed-mapping - size: 8785045 - timestamp: 1753401550884 + - pkg:pypi/numpy?source=hash-mapping + run_exports: + weak: + - numpy >=1.26.4,<2.0a0 + size: 7484186 + timestamp: 1707225809722 - conda: https://conda.anaconda.org/conda-forge/linux-64/openjpeg-2.5.3-h55fea9a_1.conda sha256: 0b7396dacf988f0b859798711b26b6bc9c6161dca21bacfd778473da58730afa md5: 01243c4aaf71bde0297966125aea4706 @@ -2344,23 +2325,6 @@ packages: purls: [] size: 8252 timestamp: 1726802366959 -- conda: https://conda.anaconda.org/conda-forge/linux-64/pyright-1.1.399-py312h66e93f0_0.conda - sha256: 9857f51927cbe196ce7c0b4258504283a3492025a54dd7396282e718133eada2 - md5: 2fcde1e79f9d16a21acc7a1391c80216 - depends: - - __glibc >=2.17,<3.0.a0 - - libgcc >=13 - - nodeenv >=1.6.0 - - nodejs - - python >=3.12,<3.13.0a0 - - python_abi 3.12.* *_cp312 - - typing_extensions >=4.1 - license: MIT - license_family: MIT - purls: - - pkg:pypi/pyright?source=hash-mapping - size: 3543823 - timestamp: 1744274982054 - conda: https://conda.anaconda.org/conda-forge/linux-64/pyside6-6.9.0-py312h91f0f75_0.conda sha256: 4db931dccd8347140e79236378096d9a1b97b98bbd206d54cebd42491ad12535 md5: e3a335c7530a1d0c4db621914f00f9f7 @@ -2617,6 +2581,25 @@ packages: - pkg:pypi/tornado?source=hash-mapping size: 850902 timestamp: 1748003427956 +- conda: https://conda.anaconda.org/conda-forge/linux-64/ty-0.0.49-h4e94fc0_0.conda + noarch: python + sha256: 820b2baa740578d49b986cae19c9b74b29c14220ea2de5ca136227db625e24a2 + md5: f0de70fcdee2ba2da18db5ecfdc203c0 + depends: + - python + - __glibc >=2.17,<3.0.a0 + - libgcc >=14 + - _python_abi3_support 1.* + - cpython >=3.10 + constrains: + - __glibc >=2.17 + license: MIT + license_family: MIT + purls: + - pkg:pypi/ty?source=hash-mapping + run_exports: {} + size: 10273495 + timestamp: 1781322169980 - conda: https://conda.anaconda.org/conda-forge/linux-64/ukkonen-1.0.1-py312h68727a3_5.conda sha256: 9fb020083a7f4fee41f6ece0f4840f59739b3e249f157c8a407bb374ffb733b5 md5: f9664ee31aed96c85b7319ab0a693341 @@ -2946,18 +2929,6 @@ packages: purls: [] size: 85189 timestamp: 1753484064210 -- conda: https://conda.anaconda.org/conda-forge/linux-64/zlib-1.3.1-hb9d3cd8_2.conda - sha256: 5d7c0e5f0005f74112a34a7425179f4eb6e73c92f5d109e6af4ddeca407c92ab - md5: c9f075ab2f33b3bbee9e62d4ad0a6cd8 - depends: - - __glibc >=2.17,<3.0.a0 - - libgcc >=13 - - libzlib 1.3.1 hb9d3cd8_2 - license: Zlib - license_family: Other - purls: [] - size: 92286 - timestamp: 1727963153079 - conda: https://conda.anaconda.org/conda-forge/linux-64/zstandard-0.23.0-py312h66e93f0_1.conda sha256: b4fd6bd1cb87a183a8bbe85b4e87a1e7c51473309d0d82cd88d38fb021bcf41e md5: d28b82fcc8d1b462b595af4b15a6cdcf @@ -4885,32 +4856,16 @@ packages: - pkg:pypi/nh3?source=hash-mapping size: 624089 timestamp: 1752853325963 -- conda: https://conda.anaconda.org/conda-forge/osx-arm64/nodejs-24.4.1-hab9d20b_0.conda - sha256: c79d2c81f80a9adedc77362f2e8b10879ed0f9806deb6ba2464c1287a05f0b9b - md5: 463a537de602f8558604f27395b323d0 +- conda: https://conda.anaconda.org/conda-forge/osx-arm64/numpy-1.26.4-py312h8442bc7_0.conda + sha256: c8841d6d6f61fd70ca80682efbab6bdb8606dc77c68d8acabfbd7c222054f518 + md5: d83fc83d589e2625a3451c9a7e21047c depends: - - libcxx >=19 - - __osx >=11.0 - - openssl >=3.5.1,<4.0a0 - - libuv >=1.51.0,<2.0a0 - - icu >=75.1,<76.0a0 - - libzlib >=1.3.1,<2.0a0 - license: MIT - license_family: MIT - purls: [] - size: 17949155 - timestamp: 1752839389217 -- conda: https://conda.anaconda.org/conda-forge/osx-arm64/numpy-2.3.2-py312h2f38b44_0.conda - sha256: 581039072c18b2abd8dfcf7fe5c16a8fbb72e14821bad4817ca00dbb16f3bad3 - md5: c58a6fa1ee8edb9de10d0f5c91806193 - depends: - - python - - libcxx >=19 - - python 3.12.* *_cpython - - __osx >=11.0 - - liblapack >=3.9.0,<4.0a0 - libblas >=3.9.0,<4.0a0 - libcblas >=3.9.0,<4.0a0 + - libcxx >=16 + - liblapack >=3.9.0,<4.0a0 + - python >=3.12,<3.13.0a0 + - python >=3.12,<3.13.0a0 *_cpython - python_abi 3.12.* *_cp312 constrains: - numpy-base <0a0 @@ -4918,8 +4873,11 @@ packages: license_family: BSD purls: - pkg:pypi/numpy?source=hash-mapping - size: 6657726 - timestamp: 1753401542508 + run_exports: + weak: + - numpy >=1.26.4,<2.0a0 + size: 6073136 + timestamp: 1707226249608 - conda: https://conda.anaconda.org/conda-forge/osx-arm64/openjpeg-2.5.3-h889cd5d_1.conda sha256: 6013916893fcd9bc97c479279cfe4616de7735ec566bad0ee41bc729e14d31b2 md5: ab581998c77c512d455a13befcddaac3 @@ -5054,23 +5012,6 @@ packages: purls: [] size: 8381 timestamp: 1726802424786 -- conda: https://conda.anaconda.org/conda-forge/osx-arm64/pyright-1.1.399-py312hea69d52_0.conda - sha256: 97a022cd83aaad61b347ba2d420f40f206c92c201b612544f5a16d82542900d8 - md5: 2ec7ce1f1637de2984b9a6d0362d07ec - depends: - - __osx >=11.0 - - nodeenv >=1.6.0 - - nodejs - - python >=3.12,<3.13.0a0 - - python >=3.12,<3.13.0a0 *_cpython - - python_abi 3.12.* *_cp312 - - typing_extensions >=4.1 - license: MIT - license_family: MIT - purls: - - pkg:pypi/pyright?source=hash-mapping - size: 3572220 - timestamp: 1744275049046 - conda: https://conda.anaconda.org/conda-forge/osx-arm64/python-3.12.11-hc22306f_0_cpython.conda sha256: cde8b944c2dc378a5afbc48028d0843583fd215493d5885a80f1b41de085552f md5: 9207ebad7cfbe2a4af0702c92fd031c4 @@ -5207,6 +5148,24 @@ packages: - pkg:pypi/tornado?source=hash-mapping size: 853490 timestamp: 1754732280524 +- conda: https://conda.anaconda.org/conda-forge/osx-arm64/ty-0.0.49-hdfcc030_0.conda + noarch: python + sha256: cd727a7e22fd445b38181038d68f947900ae3ddcd17a426d991b842735f63769 + md5: ccd9ff0122035dc50d199ea0484864f9 + depends: + - python + - __osx >=11.0 + - _python_abi3_support 1.* + - cpython >=3.10 + constrains: + - __osx >=11.0 + license: MIT + license_family: MIT + purls: + - pkg:pypi/ty?source=hash-mapping + run_exports: {} + size: 9251353 + timestamp: 1781322180954 - conda: https://conda.anaconda.org/conda-forge/osx-arm64/ukkonen-1.0.1-py312h6142ec9_5.conda sha256: 1e4452b4a12d8a69c237f14b876fbf0cdc456914170b49ba805779c749c31eca md5: 2b485a809d1572cbe7f0ad9ee107e4b0 diff --git a/pyproject.toml b/pyproject.toml index 7ab6d609..e2c64e1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,12 +86,15 @@ upload-wheel = "python -m twine upload --repository pypi" [tool.pixi.feature.dev.dependencies] matplotlib = "*" +# Pin the lint/test env to numpy<2: the numpy 2.x type stubs trip ty false positives +# (savez/zip) that don't reflect runtime behavior. Runtime deps stay numpy-unconstrained. +numpy = ">=1.26.0,<2.0.0" pre-commit = "*" pytest = "*" pytest-cov = "*" pytest-xdist = "*" seaborn = "*" -pyright = "==1.1.399" +ty = "==0.0.49" [tool.pixi.feature.dev.tasks] lint-all = "pre-commit run --all-files --show-diff-on-failure" @@ -133,10 +136,37 @@ docstring-code-line-length = "dynamic" [tool.isort] known_third_party = ["wandb"] -[tool.pyright] -useLibraryCodeForTypes = true -reportPrivateImportUsage = false -typeCheckingMode = "basic" +[tool.ty.src] +# gitignore-style globs. Skip notebooks, tests, and the Modal-app tutorial whose deps +# (modal, abnumber) and dynamic decorators aren't resolvable in the lint env. +exclude = [ + "**/*.ipynb", + "**/*_test.py", + "**/test_*.py", + "**/tests/", + "**/conftest.py", + "cookbook/tutorials/binder_design.py", +] + +[tool.ty.rules] +# Strict LSP override-compatibility checking that pyright `basic` never enforced; low +# value for duck-typed ML code. +invalid-method-override = "ignore" +# flash_attn resolves only where it's installed (a GPU environment); the ignores it +# requires there read as unused/redundant in this CPU lint env where it's absent. +unused-ignore-comment = "ignore" +unused-type-ignore-comment = "ignore" + +[tool.ty.analysis] +# ty can't introspect optional/compiled deps absent from this lint env (flash_attn is +# GPU-only; zstd is a C-extension), so it falsely reports them unresolved. Treat as Any +# rather than disabling unresolved-import, so genuinely-broken first-party imports still +# surface. +replace-imports-with-any = [ + "flash_attn", + "flash_attn.**", + "zstd", +] [tool.importlinter] root_package = "esm"