Skip to content

Commit ac72152

Browse files
committed
feat: add input validation layer for early error detection
Add rfdiffusion/validation.py with validators for: - PDB file existence and ATOM record format - Contig string syntax (ranges, chain-residue specs) - Model checkpoint existence - Hotspot residue format (chain letter + number) - Diffuser config parameters (T, partial_T bounds) Validators are called in Sampler.initialize() and sample_init(), before GPU allocation and model loading, so users get clear error messages instead of cryptic tensor shape mismatches.
1 parent 9535f19 commit ac72152

2 files changed

Lines changed: 203 additions & 0 deletions

File tree

rfdiffusion/inference/model_runners.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@
1717
import string
1818

1919
from rfdiffusion.model_input_logger import pickle_function_call
20+
from rfdiffusion.validation import (
21+
validate_pdb_path,
22+
validate_checkpoint_path,
23+
validate_contig_string,
24+
validate_hotspot_res,
25+
validate_diffuser_config,
26+
)
2027
import sys
2128

2229
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
@@ -112,6 +119,14 @@ def initialize(self, conf: DictConfig) -> None:
112119
), "trb_save_ckpt_path is not the place to specify an input model. Specify in inference.ckpt_override_path"
113120
self._conf["inference"]["trb_save_ckpt_path"] = self.ckpt_path
114121

122+
# Validate inputs early, before GPU allocation and model loading
123+
validate_checkpoint_path(self.ckpt_path)
124+
if conf.inference.input_pdb is not None:
125+
validate_pdb_path(conf.inference.input_pdb)
126+
validate_diffuser_config(conf.diffuser)
127+
if conf.ppi.hotspot_res is not None:
128+
validate_hotspot_res(conf.ppi.hotspot_res)
129+
115130
#######################
116131
### Assemble Config ###
117132
#######################
@@ -313,6 +328,10 @@ def sample_init(self, return_forward_trajectory=False):
313328
### Generate specific contig ###
314329
################################
315330

331+
# Validate contig string before parsing
332+
if self.contig_conf.contigs is not None:
333+
validate_contig_string(self.contig_conf.contigs)
334+
316335
# Generate a specific contig from the range of possibilities specified at input
317336

318337
self.contig_map = self.construct_contig(self.target_feats)

rfdiffusion/validation.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
"""Input validation for RFdiffusion inference.
2+
3+
Catches common configuration and input errors early, before model loading
4+
and GPU allocation, so users get clear error messages instead of cryptic
5+
tensor shape mismatches deep in the forward pass.
6+
"""
7+
8+
import os
9+
import re
10+
import logging
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
class ValidationError(ValueError):
16+
"""Raised when input validation fails with a user-friendly message."""
17+
pass
18+
19+
20+
def validate_pdb_path(pdb_path: str) -> None:
21+
"""Validate that a PDB file exists and contains parseable ATOM records.
22+
23+
Args:
24+
pdb_path: Path to input PDB file.
25+
26+
Raises:
27+
ValidationError: If file doesn't exist or has no ATOM records.
28+
"""
29+
if not os.path.isfile(pdb_path):
30+
raise ValidationError(
31+
f"Input PDB file not found: {pdb_path}"
32+
)
33+
34+
has_atoms = False
35+
with open(pdb_path, "r") as f:
36+
for line in f:
37+
if line.startswith(("ATOM", "HETATM")) and len(line) >= 54:
38+
has_atoms = True
39+
try:
40+
float(line[30:38])
41+
float(line[38:46])
42+
float(line[46:54])
43+
except ValueError:
44+
raise ValidationError(
45+
f"Invalid coordinates in PDB line: {line.rstrip()}"
46+
)
47+
break
48+
49+
if not has_atoms:
50+
raise ValidationError(
51+
f"PDB file contains no ATOM/HETATM records: {pdb_path}"
52+
)
53+
54+
55+
def validate_contig_string(contigs: list) -> None:
56+
"""Validate contig string syntax before parsing.
57+
58+
Args:
59+
contigs: List of contig specification strings.
60+
61+
Raises:
62+
ValidationError: If contig syntax is invalid.
63+
"""
64+
if not contigs or not isinstance(contigs, (list, tuple)):
65+
raise ValidationError(
66+
"contigs must be a non-empty list of strings. "
67+
"Example: ['10-20/A5-50/0 30-40']"
68+
)
69+
70+
contig_str = contigs[0]
71+
if not isinstance(contig_str, str) or not contig_str.strip():
72+
raise ValidationError(
73+
f"Contig string must be a non-empty string, got: {contig_str!r}"
74+
)
75+
76+
for segment in contig_str.strip().split():
77+
for part in segment.split("/"):
78+
part = part.strip()
79+
if not part:
80+
continue
81+
# Chain break marker
82+
if part == "0":
83+
continue
84+
# Numeric range: "10-20" or "10"
85+
if part[0].isdigit():
86+
if "-" in part:
87+
pieces = part.split("-")
88+
if len(pieces) != 2:
89+
raise ValidationError(
90+
f"Invalid contig range format: '{part}'. "
91+
f"Expected 'N-M' (e.g., '10-20')."
92+
)
93+
try:
94+
lo, hi = int(pieces[0]), int(pieces[1])
95+
except ValueError:
96+
raise ValidationError(
97+
f"Non-integer values in contig range: '{part}'"
98+
)
99+
if lo < 0 or hi < 0:
100+
raise ValidationError(
101+
f"Negative value in contig range: '{part}'"
102+
)
103+
if lo > hi:
104+
raise ValidationError(
105+
f"Invalid contig range: '{part}' (start > end)"
106+
)
107+
# Chain-residue range: "A5-50" or "A5"
108+
elif part[0].isalpha():
109+
if not re.match(r"^[A-Za-z]\d+(-\d+)?$", part):
110+
logger.warning(f"Unusual contig segment: '{part}'")
111+
112+
113+
def validate_checkpoint_path(ckpt_path: str) -> None:
114+
"""Validate that a model checkpoint file exists.
115+
116+
Args:
117+
ckpt_path: Path to model checkpoint.
118+
119+
Raises:
120+
ValidationError: If checkpoint file doesn't exist.
121+
"""
122+
if not os.path.isfile(ckpt_path):
123+
raise ValidationError(
124+
f"Model checkpoint not found: {ckpt_path}. "
125+
f"Please download models following the README instructions."
126+
)
127+
128+
129+
def validate_hotspot_res(hotspot_res: list) -> None:
130+
"""Validate hotspot residue format (e.g., ['A50', 'B123']).
131+
132+
Args:
133+
hotspot_res: List of hotspot residue strings.
134+
135+
Raises:
136+
ValidationError: If format is invalid.
137+
"""
138+
if hotspot_res is None:
139+
return
140+
141+
for res in hotspot_res:
142+
if not isinstance(res, str) or len(res) < 2:
143+
raise ValidationError(
144+
f"Invalid hotspot residue format: {res!r}. "
145+
f"Expected format like 'A50' (chain letter + residue number)."
146+
)
147+
if not res[0].isalpha():
148+
raise ValidationError(
149+
f"Hotspot residue must start with a chain letter: {res!r}"
150+
)
151+
try:
152+
int(res[1:])
153+
except ValueError:
154+
raise ValidationError(
155+
f"Hotspot residue number must be an integer: {res!r}"
156+
)
157+
158+
159+
def validate_diffuser_config(diffuser_conf) -> None:
160+
"""Validate diffuser configuration parameters.
161+
162+
Args:
163+
diffuser_conf: Diffuser configuration object.
164+
165+
Raises:
166+
ValidationError: If parameters are out of valid range.
167+
"""
168+
T = getattr(diffuser_conf, "T", None)
169+
partial_T = getattr(diffuser_conf, "partial_T", None)
170+
171+
if T is not None and T < 1:
172+
raise ValidationError(
173+
f"diffuser.T must be >= 1, got {T}"
174+
)
175+
if partial_T is not None:
176+
if partial_T < 1:
177+
raise ValidationError(
178+
f"diffuser.partial_T must be >= 1, got {partial_T}"
179+
)
180+
if T is not None and partial_T > T:
181+
raise ValidationError(
182+
f"diffuser.partial_T ({partial_T}) cannot exceed "
183+
f"diffuser.T ({T})"
184+
)

0 commit comments

Comments
 (0)