This flower app performs federated training of a FM on tokenized EHR data.
git clone https://github.com/bbj-lab/fms-ehrs-flwr
cd fms-ehrs-flwr
mkdir -p logs
# pip install uv
uv venv --python=$(which python3) venv
source venv/bin/activate
uv pip install --torch-backend=cu128 --link-mode=copy -e .You can develop code on a single gpu and use the gpudev partition which
generally has good availability. (This is not efficient, so I would use this
primarily for troubleshooting/debugging.)
systemd-run --scope --user tmux new -s gpuq
srun -p gpudev \
--gres=gpu:1 \
--cpus-per-task=3 \
--time=8:00:00 \
--job-name=flwr \
--pty bash -i
source venv/bin/activate
flwr run . minimal-gpudev 2>&1 | tee logs/${SLURM_JOB_ID}-flwr.stddoutThere's a second configuration that runs 3 gpu's on the gpuq partition.
jid=$(sbatch --parsable slurm.sh)The pyproject.toml file provides some configurable options for our app:
[tool.flwr.app.config]
data-version = "W++" # the version of the data we're using
gradient-accumulation-steps = 2 # waiting to aggregate gradients is a multiplier on the effective batch size
home-dir = "/gpfs/data/bbj-lab/users/burkh4rt" # parent of "data-mimic" and "data-ucmc"
local-epochs = 1 # how long to let individual workers run before aggregating
lr = 2e-4 # learning rate
max-seq-length = 1024
model-dir = "mdls" # where in the home-dir to store trained models
num-server-rounds = 10 # number of epochs to run
per-device-eval-batch-size = 4
per-device-train-batch-size = 4
setup-version = "1e_4b_10r_0002lr"Logs from the runs are placed into the logs directory. Each one should record
all configuration settings used for the run, for example,
[2025-10-16T14:02:01-0500] context.run_config={
'data-version': 'W++',
'gradient-accumulation-steps': 2,
'home-dir': '/gpfs/data/bbj-lab/users/burkh4rt',
'local-epochs': 1,
'lr': 0.0002,
'max-seq-length': 1024,
'model-dir': 'mdls',
'num-server-rounds': 10,
'per-device-eval-batch-size': 4,
'per-device-train-batch-size': 4,
'setup-version': '1e_4b_10r_0002lr'
}
and a summary of performance on the evaluation set after each round:
[SUMMARY]
Run finished 3 round(s) in 9800.81s
History (loss, distributed):
round 1: 0.931810670185353
round 2: 0.8875063137023168
round 3: 0.881442553112436
We've slightly modified flwr's implementation of
FedAvg to save a copy of the aggregated model
after each training round. Currently, these are saved to the model-dir entry in
our configuration within the home-dir folder.
Running nvtop on the node running the job
(srun --jobid=$jid --pty nvtop) should give you something like this:
