Skip to content

Commit 447cffb

Browse files
committed
restore gene mean calculation during model initialization
1 parent f128816 commit 447cffb

1 file changed

Lines changed: 8 additions & 0 deletions

File tree

src/perturbo/models/_model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,14 @@ def setup_mudata(
214214
log_cpm = np.log(library_size / 1e6)
215215
mdata[modalities.rna_layer].obs[size_factor_key] = log_cpm - log_cpm.mean()
216216

217+
# add gene mean estimate (legacy, for simulator)
218+
gene_mean_key = "_gene_mean"
219+
rna_adata = mdata[modalities.rna_layer]
220+
mean_counts = np.mean(rna_adata.X, axis=0)
221+
if isinstance(mean_counts, np.matrix): # occurs when summing sparse array
222+
mean_counts = mean_counts.A1
223+
rna_adata.var["_gene_mean"] = mean_counts
224+
217225
# add indices to enable pyro subsampling of local vars
218226
mdata[modalities.rna_layer].obs = mdata[modalities.rna_layer].obs.assign(_ind_x=lambda x: np.arange(len(x)))
219227
index_field = fields.MuDataNumericalObsField(

0 commit comments

Comments
 (0)