From ee0e93c1b696d33e77200655fb4da3d5ea9a5fdd Mon Sep 17 00:00:00 2001 From: Folyd Date: Thu, 19 Oct 2023 22:47:53 +0800 Subject: [PATCH] Add `Field::remove()`, `Schema::remove_field()`, and `RecordBatch::remove_column()` APIs --- arrow-array/src/record_batch.rs | 41 +++++++++++++++++++++++++++++++++ arrow-schema/src/fields.rs | 23 ++++++++++++++++++ arrow-schema/src/schema.rs | 21 +++++++++++++++++ 3 files changed, 85 insertions(+) diff --git a/arrow-array/src/record_batch.rs b/arrow-array/src/record_batch.rs index 27804447fba6..88d596afdd33 100644 --- a/arrow-array/src/record_batch.rs +++ b/arrow-array/src/record_batch.rs @@ -20,6 +20,7 @@ use crate::{new_empty_array, Array, ArrayRef, StructArray}; use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaBuilder, SchemaRef}; +use std::mem; use std::ops::Index; use std::sync::Arc; @@ -334,6 +335,46 @@ impl RecordBatch { &self.columns[..] } + /// Remove column by index and return it. + /// + /// Return `Some(ArrayRef)` if the column is removed, otherwise return `None. + /// - Return `None` if the `index` is out of bounds + /// - Return `None` if the `index` is in bounds but the schema is shared (i.e. ref count > 1) + /// + /// ``` + /// use std::sync::Arc; + /// use arrow_array::{BooleanArray, Int32Array, RecordBatch}; + /// use arrow_schema::{DataType, Field, Schema}; + /// let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]); + /// let bool_array = BooleanArray::from(vec![true, false, false, true, true]); + /// let schema = Schema::new(vec![ + /// Field::new("id", DataType::Int32, false), + /// Field::new("bool", DataType::Boolean, false), + /// ]); + /// + /// let mut batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id_array), Arc::new(bool_array)]).unwrap(); + /// + /// let removed_column = batch.remove_column(0).unwrap(); + /// assert_eq!(removed_column.as_any().downcast_ref::().unwrap(), &Int32Array::from(vec![1, 2, 3, 4, 5])); + /// assert_eq!(batch.num_columns(), 1); + /// ``` + pub fn remove_column(&mut self, index: usize) -> Option { + if index < self.num_columns() { + let new_schema = mem::replace(&mut self.schema, Arc::new(Schema::empty())); + if Arc::strong_count(&new_schema) == 1 { + let mut schema = Arc::::into_inner(new_schema).unwrap(); + schema.fields.remove(index); + self.schema = Arc::new(schema); + return Some(self.columns.remove(index)); + } else { + self.schema = new_schema; + return None; + } + } + + None + } + /// Return a new RecordBatch where each column is sliced /// according to `offset` and `length` /// diff --git a/arrow-schema/src/fields.rs b/arrow-schema/src/fields.rs index 07e9abeee56a..5793e8c2096d 100644 --- a/arrow-schema/src/fields.rs +++ b/arrow-schema/src/fields.rs @@ -83,6 +83,29 @@ impl Fields { .zip(other.iter()) .all(|(a, b)| Arc::ptr_eq(a, b) || a.contains(b)) } + + /// Remove a field by index and reture it. + /// ``` + /// use arrow_schema::{DataType, Field, Fields}; + /// let mut fields = Fields::from(vec![ + /// Field::new("a", DataType::Boolean, false), + /// Field::new("b", DataType::Int8, false), + /// Field::new("c", DataType::Utf8, false), + /// ]); + /// assert_eq!(fields.len(), 3); + /// assert_eq!(fields.remove(1).unwrap(), Field::new("b", DataType::Int8, false).into()); + /// assert_eq!(fields.len(), 2); + /// ``` + pub fn remove(&mut self, index: usize) -> Option { + if index >= self.len() { + return None; + } + + let mut new_fields = self.0.iter().cloned().collect::>(); + let field = new_fields.remove(index); + self.0 = Arc::from(new_fields); + Some(field) + } } impl Default for Fields { diff --git a/arrow-schema/src/schema.rs b/arrow-schema/src/schema.rs index a00e8a588757..61bc1cf227cb 100644 --- a/arrow-schema/src/schema.rs +++ b/arrow-schema/src/schema.rs @@ -327,6 +327,27 @@ impl Schema { self.metadata.get(k).map(|v2| v1 == v2).unwrap_or_default() }) } + + /// Remove field by name and return it. + /// + /// ``` + /// use arrow_schema::{DataType, Field, Schema}; + /// let mut schema = Schema::new(vec![ + /// Field::new("a", DataType::Boolean, false), + /// Field::new("b", DataType::Int8, false), + /// Field::new("c", DataType::Utf8, false), + /// ]); + /// assert_eq!(schema.fields.len(), 3); + /// assert_eq!(schema.remove_field("b").unwrap(), Field::new("b", DataType::Int8, false).into()); + /// assert_eq!(schema.fields.len(), 2); + /// ``` + pub fn remove_field(&mut self, name: &str) -> Option { + if let Some((idx, _)) = self.fields.find(name) { + self.fields.remove(idx) + } else { + None + } + } } impl fmt::Display for Schema {