From 83f032a3e657cccb76e1295437deaf31d3c0fa3b Mon Sep 17 00:00:00 2001 From: "Daniel C. Jones" Date: Fri, 7 Jun 2024 17:26:18 -0700 Subject: [PATCH] First pass at factorization model. Fix some minor bugs. Output factorization. Partial latent space mixture model implementation. Use mixture model when factoring. Add genewise offset parameters. Attempts to debug the factorization model. First attempt at new poisson matrix factorization model. Purge old stuff from the NB model. Sampling r Fix bug in multinomial sampling. A failed attempt at inducing sparsity. Revert "A failed attempt at inducing sparsity." This reverts commit 57b59dcc8003ce3f9f5529b7abaac1d22d960af3. Getting mixture model on phi working. Tweak beta prior. Gradual switch to gamma-gamma prior. Sampling phi and theta. Sampling r using the crt trick. Sampling s theta Sampling s phi Remove some dead code, fix a bug. Solve collapsing components by using NB marginal. Fix forgetting to sqrt variance. Fix kmeans initialization getting lost. Include missing prior means on s params. Lots of debugging junk. Figure out reasonable priors, delete some diagnostic shit. Optimization Tinkering with priors. Minor optimization. Add arguments to set some hyperparameters. Fix up clumsy rebase. Remove unecessary layer dimension from foreground_counts array. Faster sort on transcript repo. Remove unused counts array. PCA before Kmeans when initializing components assignments Remove some dead code. Debugging and tinkering with initialization. A possible implementation of 'effective volume' parameters. Different effective volume scheme, disabled for now. Add argument to exclude genes with regex, with xenium defaults. Dirichlet distributed theta, which seems to get better results. Also fixing phi dispersion parameter for now. Not ideal, but seems to get the best results. Option to write factorization parameters. Pin sprs at version 0.11.1 until compilation problems are resolved. See: https://github.com/rust-ml/linfa/issues/307 Fix compile errors from rebase. Quantile transformation of z coordinates. Revert "Quantile transformation of z coordinates." This reverts commit cc39ad0223cb62ad50be0000860f61e07fb6af9b. Add volume term when sampling z, fix log_effective_volume not updating with --no-cell-scales Implement partial factorization scheme. Use more unfactored genes by default. Fix rebase, bump version. Remove some more debugging code. Replace some problematic dependencies. --- Cargo.toml | 11 +- src/main.rs | 142 +++- src/output.rs | 454 +++++++----- src/sampler.rs | 1302 +++++++++++++++++++---------------- src/sampler/math.rs | 116 +--- src/sampler/transcripts.rs | 59 +- src/sampler/voxelsampler.rs | 19 +- 7 files changed, 1229 insertions(+), 874 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index cf763ed..f42bf4a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "proseg" description = "Probabilistic cell segmentation for in situ spatial transcriptomics" -version = "1.1.9" +version = "2.0.0" edition = "2021" authors = ["Daniel C. Jones "] repository = "https://github.com/dcjones/proseg" @@ -38,10 +38,7 @@ itertools = "0.12.1" json = "0.12.4" kiddo = "4.2.0" libm = "0.2.7" -linfa = "0.7.0" -linfa-clustering = "0.7.0" -ndarray = { version = "0.15.6", features = ["rayon"] } -ndarray-conv = "0.2.0" +ndarray = { version = "0.16.1", features = ["rayon", "matrixmultiply-threading"] } num-traits = "0.2.17" numeric_literals = "0.2.0" parquet = "52.2.0" @@ -49,4 +46,8 @@ petgraph = "0.6.3" rand = "0.8.5" rand_distr = "0.4.3" rayon = "1.7.0" +regex = "1.10.6" thread_local = "1.1.7" +faer = "0.19.4" +faer-ext = { version = "0.3.0", features = ["ndarray"] } +clustering = { version = "0.2.1", features = ["parallel"] } diff --git a/src/main.rs b/src/main.rs index d921953..2e428b5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,6 +13,7 @@ use hull::convex_hull_area; use indicatif::{ProgressBar, ProgressStyle}; use itertools::Itertools; use rayon::current_num_threads; +use regex::Regex; use sampler::transcripts::{ coordinate_span, estimate_full_area, filter_cellfree_transcripts, read_transcripts_csv, CellIndex, Transcript, BACKGROUND_CELL, @@ -61,6 +62,10 @@ struct Args { #[arg(long, default_value_t = false)] merfish: bool, + /// Regex pattern matching names of genes/features to be excluded + #[arg(long, default_value = None)] + excluded_genes: Option, + /// Initialize with cell assignments rather than nucleus assignments #[arg(long, default_value_t = false)] use_cell_initialization: bool, @@ -123,7 +128,7 @@ struct Args { /// Ignore the z coordinate if any, treating the data as 2D #[arg(long, default_value_t = false)] - no_z_coord: bool, + ignore_z_coord: bool, /// Filter out transcripts with quality values below this threshold #[arg(long, default_value_t = 0.0_f32)] @@ -139,6 +144,10 @@ struct Args { #[arg(long, default_value_t = 10)] ncomponents: usize, + /// Dimenionality of the latent space + #[arg(long, default_value_t = 100)] + nhidden: usize, + /// Number of z-axis layers used to model background expression #[arg(long, default_value_t = 4)] nbglayers: usize, @@ -257,10 +266,9 @@ struct Args { #[arg(long, value_enum, default_value_t = OutputFormat::Infer)] output_rates_fmt: OutputFormat, - /// Output per-component parameter values - #[arg(long, default_value = None)] - output_component_params: Option, - + // /// Output per-component parameter values + // #[arg(long, default_value = None)] + // output_component_params: Option, #[arg(long, value_enum, default_value_t = OutputFormat::Infer)] output_component_params_fmt: OutputFormat, @@ -292,6 +300,20 @@ struct Args { #[arg(long, value_enum, default_value_t = OutputFormat::Infer)] output_gene_metadata_fmt: OutputFormat, + /// Output cell metagene rates + #[arg(long, default_value=None)] + output_metagene_rates: Option, + + #[arg(long, value_enum, default_value_t = OutputFormat::Infer)] + output_metagene_rates_fmt: OutputFormat, + + /// Output metagene loadings + #[arg(long, default_value=None)] + output_metagene_loadings: Option, + + #[arg(long, value_enum, default_value_t = OutputFormat::Infer)] + output_metagene_loadings_fmt: OutputFormat, + /// Output a table of each voxel in each cell #[arg(long, default_value=None)] output_cell_voxels: Option, @@ -323,6 +345,30 @@ struct Args { /// Use connectivity checks to prevent cells from having any disconnected voxels #[arg(long, default_value_t = true)] enforce_connectivity: bool, + + #[arg(long, default_value_t = 300)] + nunfactored: usize, + + /// Disable factorization model and use genes directly + #[arg(long, default_value_t = false)] + no_factorization: bool, + + /// Disable cell scale factors + #[arg(long, default_value_t = false)] + no_cell_scales: bool, + + // Hyperparameters + #[arg(long, default_value_t = 1.0)] + hyperparam_e_phi: f32, + + #[arg(long, default_value_t = 1.0)] + hyperparam_f_phi: f32, + + #[arg(long, default_value_t = 1.0)] + hyperparam_neg_mu_phi: f32, + + #[arg(long, default_value_t = 0.1)] + hyperparam_tau_phi: f32, } fn set_xenium_presets(args: &mut Args) { @@ -339,6 +385,9 @@ fn set_xenium_presets(args: &mut Args) { args.cell_id_unassigned .get_or_insert(String::from("UNASSIGNED")); args.qv_column.get_or_insert(String::from("qv")); + args.excluded_genes.get_or_insert(String::from( + "^(Deprecated|NegControl|Unassigned|Intergenic)", + )); // newer xenium data does have a fov column args.fov_column.get_or_insert(String::from("fov_name")); @@ -406,9 +455,10 @@ fn set_visiumhd_presets(args: &mut Args) { args.z_column.get_or_insert(String::from("z")); // ignored args.cell_id_column.get_or_insert(String::from("cell")); args.cell_id_unassigned.get_or_insert(String::from("0")); - args.initial_voxel_size = Some(4.0); + args.initial_voxel_size = Some(1.0); args.voxel_layers = 1; - args.no_z_coord = true; + args.nbglayers = 1; + args.ignore_z_coord = true; // TODO: This is the resolution on the one dataset I have. It probably // doesn't generalize. @@ -506,8 +556,11 @@ fn main() { mut cell_assignments, mut nucleus_population) = */ + let excluded_genes = args.excluded_genes.map(|pat| Regex::new(&pat).unwrap()); + let mut dataset = read_transcripts_csv( &args.transcript_csv, + excluded_genes, &expect_arg(args.gene_column, "gene-column"), args.transcript_id_column, args.compartment_column, @@ -522,10 +575,22 @@ fn main() { &expect_arg(args.y_column, "y-column"), &expect_arg(args.z_column, "z-column"), args.min_qv, - args.no_z_coord, + args.ignore_z_coord, args.coordinate_scale.unwrap_or(1.0), ); + if !args.no_factorization { + dataset.select_unfactored_genes(args.nunfactored); + } + + // let cd3e_idx = dataset + // .transcript_names + // .iter() + // .position(|gene| gene == "CXCL6") + // .unwrap(); + // dbg!(cd3e_idx); + // panic!(); + // Warn if any nucleus has extremely high population, which is likely // an error interpreting the file. dataset.nucleus_population.iter().for_each(|&p| { @@ -654,6 +719,8 @@ fn main() { Some(args.burnin_dispersion) }, + use_cell_scales: !args.no_cell_scales, + min_cell_volume, μ_μ_volume: (2.0 * mean_nucleus_area * zspan).ln(), @@ -661,12 +728,26 @@ fn main() { α_σ_volume: 0.1, β_σ_volume: 0.1, - e_r: 1.0, + use_factorization: !args.no_factorization, + + // TODO: mean/var ratio is always 1/fφ, but that doesn't seem like the whole + // story. Seems like it needs to change as a function of the dimensionality + // of the latent space. - e_h: 1.0, - f_h: 1.0, + // I also don't know if this "severe prior" approach is going to work in + // the long run because we may have far more cells. Needs more testing. + // eφ: 1000.0, + // fφ: 1.0, - γ: 1.0, + // μφ: -20.0, + // τφ: 0.1, + αθ: 1e-1, + + eφ: args.hyperparam_e_phi, + fφ: args.hyperparam_f_phi, + + μφ: -args.hyperparam_neg_mu_phi, + τφ: args.hyperparam_tau_phi, α_bg: 1.0, β_bg: 1.0, @@ -692,6 +773,8 @@ fn main() { σ_z_diffusion_proposal: 0.2 * zspan, σ_z_diffusion: 0.2 * zspan, + τv: 10.0, + zmin, zmax, @@ -709,6 +792,8 @@ fn main() { &dataset.nucleus_population, &dataset.cell_assignments, args.ncomponents, + args.nhidden, + args.nunfactored, args.nbglayers, ncells, ngenes, @@ -736,7 +821,9 @@ fn main() { initial_voxel_size, chunk_size, )); - sampler.borrow_mut().initialize(&priors, &mut params); + sampler + .borrow_mut() + .initialize(&priors, &mut params, &dataset.transcripts); let mut total_steps = 0; @@ -856,12 +943,12 @@ fn main() { ¶ms, &dataset.transcript_names, ); - write_component_params( - &args.output_component_params, - args.output_component_params_fmt, - ¶ms, - &dataset.transcript_names, - ); + // write_component_params( + // &args.output_component_params, + // args.output_component_params_fmt, + // ¶ms, + // &dataset.transcript_names, + // ); write_cell_metadata( &args.output_cell_metadata, args.output_cell_metadata_fmt, @@ -890,6 +977,17 @@ fn main() { &dataset.transcript_names, &ecounts, ); + write_metagene_rates( + &args.output_metagene_rates, + args.output_metagene_rates_fmt, + ¶ms.φ, + ); + write_metagene_loadings( + &args.output_metagene_loadings, + args.output_metagene_loadings_fmt, + &dataset.transcript_names, + ¶ms.θ, + ); write_voxels( &args.output_cell_voxels, args.output_cell_voxels_fmt, @@ -935,6 +1033,7 @@ fn run_hexbin_sampler( for _ in 0..niter { // sampler.check_perimeter_bounds(priors); + // let t0 = std::time::Instant::now(); if sample_cell_regions { // let t0 = std::time::Instant::now(); for _ in 0..local_steps_per_iter { @@ -942,16 +1041,17 @@ fn run_hexbin_sampler( priors, params, &mut proposal_stats, - transcripts, hillclimb, &mut uncertainty, ); } // println!("Sample cell regions: {:?}", t0.elapsed()); } + // println!("Sample cell regions: {:?}", t0.elapsed()); + // let t0 = std::time::Instant::now(); sampler.sample_global_params(priors, params, transcripts, &mut uncertainty, burnin); - // println!("Sample parameters: {:?}", t0.elapsed()); + // println!("Sample global parameters: {:?}", t0.elapsed()); let nassigned = params.nassigned(); let nforeground = params.nforeground(); diff --git a/src/output.rs b/src/output.rs index 45bb6c0..88c62f6 100644 --- a/src/output.rs +++ b/src/output.rs @@ -1,7 +1,7 @@ use arrow::array::RecordBatch; -use arrow::datatypes::{Schema, Field, DataType}; -use arrow::error::ArrowError; use arrow::csv; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::error::ArrowError; use parquet::errors::ParquetError; use parquet::arrow::ArrowWriter; use parquet::file::properties::WriterProperties; @@ -9,22 +9,18 @@ use parquet::basic::{Compression::ZSTD, ZstdLevel}; use flate2::write::GzEncoder; use flate2::Compression; use geo::MultiPolygon; -use ndarray::{Array1, Array2, Axis, Zip}; +use ndarray::{Array1, Array2, Axis}; use std::fs::File; use std::io::Write; use std::sync::Arc; -use crate::schemas::{transcript_metadata_schema, OutputFormat}; use super::sampler::transcripts::Transcript; use super::sampler::transcripts::BACKGROUND_CELL; use super::sampler::voxelsampler::VoxelSampler; use super::sampler::{ModelParams, TranscriptState}; +use crate::schemas::{transcript_metadata_schema, OutputFormat}; -pub fn write_table( - filename: &str, - fmt: OutputFormat, - batch: &RecordBatch, -) { +pub fn write_table(filename: &str, fmt: OutputFormat, batch: &RecordBatch) { let fmt = match fmt { OutputFormat::Infer => infer_format_from_filename(filename), _ => fmt, @@ -55,23 +51,15 @@ pub fn write_table( } } -fn write_table_csv( - output: &mut W, - batch: &RecordBatch, -) -> Result<(), ArrowError> +fn write_table_csv(output: &mut W, batch: &RecordBatch) -> Result<(), ArrowError> where W: std::io::Write, { - let mut writer = csv::WriterBuilder::new() - .with_header(true) - .build(output); + let mut writer = csv::WriterBuilder::new().with_header(true).build(output); writer.write(batch) } -fn write_table_parquet( - output: &mut W, - batch: &RecordBatch, -) -> Result<(), ParquetError> +fn write_table_parquet(output: &mut W, batch: &RecordBatch) -> Result<(), ParquetError> where W: std::io::Write + Send, { @@ -109,9 +97,8 @@ pub fn write_counts( let schema = Schema::new( transcript_names .iter() - .map(|name| { - Field::new(name, DataType::UInt32, false) - }).collect::>() + .map(|name| Field::new(name, DataType::UInt32, false)) + .collect::>(), ); let mut columns: Vec> = Vec::new(); @@ -121,10 +108,7 @@ pub fn write_counts( )); } - let batch = RecordBatch::try_new( - Arc::new(schema), - columns - ).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), columns).unwrap(); write_table(output_counts, output_counts_fmt, &batch); } @@ -140,9 +124,8 @@ pub fn write_expected_counts( let schema = Schema::new( transcript_names .iter() - .map(|name| { - Field::new(name, DataType::Float32, false) - }).collect::>() + .map(|name| Field::new(name, DataType::Float32, false)) + .collect::>(), ); let mut columns: Vec> = Vec::new(); @@ -152,93 +135,148 @@ pub fn write_expected_counts( )); } - let batch = RecordBatch::try_new( - Arc::new(schema), - columns - ).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), columns).unwrap(); - write_table( - output_expected_counts, - output_expected_counts_fmt, - &batch, - ); + write_table(output_expected_counts, output_expected_counts_fmt, &batch); } } -pub fn write_rates( - output_rates: &Option, - output_rates_fmt: OutputFormat, - params: &ModelParams, - transcript_names: &[String], + +pub fn write_metagene_rates( + output_metagene_rates: &Option, + output_metagene_rates_fmt: OutputFormat, + φ: &Array2 ) { - if let Some(output_rates) = output_rates { + if let Some(output_metagene_rates) = output_metagene_rates { + let k = φ.shape()[1]; let schema = Schema::new( - transcript_names - .iter() - .map(|name| { - Field::new(name, DataType::Float32, false) - }).collect::>() + (0..k) + .map(|name| Field::new(format!("phi{}", name), DataType::Float32, false)) + .collect::>(), ); let mut columns: Vec> = Vec::new(); - for row in params.λ.rows() { + for row in φ.columns() { columns.push(Arc::new( row.iter().cloned().collect::(), )); } - let batch = RecordBatch::try_new( - Arc::new(schema), - columns - ).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), columns).unwrap(); - write_table(output_rates, output_rates_fmt, &batch); + write_table(output_metagene_rates, output_metagene_rates_fmt, &batch); } } -pub fn write_component_params( - output_component_params: &Option, - output_component_params_fmt: OutputFormat, - params: &ModelParams, + +pub fn write_metagene_loadings( + output_metagene_rates: &Option, + output_metagene_rates_fmt: OutputFormat, transcript_names: &[String], + θ: &Array2 ) { - if let Some(output_component_params) = output_component_params { - // What does this look like: rows for each gene, columns for α1, β1, α2, β2, etc. - let α = ¶ms.r; - let φ = ¶ms.φ; - let β = φ.map(|φ| (-φ).exp()); - - let ncomponents = params.ncomponents(); - - let mut fields = Vec::new(); - fields.push(Field::new("gene", DataType::Utf8, false)); - for i in 0..ncomponents { - fields.push(Field::new(&format!("α_{}", i), DataType::Float32, false)); - fields.push(Field::new(&format!("β_{}", i), DataType::Float32, false)); + if let Some(output_metagene_rates) = output_metagene_rates { + let k = θ.shape()[1]; + + let mut schema = vec![ + Field::new("gene", DataType::Utf8, false), + ]; + + for i in 0..k { + schema.push(Field::new(format!("theta{}", i), DataType::Float32, false)); } - let schema = Schema::new(fields); + + let schema = Schema::new(schema); let mut columns: Vec> = Vec::new(); - columns.push(Arc::new(arrow::array::StringArray::from(transcript_names.iter().cloned().collect::>()))); - Zip::from(α.rows()).and(β.rows()).for_each(|α, β| { - columns.push(Arc::new(α.iter().cloned().collect::())); - columns.push(Arc::new(β.iter().cloned().collect::())); - }); + columns.push( + Arc::new( + transcript_names.iter().map(|gene| Some(gene)).collect::() + ) + ); + + for row in θ.columns() { + columns.push(Arc::new( + row.iter().cloned().collect::(), + )); + } + + let batch = RecordBatch::try_new(Arc::new(schema), columns).unwrap(); + + write_table(output_metagene_rates, output_metagene_rates_fmt, &batch); + } +} - let batch = RecordBatch::try_new( - Arc::new(schema), - columns - ).unwrap(); - write_table( - output_component_params, - output_component_params_fmt, - &batch, +pub fn write_rates( + output_rates: &Option, + output_rates_fmt: OutputFormat, + params: &ModelParams, + transcript_names: &[String], +) { + if let Some(output_rates) = output_rates { + let schema = Schema::new( + transcript_names + .iter() + .map(|name| Field::new(name, DataType::Float32, false)) + .collect::>(), ); + + let mut columns: Vec> = Vec::new(); + for row in params.λ.columns() { + columns.push(Arc::new( + row.iter().cloned().collect::(), + )); + } + + let batch = RecordBatch::try_new(Arc::new(schema), columns).unwrap(); + + write_table(output_rates, output_rates_fmt, &batch); } } +// pub fn write_component_params( +// output_component_params: &Option, +// output_component_params_fmt: OutputFormat, +// params: &ModelParams, +// transcript_names: &[String], +// ) { +// if let Some(output_component_params) = output_component_params { +// // What does this look like: rows for each gene, columns for α1, β1, α2, β2, etc. +// let α = ¶ms.r; +// let φ = ¶ms.φ; +// let β = φ.map(|φ| (-φ).exp()); +// let ncomponents = params.ncomponents(); +// +// let mut fields = Vec::new(); +// fields.push(Field::new("gene", DataType::Utf8, false)); +// for i in 0..ncomponents { +// fields.push(Field::new(&format!("α_{}", i), DataType::Float32, false)); +// fields.push(Field::new(&format!("β_{}", i), DataType::Float32, false)); +// } +// let schema = Schema::new(fields); +// +// let mut columns: Vec> = Vec::new(); +// columns.push(Arc::new(arrow::array::StringArray::from( +// transcript_names.iter().cloned().collect::>(), +// ))); +// +// Zip::from(α.rows()).and(β.rows()).for_each(|α, β| { +// columns.push(Arc::new( +// α.iter().cloned().collect::(), +// )); +// columns.push(Arc::new( +// β.iter().cloned().collect::(), +// )); +// }); +// +// let batch = RecordBatch::try_new(Arc::new(schema), columns).unwrap(); +// +// write_table(output_component_params, output_component_params_fmt, &batch); +// } +// } + // Assign cells to fovs by finding the most common transcript fov of the // assigned transcripts. fn cell_fov_vote( @@ -279,6 +317,7 @@ pub fn write_cell_metadata( fovs: &[u32], fov_names: &[String], ) { + // TODO: write factorization let ncells = cell_centroids.len(); let nfovs = fov_names.len(); let cell_fovs = cell_fov_vote(ncells, nfovs, cell_assignments, fovs); @@ -292,44 +331,78 @@ pub fn write_cell_metadata( Field::new("fov", DataType::Utf8, true), Field::new("cluster", DataType::UInt16, false), Field::new("volume", DataType::Float32, false), + Field::new("scale", DataType::Float32, false), Field::new("population", DataType::UInt64, false), ]); let columns: Vec> = vec![ - Arc::new((0..params.ncells() as u32).collect::()), - Arc::new(cell_centroids.iter().map(|(x, _, _)| *x).collect::()), - Arc::new(cell_centroids.iter().map(|(_, y, _)| *y).collect::()), - Arc::new(cell_centroids.iter().map(|(_, _, z)| *z).collect::()), Arc::new( - cell_fovs.iter().map( - |fov| { + cell_centroids + .iter() + .map(|(x, _, _)| *x) + .collect::(), + ), + Arc::new( + cell_centroids + .iter() + .map(|(_, y, _)| *y) + .collect::(), + ), + Arc::new( + cell_centroids + .iter() + .map(|(_, _, z)| *z) + .collect::(), + ), + Arc::new( + cell_fovs + .iter() + .map(|fov| { if *fov == u32::MAX { None } else { Some(fov_names[*fov as usize].clone()) } - }, - ).collect::()), - Arc::new(params.z.iter().map(|&z| z as u16).collect::()), - Arc::new(params.cell_volume.iter().cloned().collect::()), - Arc::new(params.cell_population.iter().map(|&p| p as u64).collect::()) + }) + .collect::(), + ), + Arc::new( + params + .z + .iter() + .map(|&z| z as u16) + .collect::(), + ), + Arc::new( + params + .cell_volume + .iter() + .cloned() + .collect::(), + ), + Arc::new( + params + .cell_scale + .iter() + .cloned() + .collect::(), + ), + Arc::new( + params + .cell_population + .iter() + .map(|&p| p as u64) + .collect::(), + ), ]; - let batch = RecordBatch::try_new( - Arc::new(schema), - columns - ).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), columns).unwrap(); - write_table( - output_cell_metadata, - output_cell_metadata_fmt, - &batch, - ); + write_table(output_cell_metadata, output_cell_metadata_fmt, &batch); } } - #[allow(clippy::too_many_arguments)] pub fn write_transcript_metadata( output_transcript_metadata: &Option, @@ -351,64 +424,86 @@ pub fn write_transcript_metadata( let columns: Vec> = vec![ Arc::new( - transcripts.iter().map(|t| t.transcript_id).collect::() + transcripts + .iter() + .map(|t| t.transcript_id) + .collect::(), ), Arc::new( - transcript_positions.iter().map(|(x, _, _)| *x).collect::() + transcript_positions + .iter() + .map(|(x, _, _)| *x) + .collect::(), ), Arc::new( - transcript_positions.iter().map(|(_, y, _)| *y).collect::() + transcript_positions + .iter() + .map(|(_, y, _)| *y) + .collect::(), ), Arc::new( - transcript_positions.iter().map(|(_, _, z)| *z).collect::() + transcript_positions + .iter() + .map(|(_, _, z)| *z) + .collect::(), ), Arc::new( - transcripts.iter().map(|t| t.x).collect::() + transcripts + .iter() + .map(|t| t.x) + .collect::(), ), Arc::new( - transcripts.iter().map(|t| t.y).collect::() + transcripts + .iter() + .map(|t| t.y) + .collect::(), ), Arc::new( - transcripts.iter().map(|t| t.z).collect::() + transcripts + .iter() + .map(|t| t.z) + .collect::(), ), Arc::new( transcripts .iter() .map(|t| Some(transcript_names[t.gene as usize].clone())) - .collect::() - ), - Arc::new( - qvs.iter().cloned().collect::() + .collect::(), ), + Arc::new(qvs.iter().cloned().collect::()), Arc::new( fovs.iter() .map(|fov| Some(fov_names[*fov as usize].clone())) - .collect::() + .collect::(), ), Arc::new( - cell_assignments.iter().map(|(cell, _)| *cell).collect::() + cell_assignments + .iter() + .map(|(cell, _)| *cell) + .collect::(), ), Arc::new( - cell_assignments.iter().map(|(_, pr)| *pr).collect::() + cell_assignments + .iter() + .map(|(_, pr)| *pr) + .collect::(), ), Arc::new( transcript_state .iter() .map(|&s| (s == TranscriptState::Background) as u8) - .collect::() + .collect::(), ), Arc::new( transcript_state .iter() .map(|&s| (s == TranscriptState::Confusion) as u8) - .collect::() + .collect::(), ), ]; - let batch = RecordBatch::try_new( - Arc::new(schema), - columns - ).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), columns).unwrap(); write_table( output_transcript_metadata, @@ -425,6 +520,7 @@ pub fn write_gene_metadata( transcript_names: &[String], expected_counts: &Array2, ) { + // TODO: write factorization if let Some(output_gene_metadata) = output_gene_metadata { let mut schema_fields = vec![ Field::new("gene", DataType::Utf8, false), @@ -435,7 +531,10 @@ pub fn write_gene_metadata( let mut columns: Vec> = vec![ Arc::new( - transcript_names.iter().map(|s| Some(s.clone())).collect::() + transcript_names + .iter() + .map(|s| Some(s.clone())) + .collect::(), ), Arc::new( params @@ -443,71 +542,73 @@ pub fn write_gene_metadata( .sum_axis(Axis(1)) .iter() .map(|x| *x as u64) - .collect::() + .collect::(), ), Arc::new( expected_counts .sum_axis(Axis(1)) - .iter().cloned() - .collect::() + .iter() + .cloned() + .collect::(), ), // Arc::new(array::Float32Array::from_values( // params.r.iter().cloned(), // )) ]; - // cell type dispersions - for i in 0..params.ncomponents() { - schema_fields.push(Field::new( - &format!("dispersion_{}", i), - DataType::Float32, - false, - )); - columns.push(Arc::new( - params.r.row(i).iter().cloned().collect::() - )); - } - - // cell type rates - for i in 0..params.ncomponents() { - schema_fields.push(Field::new(&format!("λ_{}", i), DataType::Float32, false)); - - let mut λ_component = Array1::::from_elem(params.ngenes(), 0_f32); - let mut count = 0; - Zip::from(¶ms.z) - .and(params.λ.columns()) - .for_each(|&z, λ| { - if i == z as usize { - Zip::from(&mut λ_component).and(λ).for_each(|a, b| *a += b); - count += 1; - } - }); - λ_component /= count as f32; - - columns.push(Arc::new( - λ_component.iter().cloned().collect::() - )); - } + // // cell type dispersions + // schema_fields.push(Field::new("dispersion", DataType::Float32, false)); + // columns.push(Arc::new( + // params + // .r + // .iter() + // .cloned() + // .collect::(), + // )); + + // // cell type rates + // for i in 0..params.ncomponents() { + // schema_fields.push(Field::new(&format!("λ_{}", i), DataType::Float32, false)); + // + + // let mut λ_component = Array1::::from_elem(params.ngenes(), 0_f32); + // let mut count = 0; + // Zip::from(¶ms.z) + // .and(params.λ.columns()) + // .for_each(|&z, λ| { + // if i == z as usize { + // Zip::from(&mut λ_component).and(λ).for_each(|a, b| *a += b); + // count += 1; + // } + // }); + // λ_component /= count as f32; + // + + // columns.push(Arc::new( + // λ_component + // .iter() + // .cloned() + // .collect::(), + // )); + // } // background rates for i in 0..params.nlayers() { schema_fields.push(Field::new(format!("λ_bg_{}", i), DataType::Float32, false)); columns.push(Arc::new( - params.λ_bg.column(i).iter().cloned().collect::() + params + .λ_bg + .column(i) + .iter() + .cloned() + .collect::(), )); } let schema = Schema::new(schema_fields); - let batch = RecordBatch::try_new( - Arc::new(schema), - columns - ).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), columns).unwrap(); - write_table( - output_gene_metadata, - output_gene_metadata_fmt, - &batch, - ); + write_table(output_gene_metadata, output_gene_metadata_fmt, &batch); } } @@ -557,10 +658,7 @@ pub fn write_voxels( Arc::new(z1s.iter().cloned().collect::()), ]; - let batch = RecordBatch::try_new( - Arc::new(schema), - columns - ).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), columns).unwrap(); write_table(output_voxels, output_voxels_fmt, &batch); } diff --git a/src/sampler.rs b/src/sampler.rs index e444fc1..616feba 100644 --- a/src/sampler.rs +++ b/src/sampler.rs @@ -1,28 +1,29 @@ mod connectivity; -pub mod voxelsampler; mod math; -pub mod polyagamma; +mod polyagamma; mod polygons; mod sampleset; pub mod transcripts; +pub mod voxelsampler; +// use super::output::{write_expected_counts, OutputFormat}; +use crate::hull::convex_hull_area; +use clustering::kmeans; use core::fmt::Debug; +use faer_ext::IntoFaer; use flate2::write::GzEncoder; use flate2::Compression; -use crate::hull::convex_hull_area; use itertools::{izip, Itertools}; -use libm::{lgammaf, log1pf}; -use linfa::traits::{Fit, Predict}; -use linfa::DatasetBase; -use linfa_clustering::KMeans; +use libm::lgammaf; use math::{ - logistic, lognormal_logpdf, negbin_logpmf_fast, normal_pdf, normal_x2_logpdf, normal_x2_pdf, - rand_crt, LogFactorial, LogGammaPlus, + lognormal_logpdf, negbin_logpmf, normal_logpdf, normal_x2_logpdf, normal_x2_pdf, odds_to_prob, + rand_crt, randn, }; -use ndarray::{Array1, Array2, Array3, Axis, Zip}; +use ndarray::linalg::general_mat_mul; +use ndarray::{s, Array1, Array2, Axis, Zip}; use polyagamma::PolyaGamma; use rand::{thread_rng, Rng}; -use rand_distr::{Dirichlet, Distribution, Gamma, Normal, StandardNormal}; +use rand_distr::{Binomial, Distribution, Gamma, Normal, StandardNormal}; use rayon::prelude::*; use std::cell::RefCell; use std::collections::HashMap; @@ -35,6 +36,11 @@ use transcripts::{CellIndex, Transcript, BACKGROUND_CELL}; // use std::time::Instant; +// use std::any::type_name; +// fn print_type_of(_: &T) { +// println!("{}", std::any::type_name::()) +// } + // Bounding perimeter as some multiple of the perimiter of a sphere with the // same volume. This of course is all on a lattice, so it's approximate. // `eta` is the scaling factor between number of mismatching neighbors on the @@ -72,6 +78,8 @@ pub struct ModelPriors { pub dispersion: Option, pub burnin_dispersion: Option, + pub use_cell_scales: bool, + pub min_cell_volume: f32, // params for normal prior @@ -82,14 +90,18 @@ pub struct ModelPriors { pub α_σ_volume: f32, pub β_σ_volume: f32, - // gamma rate prior - pub e_r: f32, + pub use_factorization: bool, - pub e_h: f32, - pub f_h: f32, + // dirichlet prior on θ + pub αθ: f32, - // normal precision parameter for β - pub γ: f32, + // gamma prior on rφ + pub eφ: f32, + pub fφ: f32, + + // log-normal prior on sφ + pub μφ: f32, + pub τφ: f32, // gamma prior for background rates pub α_bg: f32, @@ -119,6 +131,9 @@ pub struct ModelPriors { pub σ_z_diffusion_proposal: f32, pub σ_z_diffusion: f32, + // prior precision on effective log cell volume + pub τv: f32, + // bounds on z coordinate pub zmin: f32, pub zmax: f32, @@ -128,6 +143,7 @@ pub struct ModelPriors { } // Model global parameters. +#[allow(non_snake_case)] pub struct ModelParams { pub transcript_positions: Vec<(f32, f32, f32)>, proposed_transcript_positions: Vec<(f32, f32, f32)>, @@ -142,12 +158,15 @@ pub struct ModelParams { pub cell_population: Vec, - // per-cell volumes + // [ncells] per-cell volumes pub cell_volume: Array1, - pub cell_log_volume: Array1, + pub log_cell_volume: Array1, - // per-component volumes - pub component_volume: Array1, + // [ncells] cell_volume * cell_scale + pub effective_cell_volume: Array1, + + // [ncells] per-cell "effective" volume scaling factor + pub cell_scale: Array1, // area of the convex hull containing all transcripts full_layer_volume: f32, @@ -159,11 +178,8 @@ pub struct ModelParams { pub transcript_state: Array1, pub prev_transcript_state: Array1, - // [ngenes, ncells, nlayers] transcripts counts - pub counts: Array3, - - // [ncells, ngenes, nlayers] foreground transcripts counts - foreground_counts: Array3, + // [ncells, ngenes] foreground transcripts counts + foreground_counts: Array2, // [ngenes] background transcripts counts confusion_counts: Array1, @@ -174,45 +190,76 @@ pub struct ModelParams { // [ngenes, nlayers] total gene occourance counts pub total_gene_counts: Array2, - // Not parameters, but needed for sampling global params - logfactorial: LogFactorial, + // [ncells, nhidden] + pub cell_latent_counts: Array2, + + // [ngenes, nhidden] + pub gene_latent_counts: Array2, + + // Thread local [ngenes, nhidden] matrices for accumulation + pub gene_latent_counts_tl: ThreadLocal>>, - // TODO: This needs to be an matrix I guess! - loggammaplus: Array2, + // [nhidden] + pub latent_counts: Array1, - pub z: Array1, // assignment of cells to components + // [nhidden] thread local storage for sampling latent counts + pub multinomial_rates: ThreadLocal>>, + pub multinomial_sample: ThreadLocal>>, - component_population: Array1, // number of cells assigned to each component + // [ncells, ncomponents] space for sampling component assignments + pub z_probs: ThreadLocal>>, - // thread-local space used for sampling z - z_probs: ThreadLocal>>, + // [ncells] assignment of cells to components + pub z: Array1, - π: Vec, // mixing proportions over components + // [ncomponents] component probabilities + pub π: Array1, + + // [ncomponents] number of cells assigned to each component + component_population: Array1, + + // [ncomponents] total volume of each component + component_volume: Array1, + + // [ncomponents, nhidden] + component_latent_counts: Array2, μ_volume: Array1, // volume dist mean param by component σ_volume: Array1, // volume dist std param by component - // Prior on NB dispersion parameters - h: f32, + // [ncells, nhidden]: cell ψ parameter in the latent space + pub φ: Array2, + + // [nhidden]: precompute φ_k.dot(cell_volume) + φ_v_dot: Array1, - // [ncells, ngenes] Polya-gamma samples, used for sampling NB rates - ω: Array2, + // For components as well??? + // [ncells, nhidden] aux CRT variables for sampling rφ + // pub lφ: Array2, - // [ncomponents, ngenes] NB logit(p) parameters - pub φ: Array2, + // [ncells, nhidden] aux CRT variables for sampling rφ + pub lφ: Array2, + + // [ncells, nhidden] aux PolyaGamma variables for sampling sφ + pub ωφ: Array2, + + // [ncomponents, nhidden] φ gamma shape parameters + pub rφ: Array2, - // [ncomponents, ngenes] Parameters for sampling φ - μ_φ: Array2, - σ_φ: Array2, + // for precomputing lgamma(rφ) + lgamma_rφ: Array2, - // [ncomponents, ngenes] NB r parameters. - pub r: Array2, + // [ncomponents, nhidden] φ gamma scale parameters + pub sφ: Array2, - // [ncomponents, ngenes] gamama parameters for sampling r - uv: Array2<(u32, f32)>, + // Size of the upper block of θ that is the identity matrix + nunfactored: usize, - // Precomputing lgamma(r) - lgamma_r: Array2, + // [ngenes, nhidden]: gene loadings in the latent space + pub θ: Array2, + + // [nhidden]: Sums across the first axis of θ + pub θksum: Array1, // // [ncomponents, ngenes] NB p parameters. // θ: Array2, @@ -223,7 +270,7 @@ pub struct ModelParams { // // log(1 - ods_to_prob(θ)) // log1mp: Array2, - // [ngenes, ncells] Poisson rates + // [ncells, ngenes] Poisson rates pub λ: Array2, // [ngenes, nlayers] background rate: rate at which halucinate transcripts @@ -241,6 +288,7 @@ impl ModelParams { // initialize model parameters, with random cell assignments // and other parameterz unninitialized. #[allow(clippy::too_many_arguments)] + #[allow(non_snake_case)] pub fn new( priors: &ModelPriors, full_layer_volume: f32, @@ -252,6 +300,8 @@ impl ModelParams { init_cell_population: &[usize], prior_seg_cell_assignment: &[u32], ncomponents: usize, + nhidden: usize, + nunfactored: usize, nlayers: usize, ncells: usize, ngenes: usize, @@ -264,9 +314,12 @@ impl ModelParams { if initial_perturbation_sd > 0.0 { let mut rng = rand::thread_rng(); for pos in &mut transcript_positions { - pos.0 += rng.sample::(StandardNormal) * initial_perturbation_sd; - pos.1 += rng.sample::(StandardNormal) * initial_perturbation_sd; - pos.2 += rng.sample::(StandardNormal) * initial_perturbation_sd; + pos.0 += + rng.sample::(StandardNormal) * initial_perturbation_sd; + pos.1 += + rng.sample::(StandardNormal) * initial_perturbation_sd; + pos.2 += + rng.sample::(StandardNormal) * initial_perturbation_sd; } } @@ -274,77 +327,103 @@ impl ModelParams { let accept_proposed_transcript_positions = vec![false; transcripts.len()]; let transcript_position_updates = vec![(0, 0, 0, 0); transcripts.len()]; - let r = Array2::::from_elem((ncomponents, ngenes), 100.0_f32); + let nhidden = if priors.use_factorization { + nhidden + nunfactored + } else { + ngenes + }; + + // TODO: save some space when not using factorization + let lφ = Array2::::zeros((ncells, nhidden)); + let ωφ = Array2::::zeros((ncells, nhidden)); + let rφ = Array2::::from_elem((ncomponents, nhidden), 1.0_f32); + let lgamma_rφ = Array2::::from_elem((ncomponents, nhidden), 1.0_f32); + let sφ = Array2::::from_elem((ncomponents, nhidden), 1.0_f32); let cell_volume = Array1::::zeros(ncells); - let cell_log_volume = Array1::::zeros(ncells); - let h = 10.0; + let log_cell_volume = Array1::::zeros(ncells); + let cell_scale = Array1::::from_elem(ncells, 1.0_f32); + let effective_cell_volume = cell_volume.clone(); // compute initial counts - let mut counts = Array3::::from_elem((ngenes, ncells, nlayers), 0); + let mut counts = Array2::::zeros((ngenes, ncells)); let mut total_gene_counts = Array2::::from_elem((ngenes, nlayers), 0); for (i, &j) in init_cell_assignments.iter().enumerate() { let gene = transcripts[i].gene as usize; let layer = ((transcripts[i].z - z0) / layer_depth) as usize; if j != BACKGROUND_CELL { - counts[[gene, j as usize, layer]] += 1; + counts[[gene, j as usize]] += 1.0; } total_gene_counts[[gene, layer]] += 1; } // initial component assignments - let norm_constant = 1e4; - let mut init_samples = counts - .sum_axis(Axis(2)) - .map(|&x| (x as f32)) - .reversed_axes(); - init_samples.rows_mut().into_iter().for_each(|mut row| { - let rowsum = row.sum(); - row.mapv_inplace(|x| (norm_constant * (x / rowsum)).ln_1p()); + let norm_constant = 1e2; + counts.columns_mut().into_iter().for_each(|mut col| { + let colsum = col.sum(); + col.mapv_inplace(|x| (norm_constant * (x / colsum)).ln_1p()); + // col.mapv_inplace(|x| x.ln_1p()); }); - let init_samples = DatasetBase::from(init_samples); - - // log1p transformed counts - // let init_samples = - // DatasetBase::from(counts.sum_axis(Axis(2)).map(|&x| (x as f32).ln_1p()).reversed_axes()); - - let rng = rand::thread_rng(); - let model = KMeans::params_with_rng(ncomponents, rng) - .tolerance(1e-1) - .fit(&init_samples) - .expect("kmeans failed to converge"); - - let z = model.predict(&init_samples).map(|&x| x as u32); - - // let rng = rand::thread_rng(); - // let model = GaussianMixtureModel::params_with_rng(ncomponents, rng) - // .tolerance(1e-1) - // .fit(&init_samples) - // .expect("gmm failed to converge"); - - // let z = model.predict(&init_samples).map(|&x| x as u32); - - // // initial component assignments - // let mut rng = rand::thread_rng(); - // let z = (0..ncells) - // .map(|_| rng.gen_range(0..ncomponents) as u32) - // .collect::>() - // .into(); - - let uv = Array2::<(u32, f32)>::from_elem((ncomponents, ngenes), (0_u32, 0_f32)); - let lgamma_r = Array2::::from_elem((ncomponents, ngenes), 0.0); - let loggammaplus = - Array2::::from_elem((ncomponents, ngenes), LogGammaPlus::default()); - let ω = Array2::::from_elem((ncells, ngenes), 0.0); - let φ = Array2::::from_elem((ncomponents, ngenes), 0.0); - let μ_φ = Array2::::from_elem((ncomponents, ngenes), 0.0); - let σ_φ = Array2::::from_elem((ncomponents, ngenes), 0.0); - - let component_volume = Array1::::from_elem(ncomponents, 0.0); + + // subtract mean + for mut counts_i in counts.rows_mut() { + let mean = counts_i.mean().unwrap(); + counts_i -= mean; + } + + let counts = counts.view().into_faer(); + let counts_svd = counts.thin_svd(); + + // truncate the svd for kmeans (I hope it's sorted...) + let svd_dim = counts_svd.v().shape().1.min(100); + let embedding: Vec> = counts_svd + .v() + .row_iter() + .map(|row| row.get(0..svd_dim).iter().cloned().collect()) + .collect(); + + let kmeans_results = kmeans(ncomponents, &embedding, 100); + + let z: Array1 = kmeans_results + .membership + .iter() + .map(|z_c| *z_c as u32) + .collect(); + + let z_probs = ThreadLocal::new(); + + let π = Array1::::from_elem(ncomponents, 1.0 / ncomponents as f32); + + let mut rng = rand::thread_rng(); + let φ = Array2::::from_shape_simple_fn((ncells, nhidden), || randn(&mut rng).exp()); + let φ_v_dot = Array1::::zeros(nhidden); + let mut θ = + Array2::::from_shape_simple_fn((ngenes, nhidden), || randn(&mut rng).exp()); + let θksum = Array1::::zeros(nhidden); + + θ.fill(0.0); + if !priors.use_factorization { + θ.diag_mut().fill(1.0); + } + + // fix the upper block of θ to the identity matrix + if nunfactored > 0 { + let mut θunfac = θ.slice_mut(s![0..nunfactored, 0..nunfactored]); + θunfac.diag_mut().fill(1.0); + } + let transcript_state = Array1::::from_elem(transcripts.len(), TranscriptState::Foreground); let prev_transcript_state = Array1::::from_elem(transcripts.len(), TranscriptState::Foreground); + let foreground_counts = Array2::::from_elem((ncells, ngenes), 0); + let confusion_counts = Array1::::from_elem(ngenes, 0); + let background_counts = Array2::::from_elem((ngenes, nlayers), 0); + let cell_latent_counts = Array2::::from_elem((ncells, nhidden), 0); + let gene_latent_counts = Array2::::from_elem((ngenes, nhidden), 0); + let gene_latent_counts_tl = ThreadLocal::new(); + let latent_counts = Array1::::from_elem(nhidden, 0); + ModelParams { transcript_positions, proposed_transcript_positions, @@ -356,74 +435,54 @@ impl ModelParams { cell_assignment_time: vec![0; init_cell_assignments.len()], cell_population: init_cell_population.to_vec(), cell_volume, - cell_log_volume, - component_volume, + log_cell_volume, + cell_scale, + effective_cell_volume, full_layer_volume, z0, layer_depth, transcript_state, prev_transcript_state, - counts, - foreground_counts: Array3::::from_elem((ncells, ngenes, nlayers), 0), - confusion_counts: Array1::::from_elem(ngenes, 0), - background_counts: Array2::::from_elem((ngenes, nlayers), 0), + foreground_counts, + confusion_counts, + background_counts, total_gene_counts, - logfactorial: LogFactorial::new(), - loggammaplus, + cell_latent_counts, + gene_latent_counts, + gene_latent_counts_tl, + latent_counts, + multinomial_rates: ThreadLocal::new(), + multinomial_sample: ThreadLocal::new(), + z_probs, z, + π, component_population: Array1::::from_elem(ncomponents, 0), - z_probs: ThreadLocal::new(), - π: vec![1_f32 / (ncomponents as f32); ncomponents], + component_volume: Array1::::from_elem(ncomponents, 0.0), + component_latent_counts: Array2::::from_elem((ncomponents, nhidden), 0), μ_volume: Array1::::from_elem(ncomponents, priors.μ_μ_volume), σ_volume: Array1::::from_elem(ncomponents, priors.σ_μ_volume), - h, - ω, φ, - μ_φ, - σ_φ, - r, - uv, - lgamma_r, + φ_v_dot, + nunfactored, + θ, + θksum, + lφ, + ωφ, + sφ, + rφ, + lgamma_rφ, // θ: Array2::::from_elem((ncomponents, ngenes), 0.1), - λ: Array2::::from_elem((ngenes, ncells), 0.1), + λ: Array2::::from_elem((ncells, ngenes), 0.1), λ_bg: Array2::::from_elem((ngenes, nlayers), 0.0), λ_c: Array1::::from_elem(ngenes, 1e-4), t: 0, } } - pub fn ncomponents(&self) -> usize { - self.π.len() - } - - fn zlayer(&self, z: f32) -> usize { - let layer = ((z - self.z0) / self.layer_depth).max(0.0) as usize; - layer.min(self.nlayers() - 1) - } - - fn recompute_counts(&mut self, transcripts: &[Transcript]) { - self.counts.fill(0); - for (i, &j) in self.cell_assignments.iter().enumerate() { - let gene = transcripts[i].gene as usize; - if j != BACKGROUND_CELL { - let layer = self.zlayer(self.transcript_positions[i].2); - self.counts[[gene, j as usize, layer]] += 1; - } - } - - self.check_counts(transcripts); - } - - fn check_counts(&self, transcripts: &[Transcript]) { - for (i, (transcript, &assignment)) in - transcripts.iter().zip(&self.cell_assignments).enumerate() - { - let layer = self.zlayer(self.transcript_positions[i].2); - if assignment != BACKGROUND_CELL { - assert!(self.counts[[transcript.gene as usize, assignment as usize, layer]] > 0); - } - } - } + // fn zlayer(&self, z: f32) -> usize { + // let layer = ((z - self.z0) / self.layer_depth).max(0.0) as usize; + // layer.min(self.nlayers() - 1) + // } pub fn nforeground(&self) -> usize { self.foreground_counts.iter().map(|x| *x as usize).sum() @@ -450,7 +509,7 @@ impl ModelParams { pub fn log_likelihood(&self, priors: &ModelPriors) -> f32 { // iterate over cells - let mut ll = Zip::from(self.λ.columns()) + let mut ll = Zip::from(self.λ.rows()) .and(&self.cell_volume) // .and(self.counts.axis_iter(Axis(1))) .and(self.foreground_counts.axis_iter(Axis(0))) @@ -475,11 +534,6 @@ impl ModelParams { }) - λ * cell_volume }); - // if part < -57983890000.0 { - // dbg!(part, cell_volume, cs.sum()); - // dbg!(λ); - // panic!(); - // } accum + part }); @@ -895,18 +949,19 @@ pub trait Proposal { }); } else { let volume_diff = self.old_cell_volume_delta(); + let a_c = params.cell_scale[old_cell as usize]; let prev_volume = params.cell_volume[old_cell as usize]; let new_volume = prev_volume + volume_diff; // normalization term difference - δ += Zip::from(params.λ.column(old_cell as usize)) - .fold(0.0, |acc, &λ| acc - λ * volume_diff); + δ += Zip::from(params.λ.row(old_cell as usize)) + .fold(0.0, |acc, &λ| acc - λ * a_c * volume_diff); Zip::from(self.gene_count().rows()) .and(params.λ_bg.rows()) .and(¶ms.λ_c) - .and(params.λ.column(old_cell as usize)) + .and(params.λ.row(old_cell as usize)) .for_each(|gene_counts, λ_bg, &λ_c, λ| { Zip::from(gene_counts).and(λ_bg).for_each(|&count, &λ_bg| { if count > 0 { @@ -938,19 +993,20 @@ pub trait Proposal { }); } else { let volume_diff = self.new_cell_volume_delta(); + let a_c = params.cell_scale[new_cell as usize]; let prev_volume = params.cell_volume[new_cell as usize]; let new_volume = prev_volume + volume_diff; // normalization term difference - δ += Zip::from(params.λ.column(new_cell as usize)) - .fold(0.0, |acc, &λ| acc - λ * volume_diff); + δ += Zip::from(params.λ.row(new_cell as usize)) + .fold(0.0, |acc, &λ| acc - λ * a_c * volume_diff); // add in new cell likelihood terms Zip::from(self.gene_count().rows()) .and(params.λ_bg.rows()) .and(¶ms.λ_c) - .and(params.λ.column(new_cell as usize)) + .and(params.λ.row(new_cell as usize)) .for_each(|gene_counts, λ_bg, &λ_c, λ| { Zip::from(gene_counts).and(λ_bg).for_each(|&count, &λ_bg| { if count > 0 { @@ -977,14 +1033,6 @@ pub trait Proposal { if (hillclimb && δ > 0.0) || (!hillclimb && logu < δ + self.log_weight()) { self.accept(); - // TODO: debugging - // if from_background && !to_background { - // dbg!( - // self.log_weight(), - // δ, - // self.gene_count().sum(), - // ); - // } } else { self.reject(); } @@ -997,19 +1045,34 @@ where Self: Sync, { // fn generate_proposals<'b, 'c>(&'b mut self, params: &ModelParams) -> &'c mut [P] where 'b: 'c; - fn initialize(&mut self, priors: &ModelPriors, params: &mut ModelParams) { - Zip::from(&mut params.cell_log_volume) - .and(¶ms.cell_volume) - .into_par_iter() - .with_min_len(50) - .for_each(|(log_volume, &volume)| { - *log_volume = volume.ln(); + fn initialize( + &mut self, + priors: &ModelPriors, + params: &mut ModelParams, + transcripts: &Vec, + ) { + Zip::from(&mut params.φ_v_dot) + .and(params.φ.axis_iter(Axis(1))) + .for_each(|φ_v_dot_k, φ_k| { + *φ_v_dot_k = φ_k.dot(¶ms.cell_volume); + }); + + Zip::from(&mut params.θksum) + .and(params.θ.axis_iter(Axis(1))) + .for_each(|θksum, θ_k| { + *θksum = θ_k.sum(); }); // get to a reasonably high probability assignment - for _ in 0..40 { - self.sample_component_nb_params(priors, params, true); + for _ in 0..20 { + self.sample_transcript_state(priors, params, transcripts, &mut Option::None); + self.compute_counts(priors, params, transcripts); + self.sample_factor_model(priors, params, false, true); + self.sample_background_rates(priors, params); + self.sample_confusion_rates(priors, params); } + + // panic!("Finished initializing"); } fn repopulate_proposals(&mut self, priors: &ModelPriors, params: &ModelParams); @@ -1036,7 +1099,6 @@ where priors: &ModelPriors, params: &mut ModelParams, stats: &mut ProposalStats, - transcripts: &[Transcript], hillclimb: bool, uncertainty: &mut Option<&mut UncertaintyTracker>, ) { @@ -1048,13 +1110,12 @@ where self.proposals_mut() .par_iter_mut() .for_each(|p| p.evaluate(priors, params, hillclimb)); - self.apply_accepted_proposals(stats, transcripts, priors, params, uncertainty); + self.apply_accepted_proposals(stats, priors, params, uncertainty); } fn apply_accepted_proposals( &mut self, stats: &mut ProposalStats, - transcripts: &[Transcript], priors: &ModelPriors, params: &mut ModelParams, uncertainty: &mut Option<&mut UncertaintyTracker>, @@ -1117,15 +1178,6 @@ where cell_volume += proposal.old_cell_volume_delta(); cell_volume = cell_volume.max(priors.min_cell_volume); params.cell_volume[old_cell as usize] = cell_volume; - - for &i in proposal.transcripts() { - let gene = transcripts[i].gene; - let layer = ((params.transcript_positions[i].2 - params.z0) - / params.layer_depth) - .max(0.0) as usize; - let layer = layer.min(params.nlayers() - 1); - params.counts[[gene as usize, old_cell as usize, layer]] -= 1; - } } if new_cell != BACKGROUND_CELL { @@ -1135,15 +1187,6 @@ where cell_volume += proposal.new_cell_volume_delta(); cell_volume = cell_volume.max(priors.min_cell_volume); params.cell_volume[new_cell as usize] = cell_volume; - - for &i in proposal.transcripts() { - let gene = transcripts[i].gene; - let layer = ((params.transcript_positions[i].2 - params.z0) - / params.layer_depth) - .max(0.0) as usize; - let layer = layer.min(params.nlayers() - 1); - params.counts[[gene as usize, new_cell as usize, layer]] += 1; - } } } @@ -1158,13 +1201,11 @@ where uncertainty: &mut Option<&mut UncertaintyTracker>, burnin: bool, ) { - let mut rng = thread_rng(); - // let t0 = Instant::now(); self.sample_volume_params(priors, params); // println!(" Sample volume params: {:?}", t0.elapsed()); - // // Sample background/foreground counts + // Sample background/foreground counts // let t0 = Instant::now(); self.sample_transcript_state(priors, params, transcripts, uncertainty); // println!(" Sample transcript states: {:?}", t0.elapsed()); @@ -1174,41 +1215,12 @@ where // println!(" Compute counts: {:?}", t0.elapsed()); // let t0 = Instant::now(); - self.sample_component_nb_params(priors, params, burnin); - // println!(" Sample nb params: {:?}", t0.elapsed()); - - // Sample λ - // let t0 = Instant::now(); - self.sample_rates(priors, params); - // println!(" Sample λ: {:?}", t0.elapsed()); - - // TODO: - // This is the most expensive part. We could sample this less frequently, - // but we should try to optimize as much as possible. - // Ideas: - // - Main bottlneck is computing log(p) and log(1-p). I don't see - // anything obvious to do about that. + self.sample_factor_model(priors, params, true, burnin); + // println!(" Sample factor model: {:?}", t0.elapsed()); - // Sample z // let t0 = Instant::now(); - self.sample_component_assignments(priors, params); - // println!(" Sample z: {:?}", t0.elapsed()); - - // sample π - let mut α = vec![1_f32; params.ncomponents()]; - for z_i in params.z.iter() { - α[*z_i as usize] += 1.0; - } - - if α.len() == 1 { - params.π.clear(); - params.π.push(1.0); - } else { - params.π.clear(); - params - .π - .extend(Dirichlet::new(&α).unwrap().sample(&mut rng).iter()); - } + // self.sample_z(params); + // println!(" sample_z: {:?}", t0.elapsed()); // let t0 = Instant::now(); self.sample_background_rates(priors, params); @@ -1216,7 +1228,6 @@ where // let t0 = Instant::now(); self.sample_confusion_rates(priors, params); - // TODO: disabling confusion to see if it actually does anything // println!(" Sample confusion rates: {:?}", t0.elapsed()); // let t0 = Instant::now(); @@ -1251,7 +1262,7 @@ where let layer = ((position.2 - params.z0) / params.layer_depth).max(0.0) as usize; let layer = layer.min(nlayers - 1); - let λ_cell = params.λ[[gene, cell as usize]]; + let λ_cell = params.λ[[cell as usize, gene]]; let λ_bg = params.λ_bg[[gene, layer]]; let λ_c = params.λ_c[gene]; let λ = λ_cell + λ_bg + λ_c; @@ -1301,6 +1312,15 @@ where params: &mut ModelParams, transcripts: &Vec, ) { + // TODO: Ok, this is the bottleneck now. It can't be trivially parallelized + // because we are accumulating these big matrices. + + // These matrices are also problematic because they use so much memory. + // Ideally we'd find a solution where we can both accumulate in parallel + // and be sparse. + + // I think the thing to do is + let nlayers = params.nlayers(); params.confusion_counts.fill(0_u32); params.background_counts.fill(0_u32); @@ -1323,300 +1343,504 @@ where params.confusion_counts[gene] += 1; } TranscriptState::Foreground => { - params.foreground_counts[[cell as usize, gene, layer]] += 1; + params.foreground_counts[[cell as usize, gene]] += 1; } } }); - - // dbg!(params.background_counts.sum()); - // dbg!(params.confusion_counts.sum()); } - fn sample_component_nb_params( - &mut self, - priors: &ModelPriors, - params: &mut ModelParams, - burnin: bool, - ) { - // total component area - // let mut component_cell_area = vec![0_f32; params.ncomponents()]; - params.component_volume.fill(0.0); - params - .cell_volume - .iter() - .zip(¶ms.z) - .for_each(|(volume, z_i)| { - params.component_volume[*z_i as usize] += *volume; - }); + fn sample_latent_counts(&mut self, params: &mut ModelParams) { + params.cell_latent_counts.fill(0); - // Sample ω + // zero out thread local gene latent counts + for x in params.gene_latent_counts_tl.iter_mut() { + x.borrow_mut().fill(0); + } - // let rmin = params.r.iter().min_by(|a, b| a.partial_cmp(b).unwrap()); - // let rmax = params.r.iter().max_by(|a, b| a.partial_cmp(b).unwrap()); - // dbg!(rmin, rmax); + let ngenes = params.foreground_counts.shape()[1]; + let nhidden = params.cell_latent_counts.shape()[1]; - // let φmin = params.φ.iter().min_by(|a, b| a.partial_cmp(b).unwrap()); - // let φmax = params.φ.iter().max_by(|a, b| a.partial_cmp(b).unwrap()); - // dbg!(φmin, φmax); + Zip::from(params.cell_latent_counts.outer_iter_mut()) // for every cell + .and(params.φ.outer_iter()) + .and(params.foreground_counts.outer_iter()) + .par_for_each(|mut cell_latent_counts_c, φ_c, x_c| { + let mut rng = thread_rng(); + let mut multinomial_rates = params + .multinomial_rates + .get_or(|| RefCell::new(Array1::zeros(nhidden))) + .borrow_mut(); - // let vmin = params.cell_volume.iter().min_by(|a, b| a.partial_cmp(b).unwrap()); - // let vmax = params.cell_volume.iter().max_by(|a, b| a.partial_cmp(b).unwrap()); - // dbg!(vmin, vmax); + let mut multinomial_sample = params + .multinomial_sample + .get_or(|| RefCell::new(Array1::zeros(nhidden))) + .borrow_mut(); - // let t0 = Instant::now(); - Zip::from(params.ω.rows_mut()) // for every cell - .and(params.foreground_counts.axis_iter(Axis(0))) - .and(¶ms.cell_log_volume) - .and(¶ms.z) - .par_for_each(|ωs, cs, &logv, &z| { - let mut rng = thread_rng(); - Zip::from(cs.axis_iter(Axis(0))) // for every gene - .and(ωs) - .and(params.φ.row(z as usize)) - .and(params.r.row(z as usize)) - .for_each(|c, ω, φ, &r| { - *ω = PolyaGamma::new(c.sum() as f32 + r, logv + φ).sample(&mut rng); - }); - }); - // println!(" Sample ω: {:?}", t0.elapsed()); + let mut gene_latent_counts_tl = params + .gene_latent_counts_tl + .get_or(|| RefCell::new(Array2::zeros((ngenes, nhidden)))) + .borrow_mut(); - // Compute parameters to sample φ - // let t0 = Instant::now(); - params.μ_φ.fill(0.0); - params.σ_φ.fill(0.0); - Zip::from(params.ω.rows()) // for every cell - .and(params.foreground_counts.axis_iter(Axis(0))) // for each cell - .and(¶ms.cell_log_volume) - .and(¶ms.z) - .for_each(|ωs, cs, &logv, &z| { - Zip::from(params.μ_φ.row_mut(z as usize)) // for every gene - .and(params.σ_φ.row_mut(z as usize)) - .and(params.r.row(z as usize)) - .and(ωs) - .and(cs.axis_iter(Axis(0))) - .for_each(|μ, σ, &r, &ω, c| { - *σ += ω; - *μ += (c.sum() as f32 - r) / 2.0 - ω * logv; + Zip::indexed(x_c) // for every gene + .and(params.θ.outer_iter()) + .for_each(|g, &x_cg, θ_g| { + if x_cg == 0 { + return; + } + + multinomial_sample.fill(0); + + // rates: normalized element-wise product + multinomial_rates.assign(&φ_c); + *multinomial_rates *= &θ_g; + let rate_norm = multinomial_rates.sum(); + *multinomial_rates /= rate_norm; + + // multinomial sampling + { + let mut ρ = 1.0; + let mut s = x_cg as u32; + for (p, x) in + izip!(multinomial_rates.iter(), multinomial_sample.iter_mut()) + { + if ρ > 0.0 { + *x = Binomial::new(s as u64, ((*p / ρ) as f64).min(1.0)) + .unwrap() + .sample(&mut rng) + as u32; + } + s -= *x; + ρ = ρ - *p; + + if s == 0 { + break; + } + } + } + + // add to cell marginal + cell_latent_counts_c.scaled_add(1, &multinomial_sample); + + // add to gene marginal + let mut gene_latent_counts_g = gene_latent_counts_tl.row_mut(g); + gene_latent_counts_g.scaled_add(1, &multinomial_sample); }); }); - Zip::from(&mut params.σ_φ).for_each(|σ| *σ = (priors.γ + *σ).recip()); + // accumulate from thread local matrices + params.gene_latent_counts.fill(0); + for x in params.gene_latent_counts_tl.iter_mut() { + params.gene_latent_counts.scaled_add(1, &x.borrow()); + } + + // marginal count along the hidden axis + Zip::from(&mut params.latent_counts) + .and(params.gene_latent_counts.columns()) + .for_each(|lc, glc| { + *lc = glc.sum(); + }); - Zip::from(&mut params.μ_φ) - .and(¶ms.σ_φ) - .for_each(|μ, σ| *μ *= σ); - // println!(" Compute φ parameters: {:?}", t0.elapsed()); + let count = params.latent_counts.sum(); + assert!(params.gene_latent_counts.sum() == count); + assert!(params.cell_latent_counts.sum() == count); - // Sample φ - // let t0 = Instant::now(); - Zip::from(&mut params.φ) - .and(¶ms.μ_φ) - .and(¶ms.σ_φ) - .for_each(|φ, &μ, &σ| { - let mut rng = thread_rng(); - *φ = Normal::new(μ, σ.sqrt()).unwrap().sample(&mut rng); + // compute component-wise counts + params.component_population.fill(0); + params.component_volume.fill(0.0); + params.component_latent_counts.fill(0); + Zip::from(¶ms.z) + .and(¶ms.cell_volume) + .and(params.cell_latent_counts.rows()) + .for_each(|z_c, v_c, x_c| { + let z_c = *z_c as usize; + params.component_population[z_c] += 1; + params.component_volume[z_c] += *v_c; + params + .component_latent_counts + .row_mut(z_c) + .scaled_add(1, &x_c); }); - // println!(" Sample φ: {:?}", t0.elapsed()); - // }); - // Zip::from(&mut params.ω) - // .and(¶ms.foreground_counts) + // dbg!(¶ms.component_population); + } - // Sample θ + // This is indended just for debugging + // fn write_cell_latent_counts(&self, params: &ModelParams, filename: &str) { + // let file = File::create(filename).unwrap(); + // let mut encoder = GzEncoder::new(file, Compression::default()); + + // // header + // for i in 0..params.cell_latent_counts.shape()[1] { + // if i != 0 { + // write!(encoder, ",").unwrap(); + // } + // write!(encoder, "metagene{}", i).unwrap(); + // } + // writeln!(encoder).unwrap(); + + // for x_c in params.cell_latent_counts.rows() { + // for (k, x_ck) in x_c.iter().enumerate() { + // if k != 0 { + // write!(encoder, ",").unwrap(); + // } + // write!(encoder, "{}", *x_ck).unwrap(); + // } + // writeln!(encoder).unwrap(); + // } + // } + + fn sample_factor_model( + &mut self, + priors: &ModelPriors, + params: &mut ModelParams, + sample_z: bool, + burnin: bool, + ) { // let t0 = Instant::now(); - // Zip::from(params.θ.columns_mut()) - // .and(params.r.columns()) - // .and(params.component_counts.rows()) - // .par_for_each(|θs, rs, cs| { - // let mut rng = thread_rng(); - // for (θ, &c, &r, &a) in izip!(θs, cs, rs, ¶ms.component_volume) { - // *θ = prob_to_odds( - // Beta::new(priors.α_θ + c as f32, priors.β_θ + a * r) - // .unwrap() - // .sample(&mut rng), - // ); - - // // *θ = θ.max(1e-6).min(1e6); - // } - // }); - // println!(" Sample θ: {:?}", t0.elapsed()); + self.sample_latent_counts(params); + // println!(" sample_latent_counts: {:?}", t0.elapsed()); - // dbg!( - // params.θ.iter().min_by(|a, b| a.partial_cmp(b).unwrap()), - // params.θ.iter().max_by(|a, b| a.partial_cmp(b).unwrap()), - // ); + if sample_z { + // let t0 = Instant::now(); + self.sample_z(params); + // println!(" sample_z: {:?}", t0.elapsed()); + } + self.sample_π(params); - // dbg!(params.h); + if priors.use_factorization { + // let t0 = Instant::now(); + self.sample_θ(priors, params); + // println!(" sample_θ: {:?}", t0.elapsed()); + } - // Sample r // let t0 = Instant::now(); + self.sample_φ(params); + // println!(" sample_φ: {:?}", t0.elapsed()); - // TODO: Need some argument to determine dispersion sampling - // behavior. - - fn set_constant_dispersion(params: &mut ModelParams, dispersion: f32) { - params.r.fill(dispersion); - Zip::from(¶ms.r) - .and(&mut params.lgamma_r) - .and(&mut params.loggammaplus) - .par_for_each(|r, lgamma_r, loggammaplus| { - *lgamma_r = lgammaf(*r); - loggammaplus.reset(*r); - }); - } - + // let t0 = Instant::now(); if let Some(dispersion) = priors.dispersion { - set_constant_dispersion(params, dispersion); + params.rφ.fill(dispersion); } else if burnin && priors.burnin_dispersion.is_some() { let dispersion = priors.burnin_dispersion.unwrap(); - set_constant_dispersion(params, dispersion); + params.rφ.fill(dispersion); } else { - // for each gene - params.uv.fill((0_u32, 0_f32)); - Zip::from(params.r.columns_mut()) - .and(params.lgamma_r.columns_mut()) - .and(params.loggammaplus.columns_mut()) - .and(params.φ.columns()) - .and(params.foreground_counts.axis_iter(Axis(1))) - .and(params.uv.columns_mut()) - .par_for_each(|rs, lgamma_rs, loggammaplus, φs, cs, mut uv| { - let mut rng = thread_rng(); - - // iterate over cells computing u and v - Zip::from(¶ms.z) - .and(cs.axis_iter(Axis(0))) - .and(¶ms.cell_volume) - .for_each(|&z, c, &vol| { - let z = z as usize; - let c = c.sum(); - let r = rs[z]; - let φ = φs[z]; - let ψ = φ + vol.ln(); - let δv = -ψ - log1pf((-ψ).exp()); - - let uv_z = uv[z]; - - if (uv[z].1 + δv).is_infinite() { - dbg!(uv[z], ψ, φ, vol); - } + self.sample_rφ(priors, params); + } + // println!(" sample_rφ: {:?}", t0.elapsed()); - uv[z] = (uv_z.0 + rand_crt(&mut rng, c as u32, r), uv_z.1 + δv); + // let t0 = Instant::now(); + self.sample_ωck(params); + self.sample_sφ(priors, params); - assert!(uv[z].1.is_finite()); - }); + if priors.use_cell_scales { + self.sample_cell_scales(priors, params); + } else { + params.effective_cell_volume.assign(¶ms.cell_volume); + } - // iterate over components sampling r - Zip::from(rs) - .and(lgamma_rs) - .and(loggammaplus) - .and(uv) - .for_each(|r, lgamma_r, loggammaplus, uv| { - let dist = - Gamma::new(priors.e_r + uv.0 as f32, (params.h - uv.1).recip()); - if dist.is_err() { - dbg!(uv.0, uv.1, params.h); - } - let dist = dist.unwrap(); - *r = dist.sample(&mut rng); + // let t0 = Instant::now(); + self.compute_rates(params); + // println!(" compute_rates: {:?}", t0.elapsed()); + } - // if *r < 0.001 { - // dbg!(uv.0, uv.1, params.h, *r); - // } + fn sample_z(&mut self, params: &mut ModelParams) { + let ncomponents = params.π.shape()[0]; - // *r = Gamma::new(priors.e_r + uv.0 as f32, (params.h - uv.1).recip()) - // .unwrap() - // .sample(&mut rng); + // precompute lgamma(rφ) + Zip::from(&mut params.lgamma_rφ) + .and(¶ms.rφ) + .par_for_each(|lgamma_r_tk, r_tk| { + *lgamma_r_tk = lgammaf(*r_tk); + }); - assert!(r.is_finite()); + Zip::from(&mut params.z) // for each cell + .and(params.φ.rows()) + .and(params.cell_latent_counts.rows()) + .and(¶ms.effective_cell_volume) + .and(¶ms.cell_volume) + .par_for_each(|z_c, φ_c, x_c, ev_c, v_c| { + let mut rng = rand::thread_rng(); + let log_v_c = v_c.ln(); + let mut z_probs = params + .z_probs + .get_or(|| RefCell::new(vec![0_f64; ncomponents])) + .borrow_mut(); - // TODO: Without this, things get kind of fucky. - // *r = r.min(200.0).max(2e-4); - *r = r.max(2e-4); - // *r = r.min(50.0); - // *r = r.max(1e-5); + // compute probability of φ_c under every component - *lgamma_r = lgammaf(*r); - loggammaplus.reset(*r); + // for every component + let mut z_probs_sum = 0.0; + for (z_probs_t, π_t, r_t, lgamma_r_t, s_t, μ_vol_c, σ_vol_c) in izip!( + z_probs.iter_mut(), + params.π.iter(), + params.rφ.rows(), + params.lgamma_rφ.rows(), + params.sφ.rows(), + ¶ms.μ_volume, + ¶ms.σ_volume + ) { + *z_probs_t = π_t.ln() as f64; + + // for every hidden dim + Zip::from(r_t) + .and(lgamma_r_t) + .and(s_t) + .and(¶ms.θksum) + .and(x_c) + .for_each(|r_tk, lgamma_r_tk, s_tk, θ_k_sum, x_ck| { + let p = odds_to_prob(*s_tk * *ev_c * *θ_k_sum); + let lp = negbin_logpmf(*r_tk, *lgamma_r_tk, p, *x_ck) as f64; + *z_probs_t += lp; }); - // // self.cell_areas.slice(0..self.ncells) + *z_probs_t += normal_logpdf(*μ_vol_c, *σ_vol_c, log_v_c) as f64; + } + + for z_probs_t in z_probs.iter_mut() { + *z_probs_t = z_probs_t.exp(); + z_probs_sum += *z_probs_t; + } + + if !z_probs_sum.is_finite() { + dbg!(&z_probs, &φ_c, z_probs_sum); + } + + // cumulative probabilities in-place + z_probs.iter_mut().fold(0.0, |mut acc, x| { + acc += *x / z_probs_sum; + *x = acc; + acc + }); + + let u = rng.gen::(); + *z_c = z_probs.partition_point(|x| *x < u) as u32; + }); + } - // let u = Zip::from(counts.axis_iter(Axis(0))).fold(0, |accum, cs| { - // let c = cs.sum(); - // accum + rand_crt(&mut rng, c, *r) - // }); + fn sample_π(&mut self, params: &mut ModelParams) { + let mut rng = rand::thread_rng(); + let mut π_sum = 0.0; + Zip::from(&mut params.π) + .and(¶ms.component_population) + .for_each(|π_t, pop_t| { + *π_t = Gamma::new(1.0 + *pop_t as f32, 1.0) + .unwrap() + .sample(&mut rng); + π_sum += *π_t; + }); - // let v = - // Zip::from(¶ms.z) - // .and(¶ms.cell_volume) - // .fold(0.0, |accum, z, a| { - // let w = θs[*z as usize]; - // accum + log1pf(-odds_to_prob(w * *a)) - // }); + // normalize to get dirichlet posterior + params.π.iter_mut().for_each(|π_t| *π_t /= π_sum); + } - // // dbg!(u, v); + fn compute_rates(&mut self, params: &mut ModelParams) { + general_mat_mul(1.0, ¶ms.φ, ¶ms.θ.t(), 0.0, &mut params.λ); + } - // *r = Gamma::new(priors.e_r + u as f32, (params.h - v).recip()) - // .unwrap() - // .sample(&mut rng); + fn sample_φ(&mut self, params: &mut ModelParams) { + Zip::from(params.φ.outer_iter_mut()) // for each cell + .and(¶ms.z) + .and(params.cell_latent_counts.outer_iter()) + .and(¶ms.effective_cell_volume) + .par_for_each(|φ_c, z_c, x_c, v_c| { + let z_c = *z_c as usize; + let mut rng = thread_rng(); + Zip::from(φ_c) // for each latent dim + .and(¶ms.θksum) + .and(x_c) + .and(¶ms.rφ.row(z_c)) + .and(¶ms.sφ.row(z_c)) + .for_each(|φ_ck, θ_k_sum, x_ck, r_k, s_k| { + let shape = *r_k + *x_ck as f32; + let scale = s_k / (1.0 + s_k * v_c * *θ_k_sum); + *φ_ck = Gamma::new(shape, scale).unwrap().sample(&mut rng); + }); + }); - // assert!(r.is_finite()); + Zip::from(&mut params.φ_v_dot) + .and(params.φ.axis_iter(Axis(1))) + .for_each(|φ_v_dot_k, φ_k| { + *φ_v_dot_k = φ_k.dot(¶ms.effective_cell_volume); + }); + } - // *r = r.min(200.0).max(1e-5); + fn sample_θ(&mut self, priors: &ModelPriors, params: &mut ModelParams) { + let mut θfac = params + .θ + .slice_mut(s![params.nunfactored.., params.nunfactored..]); + let gene_latent_counts_fac = params + .gene_latent_counts + .slice(s![params.nunfactored.., params.nunfactored..]); + + // Sampling with Dirichlet prior on θ (I think Gamma makes more + // sense, but this is an alternative to consider) + Zip::from(θfac.axis_iter_mut(Axis(1))) + .and(gene_latent_counts_fac.axis_iter(Axis(1))) + .for_each(|mut θ_k, x_k| { + let mut rng = thread_rng(); - // *lgamma_r = lgammaf(*r); - // loggammaplus.reset(*r); + // dirichlet sampling by normalizing gammas + Zip::from(&mut θ_k).and(x_k).for_each(|θ_gk, x_gk| { + *θ_gk = Gamma::new(priors.αθ + *x_gk as f32, 1.0) + .unwrap() + .sample(&mut rng); }); - } - // params.h = 0.1; - params.h = Gamma::new( - priors.e_h * (1_f32 + params.r.len() as f32), - (priors.f_h + params.r.sum()).recip(), - ) - .unwrap() - .sample(&mut thread_rng()); - // dbg!(params.h); + let θsum = θ_k.sum(); + θ_k *= θsum.recip(); + }); + + Zip::from(&mut params.θksum) + .and(params.θ.axis_iter(Axis(1))) + .for_each(|θksum, θ_k| { + *θksum = θ_k.sum(); + }); } - fn sample_rates(&mut self, _priors: &ModelPriors, params: &mut ModelParams) { - // loop over genes - Zip::from(params.λ.rows_mut()) - .and(params.foreground_counts.axis_iter(Axis(1))) - .and(params.φ.columns()) - .and(params.r.columns()) - .par_for_each(|mut λs, cs, φs, rs| { + fn sample_rφ(&mut self, priors: &ModelPriors, params: &mut ModelParams) { + Zip::from(params.lφ.outer_iter_mut()) // for every cell + .and(¶ms.z) + .and(params.cell_latent_counts.outer_iter()) + .par_for_each(|l_c, z_c, x_c| { let mut rng = thread_rng(); - // loop over cells - for (λ, &z, cs, cell_volume) in - izip!(&mut λs, ¶ms.z, cs.outer_iter(), ¶ms.cell_volume) - { - let z = z as usize; - - let c = cs.sum(); - let φ = φs[z]; - // dbg!(φ, φ.exp(), (-φ).exp()); - // let ψ = φ + cell_volume.ln(); - let r = rs[z]; - - let α = r + c as f32; - // let β = (-φ).exp() / cell_volume + 1.0; - let β0 = (-φ).exp(); - let β = β0 + cell_volume; - - *λ = Gamma::new(α, β.recip()).unwrap().sample(&mut rng); - // .max(1e-14); - - // if c > 10 { - // // TODO: How far off is than from just count divided by volume? - // dbg!( - // *λ, - // r, - // c as f32 / *cell_volume, *cell_volume); - // } + Zip::from(l_c) // for each hidden dim + .and(x_c) + .and(¶ms.rφ.row(*z_c as usize)) + .for_each(|l_ck, x_ck, r_k| { + *l_ck = rand_crt(&mut rng, *x_ck, *r_k); + }); + }); - assert!(λ.is_finite()); - } + Zip::indexed(params.rφ.outer_iter_mut()) // for each component + .and(params.sφ.outer_iter()) + .par_for_each(|t, r_t, s_t| { + let mut rng = thread_rng(); + Zip::from(r_t) // each hidden dim + .and(s_t) + .and(params.lφ.axis_iter(Axis(1))) + .and(¶ms.θksum) + .for_each(|r_tk, s_tk, l_k, θ_k_sum| { + // summing elements of lφ in component t + let lsum = l_k + .iter() + .zip(¶ms.z) + .filter(|(_l_ck, z_c)| **z_c as usize == t) + .map(|(l_ck, _z_c)| *l_ck) + .sum::(); + + let shape = priors.eφ + lsum as f32; + + let scale_inv = (1.0 / priors.fφ) + + params + .z + .iter() + .zip(¶ms.effective_cell_volume) + .filter(|(z_c, _v_c)| **z_c as usize == t) + .map(|(_z_c, v_c)| (*s_tk * v_c * *θ_k_sum).ln_1p()) + .sum::(); + let scale = scale_inv.recip(); + *r_tk = Gamma::new(shape, scale).unwrap().sample(&mut rng); + *r_tk = r_tk.max(2e-4); + }); + }); + + // dbg!(¶ms.rφ); + } + + fn sample_ωck(&mut self, params: &mut ModelParams) { + // sample ω ~ PolyaGamma + // let t0 = Instant::now(); + Zip::from(params.ωφ.outer_iter_mut()) // for every cell + .and(¶ms.z) + .and(params.cell_latent_counts.outer_iter()) + .and(¶ms.effective_cell_volume) + .par_for_each(|ω_c, z_c, x_c, v_c| { + let mut rng = thread_rng(); + Zip::from(ω_c) // for every hidden dim + .and(x_c) + .and(params.rφ.row(*z_c as usize)) + .and(params.sφ.row(*z_c as usize)) + .and(¶ms.θksum) + .for_each(|ω_ck, x_ck, r_k, s_k, θ_k_sum| { + let ε = (*s_k * *v_c * θ_k_sum).ln(); + *ω_ck = PolyaGamma::new(*x_ck as f32 + *r_k, ε).sample(&mut rng); + }); + }); + // println!(" sample_sφ/PolyaGamma: {:?}", t0.elapsed()); + } + + fn sample_sφ(&mut self, priors: &ModelPriors, params: &mut ModelParams) { + // sample s ~ LogNormal + // let t0 = Instant::now(); + Zip::indexed(params.sφ.outer_iter_mut()) // for every component + .and(params.rφ.outer_iter()) + .par_for_each(|t, s_t, r_t| { + let mut rng = thread_rng(); + Zip::from(s_t) // for every hidden dim + .and(r_t) + .and(¶ms.θksum) + .and(params.ωφ.axis_iter(Axis(1))) + .and(params.cell_latent_counts.axis_iter(Axis(1))) + .for_each(|s_tk, r_tk, θ_k_sum, ω_k, x_k| { + // TODO: This would be fatser if we went pre-computing μ, σ [ncomponets, nhidden] matrices + // by processing cells, rather than doing this filtering thing. + let τ = priors.τφ + + ω_k + .iter() + .zip(¶ms.z) + .filter(|(_ω_ck, z_c)| **z_c as usize == t) + .map(|(ω_ck, _z_c)| *ω_ck) + .sum::(); + let σ2 = τ.recip(); + let μ = σ2 + * (priors.μφ * priors.τφ + + Zip::from(x_k) // for every cell + .and(¶ms.z) + .and(ω_k) + .and(¶ms.effective_cell_volume) + .fold(0.0, |acc, x_ck, z_c, ω_ck, v_c| { + if *z_c as usize == t { + acc + (*x_ck as f32 - *r_tk) / 2.0 + - ω_ck * (v_c * *θ_k_sum).ln() + } else { + acc + } + })); + *s_tk = (μ + σ2.sqrt() * randn(&mut rng)).exp(); + }); + }); + // println!(" sample_sφ/LogNormal: {:?}", t0.elapsed()); + } + + fn sample_cell_scales(&mut self, priors: &ModelPriors, params: &mut ModelParams) { + // for each cell + Zip::from(&mut params.cell_scale) + .and(&mut params.effective_cell_volume) + .and(¶ms.log_cell_volume) + .and(params.cell_latent_counts.outer_iter()) + .and(¶ms.z) + .and(params.ωφ.outer_iter()) + .par_for_each(|a_c, eff_v_c, &log_v_c, x_c, &z_c, ω_c| { + let mut rng = thread_rng(); + let z_c = z_c as usize; + + let τ = priors.τv + ω_c.sum(); + let σ2 = τ.recip(); + + // for each hidden dim + let μ = σ2 + * Zip::from(¶ms.θksum) + .and(x_c) + .and(params.rφ.row(z_c)) + .and(ω_c) + .and(params.sφ.row(z_c)) + .fold(0.0, |acc, &θ_k_sum, x_ck, &r_tk, &ω_ck, &s_tk| { + acc + (*x_ck as f32 - r_tk) / 2.0 + - ω_ck * ((s_tk * θ_k_sum).ln() + log_v_c) + }); + + let log_a_c = μ + σ2.sqrt() * randn(&mut rng); + *a_c = log_a_c.exp(); + *eff_v_c = (log_a_c + log_v_c).exp(); }); } @@ -1660,80 +1884,8 @@ where }); } - fn sample_component_assignments(&mut self, _priors: &ModelPriors, params: &mut ModelParams) { - let ncomponents = params.ncomponents(); - - // loop over cells - Zip::from(params.foreground_counts.axis_iter(Axis(0))) - .and(&mut params.z) - .and(¶ms.cell_log_volume) - .par_for_each(|cs, z_i, cell_log_volume| { - let mut z_probs = params - .z_probs - .get_or(|| RefCell::new(vec![0_f64; ncomponents])) - .borrow_mut(); - - // loop over components - for (zp, π, φs, rs, lgamma_r, loggammaplus, &μ_volume, &σ_volume) in izip!( - z_probs.iter_mut(), - ¶ms.π, - params.φ.rows(), - params.r.rows(), - params.lgamma_r.rows(), - params.loggammaplus.rows(), - ¶ms.μ_volume, - ¶ms.σ_volume - ) { - // sum over genes - *zp = (*π as f64) - * (Zip::from(cs.axis_iter(Axis(0))) - .and(rs) - .and(φs) - .and(lgamma_r) - .and(loggammaplus) - .fold(0_f32, |accum, cs, &r, φ, &lgamma_r, lgammaplus| { - let ψ = φ + cell_log_volume; - let c = cs.iter().map(|&x| x as u32).sum(); // sum counts across layers - accum - + negbin_logpmf_fast( - r, - lgamma_r, - lgammaplus.eval(c), - logistic(ψ), - c, - params.logfactorial.eval(c), - ) - }) as f64) - .exp(); - - *zp *= normal_pdf(μ_volume, σ_volume, *cell_log_volume).exp() as f64; - } - - // z_probs.iter_mut().enumerate().for_each(|(j, zp)| { - // *zp = (self.params.π[j] as f64) * - // negbin_logpmf(r, lgamma_r, p, k) - // // (self.params.cell_logprob_fast(j as usize, *cell_area, &cs, &clfs) as f64).exp(); - // }); - - let z_prob_sum = z_probs.iter().sum::(); - - assert!(z_prob_sum.is_finite()); - - // cumulative probabilities in-place - z_probs.iter_mut().fold(0.0, |mut acc, x| { - acc += *x / z_prob_sum; - *x = acc; - acc - }); - - let rng = &mut thread_rng(); - let u = rng.gen::(); - *z_i = z_probs.partition_point(|x| *x < u) as u32; - }); - } - fn sample_volume_params(&mut self, priors: &ModelPriors, params: &mut ModelParams) { - Zip::from(&mut params.cell_log_volume) + Zip::from(&mut params.log_cell_volume) .and(¶ms.cell_volume) .into_par_iter() .with_min_len(50) @@ -1742,13 +1894,11 @@ where }); // compute sample means - params.component_population.fill(0_u32); params.μ_volume.fill(0_f32); Zip::from(¶ms.z) - .and(¶ms.cell_log_volume) + .and(¶ms.log_cell_volume) .for_each(|&z, &log_volume| { params.μ_volume[z as usize] += log_volume; - params.component_population[z as usize] += 1; }); // dbg!(¶ms.component_population); @@ -1775,7 +1925,7 @@ where // compute sample variances params.σ_volume.fill(0_f32); Zip::from(¶ms.z) - .and(¶ms.cell_log_volume) + .and(¶ms.log_cell_volume) .for_each(|&z, &log_volume| { params.σ_volume[z as usize] += (params.μ_volume[z as usize] - log_volume).powi(2); }); @@ -1897,7 +2047,7 @@ where let λ_prev = if cell_prev == BACKGROUND_CELL { 0.0 } else { - params.λ[[gene, cell_prev as usize]] + params.λ_c[gene] + params.λ[[cell_prev as usize, gene]] + params.λ_c[gene] } + params.λ_bg[[gene, layer_prev]]; let layer_new = @@ -1907,7 +2057,7 @@ where let λ_new = if cell_new == BACKGROUND_CELL { 0.0 } else { - params.λ[[gene, cell_new as usize]] + params.λ_c[gene] + params.λ[[cell_new as usize, gene]] + params.λ_c[gene] } + params.λ_bg[[gene, layer_new]]; let ln_λ_diff = λ_new.ln() - λ_prev.ln(); @@ -1960,9 +2110,12 @@ where transcripts: &Vec, uncertainty: &mut Option<&mut UncertaintyTracker>, ) { + // let t0 = Instant::now(); self.propose_eval_transcript_positions(priors, params, transcripts); + // println!(" REPO: proposals {:?}", t0.elapsed()); // Update position and compute cell and layer changes for updates + // let t0 = Instant::now(); params .transcript_position_updates .par_iter_mut() @@ -1988,34 +2141,26 @@ where } }, ); + // println!(" REPO: update positions {:?}", t0.elapsed()); // Update counts and cell_population + // let t0 = Instant::now(); params .transcript_position_updates .iter() - .zip(transcripts) .zip(¶ms.accept_proposed_transcript_positions) .zip(&mut params.cell_assignments) .zip(&mut params.cell_assignment_time) .enumerate() .for_each( - |( - i, - ((((update, transcript), &accept), cell_assignment), cell_assignment_time), - )| { - let (cell_prev, cell_new, layer_prev, layer_new) = *update; + |(i, (((update, &accept), cell_assignment), cell_assignment_time))| { + let (cell_prev, cell_new, _layer_prev, _layer_new) = *update; if accept { - let gene = transcript.gene as usize; if cell_prev != BACKGROUND_CELL { - assert!( - params.counts[[gene, cell_prev as usize, layer_prev as usize]] > 0 - ); - params.counts[[gene, cell_prev as usize, layer_prev as usize]] -= 1; assert!(params.cell_population[cell_prev as usize] > 0); params.cell_population[cell_prev as usize] -= 1; } if cell_new != BACKGROUND_CELL { - params.counts[[gene, cell_new as usize, layer_new as usize]] += 1; params.cell_population[cell_new as usize] += 1; } @@ -2036,10 +2181,13 @@ where } }, ); + // println!(" REPO: update counts {:?}", t0.elapsed()); + // let t0 = Instant::now(); self.update_transcript_positions( ¶ms.accept_proposed_transcript_positions, ¶ms.transcript_positions, ); + // println!(" REPO: voxel sampler update positions {:?}", t0.elapsed()); } } diff --git a/src/sampler/math.rs b/src/sampler/math.rs index 5e31f9d..1b70e37 100644 --- a/src/sampler/math.rs +++ b/src/sampler/math.rs @@ -1,24 +1,24 @@ -// use libm::{lgammaf, erff}; use libm::lgammaf; use rand::rngs::ThreadRng; use rand::Rng; +use rand_distr::StandardNormal; // pub fn logit(p: f32) -> f32 { // return p.ln() - (1.0 - p).ln(); // } -pub fn logistic(x: f32) -> f32 { - 1.0 / (1.0 + (-x).exp()) +// pub fn sq(x: f32) -> f32 { +// x * x +// } + +pub fn odds_to_prob(q: f32) -> f32 { + q / (1.0 + q) } pub fn relerr(a: f32, b: f32) -> f32 { ((a - b) / a).abs() } -pub fn lfact(k: u32) -> f32 { - lgammaf(k as f32 + 1.0) -} - // // Partial Student-T log-pdf (just the terms that don't cancel out when doing MH sampling) // pub fn studentt_logpdf_part(σ2: f32, df: f32, x2: f32) -> f32 { // return -((df + 1.0) / 2.0) * ((x2 / σ2) / df).ln_1p(); @@ -35,30 +35,27 @@ pub fn normal_x2_logpdf(σ: f32, x2: f32) -> f32 { -x2 / (2.0 * σ.powi(2)) - σ.ln() - LN_SQRT_TWO_PI } -// pub fn negbin_logpmf(r: f32, lgamma_r: f32, p: f32, k: u32) -> f32 { -// let k_ln_factorial = lgammaf(k as f32 + 1.0); -// let lgamma_rpk = lgammaf(r + k as f32); -// return negbin_logpmf_fast(r, lgamma_r, lgamma_rpk, p, k, k_ln_factorial); -// } - -// const SQRT2: f32 = 1.4142135623730951_f32; - -// pub fn normal_cdf(μ: f32, σ: f32, x: f32) -> f32 { -// return 0.5 * (1.0 + erff((x - μ) / (SQRT2 * σ))); +// pub fn gamma_logpdf(shape: f32, scale: f32, x: f32) -> f32 { +// return +// -lgammaf(shape) +// - shape * scale.ln() +// + (shape - 1.0) * x.ln() +// - x / scale; // } -pub fn normal_pdf(μ: f32, σ: f32, x: f32) -> f32 { - let xμ = x - μ; - (-xμ.powi(2) / (2.0 * σ.powi(2))).exp() / (σ * SQRT_TWO_PI) +pub fn rand_crt(rng: &mut ThreadRng, n: u32, r: f32) -> u32 { + (0..n) + .map(|t| rng.gen_bool(r as f64 / (r as f64 + t as f64)) as u32) + .sum() } -pub fn lognormal_logpdf(μ: f32, σ: f32, x: f32) -> f32 { - let xln = x.ln(); - -LN_SQRT_TWO_PI - σ.ln() - xln - ((xln - μ) / σ).powi(2) / 2.0 +pub fn negbin_logpmf(r: f32, lgamma_r: f32, p: f32, k: u32) -> f32 { + let k_ln_factorial = lgammaf(k as f32 + 1.0); + let lgamma_rpk = lgammaf(r + k as f32); + return negbin_logpmf_fast(r, lgamma_r, lgamma_rpk, p, k, k_ln_factorial); } -// Negative binomial log probability function with capacity for precomputing some values. -pub fn negbin_logpmf_fast( +fn negbin_logpmf_fast( r: f32, lgamma_r: f32, lgamma_rpk: f32, @@ -82,68 +79,21 @@ pub fn negbin_logpmf_fast( } } -pub fn rand_crt(rng: &mut ThreadRng, n: u32, r: f32) -> u32 { - (0..n) - .map(|t| rng.gen_bool(r as f64 / (r as f64 + t as f64)) as u32) - .sum() -} - -// log-factorial with precomputed values for small numbers -pub struct LogFactorial { - values: Vec, -} - -impl LogFactorial { - fn new_with_n(n: usize) -> Self { - LogFactorial { - values: Vec::from_iter((0..n as u32).map(lfact)), - } - } +// const SQRT2: f32 = 1.4142135623730951_f32; - pub fn new() -> Self { - LogFactorial::new_with_n(100) - } +// pub fn normal_cdf(μ: f32, σ: f32, x: f32) -> f32 { +// return 0.5 * (1.0 + erff((x - μ) / (SQRT2 * σ))); +// } - pub fn eval(&self, k: u32) -> f32 { - self.values - .get(k as usize) - .map_or_else(|| lfact(k), |&value| value) - } +pub fn normal_logpdf(μ: f32, σ: f32, x: f32) -> f32 { + -LN_SQRT_TWO_PI - σ.ln() - ((x - μ) / σ).powi(2) / 2.0 } -// Partially memoized lgamma(r + k), memoized over k. -#[derive(Clone)] -pub struct LogGammaPlus { - r: f32, - values: Vec, +pub fn lognormal_logpdf(μ: f32, σ: f32, x: f32) -> f32 { + let xln = x.ln(); + -LN_SQRT_TWO_PI - σ.ln() - xln - ((xln - μ) / σ).powi(2) / 2.0 } -impl LogGammaPlus { - fn new_with_n(r: f32, n: usize) -> Self { - LogGammaPlus { - r, - values: Vec::from_iter((0..n as u32).map(|x| lgammaf(r + x as f32))), - } - } - - pub fn new(r: f32) -> Self { - LogGammaPlus::new_with_n(r, 100) - } - - pub fn default() -> Self { - LogGammaPlus::new(0.0) - } - - pub fn reset(&mut self, r: f32) { - self.values.iter_mut().enumerate().for_each(|(k, v)| { - *v = lgammaf(r + k as f32); - }); - self.r = r; - } - - pub fn eval(&self, k: u32) -> f32 { - self.values - .get(k as usize) - .map_or_else(|| lgammaf(self.r + k as f32), |&value| value) - } +pub fn randn(rng: &mut ThreadRng) -> f32 { + return rng.sample::(StandardNormal); } diff --git a/src/sampler/transcripts.rs b/src/sampler/transcripts.rs index 2b6534d..cac6203 100644 --- a/src/sampler/transcripts.rs +++ b/src/sampler/transcripts.rs @@ -4,8 +4,9 @@ use flate2::read::MultiGzDecoder; use itertools::izip; use kiddo::float::kdtree::KdTree; use kiddo::SquaredEuclidean; -use ndarray::Array2; +use ndarray::{Array1, Array2}; use parquet::arrow::arrow_reader::{ParquetRecordBatchReader, ParquetRecordBatchReaderBuilder}; +use regex::Regex; use std::collections::HashMap; use std::fs::File; use std::str; @@ -38,9 +39,36 @@ pub struct TranscriptDataset { pub fov_names: Vec, } +impl TranscriptDataset { + pub fn select_unfactored_genes(&mut self, _nunfactored: usize) { + // Current heuristic is just to select the highest expression genes. + let mut gene_counts = Array1::::zeros(self.transcript_names.len()); + for transcript in self.transcripts.iter() { + gene_counts[transcript.gene as usize] += 1; + } + + let mut ord = (0..self.transcript_names.len()).collect::>(); + ord.sort_unstable_by(|&i, &j| gene_counts[i].cmp(&gene_counts[j]).reverse()); + + let mut rev_ord = vec![0; ord.len()]; + for (i, j) in ord.iter().enumerate() { + rev_ord[*j] = i; + } + + self.transcript_names = ord + .iter() + .map(|&i| self.transcript_names[i].clone()) + .collect(); + for transcript in self.transcripts.iter_mut() { + transcript.gene = rev_ord[transcript.gene as usize] as u32; + } + } +} + #[allow(clippy::too_many_arguments)] pub fn read_transcripts_csv( path: &str, + excluded_genes: Option, transcript_column: &str, id_column: Option, compartment_column: Option, @@ -55,7 +83,7 @@ pub fn read_transcripts_csv( y_column: &str, z_column: &str, min_qv: f32, - no_z_column: bool, + ignore_z_column: bool, coordinate_scale: f32, ) -> TranscriptDataset { let fmt = infer_format_from_filename(path); @@ -65,6 +93,7 @@ pub fn read_transcripts_csv( let mut rdr = csv::Reader::from_path(path).unwrap(); read_transcripts_csv_xyz( &mut rdr, + excluded_genes, transcript_column, id_column, compartment_column, @@ -79,7 +108,7 @@ pub fn read_transcripts_csv( y_column, z_column, min_qv, - no_z_column, + ignore_z_column, coordinate_scale, ) } @@ -87,6 +116,7 @@ pub fn read_transcripts_csv( let mut rdr = csv::Reader::from_reader(MultiGzDecoder::new(File::open(path).unwrap())); read_transcripts_csv_xyz( &mut rdr, + excluded_genes, transcript_column, id_column, compartment_column, @@ -101,12 +131,13 @@ pub fn read_transcripts_csv( y_column, z_column, min_qv, - no_z_column, + ignore_z_column, coordinate_scale, ) } OutputFormat::Parquet => read_xenium_transcripts_parquet( path, + excluded_genes, transcript_column, &id_column.unwrap(), &compartment_column.unwrap(), @@ -119,7 +150,7 @@ pub fn read_transcripts_csv( y_column, z_column, min_qv, - no_z_column, + ignore_z_column, coordinate_scale, ), OutputFormat::Infer => panic!("Could not infer format of file '{}'", path), @@ -188,7 +219,7 @@ fn postprocess_cell_assignments( #[allow(clippy::too_many_arguments)] fn read_transcripts_csv_xyz( rdr: &mut csv::Reader, - + excluded_genes: Option, transcript_column: &str, id_column: Option, compartment_column: Option, @@ -277,6 +308,12 @@ where let transcript_name = &row[transcript_col]; + if let Some(excluded_genes) = &excluded_genes { + if excluded_genes.is_match(transcript_name) { + continue; + } + } + let gene = if let Some(gene) = transcript_name_map.get(transcript_name) { *gene } else { @@ -388,6 +425,7 @@ where #[allow(clippy::too_many_arguments)] fn read_xenium_transcripts_parquet( filename: &str, + excluded_genes: Option, transcript_col_name: &str, id_col_name: &str, compartment_col_name: &str, @@ -421,6 +459,7 @@ fn read_xenium_transcripts_parquet( read_xenium_transcripts_parquet_str_type::( rdr, schema, + excluded_genes, transcript_col_name, id_col_name, compartment_col_name, @@ -441,6 +480,7 @@ fn read_xenium_transcripts_parquet( read_xenium_transcripts_parquet_str_type::( rdr, schema, + excluded_genes, transcript_col_name, id_col_name, compartment_col_name, @@ -465,6 +505,7 @@ fn read_xenium_transcripts_parquet( fn read_xenium_transcripts_parquet_str_type( rdr: ParquetRecordBatchReader, schema: arrow::datatypes::Schema, + excluded_genes: Option, transcript_col_name: &str, id_col_name: &str, compartment_col_name: &str, @@ -587,6 +628,12 @@ where continue; } + if let Some(excluded_genes) = &excluded_genes { + if excluded_genes.is_match(transcript) { + continue; + } + } + let fov = match fov_map.get(fov) { Some(fov) => *fov, None => { diff --git a/src/sampler/voxelsampler.rs b/src/sampler/voxelsampler.rs index 46118db..8806889 100644 --- a/src/sampler/voxelsampler.rs +++ b/src/sampler/voxelsampler.rs @@ -9,7 +9,7 @@ use super::{chunkquad, perimeter_bound, ModelParams, ModelPriors, Proposal, Samp // use arrow; use geo::geometry::{MultiPolygon, Polygon}; use itertools::Itertools; -use ndarray::Array2; +use ndarray::{Array2, Zip}; use rand::{thread_rng, Rng}; use rayon::prelude::*; use std::cell::RefCell; @@ -500,8 +500,6 @@ impl VoxelSampler { } } - params.recompute_counts(transcripts); - // for (transcript, &cell) in transcripts.iter().zip(params.cell_assignments.iter()) { // // let position = clip_z_position( // // (transcript.x, transcript.y, transcript.z), zmin, zmax); @@ -550,6 +548,7 @@ impl VoxelSampler { sampler.recompute_cell_population(); sampler.recompute_cell_perimeter(); sampler.recompute_cell_volume(priors, params); + params.effective_cell_volume.assign(¶ms.cell_volume); sampler.populate_mismatches(); sampler.update_transcript_positions( &vec![true; transcripts.len()], @@ -705,6 +704,14 @@ impl VoxelSampler { assert!(*cell_volume > 0.0); *cell_volume = cell_volume.max(priors.min_cell_volume); } + + Zip::from(&mut params.log_cell_volume) + .and(¶ms.cell_volume) + .into_par_iter() + .with_min_len(50) + .for_each(|(log_volume, &volume)| { + *log_volume = volume.ln(); + }); } fn recompute_cell_population(&mut self) { @@ -1407,6 +1414,7 @@ impl Sampler for VoxelSampler { } fn update_transcript_positions(&mut self, updated: &[bool], positions: &[(f32, f32, f32)]) { + // let t0 = Instant::now(); self.transcript_voxels .par_iter_mut() .zip(positions) @@ -1417,9 +1425,12 @@ impl Sampler for VoxelSampler { *voxel = self.chunkquad.layout.world_pos_to_voxel(position); } }); + // println!(" REPO: compute voxels {:?}", t0.elapsed()); + // let t0 = Instant::now(); self.transcript_voxel_ord - .par_sort_unstable_by_key(|&t| self.transcript_voxels[t]); + .par_sort_by_key(|&t| self.transcript_voxels[t]); + // println!(" REPO: sort on voxel assignment {:?}", t0.elapsed()); } // fn update_transcript_position(&mut self, i: usize, prev_pos: (f32, f32, f32), new_pos: (f32, f32, f32)) {