diff --git a/README.md b/README.md index c73c8fc7..455a3df9 100644 --- a/README.md +++ b/README.md @@ -193,6 +193,12 @@ It is important to set `data.load_saved_data=False` to ensure that cached data i uv run train.py data.data_path=path/to/your/data.txt data.load_saved_data=False data.saved_data_path=path/to/your/saved_data.pkl ``` +A colab notebook demonstrating an example of training using your own data is provided. This example uses a dummy dataset of three 200bp sequences with a single cell type "CELL_A". + +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1aQJm91rmmS4do-B-iYTloDBjZU7hD4ud?usp=sharing) + +along with a copy of the notebook at `notebooks/new_data_training_and_sequence_generation.ipynb`. This example was run on Google Colab using a T4 GPU. + ## Contributors ✨ diff --git a/notebooks/new_data_training_and_sequence_generation.ipynb b/notebooks/new_data_training_and_sequence_generation.ipynb new file mode 100644 index 00000000..edb044e1 --- /dev/null +++ b/notebooks/new_data_training_and_sequence_generation.ipynb @@ -0,0 +1,359 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "4_QqB7EQOYi4" + }, + "source": [ + "# Cloning repository and installing dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "collapsed": true, + "id": "-rPa178i9Bwh", + "outputId": "a1a9aae4-b0f5-457c-f08f-b21e5a75cd87" + }, + "outputs": [], + "source": [ + "!git clone https://github.com/pinellolab/DNA-Diffusion.git && cd DNA-Diffusion && uv sync" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "YYeWyznJAkKA", + "outputId": "defbe375-7ccc-4a7e-8ba3-19c4ae7479a3" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/content/DNA-Diffusion\n" + ] + } + ], + "source": [ + "%cd DNA-Diffusion" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vleK2Khu6hJK" + }, + "source": [ + "# Creating a Simulated Dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ninf0bzF7kZz" + }, + "source": [ + "Below we provide an example of how to use a new dataset with the DNA-Diffusion library. The dummy dataset has 3 sequences with an associated cell type of \"CELL_A\". We demonstrate that using this dataset we can regenerate the associated file \"encode_data.pkl\" that is used to train the model." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "3ZoFvDOI8ggK" + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "\n", + "tags = ['CELL_A', 'CELL_A', 'CELL_A']\n", + "chr = [\"chr1\", \"chr2\", \"chr3\"]\n", + "\n", + "df = pd.DataFrame(columns=['chr', 'sequence', 'TAG'])\n", + "\n", + "for i, (tag, chromosome) in enumerate(zip(tags, chr)):\n", + " if i == 2:\n", + " sequence = \"A\" * 200\n", + " df.loc[i] = [chromosome, sequence, tag]\n", + " else:\n", + " sequence = ''.join(np.random.choice(list('ACGT'), size=200))\n", + " df.loc[i] = [chromosome, sequence, tag]\n", + "\n", + "df.to_csv('data/dummy_data.txt', index=False, sep='\\t')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sOv6pRP4_f41" + }, + "source": [ + "# Basic Training Example" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nYDLoBS9__IH" + }, + "source": [ + "Below we provide an example of the training using the debug flag. This will only train the model on a single sequence for a minimum of 5 epochs with a patience parameter of 2 epochs. We also show that the data file can be overrided within the CLI call to integrate the new dataset. It is important to set data.load_saved_data=False, so that the additional metadata used to train the model is regenerated." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ArMTlzed_V-E", + "outputId": "533f9def-d65a-46d7-aaeb-2157b251529b" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model:\n", + " _target_: src.dnadiffusion.models.unet.UNet\n", + " dim: 200\n", + " channels: 1\n", + " dim_mults:\n", + " - 1\n", + " - 2\n", + " - 4\n", + " resnet_block_groups: 4\n", + "data:\n", + " _target_: src.dnadiffusion.data.dataloader.get_dataset\n", + " data_path: data/dummy_data.txt\n", + " saved_data_path: data/encode_data.pkl\n", + " load_saved_data: false\n", + " debug: true\n", + "optimizer:\n", + " _target_: torch.optim.Adam\n", + " lr: 0.0001\n", + "diffusion:\n", + " _target_: src.dnadiffusion.models.diffusion.Diffusion\n", + " timesteps: 50\n", + " beta_start: 0.0001\n", + " beta_end: 0.2\n", + "training:\n", + " distributed: false\n", + " precision: float32\n", + " num_workers: 1\n", + " pin_memory: false\n", + " batch_size: 1\n", + " sample_batch_size: 1\n", + " num_epochs: 2200\n", + " min_epochs: 5\n", + " patience: 2\n", + " log_step: 50\n", + " sample_epoch: 50000\n", + " number_of_samples: 10\n", + " use_wandb: false\n", + "\n", + " 0% 0/2200 [00:00 tuple[Dataset, Dataset, list[int], dict[int, str]]: encode_data = load_data( data_path, saved_data_path, load_saved_data, - output_path, ) if debug: x_data = encode_data["X_train"][:1] @@ -50,11 +48,13 @@ def get_dataset_for_sampling( saved_data_path: str, load_saved_data: bool, debug: bool, - output_path: str | None = None, cell_types: str | list[str] | None = None, ) -> tuple[Dataset, Dataset, list[int], dict[int, str]]: train_data, val_data, cell_num_list, numeric_to_tag_dict = get_dataset( - data_path, saved_data_path, load_saved_data, debug, output_path + data_path, + saved_data_path, + load_saved_data, + debug, ) if cell_types is None: @@ -80,7 +80,9 @@ def get_dataset_for_sampling( elif len(matches) > 1: print(f"Warning: '{cell_type_query}' matches multiple cell types: {matches}. Please be more specific.") else: - print(f"Warning: Cell type '{cell_type_query}' not found in dataset. Available types: {list(tag_to_numeric.keys())}") + print( + f"Warning: Cell type '{cell_type_query}' not found in dataset. Available types: {list(tag_to_numeric.keys())}" + ) if not filtered_cell_nums: raise ValueError(f"No valid cell types found. Available types: {list(tag_to_numeric.keys())}") @@ -116,7 +118,6 @@ def load_data( data_path: str, saved_data_path: str, load_saved_data: bool, - output_path: str | None = None, sequence_length: int = 200, ): # Preprocessing data @@ -125,6 +126,7 @@ def load_data( encode_data = pickle.load(f) else: + output_path = saved_data_path encode_data = preprocess_data(data_path, output_path) # Creating sequence dataset @@ -136,7 +138,9 @@ def load_data( # Create test dataset using chr1 val_df = encode_data["validation_df"] - val_test_seq = np.array([one_hot_encode(x, nucleotides, sequence_length) for x in val_df["sequence"] if "N" not in x]) + val_test_seq = np.array( + [one_hot_encode(x, nucleotides, sequence_length) for x in val_df["sequence"] if "N" not in x] + ) X_val = np.array([x.T.tolist() for x in val_test_seq]) X_val[X_val == 0] = -1 diff --git a/train.py b/train.py index 49178d3d..81e949a4 100644 --- a/train.py +++ b/train.py @@ -52,7 +52,7 @@ def train( val_dl, _ = get_dataloader(val_data, local_batch_size, num_workers, distributed, pin_memory) # Metrics - if rank_0 == 0 and use_wandb: + if rank_0 and use_wandb: init_wandb() global_step = 0