REDIAL stands for RNA Embedding perturbation Diagnostics for Language models.
redial contains a small Python package, redial, for extracting RNA language-model embeddings and using embedding perturbations to estimate RNA contact maps.
The package currently wraps two RNA models:
RNAFMEmbeddings: loads RNA-FM withfm.pretrained.rna_fm_t12().StructRFMEmbeddings: loads StructRFM throughstructRFM.infer.structRFM_infer.
It also provides a shared model interface and a mutation-based contact predictor.
redial/
|-- README.md
`-- redial/
|-- __init__.py
|-- ContactPredictor.py
|-- RNALanguageModel.py
|-- RNAFMEmbeddings.py
`-- StructRFMEmbeddings.py
Defines the common interface used by the embedding wrappers:
- Selects
cudawhen available, otherwisecpu. - Exposes
encode(sequence, layer=...)for sequence embeddings. - Exposes
decode(embedding)for decoding embeddings to logits. - Exposes
short_circuit(layers=...)for replacing selected transformer layers with identity layers. - Provides
perplexity(sequence, logits)for computing reconstruction perplexity from decoded logits. - Provides
model_params,model,tokenizer, anddetokenizer.
Wrapper around RNA-FM.
Model parameters:
{
"max_length": 1024,
"dim": 640,
"layer": 12,
}encode(sequence, layer=[12]) returns a NumPy array with shape:
num_requested_layers x sequence_length x 640
The wrapper strips RNA-FM special tokens before returning embeddings.
Wrapper around StructRFM.
Default checkpoint path:
/home/dteng/rna_db/structRFM_checkpoint
Model parameters:
{
"max_length": 514,
"dim": 768,
"layer": 12,
"num_attention_heads": 12,
"device": torch.device(...),
}encode(sequence, layer=[12]) returns a NumPy array with shape:
num_requested_layers x sequence_length x 768
The wrapper calls extract_raw_feature(sequence, return_all=True) and strips BOS/EOS tokens before returning embeddings.
Implements embedding perturbation contact prediction:
- Encode the original RNA sequence.
- For each position, mutate the nucleotide to the other three bases.
- Re-encode each mutated sequence.
- Measure embedding changes with an L2 norm.
- Optionally apply average product correction, or APC.
- Symmetrize and normalize the contact map to
[0, 1].
predict(sequence, layer=12, do_apc=True) returns a NumPy array with shape:
sequence_length x sequence_length
batch_predict(sequences, model) runs prediction for multiple sequences and returns a dictionary keyed by sequence.
The package depends on the model libraries used by the wrappers, plus NumPy and PyTorch. In this repository, the relevant dependencies are listed in the top-level requirements.txt:
numpy
torch>=2.0.1
structRFM
transformers[torch]
RNAFMEmbeddings also imports fm, so the RNA-FM package must be installed in the active Python environment.
StructRFMEmbeddings requires a valid StructRFM checkpoint. Use the default path or pass another path:
from redial.StructRFMEmbeddings import StructRFMEmbeddings
model = StructRFMEmbeddings(pretrained_path="/path/to/structRFM_checkpoint")Run examples from the redial directory, or make sure the parent directory of redial is on PYTHONPATH.
from redial.RNAFMEmbeddings import RNAFMEmbeddings
model = RNAFMEmbeddings()
sequence = "ACGUACGU"
embeddings = model.encode(sequence, layer=[12])
last_layer = embeddings[0]
print(last_layer.shape) # (8, 640)from redial.StructRFMEmbeddings import StructRFMEmbeddings
model = StructRFMEmbeddings(pretrained_path="/home/dteng/rna_db/structRFM_checkpoint")
sequence = "ACGUACGU"
embeddings = model.encode(sequence, layer=[12])
last_layer = embeddings[0]
print(last_layer.shape) # (8, 768)from redial.RNAFMEmbeddings import RNAFMEmbeddings
model = RNAFMEmbeddings()
sequence = "ACGUACGU"
embedding = model.encode(sequence, layer=[12])
logits = model.decode(embedding)
perplexity = model.perplexity(sequence, logits)
print(perplexity)from redial.RNAFMEmbeddings import RNAFMEmbeddings
model = RNAFMEmbeddings()
model.short_circuit(layers=[0, 1])
embeddings = model.encode("ACGUACGU", layer=[12])Layer indices passed to short_circuit are zero-based and must be less than model.model_params["layer"].
from redial.ContactPredictor import ContactPredictor
from redial.RNAFMEmbeddings import RNAFMEmbeddings
model = RNAFMEmbeddings()
predictor = ContactPredictor(model=model)
contact = predictor.predict("ACGUACGU", layer=12, do_apc=True)
print(contact.shape) # (8, 8)
print(contact.min(), contact.max()) # normalized to [0, 1]from redial.ContactPredictor import batch_predict
from redial.RNAFMEmbeddings import RNAFMEmbeddings
model = RNAFMEmbeddings()
results = batch_predict(["ACGUACGU", "GGCAUU"], model=model)ContactPredictor.predictis computationally expensive. For a sequence of lengthL, it performs one original encoding plus up to3 * Lmutated-sequence encodings.- RNA-FM supports sequences up to length 1024 in this wrapper.
- StructRFM supports sequences up to length 514 in this wrapper.
ContactPredictor.pyhas a__main__debugging block that expects../2_comparison/database.jsonand writes.npyfiles underrnafm/andstructrfm/relative to the current working directory.