diff --git a/eggstrain/src/execution/mod.rs b/eggstrain/src/execution/mod.rs index 0492aff..c439a44 100644 --- a/eggstrain/src/execution/mod.rs +++ b/eggstrain/src/execution/mod.rs @@ -1,2 +1,4 @@ pub mod operators; pub mod query_dag; +pub mod record_buffer; +pub mod record_table; diff --git a/eggstrain/src/execution/operators/filter.rs b/eggstrain/src/execution/operators/filter.rs index d771c18..286423e 100644 --- a/eggstrain/src/execution/operators/filter.rs +++ b/eggstrain/src/execution/operators/filter.rs @@ -18,7 +18,10 @@ pub(crate) struct Filter { /// TODO docs impl Filter { - pub(crate) fn new(predicate: Arc, children: Vec>) -> Self { + pub(crate) fn new( + predicate: Arc, + children: Vec>, + ) -> Self { Self { predicate, children, diff --git a/eggstrain/src/execution/operators/hash_join.rs b/eggstrain/src/execution/operators/hash_join.rs new file mode 100644 index 0000000..f2b6667 --- /dev/null +++ b/eggstrain/src/execution/operators/hash_join.rs @@ -0,0 +1,257 @@ +use super::{BinaryOperator, Operator}; +use crate::execution::record_table::RecordTable; +use arrow::array::ArrayRef; +use arrow::datatypes::{Schema, SchemaBuilder, SchemaRef}; +use arrow::record_batch::RecordBatch; +use async_trait::async_trait; +use datafusion::common::arrow::row::{Row, RowConverter, Rows}; +use datafusion::logical_expr::{left, ColumnarValue}; +use datafusion::physical_expr::PhysicalExprRef; +use datafusion::physical_plan::joins::utils::{build_join_schema, ColumnIndex}; +use datafusion::physical_plan::joins::HashJoinExec; +use datafusion::physical_plan::ExecutionPlan; +use datafusion_common::hash_utils::create_hashes; +use datafusion_common::{DataFusionError, JoinSide, JoinType, Result}; +use std::sync::Arc; +use tokio::sync::broadcast; +use tokio::sync::broadcast::error::RecvError; + +/// TODO docs +pub struct HashJoin { + left_schema: SchemaRef, + right_schema: SchemaRef, + equate_on: Vec<(PhysicalExprRef, PhysicalExprRef)>, + children: Vec>, +} + +/// TODO docs +impl HashJoin { + pub(crate) fn new(hash_exec: &HashJoinExec) -> Self { + Self { + children: hash_exec.children().clone(), + left_schema: hash_exec.left().schema(), + right_schema: hash_exec.right().schema(), + //correctly convert to match the type + equate_on: hash_exec + .on() + .iter() + .map(|(left, right)| (left.clone().clone(), right.clone())) + .collect(), + } + } + + /// Given a [`RecordBatch`]`, hashes based on the input physical expressions. + /// + /// TODO docs + fn hash_batch(&self, batch: &RecordBatch) -> Result> { + let rows = batch.num_rows(); + + // A vector of columns, horizontally these are the join keys + let mut column_vals = Vec::with_capacity(self.equate_on.len()); + + for (left_eq, right_eq) in self.equate_on.iter() { + let eq = if LEFT { left_eq } else { right_eq }; + + // Extract a single column + let col_val = eq.evaluate(batch)?; + match col_val { + ColumnarValue::Array(arr) => column_vals.push(arr), + ColumnarValue::Scalar(s) => { + return Err(DataFusionError::NotImplemented(format!( + "Join physical expression scalar condition on {:#?} not implemented", + s + ))); + } + } + } + + let mut hashes = Vec::with_capacity(column_vals.len()); + create_hashes(&column_vals, &Default::default(), &mut hashes)?; + assert_eq!(hashes.len(), rows); + + Ok(hashes) + } + + /// Builds the Hash Table from the [`RecordBatch`]es coming from the left child. + /// + /// TODO docs + async fn build_table( + &self, + mut rx_left: broadcast::Receiver, + ) -> Result { + // Take in all of the record batches from the left and create a hash table + let mut record_table = RecordTable::new(self.left_schema.clone()); + + loop { + match rx_left.recv().await { + Ok(batch) => { + // TODO gather N batches and use rayon to insert all at once + let hashes = self.hash_batch::(&batch)?; + record_table.insert_batch(batch, hashes)?; + } + Err(e) => match e { + RecvError::Closed => break, + RecvError::Lagged(_) => todo!(), + }, + } + } + + Ok(record_table) + } + + pub fn build_join_schema(left: &Schema, right: &Schema) -> (Schema, Vec) { + let (fields, column_indices): (SchemaBuilder, Vec) = { + let left_fields = + left.fields() + .iter() + .map(|f| f.clone()) + .enumerate() + .map(|(index, f)| { + ( + f, + ColumnIndex { + index, + side: JoinSide::Left, + }, + ) + }); + let right_fields = + right + .fields() + .iter() + .map(|f| f.clone()) + .enumerate() + .map(|(index, f)| { + ( + f, + ColumnIndex { + index, + side: JoinSide::Right, + }, + ) + }); + + // left then right + left_fields.chain(right_fields).unzip() + }; + (fields.finish(), column_indices) + } + + /// Given a single batch (coming from the right child), probes the hash table and outputs a + /// [`RecordBatch`] for every tuple on the right that gets matched with a tuple in the hash table. + /// + /// Note: This is super inefficient since its possible that we could emit a bunch of + /// [`RecordBatch`]es that have just 1 tuple in them. + /// + /// TODO This is a place for easy optimization. + /// + /// TODO only implements an inner join + async fn probe_table( + &self, + table: &RecordTable, + right_batch: RecordBatch, + tx: &broadcast::Sender, + ) -> Result<()> { + let hashes = self.hash_batch::(&right_batch)?; + + // let left_column_count = self.left_schema.fields().size(); + // let right_column_count = self.right_schema.fields().size(); + // let output_columns = left_column_count + right_column_count - self.equate_on.len(); + + let right_rows = table.buffer.record_batch_to_rows(right_batch)?; + for (row, &hash) in hashes.iter().enumerate() { + // For each of these hashes, check if it is in the table + let Some(records) = table.get_record_indices(hash) else { + return Ok(()); + }; + assert!(!records.is_empty()); + + // TODO + + // Create a new schema that is the join of the two schemas + let left_schema: Schema = (*self.left_schema).clone(); + let right_schema: Schema = (*self.right_schema).clone(); + + // let new_schema = Schema::try_merge(vec![left_schema, right_schema])?; + let (new_schema, column_indices) = + HashJoin::build_join_schema(&left_schema, &right_schema); + let joined_schema: SchemaRef = Arc::new(new_schema); + + let row_converter = RowConverter::new(new_schema); + // There are records associated with this hash value, so we need to emit things + for &record in records { + let left_tuple = table.buffer.get(record).unwrap(); + let right_tuple: Row = right_rows.row(row); + let joined_tuple = todo!("Join the two tuples in some way"); + // let cols = vec[cols] + for col in column_indices {} + self.column_index = + todo!("Join the two tuples in some way, then append to a `Rows`") + } + todo!("Convert the `Rows` back into a `RecordBatch` with `RowConverter::convert_rows`"); + let out_columns: Vec = RowConverter::convert_rows(joined_schema, rows)?; + + todo!("Figure out names for each column"); + + let out_columns_iter = out_columns.into_iter().map(|col| ("name", col)); + + let joined_batch = RecordBatch::try_from_iter(out_columns_iter)?; + + tx.send(joined_batch) + .expect("Unable to send the projected batch"); + } + + Ok(()) + } +} + +/// TODO docs +impl Operator for HashJoin { + fn children(&self) -> Vec> { + self.children.clone() + } +} + +/// TODO docs +#[async_trait] +impl BinaryOperator for HashJoin { + type InLeft = RecordBatch; + type InRight = RecordBatch; + type Out = RecordBatch; + + fn into_binary( + self, + ) -> Arc> + { + Arc::new(self) + } + + async fn execute( + &self, + rx_left: broadcast::Receiver, + mut rx_right: broadcast::Receiver, + tx: broadcast::Sender, + ) { + // Phase 1: Build Phase + // TODO assign to its own tokio task + let record_table = self + .build_table(rx_left) + .await + .expect("Unable to build hash table"); + + // Phase 2: Probe Phase + loop { + match rx_right.recv().await { + Ok(batch) => { + self.probe_table(&record_table, batch, &tx) + .await + .expect("Unable to probe hash table"); + } + Err(e) => match e { + RecvError::Closed => break, + RecvError::Lagged(_) => todo!(), + }, + } + } + } +} diff --git a/eggstrain/src/execution/operators/mod.rs b/eggstrain/src/execution/operators/mod.rs index 603e02b..461332d 100644 --- a/eggstrain/src/execution/operators/mod.rs +++ b/eggstrain/src/execution/operators/mod.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use tokio::sync::broadcast::{Receiver, Sender}; pub mod filter; +pub mod hash_join; pub mod project; /// Defines shared behavior for all operators diff --git a/eggstrain/src/execution/query_dag.rs b/eggstrain/src/execution/query_dag.rs index 354fb1f..dc3eb36 100644 --- a/eggstrain/src/execution/query_dag.rs +++ b/eggstrain/src/execution/query_dag.rs @@ -121,8 +121,7 @@ fn datafusion_execute(plan: Arc, tx: broadcast::Sender Self { + Self { index, row } + } + + // Use functional style due to easy copying + pub fn with_row(&self, row: u32) -> Self { + Self { + index: self.index, + row, + } + } +} + +pub fn schema_to_fields(schema: SchemaRef) -> Vec { + schema + .fields() + .iter() + .map(|f| SortField::new(f.data_type().clone())) + .collect::>() +} + +pub struct RecordBuffer { + schema: SchemaRef, + converter: RowConverter, + inner: Vec, // vector of row groups +} + +impl RecordBuffer { + pub fn new(schema: SchemaRef) -> Self { + let fields = schema_to_fields(schema.clone()); + Self { + schema, + converter: RowConverter::new(fields).expect("Unable to create a RowConverter"), + inner: vec![], + } + } + + pub fn with_capacity(schema: SchemaRef, capacity: usize) -> Self { + let fields = schema_to_fields(schema.clone()); + Self { + schema, + converter: RowConverter::new(fields).expect("Unable to create a RowConverter"), + inner: Vec::with_capacity(capacity), + } + } + + pub fn converter(&self) -> &RowConverter { + &self.converter + } + + pub fn record_batch_to_rows(&self, batch: RecordBatch) -> Result { + // `Ok` to make use of `?` behavior + Ok(self.converter.convert_columns(batch.columns())?) + } + + pub fn insert(&mut self, batch: RecordBatch) -> Result { + assert_eq!( + self.schema, + batch.schema(), + "Trying to insert a RecordBatch into a RecordBuffer with the incorrect schema" + ); + assert!( + (self.inner.len() as u32) < u32::MAX - 1, + "Maximum size for a RecordBuffer is u32::MAX" + ); + + let rows = self.record_batch_to_rows(batch)?; + self.inner.push(rows); + + Ok(RecordIndex { + index: (self.inner.len() - 1) as u32, + row: 0, + }) + } + + /// Retrieve the row group and row number associated with the RecordIndex + pub fn get_group(&self, index: RecordIndex) -> Option<(&Rows, u32)> { + if (index.index as usize) >= self.inner.len() { + return None; + } + + Some((&self.inner[index.index as usize], index.row)) + } + + /// Retrieve row / tuple associated with the RecordIndex + pub fn get(&self, index: RecordIndex) -> Option { + if (index.index as usize) >= self.inner.len() { + return None; + } + + Some(self.inner[index.index as usize].row(index.row as usize)) + } +} diff --git a/eggstrain/src/execution/record_table.rs b/eggstrain/src/execution/record_table.rs new file mode 100644 index 0000000..540fd84 --- /dev/null +++ b/eggstrain/src/execution/record_table.rs @@ -0,0 +1,56 @@ +use super::record_buffer::{RecordBuffer, RecordIndex}; +use arrow::row::RowConverter; +use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; +use datafusion_common::Result; +use std::collections::HashMap; // TODO replace with a raw table + +pub struct RecordTable { + /// Maps a Hash value to a `RecordIndex` into the `RecordBuffer` + inner: HashMap>, + pub(crate) buffer: RecordBuffer, +} + +impl RecordTable { + pub fn new(schema: SchemaRef) -> Self { + Self { + buffer: RecordBuffer::new(schema), + inner: HashMap::new(), + } + } + + pub fn with_capacity(schema: SchemaRef, map_capacity: usize, buffer_capacity: usize) -> Self { + Self { + buffer: RecordBuffer::with_capacity(schema, buffer_capacity), + inner: HashMap::with_capacity(map_capacity), + } + } + + pub fn converter(&self) -> &RowConverter { + self.buffer.converter() + } + + pub fn insert_batch(&mut self, batch: RecordBatch, hashes: Vec) -> Result<()> { + assert_eq!( + batch.num_rows(), + hashes.len(), + "There should be an equal amount of batch rows and hashed values" + ); + + // Points to the location of the base of the record batch + let base_record_id = self.buffer.insert(batch)?; + + for (row, &hash) in hashes.iter().enumerate() { + // Insert the record into the hashtable bucket + self.inner + .entry(hash) + .or_default() + .push(base_record_id.with_row(row as u32)) + } + + Ok(()) + } + + pub fn get_record_indices(&self, hash: u64) -> Option<&Vec> { + self.inner.get(&hash) + } +}