LangCell: Language-Cell Pre-training for Cell Identity Understanding
Paper β’ 2405.06708 β’ Published
YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
A CLIP-style contrastive model that aligns biological text descriptions with gene-set representations, trained on MSigDB v2024.1 (human + mouse).
Given a text query like "type I interferon signaling", the model retrieves the corresponding gene set β and vice versa.
TEXT SIDE GENE SET SIDE
βββββββββββββββββββββ ββββββββββββββββββββββββββ
"Genes up-regulated in {STAT1, IRF7, ISG15,
response to IFN-Ξ±..." OAS1, MX1, IFIT1, ...}
β β
βΌ βΌ
BioLORD-2023 (frozen) GSFM (fine-tuned, lr/10)
[768-dim] [256-dim]
β β
βΌ βΌ
text_proj (trainable) gene_proj (trainable)
768 β 512 β 256 256 β 256 β 256
β β
βΌ βΌ
z_text [256] z_gene [256]
β β
βββββββ L2-normalize ββββββββββββββββββββ
β
βΌ
InfoNCE loss (Ο learnable)
| Component | Model | Dim | Training |
|---|---|---|---|
| Gene encoder | GSFM (MLP autoencoder, Set model) | 256 | Fine-tuned at 1/10 LR |
| Text encoder | BioLORD-2023 (MPNet-base) | 768 | Frozen |
| Gene projection | MLP: 256 β 256 β 256 + LayerNorm | 256 | Trained |
| Text projection | MLP: 768 β 512 β 256 + LayerNorm | 256 | Trained |
MSigDB v2024.1 β 50,896 gene setβtext pairs from the Molecular Signatures Database.
| Split | Collections | Pairs | Purpose |
|---|---|---|---|
| Train | C2, C5, C8, C1, M2, M5, M8, M1 | 38,622 | Curated, GO, cell type signatures |
| Val | C3, C4, M3 | 6,766 | Regulatory targets, computational |
| Test | H, C6, C7, MH | 5,508 | Hallmarks, oncogenic, immunologic |
Each pair consists of:
[Collection: H] [Species: human]\nHALLMARK APOPTOSIS\nGenes mediating programmed cell death by activation of caspases.["CASP3", "CASP6", "TP53", "BAX", ...]Data augmentation: 20% gene dropout (randomly remove genes each epoch).
Based on ProtST (ICML 2023) adapted for gene sets:
| Parameter | Value |
|---|---|
| Loss | Symmetric InfoNCE (NT-Xent) |
| Temperature | 0.07 (learnable, clamped [0.01, 1.0]) |
| Batch size | 256 |
| LR (projections) | 1e-4 |
| LR (gene encoder) | 1e-5 (10x lower) |
| LR (text encoder) | 0 (frozen) |
| Optimizer | AdamW (weight_decay=0.01) |
| Schedule | 500-step warmup β cosine decay |
| Epochs | 50 (early stopping, patience=10) |
| Gene dropout | 20% |
| Max gene set size | 512 |
| Hardware | 1Γ T4 GPU (16GB) |
pip install torch sentence-transformers huggingface_hub safetensors lightning
GIT_LFS_SKIP_SMUDGE=1 pip install "git+https://huggingface.co/maayanlab/gsfm"
import torch
import torch.nn as nn
import torch.nn.functional as F
from gsfm import GSFM, Vocab
from sentence_transformers import SentenceTransformer
from huggingface_hub import hf_hub_download
# Load gene encoder + vocab
gene_encoder = GSFM.from_pretrained("maayanlab/gsfm-rummagene")
vocab = Vocab.from_pretrained("maayanlab/gsfm-rummagene")
gene_encoder.eval()
# Load text encoder
text_encoder = SentenceTransformer("FremyCompany/BioLORD-2023")
# Load projection heads
clip_path = hf_hub_download("AliSaadatV/GeneSetCLIP", "clip_model.pt")
class ProjectionHead(nn.Module):
def __init__(self, d_in, d_h, d_out):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_in, d_h), nn.GELU(), nn.Dropout(0.1),
nn.Linear(d_h, d_out), nn.LayerNorm(d_out))
def forward(self, x): return self.net(x)
class GeneSetCLIP(nn.Module):
def __init__(self):
super().__init__()
self.log_temperature = nn.Parameter(torch.zeros(1))
self.text_proj = ProjectionHead(768, 512, 256)
self.gene_proj = ProjectionHead(256, 256, 256)
clip_model = GeneSetCLIP()
clip_model.load_state_dict(torch.load(clip_path, map_location="cpu", weights_only=True))
clip_model.eval()
# --- Encode a gene set ---
genes = ["STAT1", "IRF7", "ISG15", "OAS1", "MX1", "IFIT1"]
gene_ids = torch.tensor([vocab(genes)])
with torch.no_grad():
gene_emb = gene_encoder.encode(gene_ids)
z_gene = F.normalize(clip_model.gene_proj(gene_emb), dim=-1)
# --- Encode text queries ---
queries = [
"Interferon alpha response genes",
"Apoptosis signaling",
"Cell cycle regulation",
]
text_embs = text_encoder.encode(queries, convert_to_tensor=True)
with torch.no_grad():
z_text = F.normalize(clip_model.text_proj(text_embs), dim=-1)
# --- Compute similarities ---
sims = (z_gene @ z_text.T).squeeze()
for q, s in zip(queries, sims):
print(f" {s.item():.3f} {q}")
# Expected: highest similarity for "Interferon alpha response genes"
python data_processing.py
This downloads all MSigDB GMT files and scrapes descriptions.
# Self-contained (downloads data from Hub automatically)
python train_job.py
# Or with local data
python train.py
from huggingface_hub import HfApi
# Submit as HF Job with GPU
# See train_job.py for the self-contained script
| File | Description |
|---|---|
clip_model.pt |
Trained projection heads (text + gene) |
gene_encoder.pt |
Fine-tuned GSFM gene encoder |
config.json |
Training configuration |
vocab.json |
Gene symbol β token ID mapping |
test_results.json |
Evaluation metrics on test set |
train_job.py |
Self-contained training script (for HF Jobs) |
train.py |
Modular training script |
data_processing.py |
MSigDB data download + processing |