xfetusEDM2 provides tools for training and evaluating EDM2 diffusion models on the Open Fetal Planes Ultrasound dataset. The repository includes utilities for training models, generating synthetic images, and evaluating performance using FID metrics.
🔩 Installation
- Install NVIDIA Drivers. Ensure that your system has compatible NVIDIA drivers installed.
sudo apt install nvidia-driver-550 #Update the NVIDIA Driver
sudo reboot # if in local machine reboot-
Check PyTorch and supported CUDA version https://pypi.org/project/torch/#history
⚠️ PyTorch (2.11.0, released on Mar 23, 2026); CUDA versions available (CUDA 12.6, CUDA 12.8, CUDA 13.0 (stable))⚠️ PyTorch (2.10.0, released on Jan 21, 2026); CUDA versions available (CUDA 12.6, CUDA 12.8) -
Create a Python Environment (using uv)
wget -qO- https://astral.sh/uv/install.sh | sh
uv venv --python 3.11 # Create a virtual environment at .venv.
source .venv/bin/activate #To activate the virtual environment
uv sync --extra test --extra learning
uv pip list --verbose #check versions- Launch Jupyter locally
uv run jupyter notebook- pre-commit hooks
#Generate the baseline file
mkdir -p .github
detect-secrets scan > .github/.secrets.baseline
# run pre-commit hooks
uv run pre-commit run -aTo train the EDM2 on the small (s) version of the EDM2 architecture for the fetal planes dataset first download the dataset from https://zenodo.org/records/3904280 as shown in data.
Then to train xxs, s-sized model (edm2-img512-xxs, edm2-img512-s,) for ImageNet-512 using 1 or 8 GPUs (--batch-gpu=1, --batch-gpu=8), for example, run the following command in the root directory of this repo:
source .venv/bin/activate #To activate the virtual environment
torchrun --standalone --nproc_per_node=1 train_edm2.py \
--outdir ~/scratch-volume/FETAL_PLANES_DB/OUTPUT_DIRECTORY \
--data ~/scratch-volume/FETAL_PLANES_DB \
--batch 1 \
--preset edm2-img512-xxs \
--batch-gpu=1where DATASET_LOCATION should be the root directory of the downloaded fetal planes dataset and OUTPUT_DIRECTORY is the location we will save log.txt, stats.jsonl and our model checkpoints (e.g. network-snapshot-0000000-0.050.pkl, etc ).
Once our model is trained we generate 5k image per class. This can be done using the following bash code:
for class_idx in 0 1 2 3 4 5; do
python generate_images.py \
--preset=edm2-img512-s-guid-fid \
--net_ckpt=./OUTPUT_DIRECTORY/training-state-0008519.pt \
--gnet_ckpt=./OUTPUT_DIRECTORY/training-state-0001310.pt \
--outdir=./OUTPUT_DIRECTORY/diffusion_samples_FETAL_cond_${class_idx} \
--guidance 1.5 \
--seeds=0-5000 \
--class=${class_idx}
doneGeneration require two network checkpoints (the first should be trained for longer than the second). In this example, we have set the first model to training-state-0008519.pt and the second model to training-state-0001310.pt, but whatever checkpoints can be used here, just make sure the net_ckpt has been trained for longer. The guidance flag controls the strength of the autoguidance and may need to be tuned for optimal performance. The outdir flag is where the generated images will be saved.
Finally, to measure the FID of the generated images you can use the following command:
python fid_measurement.py \
--real_root ./DATASET_LOCATION \
--csv_file ./DATASET_LOCATION/FETAL_PLANES_DB_data.csv \
--fake_root ./OUTPUT_DIRECTORY/ \
--split test \
--batch_size 32 \
--device cudawhere the fake_root flag is where generated images are saved.
We welcome contributions from the community. Before submitting a PR:
uv run pre-commit run -aThis ensures code formatting and linting checks pass.
You need to authorize a personal access token for use with single sign-on
git clone https://github.com/xfetus/fetal-ultrasound-edm2.git