Keywords: Atrial fibrillation, ESUS, Hypergraph learning, Pre-training, Transfer learning
This repo explores Atrial Fibrillation (AF) prediction in ESUS patients using pre-training + transfer learning on hypergraphs. We first learn patient representations on a large stroke cohort (AI-RESPECT, n = 7,780), then transfer compact embeddings to a smaller ESUS cohort (n = 510) for downstream AF prediction with lightweight ML models.
Goal: Predict whether an ESUS patient will develop AF (binary classification).
- Challenges: small target cohort, high-dimensional diagnostic features (ICD), risk of overfitting.
- Inputs:
- Baseline clinical features (58 dims)
- Diagnostic features (ICD; up to 1,529 dims for ESUS)
- Output: AF risk (0/1)
We represent patient data as a hypergraph: nodes = diagnostic features; hyperedges = patient visits/encounters. We then learn hyperedge (patient) embeddings and transfer them to ESUS.
Concatenate clinical + diagnostic features:
x_i = x_{i,b} ⊕ x_{i,d}- Trains directly on ESUS; simple but prone to overfitting in small-N, high-D.
Pre-train a hypergraph transformer on AI-RESPECT (labeled PSCI task) and transfer:
- Learn final hyperedge embedding as a 32-D patient vector.
- Build ESUS features by
x_i = x_{i,tr} ⊕ x_{i,b}and train AF classifiers (LR/RF/GB).
Pre-train on AI-RESPECT without labels via two components, then transfer:
- Hypergraph View Augmentation (genSim):
- Node masking biased by duplication; hyperedge selection via Gumbel-Softmax.
- Consistency objectives across two augmented views:
L_genSim = L_hyper + L_sim. - Triplet Contrastive Learning (Trip):
- Node-level, hyperedge-level, and membership-level contrasts across augmented graphs.
- Total loss:
L_total = L_genSim + L_n + L_e + L_m(equal weights). - Extract a 32-D patient embedding and concatenate with clinical features as above.
Overview of our proposed framework for AF prediction in ESUS patients
Hypergraphs naturally capture many-to-many relations between features and visits, enabling attention-based message passing:
- Within-hyperedge (V→E) and within-node (E→V) self-attention propagate information between features and patient encounters.
- Produces compact, expressive patient embeddings for downstream AF prediction.
This project leverages electronic health record (EHR) data from the Emory Healthcare System, combining two cohorts for pre-training and transfer learning:
-
ESUS Dataset (Target Cohort)
-
510 patients diagnosed with Embolic Stroke of Undetermined Source (ESUS) between Jan 1, 2015 – Dec 13, 2023
-
107 developed post-stroke AF as a first occurrence
-
Inclusion criteria: ≥18 years old, no prior stroke within 5 years before 2015, and no history of AF before index stroke
-
Features:
-
58 baseline clinical variables: demographics, biomarkers, echocardiographic, ECG features, comorbidities
-
1,529 diagnostic features: 990 ICD-based + 539 medication-related
-
AI-RESPECT Dataset (Pre-training Cohort)
-
7,780 stroke patients diagnosed between Jan 1, 2012 – Dec 31, 2021
-
1,735 developed post-stroke cognitive impairment (PSCI)
-
Inclusion criteria: stroke diagnosis with no prior history of cognitive impairment
-
Features: 2,595 diagnostic features, broader coverage of stroke-related medical history
Across both datasets, 1,494 diagnostic features overlap, covering 97.7% of ESUS diagnostic features, enabling effective pre-training transfer.
⚠️ Note: Both ESUS and AI-RESPECT datasets are institutional EHR data and not publicly available due to patient privacy. Access requires proper institutional approvals and IRB clearance.
-
assets: Contains supporting materials for the project.
-
downstream_predicting: Contains scripts for downstream ML training and validation.
-
supervised_pretraining: Code for supervised representation learning on hypergraphs.
- src: Source code for supervised training (layers, models, preprocessing, training).
- outputs: Stores results and pretrained supervised embeddings.
-
unsupervised_pretraining: Code for unsupervised representation learning on hypergraphs.
- data: Dataset folders for unsupervised experiments.
- pyg_data: Processed data for PyG (PyTorch Geometric).
- raw_data: Raw input data.
- src: Source code for unsupervised hypergraph pretraining.
- edge_representation: output folder for edge-level embedding.
- outputs: Stores results and embeddings from unsupervised pretraining.
- scripts: Scripts for preprocessing, hypergraph construction, augmentation, training for unsupervised training.
- data: Dataset folders for unsupervised experiments.
-
requirements.txt: requirement.txt for depedency.
Navigate to the unsupervised_pretraining/scripts/ folder to generate initial hyperedge features:
- Script:
gen_feat.sh - Description: Runs random walks + Word2Vec to produce node embeddings and hyperedge feature representations.
After generating features, compute overlapness and homogeneity metrics:
- Script:
gen_overlap_homogeneity.sh - Description: Generates
overlapnessandhomogeneityfiles for each dataset. - Note: Place the generated files under the same folder as the corresponding dataset.
Once the overlapness and homogeneity files are ready, run the training scripts to obtain low-dimensional hyperedge feature representations:
- Scripts:
spicd3.sh– train on separate ICD-3 datasetspicd4.sh– train on separate ICD-4 datasetcbicd3.sh– train on combined ICD-3 datasetcbicd4.sh– train on combined ICD-4 dataset- Description: Each script calls the transfer learning framework to pretrain and fine-tune on hypergraph datasets, producing compressed representations of hyperedge features.
As in the unsupervised setting, begin by generating features for each dataset.
Navigate to the supervised_pretraining directory and run train.py to obtain low-dimensional feature representations for hyperedges.
Two complementary runners are provided:
Script: /downstream_predicting/External_validation.py
- Train on main cohort (pick one embedding set per run)
# Supervised embedding
python External_validation.py --train-main --main-baseline data/main/baseline.csv --embed-supervised data/main/supervised.csv --models-dir outputs/main/models --figs-dir outputs/main/figs
# Unsupervised embedding(s) (CSV or NPY; if multiple given, the first is used for training)
python External_validation.py --train-main --main-baseline data/main/baseline.csv --embed-unsupervised data/main/unsup_a.csv data/main/unsup_b.csv --models-dir outputs/main/models --figs-dir outputs/main/figs- Predict on external cohort with saved models
python External_validation.py --predict-external --external-baseline data/external/baseline.csv --external-embed-unsupervised data/external/unsup_a.npy data/external/unsup_b.npy --models-dir outputs/main/models --figs-dir outputs/external/figs- Train + External Validate in one pass
python External_validation.py --train-main --predict-external --main-baseline data/main/baseline.csv --embed-supervised data/main/supervised.csv --external-baseline data/external/baseline.csv --external-embed-supervised data/external/supervised.csv --models-dir outputs/main/models --figs-dir outputs/external/figsOutputs:
outputs/main/models/{LR,RF,GB}.pkl— full trained pipelinesoutputs/external/figs/<dataset>.png— overlaid ROC curves
Script: /downstream_predicting/ML_prediction.py
python ML_prediction.py --baseline data/main/baseline.csv --supervised data/main/supervised.csv --unsupervised data/main/unsup_a.csv data/main/unsup_b.csv --results outputs/cv/results.csv --outer-splits 5 --inner-splits 3 --seed 42Outputs:
outputs/cv/results.csvwith rows:
dataset,model,auc_mean,auc_std,f1_mean,f1_std
Notes:
targetcolumn must be present in baseline CSV.- Option A accepts unsupervised embeddings as CSV or NPY; Option B expects CSV.
- To compare multiple embeddings with Opt---ion A, run multiple times with different
--models-dir.
