Skip to content

tiwarylab/redial

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 

Repository files navigation

REDIAL

REDIAL stands for RNA Embedding perturbation Diagnostics for Language models.

redial contains a small Python package, redial, for extracting RNA language-model embeddings and using embedding perturbations to estimate RNA contact maps.

The package currently wraps two RNA models:

  • RNAFMEmbeddings: loads RNA-FM with fm.pretrained.rna_fm_t12().
  • StructRFMEmbeddings: loads StructRFM through structRFM.infer.structRFM_infer.

It also provides a shared model interface and a mutation-based contact predictor.

Folder Layout

redial/
|-- README.md
`-- redial/
    |-- __init__.py
    |-- ContactPredictor.py
    |-- RNALanguageModel.py
    |-- RNAFMEmbeddings.py
    `-- StructRFMEmbeddings.py

Modules

RNALanguageModel.py

Defines the common interface used by the embedding wrappers:

  • Selects cuda when available, otherwise cpu.
  • Exposes encode(sequence, layer=...) for sequence embeddings.
  • Exposes decode(embedding) for decoding embeddings to logits.
  • Exposes short_circuit(layers=...) for replacing selected transformer layers with identity layers.
  • Provides perplexity(sequence, logits) for computing reconstruction perplexity from decoded logits.
  • Provides model_params, model, tokenizer, and detokenizer.

RNAFMEmbeddings.py

Wrapper around RNA-FM.

Model parameters:

{
    "max_length": 1024,
    "dim": 640,
    "layer": 12,
}

encode(sequence, layer=[12]) returns a NumPy array with shape:

num_requested_layers x sequence_length x 640

The wrapper strips RNA-FM special tokens before returning embeddings.

StructRFMEmbeddings.py

Wrapper around StructRFM.

Default checkpoint path:

/home/dteng/rna_db/structRFM_checkpoint

Model parameters:

{
    "max_length": 514,
    "dim": 768,
    "layer": 12,
    "num_attention_heads": 12,
    "device": torch.device(...),
}

encode(sequence, layer=[12]) returns a NumPy array with shape:

num_requested_layers x sequence_length x 768

The wrapper calls extract_raw_feature(sequence, return_all=True) and strips BOS/EOS tokens before returning embeddings.

ContactPredictor.py

Implements embedding perturbation contact prediction:

  1. Encode the original RNA sequence.
  2. For each position, mutate the nucleotide to the other three bases.
  3. Re-encode each mutated sequence.
  4. Measure embedding changes with an L2 norm.
  5. Optionally apply average product correction, or APC.
  6. Symmetrize and normalize the contact map to [0, 1].

predict(sequence, layer=12, do_apc=True) returns a NumPy array with shape:

sequence_length x sequence_length

batch_predict(sequences, model) runs prediction for multiple sequences and returns a dictionary keyed by sequence.

Dependencies

The package depends on the model libraries used by the wrappers, plus NumPy and PyTorch. In this repository, the relevant dependencies are listed in the top-level requirements.txt:

numpy
torch>=2.0.1
structRFM
transformers[torch]

RNAFMEmbeddings also imports fm, so the RNA-FM package must be installed in the active Python environment.

StructRFMEmbeddings requires a valid StructRFM checkpoint. Use the default path or pass another path:

from redial.StructRFMEmbeddings import StructRFMEmbeddings

model = StructRFMEmbeddings(pretrained_path="/path/to/structRFM_checkpoint")

Usage

Run examples from the redial directory, or make sure the parent directory of redial is on PYTHONPATH.

Extract RNA-FM Embeddings

from redial.RNAFMEmbeddings import RNAFMEmbeddings

model = RNAFMEmbeddings()
sequence = "ACGUACGU"

embeddings = model.encode(sequence, layer=[12])
last_layer = embeddings[0]

print(last_layer.shape)  # (8, 640)

Extract StructRFM Embeddings

from redial.StructRFMEmbeddings import StructRFMEmbeddings

model = StructRFMEmbeddings(pretrained_path="/home/dteng/rna_db/structRFM_checkpoint")
sequence = "ACGUACGU"

embeddings = model.encode(sequence, layer=[12])
last_layer = embeddings[0]

print(last_layer.shape)  # (8, 768)

Decode Embeddings and Compute Perplexity

from redial.RNAFMEmbeddings import RNAFMEmbeddings

model = RNAFMEmbeddings()
sequence = "ACGUACGU"

embedding = model.encode(sequence, layer=[12])
logits = model.decode(embedding)
perplexity = model.perplexity(sequence, logits)

print(perplexity)

Short-Circuit Transformer Layers

from redial.RNAFMEmbeddings import RNAFMEmbeddings

model = RNAFMEmbeddings()
model.short_circuit(layers=[0, 1])

embeddings = model.encode("ACGUACGU", layer=[12])

Layer indices passed to short_circuit are zero-based and must be less than model.model_params["layer"].

Predict a Contact Map

from redial.ContactPredictor import ContactPredictor
from redial.RNAFMEmbeddings import RNAFMEmbeddings

model = RNAFMEmbeddings()
predictor = ContactPredictor(model=model)

contact = predictor.predict("ACGUACGU", layer=12, do_apc=True)

print(contact.shape)  # (8, 8)
print(contact.min(), contact.max())  # normalized to [0, 1]

Batch Prediction

from redial.ContactPredictor import batch_predict
from redial.RNAFMEmbeddings import RNAFMEmbeddings

model = RNAFMEmbeddings()
results = batch_predict(["ACGUACGU", "GGCAUU"], model=model)

Notes

  • ContactPredictor.predict is computationally expensive. For a sequence of length L, it performs one original encoding plus up to 3 * L mutated-sequence encodings.
  • RNA-FM supports sequences up to length 1024 in this wrapper.
  • StructRFM supports sequences up to length 514 in this wrapper.
  • ContactPredictor.py has a __main__ debugging block that expects ../2_comparison/database.json and writes .npy files under rnafm/ and structrfm/ relative to the current working directory.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages