Skip to content

bbj-lab/fms-ehrs-flwr

Repository files navigation

fms-ehrs-flwr

This flower app performs federated training of a FM on tokenized EHR data.

Install

git clone https://github.com/bbj-lab/fms-ehrs-flwr
cd fms-ehrs-flwr
mkdir -p logs
# pip install uv
uv venv --python=$(which python3) venv
source venv/bin/activate
uv pip install --torch-backend=cu128 --link-mode=copy -e .

Interactive run

You can develop code on a single gpu and use the gpudev partition which generally has good availability. (This is not efficient, so I would use this primarily for troubleshooting/debugging.)

systemd-run --scope --user tmux new -s gpuq
srun -p gpudev \
  --gres=gpu:1 \
  --cpus-per-task=3 \
  --time=8:00:00 \
  --job-name=flwr \
  --pty bash -i

source venv/bin/activate
flwr run . minimal-gpudev 2>&1 | tee logs/${SLURM_JOB_ID}-flwr.stddout

Slurm

There's a second configuration that runs 3 gpu's on the gpuq partition.

jid=$(sbatch --parsable slurm.sh)

Configuration

The pyproject.toml file provides some configurable options for our app:

[tool.flwr.app.config]
    data-version                = "W++" # the version of the data we're using
    gradient-accumulation-steps = 2 # waiting to aggregate gradients is a multiplier on the effective batch size
    home-dir                    = "/gpfs/data/bbj-lab/users/burkh4rt" # parent of "data-mimic" and "data-ucmc"
    local-epochs                = 1 # how long to let individual workers run before aggregating
    lr                          = 2e-4 # learning rate
    max-seq-length              = 1024
    model-dir                   = "mdls" # where in the home-dir to store trained models
    num-server-rounds           = 10 # number of epochs to run
    per-device-eval-batch-size  = 4
    per-device-train-batch-size = 4
    setup-version               = "1e_4b_10r_0002lr"

Outputs

Logs from the runs are placed into the logs directory. Each one should record all configuration settings used for the run, for example,

[2025-10-16T14:02:01-0500] context.run_config={
    'data-version': 'W++',
    'gradient-accumulation-steps': 2,
    'home-dir': '/gpfs/data/bbj-lab/users/burkh4rt',
    'local-epochs': 1,
    'lr': 0.0002,
    'max-seq-length': 1024,
    'model-dir': 'mdls',
    'num-server-rounds': 10,
    'per-device-eval-batch-size': 4,
    'per-device-train-batch-size': 4,
    'setup-version': '1e_4b_10r_0002lr'
    }

and a summary of performance on the evaluation set after each round:

[SUMMARY]
Run finished 3 round(s) in 9800.81s
History (loss, distributed):
    round 1: 0.931810670185353
    round 2: 0.8875063137023168
    round 3: 0.881442553112436

We've slightly modified flwr's implementation of FedAvg to save a copy of the aggregated model after each training round. Currently, these are saved to the model-dir entry in our configuration within the home-dir folder.

Monitoring

Running nvtop on the node running the job (srun --jobid=$jid --pty nvtop) should give you something like this:

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors