hbp5181 commited on
Commit
877c6aa
·
verified ·
1 Parent(s): 641633d

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +141 -79
train.py CHANGED
@@ -3,92 +3,154 @@ import pandas as pd
3
  import numpy as np
4
  from sklearn.model_selection import KFold
5
  from sklearn.metrics import mean_squared_error, r2_score
6
- from scipy.stats import pearsonr, ttest_ind
7
  from catboost import CatBoostRegressor
 
8
 
9
- # Load dataset, this should be specified for which model will be trained(eg., embedding only or including physical terms)
10
- data = pd.read_csv("embeddings/ESM2_interaction.csv")
 
 
 
 
 
 
 
 
11
 
12
- # Fill missing feature strings (Features are chosen based on what kind of mdoel will be trained.
13
- # Ligand and Receptor Features are ESM2 embeddings and Physical Features are PyRosetta Features
14
- for col in ["Ligand Features", "Receptor Features", "Physical Features"]:
15
- data[col] = data[col].fillna("")
16
 
17
- # Parse comma-separated floats
18
- for col in ["Ligand Features", "Receptor Features", "Physical Features"]:
19
- data[col] = data[col].apply(
20
- lambda s: [float(x) for x in str(s).split(",") if x.strip()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- # Build feature arrays
24
- X_ligand = np.vstack(data["Ligand Features"].values)
25
- X_receptor = np.vstack(data["Receptor Features"].values)
26
- # optional: X_physical = np.vstack(data["Physical Features"].values)
27
 
28
- # Convert KD(M) into log10 scale
29
- raw_y = data["KD(M)"].values
30
- y = np.log10(raw_y) # assumes all KD values are positive
 
 
31
 
 
32
  records = []
 
 
 
 
 
 
 
 
 
33
 
34
- # Repeat 5×5-fold CV, with and without physical features
35
- for repeat in range(1, 6):
36
- kf = KFold(n_splits=5, shuffle=True, random_state=repeat)
37
-
38
- for include_phys in (False, True):
39
- X_base = np.hstack([X_ligand, X_receptor])
40
- X_full = np.hstack([X_base, X_physical])
41
- X_data = X_full if include_phys else X_base
42
-
43
- for fold_idx, (train_idx, test_idx) in enumerate(kf.split(X_data), start=1):
44
- X_train, X_test = X_data[train_idx], X_data[test_idx]
45
- y_train, y_test = y[train_idx], y[test_idx]
46
-
47
- # Initialize with your chosen hyperparameters and GPU support
48
- model = CatBoostRegressor(
49
- iterations=2000,
50
- learning_rate=0.08,
51
- depth=4,
52
- verbose=500,
53
- task_type="GPU",
54
- devices="0"
55
- )
56
-
57
- # Train and time this fold
58
- model.fit(X_train, y_train)
59
-
60
- preds = model.predict(X_test)
61
- rmse = np.sqrt(mean_squared_error(y_test, preds))
62
- r2 = r2_score(y_test, preds)
63
- pcc = pearsonr(y_test, preds)[0]
64
-
65
- records.append({
66
- "repeat": repeat,
67
- "fold": fold_idx,
68
- "with_physical": include_phys,
69
- "pearson_r": pcc,
70
- "r2": r2,
71
- "rmse": rmse
72
- })
73
-
74
- # Aggregate metrics
75
- metrics_df = pd.DataFrame(records)
76
-
77
- # Save to CSV
78
- out_dir = "metrics"
79
- os.makedirs(out_dir, exist_ok=True)
80
- csv_path = os.path.join(out_dir, "InteractionMetrics.csv")
81
- metrics_df.to_csv(csv_path, index=False)
82
- print(f"All metrics saved to {csv_path}")
83
-
84
- # Conduct independent t tests for each metric
85
- results = {}
86
- for metric in ["pearson_r", "r2", "rmse"]:
87
- grp_with = metrics_df.loc[metrics_df.with_physical, metric]
88
- grp_without = metrics_df.loc[~metrics_df.with_physical, metric]
89
- t_stat, p_val = ttest_ind(grp_with, grp_without, equal_var=False)
90
- results[metric] = (t_stat, p_val)
91
-
92
- print("\nT test results comparing with vs without physical features:")
93
- for m, (t_stat, p_val) in results.items():
94
- print(f"{m} → t = {t_stat:.3f}, p = {p_val:.3f}")
 
3
  import numpy as np
4
  from sklearn.model_selection import KFold
5
  from sklearn.metrics import mean_squared_error, r2_score
6
+ from scipy.stats import pearsonr, spearmanr
7
  from catboost import CatBoostRegressor
8
+ import matplotlib.pyplot as plt
9
 
10
+ # Set publication-style fonts
11
+ plt.rcParams.update({
12
+ 'font.family': 'serif',
13
+ 'font.size': 13,
14
+ 'axes.labelsize': 14,
15
+ 'axes.titlesize': 14,
16
+ 'xtick.labelsize': 12,
17
+ 'ytick.labelsize': 12,
18
+ 'legend.fontsize': 12
19
+ })
20
 
21
+ # Load dataset
22
+ data = pd.read_csv("/storage/group/cdm8/default/BindPred/embeddings/Seq_Gen_updated.csv")
 
 
23
 
24
+ # Handle missing values
25
+ data['Ligand Features'] = data['Ligand Features'].fillna('')
26
+ data['Receptor Features'] = data['Receptor Features'].fillna('')
27
+
28
+ # Convert embedding strings to float lists
29
+ data['Ligand Features'] = data['Ligand Features'].apply(
30
+ lambda x: [float(i) for i in str(x).split(',') if i.strip()] if isinstance(x, str) else []
31
+ )
32
+ data['Receptor Features'] = data['Receptor Features'].apply(
33
+ lambda x: [float(i) for i in str(x).split(',') if i.strip()] if isinstance(x, str) else []
34
+ )
35
+
36
+ # Combine embeddings
37
+ data['Combined Features'] = data.apply(
38
+ lambda row: np.concatenate((row['Ligand Features'], row['Receptor Features']))
39
+ if len(row['Ligand Features']) > 0 and len(row['Receptor Features']) > 0 else np.array([]),
40
+ axis=1
41
+ )
42
+
43
+ # Filter valid rows
44
+ data = data[data['Combined Features'].apply(len) > 0]
45
+
46
+ # Check KD(M) column
47
+ if "KD(M)" not in data.columns or data["KD(M)"].isnull().any():
48
+ raise ValueError("Missing or NaN values in 'KD(M)' column.")
49
+
50
+ # Prepare features and log-transformed labels
51
+ X = np.vstack(data['Combined Features'])
52
+ y = np.log10(data['KD(M)'])
53
+
54
+ # Cross-validation
55
+ kf = KFold(n_splits=5, shuffle=True, random_state=42)
56
+ all_y_true = []
57
+ all_y_pred = []
58
+ test_indices_all = []
59
+
60
+ # Output directory
61
+ output_dir = "new_plt"
62
+ os.makedirs(output_dir, exist_ok=True)
63
+
64
+ for fold, (train_index, test_index) in enumerate(kf.split(X)):
65
+ X_train, X_test = X[train_index], X[test_index]
66
+ y_train, y_test = y.iloc[train_index], y.iloc[test_index]
67
+
68
+ model = CatBoostRegressor(
69
+ iterations=2000,
70
+ learning_rate=0.08,
71
+ depth=4,
72
+ verbose=500,
73
+ task_type="GPU",
74
+ devices='0'
75
  )
76
+ model.fit(X_train, y_train)
77
+ y_pred = model.predict(X_test)
78
+
79
+ all_y_true.extend(y_test)
80
+ all_y_pred.extend(y_pred)
81
+ test_indices_all.extend(test_index)
82
+
83
+ # Convert predictions to arrays
84
+ all_y_true = np.array(all_y_true)
85
+ all_y_pred = np.array(all_y_pred)
86
+
87
+ # Compute performance metrics
88
+ pcc, _ = pearsonr(all_y_true, all_y_pred)
89
+ srcc, _ = spearmanr(all_y_true, all_y_pred)
90
+ rmse = np.sqrt(mean_squared_error(all_y_true, all_y_pred))
91
+ r2 = r2_score(all_y_true, all_y_pred)
92
+
93
+ # Compute absolute error
94
+ errors = np.abs(all_y_true - all_y_pred)
95
+
96
+ # Plotting
97
+ plt.figure(figsize=(5, 5))
98
+ plt.title("ESM2 Embeddings", fontsize=15, pad=10)
99
+
100
+ sc = plt.scatter(
101
+ all_y_true,
102
+ all_y_pred,
103
+ s=25,
104
+ c=errors,
105
+ cmap='Reds',
106
+ alpha=0.9,
107
+ edgecolors='black',
108
+ linewidth=0.4,
109
+ marker='^' # triangle markers
110
+ )
111
+
112
+ # Diagonal reference line
113
+ plt.plot([-15, -2], [-15, -2], color='black', linestyle='--', linewidth=1)
114
+
115
+ # Axis setup
116
+ plt.xlabel("Experimental Log10(Kd)", fontsize=14, labelpad=10)
117
+ plt.ylabel("BindPred Prediction of Log10(Kd)", fontsize=14, labelpad=10)
118
+ plt.xlim(-15.0, -2.0)
119
+ plt.ylim(-15.0, -2.0)
120
+ plt.gca().set_aspect('equal', adjustable='box')
121
+
122
+ # Metrics box
123
+ plt.text(0.05, 0.95,
124
+ f"PCC: {pcc:.3f}\nRMSE: {rmse:.3f}\nR²: {r2:.3f}",
125
+ transform=plt.gca().transAxes,
126
+ fontsize=12,
127
+ verticalalignment='top',
128
+ horizontalalignment='left',
129
+ bbox=dict(facecolor='white', edgecolor='gray', boxstyle='round,pad=0.3'))
130
 
131
+ # Colorbar
132
+ cbar = plt.colorbar(sc)
133
+ cbar.set_label("Absolute Error", fontsize=12)
 
134
 
135
+ # Save plot
136
+ plt.tight_layout()
137
+ plt.savefig(os.path.join(output_dir, 'esm2_plot.png'), dpi=700)
138
+ plt.savefig(os.path.join(output_dir, 'esm2_plot.pdf'), dpi=700)
139
+ plt.show()
140
 
141
+ # Save prediction results to CSV
142
  records = []
143
+ for idx, test_idx in enumerate(test_indices_all):
144
+ row = data.iloc[test_idx]
145
+ record = {
146
+ "PDB_ID": row.get("PDB_ID", "NA"),
147
+ "Mutation": row.get("Mutation", "NA"),
148
+ "Actual_log10Kd": all_y_true[idx],
149
+ "Predicted_log10Kd": all_y_pred[idx]
150
+ }
151
+ records.append(record)
152
 
153
+ df_preds = pd.DataFrame(records)
154
+ csv_path = os.path.join(output_dir, "ESM2_predictions.csv")
155
+ df_preds.to_csv(csv_path, index=False)
156
+ print(f"Saved prediction results to {csv_path}")