Skip to content

Commit c5f79ae

Browse files
authored
Parallel plot refactor (#247)
* Fix duplicate plotting in CRISPRessoBatch aggregate * Refactor mulltiprocessing plots in CRISPRessoBatch * Refactor multiprocessing plots in CRISPRessoCORE * Refactor multiprocessing plots for CRISPRessoAggregate
1 parent 4ed5e24 commit c5f79ae

4 files changed

Lines changed: 161 additions & 390 deletions

File tree

CRISPResso2/CRISPRessoAggregateCORE.py

Lines changed: 31 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import glob
1010
from copy import deepcopy
1111
from concurrent.futures import ProcessPoolExecutor, wait
12+
from functools import partial
1213
import sys
1314
import argparse
1415
import numpy as np
@@ -18,7 +19,7 @@
1819
from CRISPResso2 import CRISPRessoShared
1920
from CRISPResso2 import CRISPRessoPlot
2021
from CRISPResso2 import CRISPRessoReport
21-
from CRISPResso2.CRISPRessoMultiProcessing import get_max_processes
22+
from CRISPResso2.CRISPRessoMultiProcessing import get_max_processes, run_plot
2223

2324

2425
import logging
@@ -108,6 +109,13 @@ def main():
108109
process_pool = ProcessPoolExecutor(n_processes)
109110
process_results = []
110111

112+
plot = partial(
113+
run_plot,
114+
num_processes=n_processes,
115+
process_pool=process_pool,
116+
process_results=process_results,
117+
)
118+
111119
#glob returns paths including the original prefix
112120
all_files = []
113121
for prefix in args.prefix:
@@ -491,13 +499,10 @@ def main():
491499
'quantification_window_idxs': include_idxs,
492500
'group_column': 'Folder',
493501
}
494-
if n_processes > 1:
495-
process_results.append(process_pool.submit(
496-
CRISPRessoPlot.plot_nucleotide_quilt,
497-
**nucleotide_quilt_input,
498-
))
499-
else:
500-
CRISPRessoPlot.plot_nucleotide_quilt(**nucleotide_quilt_input)
502+
plot(
503+
CRISPRessoPlot.plot_nucleotide_quilt,
504+
nucleotide_quilt_input,
505+
)
501506

502507
plot_name = os.path.basename(this_window_nuc_pct_quilt_plot_name)
503508
window_nuc_pct_quilt_plot_names.append(plot_name)
@@ -529,13 +534,10 @@ def main():
529534
'quantification_window_idxs': include_idxs,
530535
'group_column': 'Folder',
531536
}
532-
if n_processes > 1:
533-
process_results.append(process_pool.submit(
534-
CRISPRessoPlot.plot_nucleotide_quilt,
535-
**nucleotide_quilt_input,
536-
))
537-
else:
538-
CRISPRessoPlot.plot_nucleotide_quilt(**nucleotide_quilt_input)
537+
plot(
538+
CRISPRessoPlot.plot_nucleotide_quilt,
539+
nucleotide_quilt_input,
540+
)
539541

540542
plot_name = os.path.basename(this_nuc_pct_quilt_plot_name)
541543
nuc_pct_quilt_plot_names.append(plot_name)
@@ -571,13 +573,10 @@ def main():
571573
'quantification_window_idxs': consensus_include_idxs,
572574
'group_column': 'Folder',
573575
}
574-
if n_processes > 1:
575-
process_results.append(process_pool.submit(
576-
CRISPRessoPlot.plot_nucleotide_quilt,
577-
**nucleotide_quilt_input,
578-
))
579-
else:
580-
CRISPRessoPlot.plot_nucleotide_quilt(**nucleotide_quilt_input)
576+
plot(
577+
CRISPRessoPlot.plot_nucleotide_quilt,
578+
nucleotide_quilt_input,
579+
)
581580

582581
plot_name = os.path.basename(this_nuc_pct_quilt_plot_name)
583582
nuc_pct_quilt_plot_names.append(plot_name)
@@ -633,15 +632,10 @@ def main():
633632
'plot_path': plot_path,
634633
'title': modification_type,
635634
}
636-
if n_processes > 1:
637-
process_results.append(process_pool.submit(
638-
CRISPRessoPlot.plot_allele_modification_heatmap,
639-
**allele_modification_heatmap_input,
640-
))
641-
else:
642-
CRISPRessoPlot.plot_allele_modification_heatmap(
643-
**allele_modification_heatmap_input,
644-
)
635+
plot(
636+
CRISPRessoPlot.plot_allele_modification_heatmap,
637+
allele_modification_heatmap_input,
638+
)
645639

646640
crispresso2_info['results']['general_plots']['allele_modification_heatmap_plot_names'].append(plot_name)
647641
crispresso2_info['results']['general_plots']['allele_modification_heatmap_plot_paths'][plot_name] = plot_path
@@ -668,15 +662,10 @@ def main():
668662
'plot_path': plot_path,
669663
'title': modification_type,
670664
}
671-
if n_processes > 1:
672-
process_results.append(process_pool.submit(
673-
CRISPRessoPlot.plot_allele_modification_line,
674-
**allele_modification_line_input,
675-
))
676-
else:
677-
CRISPRessoPlot.plot_allele_modification_line(
678-
**allele_modification_line_input
679-
)
665+
plot(
666+
CRISPRessoPlot.plot_allele_modification_line,
667+
allele_modification_line_input,
668+
)
680669
crispresso2_info['results']['general_plots']['allele_modification_line_plot_names'].append(plot_name)
681670
crispresso2_info['results']['general_plots']['allele_modification_line_plot_paths'][plot_name] = plot_path
682671
crispresso2_info['results']['general_plots']['allele_modification_line_plot_titles'][plot_name] = 'CRISPRessoAggregate {0} Across Samples for {1}'.format(
@@ -778,13 +767,7 @@ def main():
778767
'save_png': save_png,
779768
'cutoff': args.min_reads_for_inclusion,
780769
}
781-
if n_processes > 1:
782-
process_results.append(process_pool.submit(
783-
CRISPRessoPlot.plot_reads_total,
784-
**reads_total_input,
785-
))
786-
else:
787-
CRISPRessoPlot.plot_reads_total(**reads_total_input)
770+
plot(CRISPRessoPlot.plot_reads_total, reads_total_input)
788771

789772
plot_name = os.path.basename(plot_root)
790773
crispresso2_info['results']['general_plots']['summary_plot_root'] = plot_name
@@ -801,13 +784,7 @@ def main():
801784
'save_png': save_png,
802785
'cutoff': args.min_reads_for_inclusion,
803786
}
804-
if n_processes > 1:
805-
process_results.append(process_pool.submit(
806-
CRISPRessoPlot.plot_unmod_mod_pcts,
807-
**unmod_mod_pcts_input,
808-
))
809-
else:
810-
CRISPRessoPlot.plot_unmod_mod_pcts(**unmod_mod_pcts_input)
787+
plot(CRISPRessoPlot.plot_unmod_mod_pcts, unmod_mod_pcts_input)
811788

812789
plot_name = os.path.basename(plot_root)
813790
crispresso2_info['results']['general_plots']['summary_plot_root'] = plot_name

CRISPResso2/CRISPRessoBatchCORE.py

Lines changed: 39 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import os
99
from copy import deepcopy
1010
from concurrent.futures import ProcessPoolExecutor, wait
11+
from functools import partial
1112
import sys
1213
import traceback
1314
from datetime import datetime
@@ -337,6 +338,13 @@ def main():
337338
process_results = []
338339
process_pool = ProcessPoolExecutor(n_processes_for_batch)
339340

341+
plot = partial(
342+
CRISPRessoMultiProcessing.run_plot,
343+
num_processes=n_processes_for_batch,
344+
process_results=process_results,
345+
process_pool=process_pool,
346+
)
347+
340348
window_nuc_pct_quilt_plot_names = []
341349
nuc_pct_quilt_plot_names = []
342350
window_nuc_conv_plot_names = []
@@ -559,15 +567,10 @@ def main():
559567
'sgRNA_intervals': sub_sgRNA_intervals,
560568
'quantification_window_idxs': include_idxs,
561569
}
562-
if n_processes_for_batch > 1:
563-
process_results.append(process_pool.submit(
564-
CRISPRessoPlot.plot_nucleotide_quilt,
565-
**nucleotide_quilt_input,
566-
))
567-
else:
568-
CRISPRessoPlot.plot_nucleotide_quilt(
569-
**nucleotide_quilt_input,
570-
)
570+
plot(
571+
CRISPRessoPlot.plot_nucleotide_quilt,
572+
nucleotide_quilt_input,
573+
)
571574
plot_name = os.path.basename(this_window_nuc_pct_quilt_plot_name)
572575
window_nuc_pct_quilt_plot_names.append(plot_name)
573576
crispresso2_info['results']['general_plots']['summary_plot_titles'][plot_name] = 'sgRNA: ' + sgRNA + ' Amplicon: ' + amplicon_name
@@ -587,15 +590,10 @@ def main():
587590
'sgRNA_intervals': sub_sgRNA_intervals,
588591
'quantification_window_idxs': include_idxs,
589592
}
590-
if n_processes_for_batch > 1:
591-
process_results.append(process_pool.submit(
592-
CRISPRessoPlot.plot_conversion_map,
593-
**conversion_map_input,
594-
))
595-
else:
596-
CRISPRessoPlot.plot_conversion_map(
597-
**conversion_map_input,
598-
)
593+
plot(
594+
CRISPRessoPlot.plot_conversion_map,
595+
conversion_map_input,
596+
)
599597
plot_name = os.path.basename(this_window_nuc_conv_plot_name)
600598
window_nuc_conv_plot_names.append(plot_name)
601599
crispresso2_info['results']['general_plots']['summary_plot_titles'][plot_name] = 'sgRNA: ' + sgRNA + ' Amplicon: ' + amplicon_name
@@ -617,15 +615,10 @@ def main():
617615
'sgRNA_intervals': consensus_sgRNA_intervals,
618616
'quantification_window_idxs': include_idxs,
619617
}
620-
if n_processes_for_batch > 1:
621-
process_results.append(process_pool.submit(
622-
CRISPRessoPlot.plot_nucleotide_quilt,
623-
**nucleotide_plot_input,
624-
))
625-
else:
626-
CRISPRessoPlot.plot_nucleotide_quilt(
627-
**nucleotide_plot_input,
628-
)
618+
plot(
619+
CRISPRessoPlot.plot_nucleotide_quilt,
620+
nucleotide_quilt_input,
621+
)
629622
plot_name = os.path.basename(this_nuc_pct_quilt_plot_name)
630623
nuc_pct_quilt_plot_names.append(plot_name)
631624
crispresso2_info['results']['general_plots']['summary_plot_titles'][plot_name] = 'Amplicon: ' + amplicon_name
@@ -644,15 +637,10 @@ def main():
644637
'sgRNA_intervals': consensus_sgRNA_intervals,
645638
'quantification_window_idxs': include_idxs,
646639
}
647-
if n_processes_for_batch > 1:
648-
process_results.append(process_pool.submit(
649-
CRISPRessoPlot.plot_conversion_map,
650-
**conversion_map_input,
651-
))
652-
else:
653-
CRISPRessoPlot.plot_conversion_map(
654-
**conversion_map_input,
655-
)
640+
plot(
641+
CRISPRessoPlot.plot_conversion_map,
642+
conversion_map_input,
643+
)
656644
plot_name = os.path.basename(this_nuc_conv_plot_name)
657645
nuc_conv_plot_names.append(plot_name)
658646
crispresso2_info['results']['general_plots']['summary_plot_titles'][plot_name] = 'Amplicon: ' + amplicon_name
@@ -671,15 +659,10 @@ def main():
671659
'fig_filename_root': this_nuc_pct_quilt_plot_name,
672660
'save_also_png': save_png,
673661
}
674-
if n_processes_for_batch > 1:
675-
process_results.append(process_pool.submit(
676-
CRISPRessoPlot.plot_nucleotide_quilt,
677-
**nucleotide_quilt_input,
678-
))
679-
else:
680-
CRISPRessoPlot.plot_nucleotide_quilt(
681-
**nucleotide_quilt_input,
682-
)
662+
plot(
663+
CRISPRessoPlot.plot_nucleotide_quilt,
664+
nucleotide_quilt_input,
665+
)
683666
plot_name = os.path.basename(this_nuc_pct_quilt_plot_name)
684667
nuc_pct_quilt_plot_names.append(plot_name)
685668
crispresso2_info['results']['general_plots']['summary_plot_labels'][plot_name] = 'Composition of each base for the amplicon ' + amplicon_name
@@ -693,15 +676,10 @@ def main():
693676
'conversion_nuc_to': args.conversion_nuc_to,
694677
'save_also_png': save_png,
695678
}
696-
if n_processes_for_batch > 1:
697-
process_results.append(process_pool.submit(
698-
CRISPRessoPlot.plot_conversion_map,
699-
**conversion_map_input,
700-
))
701-
else:
702-
CRISPRessoPlot.plot_conversion_map(
703-
**conversion_map_input,
704-
)
679+
plot(
680+
CRISPRessoPlot.plot_conversion_map,
681+
conversion_map_input,
682+
)
705683
plot_name = os.path.basename(this_nuc_conv_plot_name)
706684
nuc_conv_plot_names.append(plot_name)
707685
crispresso2_info['results']['general_plots']['summary_plot_labels'][plot_name] = args.conversion_nuc_from + '->' + args.conversion_nuc_to +' conversion rates for the amplicon ' + amplicon_name
@@ -756,15 +734,10 @@ def main():
756734
'plot_path': plot_path,
757735
'title': modification_type,
758736
}
759-
if n_processes_for_batch > 1:
760-
process_results.append(process_pool.submit(
761-
CRISPRessoPlot.plot_allele_modification_heatmap,
762-
**allele_modification_heatmap_input,
763-
))
764-
else:
765-
CRISPRessoPlot.plot_allele_modification_heatmap(
766-
**allele_modification_heatmap_input,
767-
)
737+
plot(
738+
CRISPRessoPlot.plot_allele_modification_heatmap,
739+
allele_modification_heatmap_input,
740+
)
768741

769742
crispresso2_info['results']['general_plots']['allele_modification_heatmap_plot_names'].append(plot_name)
770743
crispresso2_info['results']['general_plots']['allele_modification_heatmap_plot_paths'][plot_name] = plot_path
@@ -791,13 +764,9 @@ def main():
791764
'plot_path': plot_path,
792765
'title': modification_type,
793766
}
794-
if n_processes_for_batch > 1:
795-
process_results.append(process_pool.submit(
796-
CRISPRessoPlot.plot_allele_modification_line,
797-
**allele_modification_line_input,
798-
))
799-
CRISPRessoPlot.plot_allele_modification_line(
800-
**allele_modification_line_input,
767+
plot(
768+
CRISPRessoPlot.plot_allele_modification_line,
769+
allele_modification_line_input,
801770
)
802771

803772
crispresso2_info['results']['general_plots']['allele_modification_line_plot_names'].append(plot_name)

0 commit comments

Comments
 (0)