-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
353 lines (296 loc) · 12.8 KB
/
dataset.py
File metadata and controls
353 lines (296 loc) · 12.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
"""
PDBbind / CD cross-dock dataset loader.
Handles:
- PDBbind time-split (train/val/test)
- CD cross-dock test set (Apo-Holo and Holo-Holo pairs)
Each item returns a featurised protein-ligand complex ready for the
DiffBindFR score network.
"""
import os
import math
import json
import random
import numpy as np
import torch
from torch import Tensor
from torch.utils.data import Dataset
from typing import Optional
try:
from rdkit import Chem
from rdkit.Chem import AllChem
RDKIT_AVAILABLE = True
except ImportError:
RDKIT_AVAILABLE = False
# ──────────────────────────────────────────────────────────────────────────────
# Protein featurisation helpers
# ──────────────────────────────────────────────────────────────────────────────
AA_3TO1 = {
"ALA": "A", "ARG": "R", "ASN": "N", "ASP": "D", "CYS": "C",
"GLN": "Q", "GLU": "E", "GLY": "G", "HIS": "H", "ILE": "I",
"LEU": "L", "LYS": "K", "MET": "M", "PHE": "F", "PRO": "P",
"SER": "S", "THR": "T", "TRP": "W", "TYR": "Y", "VAL": "V",
}
AA_TYPES = sorted(set(AA_3TO1.values())) + ["X"] # 21 types
AA_TO_IDX = {aa: i for i, aa in enumerate(AA_TYPES)}
def rbf_encode(d: Tensor, d_min: float = 0., d_max: float = 20., n: int = 16) -> Tensor:
centers = torch.linspace(d_min, d_max, n, device=d.device)
sigma = (d_max - d_min) / n
return torch.exp(-((d.unsqueeze(-1) - centers) ** 2) / (2 * sigma ** 2))
def parse_pocket_from_pdb(pdb_path: str, radius: float = 10.0, ligand_center=None):
"""
Parse protein pocket residues within `radius` Å of `ligand_center`.
Returns:
residues: list of dicts with keys: name, ca_pos, n_pos, c_pos, cb_pos, seq_idx
"""
from Bio.PDB import PDBParser
parser = PDBParser(QUIET=True)
try:
structure = parser.get_structure("prot", pdb_path)
except Exception as e:
return []
all_residues = []
for model in structure:
for chain in model:
for res in chain:
if res.id[0] != " ": # skip HETATM
continue
resname = res.resname.strip()
aa = AA_3TO1.get(resname, "X")
try:
ca = torch.tensor(res["CA"].get_vector().get_array(), dtype=torch.float)
except KeyError:
continue
try:
n_pos = torch.tensor(res["N"].get_vector().get_array(), dtype=torch.float)
c_pos = torch.tensor(res["C"].get_vector().get_array(), dtype=torch.float)
except KeyError:
n_pos = ca.clone()
c_pos = ca.clone()
try:
cb = torch.tensor(res["CB"].get_vector().get_array(), dtype=torch.float)
except KeyError:
cb = ca.clone()
all_residues.append({
"aa": aa,
"ca": ca,
"n": n_pos,
"c": c_pos,
"cb": cb,
"chi1": None,
})
if ligand_center is None or len(all_residues) == 0:
return all_residues
# Filter to pocket
lc = torch.tensor(ligand_center, dtype=torch.float) if not isinstance(ligand_center, Tensor) else ligand_center
pocket = [r for r in all_residues if (r["ca"] - lc).norm() < radius]
return pocket
def featurise_pocket(residues: list) -> dict:
"""
Convert a list of pocket residues into GVP-ready tensors.
node_s: [N, 27] scalar node features (aa-type, backbone angles)
node_v: [N, 3, 3] vector node features (N→CA, CA→C, CB direction)
edge_s: [E, 20] scalar edge features (RBF distance)
edge_v: [E, 1, 3] vector edge features (unit displacement)
edge_index: [2, E]
ca_pos: [N, 3]
"""
N = len(residues)
if N == 0:
return None
# Node scalar features: one-hot AA (21) + backbone angles sines/cosines (6)
aa_idx = torch.tensor([AA_TO_IDX.get(r["aa"], 20) for r in residues], dtype=torch.long)
aa_onehot = torch.nn.functional.one_hot(aa_idx, num_classes=21).float()
# Backbone dihedral angles (simplified: zeros if not computed)
dihedrals = torch.zeros(N, 6) # sin/cos of phi, psi, omega
node_s = torch.cat([aa_onehot, dihedrals], dim=-1) # [N, 27]
# Node vector features: normalised backbone bond directions
ca = torch.stack([r["ca"] for r in residues]) # [N, 3]
n = torch.stack([r["n"] for r in residues])
c = torch.stack([r["c"] for r in residues])
cb = torch.stack([r["cb"] for r in residues])
def safe_norm(v):
return v / v.norm(dim=-1, keepdim=True).clamp(min=1e-8)
n_to_ca = safe_norm(ca - n)
ca_to_c = safe_norm(c - ca)
ca_to_cb = safe_norm(cb - ca)
node_v = torch.stack([n_to_ca, ca_to_c, ca_to_cb], dim=1) # [N, 3, 3]
# Build k-NN edges (k=30)
k = min(30, N - 1) if N > 1 else 0
dist_mat = torch.cdist(ca, ca) # [N, N]
topk = dist_mat.topk(k + 1, dim=1, largest=False)
src = torch.arange(N).unsqueeze(1).expand(-1, k).reshape(-1)
dst = topk.indices[:, 1:].reshape(-1) # exclude self
edge_index = torch.stack([src, dst]) # [2, E]
edge_dist = dist_mat[src, dst] # [E]
edge_s = rbf_encode(edge_dist) # [E, 16]
# Positional encoding extra features
rel_pos = ca[dst] - ca[src] # [E, 3]
edge_s_extra = rbf_encode(edge_dist, n=4)
edge_s = torch.cat([edge_s, edge_s_extra], dim=-1) # [E, 20]
disp = rel_pos / (edge_dist.unsqueeze(-1) + 1e-8)
edge_v = disp.unsqueeze(1) # [E, 1, 3]
return dict(
node_s=node_s,
node_v=node_v,
edge_s=edge_s,
edge_v=edge_v,
edge_index=edge_index,
ca_pos=ca,
n_residues=N,
)
# ──────────────────────────────────────────────────────────────────────────────
# Ligand torsion graph
# ──────────────────────────────────────────────────────────────────────────────
def get_torsion_bonds(mol):
"""
Return list of (i, j, k, l) atom indices defining rotatable bonds.
A bond (j, k) is rotatable if it is a single bond, not in a ring,
and neither j nor k is terminal.
"""
from rdkit.Chem import rdMolTransforms
rot_bonds = []
for bond in mol.GetBonds():
if bond.IsInRing():
continue
if bond.GetBondTypeAsDouble() != 1.0:
continue
j = bond.GetBeginAtomIdx()
k = bond.GetEndAtomIdx()
# Need at least one neighbour on each side (excluding j↔k)
j_nbrs = [n.GetIdx() for n in mol.GetAtomWithIdx(j).GetNeighbors() if n.GetIdx() != k]
k_nbrs = [n.GetIdx() for n in mol.GetAtomWithIdx(k).GetNeighbors() if n.GetIdx() != j]
if j_nbrs and k_nbrs:
rot_bonds.append((j_nbrs[0], j, k, k_nbrs[0]))
return rot_bonds
# ──────────────────────────────────────────────────────────────────────────────
# Dataset
# ──────────────────────────────────────────────────────────────────────────────
class PDBBindDataset(Dataset):
"""
PDBbind v2020 dataset using the time-split convention.
Each item corresponds to one protein-ligand complex.
The protein structure may be the crystal Holo state (for redocking)
or an Apo/cross-dock structure.
"""
def __init__(
self,
data_dir: str,
split: str = "train", # train / val / test
split_file: Optional[str] = None,
pocket_radius: float = 10.0,
max_residues: int = 80,
max_atoms: int = 100,
augment: bool = True,
):
self.data_dir = data_dir
self.pocket_radius = pocket_radius
self.max_residues = max_residues
self.max_atoms = max_atoms
self.augment = augment
# Load split indices
if split_file is not None:
with open(split_file) as f:
self.ids = json.load(f)[split]
else:
# Fallback: list all subdirectories
self.ids = [
d for d in os.listdir(data_dir)
if os.path.isdir(os.path.join(data_dir, d))
]
self.ids = sorted(self.ids)
def __len__(self):
return len(self.ids)
def __getitem__(self, idx: int):
pdb_id = self.ids[idx]
base = os.path.join(self.data_dir, pdb_id)
prot_path = os.path.join(base, f"{pdb_id}_protein.pdb")
lig_path = os.path.join(base, f"{pdb_id}_ligand.sdf")
if not os.path.exists(prot_path) or not os.path.exists(lig_path):
return None
# Load ligand
mol = Chem.SDMolSupplier(lig_path, removeHs=False, sanitize=True)[0]
if mol is None:
return None
try:
conf = mol.GetConformer()
lig_pos = torch.tensor(conf.GetPositions(), dtype=torch.float)
except Exception:
return None
ligand_center = lig_pos.mean(dim=0).numpy()
# Load protein pocket
residues = parse_pocket_from_pdb(prot_path, self.pocket_radius, ligand_center)
if len(residues) == 0:
return None
if len(residues) > self.max_residues:
# Keep closest residues
ca_stack = torch.stack([r["ca"] for r in residues])
lc = torch.tensor(ligand_center)
dists = (ca_stack - lc).norm(dim=-1)
keep = dists.topk(self.max_residues, largest=False).indices.tolist()
residues = [residues[i] for i in sorted(keep)]
pocket_feats = featurise_pocket(residues)
if pocket_feats is None:
return None
# Featurise ligand
from models.ligand_encoder import mol_to_graph
lig_graph = mol_to_graph(mol)
# Get torsion bonds
rot_bonds = get_torsion_bonds(mol)
if rot_bonds:
tor_i = torch.tensor([b[1] for b in rot_bonds], dtype=torch.long)
tor_j = torch.tensor([b[2] for b in rot_bonds], dtype=torch.long)
tor_edge_index = torch.stack([tor_i, tor_j]) # [2, d]
# Initial torsion angles from coords
lig_torsions = torch.zeros(len(rot_bonds))
else:
tor_edge_index = torch.zeros(2, 0, dtype=torch.long)
lig_torsions = torch.zeros(0)
return dict(
pdb_id=pdb_id,
# Protein
**{f"prot_{k}": v for k, v in pocket_feats.items()},
# Ligand
lig_atom_feats=lig_graph["atom_feats"],
lig_bond_feats=lig_graph["bond_feats"],
lig_edge_index=lig_graph["edge_index"],
lig_pos_crystal=lig_pos,
lig_pos=lig_pos.clone(),
# Torsions
tor_edge_index=tor_edge_index,
lig_torsions=lig_torsions,
# Placeholders for sc torsions (populated per-residue)
sc_torsions=torch.zeros(0),
sc_residue_idx=torch.zeros(0, dtype=torch.long),
)
class CDCrossDockDataset(Dataset):
"""
CD cross-dock test set: pairs of (apo/holo receptor, holo ligand).
Expects a CSV with columns:
receptor_pdb, receptor_chain, ligand_pdb, ligand_chain
or a pre-built JSON file listing all cross-dock pairs.
"""
def __init__(
self,
pairs_file: str,
pdb_dir: str,
pocket_radius: float = 10.0,
max_residues: int = 80,
):
import pandas as pd
self.pdb_dir = pdb_dir
self.pocket_radius = pocket_radius
self.max_residues = max_residues
if pairs_file.endswith(".json"):
with open(pairs_file) as f:
self.pairs = json.load(f)
else:
df = pd.read_csv(pairs_file)
self.pairs = df.to_dict("records")
def __len__(self):
return len(self.pairs)
def __getitem__(self, idx: int):
pair = self.pairs[idx]
# Load receptor (apo or other-holo) + crystal ligand
# Implementation mirrors PDBBindDataset but with explicit pair paths
# (omitted for brevity – follows same featurization pipeline)
raise NotImplementedError("Implement path resolution for your CD test set layout")