| from __future__ import annotations |
|
|
| import csv, re, json |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Dict, Optional, Tuple, Any, List |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import joblib |
| import xgboost as xgb |
|
|
| from transformers import EsmModel, EsmTokenizer, AutoModelForMaskedLM |
| from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer |
| from lightning.pytorch import seed_everything |
| seed_everything(1986) |
|
|
| |
| |
| |
| @dataclass(frozen=True) |
| class BestRow: |
| property_key: str |
| best_wt: Optional[str] |
| best_smiles: Optional[str] |
| task_type: str |
| thr_wt: Optional[float] |
| thr_smiles: Optional[float] |
|
|
|
|
| def _clean(s: str) -> str: |
| return (s or "").strip() |
|
|
| def _none_if_dash(s: str) -> Optional[str]: |
| s = _clean(s) |
| if s in {"", "-", "—", "NA", "N/A"}: |
| return None |
| return s |
|
|
| def _float_or_none(s: str) -> Optional[float]: |
| s = _clean(s) |
| if s in {"", "-", "—", "NA", "N/A"}: |
| return None |
| return float(s) |
|
|
| def normalize_property_key(name: str) -> str: |
| n = name.strip().lower() |
| n = re.sub(r"\s*\(.*?\)\s*", "", n) |
| n = n.replace("-", "_").replace(" ", "_") |
|
|
| if "permeability" in n and "pampa" not in n and "caco" not in n: |
| return "permeability_penetrance" |
| if n == "binding_affinity": |
| return "binding_affinity" |
| if n in {"halflife", "half_life"}: |
| return "halflife" |
| if n == "non_fouling": |
| return "nf" |
| return n |
|
|
|
|
| def read_best_manifest_csv(path: str | Path) -> Dict[str, BestRow]: |
| """ |
| Properties, Best_Model_WT, Best_Model_SMILES, Type, Threshold_WT, Threshold_SMILES, |
| Hemolysis, SVM, SGB, Classifier, 0.2801, 0.2223, |
| """ |
| p = Path(path) |
| out: Dict[str, BestRow] = {} |
|
|
| with p.open("r", newline="") as f: |
| reader = csv.reader(f) |
| header = None |
| for raw in reader: |
| if not raw or all(_clean(x) == "" for x in raw): |
| continue |
| while raw and _clean(raw[-1]) == "": |
| raw = raw[:-1] |
|
|
| if header is None: |
| header = [h.strip() for h in raw] |
| continue |
|
|
| if len(raw) < len(header): |
| raw = raw + [""] * (len(header) - len(raw)) |
| rec = dict(zip(header, raw)) |
|
|
| prop_raw = _clean(rec.get("Properties", "")) |
| if not prop_raw: |
| continue |
| prop_key = normalize_property_key(prop_raw) |
|
|
| row = BestRow( |
| property_key=prop_key, |
| best_wt=_none_if_dash(rec.get("Best_Model_WT", "")), |
| best_smiles=_none_if_dash(rec.get("Best_Model_SMILES", "")), |
| task_type=_clean(rec.get("Type", "Classifier")), |
| thr_wt=_float_or_none(rec.get("Threshold_WT", "")), |
| thr_smiles=_float_or_none(rec.get("Threshold_SMILES", "")), |
| ) |
| out[prop_key] = row |
|
|
| return out |
|
|
|
|
| MODEL_ALIAS = { |
| "SVM": "svm_gpu", |
| "SVR": "svr", |
| "ENET": "enet_gpu", |
| "CNN": "cnn", |
| "MLP": "mlp", |
| "TRANSFORMER": "transformer", |
| "XGB": "xgb", |
| "XGB_REG": "xgb_reg", |
| "POOLED": "pooled", |
| "UNPOOLED": "unpooled", |
| "TRANSFORMER_WT_LOG": "transformer_wt_log", |
| } |
| def canon_model(label: Optional[str]) -> Optional[str]: |
| if label is None: |
| return None |
| k = label.strip().upper() |
| return MODEL_ALIAS.get(k, label.strip().lower()) |
|
|
|
|
| |
| |
| |
| def find_best_artifact(model_dir: Path) -> Path: |
| for pat in ["best_model.json", "best_model.pt", "best_model*.joblib"]: |
| hits = sorted(model_dir.glob(pat)) |
| if hits: |
| return hits[0] |
| raise FileNotFoundError(f"No best_model artifact found in {model_dir}") |
|
|
| def load_artifact(model_dir: Path, device: torch.device) -> Tuple[str, Any, Path]: |
| art = find_best_artifact(model_dir) |
|
|
| if art.suffix == ".json": |
| booster = xgb.Booster() |
| |
| booster.load_model(str(art)) |
| return "xgb", booster, art |
|
|
| if art.suffix == ".joblib": |
| obj = joblib.load(art) |
| return "joblib", obj, art |
|
|
| if art.suffix == ".pt": |
| ckpt = torch.load(art, map_location=device, weights_only=False) |
| return "torch_ckpt", ckpt, art |
|
|
| raise ValueError(f"Unknown artifact type: {art}") |
|
|
|
|
| |
| |
| |
| class MaskedMeanPool(nn.Module): |
| def forward(self, X, M): |
| Mf = M.unsqueeze(-1).float() |
| denom = Mf.sum(dim=1).clamp(min=1.0) |
| return (X * Mf).sum(dim=1) / denom |
|
|
| class MLPHead(nn.Module): |
| def __init__(self, in_dim, hidden=512, dropout=0.1): |
| super().__init__() |
| self.pool = MaskedMeanPool() |
| self.net = nn.Sequential( |
| nn.Linear(in_dim, hidden), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden, 1), |
| ) |
| def forward(self, X, M): |
| z = self.pool(X, M) |
| return self.net(z).squeeze(-1) |
|
|
| class CNNHead(nn.Module): |
| def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1): |
| super().__init__() |
| blocks = [] |
| ch = in_ch |
| for _ in range(layers): |
| blocks += [nn.Conv1d(ch, c, kernel_size=k, padding=k//2), |
| nn.GELU(), |
| nn.Dropout(dropout)] |
| ch = c |
| self.conv = nn.Sequential(*blocks) |
| self.head = nn.Linear(c, 1) |
|
|
| def forward(self, X, M): |
| Xc = X.transpose(1, 2) |
| Y = self.conv(Xc).transpose(1, 2) |
| Mf = M.unsqueeze(-1).float() |
| denom = Mf.sum(dim=1).clamp(min=1.0) |
| pooled = (Y * Mf).sum(dim=1) / denom |
| return self.head(pooled).squeeze(-1) |
|
|
| class TransformerHead(nn.Module): |
| def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1): |
| super().__init__() |
| self.proj = nn.Linear(in_dim, d_model) |
| enc_layer = nn.TransformerEncoderLayer( |
| d_model=d_model, nhead=nhead, dim_feedforward=ff, |
| dropout=dropout, batch_first=True, activation="gelu" |
| ) |
| self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers) |
| self.head = nn.Linear(d_model, 1) |
|
|
| def forward(self, X, M): |
| pad_mask = ~M |
| Z = self.proj(X) |
| Z = self.enc(Z, src_key_padding_mask=pad_mask) |
| Mf = M.unsqueeze(-1).float() |
| denom = Mf.sum(dim=1).clamp(min=1.0) |
| pooled = (Z * Mf).sum(dim=1) / denom |
| return self.head(pooled).squeeze(-1) |
|
|
| def _infer_in_dim_from_sd(sd: dict, model_name: str) -> int: |
| if model_name == "mlp": |
| return int(sd["net.0.weight"].shape[1]) |
| if model_name == "cnn": |
| return int(sd["conv.0.weight"].shape[1]) |
| if model_name == "transformer": |
| return int(sd["proj.weight"].shape[1]) |
| raise ValueError(model_name) |
|
|
| def _infer_num_layers_from_sd(sd: dict, prefix: str = "enc.layers.") -> int: |
| |
| idxs = set() |
| for k in sd.keys(): |
| if k.startswith(prefix): |
| rest = k[len(prefix):] |
| m = re.match(r"(\d+)\.", rest) |
| if m: |
| idxs.add(int(m.group(1))) |
| return (max(idxs) + 1) if idxs else 1 |
|
|
| def _infer_transformer_arch_from_sd(sd: dict) -> Tuple[int, int, int]: |
| """ |
| Returns (d_model, layers, ff) inferred from weights. |
| - d_model from proj.weight (shape: [d_model, in_dim]) |
| - layers from count of enc.layers.* |
| - ff from enc.layers.0.linear1.weight (shape: [ff, d_model]) |
| """ |
| if "proj.weight" not in sd: |
| raise KeyError("Missing proj.weight in state_dict; cannot infer transformer d_model.") |
| d_model = int(sd["proj.weight"].shape[0]) |
| layers = _infer_num_layers_from_sd(sd, prefix="enc.layers.") |
| if "enc.layers.0.linear1.weight" in sd: |
| ff = int(sd["enc.layers.0.linear1.weight"].shape[0]) |
| else: |
| ff = 4 * d_model |
| return d_model, layers, ff |
|
|
| def _pick_nhead(d_model: int) -> int: |
| |
| for h in (8, 6, 4, 3, 2, 1): |
| if d_model % h == 0: |
| return h |
| return 1 |
|
|
| def build_torch_model_from_ckpt(model_name: str, ckpt: dict, device: torch.device) -> nn.Module: |
| params = ckpt["best_params"] |
| sd = ckpt["state_dict"] |
| in_dim = int(ckpt.get("in_dim", _infer_in_dim_from_sd(sd, model_name))) |
| dropout = float(params.get("dropout", 0.1)) |
|
|
| if model_name == "mlp": |
| model = MLPHead(in_dim=in_dim, hidden=int(params["hidden"]), dropout=dropout) |
| elif model_name == "cnn": |
| model = CNNHead(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]), |
| layers=int(params["layers"]), dropout=dropout) |
| elif model_name == "transformer": |
| |
| d_model = params.get("d_model") or params.get("hidden") or params.get("hidden_dim") |
|
|
| if d_model is None: |
| d_model_i, layers_i, ff_i = _infer_transformer_arch_from_sd(sd) |
| nhead_i = _pick_nhead(d_model_i) |
| model = TransformerHead( |
| in_dim=in_dim, |
| d_model=int(d_model_i), |
| nhead=int(params.get("nhead", nhead_i)), |
| layers=int(params.get("layers", layers_i)), |
| ff=int(params.get("ff", ff_i)), |
| dropout=float(params.get("dropout", dropout)), |
| ) |
| else: |
| d_model = int(d_model) |
| model = TransformerHead( |
| in_dim=in_dim, |
| d_model=d_model, |
| nhead=int(params.get("nhead", _pick_nhead(d_model))), |
| layers=int(params.get("layers", 2)), |
| ff=int(params.get("ff", 4 * d_model)), |
| dropout=dropout |
| ) |
| else: |
| raise ValueError(f"Unknown NN model_name={model_name}") |
|
|
| model.load_state_dict(sd) |
| model.to(device) |
| model.eval() |
| return model |
|
|
|
|
| |
| |
| |
| def affinity_to_class(y: float) -> int: |
| |
| if y >= 9.0: return 0 |
| if y < 7.0: return 2 |
| return 1 |
|
|
| class CrossAttnPooled(nn.Module): |
| def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1): |
| super().__init__() |
| self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden)) |
| self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden)) |
|
|
| self.layers = nn.ModuleList([]) |
| for _ in range(n_layers): |
| self.layers.append(nn.ModuleDict({ |
| "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False), |
| "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False), |
| "n1t": nn.LayerNorm(hidden), |
| "n2t": nn.LayerNorm(hidden), |
| "n1b": nn.LayerNorm(hidden), |
| "n2b": nn.LayerNorm(hidden), |
| "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), |
| "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), |
| })) |
|
|
| self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout)) |
| self.reg = nn.Linear(hidden, 1) |
| self.cls = nn.Linear(hidden, 3) |
|
|
| def forward(self, t_vec, b_vec): |
| t = self.t_proj(t_vec).unsqueeze(0) |
| b = self.b_proj(b_vec).unsqueeze(0) |
| for L in self.layers: |
| t_attn, _ = L["attn_tb"](t, b, b) |
| t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1) |
| t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1) |
|
|
| b_attn, _ = L["attn_bt"](b, t, t) |
| b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1) |
| b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1) |
|
|
| z = torch.cat([t[0], b[0]], dim=-1) |
| h = self.shared(z) |
| return self.reg(h).squeeze(-1), self.cls(h) |
|
|
| class CrossAttnUnpooled(nn.Module): |
| def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1): |
| super().__init__() |
| self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden)) |
| self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden)) |
|
|
| self.layers = nn.ModuleList([]) |
| for _ in range(n_layers): |
| self.layers.append(nn.ModuleDict({ |
| "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True), |
| "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True), |
| "n1t": nn.LayerNorm(hidden), |
| "n2t": nn.LayerNorm(hidden), |
| "n1b": nn.LayerNorm(hidden), |
| "n2b": nn.LayerNorm(hidden), |
| "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), |
| "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), |
| })) |
|
|
| self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout)) |
| self.reg = nn.Linear(hidden, 1) |
| self.cls = nn.Linear(hidden, 3) |
|
|
| def _masked_mean(self, X, M): |
| Mf = M.unsqueeze(-1).float() |
| denom = Mf.sum(dim=1).clamp(min=1.0) |
| return (X * Mf).sum(dim=1) / denom |
|
|
| def forward(self, T, Mt, B, Mb): |
| T = self.t_proj(T) |
| Bx = self.b_proj(B) |
| kp_t = ~Mt |
| kp_b = ~Mb |
|
|
| for L in self.layers: |
| T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b) |
| T = L["n1t"](T + T_attn) |
| T = L["n2t"](T + L["fft"](T)) |
|
|
| B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t) |
| Bx = L["n1b"](Bx + B_attn) |
| Bx = L["n2b"](Bx + L["ffb"](Bx)) |
|
|
| t_pool = self._masked_mean(T, Mt) |
| b_pool = self._masked_mean(Bx, Mb) |
| z = torch.cat([t_pool, b_pool], dim=-1) |
| h = self.shared(z) |
| return self.reg(h).squeeze(-1), self.cls(h) |
|
|
| def load_binding_model(best_model_pt: Path, pooled_or_unpooled: str, device: torch.device) -> nn.Module: |
| ckpt = torch.load(best_model_pt, map_location=device, weights_only=False) |
| params = ckpt["best_params"] |
| sd = ckpt["state_dict"] |
|
|
| |
| Ht = int(sd["t_proj.0.weight"].shape[1]) |
| Hb = int(sd["b_proj.0.weight"].shape[1]) |
|
|
| common = dict( |
| Ht=Ht, Hb=Hb, |
| hidden=int(params["hidden_dim"]), |
| n_heads=int(params["n_heads"]), |
| n_layers=int(params["n_layers"]), |
| dropout=float(params["dropout"]), |
| ) |
|
|
| if pooled_or_unpooled == "pooled": |
| model = CrossAttnPooled(**common) |
| elif pooled_or_unpooled == "unpooled": |
| model = CrossAttnUnpooled(**common) |
| else: |
| raise ValueError(pooled_or_unpooled) |
|
|
| model.load_state_dict(sd) |
| model.to(device).eval() |
| return model |
|
|
|
|
| |
| |
| |
| def _safe_isin(ids: torch.Tensor, test_ids: torch.Tensor) -> torch.Tensor: |
| """ |
| Pytorch patch |
| """ |
| if hasattr(torch, "isin"): |
| return torch.isin(ids, test_ids) |
| |
| |
| return (ids.unsqueeze(-1) == test_ids.view(1, 1, -1)).any(dim=-1) |
| |
| class SMILESEmbedder: |
| """ |
| PeptideCLM RoFormer embeddings for SMILES. |
| - pooled(): mean over tokens where attention_mask==1 AND token_id not in SPECIAL_IDS |
| - unpooled(): returns token embeddings filtered to valid tokens (specials removed), |
| plus a 1-mask of length Li (since already filtered). |
| """ |
| def __init__( |
| self, |
| device: torch.device, |
| vocab_path: str, |
| splits_path: str, |
| clm_name: str = "aaronfeller/PeptideCLM-23M-all", |
| max_len: int = 512, |
| use_cache: bool = True, |
| ): |
| self.device = device |
| self.max_len = max_len |
| self.use_cache = use_cache |
|
|
| self.tokenizer = SMILES_SPE_Tokenizer(vocab_path, splits_path) |
| self.model = AutoModelForMaskedLM.from_pretrained(clm_name).roformer.to(device).eval() |
|
|
| self.special_ids = self._get_special_ids(self.tokenizer) |
| self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long) |
| if len(self.special_ids) else None) |
|
|
| self._cache_pooled: Dict[str, torch.Tensor] = {} |
| self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {} |
|
|
| @staticmethod |
| def _get_special_ids(tokenizer) -> List[int]: |
| cand = [ |
| getattr(tokenizer, "pad_token_id", None), |
| getattr(tokenizer, "cls_token_id", None), |
| getattr(tokenizer, "sep_token_id", None), |
| getattr(tokenizer, "bos_token_id", None), |
| getattr(tokenizer, "eos_token_id", None), |
| getattr(tokenizer, "mask_token_id", None), |
| ] |
| return sorted({int(x) for x in cand if x is not None}) |
|
|
| def _tokenize(self, smiles_list: List[str]) -> Dict[str, torch.Tensor]: |
| tok = self.tokenizer( |
| smiles_list, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=self.max_len, |
| ) |
| for k in tok: |
| tok[k] = tok[k].to(self.device) |
| if "attention_mask" not in tok: |
| tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device) |
| return tok |
|
|
| @torch.no_grad() |
| def pooled(self, smiles: str) -> torch.Tensor: |
| s = smiles.strip() |
| if self.use_cache and s in self._cache_pooled: |
| return self._cache_pooled[s] |
|
|
| tok = self._tokenize([s]) |
| ids = tok["input_ids"] |
| attn = tok["attention_mask"].bool() |
|
|
| out = self.model(input_ids=ids, attention_mask=tok["attention_mask"]) |
| h = out.last_hidden_state |
|
|
| valid = attn |
| if self.special_ids_t is not None and self.special_ids_t.numel() > 0: |
| valid = valid & (~_safe_isin(ids, self.special_ids_t)) |
|
|
| vf = valid.unsqueeze(-1).float() |
| summed = (h * vf).sum(dim=1) |
| denom = vf.sum(dim=1).clamp(min=1e-9) |
| pooled = summed / denom |
|
|
| if self.use_cache: |
| self._cache_pooled[s] = pooled |
| return pooled |
|
|
| @torch.no_grad() |
| def unpooled(self, smiles: str) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Returns: |
| X: (1, Li, H) float32 on device |
| M: (1, Li) bool on device |
| where Li excludes padding + special tokens. |
| """ |
| s = smiles.strip() |
| if self.use_cache and s in self._cache_unpooled: |
| return self._cache_unpooled[s] |
|
|
| tok = self._tokenize([s]) |
| ids = tok["input_ids"] |
| attn = tok["attention_mask"].bool() |
|
|
| out = self.model(input_ids=ids, attention_mask=tok["attention_mask"]) |
| h = out.last_hidden_state |
|
|
| valid = attn |
| if self.special_ids_t is not None and self.special_ids_t.numel() > 0: |
| valid = valid & (~_safe_isin(ids, self.special_ids_t)) |
|
|
| |
| keep = valid[0] |
| X = h[:, keep, :] |
| M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device) |
|
|
| if self.use_cache: |
| self._cache_unpooled[s] = (X, M) |
| return X, M |
|
|
|
|
| class WTEmbedder: |
| """ |
| ESM2 embeddings for AA sequences. |
| - pooled(): mean over tokens where attention_mask==1 AND token_id not in {CLS, EOS, PAD,...} |
| - unpooled(): returns token embeddings filtered to valid tokens (specials removed), |
| plus a 1-mask of length Li (since already filtered). |
| """ |
| def __init__( |
| self, |
| device: torch.device, |
| esm_name: str = "facebook/esm2_t33_650M_UR50D", |
| max_len: int = 1022, |
| use_cache: bool = True, |
| ): |
| self.device = device |
| self.max_len = max_len |
| self.use_cache = use_cache |
|
|
| self.tokenizer = EsmTokenizer.from_pretrained(esm_name) |
| self.model = EsmModel.from_pretrained(esm_name, add_pooling_layer=False).to(device).eval() |
|
|
| self.special_ids = self._get_special_ids(self.tokenizer) |
| self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long) |
| if len(self.special_ids) else None) |
|
|
| self._cache_pooled: Dict[str, torch.Tensor] = {} |
| self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {} |
|
|
| @staticmethod |
| def _get_special_ids(tokenizer) -> List[int]: |
| cand = [ |
| getattr(tokenizer, "pad_token_id", None), |
| getattr(tokenizer, "cls_token_id", None), |
| getattr(tokenizer, "sep_token_id", None), |
| getattr(tokenizer, "bos_token_id", None), |
| getattr(tokenizer, "eos_token_id", None), |
| getattr(tokenizer, "mask_token_id", None), |
| ] |
| return sorted({int(x) for x in cand if x is not None}) |
|
|
| def _tokenize(self, seq_list: List[str]) -> Dict[str, torch.Tensor]: |
| tok = self.tokenizer( |
| seq_list, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=self.max_len, |
| ) |
| tok = {k: v.to(self.device) for k, v in tok.items()} |
| if "attention_mask" not in tok: |
| tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device) |
| return tok |
|
|
| @torch.no_grad() |
| def pooled(self, seq: str) -> torch.Tensor: |
| s = seq.strip() |
| if self.use_cache and s in self._cache_pooled: |
| return self._cache_pooled[s] |
|
|
| tok = self._tokenize([s]) |
| ids = tok["input_ids"] |
| attn = tok["attention_mask"].bool() |
|
|
| out = self.model(**tok) |
| h = out.last_hidden_state |
|
|
| valid = attn |
| if self.special_ids_t is not None and self.special_ids_t.numel() > 0: |
| valid = valid & (~_safe_isin(ids, self.special_ids_t)) |
|
|
| vf = valid.unsqueeze(-1).float() |
| summed = (h * vf).sum(dim=1) |
| denom = vf.sum(dim=1).clamp(min=1e-9) |
| pooled = summed / denom |
|
|
| if self.use_cache: |
| self._cache_pooled[s] = pooled |
| return pooled |
|
|
| @torch.no_grad() |
| def unpooled(self, seq: str) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Returns: |
| X: (1, Li, H) float32 on device |
| M: (1, Li) bool on device |
| where Li excludes padding + special tokens. |
| """ |
| s = seq.strip() |
| if self.use_cache and s in self._cache_unpooled: |
| return self._cache_unpooled[s] |
|
|
| tok = self._tokenize([s]) |
| ids = tok["input_ids"] |
| attn = tok["attention_mask"].bool() |
|
|
| out = self.model(**tok) |
| h = out.last_hidden_state |
|
|
| valid = attn |
| if self.special_ids_t is not None and self.special_ids_t.numel() > 0: |
| valid = valid & (~_safe_isin(ids, self.special_ids_t)) |
|
|
| keep = valid[0] |
| X = h[:, keep, :] |
| M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device) |
|
|
| if self.use_cache: |
| self._cache_unpooled[s] = (X, M) |
| return X, M |
|
|
|
|
|
|
| |
| |
| |
| class PeptiVersePredictor: |
| """ |
| - loads best models from training_classifiers/ |
| - computes embeddings as needed (pooled/unpooled) |
| - supports: xgb, joblib(ENET/SVM/SVR), NN(mlp/cnn/transformer), binding pooled/unpooled. |
| """ |
| def __init__( |
| self, |
| manifest_path: str | Path, |
| classifier_weight_root: str | Path, |
| esm_name="facebook/esm2_t33_650M_UR50D", |
| clm_name="aaronfeller/PeptideCLM-23M-all", |
| smiles_vocab="tokenizer/new_vocab.txt", |
| smiles_splits="tokenizer/new_splits.txt", |
| device: Optional[str] = None, |
| ): |
| self.root = Path(classifier_weight_root) |
| self.training_root = self.root / "training_classifiers" |
| self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) |
|
|
| self.manifest = read_best_manifest_csv(manifest_path) |
|
|
| self.wt_embedder = WTEmbedder(self.device) |
| self.smiles_embedder = SMILESEmbedder(self.device, clm_name=clm_name, |
| vocab_path=str(self.root / smiles_vocab), |
| splits_path=str(self.root / smiles_splits)) |
|
|
| self.models: Dict[Tuple[str, str], Any] = {} |
| self.meta: Dict[Tuple[str, str], Dict[str, Any]] = {} |
|
|
| self._load_all_best_models() |
|
|
| def _resolve_dir(self, prop_key: str, model_name: str, mode: str) -> Path: |
| |
| disk_prop = "half_life" if prop_key == "halflife" else prop_key |
| base = self.training_root / disk_prop |
|
|
| |
| if prop_key == "halflife" and model_name in {"xgb_wt_log", "xgb_smiles"}: |
| d = base / model_name |
| if d.exists(): |
| return d |
|
|
| |
| if prop_key == "halflife" and mode == "wt" and model_name == "transformer": |
| d = base / "transformer_wt_log" |
| if d.exists(): |
| return d |
| |
| if prop_key == "halflife" and model_name == "xgb": |
| d = base / ("xgb_wt_log" if mode == "wt" else "xgb_smiles") |
| if d.exists(): |
| return d |
|
|
| candidates = [ |
| base / f"{model_name}_{mode}", |
| base / model_name, |
| ] |
| if mode == "wt": |
| candidates += [base / f"{model_name}_wt"] |
| if mode == "smiles": |
| candidates += [base / f"{model_name}_smiles"] |
|
|
| for d in candidates: |
| if d.exists(): |
| return d |
|
|
| raise FileNotFoundError( |
| f"Cannot find model directory for {prop_key} {model_name} {mode}. Tried: {candidates}" |
| ) |
|
|
|
|
| def _load_all_best_models(self): |
| for prop_key, row in self.manifest.items(): |
| for mode, label, thr in [ |
| ("wt", row.best_wt, row.thr_wt), |
| ("smiles", row.best_smiles, row.thr_smiles), |
| ]: |
| m = canon_model(label) |
| if m is None: |
| continue |
|
|
| |
| if prop_key == "binding_affinity": |
| |
| pooled_or_unpooled = m |
| folder = f"wt_{mode}_{pooled_or_unpooled}" |
| model_dir = self.training_root / "binding_affinity" / folder |
| art = find_best_artifact(model_dir) |
| if art.suffix != ".pt": |
| raise RuntimeError(f"Binding model expected best_model.pt, got {art}") |
| model = load_binding_model(art, pooled_or_unpooled=pooled_or_unpooled, device=self.device) |
| self.models[(prop_key, mode)] = model |
| self.meta[(prop_key, mode)] = { |
| "task_type": "Regression", |
| "threshold": None, |
| "artifact": str(art), |
| "model_name": pooled_or_unpooled, |
| } |
| continue |
|
|
| model_dir = self._resolve_dir(prop_key, m, mode) |
| kind, obj, art = load_artifact(model_dir, self.device) |
|
|
| if kind in {"xgb", "joblib"}: |
| self.models[(prop_key, mode)] = obj |
| else: |
| |
| arch = m |
| if arch.startswith("transformer"): |
| arch = "transformer" |
| elif arch.startswith("mlp"): |
| arch = "mlp" |
| elif arch.startswith("cnn"): |
| arch = "cnn" |
|
|
| self.models[(prop_key, mode)] = build_torch_model_from_ckpt(arch, obj, self.device) |
|
|
| self.meta[(prop_key, mode)] = { |
| "task_type": row.task_type, |
| "threshold": thr, |
| "artifact": str(art), |
| "model_name": m, |
| "kind": kind, |
| } |
|
|
|
|
| def _get_features_for_model(self, prop_key: str, mode: str, input_str: str): |
| """ |
| Returns either: |
| - pooled np array shape (1,H) for xgb/joblib |
| - unpooled torch tensors (X,M) for NN |
| """ |
| model = self.models[(prop_key, mode)] |
| meta = self.meta[(prop_key, mode)] |
| kind = meta.get("kind", None) |
| model_name = meta.get("model_name", "") |
|
|
| if prop_key == "binding_affinity": |
| raise RuntimeError("Use predict_binding_affinity().") |
|
|
| |
| if kind == "torch_ckpt": |
| if mode == "wt": |
| X, M = self.wt_embedder.unpooled(input_str) |
| else: |
| X, M = self.smiles_embedder.unpooled(input_str) |
| return X, M |
|
|
| |
| if mode == "wt": |
| v = self.wt_embedder.pooled(input_str) |
| else: |
| v = self.smiles_embedder.pooled(input_str) |
| feats = v.detach().cpu().numpy().astype(np.float32) |
| feats = np.nan_to_num(feats, nan=0.0) |
| feats = np.clip(feats, np.finfo(np.float32).min, np.finfo(np.float32).max) |
| return feats |
|
|
| def predict_property(self, prop_key: str, mode: str, input_str: str) -> Dict[str, Any]: |
| """ |
| mode: "wt" for AA sequence input, "smiles" for SMILES input |
| Returns dict with score + label if classifier threshold exists. |
| """ |
| if (prop_key, mode) not in self.models: |
| raise KeyError(f"No model loaded for ({prop_key}, {mode}). Check manifest and folders.") |
|
|
| meta = self.meta[(prop_key, mode)] |
| model = self.models[(prop_key, mode)] |
| task_type = meta["task_type"].lower() |
| thr = meta.get("threshold", None) |
| kind = meta.get("kind", None) |
|
|
| if prop_key == "binding_affinity": |
| raise RuntimeError("Use predict_binding_affinity().") |
|
|
| |
| if kind == "torch_ckpt": |
| X, M = self._get_features_for_model(prop_key, mode, input_str) |
| with torch.no_grad(): |
| y = model(X, M).squeeze().float().cpu().item() |
| |
| model_name = meta.get("model_name", "") |
| if ( |
| prop_key == "halflife" |
| and mode == "wt" |
| and model_name in {"xgb_wt_log", "transformer_wt_log"} |
| ): |
| y = float(np.expm1(y)) |
| if task_type == "classifier": |
| prob = float(1.0 / (1.0 + np.exp(-y))) |
| out = {"property": prop_key, "mode": mode, "score": prob} |
| if thr is not None: |
| out["label"] = int(prob >= float(thr)) |
| out["threshold"] = float(thr) |
| return out |
| else: |
| return {"property": prop_key, "mode": mode, "score": float(y)} |
|
|
| if kind == "xgb": |
| feats = self._get_features_for_model(prop_key, mode, input_str) |
| dmat = xgb.DMatrix(feats) |
| pred = float(model.predict(dmat)[0]) |
|
|
| |
| model_name = meta.get("model_name", "") |
| if ( |
| prop_key == "halflife" |
| and mode == "wt" |
| and model_name in {"xgb_wt_log", "transformer_wt_log"} |
| ): |
| pred = float(np.expm1(pred)) |
|
|
| out = {"property": prop_key, "mode": mode, "score": pred} |
|
|
| return out |
|
|
| |
| if kind == "joblib": |
| feats = self._get_features_for_model(prop_key, mode, input_str) |
| |
| if task_type == "classifier": |
| if hasattr(model, "predict_proba"): |
| pred = float(model.predict_proba(feats)[:, 1][0]) |
| else: |
| if hasattr(model, "decision_function"): |
| logit = float(model.decision_function(feats)[0]) |
| pred = float(1.0 / (1.0 + np.exp(-logit))) |
| else: |
| pred = float(model.predict(feats)[0]) |
| out = {"property": prop_key, "mode": mode, "score": pred} |
| if thr is not None: |
| out["label"] = int(pred >= float(thr)) |
| out["threshold"] = float(thr) |
| return out |
| else: |
| pred = float(model.predict(feats)[0]) |
| return {"property": prop_key, "mode": mode, "score": pred} |
|
|
| raise RuntimeError(f"Unknown model kind={kind}") |
|
|
| def predict_binding_affinity(self, mode: str, target_seq: str, binder_str: str) -> Dict[str, Any]: |
| """ |
| mode: "wt" (binder is AA sequence) -> wt_wt_(pooled|unpooled) |
| "smiles" (binder is SMILES) -> wt_smiles_(pooled|unpooled) |
| """ |
| prop_key = "binding_affinity" |
| if (prop_key, mode) not in self.models: |
| raise KeyError(f"No binding model loaded for ({prop_key}, {mode}).") |
|
|
| model = self.models[(prop_key, mode)] |
| pooled_or_unpooled = self.meta[(prop_key, mode)]["model_name"] |
|
|
| |
| if pooled_or_unpooled == "pooled": |
| t_vec = self.wt_embedder.pooled(target_seq) |
| if mode == "wt": |
| b_vec = self.wt_embedder.pooled(binder_str) |
| else: |
| b_vec = self.smiles_embedder.pooled(binder_str) |
| with torch.no_grad(): |
| reg, logits = model(t_vec, b_vec) |
| affinity = float(reg.squeeze().cpu().item()) |
| cls_logit = int(torch.argmax(logits, dim=-1).cpu().item()) |
| cls_thr = affinity_to_class(affinity) |
| else: |
| T, Mt = self.wt_embedder.unpooled(target_seq) |
| if mode == "wt": |
| B, Mb = self.wt_embedder.unpooled(binder_str) |
| else: |
| B, Mb = self.smiles_embedder.unpooled(binder_str) |
| with torch.no_grad(): |
| reg, logits = model(T, Mt, B, Mb) |
| affinity = float(reg.squeeze().cpu().item()) |
| cls_logit = int(torch.argmax(logits, dim=-1).cpu().item()) |
| cls_thr = affinity_to_class(affinity) |
|
|
| names = {0: "High (≥9)", 1: "Moderate (7-9)", 2: "Low (<7)"} |
| return { |
| "property": "binding_affinity", |
| "mode": mode, |
| "affinity": affinity, |
| "class_by_threshold": names[cls_thr], |
| "class_by_logits": names[cls_logit], |
| "binding_model": pooled_or_unpooled, |
| } |
|
|
|
|
| if __name__ == "__main__": |
| predictor = PeptiVersePredictor( |
| manifest_path="basic_models.txt", |
| classifier_weight_root="./" |
| ) |
| print(predictor.predict_property("hemolysis", "wt", "GIGAVLKVLTTGLPALISWIKRKRQQ")) |
| print(predictor.predict_binding_affinity("wt", target_seq="...", binder_str="...")) |
|
|
| |
| """ |
| device = torch.device("cuda:0") |
| |
| wt = WTEmbedder(device) |
| sm = SMILESEmbedder(device, |
| vocab_path="./tokeizner/new_vocab.txt", |
| splits_path="./tokenizer/new_splits.txt" |
| ) |
| |
| p = wt.pooled("GIGAVLKVLTTGLPALISWIKRKRQQ") # (1,1280) |
| X, M = wt.unpooled("GIGAVLKVLTTGLPALISWIKRKRQQ") # (1,Li,1280), (1,Li) |
| |
| p2 = sm.pooled("NCC(=O)N[C@H](CS)C(=O)O") # (1,H_smiles) |
| X2, M2 = sm.unpooled("NCC(=O)N[C@H](CS)C(=O)O") # (1,Li,H_smiles), (1,Li) |
| """ |
|
|