Added Apple silicon MPS support#1
Conversation
pretrained.py: replace hardcoded device='cuda' default with auto-detection (cuda → mps → cpu) via _default_device() helper. pretrained_tap.py: replace one-liner DEVICE constant with MPS-aware fallback. model/flash_abb.py: fix featurize() default device and to_pdbs() device pass-through (inferred from stored output tensor so it works on any device). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
pip 21.3+ requires a build-system declaration for editable installs. Without pyproject.toml, pip 22+ falls back to legacy mode inconsistently and pip 26 rejects setup.py-only editable installs entirely. Declares setuptools as the build backend; setup.py remains the source of truth for package metadata. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Top-level `import torch.utils.checkpoint` pulls in PyTorch's symbolic shapes machinery (torch.fx.experimental.symbolic_shapes), which accesses sympy.core at class definition time. In IPython, earlier imports can leave sympy partially initialised in sys.modules, causing AttributeError on sympy.core when this import chain fires. torch.utils.checkpoint is only needed inside get_checkpoint_fn(), which is only reached during training (blocks_per_ckpt not None, grad enabled). Deferring the import to that call site avoids the issue entirely and has no effect on inference. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
MPS scaled_dot_product_attention requires query and value to have the same head dimension. FlashpointAttention concatenates geometric point encodings onto Q/K (d_q = c_ipa + n_qk_points*9 = 52) and spectral position features onto V (d_v = c_ipa + n_v_points*3 + rel_pos_dim = 168), so d_q != d_v. MPS silently returns output with d_q instead of d_v, causing the downstream torch.split([16, 24, 128]) to fail on a tensor of size 52. Fix: check q.shape[-1] != v.shape[-1] and use torch.matmul + softmax when they differ. CUDA is unaffected (dimensions match there too, so the SDPA path is unchanged on CUDA and CPU with equal dims). Applied to both flashpoint_attention.py copies (StructureModule and fpa_transformer used by BERTCoords/SSS/TAP). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
|
Thanks for the contribution! I encountered a similar issue when experimenting with FlashAttention which I believe requires dimensions to be powers of 2. Currently, this solution will always default to manual attention which isn't ideal, especially if anyone wants to adapt this for long-context structure tasks. My preference would be to have use_manual only true on MPS and possibly try padding the query dimension to match d_v so that we retain efficeint SDP kernels. |
flex_attention (torch.nn.attention.flex_attention) was added in PyTorch 2.5. The import is present in both flashpoint_attention.py files but the symbols are never called in the inference path (only a commented-out line references them). Wrap in try/except so FlashABB loads on PyTorch 2.4 and earlier. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
|
Agreed on non-ideal solution by defaulting to manual attention even for CUDA users. Perhaps we can put a gate on device: "if q.device.type == 'mps' and q.shape[-1] != v.shape[-1]: " instead of relying on d_q != d_v. Padding is not ideal either because it adds tensor allocation and triples the Q/K working size. diff --git a/flash_abb/model/flashpoint_attention.py b/flash_abb/model/flashpoint_attention.py
@@ -207,7 +207,7 @@
- if q.shape[-1] != v.shape[-1]:
+ if q.device.type == 'mps' and q.shape[-1] != v.shape[-1]:
diff --git a/flash_abb/model/fpa_transformer/flashpoint_attention.py b/flash_abb/model/fpa_transformer/flashpoint_attention.py
@@ -239,7 +239,7 @@
- _use_manual = (q.shape[-1] != v.shape[-1])
+ _use_manual = (q.device.type == 'mps' and q.shape[-1] != v.shape[-1])Happy to discuss further if you are interested in an MPS-optimized path. Personally I prefer to develop on a Mac laptop before move it to a CUDA device so I see the need for supporting MPS. |
|
I'm happy to merge this with that change and a comment indicating that efficient attention isn't currently available for MPS. I can also make an issue with a feature request for the MPS-optimized version. |
Auto-detect device (
mps→cuda→cpu) inpretrained,pretrained_tap,flash_abb.pyfeaturize()andto_pdbs()to respect the model's device rather than hardcodingcudaimport torch.utils.checkpointto avoidsympy.corepartial-init error in IPython/Jupyterpyproject.tomlbuild-system declaration required by pip ≥ 26matmul+softmaxwhend_q ≠ d_vinFlashpointAttention— MPS'sscaled_dot_product_attentionsilently returns wrong output shape in that case (d_q=52, d_v=168).The fallback is mathematically identical (
scale=1.0, Q is pre-scaled).Tested on
pretrained,pretrained_tap,pretrained_sssAll five tensors pass.
┌────────────────────┬──────────┬───────────┐
│ Output │ max diff │ mean diff │
├────────────────────┼──────────┼───────────┤
│ coords (structure) │ 5.2e-05 │ 1.6e-06 │
├────────────────────┼──────────┼───────────┤
│ bb_coords │ 1.1e-05 │ 2.9e-06 │
├────────────────────┼──────────┼───────────┤
│ TAP scores │ 7.6e-05 │ 2.2e-05 │
├────────────────────┼──────────┼───────────┤
│ SSS embeddings │ 1.6e-04 │ 9.5e-06 │
├────────────────────┼──────────┼───────────┤
│ SSS mask │ 0 │ 0 │
└────────────────────┴──────────┴───────────┘
The differences are likely float32 rounding noise from hardware-level FP arithmetic differences between CUDA and MPS — roughly 5–7 orders of magnitude below the values themselves. The SDPA fallback is numerically equivalent.
outputs_npz.zip