YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

GeneSetCLIP: Contrastive Pretraining for Gene Set–Text Alignment

A CLIP-style contrastive model that aligns biological text descriptions with gene-set representations, trained on MSigDB v2024.1 (human + mouse).

Given a text query like "type I interferon signaling", the model retrieves the corresponding gene set β€” and vice versa.

Architecture

TEXT SIDE                               GENE SET SIDE
─────────────────────                   ──────────────────────────
"Genes up-regulated in                  {STAT1, IRF7, ISG15,
 response to IFN-Ξ±..."                  OAS1, MX1, IFIT1, ...}
        β”‚                                        β”‚
        β–Ό                                        β–Ό
 BioLORD-2023 (frozen)                  GSFM (fine-tuned, lr/10)
 [768-dim]                              [256-dim]
        β”‚                                        β”‚
        β–Ό                                        β–Ό
 text_proj (trainable)                  gene_proj (trainable)
 768 β†’ 512 β†’ 256                        256 β†’ 256 β†’ 256
        β”‚                                        β”‚
        β–Ό                                        β–Ό
   z_text [256]                         z_gene [256]
        β”‚                                        β”‚
        └────── L2-normalize β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                      β”‚
                      β–Ό
              InfoNCE loss (Ο„ learnable)

Components

Component Model Dim Training
Gene encoder GSFM (MLP autoencoder, Set model) 256 Fine-tuned at 1/10 LR
Text encoder BioLORD-2023 (MPNet-base) 768 Frozen
Gene projection MLP: 256 β†’ 256 β†’ 256 + LayerNorm 256 Trained
Text projection MLP: 768 β†’ 512 β†’ 256 + LayerNorm 256 Trained

Why these encoders?

  • GSFM: Purpose-built gene-set encoder from Ma'ayan Lab. Takes variable-length gene sets as input (multi-hot encoding β†’ MLP), producing permutation-invariant 256-dim embeddings. Pretrained on Rummagene (gene sets from PubMed tables).
  • BioLORD-2023: Ontology-grounded biomedical sentence embeddings. Trained on UMLS concept name-synonym pairs + LLM-generated definitions β€” structurally identical to MSigDB gene set descriptions (name + definition anchored in GO/KEGG/Reactome).

Training Data

MSigDB v2024.1 β€” 50,896 gene set–text pairs from the Molecular Signatures Database.

Split Collections Pairs Purpose
Train C2, C5, C8, C1, M2, M5, M8, M1 38,622 Curated, GO, cell type signatures
Val C3, C4, M3 6,766 Regulatory targets, computational
Test H, C6, C7, MH 5,508 Hallmarks, oncogenic, immunologic

Each pair consists of:

  • Text: [Collection: H] [Species: human]\nHALLMARK APOPTOSIS\nGenes mediating programmed cell death by activation of caspases.
  • Genes: ["CASP3", "CASP6", "TP53", "BAX", ...]

Data augmentation: 20% gene dropout (randomly remove genes each epoch).

Training Recipe

Based on ProtST (ICML 2023) adapted for gene sets:

Parameter Value
Loss Symmetric InfoNCE (NT-Xent)
Temperature 0.07 (learnable, clamped [0.01, 1.0])
Batch size 256
LR (projections) 1e-4
LR (gene encoder) 1e-5 (10x lower)
LR (text encoder) 0 (frozen)
Optimizer AdamW (weight_decay=0.01)
Schedule 500-step warmup β†’ cosine decay
Epochs 50 (early stopping, patience=10)
Gene dropout 20%
Max gene set size 512
Hardware 1Γ— T4 GPU (16GB)

Quick Start

Installation

pip install torch sentence-transformers huggingface_hub safetensors lightning
GIT_LFS_SKIP_SMUDGE=1 pip install "git+https://huggingface.co/maayanlab/gsfm"

Inference

import torch
import torch.nn as nn
import torch.nn.functional as F
from gsfm import GSFM, Vocab
from sentence_transformers import SentenceTransformer
from huggingface_hub import hf_hub_download

# Load gene encoder + vocab
gene_encoder = GSFM.from_pretrained("maayanlab/gsfm-rummagene")
vocab = Vocab.from_pretrained("maayanlab/gsfm-rummagene")
gene_encoder.eval()

# Load text encoder
text_encoder = SentenceTransformer("FremyCompany/BioLORD-2023")

# Load projection heads
clip_path = hf_hub_download("AliSaadatV/GeneSetCLIP", "clip_model.pt")

class ProjectionHead(nn.Module):
    def __init__(self, d_in, d_h, d_out):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_in, d_h), nn.GELU(), nn.Dropout(0.1),
            nn.Linear(d_h, d_out), nn.LayerNorm(d_out))
    def forward(self, x): return self.net(x)

class GeneSetCLIP(nn.Module):
    def __init__(self):
        super().__init__()
        self.log_temperature = nn.Parameter(torch.zeros(1))
        self.text_proj = ProjectionHead(768, 512, 256)
        self.gene_proj = ProjectionHead(256, 256, 256)

clip_model = GeneSetCLIP()
clip_model.load_state_dict(torch.load(clip_path, map_location="cpu", weights_only=True))
clip_model.eval()

# --- Encode a gene set ---
genes = ["STAT1", "IRF7", "ISG15", "OAS1", "MX1", "IFIT1"]
gene_ids = torch.tensor([vocab(genes)])
with torch.no_grad():
    gene_emb = gene_encoder.encode(gene_ids)
    z_gene = F.normalize(clip_model.gene_proj(gene_emb), dim=-1)

# --- Encode text queries ---
queries = [
    "Interferon alpha response genes",
    "Apoptosis signaling",
    "Cell cycle regulation",
]
text_embs = text_encoder.encode(queries, convert_to_tensor=True)
with torch.no_grad():
    z_text = F.normalize(clip_model.text_proj(text_embs), dim=-1)

# --- Compute similarities ---
sims = (z_gene @ z_text.T).squeeze()
for q, s in zip(queries, sims):
    print(f"  {s.item():.3f}  {q}")
# Expected: highest similarity for "Interferon alpha response genes"

Training from Scratch

1. Process MSigDB data

python data_processing.py

This downloads all MSigDB GMT files and scrapes descriptions.

2. Train

# Self-contained (downloads data from Hub automatically)
python train_job.py

# Or with local data
python train.py

3. On HF Jobs (GPU)

from huggingface_hub import HfApi
# Submit as HF Job with GPU
# See train_job.py for the self-contained script

Downstream Applications

  1. Zero-shot gene set annotation: Embed a gene list from an experiment β†’ find nearest text descriptions
  2. Cross-modal search: Text query β†’ gene sets, or gene list β†’ pathway descriptions
  3. Gene set similarity: Compare gene sets via embedding cosine similarity (captures functional similarity beyond gene overlap)
  4. Cell type annotation: Embed cell marker gene sets β†’ match to cell type text descriptions
  5. Biological RAG: Use MSigDB embeddings as retrieval corpus for LLM-based reasoning

Key References

  • ProtST (ICML 2023) β€” Protein-text contrastive alignment
  • MoleculeSTM (Nature MI 2024) β€” Molecule-text alignment
  • LangCell β€” Cell-text contrastive with MSigDB pathways
  • BioLORD-2023 (JAMIA 2024) β€” Biomedical sentence embeddings
  • Set Transformer β€” Permutation-invariant set encoding

Files

File Description
clip_model.pt Trained projection heads (text + gene)
gene_encoder.pt Fine-tuned GSFM gene encoder
config.json Training configuration
vocab.json Gene symbol β†’ token ID mapping
test_results.json Evaluation metrics on test set
train_job.py Self-contained training script (for HF Jobs)
train.py Modular training script
data_processing.py MSigDB data download + processing

License

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Papers for AliSaadatV/GeneSetCLIP