diff --git a/Cargo.toml b/Cargo.toml index cf763ed..f42bf4a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "proseg" description = "Probabilistic cell segmentation for in situ spatial transcriptomics" -version = "1.1.9" +version = "2.0.0" edition = "2021" authors = ["Daniel C. Jones "] repository = "https://github.com/dcjones/proseg" @@ -38,10 +38,7 @@ itertools = "0.12.1" json = "0.12.4" kiddo = "4.2.0" libm = "0.2.7" -linfa = "0.7.0" -linfa-clustering = "0.7.0" -ndarray = { version = "0.15.6", features = ["rayon"] } -ndarray-conv = "0.2.0" +ndarray = { version = "0.16.1", features = ["rayon", "matrixmultiply-threading"] } num-traits = "0.2.17" numeric_literals = "0.2.0" parquet = "52.2.0" @@ -49,4 +46,8 @@ petgraph = "0.6.3" rand = "0.8.5" rand_distr = "0.4.3" rayon = "1.7.0" +regex = "1.10.6" thread_local = "1.1.7" +faer = "0.19.4" +faer-ext = { version = "0.3.0", features = ["ndarray"] } +clustering = { version = "0.2.1", features = ["parallel"] } diff --git a/src/main.rs b/src/main.rs index d921953..2e428b5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,6 +13,7 @@ use hull::convex_hull_area; use indicatif::{ProgressBar, ProgressStyle}; use itertools::Itertools; use rayon::current_num_threads; +use regex::Regex; use sampler::transcripts::{ coordinate_span, estimate_full_area, filter_cellfree_transcripts, read_transcripts_csv, CellIndex, Transcript, BACKGROUND_CELL, @@ -61,6 +62,10 @@ struct Args { #[arg(long, default_value_t = false)] merfish: bool, + /// Regex pattern matching names of genes/features to be excluded + #[arg(long, default_value = None)] + excluded_genes: Option, + /// Initialize with cell assignments rather than nucleus assignments #[arg(long, default_value_t = false)] use_cell_initialization: bool, @@ -123,7 +128,7 @@ struct Args { /// Ignore the z coordinate if any, treating the data as 2D #[arg(long, default_value_t = false)] - no_z_coord: bool, + ignore_z_coord: bool, /// Filter out transcripts with quality values below this threshold #[arg(long, default_value_t = 0.0_f32)] @@ -139,6 +144,10 @@ struct Args { #[arg(long, default_value_t = 10)] ncomponents: usize, + /// Dimenionality of the latent space + #[arg(long, default_value_t = 100)] + nhidden: usize, + /// Number of z-axis layers used to model background expression #[arg(long, default_value_t = 4)] nbglayers: usize, @@ -257,10 +266,9 @@ struct Args { #[arg(long, value_enum, default_value_t = OutputFormat::Infer)] output_rates_fmt: OutputFormat, - /// Output per-component parameter values - #[arg(long, default_value = None)] - output_component_params: Option, - + // /// Output per-component parameter values + // #[arg(long, default_value = None)] + // output_component_params: Option, #[arg(long, value_enum, default_value_t = OutputFormat::Infer)] output_component_params_fmt: OutputFormat, @@ -292,6 +300,20 @@ struct Args { #[arg(long, value_enum, default_value_t = OutputFormat::Infer)] output_gene_metadata_fmt: OutputFormat, + /// Output cell metagene rates + #[arg(long, default_value=None)] + output_metagene_rates: Option, + + #[arg(long, value_enum, default_value_t = OutputFormat::Infer)] + output_metagene_rates_fmt: OutputFormat, + + /// Output metagene loadings + #[arg(long, default_value=None)] + output_metagene_loadings: Option, + + #[arg(long, value_enum, default_value_t = OutputFormat::Infer)] + output_metagene_loadings_fmt: OutputFormat, + /// Output a table of each voxel in each cell #[arg(long, default_value=None)] output_cell_voxels: Option, @@ -323,6 +345,30 @@ struct Args { /// Use connectivity checks to prevent cells from having any disconnected voxels #[arg(long, default_value_t = true)] enforce_connectivity: bool, + + #[arg(long, default_value_t = 300)] + nunfactored: usize, + + /// Disable factorization model and use genes directly + #[arg(long, default_value_t = false)] + no_factorization: bool, + + /// Disable cell scale factors + #[arg(long, default_value_t = false)] + no_cell_scales: bool, + + // Hyperparameters + #[arg(long, default_value_t = 1.0)] + hyperparam_e_phi: f32, + + #[arg(long, default_value_t = 1.0)] + hyperparam_f_phi: f32, + + #[arg(long, default_value_t = 1.0)] + hyperparam_neg_mu_phi: f32, + + #[arg(long, default_value_t = 0.1)] + hyperparam_tau_phi: f32, } fn set_xenium_presets(args: &mut Args) { @@ -339,6 +385,9 @@ fn set_xenium_presets(args: &mut Args) { args.cell_id_unassigned .get_or_insert(String::from("UNASSIGNED")); args.qv_column.get_or_insert(String::from("qv")); + args.excluded_genes.get_or_insert(String::from( + "^(Deprecated|NegControl|Unassigned|Intergenic)", + )); // newer xenium data does have a fov column args.fov_column.get_or_insert(String::from("fov_name")); @@ -406,9 +455,10 @@ fn set_visiumhd_presets(args: &mut Args) { args.z_column.get_or_insert(String::from("z")); // ignored args.cell_id_column.get_or_insert(String::from("cell")); args.cell_id_unassigned.get_or_insert(String::from("0")); - args.initial_voxel_size = Some(4.0); + args.initial_voxel_size = Some(1.0); args.voxel_layers = 1; - args.no_z_coord = true; + args.nbglayers = 1; + args.ignore_z_coord = true; // TODO: This is the resolution on the one dataset I have. It probably // doesn't generalize. @@ -506,8 +556,11 @@ fn main() { mut cell_assignments, mut nucleus_population) = */ + let excluded_genes = args.excluded_genes.map(|pat| Regex::new(&pat).unwrap()); + let mut dataset = read_transcripts_csv( &args.transcript_csv, + excluded_genes, &expect_arg(args.gene_column, "gene-column"), args.transcript_id_column, args.compartment_column, @@ -522,10 +575,22 @@ fn main() { &expect_arg(args.y_column, "y-column"), &expect_arg(args.z_column, "z-column"), args.min_qv, - args.no_z_coord, + args.ignore_z_coord, args.coordinate_scale.unwrap_or(1.0), ); + if !args.no_factorization { + dataset.select_unfactored_genes(args.nunfactored); + } + + // let cd3e_idx = dataset + // .transcript_names + // .iter() + // .position(|gene| gene == "CXCL6") + // .unwrap(); + // dbg!(cd3e_idx); + // panic!(); + // Warn if any nucleus has extremely high population, which is likely // an error interpreting the file. dataset.nucleus_population.iter().for_each(|&p| { @@ -654,6 +719,8 @@ fn main() { Some(args.burnin_dispersion) }, + use_cell_scales: !args.no_cell_scales, + min_cell_volume, μ_μ_volume: (2.0 * mean_nucleus_area * zspan).ln(), @@ -661,12 +728,26 @@ fn main() { α_σ_volume: 0.1, β_σ_volume: 0.1, - e_r: 1.0, + use_factorization: !args.no_factorization, + + // TODO: mean/var ratio is always 1/fφ, but that doesn't seem like the whole + // story. Seems like it needs to change as a function of the dimensionality + // of the latent space. - e_h: 1.0, - f_h: 1.0, + // I also don't know if this "severe prior" approach is going to work in + // the long run because we may have far more cells. Needs more testing. + // eφ: 1000.0, + // fφ: 1.0, - γ: 1.0, + // μφ: -20.0, + // τφ: 0.1, + αθ: 1e-1, + + eφ: args.hyperparam_e_phi, + fφ: args.hyperparam_f_phi, + + μφ: -args.hyperparam_neg_mu_phi, + τφ: args.hyperparam_tau_phi, α_bg: 1.0, β_bg: 1.0, @@ -692,6 +773,8 @@ fn main() { σ_z_diffusion_proposal: 0.2 * zspan, σ_z_diffusion: 0.2 * zspan, + τv: 10.0, + zmin, zmax, @@ -709,6 +792,8 @@ fn main() { &dataset.nucleus_population, &dataset.cell_assignments, args.ncomponents, + args.nhidden, + args.nunfactored, args.nbglayers, ncells, ngenes, @@ -736,7 +821,9 @@ fn main() { initial_voxel_size, chunk_size, )); - sampler.borrow_mut().initialize(&priors, &mut params); + sampler + .borrow_mut() + .initialize(&priors, &mut params, &dataset.transcripts); let mut total_steps = 0; @@ -856,12 +943,12 @@ fn main() { ¶ms, &dataset.transcript_names, ); - write_component_params( - &args.output_component_params, - args.output_component_params_fmt, - ¶ms, - &dataset.transcript_names, - ); + // write_component_params( + // &args.output_component_params, + // args.output_component_params_fmt, + // ¶ms, + // &dataset.transcript_names, + // ); write_cell_metadata( &args.output_cell_metadata, args.output_cell_metadata_fmt, @@ -890,6 +977,17 @@ fn main() { &dataset.transcript_names, &ecounts, ); + write_metagene_rates( + &args.output_metagene_rates, + args.output_metagene_rates_fmt, + ¶ms.φ, + ); + write_metagene_loadings( + &args.output_metagene_loadings, + args.output_metagene_loadings_fmt, + &dataset.transcript_names, + ¶ms.θ, + ); write_voxels( &args.output_cell_voxels, args.output_cell_voxels_fmt, @@ -935,6 +1033,7 @@ fn run_hexbin_sampler( for _ in 0..niter { // sampler.check_perimeter_bounds(priors); + // let t0 = std::time::Instant::now(); if sample_cell_regions { // let t0 = std::time::Instant::now(); for _ in 0..local_steps_per_iter { @@ -942,16 +1041,17 @@ fn run_hexbin_sampler( priors, params, &mut proposal_stats, - transcripts, hillclimb, &mut uncertainty, ); } // println!("Sample cell regions: {:?}", t0.elapsed()); } + // println!("Sample cell regions: {:?}", t0.elapsed()); + // let t0 = std::time::Instant::now(); sampler.sample_global_params(priors, params, transcripts, &mut uncertainty, burnin); - // println!("Sample parameters: {:?}", t0.elapsed()); + // println!("Sample global parameters: {:?}", t0.elapsed()); let nassigned = params.nassigned(); let nforeground = params.nforeground(); diff --git a/src/output.rs b/src/output.rs index 45bb6c0..88c62f6 100644 --- a/src/output.rs +++ b/src/output.rs @@ -1,7 +1,7 @@ use arrow::array::RecordBatch; -use arrow::datatypes::{Schema, Field, DataType}; -use arrow::error::ArrowError; use arrow::csv; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::error::ArrowError; use parquet::errors::ParquetError; use parquet::arrow::ArrowWriter; use parquet::file::properties::WriterProperties; @@ -9,22 +9,18 @@ use parquet::basic::{Compression::ZSTD, ZstdLevel}; use flate2::write::GzEncoder; use flate2::Compression; use geo::MultiPolygon; -use ndarray::{Array1, Array2, Axis, Zip}; +use ndarray::{Array1, Array2, Axis}; use std::fs::File; use std::io::Write; use std::sync::Arc; -use crate::schemas::{transcript_metadata_schema, OutputFormat}; use super::sampler::transcripts::Transcript; use super::sampler::transcripts::BACKGROUND_CELL; use super::sampler::voxelsampler::VoxelSampler; use super::sampler::{ModelParams, TranscriptState}; +use crate::schemas::{transcript_metadata_schema, OutputFormat}; -pub fn write_table( - filename: &str, - fmt: OutputFormat, - batch: &RecordBatch, -) { +pub fn write_table(filename: &str, fmt: OutputFormat, batch: &RecordBatch) { let fmt = match fmt { OutputFormat::Infer => infer_format_from_filename(filename), _ => fmt, @@ -55,23 +51,15 @@ pub fn write_table( } } -fn write_table_csv( - output: &mut W, - batch: &RecordBatch, -) -> Result<(), ArrowError> +fn write_table_csv(output: &mut W, batch: &RecordBatch) -> Result<(), ArrowError> where W: std::io::Write, { - let mut writer = csv::WriterBuilder::new() - .with_header(true) - .build(output); + let mut writer = csv::WriterBuilder::new().with_header(true).build(output); writer.write(batch) } -fn write_table_parquet( - output: &mut W, - batch: &RecordBatch, -) -> Result<(), ParquetError> +fn write_table_parquet(output: &mut W, batch: &RecordBatch) -> Result<(), ParquetError> where W: std::io::Write + Send, { @@ -109,9 +97,8 @@ pub fn write_counts( let schema = Schema::new( transcript_names .iter() - .map(|name| { - Field::new(name, DataType::UInt32, false) - }).collect::>() + .map(|name| Field::new(name, DataType::UInt32, false)) + .collect::>(), ); let mut columns: Vec> = Vec::new(); @@ -121,10 +108,7 @@ pub fn write_counts( )); } - let batch = RecordBatch::try_new( - Arc::new(schema), - columns - ).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), columns).unwrap(); write_table(output_counts, output_counts_fmt, &batch); } @@ -140,9 +124,8 @@ pub fn write_expected_counts( let schema = Schema::new( transcript_names .iter() - .map(|name| { - Field::new(name, DataType::Float32, false) - }).collect::>() + .map(|name| Field::new(name, DataType::Float32, false)) + .collect::>(), ); let mut columns: Vec> = Vec::new(); @@ -152,93 +135,148 @@ pub fn write_expected_counts( )); } - let batch = RecordBatch::try_new( - Arc::new(schema), - columns - ).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), columns).unwrap(); - write_table( - output_expected_counts, - output_expected_counts_fmt, - &batch, - ); + write_table(output_expected_counts, output_expected_counts_fmt, &batch); } } -pub fn write_rates( - output_rates: &Option, - output_rates_fmt: OutputFormat, - params: &ModelParams, - transcript_names: &[String], + +pub fn write_metagene_rates( + output_metagene_rates: &Option, + output_metagene_rates_fmt: OutputFormat, + φ: &Array2 ) { - if let Some(output_rates) = output_rates { + if let Some(output_metagene_rates) = output_metagene_rates { + let k = φ.shape()[1]; let schema = Schema::new( - transcript_names - .iter() - .map(|name| { - Field::new(name, DataType::Float32, false) - }).collect::>() + (0..k) + .map(|name| Field::new(format!("phi{}", name), DataType::Float32, false)) + .collect::>(), ); let mut columns: Vec> = Vec::new(); - for row in params.λ.rows() { + for row in φ.columns() { columns.push(Arc::new( row.iter().cloned().collect::(), )); } - let batch = RecordBatch::try_new( - Arc::new(schema), - columns - ).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), columns).unwrap(); - write_table(output_rates, output_rates_fmt, &batch); + write_table(output_metagene_rates, output_metagene_rates_fmt, &batch); } } -pub fn write_component_params( - output_component_params: &Option, - output_component_params_fmt: OutputFormat, - params: &ModelParams, + +pub fn write_metagene_loadings( + output_metagene_rates: &Option, + output_metagene_rates_fmt: OutputFormat, transcript_names: &[String], + θ: &Array2 ) { - if let Some(output_component_params) = output_component_params { - // What does this look like: rows for each gene, columns for α1, β1, α2, β2, etc. - let α = ¶ms.r; - let φ = ¶ms.φ; - let β = φ.map(|φ| (-φ).exp()); - - let ncomponents = params.ncomponents(); - - let mut fields = Vec::new(); - fields.push(Field::new("gene", DataType::Utf8, false)); - for i in 0..ncomponents { - fields.push(Field::new(&format!("α_{}", i), DataType::Float32, false)); - fields.push(Field::new(&format!("β_{}", i), DataType::Float32, false)); + if let Some(output_metagene_rates) = output_metagene_rates { + let k = θ.shape()[1]; + + let mut schema = vec![ + Field::new("gene", DataType::Utf8, false), + ]; + + for i in 0..k { + schema.push(Field::new(format!("theta{}", i), DataType::Float32, false)); } - let schema = Schema::new(fields); + + let schema = Schema::new(schema); let mut columns: Vec> = Vec::new(); - columns.push(Arc::new(arrow::array::StringArray::from(transcript_names.iter().cloned().collect::>()))); - Zip::from(α.rows()).and(β.rows()).for_each(|α, β| { - columns.push(Arc::new(α.iter().cloned().collect::())); - columns.push(Arc::new(β.iter().cloned().collect::())); - }); + columns.push( + Arc::new( + transcript_names.iter().map(|gene| Some(gene)).collect::() + ) + ); + + for row in θ.columns() { + columns.push(Arc::new( + row.iter().cloned().collect::(), + )); + } + + let batch = RecordBatch::try_new(Arc::new(schema), columns).unwrap(); + + write_table(output_metagene_rates, output_metagene_rates_fmt, &batch); + } +} - let batch = RecordBatch::try_new( - Arc::new(schema), - columns - ).unwrap(); - write_table( - output_component_params, - output_component_params_fmt, - &batch, +pub fn write_rates( + output_rates: &Option, + output_rates_fmt: OutputFormat, + params: &ModelParams, + transcript_names: &[String], +) { + if let Some(output_rates) = output_rates { + let schema = Schema::new( + transcript_names + .iter() + .map(|name| Field::new(name, DataType::Float32, false)) + .collect::>(), ); + + let mut columns: Vec> = Vec::new(); + for row in params.λ.columns() { + columns.push(Arc::new( + row.iter().cloned().collect::(), + )); + } + + let batch = RecordBatch::try_new(Arc::new(schema), columns).unwrap(); + + write_table(output_rates, output_rates_fmt, &batch); } } +// pub fn write_component_params( +// output_component_params: &Option, +// output_component_params_fmt: OutputFormat, +// params: &ModelParams, +// transcript_names: &[String], +// ) { +// if let Some(output_component_params) = output_component_params { +// // What does this look like: rows for each gene, columns for α1, β1, α2, β2, etc. +// let α = ¶ms.r; +// let φ = ¶ms.φ; +// let β = φ.map(|φ| (-φ).exp()); +// let ncomponents = params.ncomponents(); +// +// let mut fields = Vec::new(); +// fields.push(Field::new("gene", DataType::Utf8, false)); +// for i in 0..ncomponents { +// fields.push(Field::new(&format!("α_{}", i), DataType::Float32, false)); +// fields.push(Field::new(&format!("β_{}", i), DataType::Float32, false)); +// } +// let schema = Schema::new(fields); +// +// let mut columns: Vec> = Vec::new(); +// columns.push(Arc::new(arrow::array::StringArray::from( +// transcript_names.iter().cloned().collect::>(), +// ))); +// +// Zip::from(α.rows()).and(β.rows()).for_each(|α, β| { +// columns.push(Arc::new( +// α.iter().cloned().collect::(), +// )); +// columns.push(Arc::new( +// β.iter().cloned().collect::(), +// )); +// }); +// +// let batch = RecordBatch::try_new(Arc::new(schema), columns).unwrap(); +// +// write_table(output_component_params, output_component_params_fmt, &batch); +// } +// } + // Assign cells to fovs by finding the most common transcript fov of the // assigned transcripts. fn cell_fov_vote( @@ -279,6 +317,7 @@ pub fn write_cell_metadata( fovs: &[u32], fov_names: &[String], ) { + // TODO: write factorization let ncells = cell_centroids.len(); let nfovs = fov_names.len(); let cell_fovs = cell_fov_vote(ncells, nfovs, cell_assignments, fovs); @@ -292,44 +331,78 @@ pub fn write_cell_metadata( Field::new("fov", DataType::Utf8, true), Field::new("cluster", DataType::UInt16, false), Field::new("volume", DataType::Float32, false), + Field::new("scale", DataType::Float32, false), Field::new("population", DataType::UInt64, false), ]); let columns: Vec> = vec![ - Arc::new((0..params.ncells() as u32).collect::()), - Arc::new(cell_centroids.iter().map(|(x, _, _)| *x).collect::()), - Arc::new(cell_centroids.iter().map(|(_, y, _)| *y).collect::()), - Arc::new(cell_centroids.iter().map(|(_, _, z)| *z).collect::()), Arc::new( - cell_fovs.iter().map( - |fov| { + cell_centroids + .iter() + .map(|(x, _, _)| *x) + .collect::(), + ), + Arc::new( + cell_centroids + .iter() + .map(|(_, y, _)| *y) + .collect::(), + ), + Arc::new( + cell_centroids + .iter() + .map(|(_, _, z)| *z) + .collect::(), + ), + Arc::new( + cell_fovs + .iter() + .map(|fov| { if *fov == u32::MAX { None } else { Some(fov_names[*fov as usize].clone()) } - }, - ).collect::()), - Arc::new(params.z.iter().map(|&z| z as u16).collect::()), - Arc::new(params.cell_volume.iter().cloned().collect::()), - Arc::new(params.cell_population.iter().map(|&p| p as u64).collect::()) + }) + .collect::(), + ), + Arc::new( + params + .z + .iter() + .map(|&z| z as u16) + .collect::(), + ), + Arc::new( + params + .cell_volume + .iter() + .cloned() + .collect::(), + ), + Arc::new( + params + .cell_scale + .iter() + .cloned() + .collect::(), + ), + Arc::new( + params + .cell_population + .iter() + .map(|&p| p as u64) + .collect::(), + ), ]; - let batch = RecordBatch::try_new( - Arc::new(schema), - columns - ).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), columns).unwrap(); - write_table( - output_cell_metadata, - output_cell_metadata_fmt, - &batch, - ); + write_table(output_cell_metadata, output_cell_metadata_fmt, &batch); } } - #[allow(clippy::too_many_arguments)] pub fn write_transcript_metadata( output_transcript_metadata: &Option, @@ -351,64 +424,86 @@ pub fn write_transcript_metadata( let columns: Vec> = vec![ Arc::new( - transcripts.iter().map(|t| t.transcript_id).collect::() + transcripts + .iter() + .map(|t| t.transcript_id) + .collect::(), ), Arc::new( - transcript_positions.iter().map(|(x, _, _)| *x).collect::() + transcript_positions + .iter() + .map(|(x, _, _)| *x) + .collect::(), ), Arc::new( - transcript_positions.iter().map(|(_, y, _)| *y).collect::() + transcript_positions + .iter() + .map(|(_, y, _)| *y) + .collect::(), ), Arc::new( - transcript_positions.iter().map(|(_, _, z)| *z).collect::() + transcript_positions + .iter() + .map(|(_, _, z)| *z) + .collect::(), ), Arc::new( - transcripts.iter().map(|t| t.x).collect::() + transcripts + .iter() + .map(|t| t.x) + .collect::(), ), Arc::new( - transcripts.iter().map(|t| t.y).collect::() + transcripts + .iter() + .map(|t| t.y) + .collect::(), ), Arc::new( - transcripts.iter().map(|t| t.z).collect::() + transcripts + .iter() + .map(|t| t.z) + .collect::(), ), Arc::new( transcripts .iter() .map(|t| Some(transcript_names[t.gene as usize].clone())) - .collect::() - ), - Arc::new( - qvs.iter().cloned().collect::() + .collect::(), ), + Arc::new(qvs.iter().cloned().collect::()), Arc::new( fovs.iter() .map(|fov| Some(fov_names[*fov as usize].clone())) - .collect::() + .collect::(), ), Arc::new( - cell_assignments.iter().map(|(cell, _)| *cell).collect::() + cell_assignments + .iter() + .map(|(cell, _)| *cell) + .collect::(), ), Arc::new( - cell_assignments.iter().map(|(_, pr)| *pr).collect::() + cell_assignments + .iter() + .map(|(_, pr)| *pr) + .collect::(), ), Arc::new( transcript_state .iter() .map(|&s| (s == TranscriptState::Background) as u8) - .collect::() + .collect::(), ), Arc::new( transcript_state .iter() .map(|&s| (s == TranscriptState::Confusion) as u8) - .collect::() + .collect::(), ), ]; - let batch = RecordBatch::try_new( - Arc::new(schema), - columns - ).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), columns).unwrap(); write_table( output_transcript_metadata, @@ -425,6 +520,7 @@ pub fn write_gene_metadata( transcript_names: &[String], expected_counts: &Array2, ) { + // TODO: write factorization if let Some(output_gene_metadata) = output_gene_metadata { let mut schema_fields = vec![ Field::new("gene", DataType::Utf8, false), @@ -435,7 +531,10 @@ pub fn write_gene_metadata( let mut columns: Vec> = vec![ Arc::new( - transcript_names.iter().map(|s| Some(s.clone())).collect::() + transcript_names + .iter() + .map(|s| Some(s.clone())) + .collect::(), ), Arc::new( params @@ -443,71 +542,73 @@ pub fn write_gene_metadata( .sum_axis(Axis(1)) .iter() .map(|x| *x as u64) - .collect::() + .collect::(), ), Arc::new( expected_counts .sum_axis(Axis(1)) - .iter().cloned() - .collect::() + .iter() + .cloned() + .collect::(), ), // Arc::new(array::Float32Array::from_values( // params.r.iter().cloned(), // )) ]; - // cell type dispersions - for i in 0..params.ncomponents() { - schema_fields.push(Field::new( - &format!("dispersion_{}", i), - DataType::Float32, - false, - )); - columns.push(Arc::new( - params.r.row(i).iter().cloned().collect::() - )); - } - - // cell type rates - for i in 0..params.ncomponents() { - schema_fields.push(Field::new(&format!("λ_{}", i), DataType::Float32, false)); - - let mut λ_component = Array1::::from_elem(params.ngenes(), 0_f32); - let mut count = 0; - Zip::from(¶ms.z) - .and(params.λ.columns()) - .for_each(|&z, λ| { - if i == z as usize { - Zip::from(&mut λ_component).and(λ).for_each(|a, b| *a += b); - count += 1; - } - }); - λ_component /= count as f32; - - columns.push(Arc::new( - λ_component.iter().cloned().collect::() - )); - } + // // cell type dispersions + // schema_fields.push(Field::new("dispersion", DataType::Float32, false)); + // columns.push(Arc::new( + // params + // .r + // .iter() + // .cloned() + // .collect::(), + // )); + + // // cell type rates + // for i in 0..params.ncomponents() { + // schema_fields.push(Field::new(&format!("λ_{}", i), DataType::Float32, false)); + // + + // let mut λ_component = Array1::::from_elem(params.ngenes(), 0_f32); + // let mut count = 0; + // Zip::from(¶ms.z) + // .and(params.λ.columns()) + // .for_each(|&z, λ| { + // if i == z as usize { + // Zip::from(&mut λ_component).and(λ).for_each(|a, b| *a += b); + // count += 1; + // } + // }); + // λ_component /= count as f32; + // + + // columns.push(Arc::new( + // λ_component + // .iter() + // .cloned() + // .collect::(), + // )); + // } // background rates for i in 0..params.nlayers() { schema_fields.push(Field::new(format!("λ_bg_{}", i), DataType::Float32, false)); columns.push(Arc::new( - params.λ_bg.column(i).iter().cloned().collect::() + params + .λ_bg + .column(i) + .iter() + .cloned() + .collect::(), )); } let schema = Schema::new(schema_fields); - let batch = RecordBatch::try_new( - Arc::new(schema), - columns - ).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), columns).unwrap(); - write_table( - output_gene_metadata, - output_gene_metadata_fmt, - &batch, - ); + write_table(output_gene_metadata, output_gene_metadata_fmt, &batch); } } @@ -557,10 +658,7 @@ pub fn write_voxels( Arc::new(z1s.iter().cloned().collect::()), ]; - let batch = RecordBatch::try_new( - Arc::new(schema), - columns - ).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema), columns).unwrap(); write_table(output_voxels, output_voxels_fmt, &batch); } diff --git a/src/sampler.rs b/src/sampler.rs index e444fc1..616feba 100644 --- a/src/sampler.rs +++ b/src/sampler.rs @@ -1,28 +1,29 @@ mod connectivity; -pub mod voxelsampler; mod math; -pub mod polyagamma; +mod polyagamma; mod polygons; mod sampleset; pub mod transcripts; +pub mod voxelsampler; +// use super::output::{write_expected_counts, OutputFormat}; +use crate::hull::convex_hull_area; +use clustering::kmeans; use core::fmt::Debug; +use faer_ext::IntoFaer; use flate2::write::GzEncoder; use flate2::Compression; -use crate::hull::convex_hull_area; use itertools::{izip, Itertools}; -use libm::{lgammaf, log1pf}; -use linfa::traits::{Fit, Predict}; -use linfa::DatasetBase; -use linfa_clustering::KMeans; +use libm::lgammaf; use math::{ - logistic, lognormal_logpdf, negbin_logpmf_fast, normal_pdf, normal_x2_logpdf, normal_x2_pdf, - rand_crt, LogFactorial, LogGammaPlus, + lognormal_logpdf, negbin_logpmf, normal_logpdf, normal_x2_logpdf, normal_x2_pdf, odds_to_prob, + rand_crt, randn, }; -use ndarray::{Array1, Array2, Array3, Axis, Zip}; +use ndarray::linalg::general_mat_mul; +use ndarray::{s, Array1, Array2, Axis, Zip}; use polyagamma::PolyaGamma; use rand::{thread_rng, Rng}; -use rand_distr::{Dirichlet, Distribution, Gamma, Normal, StandardNormal}; +use rand_distr::{Binomial, Distribution, Gamma, Normal, StandardNormal}; use rayon::prelude::*; use std::cell::RefCell; use std::collections::HashMap; @@ -35,6 +36,11 @@ use transcripts::{CellIndex, Transcript, BACKGROUND_CELL}; // use std::time::Instant; +// use std::any::type_name; +// fn print_type_of(_: &T) { +// println!("{}", std::any::type_name::()) +// } + // Bounding perimeter as some multiple of the perimiter of a sphere with the // same volume. This of course is all on a lattice, so it's approximate. // `eta` is the scaling factor between number of mismatching neighbors on the @@ -72,6 +78,8 @@ pub struct ModelPriors { pub dispersion: Option, pub burnin_dispersion: Option, + pub use_cell_scales: bool, + pub min_cell_volume: f32, // params for normal prior @@ -82,14 +90,18 @@ pub struct ModelPriors { pub α_σ_volume: f32, pub β_σ_volume: f32, - // gamma rate prior - pub e_r: f32, + pub use_factorization: bool, - pub e_h: f32, - pub f_h: f32, + // dirichlet prior on θ + pub αθ: f32, - // normal precision parameter for β - pub γ: f32, + // gamma prior on rφ + pub eφ: f32, + pub fφ: f32, + + // log-normal prior on sφ + pub μφ: f32, + pub τφ: f32, // gamma prior for background rates pub α_bg: f32, @@ -119,6 +131,9 @@ pub struct ModelPriors { pub σ_z_diffusion_proposal: f32, pub σ_z_diffusion: f32, + // prior precision on effective log cell volume + pub τv: f32, + // bounds on z coordinate pub zmin: f32, pub zmax: f32, @@ -128,6 +143,7 @@ pub struct ModelPriors { } // Model global parameters. +#[allow(non_snake_case)] pub struct ModelParams { pub transcript_positions: Vec<(f32, f32, f32)>, proposed_transcript_positions: Vec<(f32, f32, f32)>, @@ -142,12 +158,15 @@ pub struct ModelParams { pub cell_population: Vec, - // per-cell volumes + // [ncells] per-cell volumes pub cell_volume: Array1, - pub cell_log_volume: Array1, + pub log_cell_volume: Array1, - // per-component volumes - pub component_volume: Array1, + // [ncells] cell_volume * cell_scale + pub effective_cell_volume: Array1, + + // [ncells] per-cell "effective" volume scaling factor + pub cell_scale: Array1, // area of the convex hull containing all transcripts full_layer_volume: f32, @@ -159,11 +178,8 @@ pub struct ModelParams { pub transcript_state: Array1, pub prev_transcript_state: Array1, - // [ngenes, ncells, nlayers] transcripts counts - pub counts: Array3, - - // [ncells, ngenes, nlayers] foreground transcripts counts - foreground_counts: Array3, + // [ncells, ngenes] foreground transcripts counts + foreground_counts: Array2, // [ngenes] background transcripts counts confusion_counts: Array1, @@ -174,45 +190,76 @@ pub struct ModelParams { // [ngenes, nlayers] total gene occourance counts pub total_gene_counts: Array2, - // Not parameters, but needed for sampling global params - logfactorial: LogFactorial, + // [ncells, nhidden] + pub cell_latent_counts: Array2, + + // [ngenes, nhidden] + pub gene_latent_counts: Array2, + + // Thread local [ngenes, nhidden] matrices for accumulation + pub gene_latent_counts_tl: ThreadLocal>>, - // TODO: This needs to be an matrix I guess! - loggammaplus: Array2, + // [nhidden] + pub latent_counts: Array1, - pub z: Array1, // assignment of cells to components + // [nhidden] thread local storage for sampling latent counts + pub multinomial_rates: ThreadLocal>>, + pub multinomial_sample: ThreadLocal>>, - component_population: Array1, // number of cells assigned to each component + // [ncells, ncomponents] space for sampling component assignments + pub z_probs: ThreadLocal>>, - // thread-local space used for sampling z - z_probs: ThreadLocal>>, + // [ncells] assignment of cells to components + pub z: Array1, - π: Vec, // mixing proportions over components + // [ncomponents] component probabilities + pub π: Array1, + + // [ncomponents] number of cells assigned to each component + component_population: Array1, + + // [ncomponents] total volume of each component + component_volume: Array1, + + // [ncomponents, nhidden] + component_latent_counts: Array2, μ_volume: Array1, // volume dist mean param by component σ_volume: Array1, // volume dist std param by component - // Prior on NB dispersion parameters - h: f32, + // [ncells, nhidden]: cell ψ parameter in the latent space + pub φ: Array2, + + // [nhidden]: precompute φ_k.dot(cell_volume) + φ_v_dot: Array1, - // [ncells, ngenes] Polya-gamma samples, used for sampling NB rates - ω: Array2, + // For components as well??? + // [ncells, nhidden] aux CRT variables for sampling rφ + // pub lφ: Array2, - // [ncomponents, ngenes] NB logit(p) parameters - pub φ: Array2, + // [ncells, nhidden] aux CRT variables for sampling rφ + pub lφ: Array2, + + // [ncells, nhidden] aux PolyaGamma variables for sampling sφ + pub ωφ: Array2, + + // [ncomponents, nhidden] φ gamma shape parameters + pub rφ: Array2, - // [ncomponents, ngenes] Parameters for sampling φ - μ_φ: Array2, - σ_φ: Array2, + // for precomputing lgamma(rφ) + lgamma_rφ: Array2, - // [ncomponents, ngenes] NB r parameters. - pub r: Array2, + // [ncomponents, nhidden] φ gamma scale parameters + pub sφ: Array2, - // [ncomponents, ngenes] gamama parameters for sampling r - uv: Array2<(u32, f32)>, + // Size of the upper block of θ that is the identity matrix + nunfactored: usize, - // Precomputing lgamma(r) - lgamma_r: Array2, + // [ngenes, nhidden]: gene loadings in the latent space + pub θ: Array2, + + // [nhidden]: Sums across the first axis of θ + pub θksum: Array1, // // [ncomponents, ngenes] NB p parameters. // θ: Array2, @@ -223,7 +270,7 @@ pub struct ModelParams { // // log(1 - ods_to_prob(θ)) // log1mp: Array2, - // [ngenes, ncells] Poisson rates + // [ncells, ngenes] Poisson rates pub λ: Array2, // [ngenes, nlayers] background rate: rate at which halucinate transcripts @@ -241,6 +288,7 @@ impl ModelParams { // initialize model parameters, with random cell assignments // and other parameterz unninitialized. #[allow(clippy::too_many_arguments)] + #[allow(non_snake_case)] pub fn new( priors: &ModelPriors, full_layer_volume: f32, @@ -252,6 +300,8 @@ impl ModelParams { init_cell_population: &[usize], prior_seg_cell_assignment: &[u32], ncomponents: usize, + nhidden: usize, + nunfactored: usize, nlayers: usize, ncells: usize, ngenes: usize, @@ -264,9 +314,12 @@ impl ModelParams { if initial_perturbation_sd > 0.0 { let mut rng = rand::thread_rng(); for pos in &mut transcript_positions { - pos.0 += rng.sample::(StandardNormal) * initial_perturbation_sd; - pos.1 += rng.sample::(StandardNormal) * initial_perturbation_sd; - pos.2 += rng.sample::(StandardNormal) * initial_perturbation_sd; + pos.0 += + rng.sample::(StandardNormal) * initial_perturbation_sd; + pos.1 += + rng.sample::(StandardNormal) * initial_perturbation_sd; + pos.2 += + rng.sample::(StandardNormal) * initial_perturbation_sd; } } @@ -274,77 +327,103 @@ impl ModelParams { let accept_proposed_transcript_positions = vec![false; transcripts.len()]; let transcript_position_updates = vec![(0, 0, 0, 0); transcripts.len()]; - let r = Array2::::from_elem((ncomponents, ngenes), 100.0_f32); + let nhidden = if priors.use_factorization { + nhidden + nunfactored + } else { + ngenes + }; + + // TODO: save some space when not using factorization + let lφ = Array2::::zeros((ncells, nhidden)); + let ωφ = Array2::::zeros((ncells, nhidden)); + let rφ = Array2::::from_elem((ncomponents, nhidden), 1.0_f32); + let lgamma_rφ = Array2::::from_elem((ncomponents, nhidden), 1.0_f32); + let sφ = Array2::::from_elem((ncomponents, nhidden), 1.0_f32); let cell_volume = Array1::::zeros(ncells); - let cell_log_volume = Array1::::zeros(ncells); - let h = 10.0; + let log_cell_volume = Array1::::zeros(ncells); + let cell_scale = Array1::::from_elem(ncells, 1.0_f32); + let effective_cell_volume = cell_volume.clone(); // compute initial counts - let mut counts = Array3::::from_elem((ngenes, ncells, nlayers), 0); + let mut counts = Array2::::zeros((ngenes, ncells)); let mut total_gene_counts = Array2::::from_elem((ngenes, nlayers), 0); for (i, &j) in init_cell_assignments.iter().enumerate() { let gene = transcripts[i].gene as usize; let layer = ((transcripts[i].z - z0) / layer_depth) as usize; if j != BACKGROUND_CELL { - counts[[gene, j as usize, layer]] += 1; + counts[[gene, j as usize]] += 1.0; } total_gene_counts[[gene, layer]] += 1; } // initial component assignments - let norm_constant = 1e4; - let mut init_samples = counts - .sum_axis(Axis(2)) - .map(|&x| (x as f32)) - .reversed_axes(); - init_samples.rows_mut().into_iter().for_each(|mut row| { - let rowsum = row.sum(); - row.mapv_inplace(|x| (norm_constant * (x / rowsum)).ln_1p()); + let norm_constant = 1e2; + counts.columns_mut().into_iter().for_each(|mut col| { + let colsum = col.sum(); + col.mapv_inplace(|x| (norm_constant * (x / colsum)).ln_1p()); + // col.mapv_inplace(|x| x.ln_1p()); }); - let init_samples = DatasetBase::from(init_samples); - - // log1p transformed counts - // let init_samples = - // DatasetBase::from(counts.sum_axis(Axis(2)).map(|&x| (x as f32).ln_1p()).reversed_axes()); - - let rng = rand::thread_rng(); - let model = KMeans::params_with_rng(ncomponents, rng) - .tolerance(1e-1) - .fit(&init_samples) - .expect("kmeans failed to converge"); - - let z = model.predict(&init_samples).map(|&x| x as u32); - - // let rng = rand::thread_rng(); - // let model = GaussianMixtureModel::params_with_rng(ncomponents, rng) - // .tolerance(1e-1) - // .fit(&init_samples) - // .expect("gmm failed to converge"); - - // let z = model.predict(&init_samples).map(|&x| x as u32); - - // // initial component assignments - // let mut rng = rand::thread_rng(); - // let z = (0..ncells) - // .map(|_| rng.gen_range(0..ncomponents) as u32) - // .collect::>() - // .into(); - - let uv = Array2::<(u32, f32)>::from_elem((ncomponents, ngenes), (0_u32, 0_f32)); - let lgamma_r = Array2::::from_elem((ncomponents, ngenes), 0.0); - let loggammaplus = - Array2::::from_elem((ncomponents, ngenes), LogGammaPlus::default()); - let ω = Array2::::from_elem((ncells, ngenes), 0.0); - let φ = Array2::::from_elem((ncomponents, ngenes), 0.0); - let μ_φ = Array2::::from_elem((ncomponents, ngenes), 0.0); - let σ_φ = Array2::::from_elem((ncomponents, ngenes), 0.0); - - let component_volume = Array1::::from_elem(ncomponents, 0.0); + + // subtract mean + for mut counts_i in counts.rows_mut() { + let mean = counts_i.mean().unwrap(); + counts_i -= mean; + } + + let counts = counts.view().into_faer(); + let counts_svd = counts.thin_svd(); + + // truncate the svd for kmeans (I hope it's sorted...) + let svd_dim = counts_svd.v().shape().1.min(100); + let embedding: Vec> = counts_svd + .v() + .row_iter() + .map(|row| row.get(0..svd_dim).iter().cloned().collect()) + .collect(); + + let kmeans_results = kmeans(ncomponents, &embedding, 100); + + let z: Array1 = kmeans_results + .membership + .iter() + .map(|z_c| *z_c as u32) + .collect(); + + let z_probs = ThreadLocal::new(); + + let π = Array1::::from_elem(ncomponents, 1.0 / ncomponents as f32); + + let mut rng = rand::thread_rng(); + let φ = Array2::::from_shape_simple_fn((ncells, nhidden), || randn(&mut rng).exp()); + let φ_v_dot = Array1::::zeros(nhidden); + let mut θ = + Array2::::from_shape_simple_fn((ngenes, nhidden), || randn(&mut rng).exp()); + let θksum = Array1::::zeros(nhidden); + + θ.fill(0.0); + if !priors.use_factorization { + θ.diag_mut().fill(1.0); + } + + // fix the upper block of θ to the identity matrix + if nunfactored > 0 { + let mut θunfac = θ.slice_mut(s![0..nunfactored, 0..nunfactored]); + θunfac.diag_mut().fill(1.0); + } + let transcript_state = Array1::::from_elem(transcripts.len(), TranscriptState::Foreground); let prev_transcript_state = Array1::::from_elem(transcripts.len(), TranscriptState::Foreground); + let foreground_counts = Array2::::from_elem((ncells, ngenes), 0); + let confusion_counts = Array1::::from_elem(ngenes, 0); + let background_counts = Array2::::from_elem((ngenes, nlayers), 0); + let cell_latent_counts = Array2::::from_elem((ncells, nhidden), 0); + let gene_latent_counts = Array2::::from_elem((ngenes, nhidden), 0); + let gene_latent_counts_tl = ThreadLocal::new(); + let latent_counts = Array1::::from_elem(nhidden, 0); + ModelParams { transcript_positions, proposed_transcript_positions, @@ -356,74 +435,54 @@ impl ModelParams { cell_assignment_time: vec![0; init_cell_assignments.len()], cell_population: init_cell_population.to_vec(), cell_volume, - cell_log_volume, - component_volume, + log_cell_volume, + cell_scale, + effective_cell_volume, full_layer_volume, z0, layer_depth, transcript_state, prev_transcript_state, - counts, - foreground_counts: Array3::::from_elem((ncells, ngenes, nlayers), 0), - confusion_counts: Array1::::from_elem(ngenes, 0), - background_counts: Array2::::from_elem((ngenes, nlayers), 0), + foreground_counts, + confusion_counts, + background_counts, total_gene_counts, - logfactorial: LogFactorial::new(), - loggammaplus, + cell_latent_counts, + gene_latent_counts, + gene_latent_counts_tl, + latent_counts, + multinomial_rates: ThreadLocal::new(), + multinomial_sample: ThreadLocal::new(), + z_probs, z, + π, component_population: Array1::::from_elem(ncomponents, 0), - z_probs: ThreadLocal::new(), - π: vec![1_f32 / (ncomponents as f32); ncomponents], + component_volume: Array1::::from_elem(ncomponents, 0.0), + component_latent_counts: Array2::::from_elem((ncomponents, nhidden), 0), μ_volume: Array1::::from_elem(ncomponents, priors.μ_μ_volume), σ_volume: Array1::::from_elem(ncomponents, priors.σ_μ_volume), - h, - ω, φ, - μ_φ, - σ_φ, - r, - uv, - lgamma_r, + φ_v_dot, + nunfactored, + θ, + θksum, + lφ, + ωφ, + sφ, + rφ, + lgamma_rφ, // θ: Array2::::from_elem((ncomponents, ngenes), 0.1), - λ: Array2::::from_elem((ngenes, ncells), 0.1), + λ: Array2::::from_elem((ncells, ngenes), 0.1), λ_bg: Array2::::from_elem((ngenes, nlayers), 0.0), λ_c: Array1::::from_elem(ngenes, 1e-4), t: 0, } } - pub fn ncomponents(&self) -> usize { - self.π.len() - } - - fn zlayer(&self, z: f32) -> usize { - let layer = ((z - self.z0) / self.layer_depth).max(0.0) as usize; - layer.min(self.nlayers() - 1) - } - - fn recompute_counts(&mut self, transcripts: &[Transcript]) { - self.counts.fill(0); - for (i, &j) in self.cell_assignments.iter().enumerate() { - let gene = transcripts[i].gene as usize; - if j != BACKGROUND_CELL { - let layer = self.zlayer(self.transcript_positions[i].2); - self.counts[[gene, j as usize, layer]] += 1; - } - } - - self.check_counts(transcripts); - } - - fn check_counts(&self, transcripts: &[Transcript]) { - for (i, (transcript, &assignment)) in - transcripts.iter().zip(&self.cell_assignments).enumerate() - { - let layer = self.zlayer(self.transcript_positions[i].2); - if assignment != BACKGROUND_CELL { - assert!(self.counts[[transcript.gene as usize, assignment as usize, layer]] > 0); - } - } - } + // fn zlayer(&self, z: f32) -> usize { + // let layer = ((z - self.z0) / self.layer_depth).max(0.0) as usize; + // layer.min(self.nlayers() - 1) + // } pub fn nforeground(&self) -> usize { self.foreground_counts.iter().map(|x| *x as usize).sum() @@ -450,7 +509,7 @@ impl ModelParams { pub fn log_likelihood(&self, priors: &ModelPriors) -> f32 { // iterate over cells - let mut ll = Zip::from(self.λ.columns()) + let mut ll = Zip::from(self.λ.rows()) .and(&self.cell_volume) // .and(self.counts.axis_iter(Axis(1))) .and(self.foreground_counts.axis_iter(Axis(0))) @@ -475,11 +534,6 @@ impl ModelParams { }) - λ * cell_volume }); - // if part < -57983890000.0 { - // dbg!(part, cell_volume, cs.sum()); - // dbg!(λ); - // panic!(); - // } accum + part }); @@ -895,18 +949,19 @@ pub trait Proposal { }); } else { let volume_diff = self.old_cell_volume_delta(); + let a_c = params.cell_scale[old_cell as usize]; let prev_volume = params.cell_volume[old_cell as usize]; let new_volume = prev_volume + volume_diff; // normalization term difference - δ += Zip::from(params.λ.column(old_cell as usize)) - .fold(0.0, |acc, &λ| acc - λ * volume_diff); + δ += Zip::from(params.λ.row(old_cell as usize)) + .fold(0.0, |acc, &λ| acc - λ * a_c * volume_diff); Zip::from(self.gene_count().rows()) .and(params.λ_bg.rows()) .and(¶ms.λ_c) - .and(params.λ.column(old_cell as usize)) + .and(params.λ.row(old_cell as usize)) .for_each(|gene_counts, λ_bg, &λ_c, λ| { Zip::from(gene_counts).and(λ_bg).for_each(|&count, &λ_bg| { if count > 0 { @@ -938,19 +993,20 @@ pub trait Proposal { }); } else { let volume_diff = self.new_cell_volume_delta(); + let a_c = params.cell_scale[new_cell as usize]; let prev_volume = params.cell_volume[new_cell as usize]; let new_volume = prev_volume + volume_diff; // normalization term difference - δ += Zip::from(params.λ.column(new_cell as usize)) - .fold(0.0, |acc, &λ| acc - λ * volume_diff); + δ += Zip::from(params.λ.row(new_cell as usize)) + .fold(0.0, |acc, &λ| acc - λ * a_c * volume_diff); // add in new cell likelihood terms Zip::from(self.gene_count().rows()) .and(params.λ_bg.rows()) .and(¶ms.λ_c) - .and(params.λ.column(new_cell as usize)) + .and(params.λ.row(new_cell as usize)) .for_each(|gene_counts, λ_bg, &λ_c, λ| { Zip::from(gene_counts).and(λ_bg).for_each(|&count, &λ_bg| { if count > 0 { @@ -977,14 +1033,6 @@ pub trait Proposal { if (hillclimb && δ > 0.0) || (!hillclimb && logu < δ + self.log_weight()) { self.accept(); - // TODO: debugging - // if from_background && !to_background { - // dbg!( - // self.log_weight(), - // δ, - // self.gene_count().sum(), - // ); - // } } else { self.reject(); } @@ -997,19 +1045,34 @@ where Self: Sync, { // fn generate_proposals<'b, 'c>(&'b mut self, params: &ModelParams) -> &'c mut [P] where 'b: 'c; - fn initialize(&mut self, priors: &ModelPriors, params: &mut ModelParams) { - Zip::from(&mut params.cell_log_volume) - .and(¶ms.cell_volume) - .into_par_iter() - .with_min_len(50) - .for_each(|(log_volume, &volume)| { - *log_volume = volume.ln(); + fn initialize( + &mut self, + priors: &ModelPriors, + params: &mut ModelParams, + transcripts: &Vec, + ) { + Zip::from(&mut params.φ_v_dot) + .and(params.φ.axis_iter(Axis(1))) + .for_each(|φ_v_dot_k, φ_k| { + *φ_v_dot_k = φ_k.dot(¶ms.cell_volume); + }); + + Zip::from(&mut params.θksum) + .and(params.θ.axis_iter(Axis(1))) + .for_each(|θksum, θ_k| { + *θksum = θ_k.sum(); }); // get to a reasonably high probability assignment - for _ in 0..40 { - self.sample_component_nb_params(priors, params, true); + for _ in 0..20 { + self.sample_transcript_state(priors, params, transcripts, &mut Option::None); + self.compute_counts(priors, params, transcripts); + self.sample_factor_model(priors, params, false, true); + self.sample_background_rates(priors, params); + self.sample_confusion_rates(priors, params); } + + // panic!("Finished initializing"); } fn repopulate_proposals(&mut self, priors: &ModelPriors, params: &ModelParams); @@ -1036,7 +1099,6 @@ where priors: &ModelPriors, params: &mut ModelParams, stats: &mut ProposalStats, - transcripts: &[Transcript], hillclimb: bool, uncertainty: &mut Option<&mut UncertaintyTracker>, ) { @@ -1048,13 +1110,12 @@ where self.proposals_mut() .par_iter_mut() .for_each(|p| p.evaluate(priors, params, hillclimb)); - self.apply_accepted_proposals(stats, transcripts, priors, params, uncertainty); + self.apply_accepted_proposals(stats, priors, params, uncertainty); } fn apply_accepted_proposals( &mut self, stats: &mut ProposalStats, - transcripts: &[Transcript], priors: &ModelPriors, params: &mut ModelParams, uncertainty: &mut Option<&mut UncertaintyTracker>, @@ -1117,15 +1178,6 @@ where cell_volume += proposal.old_cell_volume_delta(); cell_volume = cell_volume.max(priors.min_cell_volume); params.cell_volume[old_cell as usize] = cell_volume; - - for &i in proposal.transcripts() { - let gene = transcripts[i].gene; - let layer = ((params.transcript_positions[i].2 - params.z0) - / params.layer_depth) - .max(0.0) as usize; - let layer = layer.min(params.nlayers() - 1); - params.counts[[gene as usize, old_cell as usize, layer]] -= 1; - } } if new_cell != BACKGROUND_CELL { @@ -1135,15 +1187,6 @@ where cell_volume += proposal.new_cell_volume_delta(); cell_volume = cell_volume.max(priors.min_cell_volume); params.cell_volume[new_cell as usize] = cell_volume; - - for &i in proposal.transcripts() { - let gene = transcripts[i].gene; - let layer = ((params.transcript_positions[i].2 - params.z0) - / params.layer_depth) - .max(0.0) as usize; - let layer = layer.min(params.nlayers() - 1); - params.counts[[gene as usize, new_cell as usize, layer]] += 1; - } } } @@ -1158,13 +1201,11 @@ where uncertainty: &mut Option<&mut UncertaintyTracker>, burnin: bool, ) { - let mut rng = thread_rng(); - // let t0 = Instant::now(); self.sample_volume_params(priors, params); // println!(" Sample volume params: {:?}", t0.elapsed()); - // // Sample background/foreground counts + // Sample background/foreground counts // let t0 = Instant::now(); self.sample_transcript_state(priors, params, transcripts, uncertainty); // println!(" Sample transcript states: {:?}", t0.elapsed()); @@ -1174,41 +1215,12 @@ where // println!(" Compute counts: {:?}", t0.elapsed()); // let t0 = Instant::now(); - self.sample_component_nb_params(priors, params, burnin); - // println!(" Sample nb params: {:?}", t0.elapsed()); - - // Sample λ - // let t0 = Instant::now(); - self.sample_rates(priors, params); - // println!(" Sample λ: {:?}", t0.elapsed()); - - // TODO: - // This is the most expensive part. We could sample this less frequently, - // but we should try to optimize as much as possible. - // Ideas: - // - Main bottlneck is computing log(p) and log(1-p). I don't see - // anything obvious to do about that. + self.sample_factor_model(priors, params, true, burnin); + // println!(" Sample factor model: {:?}", t0.elapsed()); - // Sample z // let t0 = Instant::now(); - self.sample_component_assignments(priors, params); - // println!(" Sample z: {:?}", t0.elapsed()); - - // sample π - let mut α = vec![1_f32; params.ncomponents()]; - for z_i in params.z.iter() { - α[*z_i as usize] += 1.0; - } - - if α.len() == 1 { - params.π.clear(); - params.π.push(1.0); - } else { - params.π.clear(); - params - .π - .extend(Dirichlet::new(&α).unwrap().sample(&mut rng).iter()); - } + // self.sample_z(params); + // println!(" sample_z: {:?}", t0.elapsed()); // let t0 = Instant::now(); self.sample_background_rates(priors, params); @@ -1216,7 +1228,6 @@ where // let t0 = Instant::now(); self.sample_confusion_rates(priors, params); - // TODO: disabling confusion to see if it actually does anything // println!(" Sample confusion rates: {:?}", t0.elapsed()); // let t0 = Instant::now(); @@ -1251,7 +1262,7 @@ where let layer = ((position.2 - params.z0) / params.layer_depth).max(0.0) as usize; let layer = layer.min(nlayers - 1); - let λ_cell = params.λ[[gene, cell as usize]]; + let λ_cell = params.λ[[cell as usize, gene]]; let λ_bg = params.λ_bg[[gene, layer]]; let λ_c = params.λ_c[gene]; let λ = λ_cell + λ_bg + λ_c; @@ -1301,6 +1312,15 @@ where params: &mut ModelParams, transcripts: &Vec, ) { + // TODO: Ok, this is the bottleneck now. It can't be trivially parallelized + // because we are accumulating these big matrices. + + // These matrices are also problematic because they use so much memory. + // Ideally we'd find a solution where we can both accumulate in parallel + // and be sparse. + + // I think the thing to do is + let nlayers = params.nlayers(); params.confusion_counts.fill(0_u32); params.background_counts.fill(0_u32); @@ -1323,300 +1343,504 @@ where params.confusion_counts[gene] += 1; } TranscriptState::Foreground => { - params.foreground_counts[[cell as usize, gene, layer]] += 1; + params.foreground_counts[[cell as usize, gene]] += 1; } } }); - - // dbg!(params.background_counts.sum()); - // dbg!(params.confusion_counts.sum()); } - fn sample_component_nb_params( - &mut self, - priors: &ModelPriors, - params: &mut ModelParams, - burnin: bool, - ) { - // total component area - // let mut component_cell_area = vec![0_f32; params.ncomponents()]; - params.component_volume.fill(0.0); - params - .cell_volume - .iter() - .zip(¶ms.z) - .for_each(|(volume, z_i)| { - params.component_volume[*z_i as usize] += *volume; - }); + fn sample_latent_counts(&mut self, params: &mut ModelParams) { + params.cell_latent_counts.fill(0); - // Sample ω + // zero out thread local gene latent counts + for x in params.gene_latent_counts_tl.iter_mut() { + x.borrow_mut().fill(0); + } - // let rmin = params.r.iter().min_by(|a, b| a.partial_cmp(b).unwrap()); - // let rmax = params.r.iter().max_by(|a, b| a.partial_cmp(b).unwrap()); - // dbg!(rmin, rmax); + let ngenes = params.foreground_counts.shape()[1]; + let nhidden = params.cell_latent_counts.shape()[1]; - // let φmin = params.φ.iter().min_by(|a, b| a.partial_cmp(b).unwrap()); - // let φmax = params.φ.iter().max_by(|a, b| a.partial_cmp(b).unwrap()); - // dbg!(φmin, φmax); + Zip::from(params.cell_latent_counts.outer_iter_mut()) // for every cell + .and(params.φ.outer_iter()) + .and(params.foreground_counts.outer_iter()) + .par_for_each(|mut cell_latent_counts_c, φ_c, x_c| { + let mut rng = thread_rng(); + let mut multinomial_rates = params + .multinomial_rates + .get_or(|| RefCell::new(Array1::zeros(nhidden))) + .borrow_mut(); - // let vmin = params.cell_volume.iter().min_by(|a, b| a.partial_cmp(b).unwrap()); - // let vmax = params.cell_volume.iter().max_by(|a, b| a.partial_cmp(b).unwrap()); - // dbg!(vmin, vmax); + let mut multinomial_sample = params + .multinomial_sample + .get_or(|| RefCell::new(Array1::zeros(nhidden))) + .borrow_mut(); - // let t0 = Instant::now(); - Zip::from(params.ω.rows_mut()) // for every cell - .and(params.foreground_counts.axis_iter(Axis(0))) - .and(¶ms.cell_log_volume) - .and(¶ms.z) - .par_for_each(|ωs, cs, &logv, &z| { - let mut rng = thread_rng(); - Zip::from(cs.axis_iter(Axis(0))) // for every gene - .and(ωs) - .and(params.φ.row(z as usize)) - .and(params.r.row(z as usize)) - .for_each(|c, ω, φ, &r| { - *ω = PolyaGamma::new(c.sum() as f32 + r, logv + φ).sample(&mut rng); - }); - }); - // println!(" Sample ω: {:?}", t0.elapsed()); + let mut gene_latent_counts_tl = params + .gene_latent_counts_tl + .get_or(|| RefCell::new(Array2::zeros((ngenes, nhidden)))) + .borrow_mut(); - // Compute parameters to sample φ - // let t0 = Instant::now(); - params.μ_φ.fill(0.0); - params.σ_φ.fill(0.0); - Zip::from(params.ω.rows()) // for every cell - .and(params.foreground_counts.axis_iter(Axis(0))) // for each cell - .and(¶ms.cell_log_volume) - .and(¶ms.z) - .for_each(|ωs, cs, &logv, &z| { - Zip::from(params.μ_φ.row_mut(z as usize)) // for every gene - .and(params.σ_φ.row_mut(z as usize)) - .and(params.r.row(z as usize)) - .and(ωs) - .and(cs.axis_iter(Axis(0))) - .for_each(|μ, σ, &r, &ω, c| { - *σ += ω; - *μ += (c.sum() as f32 - r) / 2.0 - ω * logv; + Zip::indexed(x_c) // for every gene + .and(params.θ.outer_iter()) + .for_each(|g, &x_cg, θ_g| { + if x_cg == 0 { + return; + } + + multinomial_sample.fill(0); + + // rates: normalized element-wise product + multinomial_rates.assign(&φ_c); + *multinomial_rates *= &θ_g; + let rate_norm = multinomial_rates.sum(); + *multinomial_rates /= rate_norm; + + // multinomial sampling + { + let mut ρ = 1.0; + let mut s = x_cg as u32; + for (p, x) in + izip!(multinomial_rates.iter(), multinomial_sample.iter_mut()) + { + if ρ > 0.0 { + *x = Binomial::new(s as u64, ((*p / ρ) as f64).min(1.0)) + .unwrap() + .sample(&mut rng) + as u32; + } + s -= *x; + ρ = ρ - *p; + + if s == 0 { + break; + } + } + } + + // add to cell marginal + cell_latent_counts_c.scaled_add(1, &multinomial_sample); + + // add to gene marginal + let mut gene_latent_counts_g = gene_latent_counts_tl.row_mut(g); + gene_latent_counts_g.scaled_add(1, &multinomial_sample); }); }); - Zip::from(&mut params.σ_φ).for_each(|σ| *σ = (priors.γ + *σ).recip()); + // accumulate from thread local matrices + params.gene_latent_counts.fill(0); + for x in params.gene_latent_counts_tl.iter_mut() { + params.gene_latent_counts.scaled_add(1, &x.borrow()); + } + + // marginal count along the hidden axis + Zip::from(&mut params.latent_counts) + .and(params.gene_latent_counts.columns()) + .for_each(|lc, glc| { + *lc = glc.sum(); + }); - Zip::from(&mut params.μ_φ) - .and(¶ms.σ_φ) - .for_each(|μ, σ| *μ *= σ); - // println!(" Compute φ parameters: {:?}", t0.elapsed()); + let count = params.latent_counts.sum(); + assert!(params.gene_latent_counts.sum() == count); + assert!(params.cell_latent_counts.sum() == count); - // Sample φ - // let t0 = Instant::now(); - Zip::from(&mut params.φ) - .and(¶ms.μ_φ) - .and(¶ms.σ_φ) - .for_each(|φ, &μ, &σ| { - let mut rng = thread_rng(); - *φ = Normal::new(μ, σ.sqrt()).unwrap().sample(&mut rng); + // compute component-wise counts + params.component_population.fill(0); + params.component_volume.fill(0.0); + params.component_latent_counts.fill(0); + Zip::from(¶ms.z) + .and(¶ms.cell_volume) + .and(params.cell_latent_counts.rows()) + .for_each(|z_c, v_c, x_c| { + let z_c = *z_c as usize; + params.component_population[z_c] += 1; + params.component_volume[z_c] += *v_c; + params + .component_latent_counts + .row_mut(z_c) + .scaled_add(1, &x_c); }); - // println!(" Sample φ: {:?}", t0.elapsed()); - // }); - // Zip::from(&mut params.ω) - // .and(¶ms.foreground_counts) + // dbg!(¶ms.component_population); + } - // Sample θ + // This is indended just for debugging + // fn write_cell_latent_counts(&self, params: &ModelParams, filename: &str) { + // let file = File::create(filename).unwrap(); + // let mut encoder = GzEncoder::new(file, Compression::default()); + + // // header + // for i in 0..params.cell_latent_counts.shape()[1] { + // if i != 0 { + // write!(encoder, ",").unwrap(); + // } + // write!(encoder, "metagene{}", i).unwrap(); + // } + // writeln!(encoder).unwrap(); + + // for x_c in params.cell_latent_counts.rows() { + // for (k, x_ck) in x_c.iter().enumerate() { + // if k != 0 { + // write!(encoder, ",").unwrap(); + // } + // write!(encoder, "{}", *x_ck).unwrap(); + // } + // writeln!(encoder).unwrap(); + // } + // } + + fn sample_factor_model( + &mut self, + priors: &ModelPriors, + params: &mut ModelParams, + sample_z: bool, + burnin: bool, + ) { // let t0 = Instant::now(); - // Zip::from(params.θ.columns_mut()) - // .and(params.r.columns()) - // .and(params.component_counts.rows()) - // .par_for_each(|θs, rs, cs| { - // let mut rng = thread_rng(); - // for (θ, &c, &r, &a) in izip!(θs, cs, rs, ¶ms.component_volume) { - // *θ = prob_to_odds( - // Beta::new(priors.α_θ + c as f32, priors.β_θ + a * r) - // .unwrap() - // .sample(&mut rng), - // ); - - // // *θ = θ.max(1e-6).min(1e6); - // } - // }); - // println!(" Sample θ: {:?}", t0.elapsed()); + self.sample_latent_counts(params); + // println!(" sample_latent_counts: {:?}", t0.elapsed()); - // dbg!( - // params.θ.iter().min_by(|a, b| a.partial_cmp(b).unwrap()), - // params.θ.iter().max_by(|a, b| a.partial_cmp(b).unwrap()), - // ); + if sample_z { + // let t0 = Instant::now(); + self.sample_z(params); + // println!(" sample_z: {:?}", t0.elapsed()); + } + self.sample_π(params); - // dbg!(params.h); + if priors.use_factorization { + // let t0 = Instant::now(); + self.sample_θ(priors, params); + // println!(" sample_θ: {:?}", t0.elapsed()); + } - // Sample r // let t0 = Instant::now(); + self.sample_φ(params); + // println!(" sample_φ: {:?}", t0.elapsed()); - // TODO: Need some argument to determine dispersion sampling - // behavior. - - fn set_constant_dispersion(params: &mut ModelParams, dispersion: f32) { - params.r.fill(dispersion); - Zip::from(¶ms.r) - .and(&mut params.lgamma_r) - .and(&mut params.loggammaplus) - .par_for_each(|r, lgamma_r, loggammaplus| { - *lgamma_r = lgammaf(*r); - loggammaplus.reset(*r); - }); - } - + // let t0 = Instant::now(); if let Some(dispersion) = priors.dispersion { - set_constant_dispersion(params, dispersion); + params.rφ.fill(dispersion); } else if burnin && priors.burnin_dispersion.is_some() { let dispersion = priors.burnin_dispersion.unwrap(); - set_constant_dispersion(params, dispersion); + params.rφ.fill(dispersion); } else { - // for each gene - params.uv.fill((0_u32, 0_f32)); - Zip::from(params.r.columns_mut()) - .and(params.lgamma_r.columns_mut()) - .and(params.loggammaplus.columns_mut()) - .and(params.φ.columns()) - .and(params.foreground_counts.axis_iter(Axis(1))) - .and(params.uv.columns_mut()) - .par_for_each(|rs, lgamma_rs, loggammaplus, φs, cs, mut uv| { - let mut rng = thread_rng(); - - // iterate over cells computing u and v - Zip::from(¶ms.z) - .and(cs.axis_iter(Axis(0))) - .and(¶ms.cell_volume) - .for_each(|&z, c, &vol| { - let z = z as usize; - let c = c.sum(); - let r = rs[z]; - let φ = φs[z]; - let ψ = φ + vol.ln(); - let δv = -ψ - log1pf((-ψ).exp()); - - let uv_z = uv[z]; - - if (uv[z].1 + δv).is_infinite() { - dbg!(uv[z], ψ, φ, vol); - } + self.sample_rφ(priors, params); + } + // println!(" sample_rφ: {:?}", t0.elapsed()); - uv[z] = (uv_z.0 + rand_crt(&mut rng, c as u32, r), uv_z.1 + δv); + // let t0 = Instant::now(); + self.sample_ωck(params); + self.sample_sφ(priors, params); - assert!(uv[z].1.is_finite()); - }); + if priors.use_cell_scales { + self.sample_cell_scales(priors, params); + } else { + params.effective_cell_volume.assign(¶ms.cell_volume); + } - // iterate over components sampling r - Zip::from(rs) - .and(lgamma_rs) - .and(loggammaplus) - .and(uv) - .for_each(|r, lgamma_r, loggammaplus, uv| { - let dist = - Gamma::new(priors.e_r + uv.0 as f32, (params.h - uv.1).recip()); - if dist.is_err() { - dbg!(uv.0, uv.1, params.h); - } - let dist = dist.unwrap(); - *r = dist.sample(&mut rng); + // let t0 = Instant::now(); + self.compute_rates(params); + // println!(" compute_rates: {:?}", t0.elapsed()); + } - // if *r < 0.001 { - // dbg!(uv.0, uv.1, params.h, *r); - // } + fn sample_z(&mut self, params: &mut ModelParams) { + let ncomponents = params.π.shape()[0]; - // *r = Gamma::new(priors.e_r + uv.0 as f32, (params.h - uv.1).recip()) - // .unwrap() - // .sample(&mut rng); + // precompute lgamma(rφ) + Zip::from(&mut params.lgamma_rφ) + .and(¶ms.rφ) + .par_for_each(|lgamma_r_tk, r_tk| { + *lgamma_r_tk = lgammaf(*r_tk); + }); - assert!(r.is_finite()); + Zip::from(&mut params.z) // for each cell + .and(params.φ.rows()) + .and(params.cell_latent_counts.rows()) + .and(¶ms.effective_cell_volume) + .and(¶ms.cell_volume) + .par_for_each(|z_c, φ_c, x_c, ev_c, v_c| { + let mut rng = rand::thread_rng(); + let log_v_c = v_c.ln(); + let mut z_probs = params + .z_probs + .get_or(|| RefCell::new(vec![0_f64; ncomponents])) + .borrow_mut(); - // TODO: Without this, things get kind of fucky. - // *r = r.min(200.0).max(2e-4); - *r = r.max(2e-4); - // *r = r.min(50.0); - // *r = r.max(1e-5); + // compute probability of φ_c under every component - *lgamma_r = lgammaf(*r); - loggammaplus.reset(*r); + // for every component + let mut z_probs_sum = 0.0; + for (z_probs_t, π_t, r_t, lgamma_r_t, s_t, μ_vol_c, σ_vol_c) in izip!( + z_probs.iter_mut(), + params.π.iter(), + params.rφ.rows(), + params.lgamma_rφ.rows(), + params.sφ.rows(), + ¶ms.μ_volume, + ¶ms.σ_volume + ) { + *z_probs_t = π_t.ln() as f64; + + // for every hidden dim + Zip::from(r_t) + .and(lgamma_r_t) + .and(s_t) + .and(¶ms.θksum) + .and(x_c) + .for_each(|r_tk, lgamma_r_tk, s_tk, θ_k_sum, x_ck| { + let p = odds_to_prob(*s_tk * *ev_c * *θ_k_sum); + let lp = negbin_logpmf(*r_tk, *lgamma_r_tk, p, *x_ck) as f64; + *z_probs_t += lp; }); - // // self.cell_areas.slice(0..self.ncells) + *z_probs_t += normal_logpdf(*μ_vol_c, *σ_vol_c, log_v_c) as f64; + } + + for z_probs_t in z_probs.iter_mut() { + *z_probs_t = z_probs_t.exp(); + z_probs_sum += *z_probs_t; + } + + if !z_probs_sum.is_finite() { + dbg!(&z_probs, &φ_c, z_probs_sum); + } + + // cumulative probabilities in-place + z_probs.iter_mut().fold(0.0, |mut acc, x| { + acc += *x / z_probs_sum; + *x = acc; + acc + }); + + let u = rng.gen::(); + *z_c = z_probs.partition_point(|x| *x < u) as u32; + }); + } - // let u = Zip::from(counts.axis_iter(Axis(0))).fold(0, |accum, cs| { - // let c = cs.sum(); - // accum + rand_crt(&mut rng, c, *r) - // }); + fn sample_π(&mut self, params: &mut ModelParams) { + let mut rng = rand::thread_rng(); + let mut π_sum = 0.0; + Zip::from(&mut params.π) + .and(¶ms.component_population) + .for_each(|π_t, pop_t| { + *π_t = Gamma::new(1.0 + *pop_t as f32, 1.0) + .unwrap() + .sample(&mut rng); + π_sum += *π_t; + }); - // let v = - // Zip::from(¶ms.z) - // .and(¶ms.cell_volume) - // .fold(0.0, |accum, z, a| { - // let w = θs[*z as usize]; - // accum + log1pf(-odds_to_prob(w * *a)) - // }); + // normalize to get dirichlet posterior + params.π.iter_mut().for_each(|π_t| *π_t /= π_sum); + } - // // dbg!(u, v); + fn compute_rates(&mut self, params: &mut ModelParams) { + general_mat_mul(1.0, ¶ms.φ, ¶ms.θ.t(), 0.0, &mut params.λ); + } - // *r = Gamma::new(priors.e_r + u as f32, (params.h - v).recip()) - // .unwrap() - // .sample(&mut rng); + fn sample_φ(&mut self, params: &mut ModelParams) { + Zip::from(params.φ.outer_iter_mut()) // for each cell + .and(¶ms.z) + .and(params.cell_latent_counts.outer_iter()) + .and(¶ms.effective_cell_volume) + .par_for_each(|φ_c, z_c, x_c, v_c| { + let z_c = *z_c as usize; + let mut rng = thread_rng(); + Zip::from(φ_c) // for each latent dim + .and(¶ms.θksum) + .and(x_c) + .and(¶ms.rφ.row(z_c)) + .and(¶ms.sφ.row(z_c)) + .for_each(|φ_ck, θ_k_sum, x_ck, r_k, s_k| { + let shape = *r_k + *x_ck as f32; + let scale = s_k / (1.0 + s_k * v_c * *θ_k_sum); + *φ_ck = Gamma::new(shape, scale).unwrap().sample(&mut rng); + }); + }); - // assert!(r.is_finite()); + Zip::from(&mut params.φ_v_dot) + .and(params.φ.axis_iter(Axis(1))) + .for_each(|φ_v_dot_k, φ_k| { + *φ_v_dot_k = φ_k.dot(¶ms.effective_cell_volume); + }); + } - // *r = r.min(200.0).max(1e-5); + fn sample_θ(&mut self, priors: &ModelPriors, params: &mut ModelParams) { + let mut θfac = params + .θ + .slice_mut(s![params.nunfactored.., params.nunfactored..]); + let gene_latent_counts_fac = params + .gene_latent_counts + .slice(s![params.nunfactored.., params.nunfactored..]); + + // Sampling with Dirichlet prior on θ (I think Gamma makes more + // sense, but this is an alternative to consider) + Zip::from(θfac.axis_iter_mut(Axis(1))) + .and(gene_latent_counts_fac.axis_iter(Axis(1))) + .for_each(|mut θ_k, x_k| { + let mut rng = thread_rng(); - // *lgamma_r = lgammaf(*r); - // loggammaplus.reset(*r); + // dirichlet sampling by normalizing gammas + Zip::from(&mut θ_k).and(x_k).for_each(|θ_gk, x_gk| { + *θ_gk = Gamma::new(priors.αθ + *x_gk as f32, 1.0) + .unwrap() + .sample(&mut rng); }); - } - // params.h = 0.1; - params.h = Gamma::new( - priors.e_h * (1_f32 + params.r.len() as f32), - (priors.f_h + params.r.sum()).recip(), - ) - .unwrap() - .sample(&mut thread_rng()); - // dbg!(params.h); + let θsum = θ_k.sum(); + θ_k *= θsum.recip(); + }); + + Zip::from(&mut params.θksum) + .and(params.θ.axis_iter(Axis(1))) + .for_each(|θksum, θ_k| { + *θksum = θ_k.sum(); + }); } - fn sample_rates(&mut self, _priors: &ModelPriors, params: &mut ModelParams) { - // loop over genes - Zip::from(params.λ.rows_mut()) - .and(params.foreground_counts.axis_iter(Axis(1))) - .and(params.φ.columns()) - .and(params.r.columns()) - .par_for_each(|mut λs, cs, φs, rs| { + fn sample_rφ(&mut self, priors: &ModelPriors, params: &mut ModelParams) { + Zip::from(params.lφ.outer_iter_mut()) // for every cell + .and(¶ms.z) + .and(params.cell_latent_counts.outer_iter()) + .par_for_each(|l_c, z_c, x_c| { let mut rng = thread_rng(); - // loop over cells - for (λ, &z, cs, cell_volume) in - izip!(&mut λs, ¶ms.z, cs.outer_iter(), ¶ms.cell_volume) - { - let z = z as usize; - - let c = cs.sum(); - let φ = φs[z]; - // dbg!(φ, φ.exp(), (-φ).exp()); - // let ψ = φ + cell_volume.ln(); - let r = rs[z]; - - let α = r + c as f32; - // let β = (-φ).exp() / cell_volume + 1.0; - let β0 = (-φ).exp(); - let β = β0 + cell_volume; - - *λ = Gamma::new(α, β.recip()).unwrap().sample(&mut rng); - // .max(1e-14); - - // if c > 10 { - // // TODO: How far off is than from just count divided by volume? - // dbg!( - // *λ, - // r, - // c as f32 / *cell_volume, *cell_volume); - // } + Zip::from(l_c) // for each hidden dim + .and(x_c) + .and(¶ms.rφ.row(*z_c as usize)) + .for_each(|l_ck, x_ck, r_k| { + *l_ck = rand_crt(&mut rng, *x_ck, *r_k); + }); + }); - assert!(λ.is_finite()); - } + Zip::indexed(params.rφ.outer_iter_mut()) // for each component + .and(params.sφ.outer_iter()) + .par_for_each(|t, r_t, s_t| { + let mut rng = thread_rng(); + Zip::from(r_t) // each hidden dim + .and(s_t) + .and(params.lφ.axis_iter(Axis(1))) + .and(¶ms.θksum) + .for_each(|r_tk, s_tk, l_k, θ_k_sum| { + // summing elements of lφ in component t + let lsum = l_k + .iter() + .zip(¶ms.z) + .filter(|(_l_ck, z_c)| **z_c as usize == t) + .map(|(l_ck, _z_c)| *l_ck) + .sum::(); + + let shape = priors.eφ + lsum as f32; + + let scale_inv = (1.0 / priors.fφ) + + params + .z + .iter() + .zip(¶ms.effective_cell_volume) + .filter(|(z_c, _v_c)| **z_c as usize == t) + .map(|(_z_c, v_c)| (*s_tk * v_c * *θ_k_sum).ln_1p()) + .sum::(); + let scale = scale_inv.recip(); + *r_tk = Gamma::new(shape, scale).unwrap().sample(&mut rng); + *r_tk = r_tk.max(2e-4); + }); + }); + + // dbg!(¶ms.rφ); + } + + fn sample_ωck(&mut self, params: &mut ModelParams) { + // sample ω ~ PolyaGamma + // let t0 = Instant::now(); + Zip::from(params.ωφ.outer_iter_mut()) // for every cell + .and(¶ms.z) + .and(params.cell_latent_counts.outer_iter()) + .and(¶ms.effective_cell_volume) + .par_for_each(|ω_c, z_c, x_c, v_c| { + let mut rng = thread_rng(); + Zip::from(ω_c) // for every hidden dim + .and(x_c) + .and(params.rφ.row(*z_c as usize)) + .and(params.sφ.row(*z_c as usize)) + .and(¶ms.θksum) + .for_each(|ω_ck, x_ck, r_k, s_k, θ_k_sum| { + let ε = (*s_k * *v_c * θ_k_sum).ln(); + *ω_ck = PolyaGamma::new(*x_ck as f32 + *r_k, ε).sample(&mut rng); + }); + }); + // println!(" sample_sφ/PolyaGamma: {:?}", t0.elapsed()); + } + + fn sample_sφ(&mut self, priors: &ModelPriors, params: &mut ModelParams) { + // sample s ~ LogNormal + // let t0 = Instant::now(); + Zip::indexed(params.sφ.outer_iter_mut()) // for every component + .and(params.rφ.outer_iter()) + .par_for_each(|t, s_t, r_t| { + let mut rng = thread_rng(); + Zip::from(s_t) // for every hidden dim + .and(r_t) + .and(¶ms.θksum) + .and(params.ωφ.axis_iter(Axis(1))) + .and(params.cell_latent_counts.axis_iter(Axis(1))) + .for_each(|s_tk, r_tk, θ_k_sum, ω_k, x_k| { + // TODO: This would be fatser if we went pre-computing μ, σ [ncomponets, nhidden] matrices + // by processing cells, rather than doing this filtering thing. + let τ = priors.τφ + + ω_k + .iter() + .zip(¶ms.z) + .filter(|(_ω_ck, z_c)| **z_c as usize == t) + .map(|(ω_ck, _z_c)| *ω_ck) + .sum::(); + let σ2 = τ.recip(); + let μ = σ2 + * (priors.μφ * priors.τφ + + Zip::from(x_k) // for every cell + .and(¶ms.z) + .and(ω_k) + .and(¶ms.effective_cell_volume) + .fold(0.0, |acc, x_ck, z_c, ω_ck, v_c| { + if *z_c as usize == t { + acc + (*x_ck as f32 - *r_tk) / 2.0 + - ω_ck * (v_c * *θ_k_sum).ln() + } else { + acc + } + })); + *s_tk = (μ + σ2.sqrt() * randn(&mut rng)).exp(); + }); + }); + // println!(" sample_sφ/LogNormal: {:?}", t0.elapsed()); + } + + fn sample_cell_scales(&mut self, priors: &ModelPriors, params: &mut ModelParams) { + // for each cell + Zip::from(&mut params.cell_scale) + .and(&mut params.effective_cell_volume) + .and(¶ms.log_cell_volume) + .and(params.cell_latent_counts.outer_iter()) + .and(¶ms.z) + .and(params.ωφ.outer_iter()) + .par_for_each(|a_c, eff_v_c, &log_v_c, x_c, &z_c, ω_c| { + let mut rng = thread_rng(); + let z_c = z_c as usize; + + let τ = priors.τv + ω_c.sum(); + let σ2 = τ.recip(); + + // for each hidden dim + let μ = σ2 + * Zip::from(¶ms.θksum) + .and(x_c) + .and(params.rφ.row(z_c)) + .and(ω_c) + .and(params.sφ.row(z_c)) + .fold(0.0, |acc, &θ_k_sum, x_ck, &r_tk, &ω_ck, &s_tk| { + acc + (*x_ck as f32 - r_tk) / 2.0 + - ω_ck * ((s_tk * θ_k_sum).ln() + log_v_c) + }); + + let log_a_c = μ + σ2.sqrt() * randn(&mut rng); + *a_c = log_a_c.exp(); + *eff_v_c = (log_a_c + log_v_c).exp(); }); } @@ -1660,80 +1884,8 @@ where }); } - fn sample_component_assignments(&mut self, _priors: &ModelPriors, params: &mut ModelParams) { - let ncomponents = params.ncomponents(); - - // loop over cells - Zip::from(params.foreground_counts.axis_iter(Axis(0))) - .and(&mut params.z) - .and(¶ms.cell_log_volume) - .par_for_each(|cs, z_i, cell_log_volume| { - let mut z_probs = params - .z_probs - .get_or(|| RefCell::new(vec![0_f64; ncomponents])) - .borrow_mut(); - - // loop over components - for (zp, π, φs, rs, lgamma_r, loggammaplus, &μ_volume, &σ_volume) in izip!( - z_probs.iter_mut(), - ¶ms.π, - params.φ.rows(), - params.r.rows(), - params.lgamma_r.rows(), - params.loggammaplus.rows(), - ¶ms.μ_volume, - ¶ms.σ_volume - ) { - // sum over genes - *zp = (*π as f64) - * (Zip::from(cs.axis_iter(Axis(0))) - .and(rs) - .and(φs) - .and(lgamma_r) - .and(loggammaplus) - .fold(0_f32, |accum, cs, &r, φ, &lgamma_r, lgammaplus| { - let ψ = φ + cell_log_volume; - let c = cs.iter().map(|&x| x as u32).sum(); // sum counts across layers - accum - + negbin_logpmf_fast( - r, - lgamma_r, - lgammaplus.eval(c), - logistic(ψ), - c, - params.logfactorial.eval(c), - ) - }) as f64) - .exp(); - - *zp *= normal_pdf(μ_volume, σ_volume, *cell_log_volume).exp() as f64; - } - - // z_probs.iter_mut().enumerate().for_each(|(j, zp)| { - // *zp = (self.params.π[j] as f64) * - // negbin_logpmf(r, lgamma_r, p, k) - // // (self.params.cell_logprob_fast(j as usize, *cell_area, &cs, &clfs) as f64).exp(); - // }); - - let z_prob_sum = z_probs.iter().sum::(); - - assert!(z_prob_sum.is_finite()); - - // cumulative probabilities in-place - z_probs.iter_mut().fold(0.0, |mut acc, x| { - acc += *x / z_prob_sum; - *x = acc; - acc - }); - - let rng = &mut thread_rng(); - let u = rng.gen::(); - *z_i = z_probs.partition_point(|x| *x < u) as u32; - }); - } - fn sample_volume_params(&mut self, priors: &ModelPriors, params: &mut ModelParams) { - Zip::from(&mut params.cell_log_volume) + Zip::from(&mut params.log_cell_volume) .and(¶ms.cell_volume) .into_par_iter() .with_min_len(50) @@ -1742,13 +1894,11 @@ where }); // compute sample means - params.component_population.fill(0_u32); params.μ_volume.fill(0_f32); Zip::from(¶ms.z) - .and(¶ms.cell_log_volume) + .and(¶ms.log_cell_volume) .for_each(|&z, &log_volume| { params.μ_volume[z as usize] += log_volume; - params.component_population[z as usize] += 1; }); // dbg!(¶ms.component_population); @@ -1775,7 +1925,7 @@ where // compute sample variances params.σ_volume.fill(0_f32); Zip::from(¶ms.z) - .and(¶ms.cell_log_volume) + .and(¶ms.log_cell_volume) .for_each(|&z, &log_volume| { params.σ_volume[z as usize] += (params.μ_volume[z as usize] - log_volume).powi(2); }); @@ -1897,7 +2047,7 @@ where let λ_prev = if cell_prev == BACKGROUND_CELL { 0.0 } else { - params.λ[[gene, cell_prev as usize]] + params.λ_c[gene] + params.λ[[cell_prev as usize, gene]] + params.λ_c[gene] } + params.λ_bg[[gene, layer_prev]]; let layer_new = @@ -1907,7 +2057,7 @@ where let λ_new = if cell_new == BACKGROUND_CELL { 0.0 } else { - params.λ[[gene, cell_new as usize]] + params.λ_c[gene] + params.λ[[cell_new as usize, gene]] + params.λ_c[gene] } + params.λ_bg[[gene, layer_new]]; let ln_λ_diff = λ_new.ln() - λ_prev.ln(); @@ -1960,9 +2110,12 @@ where transcripts: &Vec, uncertainty: &mut Option<&mut UncertaintyTracker>, ) { + // let t0 = Instant::now(); self.propose_eval_transcript_positions(priors, params, transcripts); + // println!(" REPO: proposals {:?}", t0.elapsed()); // Update position and compute cell and layer changes for updates + // let t0 = Instant::now(); params .transcript_position_updates .par_iter_mut() @@ -1988,34 +2141,26 @@ where } }, ); + // println!(" REPO: update positions {:?}", t0.elapsed()); // Update counts and cell_population + // let t0 = Instant::now(); params .transcript_position_updates .iter() - .zip(transcripts) .zip(¶ms.accept_proposed_transcript_positions) .zip(&mut params.cell_assignments) .zip(&mut params.cell_assignment_time) .enumerate() .for_each( - |( - i, - ((((update, transcript), &accept), cell_assignment), cell_assignment_time), - )| { - let (cell_prev, cell_new, layer_prev, layer_new) = *update; + |(i, (((update, &accept), cell_assignment), cell_assignment_time))| { + let (cell_prev, cell_new, _layer_prev, _layer_new) = *update; if accept { - let gene = transcript.gene as usize; if cell_prev != BACKGROUND_CELL { - assert!( - params.counts[[gene, cell_prev as usize, layer_prev as usize]] > 0 - ); - params.counts[[gene, cell_prev as usize, layer_prev as usize]] -= 1; assert!(params.cell_population[cell_prev as usize] > 0); params.cell_population[cell_prev as usize] -= 1; } if cell_new != BACKGROUND_CELL { - params.counts[[gene, cell_new as usize, layer_new as usize]] += 1; params.cell_population[cell_new as usize] += 1; } @@ -2036,10 +2181,13 @@ where } }, ); + // println!(" REPO: update counts {:?}", t0.elapsed()); + // let t0 = Instant::now(); self.update_transcript_positions( ¶ms.accept_proposed_transcript_positions, ¶ms.transcript_positions, ); + // println!(" REPO: voxel sampler update positions {:?}", t0.elapsed()); } } diff --git a/src/sampler/math.rs b/src/sampler/math.rs index 5e31f9d..1b70e37 100644 --- a/src/sampler/math.rs +++ b/src/sampler/math.rs @@ -1,24 +1,24 @@ -// use libm::{lgammaf, erff}; use libm::lgammaf; use rand::rngs::ThreadRng; use rand::Rng; +use rand_distr::StandardNormal; // pub fn logit(p: f32) -> f32 { // return p.ln() - (1.0 - p).ln(); // } -pub fn logistic(x: f32) -> f32 { - 1.0 / (1.0 + (-x).exp()) +// pub fn sq(x: f32) -> f32 { +// x * x +// } + +pub fn odds_to_prob(q: f32) -> f32 { + q / (1.0 + q) } pub fn relerr(a: f32, b: f32) -> f32 { ((a - b) / a).abs() } -pub fn lfact(k: u32) -> f32 { - lgammaf(k as f32 + 1.0) -} - // // Partial Student-T log-pdf (just the terms that don't cancel out when doing MH sampling) // pub fn studentt_logpdf_part(σ2: f32, df: f32, x2: f32) -> f32 { // return -((df + 1.0) / 2.0) * ((x2 / σ2) / df).ln_1p(); @@ -35,30 +35,27 @@ pub fn normal_x2_logpdf(σ: f32, x2: f32) -> f32 { -x2 / (2.0 * σ.powi(2)) - σ.ln() - LN_SQRT_TWO_PI } -// pub fn negbin_logpmf(r: f32, lgamma_r: f32, p: f32, k: u32) -> f32 { -// let k_ln_factorial = lgammaf(k as f32 + 1.0); -// let lgamma_rpk = lgammaf(r + k as f32); -// return negbin_logpmf_fast(r, lgamma_r, lgamma_rpk, p, k, k_ln_factorial); -// } - -// const SQRT2: f32 = 1.4142135623730951_f32; - -// pub fn normal_cdf(μ: f32, σ: f32, x: f32) -> f32 { -// return 0.5 * (1.0 + erff((x - μ) / (SQRT2 * σ))); +// pub fn gamma_logpdf(shape: f32, scale: f32, x: f32) -> f32 { +// return +// -lgammaf(shape) +// - shape * scale.ln() +// + (shape - 1.0) * x.ln() +// - x / scale; // } -pub fn normal_pdf(μ: f32, σ: f32, x: f32) -> f32 { - let xμ = x - μ; - (-xμ.powi(2) / (2.0 * σ.powi(2))).exp() / (σ * SQRT_TWO_PI) +pub fn rand_crt(rng: &mut ThreadRng, n: u32, r: f32) -> u32 { + (0..n) + .map(|t| rng.gen_bool(r as f64 / (r as f64 + t as f64)) as u32) + .sum() } -pub fn lognormal_logpdf(μ: f32, σ: f32, x: f32) -> f32 { - let xln = x.ln(); - -LN_SQRT_TWO_PI - σ.ln() - xln - ((xln - μ) / σ).powi(2) / 2.0 +pub fn negbin_logpmf(r: f32, lgamma_r: f32, p: f32, k: u32) -> f32 { + let k_ln_factorial = lgammaf(k as f32 + 1.0); + let lgamma_rpk = lgammaf(r + k as f32); + return negbin_logpmf_fast(r, lgamma_r, lgamma_rpk, p, k, k_ln_factorial); } -// Negative binomial log probability function with capacity for precomputing some values. -pub fn negbin_logpmf_fast( +fn negbin_logpmf_fast( r: f32, lgamma_r: f32, lgamma_rpk: f32, @@ -82,68 +79,21 @@ pub fn negbin_logpmf_fast( } } -pub fn rand_crt(rng: &mut ThreadRng, n: u32, r: f32) -> u32 { - (0..n) - .map(|t| rng.gen_bool(r as f64 / (r as f64 + t as f64)) as u32) - .sum() -} - -// log-factorial with precomputed values for small numbers -pub struct LogFactorial { - values: Vec, -} - -impl LogFactorial { - fn new_with_n(n: usize) -> Self { - LogFactorial { - values: Vec::from_iter((0..n as u32).map(lfact)), - } - } +// const SQRT2: f32 = 1.4142135623730951_f32; - pub fn new() -> Self { - LogFactorial::new_with_n(100) - } +// pub fn normal_cdf(μ: f32, σ: f32, x: f32) -> f32 { +// return 0.5 * (1.0 + erff((x - μ) / (SQRT2 * σ))); +// } - pub fn eval(&self, k: u32) -> f32 { - self.values - .get(k as usize) - .map_or_else(|| lfact(k), |&value| value) - } +pub fn normal_logpdf(μ: f32, σ: f32, x: f32) -> f32 { + -LN_SQRT_TWO_PI - σ.ln() - ((x - μ) / σ).powi(2) / 2.0 } -// Partially memoized lgamma(r + k), memoized over k. -#[derive(Clone)] -pub struct LogGammaPlus { - r: f32, - values: Vec, +pub fn lognormal_logpdf(μ: f32, σ: f32, x: f32) -> f32 { + let xln = x.ln(); + -LN_SQRT_TWO_PI - σ.ln() - xln - ((xln - μ) / σ).powi(2) / 2.0 } -impl LogGammaPlus { - fn new_with_n(r: f32, n: usize) -> Self { - LogGammaPlus { - r, - values: Vec::from_iter((0..n as u32).map(|x| lgammaf(r + x as f32))), - } - } - - pub fn new(r: f32) -> Self { - LogGammaPlus::new_with_n(r, 100) - } - - pub fn default() -> Self { - LogGammaPlus::new(0.0) - } - - pub fn reset(&mut self, r: f32) { - self.values.iter_mut().enumerate().for_each(|(k, v)| { - *v = lgammaf(r + k as f32); - }); - self.r = r; - } - - pub fn eval(&self, k: u32) -> f32 { - self.values - .get(k as usize) - .map_or_else(|| lgammaf(self.r + k as f32), |&value| value) - } +pub fn randn(rng: &mut ThreadRng) -> f32 { + return rng.sample::(StandardNormal); } diff --git a/src/sampler/transcripts.rs b/src/sampler/transcripts.rs index 2b6534d..cac6203 100644 --- a/src/sampler/transcripts.rs +++ b/src/sampler/transcripts.rs @@ -4,8 +4,9 @@ use flate2::read::MultiGzDecoder; use itertools::izip; use kiddo::float::kdtree::KdTree; use kiddo::SquaredEuclidean; -use ndarray::Array2; +use ndarray::{Array1, Array2}; use parquet::arrow::arrow_reader::{ParquetRecordBatchReader, ParquetRecordBatchReaderBuilder}; +use regex::Regex; use std::collections::HashMap; use std::fs::File; use std::str; @@ -38,9 +39,36 @@ pub struct TranscriptDataset { pub fov_names: Vec, } +impl TranscriptDataset { + pub fn select_unfactored_genes(&mut self, _nunfactored: usize) { + // Current heuristic is just to select the highest expression genes. + let mut gene_counts = Array1::::zeros(self.transcript_names.len()); + for transcript in self.transcripts.iter() { + gene_counts[transcript.gene as usize] += 1; + } + + let mut ord = (0..self.transcript_names.len()).collect::>(); + ord.sort_unstable_by(|&i, &j| gene_counts[i].cmp(&gene_counts[j]).reverse()); + + let mut rev_ord = vec![0; ord.len()]; + for (i, j) in ord.iter().enumerate() { + rev_ord[*j] = i; + } + + self.transcript_names = ord + .iter() + .map(|&i| self.transcript_names[i].clone()) + .collect(); + for transcript in self.transcripts.iter_mut() { + transcript.gene = rev_ord[transcript.gene as usize] as u32; + } + } +} + #[allow(clippy::too_many_arguments)] pub fn read_transcripts_csv( path: &str, + excluded_genes: Option, transcript_column: &str, id_column: Option, compartment_column: Option, @@ -55,7 +83,7 @@ pub fn read_transcripts_csv( y_column: &str, z_column: &str, min_qv: f32, - no_z_column: bool, + ignore_z_column: bool, coordinate_scale: f32, ) -> TranscriptDataset { let fmt = infer_format_from_filename(path); @@ -65,6 +93,7 @@ pub fn read_transcripts_csv( let mut rdr = csv::Reader::from_path(path).unwrap(); read_transcripts_csv_xyz( &mut rdr, + excluded_genes, transcript_column, id_column, compartment_column, @@ -79,7 +108,7 @@ pub fn read_transcripts_csv( y_column, z_column, min_qv, - no_z_column, + ignore_z_column, coordinate_scale, ) } @@ -87,6 +116,7 @@ pub fn read_transcripts_csv( let mut rdr = csv::Reader::from_reader(MultiGzDecoder::new(File::open(path).unwrap())); read_transcripts_csv_xyz( &mut rdr, + excluded_genes, transcript_column, id_column, compartment_column, @@ -101,12 +131,13 @@ pub fn read_transcripts_csv( y_column, z_column, min_qv, - no_z_column, + ignore_z_column, coordinate_scale, ) } OutputFormat::Parquet => read_xenium_transcripts_parquet( path, + excluded_genes, transcript_column, &id_column.unwrap(), &compartment_column.unwrap(), @@ -119,7 +150,7 @@ pub fn read_transcripts_csv( y_column, z_column, min_qv, - no_z_column, + ignore_z_column, coordinate_scale, ), OutputFormat::Infer => panic!("Could not infer format of file '{}'", path), @@ -188,7 +219,7 @@ fn postprocess_cell_assignments( #[allow(clippy::too_many_arguments)] fn read_transcripts_csv_xyz( rdr: &mut csv::Reader, - + excluded_genes: Option, transcript_column: &str, id_column: Option, compartment_column: Option, @@ -277,6 +308,12 @@ where let transcript_name = &row[transcript_col]; + if let Some(excluded_genes) = &excluded_genes { + if excluded_genes.is_match(transcript_name) { + continue; + } + } + let gene = if let Some(gene) = transcript_name_map.get(transcript_name) { *gene } else { @@ -388,6 +425,7 @@ where #[allow(clippy::too_many_arguments)] fn read_xenium_transcripts_parquet( filename: &str, + excluded_genes: Option, transcript_col_name: &str, id_col_name: &str, compartment_col_name: &str, @@ -421,6 +459,7 @@ fn read_xenium_transcripts_parquet( read_xenium_transcripts_parquet_str_type::( rdr, schema, + excluded_genes, transcript_col_name, id_col_name, compartment_col_name, @@ -441,6 +480,7 @@ fn read_xenium_transcripts_parquet( read_xenium_transcripts_parquet_str_type::( rdr, schema, + excluded_genes, transcript_col_name, id_col_name, compartment_col_name, @@ -465,6 +505,7 @@ fn read_xenium_transcripts_parquet( fn read_xenium_transcripts_parquet_str_type( rdr: ParquetRecordBatchReader, schema: arrow::datatypes::Schema, + excluded_genes: Option, transcript_col_name: &str, id_col_name: &str, compartment_col_name: &str, @@ -587,6 +628,12 @@ where continue; } + if let Some(excluded_genes) = &excluded_genes { + if excluded_genes.is_match(transcript) { + continue; + } + } + let fov = match fov_map.get(fov) { Some(fov) => *fov, None => { diff --git a/src/sampler/voxelsampler.rs b/src/sampler/voxelsampler.rs index 46118db..8806889 100644 --- a/src/sampler/voxelsampler.rs +++ b/src/sampler/voxelsampler.rs @@ -9,7 +9,7 @@ use super::{chunkquad, perimeter_bound, ModelParams, ModelPriors, Proposal, Samp // use arrow; use geo::geometry::{MultiPolygon, Polygon}; use itertools::Itertools; -use ndarray::Array2; +use ndarray::{Array2, Zip}; use rand::{thread_rng, Rng}; use rayon::prelude::*; use std::cell::RefCell; @@ -500,8 +500,6 @@ impl VoxelSampler { } } - params.recompute_counts(transcripts); - // for (transcript, &cell) in transcripts.iter().zip(params.cell_assignments.iter()) { // // let position = clip_z_position( // // (transcript.x, transcript.y, transcript.z), zmin, zmax); @@ -550,6 +548,7 @@ impl VoxelSampler { sampler.recompute_cell_population(); sampler.recompute_cell_perimeter(); sampler.recompute_cell_volume(priors, params); + params.effective_cell_volume.assign(¶ms.cell_volume); sampler.populate_mismatches(); sampler.update_transcript_positions( &vec![true; transcripts.len()], @@ -705,6 +704,14 @@ impl VoxelSampler { assert!(*cell_volume > 0.0); *cell_volume = cell_volume.max(priors.min_cell_volume); } + + Zip::from(&mut params.log_cell_volume) + .and(¶ms.cell_volume) + .into_par_iter() + .with_min_len(50) + .for_each(|(log_volume, &volume)| { + *log_volume = volume.ln(); + }); } fn recompute_cell_population(&mut self) { @@ -1407,6 +1414,7 @@ impl Sampler for VoxelSampler { } fn update_transcript_positions(&mut self, updated: &[bool], positions: &[(f32, f32, f32)]) { + // let t0 = Instant::now(); self.transcript_voxels .par_iter_mut() .zip(positions) @@ -1417,9 +1425,12 @@ impl Sampler for VoxelSampler { *voxel = self.chunkquad.layout.world_pos_to_voxel(position); } }); + // println!(" REPO: compute voxels {:?}", t0.elapsed()); + // let t0 = Instant::now(); self.transcript_voxel_ord - .par_sort_unstable_by_key(|&t| self.transcript_voxels[t]); + .par_sort_by_key(|&t| self.transcript_voxels[t]); + // println!(" REPO: sort on voxel assignment {:?}", t0.elapsed()); } // fn update_transcript_position(&mut self, i: usize, prev_pos: (f32, f32, f32), new_pos: (f32, f32, f32)) {