Paper | Video | Interactive interface
Implementation of Segment-Factorized Full-Song Generation on Symbolic Piano Music, 39th Conference on Neural Information Processing Systems (NeurIPS 2025) Workshop: AI for Music.
It's highly recommended to interact with our model via this user interface, where you can collaborate on creating music with the model. This repository only offers a non-interactive CLI.
pip install -e .
First, download the pretrained SFS model checkpoint from here. Put it in the pretrained_ckpt directory.
To generate 2 samples with the given segments and compose order:
python generate.py --segments A4B8C8D8B8C8E8 --compose_order 2 0 1 3 4 5 6 -n 2
To generate using a given seed MIDI:
python generate.py --segments A4B8C8D8B8C8E8 --compose_order 2 0 1 3 4 5 6 -n 2 --seed_midi <path/to/seed.mid>
The result will be saved in the generated directory.
To specify the SFS model checkpoint, use the --ckpt argument.
-
Set up the dataset
Create dataset/synced_midi directory and put all midi files to it. The dataset structure should be like this:
dataset/ └── synced_midi/ ├── file1.mid ├── file2.mid └── ...Run the following command to preprocess the dataset.
python process_dataset.py --num_processes <num processes> -
Set up utilities for logging audio to wandb
Make sure fluidsynth and ffmpeg are installed. For fluidsynth to work, you need to prepare a soundfont file and set the SOUNDFONT_PATH environment variable.
export SOUNDFONT_PATH="<path to soundfont>" -
Log in to wandb
wandb login -
Train VAE embedder
python train.py config/model/vae.yaml config/dataset/tokens.yaml --num_workers <num workers> -
Unwrap VAE embedder checkpoint. This will create a safetensors file from the ckpt (lightning module) file saved when training.
python unwrap_lightning_module.py wandb/<run_name>/files/checkpoints/<checkpoint_name>.ckpt -
Calculate embeddings for each bar of the dataset using the VAE embedder.
python embed.py config\model\vae.yaml config\dataset\tokens.yaml --ckpt_path wandb/<run_name>/files/checkpoints/<checkpoint_name>.safetensors --output_name bar_embedding -
Train SFS model
python train.py config/model/segment_full_song.yaml config/dataset/segment_full_song.yaml --num_workers <num workers> --bar_embedder_ckpt_path wandb/<run_name>/files/checkpoints/<checkpoint_name>.safetensors -
Unwrap SFS model checkpoint.
python unwrap_lightning_module.py wandb/<run_name>/files/checkpoints/<checkpoint_name>.ckpt