Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,12 @@ It is important to set `data.load_saved_data=False` to ensure that cached data i
uv run train.py data.data_path=path/to/your/data.txt data.load_saved_data=False data.saved_data_path=path/to/your/saved_data.pkl
```

A colab notebook demonstrating an example of training using your own data is provided. This example uses a dummy dataset of three 200bp sequences with a single cell type "CELL_A".

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1aQJm91rmmS4do-B-iYTloDBjZU7hD4ud?usp=sharing)

along with a copy of the notebook at `notebooks/new_data_training_and_sequence_generation.ipynb`. This example was run on Google Colab using a T4 GPU.


## Contributors ✨

Expand Down
359 changes: 359 additions & 0 deletions notebooks/new_data_training_and_sequence_generation.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,359 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "4_QqB7EQOYi4"
},
"source": [
"# Cloning repository and installing dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"collapsed": true,
"id": "-rPa178i9Bwh",
"outputId": "a1a9aae4-b0f5-457c-f08f-b21e5a75cd87"
},
"outputs": [],
"source": [
"!git clone https://github.com/pinellolab/DNA-Diffusion.git && cd DNA-Diffusion && uv sync"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "YYeWyznJAkKA",
"outputId": "defbe375-7ccc-4a7e-8ba3-19c4ae7479a3"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/content/DNA-Diffusion\n"
]
}
],
"source": [
"%cd DNA-Diffusion"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vleK2Khu6hJK"
},
"source": [
"# Creating a Simulated Dataset"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ninf0bzF7kZz"
},
"source": [
"Below we provide an example of how to use a new dataset with the DNA-Diffusion library. The dummy dataset has 3 sequences with an associated cell type of \"CELL_A\". We demonstrate that using this dataset we can regenerate the associated file \"encode_data.pkl\" that is used to train the model."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "3ZoFvDOI8ggK"
},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"\n",
"tags = ['CELL_A', 'CELL_A', 'CELL_A']\n",
"chr = [\"chr1\", \"chr2\", \"chr3\"]\n",
"\n",
"df = pd.DataFrame(columns=['chr', 'sequence', 'TAG'])\n",
"\n",
"for i, (tag, chromosome) in enumerate(zip(tags, chr)):\n",
" if i == 2:\n",
" sequence = \"A\" * 200\n",
" df.loc[i] = [chromosome, sequence, tag]\n",
" else:\n",
" sequence = ''.join(np.random.choice(list('ACGT'), size=200))\n",
" df.loc[i] = [chromosome, sequence, tag]\n",
"\n",
"df.to_csv('data/dummy_data.txt', index=False, sep='\\t')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sOv6pRP4_f41"
},
"source": [
"# Basic Training Example"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nYDLoBS9__IH"
},
"source": [
"Below we provide an example of the training using the debug flag. This will only train the model on a single sequence for a minimum of 5 epochs with a patience parameter of 2 epochs. We also show that the data file can be overrided within the CLI call to integrate the new dataset. It is important to set data.load_saved_data=False, so that the additional metadata used to train the model is regenerated."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ArMTlzed_V-E",
"outputId": "533f9def-d65a-46d7-aaeb-2157b251529b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"model:\n",
" _target_: src.dnadiffusion.models.unet.UNet\n",
" dim: 200\n",
" channels: 1\n",
" dim_mults:\n",
" - 1\n",
" - 2\n",
" - 4\n",
" resnet_block_groups: 4\n",
"data:\n",
" _target_: src.dnadiffusion.data.dataloader.get_dataset\n",
" data_path: data/dummy_data.txt\n",
" saved_data_path: data/encode_data.pkl\n",
" load_saved_data: false\n",
" debug: true\n",
"optimizer:\n",
" _target_: torch.optim.Adam\n",
" lr: 0.0001\n",
"diffusion:\n",
" _target_: src.dnadiffusion.models.diffusion.Diffusion\n",
" timesteps: 50\n",
" beta_start: 0.0001\n",
" beta_end: 0.2\n",
"training:\n",
" distributed: false\n",
" precision: float32\n",
" num_workers: 1\n",
" pin_memory: false\n",
" batch_size: 1\n",
" sample_batch_size: 1\n",
" num_epochs: 2200\n",
" min_epochs: 5\n",
" patience: 2\n",
" log_step: 50\n",
" sample_epoch: 50000\n",
" number_of_samples: 10\n",
" use_wandb: false\n",
"\n",
" 0% 0/2200 [00:00<?, ?it/s]/content/DNA-Diffusion/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:838: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
" return fn(*args, **kwargs)\n",
" 0% 9/2200 [01:51<6:44:30, 11.08s/it]Early stopping at epoch 9, Best val loss: 0.30490654706954956 achieved at epoch 7\n",
" 0% 9/2200 [01:51<7:33:13, 12.41s/it]\n"
]
}
],
"source": [
"!uv run train.py -cn train_debug data.data_path='data/dummy_data.txt' data.load_saved_data=False"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5ql3m5jQAv-F"
},
"source": [
"# Basic sequence generation example using the created checkpoint"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RgbgEGD0AwuL"
},
"source": [
"The model successfully trained and now we will use the checkpoint with the lowest validation loss to generate 1 sequence per cell type. Given that the training seed is not fixed, this is not a deterministic result and your validation loss may slightly vary from the cached example."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "2_CWe2ij2ZK5",
"outputId": "6eeffc3f-ba89-4dfa-b29d-8b59ed2d0ccb"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"All available checkpoints: \n",
"['checkpoints/model_epoch5_step6_valloss_0.348039.pt', 'checkpoints/model_epoch7_step8_valloss_0.304907.pt']\n",
"\n",
"Using checkpoint with lowest validation loss: \n",
"checkpoints/model_epoch7_step8_valloss_0.304907.pt\n"
]
}
],
"source": [
"import os\n",
"\n",
"checkpoint_dir = \"checkpoints\"\n",
"files = os.listdir(checkpoint_dir)\n",
"file_paths = sorted([os.path.join(checkpoint_dir, f) for f in files if os.path.isfile(os.path.join(checkpoint_dir, f)) and \".gitkeep\" not in f])\n",
"best_checkpoint = file_paths[-1]\n",
"\n",
"print(f\"All available checkpoints: \\n{file_paths}\")\n",
"print(f\"\\nUsing checkpoint with lowest validation loss: \\n{best_checkpoint}\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6Q45JZoaAsMw",
"outputId": "f24f5378-bf3a-4a38-c1d0-e4200fdf6906"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"model:\n",
" _target_: src.dnadiffusion.models.unet.UNet\n",
" dim: 200\n",
" channels: 1\n",
" dim_mults:\n",
" - 1\n",
" - 2\n",
" - 4\n",
" resnet_block_groups: 4\n",
"data:\n",
" _target_: src.dnadiffusion.data.dataloader.get_dataset_for_sampling\n",
" data_path: data/dummy_data.txt\n",
" saved_data_path: data/encode_data.pkl\n",
" load_saved_data: true\n",
" debug: false\n",
" cell_types: null\n",
"diffusion:\n",
" _target_: src.dnadiffusion.models.diffusion.Diffusion\n",
" timesteps: 50\n",
" beta_start: 0.0001\n",
" beta_end: 0.2\n",
"sampling:\n",
" checkpoint_path: checkpoints/model_epoch7_step8_valloss_0.304907.pt\n",
" sample_batch_size: 1\n",
" number_of_samples: 1\n",
" guidance_scale: 1.0\n",
"\n",
"Loading checkpoint\n",
"Model sent to cuda\n",
"Found cell types: ['CELL_A']\n",
"Generating 1 samples for cell CELL_A\n",
"100% 1/1 [00:02<00:00, 2.40s/it]\n"
]
}
],
"source": [
"!uv run sample.py sampling.number_of_samples=1 sampling.sample_batch_size=1 sampling.checkpoint_path=$best_checkpoint data.data_path='data/dummy_data.txt'"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7HJm-9OJC49v"
},
"source": [
"# View Generated Sequences"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "G-NmKG6FC6Pi",
"outputId": "8e58089e-ee77-4f71-f0f0-276e953f8bd0"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Displaying sequences from: data/outputs\n",
"\n",
"--- Cell Type: CELL (CELL_A.txt) ---\n",
"TATATTTATGTTAAATTCATGCTTATTTTATATTTTTTTTTTTTTTTGAGTAGTTATTGTATTTTTTATATATTGAAAATATTTTTTTTTTTTACAAAAATAAAATATAAATAACATTTGTAAAATGTTCTAAGTGTGTGTTGCATTTTAATATATAATATTTTTATTTGTAAATAAATAATAATTTTTATTTTTGTTTG\n",
"----------------------\n",
"\n"
]
}
],
"source": [
"import os\n",
"import subprocess\n",
"\n",
"def display_sequences(output_dir=\"data/outputs\"):\n",
" if not os.path.isdir(output_dir):\n",
" print(f\"Error: Directory '{output_dir}' not found.\")\n",
" return\n",
"\n",
" print(f\"Displaying sequences from: {output_dir}\\n\")\n",
"\n",
" for filename in sorted(os.listdir(output_dir)):\n",
" filepath = os.path.join(output_dir, filename)\n",
"\n",
" if os.path.isfile(filepath) and \"gitkeep\" not in filepath:\n",
" cell_type = filename.split('_')[0]\n",
" print(f\"--- Cell Type: {cell_type} ({filename}) ---\")\n",
" result = subprocess.run(['cat', filepath], capture_output=True, text=True, check=True)\n",
" print(result.stdout)\n",
" print(\"-\" * (len(cell_type) + 18) + \"\\n\")\n",
"\n",
"display_sequences()"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Loading
Loading