Skip to content

Added Apple silicon MPS support#1

Open
hubin-keio wants to merge 5 commits into
oxpig:mainfrom
hubin-keio:mps-support
Open

Added Apple silicon MPS support#1
hubin-keio wants to merge 5 commits into
oxpig:mainfrom
hubin-keio:mps-support

Conversation

@hubin-keio

Copy link
Copy Markdown
  • Auto-detect device (mpscudacpu) in pretrained, pretrained_tap, flash_abb.py

    • Fix featurize() and to_pdbs() to respect the model's device rather than hardcoding cuda
    • Defer import torch.utils.checkpoint to avoid sympy.core partial-init error in IPython/Jupyter
    • Add pyproject.toml build-system declaration required by pip ≥ 26
    • Fall back to manual matmul+softmax when d_q ≠ d_v in FlashpointAttention — MPS's
      scaled_dot_product_attention silently returns wrong output shape in that case (d_q=52, d_v=168).
      The fallback is mathematically identical (scale=1.0, Q is pre-scaled).

    Tested on

    • macOS 15, Apple M-series, PyTorch 2.x, MPS backend
    • All three capabilities: pretrained, pretrained_tap, pretrained_sss

All five tensors pass.

┌────────────────────┬──────────┬───────────┐
│ Output │ max diff │ mean diff │
├────────────────────┼──────────┼───────────┤
│ coords (structure) │ 5.2e-05 │ 1.6e-06 │
├────────────────────┼──────────┼───────────┤
│ bb_coords │ 1.1e-05 │ 2.9e-06 │
├────────────────────┼──────────┼───────────┤
│ TAP scores │ 7.6e-05 │ 2.2e-05 │
├────────────────────┼──────────┼───────────┤
│ SSS embeddings │ 1.6e-04 │ 9.5e-06 │
├────────────────────┼──────────┼───────────┤
│ SSS mask │ 0 │ 0 │
└────────────────────┴──────────┴───────────┘

The differences are likely float32 rounding noise from hardware-level FP arithmetic differences between CUDA and MPS — roughly 5–7 orders of magnitude below the values themselves. The SDPA fallback is numerically equivalent.
outputs_npz.zip

hubin-keio and others added 4 commits June 6, 2026 14:32
pretrained.py: replace hardcoded device='cuda' default with auto-detection
  (cuda → mps → cpu) via _default_device() helper.
pretrained_tap.py: replace one-liner DEVICE constant with MPS-aware fallback.
model/flash_abb.py: fix featurize() default device and to_pdbs() device
  pass-through (inferred from stored output tensor so it works on any device).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
pip 21.3+ requires a build-system declaration for editable installs.
Without pyproject.toml, pip 22+ falls back to legacy mode inconsistently
and pip 26 rejects setup.py-only editable installs entirely.
Declares setuptools as the build backend; setup.py remains the source
of truth for package metadata.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Top-level `import torch.utils.checkpoint` pulls in PyTorch's symbolic
shapes machinery (torch.fx.experimental.symbolic_shapes), which accesses
sympy.core at class definition time. In IPython, earlier imports can leave
sympy partially initialised in sys.modules, causing AttributeError on
sympy.core when this import chain fires.

torch.utils.checkpoint is only needed inside get_checkpoint_fn(), which is
only reached during training (blocks_per_ckpt not None, grad enabled).
Deferring the import to that call site avoids the issue entirely and has
no effect on inference.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
MPS scaled_dot_product_attention requires query and value to have the
same head dimension. FlashpointAttention concatenates geometric point
encodings onto Q/K (d_q = c_ipa + n_qk_points*9 = 52) and spectral
position features onto V (d_v = c_ipa + n_v_points*3 + rel_pos_dim = 168),
so d_q != d_v. MPS silently returns output with d_q instead of d_v,
causing the downstream torch.split([16, 24, 128]) to fail on a tensor of
size 52.

Fix: check q.shape[-1] != v.shape[-1] and use torch.matmul + softmax
when they differ. CUDA is unaffected (dimensions match there too, so
the SDPA path is unchanged on CUDA and CPU with equal dims).

Applied to both flashpoint_attention.py copies (StructureModule and
fpa_transformer used by BERTCoords/SSS/TAP).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@hubin-keio hubin-keio changed the title Mps support Added Apple silicon MPS support Jun 6, 2026
@Ellmen

Ellmen commented Jun 6, 2026

Copy link
Copy Markdown
Collaborator

Thanks for the contribution! I encountered a similar issue when experimenting with FlashAttention which I believe requires dimensions to be powers of 2.

Currently, this solution will always default to manual attention which isn't ideal, especially if anyone wants to adapt this for long-context structure tasks. My preference would be to have use_manual only true on MPS and possibly try padding the query dimension to match d_v so that we retain efficeint SDP kernels.

flex_attention (torch.nn.attention.flex_attention) was added in PyTorch 2.5.
The import is present in both flashpoint_attention.py files but the symbols
are never called in the inference path (only a commented-out line references
them). Wrap in try/except so FlashABB loads on PyTorch 2.4 and earlier.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@hubin-keio

Copy link
Copy Markdown
Author

Agreed on non-ideal solution by defaulting to manual attention even for CUDA users. Perhaps we can put a gate on device: "if q.device.type == 'mps' and q.shape[-1] != v.shape[-1]: " instead of relying on d_q != d_v. Padding is not ideal either because it adds tensor allocation and triples the Q/K working size.

diff --git a/flash_abb/model/flashpoint_attention.py b/flash_abb/model/flashpoint_attention.py
@@ -207,7 +207,7 @@
-        if q.shape[-1] != v.shape[-1]:
+        if q.device.type == 'mps' and q.shape[-1] != v.shape[-1]:

diff --git a/flash_abb/model/fpa_transformer/flashpoint_attention.py b/flash_abb/model/fpa_transformer/flashpoint_attention.py
@@ -239,7 +239,7 @@
-        _use_manual = (q.shape[-1] != v.shape[-1])
+        _use_manual = (q.device.type == 'mps' and q.shape[-1] != v.shape[-1])

Happy to discuss further if you are interested in an MPS-optimized path. Personally I prefer to develop on a Mac laptop before move it to a CUDA device so I see the need for supporting MPS.

@Ellmen

Ellmen commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator

I'm happy to merge this with that change and a comment indicating that efficient attention isn't currently available for MPS. I can also make an issue with a feature request for the MPS-optimized version.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants