diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py
index 384b17878..718ebf69d 100644
--- a/python/tests/test_dataframe.py
+++ b/python/tests/test_dataframe.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import os
+import re
from typing import Any
import pyarrow as pa
@@ -1245,13 +1246,17 @@ def add_with_parameter(df_internal, value: Any) -> DataFrame:
def test_dataframe_repr_html(df) -> None:
output = df._repr_html_()
- ref_html = """
-
a
b
c
-
1
4
8
-
2
5
5
-
3
6
8
-
- """
+ # Since we've added a fair bit of processing to the html output, lets just verify
+ # the values we are expecting in the table exist. Use regex and ignore everything
+ # between the
and
. We also don't want the closing > on the
+ # td and th segments because that is where the formatting data is written.
- # Ignore whitespace just to make this test look cleaner
- assert output.replace(" ", "") == ref_html.replace(" ", "")
+ headers = ["a", "b", "c"]
+ headers = [f"
" for inner in body_data for v in inner]
+ body_pattern = "(.*?)".join(body_lines)
+ assert len(re.findall(body_pattern, output, re.DOTALL)) == 1
diff --git a/src/dataframe.rs b/src/dataframe.rs
index 243e2e14f..be10b8c28 100644
--- a/src/dataframe.rs
+++ b/src/dataframe.rs
@@ -31,9 +31,11 @@ use datafusion::common::UnnestOptions;
use datafusion::config::{CsvOptions, TableParquetOptions};
use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
use datafusion::datasource::TableProvider;
+use datafusion::error::DataFusionError;
use datafusion::execution::SendableRecordBatchStream;
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
use datafusion::prelude::*;
+use futures::{StreamExt, TryStreamExt};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
@@ -70,6 +72,9 @@ impl PyTableProvider {
PyTable::new(table_provider)
}
}
+const MAX_TABLE_BYTES_TO_DISPLAY: usize = 2 * 1024 * 1024; // 2 MB
+const MIN_TABLE_ROWS_TO_DISPLAY: usize = 20;
+const MAX_LENGTH_CELL_WITHOUT_MINIMIZE: usize = 25;
/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
@@ -111,56 +116,151 @@ impl PyDataFrame {
}
fn __repr__(&self, py: Python) -> PyDataFusionResult {
- let df = self.df.as_ref().clone().limit(0, Some(10))?;
- let batches = wait_for_future(py, df.collect())?;
- let batches_as_string = pretty::pretty_format_batches(&batches);
- match batches_as_string {
- Ok(batch) => Ok(format!("DataFrame()\n{batch}")),
- Err(err) => Ok(format!("Error: {:?}", err.to_string())),
+ let (batches, has_more) = wait_for_future(
+ py,
+ collect_record_batches_to_display(self.df.as_ref().clone(), 10, 10),
+ )?;
+ if batches.is_empty() {
+ // This should not be reached, but do it for safety since we index into the vector below
+ return Ok("No data to display".to_string());
}
- }
- fn _repr_html_(&self, py: Python) -> PyDataFusionResult {
- let mut html_str = "
\n");
- return Ok(html_str);
+ // This should not be reached, but do it for safety since we index into the vector below
+ return Ok("No data to display".to_string());
}
+ let table_uuid = uuid::Uuid::new_v4().to_string();
+
+ let mut html_str = "
+
+
+
+
+ \n".to_string();
+
let schema = batches[0].schema();
let mut header = Vec::new();
for field in schema.fields() {
- header.push(format!("
{}", field.name()));
+ header.push(format!("
{}
", field.name()));
}
let header_str = header.join("");
- html_str.push_str(&format!("
{}
\n", header_str));
-
- for batch in batches {
- let formatters = batch
- .columns()
- .iter()
- .map(|c| ArrayFormatter::try_new(c.as_ref(), &FormatOptions::default()))
- .map(|c| {
- c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string())))
- })
- .collect::, _>>()?;
-
- for row in 0..batch.num_rows() {
+ html_str.push_str(&format!("
{}
\n", header_str));
+
+ let batch_formatters = batches
+ .iter()
+ .map(|batch| {
+ batch
+ .columns()
+ .iter()
+ .map(|c| ArrayFormatter::try_new(c.as_ref(), &FormatOptions::default()))
+ .map(|c| {
+ c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string())))
+ })
+ .collect::, _>>()
+ })
+ .collect::, _>>()?;
+
+ let rows_per_batch = batches.iter().map(|batch| batch.num_rows());
+
+ // We need to build up row by row for html
+ let mut table_row = 0;
+ for (batch_formatter, num_rows_in_batch) in batch_formatters.iter().zip(rows_per_batch) {
+ for batch_row in 0..num_rows_in_batch {
+ table_row += 1;
let mut cells = Vec::new();
- for formatter in &formatters {
- cells.push(format!("
{}
", formatter.value(row)));
+ for (col, formatter) in batch_formatter.iter().enumerate() {
+ let cell_data = formatter.value(batch_row).to_string();
+ // From testing, primitive data types do not typically get larger than 21 characters
+ if cell_data.len() > MAX_LENGTH_CELL_WITHOUT_MINIMIZE {
+ let short_cell_data = &cell_data[0..MAX_LENGTH_CELL_WITHOUT_MINIMIZE];
+ cells.push(format!("
+
+
+ {short_cell_data}
+ {cell_data}
+
+
+
"));
+ } else {
+ cells.push(format!("
{}
", formatter.value(batch_row)));
+ }
}
let row_str = cells.join("");
html_str.push_str(&format!("
{}
\n", row_str));
}
}
+ html_str.push_str("
\n");
+
+ html_str.push_str("
+
+ ");
- html_str.push_str("\n");
+ if has_more {
+ html_str.push_str("Data truncated due to size.");
+ }
Ok(html_str)
}
@@ -771,3 +871,83 @@ fn record_batch_into_schema(
RecordBatch::try_new(schema, data_arrays)
}
+
+/// This is a helper function to return the first non-empty record batch from executing a DataFrame.
+/// It additionally returns a bool, which indicates if there are more record batches available.
+/// We do this so we can determine if we should indicate to the user that the data has been
+/// truncated. This collects until we have achived both of these two conditions
+///
+/// - We have collected our minimum number of rows
+/// - We have reached our limit, either data size or maximum number of rows
+///
+/// Otherwise it will return when the stream has exhausted. If you want a specific number of
+/// rows, set min_rows == max_rows.
+async fn collect_record_batches_to_display(
+ df: DataFrame,
+ min_rows: usize,
+ max_rows: usize,
+) -> Result<(Vec, bool), DataFusionError> {
+ let partitioned_stream = df.execute_stream_partitioned().await?;
+ let mut stream = futures::stream::iter(partitioned_stream).flatten();
+ let mut size_estimate_so_far = 0;
+ let mut rows_so_far = 0;
+ let mut record_batches = Vec::default();
+ let mut has_more = false;
+
+ while (size_estimate_so_far < MAX_TABLE_BYTES_TO_DISPLAY && rows_so_far < max_rows)
+ || rows_so_far < min_rows
+ {
+ let mut rb = match stream.next().await {
+ None => {
+ break;
+ }
+ Some(Ok(r)) => r,
+ Some(Err(e)) => return Err(e),
+ };
+
+ let mut rows_in_rb = rb.num_rows();
+ if rows_in_rb > 0 {
+ size_estimate_so_far += rb.get_array_memory_size();
+
+ if size_estimate_so_far > MAX_TABLE_BYTES_TO_DISPLAY {
+ let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / size_estimate_so_far as f32;
+ let total_rows = rows_in_rb + rows_so_far;
+
+ let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize;
+ if reduced_row_num < min_rows {
+ reduced_row_num = min_rows.min(total_rows);
+ }
+
+ let limited_rows_this_rb = reduced_row_num - rows_so_far;
+ if limited_rows_this_rb < rows_in_rb {
+ rows_in_rb = limited_rows_this_rb;
+ rb = rb.slice(0, limited_rows_this_rb);
+ has_more = true;
+ }
+ }
+
+ if rows_in_rb + rows_so_far > max_rows {
+ rb = rb.slice(0, max_rows - rows_so_far);
+ has_more = true;
+ }
+
+ rows_so_far += rb.num_rows();
+ record_batches.push(rb);
+ }
+ }
+
+ if record_batches.is_empty() {
+ return Ok((Vec::default(), false));
+ }
+
+ if !has_more {
+ // Data was not already truncated, so check to see if more record batches remain
+ has_more = match stream.try_next().await {
+ Ok(None) => false, // reached end
+ Ok(Some(_)) => true,
+ Err(_) => false, // Stream disconnected
+ };
+ }
+
+ Ok((record_batches, has_more))
+}
diff --git a/src/utils.rs b/src/utils.rs
index 999aad755..3487de21b 100644
--- a/src/utils.rs
+++ b/src/utils.rs
@@ -42,7 +42,7 @@ pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime {
#[inline]
pub(crate) fn get_global_ctx() -> &'static SessionContext {
static CTX: OnceLock = OnceLock::new();
- CTX.get_or_init(|| SessionContext::new())
+ CTX.get_or_init(SessionContext::new)
}
/// Utility to collect rust futures with GIL released