diff --git a/README.md b/README.md index eaa0d9db..89bb7ee8 100644 --- a/README.md +++ b/README.md @@ -71,17 +71,6 @@ uv run data/master_dataset_and_filter.py ``` which will download all the necessary data and create a file `data/master_dataset.ftr` containing the full ~3.59 million dataset and a file `data/filtered_dataset.txt` containing the same subset of sequences as above. A rendered version of this code is provided at `notebooks/marimo_master_dataset_and_filter.ipynb`. - -To run data curation process as a notebook, a marimo notebook file can be found at `notebooks/marimo_master_dataset_and_filter.py`. - -This notebook can be opened/run with the following command: -```bash -uvx marimo edit notebooks/marimo_master_dataset_and_filter.py -``` - -All of the data processing files make use of uv to manage dependencies and so all libraries are installed when you run the above commands. See [uv documentation](https://docs.astral.sh/uv/guides/scripts/) for more information on how to run uv scripts and [marimo documentation](https://docs.marimo.io/) for more information on how to run marimo notebooks. - - ### Training To train the DNA-Diffusion model, we provide a basic config file for training the diffusion model on the same subset of chromatin accessible regions described in the data section above. @@ -91,6 +80,8 @@ To train the model call: uv run train.py ``` +This runs the model with our predefined config file `configs/train/default.yaml`, which is set to train the model for a minimum of 2000 epochs. The training script will save model checkpoints for the lowest 2 validation loss values in the `checkpoints/` directory. The path to this checkpoint will need to be updated in the sampling config file for sequence generation, as described in the Model Checkpoint section below. + We also provide a base config for debugging that will use a single sequence for training. You can override the default training script to use this debugging config by calling: ```bash @@ -100,7 +91,12 @@ uv run train.py -cn train_debug ### Model Checkpoint We have uploaded the model checkpoint to [HuggingFace](https://huggingface.co/ssenan/DNA-Diffusion). Below we provide an example script that handles downloading the model checkpoint and loading it for sequence generation. +If you would like to use a model checkpoint generated from the training script above, ensure you update the `checkpoint_path` within the config file `configs/sampling/default.yaml` to point to the location of the model checkpoint. By default, this is set to `checkpoints/model.safetensors`, so you will need to ensure that the model checkpoint is saved in this location. Both `pt` and `safetensors` formats are supported, so you can use either format for the model checkpoint. An example of overriding the checkpoint path from the command line is described in the sequence generation section below. + ### Sequence Generation + +#### Generate using Hugging Face Checkpoint + We provide a basic config file for generating sequences using the diffusion model resulting in 1000 sequences made per cell type. To generate sequences using the trained model, you can run the following command: ```bash @@ -119,6 +115,20 @@ Base generation utilizes a guidance scale 1.0, however this can be tuned within uv run sample_hf.py sampling.guidance_scale=7.0 sampling.number_of_samples=1 sampling.sample_batch_size=1 ``` +Both above examples will generate sequences for all cell types in the dataset. If you would like to generate sequences for a specific cell type, you can do so by specifying the `sampling.cell_type` parameter in the command line. For example, to generate a sequence for the K562 cell type, you can run: + +```bash +uv run sample_hf.py data.cell_types=K562 sampling.number_of_samples=1 sampling.sample_batch_size=1 +``` +or for both K562 and GM12878 cell types, you can run: + +```bash +uv run sample_hf.py 'data.cell_types="K562,GM12878"' sampling.number_of_samples=1 sampling.sample_batch_size=1 +``` +Cell types can be specified as a comma separated string or as a list. + +#### Generate using Local Checkpoint + If you would prefer to download the model checkpoint from Hugging Face and use it directly, you can run the following command to download the model and save it in the checkpoint directory: ```bash wget https://huggingface.co/ssenan/DNA-Diffusion/resolve/main/model.safetensors -O checkpoints/model.safetensors @@ -129,6 +139,11 @@ Then you can run the sampling script with the following command: uv run sample.py ``` +If you would like to override the checkpoint path from the command line, you can do so with the following command (replacing `checkpoints/model.pt` with the path to your model checkpoint): +```bash +uv run sample.py sampling.checkpoint_path=checkpoints/model.pt +``` + ## Examples ### Training Notebook @@ -152,14 +167,29 @@ Both examples were run on Google Colab using a T4 GPU. DNA-Diffusion is designed to be flexible and can be adapted to your own data. To use your own data, you will need to follow these steps: -* Prepare your data in the same format as the DHS Index dataset. The data should be a tab separated text file with the following columns: +* Prepare your data in the same format as the DHS Index dataset. The data should be a tab separated text file contains at least the following columns: * `chr`: the chromosome of the regulatory element (e.g. chr1, chr2, etc.) * `sequence`: the DNA sequence of the regulatory element * `TAG`: the cell type of the regulatory element (e.g. K562, hESCT0, HepG2, GM12878, etc.) -* It's expected that your sequences are 200bp long, however the model can be adapted to work with different sequence lengths by the dataloading code at `src/dnadiffusion/data/dataloader.py`. You can change the `sequence_length` parameter in the function `load_data` to the desired length, but keep in mind that the original model is trained on 200bp sequences and so the results may not be as good if you use a different length. +additional metadata columns like start, end, continuous accessibility are allowed but not required. + +* It's expected that your sequences are 200bp long, however the model can be adapted to work with different sequence lengths by the dataloading code at `src/dnadiffusion/data/dataloader.py`. You can change the `sequence_length` parameter in the function `load_data` to the desired length, but keep in mind that the original model is trained on 200bp sequences so the results may not be as good if you use a different length. * The model is designed to work with discrete class labels for the cell types, so you will need to ensure that your data is in the same format. If you have continuous labels, you can binarize them into discrete classes using a threshold or some other method. This value is contained within the `TAG` column of the dataset. +The data loading config can be found at `configs/data/default.yaml`, and you can override the default data loading config by passing the `data` parameter to the command line. For example, to use a custom data file, you can run: + +```bash +uv run train.py data.data_path=path/to/your/data.txt data.load_saved_data=False +``` + +It is important to set `data.load_saved_data=False` to ensure that cached data is not used, and instead is regenerated from the provided data file. This will ensure that the model is trained on your own data. This will overwrite the default pkl file, so if you would like to keep the original data, you can set `data.saved_data_path` to a different path. For example: + +```bash +uv run train.py data.data_path=path/to/your/data.txt data.load_saved_data=False data.saved_data_path=path/to/your/saved_data.pkl +``` + + ## Contributors ✨ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/docs/en/emoji-key)): diff --git a/configs/data/sampling.yaml b/configs/data/sampling.yaml new file mode 100644 index 00000000..c0dd9d19 --- /dev/null +++ b/configs/data/sampling.yaml @@ -0,0 +1,6 @@ +_target_: src.dnadiffusion.data.dataloader.get_dataset_for_sampling +data_path: "data/K562_hESCT0_HepG2_GM12878_12k_sequences_per_group.txt" +saved_data_path: "data/encode_data.pkl" +load_saved_data: True +debug: False +cell_types: null # null means all cell types, or specify like ["K562", "HepG2"] diff --git a/configs/sample.yaml b/configs/sample.yaml index 5c9f43fc..9bf3cbe8 100644 --- a/configs/sample.yaml +++ b/configs/sample.yaml @@ -1,5 +1,5 @@ defaults: - model: unet - - data: default + - data: sampling - diffusion: default - sampling: default diff --git a/configs/sample_hf.yaml b/configs/sample_hf.yaml index dba6f058..a61c4f92 100644 --- a/configs/sample_hf.yaml +++ b/configs/sample_hf.yaml @@ -1,5 +1,5 @@ defaults: - model: unet_pretrained - - data: default + - data: sampling - diffusion: default - sampling: default_hf diff --git a/configs/sampling/default.yaml b/configs/sampling/default.yaml index 07b1683a..9e2cec14 100644 --- a/configs/sampling/default.yaml +++ b/configs/sampling/default.yaml @@ -1,4 +1,5 @@ -checkpoint_path: "model.safetensors" +checkpoint_path: "checkpoints/model.safetensors" +# checkpoint_path: "checkpoints/DNA-Diffusion.pt" sample_batch_size: 10 number_of_samples: 1000 guidance_scale: 1.0 diff --git a/notebooks/sequence_generation.ipynb b/notebooks/sequence_generation.ipynb index 3544959d..cbc9ab0b 100644 --- a/notebooks/sequence_generation.ipynb +++ b/notebooks/sequence_generation.ipynb @@ -13,12 +13,8 @@ "cell_type": "code", "execution_count": null, "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, "collapsed": true, - "id": "LAvoel8bZ72P", - "outputId": "52ac0db4-8d14-4c3a-c039-117cc10de35b" + "id": "LAvoel8bZ72P" }, "outputs": [], "source": [ @@ -27,13 +23,13 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "b17jNfakaODu", - "outputId": "08b22fc9-cbec-4835-ad65-c70451e0bb19" + "outputId": "99c9ae5e-fd9d-49d9-901d-513750eb4967" }, "outputs": [ { @@ -59,13 +55,13 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "8NVFthhyaXKG", - "outputId": "feb9a27f-be98-4853-d824-7ca65180ae55" + "outputId": "ed804d9a-2191-46e5-bc90-abd2e5d4da05" }, "outputs": [ { @@ -76,11 +72,12 @@ " _target_: src.dnadiffusion.models.pretrained_unet.PretrainedUNet.from_pretrained\n", " pretrained_model_name_or_path: ssenan/DNA-Diffusion\n", "data:\n", - " _target_: src.dnadiffusion.data.dataloader.get_dataset\n", + " _target_: src.dnadiffusion.data.dataloader.get_dataset_for_sampling\n", " data_path: data/K562_hESCT0_HepG2_GM12878_12k_sequences_per_group.txt\n", " saved_data_path: data/encode_data.pkl\n", " load_saved_data: true\n", " debug: false\n", + " cell_types: null\n", "diffusion:\n", " _target_: src.dnadiffusion.models.diffusion.Diffusion\n", " timesteps: 50\n", @@ -92,18 +89,18 @@ " number_of_samples: 10\n", " guidance_scale: 1.0\n", "\n", - "config.json: 100% 153/153 [00:00<00:00, 1.09MB/s]\n", - "model.safetensors: 100% 378M/378M [00:06<00:00, 57.2MB/s]\n", + "config.json: 100% 153/153 [00:00<00:00, 1.26MB/s]\n", + "model.safetensors: 100% 378M/378M [00:01<00:00, 308MB/s]\n", "Model sent to cuda\n", "Found cell types: ['GM12878_ENCLB441ZZZ', 'HepG2_ENCLB029COU', 'K562_ENCLB843GMH', 'hESCT0_ENCLB449ZZZ']\n", "Generating 10 samples for cell GM12878_ENCLB441ZZZ\n", - "100% 2/2 [00:09<00:00, 4.94s/it]\n", + "100% 2/2 [00:09<00:00, 4.87s/it]\n", "Generating 10 samples for cell HepG2_ENCLB029COU\n", - "100% 2/2 [00:09<00:00, 4.61s/it]\n", + "100% 2/2 [00:09<00:00, 4.64s/it]\n", "Generating 10 samples for cell K562_ENCLB843GMH\n", - "100% 2/2 [00:09<00:00, 4.69s/it]\n", + "100% 2/2 [00:09<00:00, 4.76s/it]\n", "Generating 10 samples for cell hESCT0_ENCLB449ZZZ\n", - "100% 2/2 [00:09<00:00, 4.79s/it]\n" + "100% 2/2 [00:09<00:00, 4.83s/it]\n" ] } ], @@ -113,13 +110,13 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "zQRv3dADa_3P", - "outputId": "111ed2c3-5cad-4b3c-ff09-11de2646ea17" + "outputId": "5e1930dc-f25d-48fd-edb5-0b6706a80128" }, "outputs": [ { @@ -129,55 +126,55 @@ "Displaying sequences from: data/outputs\n", "\n", "--- Cell Type: GM12878 (GM12878_ENCLB441ZZZ.txt) ---\n", - "TTCCTCTGTGCTTCTTAGACATGCATGAAAACTTGGATTGATCTTTTGCATGATTTCATGAAATGACTTCTCTGTTGCTTTCATATCCCTGACACAGAAAAGTTTCTGTGTGTCATGATTGCTAGGTGGAATTTCCCTTTGGGCAAGAGTTTCAGTTTGAGGAGCCATTCTTATACTTCCGGATGCATCTTTTCTATGGT\n", - "TATGTGTAGCACAGACTTGCATCGCCGAAACTTCATGATTCCGCTGAAACTCTGAGTGTCAATATACAGGTTGTCGTTTTGTTTGAGGAAGTGAACGAAACCGCGGTTTTGGTGTCAACGAAGCTGCAACTAGAGCTGGAGTTTCATTTTCTTCTTTTGGCATTTTAAATATCAGTAGAACTAGGTTATGGCAGCTAGTA\n", - "GTCTTACTACTTCCACATGTAGGGGATCACTCAAGCCTAGGGGAATTCCCTGAGGCCAGGAGGGACTTTCCAGACTGAAAGAGTGAAATGCAAGCCAGCTCTTACCTGAGATCTGAAGATGCAACTACACCAAAATAAACAGGTACATCATAGCTAGAAATGAGTGGGTTTTCCCAGCATTTACACACTTGGCTAAGCAT\n", - "TGACTGTAGGAGAGCCCAGGGTCTTACGAGGACGGGGAGATCATGCTGGTGTGAGAGAGCTGGCCCTCTGCTTAATCAGATCCTCTCTGAGAGGGAAACAGGAAGTGAGAACTGCTAAATCAAGTGCACGAAAGCCATGCAAAGCTGTGGTTTCAGTTTGAGAAAGTAAAAGTTGGGTGACCGTTAACAATGAGTCAAAT\n", - "GGTCTCTTACTAATCAACATGAAAGTATGTAAAGTAAACTAACCGATGGTTTGGAGGTGGCACAGAACCGGGTGGTTCATGTGACTGCTTCTTTCAAGCAGGAGAGCCTCTGTTCAGTTTCGGGCTGAAACCACTTTCACCCAGGGGTGCAGAAGACAGCCTCCCAGCTGTATGGAAATGGAGCAAATTTGCACATTTAA\n", - "CAGTAAGGGGGATCCACATCAGCATCCAAGATTTCACATTGCTACTGTGGTCACCATGAGACATCAGAAGCGATAGGGGGAGTTTCCCAAGGGACTTGGGCAGTCCCAGTGGGTACAGAGGGTGAAACGCAAGATCTAGGGAGCTCTCAGCCAGAAACCACCCTACCACCCCTTAACCCATAAACCTCACAATGTGTCAC\n", - "CAGTCTTTGAGAGATGTCTAACTGAAGCAACGCCTCTGTGATCTCTGGAGAATTTATTTTAAAACAGGGTGAAGTGGTTTTGAAAGAGTACATTCAAAAGAGAAAACAAACAAAAAAAAAGCCATTGAGTTTCGGTTCAGATTTTGGCAGGCATGTGAGTGCCTCACAATTTGGTCAAAGAACGAAACCAAAAAGACTTT\n", - "TGCATTTGGTTGTGTCAGTTACTAGTCCAAGGATCTAGAGGATTGTGGGTAGTAATAATGTGAAGTGTTGGTTCTAAAACTCAAACTAAAGGTTGTCACACAACTGCCTGGGAAACCCCTAGAACATGAGAATGAACAGAGATGAAACCGAAACAACGTTTTTTCTAGTTTTTTAAAAAGTTTCAATTCTATGAGAATAA\n", - "GTGTAGTCTTCTAAAGCTGGGGTACATTTGAGTCATCTATGCCACGTGGCAAGGACATTTCCAAACTCTTGAAATAATGAAAGTAAAACTAAGACAAAAGTTAAAAATTCAACTCCTTGCACAGGGGGAAGTATTGTGGTTTTCAACAAGGAAAACATTCTGAGTCATTTCTGCCATGACGAGAAGGGGAGCTGAAAAGC\n", - "TGAGACCACTGCTTGGTACACTCAACAGCCAGCCTGATTCACTCAGTTTCCGTTACTTCTCGGCTTTTAAATAGAGGAATAAAACCCAATTCGTTTAGTTTCACTTCCTGTTCTTTTCATTTAACACCTTCCTCCTTCCCCATGTTAGCAAACATCATTCTGAAAATCCCCTTCAGAAACATCTAACTTCAGCTGCCGAT\n", + "TTGAGTTGTTTGAATAGTAACATGTACATACATAGTTTTGGTTCCCATTTGGAGTTAAATCTATGCATGACACATTCATTTCATTTGCAATACTGGTTAGTCATTAATTAGACTGAAATTTGATTGGTACTTCACTTTCTATCTCTAGTTATTGTGGTTTCTGGTGGTGTGTGCTGATCAAGGGGAGAATACAGTGCACT\n", + "AAGGACTAACTAAAACCGAAAACACCTAAGATTTTGGAATTTCCCTCCTTGCTTTCTATGTGCCTAGCTTGGAGACCATGAGGGTGTTGTTGAAACGAGAAATACATGAGGGTAGACTTTCACCTGTCGTTTCCATGAGTAAAATGATACTAAAGTAAAGTTTGAAGACAACGAGTTGTGTTACATGCTGTTTCTTTCTC\n", + "ACCACTAACAAAAACTGTTTTCAACGGAACCATTTTAGTTTCAGAACTTACCACATCTATGTGTGAGTCAGCAGAGTGAAAAGCAAAATAGGAAATATTAAGAAACACGCTCAAATAAATGTGATGTTATGGGGGAAATCCTGACTTTGTCACATGAAAAGGAAATGCAGCGTCATGGAAATAACTACCACCTACAAAGT\n", + "TACTTTCAAAAGATATATGAAAGGAAACATATTGTCTGTTTCATTTTCATTTTCTGAATTCTATCTAGCCAACCATAGTTGAACTCACAGTGTTAGCAGGGTTTAACCCCTGATTTGCCCCATTTTCGGTATAGCTCTGCACAAACTCATGACTGCCTGTCTTCTTTCAGGCAAAAGCTTTGTTCAAATTAGAGCAACTG\n", + "AGTCACTTGCGCCTAAGTGCACACAATGAGAGTAGGTGCTGCTGGAGGCTTCTCCCTGCAGCTGTTTGGGCTGCATGGGACCCGCTGACCCCACTCACCTAAACCAGCGGGCCTTCAAGCCGCACTCAGCTTCAGTGATTTATTTTAATAAACTGAACCCCAGAAGCACAGAGAGAGAAACAAAGAAAAACAAAACATCT\n", + "GGATTGCTATGAGCTAAAATACTTTTCTACCACCTCAGAGCTGTTAACCACTTCCTGTTTTAATTTAAGCATCTGCCAACCACAAAAACAACCCAAGTAAAACCCCAAGTGCTGACTTTCATTTCTGCAACTACTCTTTCCCCTTGTTACAAGTATGTGATGTTGCTACTATTAAAAGACGCCAAATGAACAATCTGATT\n", + "GATGTTATTTGGGTCGCAGTTCCCTTCTGTAACAAGTCATTGTATGGTTTACCACCCTTTGTGAAGTCACAGAAACAACAAAGCTAACTTATGCATGCCTTCTGGAAAACCCTCTGAAATGACAGGGCATCATGACTTCCTTTCTTACTGAGCTTTGAAGTATTAATGAGCAACGTTGCTACATAACTGGCCGCTGACAG\n", + "AATCTCTCAACAGGTTACTTTCTTAGCTGTCTGTCTAAAACTGAGTTTCAGTTTCAGTCTCTAGAGTTGACACTGCTTCATCTTGGCAGCAACCAAAGAATCTACATCTTGAAGGCAATGGTGCATTTCAGCTGAGGGAGAGAGGTCATGTCAAAATCACACTGGATGTAGCATCGGGCAGACACATATACGAATAGATT\n", + "ACTGAAATGCCACAGTGGAGGTTTCACAGTGAGAAACCAAACTAGCTGAGTGACACCAGCCCATGGCCTCTGCCAACTGATGAAATGATGAGATGTGCTTAAAACATATGTCAGAAACTTTCTACTGAAGTGCAGCCTCAGAATGATAGTTTCAGTTCCTATTTTTGGTTTTATGTCTACATCTTATGCAGAATTATGCT\n", + "GTCTTTCTAGCTCATGTGCATTCATCATACAGAACTGAAAGAAACAGTTTACTGGCTGGGACCACATTTTAGTCAGTTTTGATTTTAAAATTATTTAGGGTTTCTAAGGTGAAAAACAAGTGTTTTCAGCCCTTAGAGCTGAAACATGAATACTAATTGCTCCATGGTTTGAAAACCTGTAAGCATTGCTGTATTGATCA\n", "-------------------------\n", "\n", "--- Cell Type: HepG2 (HepG2_ENCLB029COU.txt) ---\n", - "CTAAAGCTCACAATGTGTAAGTAAACAGAAGGAACCATGTGTGTTAAATGTTTGGATGTCCTGGATAATAAGCAATCATTCCAAAGGTCTGTAAGCTCTATGACATTCAATCCATGTTTGCTTTGTATTTTTTTAAGCTGCTTGAAAGGATGGGCTCATTCACTGAACTGTGAACTTCCTGGTTTCTACAGAGGTCACAT\n", - "GCTGTAGTGCAAGAGGAGAACATACCATTGCAGGACTCACTGCTCCTCTGTGTACAGAGTTCAGCCGTCTGTTTATGTTTCAGGGGAATGGGGTATGGATAGGCCAGGTCACAAGCAGCTGATACCCGAGACATTAATTCAAGTCTTCACAGCAGAGGATGGCTGAGTGGCTGGTCAGGCTGGAGAGGGCTGAGAGGCCT\n", - "CTTGATGCTTTTGTGCTGCTAACCATGAGCCTCTGTGTCAGGCTGCGTTTCCAGTGTTTAGGCTGAGGGAATTGGATAAATGGCAACTTTGAACAAAAGGCTTAATGATGTGACGTGGTTGGAGGGGTCCAAGGTGAATCTTGTGCGCTATGTGCAAGGTCTTTCTGAGCCTACATTCCCGGCAGTCGACCCTGGGCTGC\n", - "CAGAGTTGTACAATGCCACTGTCCACTTTAGCCCTTCTGCACCTTTGACTGTGTATCAGTGAACTTTGGTCTCTGTTTTAACTGCTCCTCATGTGTCTATGTACCTGAACAAATACATGACCTCTCTATGTACACATGTACTAAACAAATCCATGTGAGAAACAGGTGCCATGATACTACGGGCTGCACGCAGTTTACTG\n", - "TGTCTGCACAGCTCTGGGGCAGAGTCCATATCCTCAGGCCTCTCACTCTGTGGGACTGTTCAAACAGTCATTCAGGCCCGGTTAGGATAGGATGCCTCTGAAAACTTGAACTCTGGCACTGGGTGGCACAAAGTTCACTTTGCATACATAAAATAGATGTTTGCTCTTTGCATTCAAAGCCGCCTATACATAAAGCCACG\n", - "CTCTCTGTGTCATGCCCCTGTCTCAGATCCCTTCCCCACATGTGCATGGAGGGCCCTCTGTGCCCTCTGACTTTGCACCTTTGCTTTCCTGCTTGTGGCTGTCCTGGCCCACCAGCTCTTGCTGTAAAACAAACACTAATTGTCCAGAACACACACCTCTTCTCTGGTGGCACTGACTGCACAACCAGCATAGCAGATAA\n", - "TTGATGCTTTGCGCGTAGTCCCGGCCGCGCTGAGGGGGGGCCTGGGGAGGCGGGGAGGGGGCTGGCCACAGGCCCTGGGCCACCCTGAGCCCGCCACGACCTCTCCCAGCCCCGCTCTCCCTCCCCAGCCGGAGGGCGGCGCCCGGAGGAGCCTTGGTGCTGCGGTTTGGAGGCGCCCGCAGGTGTCTGGACTCCGGGCA\n", - "CCCTGGCCTTGAACACAGGTCTCTTGAGCAAAGCTCTCCAAACTGGTTGGTTGATGAATGTAGTAACATATGTTTATTCAGAGTCTTTCAACAGCTTTGTAAATTTTTGGAGACAGGGTCCAAAGTGCAAAGCCTCAAGTTTTCTCTACATATTTACTAAAGATAGTTGTTTGTACAAGAGGCTTACAGGAATAAATGCA\n", - "AGACCGGCCTGGGACCAAATCACAGAGCTCAAGGCCTCCTGCCCTGTGAATCTAGATCATAGAACTAAGAGTCTGTGAACTTTGGACTCAAATTTTTCCCATGGGAACAAGGACAAAGGCAAACCTGTGTGAATGAAAACTGAGCCATTCCTGGTGGCAACTGCTGAATGAATATTATAGCACATGGAGAGAACATACTT\n", - "GACGAGGAGCCTCTGGGGAAAAAACTGAGCTGGTAAGCATTTACCGGGCAATCCAGCTGGGCAGGGGAGGGCAACCCAGACTGCCTGTCCAATGCCACCACCTCAGGCCAACTGCCCACCAGACAGAGAGGGTCACTGAGAGCCATAGAGCAGGGGAAAGCACGTGGGAAACAAGTGGGCTGGCGGGGGAGTTCACTCTG\n", + "AATTCTGTCACTTACTGAATGCATGCTGGCCTCTCTTCTATTAACCGCAATGTTAGACAAGAGAACATGGTAAGTTAATGATTGTAATTTATTTACTGAATGAATAGATTGATGCCATTGTTTTTGAAGAGCAAAGTCCTCTGTGGGCACATATTGGATAAAGCCTCTCTGTACAACCTCTCTTCTGCATCTCAGTTTCC\n", + "GTCATGGGAAAAGGCTTTCGCTTTGTGTCGCTGGTGCAAAGGGCAGTCCCTGGAGGTGACAGGCCTCCACTGCAGCCTCCATTCACACGCAGGGCAGAGGCCATTTTCAGCCGCATCTTCCCAGTTCATTCCTGATCCAGAGGTCACATATGTTAATCATTAAACAGAGTAGTAGGTGCTTAGTACCTGACAGCTCTGCA\n", + "CCAATGTCCTCCCAGCAACTGCTAAACAAATATACTTTACAATGGCCCTCTGCCCTTTGAAGTCAGGTGGTCTAGGTTAATATTGAGGGCACTTGCTATCAGAAGGAAAGTCATCCCAGCCAGAGGCCGGTTTGCCCTTGCAACAGCCACCAGTGTTTGATTTACAGGGAAGCGAGGCCAAGGGACATGGAGAAGGCACT\n", + "GTAATCCAAGATAAAAAACATGGCCATGCAATAGCTCTACACAGTTGAAGTTCAAACACTGACCTCAAAGGATCCACATGGTCCATAGGCTGAGTGAGAGGCAGGGCATGGTGCCACAGATGACTTTGCCGTGGGCATCTGAGGTATCAATGCCACTTTTTTTTTTTTTTGAAACAAAGTCCAGCTTTGTTGCCCAGGCA\n", + "AGTCCATTTTGCTCATGAAGCCCTGTACTTTACAGCCCAAGGTGCAAGGGATGTCACCATCAGACAGTAAGAAGCACATAGAATTGTGTAACTCTGGTCAGAGGTGAGGTCAAAGTGCAGGTCAGGTTATTTTTCCATTGACCTCTTTGCTCAATGGATACCCTGTGCTATATATAGTCTTTCTCAATATTGCAAAAGCA\n", + "GTCCTAACAGATAGTACAAAATAAAAGTATGTGCAACGCTTGTTTGCAGAGAACAGTTTACCCACATGTTTGTTTAACTCTCCTGGCACCAGGCCAAAGTTCAATAAATATTGAGCAATATGACCTTTGGCAAGTAGTTCTCCATGACCTGTGAATCCAGGCATTTCCATGAGGACGCCTGGGGTACATACAAATATGTA\n", + "CTGGGACTTGCTAGTGGATGTTGGGAAGCCAGTAAGTACCCTCCCTGGCCTTTGCTCTCCTGACGCCCACCACCAGGAACCGTCTGAGTCTGGGTAAACAGAAGCCTCTGTGCATGGATGGCACATTTATAAGCAACAGTCAGTGTCATGCAGCAGCCCAGGGTGTGAAGTGGACATTTGCTCAGAACAAGCTGTCTGAA\n", + "AAATTATTCTTGTTCCTGCACAGGCAAAATTCCAAAGACCTGTTTGCTCTTGACAAACACAATGCACCATGTAAATAGCTCTTATGACAAATAAATAGAATCTTCTGTGAATCTGCCTTTCAAGTATATGGCATTATTTGTTCTGTTTGCATCTCTGTAAATATTTGGCCAGGCTCAGGTGTCTAAAAAGAACCGAATAT\n", + "CTGAATCTTCTCTTTGAGACAAACTTATTCTGGTTTTTAGCTCCTGAGCATGGAAGCAGCATTAACATTCGTGTTCTCTGGTTCAAGGTCAGGATGATGGTGAGTAAACAGGTTTGTTTGCTTGATGACAACACACTTTGTCCCTTACATTTTTTGTTTTCATGTTCCCTTTGTTTTGCACTGTTTTTCCAGAGTTCTTC\n", + "CCTGGCATGTGTGCATACATGTTGTCAGCTCTGGCTCAGAGTTGGTCAAAGTCCGACTGGGTGACATTGCCATGAGGAGCATGACAGGGCCTGGGGGTCATGCTCCCCGTGCCTTCCTGCCCCCTTGCGTAACTCTGCCGCCCCCACCAGCAGGGGGCCGGGGGCAGGGAGGGCCTGGGCAGGGCCCACATGGCGCTGGA\n", "-----------------------\n", "\n", "--- Cell Type: K562 (K562_ENCLB843GMH.txt) ---\n", - "GAGTAATAGCTGTACAGTAGCCCAGGGAGGACTCCTGATCCATGCTTCCTCCCTTTCTTCTTCCTGGGATATCTTTCCTAGCCACTTTTGTTTCCCACATGCAGTAGTCAGATAAGGTCCCGCGAGTTCTCCATCAGCATCTAACCACATTGACATGCCAGTGCCTATATGAACACACTGGCTTGTTTCATCCTTCTCGA\n", - "TCCTTCATCTGCTGCCTCCATCCTTCCTTCCTCTTCTCCCTCCCTTCCTCTTCCTCTGTCTTTCTCTCCTCCTTCTTCTTTCGCCTTCCCTTGCCCTCTGCCTCCTGGCCCCGCCGACCTCTACTGCCTTTATCTCTGGCCCCAGCCCTGCGACTATCCCGGAGATGGGGGCGGGTCTGCTGTAGGGACAGGTTCTCTGG\n", - "GTAGAGCTCTCTTTCCTCTTTCTGTAAAGTGTAAGTATCCCCAAAGGCACAGCGGCCTGAGGGAGATAAGGTTTGGGCCAGAGACTGGTGCCAAGCAGCTGGTATTCGTCACTATCAGCAACAGCCAGTGTTCCAGTAATGACAGAAATAAAAACATGGGAGTCACATCTATTCTATTTGTCTGGTGACATTCTTGTATA\n", - "CCTCCTTCAGGTTGTCCTTTCCTCCTCCCACAAGGCTGCCTCCTGTTGATGGTGGGAGAGGCTCTGCCCCTACATTCCAGGAAGAGGGGGAGGATCTGGTTTGAGTCTCCTGATAACCTGCCCGATAAGGTCGCAGGCGCTGAGGGAGAGAAGAGAAGGGAGGATCGGACCAGGGAGATAAGATGGAGGCAAGGGTCAGC\n", - "TAGATAAATCCTAGAAAGCCACCGCTAATTGCCGAACCCTTATCTGTCCCAGAAACCTTCTTGGAAGATTAGATGGAGATTATGAGTGATCAGGGGCAAGTGGGAGGGACGTGGACGGACAGCCCCACCATGCACAGCAGGCGAATAAAAACTGTGAGGACCATTTAAAAGATAGAGTAGACAGAAGAGACAACAAGGCT\n", - "CTTCACCCACACTGGGATAGCGTGAACCCGTGGAAGCATGTTTGTGTCTTATCTCCTCACAGACCTTGCTGTCTTGGACATGCTCTATCTTGTCAATTCCTGATAACTGTCGACAAGGTCAGGCTATGTCATAAGACGCTCTGACCGTGACCTTGACTTTACTCTAACTCTCGTTATCAGCAGAAAAATTAAGTTAAAAC\n", - "GACCAGACGTTATCAGCTTGGGCCACTTAGGAGGCTCCAAGGAATTCAGTGGTGTGCTGTTGTGCGGAGCCTGTGCTTACAGGCCAGATGCCACACAGCGCCAAGGCTCTTATCTGTCCTCCCCCGCCCAGCAGCTCCTCCCTGGATGCTAGACTGACAGGTTTACTAGCATGCTGTCTTTACCTGATTCGCCCATCTTC\n", - "GGACTCAAAGGCGTTCTTATAATTTGCGATATGTGCTCTCAAACATTCACACACAACAACCAAGACCTGCGTCAAAATTTGCAACAGAGCCAAGACACAGGAAGTAGGGAGGGCCGTGGCCCGAGGCTCAGGGTTCTCAGAAGGACCGCGCTGCGCCGGCTCCCTTGCCAGTGAGTGTCCTCCTGCTTGTAGGTGAGTAG\n", - "ATCACAAGGCAGGCATCTTATCTGTTTGGCTCCAGAAGTCATTTGAACCAAGACCTTCTGGTTCTAAGATAAGCAAGCCAGCAATGATAAGAAGATATTGCTTACGTCATTTCTCGTCTGTGGGGTTAGAGAAGACCATCTGGGACTGAGGAAATTGAGGCTGGGCGGAGAGACTTCAGTAAGTGAAGAAAAAAAAAAAA\n", - "TATGTGCCCAAGTTAGGGCTCATCTGTTTTCTAGCCCTGCAGTATAAGTCTTTTATCTCAGTGGCAGGCTGATTATACTATGTTGTCAGCACTTTTTCCACAAGAAAAATCTTGCTAGGTAAATGGCAGAAGCAAGCACCTTCCCCTCCATGCTCTGCCTGTGGATGCTATCTGTTGGTCTGTTCCTGTTCACATGCCTG\n", + "TAGATCCATCTCAAGGGCTTCCAGGATAAAATGAGCAAGCGACTGAGTAGACTATGTTACCTGCAAGCAAGCAGGGTGAATAGTCATGAACTGTCACCGTACAGCTGTATTCTGAATAAAGGGGCGCTCTGGTTAGCATCTCTGCTGCAGCTGAGGCCCCAGTTCTGAGGTCTTTGTATCACTTACCACCTTCTGAGTGC\n", + "TAGGGATGCGATGTAGGATATTAGGACAACTTTCACTGACTCATAACAGTAGACTGTGGTTTGACTCATTTTATCTCACCTGACATTGCTTTCCACAAACACACTCATATCTAGGTCTGGTGGAACTAAGAACAGGCACAAATAGATAAGAGGATAGATTTATTGCTTTCTTTCAGTCTGGGTCTGGAAGAACATGAGTA\n", + "TGGAAGATCATGGGGTCAGGGTTGTTTGATAGTTTATAAACCGCGAAGTATTCTCGAAACTAGCTTATCACCTAGAGACAGATGAGTCATACCTGTGCCTTGTCCGGCCTCTGGCCACTCAGCCACACCGAGTCTCCATATGGCAGAGGTCGAAGCTAGAAAATAAGCAAGCTCAACACGGAGCAGACACGTGCGAGGAA\n", + "AAGAGGTTTACTGCAGCTGTGAGGAGAAGGGCAAGTGAGCCTGCTGGAGTTAGGGCAGACCCTGCTATCTTAGCAGAAAGATTCAGGGCAGGGAAGCTGGCAGGTTTGTGTTAGCGTCTTATCAGGACTGTGCACACGCTTCAACACCTGAAGATTATTTATAGTGGCCACAGTCCGATAGCATTTCGCTGCAAGCTGAT\n", + "CACCTTCCTCCCCCTGCCTCCTCCTCTTTTCTAGTGATGAGTCATCATCTGTTGTTATGTCTTGACCTTCCTTATCTCAATGACTGGAACTCTGCATAGGAAGGGTCTCTTAGACAGGGGCCAAGCCCGGAGCTTCCACACTGGCCTTTATCTGAGCGCTGTCCTTACCAGGCTCTTGGTCTTGTGCTTCTTCCGTGCTA\n", + "ATCACCCCAGAGATGTCTTCTGAGTTGAGTGAGTGTGTATATGGCTGTGAGATAAGATAACAGTTTGTAAGGCAGAAACAGAGATAACATCACCCAGCTTATTGTGGCAGGAGGCATGGAGTTCTCTGCCCTGTGACGAGGGCAGCTGGGGACCTACTGACCTGGGATCGTTTACTAAGGCAGGACTTGCGTTCCTGCCT\n", + "TAAATACCACATCAGCAAATCCAACATGTGGACTCCTGGCTGAATAGGCCTTAGATCCACCACAAAGTGGGGGACTGCTCACACATGATGGGTTCAGATAGGGCCTAACTGCCACTCAGTCATTTCCTTCACTTCTGGCTTCTTTATTGCTTGATGTTCCTGTAAAAAGTGTCAAAATGGTCTGAAAGGCAGAGATAAAG\n", + "CTGTGTGTCAGCGGTGATGCTTCCACTGCCCACGAAGCCGTCTGGAGCAACTACCGCATGGGAGACCCAGGGCAGGAGATGGCAGGAGATGGGACGAGTCATGGCCCCTCCCTGCAAACCGCGAGAGTCACCGGGGCCCCGGAGCTGGGGCCGGACCGCAGCCTCACGCGCGGCTGTGGGCAGACCATGACAATTCACCT\n", + "ATCTTAATGATGAAATATATTCTGACTTATCAGCCAAAGGCGTGGGTGGACACTGACCTCACCAGAGACGGATAGAGCAAAGACTTATCATATCTTCACACATAGATAAGGATAATGCTGTCCTTGTGCTGCTAAAGATTATCTCTCAAGGTCAGAACTAACATTTCTCTAGACACAAGAAACATTCTCCCTAGTCCACA\n", + "TGGGCTCCTGATTTTGCATCTTATCTCTGTGTACAGTTATGGAGCAGTGGACCACAGGCCGCAGTTGCTCCCATGCTTCCACTATCTCACTCGCTGAAACCACGCGTAAGGCAATTCTCTATCTTCCGAGATAATGTGCTGAGGAAGACAAGGGTGGCCTGTGTAATCTATGGTGTCAGGTCAGCGTCCAGAACACAGGG\n", "----------------------\n", "\n", "--- Cell Type: hESCT0 (hESCT0_ENCLB449ZZZ.txt) ---\n", - "CACAGCTGAATACAAAAGCTCTCACAGTGGCGGAACTAAGGCTCCCTGCCCACACACAGCATCCAATGTCTCCTCTGCCTCTCAGCATTCATCCCCACAATAGCCATTGCTGTTAGGTGATGTAGATGGGGATTCCAGAGACAAATGCCTGTAGCTTCTAACTACATTCTTCTCCACTGGCAAGAGGGTTTGCAGGGTGC\n", - "TGTGTCTCCCAGGATGAAATGCCAAGCTTCACAATTGTGCGGCCGCGCCAGCTCACTGGGCTCCTGACAGCTGAGCCAAACTGAAATCAGCCCTGTGGCAGGTTAGCAGAGACATTCTATTCATTGTCTGTAACAATCTCCCAGTCATCTAAATGGGGCTAGAGGACAATCAGGACTGCAACAATGCTACCCTGACCGCT\n", - "ACTATGTAGAGAAATCTATCTTTTGTTGAAGGAATCTCCTAGGCTCTGGAGAGAGTTAGGGAAAAGCAGGAAGTTCTTGCAAGTATCTTATTACATTTCAAGCTATCAAGTGTGTCCTTTTTGAATGCAAATGCAGCAAGTACTTGGAGTCCTGAGGGAGTGGCAAATGTAAAAGCCTCTGTTGTGTGGACAGGAGAGTG\n", - "CTGTAGGGATGAGCTTAGGCAAAGTTCAACTCGGGAGGAGTAAAGCCTCTTAAATGCTTCCTTTACTCTGGGGCCCCCCCTGAGCTTCAGGTGAATAAGGAAGATTAGAGTCTCCTCCCAAGTCCGCTCCCATGGTGTGACATTACTAGCTCCCGGATGATACTCCAGTTTTCCCTGCGCCTCCCTGGCCTGGCCCCGCA\n", - "TACAAATTCCACAGAAGCTTGGCTAATGCCTCTCCAGGCTTCACTTATGCAAATCTGCAGGGAATCTGCCTGTTCTCTCTCCTTGGTTTCCTGTTACACTGGGTATTCCTTTTACTCCCTTCACAGTGGCGACTATGTTTGCAGGTGCACCACGTGGTTTGTGCTTAGGACTTCAGGAGGAAGGGGCATGTTGCACACTG\n", - "GAGCAGTAGGGCAGCAGGCCACAACCCAGCTTCAAGTTCCAAGCTCACAAATAATGAAGCACCCGGAACAACATCTGCATAAGTAAAGAGTTCTTCACTAAAGCATATAAAGCAAGCACTTGAGGGTGGTGCCTATGATGCAAGGTGCCTGAACAAGTTGCATCCAAGTTCTGAGGACACAATGGGGTGCGAGAGACTCT\n", - "CACTTTCCCAGCGCCTTCCAGCTTCTTACTCTGCGTGTGTGTTCTATCCCCATCTCTCCCACATGGGAGCTTTTAGCAGCTCAGCTCACCACAATGGGAAGCAAAGAAAAGAATGGCACCTGCAAAGGGGGCTTGTTCAGGAAAAGATGCCCTCTCTTGGTCAATTGGAAACGGAAAAAAATTGGCAATTTGGAAGGGGA\n", - "CTCTCCCAGTGGGGGCTTGACCAGAGCTTACAGAGGCTAATTATATCAGCAGGAAGATAAACAGGGGATGGGTCAGTGCTGTTGTTTTGCATAACAAAAGAGGAAGACAGCAGGGCTCTCCATGCAGCCCAGAGTCCTTAGGAGTCATTTCATTCCAGCTGGGGCACGTAGCGCACTTCTATAATTCAAGCACTAAACAC\n", - "TCCTGCGGTGGGTGGCTCCAACATCAGCTGTTTGCTCTTTCCCCTTAGTGGCCACCCCCTGCTGTGAGACCAGGCTGGTGCCTGAATCCAGGGGGAGGATCAAAGGGGACAAGGGAGGCTGTGGAGATTACTTCTCCTCAATTTGTTTACTTAAGTAGTTTTGAGCACAGTAGGCACTAATTGTTGACATTGGCTGTGGG\n", - "CTAGTAAGAATGAAGTGTTGGGAGCAGTTGTTTGAGGAATTGAGTGGGAGGTGAGGGAGAGAAGTAGCCTCTCAAGAACACATGAGCTGGAAAGAACCCTGGTTATTGTGATGCTTATCTTGCAGAAGGGCACAGCAGGAACACAAGCAGGTGAGAAGACACAAAAAGCTTTTCTTTGCTTCAATTAGCACTCCTCTGTA\n", + "GACCTAGATCTGACACCCCTCTAGCAAGCTCACAAGAGCGCACACCCTTGCCTGCAGCTTACTCACAGCTGGGATCCCTTCTCCTTGCATATGTGAATTTCTGCTATGAAGTCCTCGGTTCATGTTTTGCAGATTTGCAACAGGGCCCAACTGTTGTGATTGACACCTCCATTTTCTTTCTCATCTTAGGTTAAGGGCAT\n", + "ATGAGGCTTGAGCTCAGCAAAGCTTAGAAAGGAGAAAGAGCCTTCCCTGCCAATCCCAACTTGAATTTAGAATGCCTGTCTGCTCTTACAAAAGCCCAGGGAGGCCAGAGAAAAACTTGCTTCCAGGAACTGCAGGCATTCATGCTGACTCTCCCCAGCTAGTAATTAAGCTAAATATACCCTGATGAATAACAAAGAAG\n", + "CAGGGAGGAAAAAGGCATGCATGAAGTACAGAAGGAGCATCTTCCTCAGGGGGGAGTGAGCTGTGAGATTCTCTCCCACTGAGCCCTCCCAGCAAAGGCGGGACTTGAGTCAAATAGGCCCCCAAGCAAGGGGGAGAAGGGCCTGCCTGGAATTCATGGCCAGAAATCACCTGACACTGGGAAGGGGGCACGGTTTGATG\n", + "ACAACCCCTCCCTCTTCACAGGCATAAGTATTCAGGAATGCAGTGATTTCAAGACACAGGAGGAGAAGCAGACTGTTGCATTTGCATATGAAAAGAGAGCTGGGAAGATGTCCACAGGGGGTGTCTCTCTGGGAGTGGGGCGACAGTACAAGCTAGGCTGCCTTGGGCAGGGCAGGCAAGGAGGGAAGGGGAAAGAAGGG\n", + "CCTCACAATGTATCAGGCTCCCAGGTGGCTCTGTAACCTGGTAGCTGAACTAATACTTGCCACAACTGCACAAGTCCTCCCCTGCCCTTAGAATTTGCCCTTCAAGAGTTTCCTCTTTTGATTGCCTTTTCTCTTCTTGTTCAATTAACACTTAAGTGATTGGCCAGTTCCTCTGAAATGTTAACTACCTACTATGTTAG\n", + "ACTCCCTGAGGGCACTAAGCATGCCTGGACTTAGGGAAAGAGGGGTATAAAGGAAAAGGAAGGTCACACACTAGGGGTTCCTCTGATCCAGGTGATTGCCTTTCCCCTTGCGCCTTCCCAGCAATTAGATTGCTAATGTTGCCCAGCTGGAGTTCAGTGCTTAATGCCATGGGAATGGGATGTGGGTTTTAGCAGGAGGG\n", + "GGAACCTTGGATTCTTTGTCTCAGAGCTACTTACCCTGCACCACGTGTGGGCCCATAAAGCCTCCCTCAGTGCACAGTGTGAGGCTGCAAGCCTCCTCCAGGAGCAGGAACCTGGGGGGCCCCTTACAAATGCAGATTCCATCCCCTCTCTCTAGTACTCAATACATCTCAGTATATGGGCAAAGACCTGCACTTGTAAC\n", + "ACCTGGCAGCGAGTGTTAAGGATGCTGAATGGGGGCCACCAGCAGGTTAAAGGACATAGGAGAGAGATTGCCTTGGGAGATAGTTGCCTGAGATCCCACAGAGAGCTTGAGAGTCAAAGGAGTGGGTCCTGCTATGTGTGAGAATTGGCAGCGGGATCCAAGCTTCTCAAGAGAAAAGCAGTGTGATACCATCATTTCCA\n", + "CAGCCCTCCTTACAGGGGAGGAACTGGGATTTGTCCCCCCCACAGGTCCTCACATTGAGGCCATTGTATCTGTTCTGTGGGGAGTCTTTAGTTCTTTTGTTTTTTCCTTGGGGTGCAGGCAGGGGGGTGGGAAGCAGGATTTTTGTTTCCTCAGATGTTGCCATCAGAAGAGACTGAGGAGGATGCAGAAATAGGAAAAA\n", + "ACCCCACAGTGCCTGCTTACAGCTACAGTTTGTCTCCGGAAGGCCCACCTCCCAGCAGGATGTCATGAACGCAGGCTCCATGTCAGCAGCACCAAGGGTGCTGGTGGCTTTTAGGCCCACTAGGGCAGGGCAGGTGTGAAGGGGTCAGGAGCGCAGCAGTCTCGGAGCAGGGGATGAGGGGTGGAGTACGGTAGGCGACC\n", "------------------------\n", "\n" ] @@ -227,13 +224,13 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "DJ7HaDUNbVQz", - "outputId": "363b4d42-88d2-429c-d2b2-fb5799827bba" + "outputId": "f1c9d977-753e-423e-bf48-fac18d37ce6c" }, "outputs": [ { @@ -244,11 +241,12 @@ " _target_: src.dnadiffusion.models.pretrained_unet.PretrainedUNet.from_pretrained\n", " pretrained_model_name_or_path: ssenan/DNA-Diffusion\n", "data:\n", - " _target_: src.dnadiffusion.data.dataloader.get_dataset\n", + " _target_: src.dnadiffusion.data.dataloader.get_dataset_for_sampling\n", " data_path: data/K562_hESCT0_HepG2_GM12878_12k_sequences_per_group.txt\n", " saved_data_path: data/encode_data.pkl\n", " load_saved_data: true\n", " debug: false\n", + " cell_types: null\n", "diffusion:\n", " _target_: src.dnadiffusion.models.diffusion.Diffusion\n", " timesteps: 50\n", @@ -263,11 +261,11 @@ "Model sent to cuda\n", "Found cell types: ['GM12878_ENCLB441ZZZ', 'HepG2_ENCLB029COU', 'K562_ENCLB843GMH', 'hESCT0_ENCLB449ZZZ']\n", "Generating 1 samples for cell GM12878_ENCLB441ZZZ\n", - "100% 1/1 [00:02<00:00, 2.31s/it]\n", + "100% 1/1 [00:02<00:00, 2.35s/it]\n", "Generating 1 samples for cell HepG2_ENCLB029COU\n", - "100% 1/1 [00:01<00:00, 1.83s/it]\n", + "100% 1/1 [00:01<00:00, 1.86s/it]\n", "Generating 1 samples for cell K562_ENCLB843GMH\n", - "100% 1/1 [00:01<00:00, 1.84s/it]\n", + "100% 1/1 [00:01<00:00, 1.87s/it]\n", "Generating 1 samples for cell hESCT0_ENCLB449ZZZ\n", "100% 1/1 [00:01<00:00, 1.88s/it]\n" ] @@ -279,13 +277,13 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "RRgGvJyDPXMv", - "outputId": "f911ff6f-5939-4199-ce56-f6aaeeb93828" + "outputId": "b2b6fb29-db46-47e7-ae45-90b2b08845f5" }, "outputs": [ { @@ -295,19 +293,19 @@ "Displaying sequences from: data/outputs\n", "\n", "--- Cell Type: GM12878 (GM12878_ENCLB441ZZZ.txt) ---\n", - "AAATTAGTATGTGAATGATTATGAGCTAAAAGAAGTTTCTATTAGTCATTTCTGTGAATAGGATGACTCTTGGTTTTAATATCTGAAACCATAAGTTTCAGTTCCTGTTTTCGGATTTTGAGTATGACATGTATGATACTCTAAGAAAACTCTAGACTAGCCACATTTTATACACTTTTAATATGCATTCTCAGTAGAGT\n", + "ACTCTCCAGCGTTGGGTTGGGGATGTCTGCAGATCTGGGTAATGTGCTCGACAGTAAGATTGAAACTGAAACTGAAACTAGAAAAGAGGAACTGAAACCAGCAGCACTGAGAAAACCCCAGACAGAACATTAGTTTCAGTTGCGGTATGTAACTCATATGACTCTAAGCAGTTACACTTTTGGGCATGGAAGCCTAACTC\n", "-------------------------\n", "\n", "--- Cell Type: HepG2 (HepG2_ENCLB029COU.txt) ---\n", - "TGGTTTAAGCACTAGTGCACACTACTCTCTATCCAGAAGTTTGCTTTGATGTAATGTGTACATATTTCACGTGGACTCTGGACTTTGGACTCTGTACATGATAGTGTTCTCATATGTTTACTTTTCACATTCCAGAAAATAAATACTGCACAGTCTCCTCTATAACAGTGTGCACTTTGTATGCTCTTAAGCTTTGTCCT\n", + "TAGATACACCTGATGTACAAATATTCCATGCACATGTTCACATTCCCACAGTTAATAATTGCGCAAGAGATCAAAGTTCAGGTACTATAAATACTCCCCCTTGCACAATACTACTATTAGTTTTACAGACACAATGTAAATATTGAGCAATACACTCTAGAGGTCTGGAGTTTTAGCAGGGAACTTTTCTTTTAGGGAGT\n", "-----------------------\n", "\n", "--- Cell Type: K562 (K562_ENCLB843GMH.txt) ---\n", - "ACACACTCTTATCTCTTACTGCAGAAGATGGTTATGCGCCGCTATCTTGCTTATCACTTCTTGCACATCCTGATACTTCACACACCCTGCATCATATGATCCCACACTGCCTGCTTATCTATGACTGTTTCAATGCATGTTATCTCTCATGTCCAGACAACCAAGTGACCACAGGACCTTGCCAGCCTCAGAAACAGGCC\n", + "CTCTTGATTTGCCTCCTTGTCTTCCCTCCTCGGCCCCCTCCCTCTGGTCTTCTTCTTATCTCTCCTGTGGACCGTTATCTCTCGGGCCTGCATGCACCTTATCTGCTCACTGGCAGGCCTCCCTTATCTCTGATCTTGCATGTGCCACTGCCTCACAATCTTATATTCTACGTCACCCACAACACTCGCCTCGTCAATGG\n", "----------------------\n", "\n", "--- Cell Type: hESCT0 (hESCT0_ENCLB449ZZZ.txt) ---\n", - "AGCACCCTACAGGGCCTGTAAAACAGATTATGGCCTGCAGGGGCAGCCTTTGCACTTCCCAGCCACCCCTAAATGCTCTTCATTCAAATGCAGACGTTGGTCTTGAATGCCCCTCAGGAGGCCTTCCCCTCAGGGCAATGGGATCCCAAGAGGAGGCAGTTAAACAGCAGGAGAGGATGTTTTGCAGGGGATGGGTACTT\n", + "AAGGAGGCAGTCCTAAGGAAGGGAATTCAACAAGACATTAGTGTTCCATAGAGGAAGAAGATGGCAACATTCCTTCTCCCTGCTGGGCTCCTGAGACTAACAAAGGAAGGAGAGCAATGGGGATCAATTGGATTTCTGGGGTGCTTACCAATGGAAAAGCCTTGAGGGCTCCTGGGGCCACAGTAGCTTTACAACTCTGC\n", "------------------------\n", "\n" ] @@ -316,6 +314,243 @@ "source": [ "display_sequences()" ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "L7vfEie7qh7Z" + }, + "outputs": [], + "source": [ + "!rm data/outputs/*.txt" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6RnXeJAtmOYO" + }, + "source": [ + "# Generating sequences for a specific cell type" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pp6ubIOWpmfh" + }, + "source": [ + "The previous examples generate sequences for each cell type used to train the model. We can also generate sequences for a subset of the available cell types, which can be accomplished via CLI overrides. The desired cell types can be provides as comma separated string or list." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mRfZkCfUqA8s" + }, + "source": [ + "Generating just K562 sequences:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "c1VVdZDImQuQ", + "outputId": "727ad3ac-f7dc-48d4-af53-30b370d8de85" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model:\n", + " _target_: src.dnadiffusion.models.pretrained_unet.PretrainedUNet.from_pretrained\n", + " pretrained_model_name_or_path: ssenan/DNA-Diffusion\n", + "data:\n", + " _target_: src.dnadiffusion.data.dataloader.get_dataset_for_sampling\n", + " data_path: data/K562_hESCT0_HepG2_GM12878_12k_sequences_per_group.txt\n", + " saved_data_path: data/encode_data.pkl\n", + " load_saved_data: true\n", + " debug: false\n", + " cell_types: K562\n", + "diffusion:\n", + " _target_: src.dnadiffusion.models.diffusion.Diffusion\n", + " timesteps: 50\n", + " beta_start: 0.0001\n", + " beta_end: 0.2\n", + "sampling:\n", + " checkpoint_path: ssenan/DNA-Diffusion\n", + " sample_batch_size: 1\n", + " number_of_samples: 1\n", + " guidance_scale: 1.0\n", + "\n", + "Matched 'K562' to 'K562_ENCLB843GMH'\n", + "Model sent to cuda\n", + "Found cell types: ['K562_ENCLB843GMH']\n", + "Generating 1 samples for cell K562_ENCLB843GMH\n", + "100% 1/1 [00:02<00:00, 2.50s/it]\n", + "Displaying sequences from: data/outputs\n", + "\n", + "--- Cell Type: K562 (K562_ENCLB843GMH.txt) ---\n", + "ACTCAGGATCCTTTGTGAGTGTCTTTGGGGTCTGCTGTTATCTGCGGTTTCTGTGGCTAGATTCTCTCTTTTCAGAGGGTCAAGATGCGTCTGCTGATCAAGTCAGAAGAAGTGGGAGTGTAGGAGCTGCAAACTGAAAGCCTCTCTCGGATATGTGTGTTTTGAAGATACCGTGGAATACAGGAATGTGACATAGAGAA\n", + "----------------------\n", + "\n" + ] + } + ], + "source": [ + "!uv run sample_hf.py data.cell_types=K562 sampling.number_of_samples=1 sampling.sample_batch_size=1\n", + "display_sequences()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Um-McXxrqq_8" + }, + "source": [ + "Generating both K562 and GM12878 sequences using a string CLI override" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "b7urnWY1qqlG", + "outputId": "90cd15c6-8cae-4508-af80-dcf2f278fc96" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model:\n", + " _target_: src.dnadiffusion.models.pretrained_unet.PretrainedUNet.from_pretrained\n", + " pretrained_model_name_or_path: ssenan/DNA-Diffusion\n", + "data:\n", + " _target_: src.dnadiffusion.data.dataloader.get_dataset_for_sampling\n", + " data_path: data/K562_hESCT0_HepG2_GM12878_12k_sequences_per_group.txt\n", + " saved_data_path: data/encode_data.pkl\n", + " load_saved_data: true\n", + " debug: false\n", + " cell_types: K562,GM12878\n", + "diffusion:\n", + " _target_: src.dnadiffusion.models.diffusion.Diffusion\n", + " timesteps: 50\n", + " beta_start: 0.0001\n", + " beta_end: 0.2\n", + "sampling:\n", + " checkpoint_path: ssenan/DNA-Diffusion\n", + " sample_batch_size: 1\n", + " number_of_samples: 1\n", + " guidance_scale: 1.0\n", + "\n", + "Matched 'K562' to 'K562_ENCLB843GMH'\n", + "Matched 'GM12878' to 'GM12878_ENCLB441ZZZ'\n", + "Model sent to cuda\n", + "Found cell types: ['K562_ENCLB843GMH', 'GM12878_ENCLB441ZZZ']\n", + "Generating 1 samples for cell K562_ENCLB843GMH\n", + "100% 1/1 [00:02<00:00, 2.32s/it]\n", + "Generating 1 samples for cell GM12878_ENCLB441ZZZ\n", + "100% 1/1 [00:01<00:00, 1.87s/it]\n", + "Displaying sequences from: data/outputs\n", + "\n", + "--- Cell Type: GM12878 (GM12878_ENCLB441ZZZ.txt) ---\n", + "TGCTCTCTGCATGTGGGTAATTTGTTAAACTAATGCTCAACACTCACTATCTATGCAACTGCATGTATTGCGGTAGAACAGTTTCTGTGTTCACAAAAGCAGGAACTTGGCTTCTGTTGGCAGTACCCTGGGTGACTGAGGATGTGGGGGGTATTGTATGCTGTCATGCTGAAACCCACAGAAGACTCTGAGAAGGCCAG\n", + "-------------------------\n", + "\n", + "--- Cell Type: K562 (K562_ENCLB843GMH.txt) ---\n", + "TTCTGTCTTATCTTGGAGCCTGATATGTTTCTGGCTGCTCAAGATAATCCCTTGTCATTCTTTATCAGGTTGAACTGATACATGACACAGGCACTGTTCCACACCTTCATGTTTTAGGAGATAAAGGACAAGACGGGTCATGCGTTCCTGTTAACATGCCTGTTCTGCTGTATTTGTTCTAGAAAAACAAGATTCTTGAA\n", + "----------------------\n", + "\n" + ] + } + ], + "source": [ + "!uv run sample_hf.py 'data.cell_types=\"K562,GM12878\"' sampling.number_of_samples=1 sampling.sample_batch_size=1\n", + "display_sequences()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EtRo3GRKq1Ez" + }, + "source": [ + "Generating both K562 and GM12878 sequences using a list CLI override" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "nFeavUEhq0pK", + "outputId": "34ba307f-b16f-43ed-8127-90ddf80be507" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model:\n", + " _target_: src.dnadiffusion.models.pretrained_unet.PretrainedUNet.from_pretrained\n", + " pretrained_model_name_or_path: ssenan/DNA-Diffusion\n", + "data:\n", + " _target_: src.dnadiffusion.data.dataloader.get_dataset_for_sampling\n", + " data_path: data/K562_hESCT0_HepG2_GM12878_12k_sequences_per_group.txt\n", + " saved_data_path: data/encode_data.pkl\n", + " load_saved_data: true\n", + " debug: false\n", + " cell_types:\n", + " - K562\n", + " - GM12878\n", + "diffusion:\n", + " _target_: src.dnadiffusion.models.diffusion.Diffusion\n", + " timesteps: 50\n", + " beta_start: 0.0001\n", + " beta_end: 0.2\n", + "sampling:\n", + " checkpoint_path: ssenan/DNA-Diffusion\n", + " sample_batch_size: 1\n", + " number_of_samples: 1\n", + " guidance_scale: 1.0\n", + "\n", + "Matched 'K562' to 'K562_ENCLB843GMH'\n", + "Matched 'GM12878' to 'GM12878_ENCLB441ZZZ'\n", + "Model sent to cuda\n", + "Found cell types: ['K562_ENCLB843GMH', 'GM12878_ENCLB441ZZZ']\n", + "Generating 1 samples for cell K562_ENCLB843GMH\n", + "100% 1/1 [00:02<00:00, 2.34s/it]\n", + "Generating 1 samples for cell GM12878_ENCLB441ZZZ\n", + "100% 1/1 [00:01<00:00, 1.87s/it]\n", + "Displaying sequences from: data/outputs\n", + "\n", + "--- Cell Type: GM12878 (GM12878_ENCLB441ZZZ.txt) ---\n", + "CTGCAGAGCGGGAGTGCCGGATGCCTGCACCTAATTAACACGACCGGCACATCTTTCGCGGGAAATCCCCTGAAGTGCTGAGGTGGCAACCGAGAACAGGTCAGGCGAGCAGCCAGGAGGGCCGGTTGCACTCTGCCCCTGCCAGTGAGGCTGCCCTACAGCATGACCAGCCAGGCTGGGAAGCTGATAGAAGTGCTGGG\n", + "-------------------------\n", + "\n", + "--- Cell Type: K562 (K562_ENCLB843GMH.txt) ---\n", + "TACACCCAGACAAGCTCTCTATCCCTCACCTTTTCTCTTGCTGTGATGTTATCAGAGCCTCACAGCTTGTCAGAGCGAGAGAGCTCCCATATGGCTTGAAGACTGGGACGGATGCAGGAGTCCAGGGTGAGGTCAGCTGTGTACCAGTGATAAAGACACATTTTCCAGAAGGGCAGTGCAGTCTTGAACTGTGACAAAGC\n", + "----------------------\n", + "\n" + ] + } + ], + "source": [ + "!uv run sample_hf.py 'data.cell_types=[K562,GM12878]' sampling.number_of_samples=1 sampling.sample_batch_size=1\n", + "display_sequences()" + ] } ], "metadata": { diff --git a/sample.py b/sample.py index 458da7ae..6c0adfa6 100644 --- a/sample.py +++ b/sample.py @@ -15,18 +15,6 @@ def sample( number_of_samples: int, guidance_scale: float, ) -> None: - print(data) - """numeric_to_tag_dict, cell_num_list, cell_list = ( - data["numeric_to_tag"], - data["cell_types"], - list(data["tag_to_numeric"].keys()), - ) - """ - - numeric_to_tag_dict = data[-1] - cell_num_list = data[-2] - - # Load checkpoint print("Loading checkpoint") if checkpoint_path.endswith(".safetensors"): checkpoint_dict = ( @@ -38,22 +26,25 @@ def sample( if torch.cuda.is_available() else torch.load(checkpoint_path, map_location="cpu") ) - # Load unet state dict - model.model.load_state_dict(checkpoint_dict["model"]) + model.load_state_dict(checkpoint_dict["model"]) + + device = "cuda" if torch.cuda.is_available() else "cpu" + model = model.to(device) + print(f"Model sent to {device}") + + numeric_to_tag_dict = data[-1] + cell_num_list = data[-2] - # Send model to device - print("Sending model to device") - model = model.to("cuda") if torch.cuda.is_available() else model + print(f"Found cell types: {[numeric_to_tag_dict[i] for i in cell_num_list]}") - for i in cell_num_list: - print(f"Generating {number_of_samples} samples for cell {numeric_to_tag_dict[i]}") + for cell_type in cell_num_list: + print(f"Generating {number_of_samples} samples for cell {numeric_to_tag_dict[cell_type]}") create_sample( model, - cell_types=cell_num_list, + cell_type=cell_type, sample_bs=sample_batch_size, conditional_numeric_to_tag=numeric_to_tag_dict, number_of_samples=number_of_samples, - group_number=i, cond_weight_to_metric=guidance_scale, save_timesteps=False, save_dataframe=True, @@ -64,7 +55,6 @@ def sample( @hydra.main(config_path="configs", config_name="sample", version_base="1.3") def main(cfg: DictConfig) -> None: print(OmegaConf.to_yaml(cfg)) - sampling_setup = {**cfg.sampling} model = hydra.utils.instantiate(cfg.model) data = hydra.utils.instantiate(cfg.data) diffusion = hydra.utils.instantiate(cfg.diffusion, model=model) @@ -72,7 +62,10 @@ def main(cfg: DictConfig) -> None: sample( data=data, model=diffusion, - **sampling_setup, + checkpoint_path=cfg.sampling.checkpoint_path, + sample_batch_size=cfg.sampling.sample_batch_size, + number_of_samples=cfg.sampling.number_of_samples, + guidance_scale=cfg.sampling.guidance_scale, ) diff --git a/sample_hf.py b/sample_hf.py index 9763890f..2b7d6494 100644 --- a/sample_hf.py +++ b/sample_hf.py @@ -27,11 +27,10 @@ def sample( create_sample( model, - cell_types=cell_num_list, + cell_type=cell_type, sample_bs=sample_batch_size, conditional_numeric_to_tag=numeric_to_tag_dict, number_of_samples=number_of_samples, - group_number=cell_type, cond_weight_to_metric=guidance_scale, save_timesteps=False, save_dataframe=True, diff --git a/src/dnadiffusion/data/dataloader.py b/src/dnadiffusion/data/dataloader.py index d76c4d81..5797e5b7 100644 --- a/src/dnadiffusion/data/dataloader.py +++ b/src/dnadiffusion/data/dataloader.py @@ -45,6 +45,49 @@ def get_dataset( return train_data, val_data, cell_num_list, numeric_to_tag_dict +def get_dataset_for_sampling( + data_path: str, + saved_data_path: str, + load_saved_data: bool, + debug: bool, + output_path: str | None = None, + cell_types: str | list[str] | None = None, +) -> tuple[Dataset, Dataset, list[int], dict[int, str]]: + train_data, val_data, cell_num_list, numeric_to_tag_dict = get_dataset( + data_path, saved_data_path, load_saved_data, debug, output_path + ) + + if cell_types is None: + return train_data, val_data, cell_num_list, numeric_to_tag_dict + + if isinstance(cell_types, str): + if "," in cell_types: + cell_types = [ct.strip() for ct in cell_types.split(",")] + else: + cell_types = [cell_types] + + tag_to_numeric = {tag: num for num, tag in numeric_to_tag_dict.items()} + + filtered_cell_nums = [] + for cell_type_query in cell_types: + if cell_type_query in tag_to_numeric: + filtered_cell_nums.append(tag_to_numeric[cell_type_query]) + else: + matches = [tag for tag in tag_to_numeric.keys() if cell_type_query.lower() in tag.lower()] + if len(matches) == 1: + filtered_cell_nums.append(tag_to_numeric[matches[0]]) + print(f"Matched '{cell_type_query}' to '{matches[0]}'") + elif len(matches) > 1: + print(f"Warning: '{cell_type_query}' matches multiple cell types: {matches}. Please be more specific.") + else: + print(f"Warning: Cell type '{cell_type_query}' not found in dataset. Available types: {list(tag_to_numeric.keys())}") + + if not filtered_cell_nums: + raise ValueError(f"No valid cell types found. Available types: {list(tag_to_numeric.keys())}") + + return train_data, val_data, filtered_cell_nums, numeric_to_tag_dict + + def get_dataloader( dataset: Dataset, batch_size: int, diff --git a/src/dnadiffusion/utils/sample_util.py b/src/dnadiffusion/utils/sample_util.py index a2dd46da..422a3d78 100644 --- a/src/dnadiffusion/utils/sample_util.py +++ b/src/dnadiffusion/utils/sample_util.py @@ -8,11 +8,10 @@ def create_sample( model: torch.nn.Module, - cell_types: list[int], + cell_type: int, sample_bs: int, conditional_numeric_to_tag: dict, number_of_samples: int = 1000, - group_number: list | None = None, cond_weight_to_metric: int = 0, save_timesteps: bool = False, save_dataframe: bool = False, @@ -23,11 +22,7 @@ def create_sample( final_sequences = [] num_batches = number_of_samples // sample_bs for n_a in tqdm(range(num_batches)): - if group_number: - sampled = torch.from_numpy(np.array([group_number] * sample_bs)) - else: - sampled = torch.from_numpy(np.random.choice(cell_types, sample_bs)) - + sampled = torch.from_numpy(np.array([cell_type] * sample_bs)) classes = sampled.float().to(model.device) if generate_attention_maps: @@ -35,7 +30,7 @@ def create_sample( classes, (sample_bs, 1, 4, sequence_length), cond_weight_to_metric ) # save cross attention maps in a numpy array - np.save(f"cross_att_values_{conditional_numeric_to_tag[group_number]}.npy", cross_att_values) + np.save(f"cross_att_values_{conditional_numeric_to_tag[cell_type]}.npy", cross_att_values) else: sampled_images = model.sample(classes, (sample_bs, 1, 4, sequence_length), cond_weight_to_metric) @@ -60,7 +55,7 @@ def create_sample( if save_timesteps: # Saving dataframe containing sequences for each timestep pd.concat(final_sequences, ignore_index=True).to_csv( - f"data/outputs/{conditional_numeric_to_tag[group_number]}.txt", + f"data/outputs/{conditional_numeric_to_tag[cell_type]}.txt", header=True, sep="\t", index=False, @@ -69,6 +64,6 @@ def create_sample( if save_dataframe: # Saving list of sequences to txt file - with open(f"data/outputs/{conditional_numeric_to_tag[group_number]}.txt", "w") as f: + with open(f"data/outputs/{conditional_numeric_to_tag[cell_type]}.txt", "w") as f: f.write("\n".join(final_sequences)) return diff --git a/train.py b/train.py index e099fdec..49178d3d 100644 --- a/train.py +++ b/train.py @@ -125,11 +125,10 @@ def train( for i in cell_num_list: create_sample( model, - cell_types=cell_num_list, + cell_type=i, sample_bs=sample_batch_size, conditional_numeric_to_tag=numeric_to_tag_dict, number_of_samples=number_of_samples, - group_number=i, cond_weight_to_metric=1, save_timesteps=False, save_dataframe=True,