Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

_repr_ and _html_repr_ show '... and additional rows' message #1041

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 77 additions & 18 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,59 +90,108 @@ impl PyDataFrame {
}

fn __repr__(&self, py: Python) -> PyDataFusionResult<String> {
let df = self.df.as_ref().clone().limit(0, Some(10))?;
// Get 11 rows to check if there are more than 10
let df = self.df.as_ref().clone().limit(0, Some(11))?;
let batches = wait_for_future(py, df.collect())?;
let batches_as_string = pretty::pretty_format_batches(&batches);
let num_rows = batches.iter().map(|batch| batch.num_rows()).sum::<usize>();

// Flatten batches into a single batch for the first 10 rows
let mut all_rows = Vec::new();
let mut total_rows = 0;

for batch in &batches {
let num_rows_to_take = if total_rows + batch.num_rows() > 10 {
10 - total_rows
} else {
batch.num_rows()
};

if num_rows_to_take > 0 {
let sliced_batch = batch.slice(0, num_rows_to_take);
all_rows.push(sliced_batch);
total_rows += num_rows_to_take;
}

if total_rows >= 10 {
break;
}
}

let batches_as_string = pretty::pretty_format_batches(&all_rows);

Comment on lines +93 to +121
Copy link
Contributor

@kosiew kosiew Mar 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can simplify batches_as_string to:

        // First get just the first 10 rows
        let preview_df = self.df.as_ref().clone().limit(0, Some(10))?;
        let preview_batches = wait_for_future(py, preview_df.collect())?;

        // Check if there are more rows by trying to get the 11th row
        let has_more_rows = {
            let check_df = self.df.as_ref().clone().limit(10, Some(1))?;
            let check_batch = wait_for_future(py, check_df.collect())?;
            !check_batch.is_empty()
        };

        let batches_as_string = pretty::pretty_format_batches(&preview_batches);

This directly retrieves just the first 10 rows, eliminating the need for manual row tracking and slicing.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did try this initiatially but calling collect twice led to a severe performance degradation. It used to take 50ms. With the manual slicing, it dropped to 5ms.

You can check my initial suggestion for the same here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

calling collect twice led to a severe performance degradation

I ran this test to compare the performance:

import pyarrow as pa
from datafusion import (
    SessionContext,
)
import time


def run_dataframe_repr_long() -> None:
    ctx = SessionContext()
    # Create a DataFrame with more than 10 rows
    batch = pa.RecordBatch.from_arrays(
        [
            pa.array(list(range(15))),
            pa.array([x * 2 for x in range(15)]),
            pa.array([x * 3 for x in range(15)]),
        ],
        names=["a", "b", "c"],
    )
    df = ctx.create_dataframe([[batch]])

    output = repr(df)


def average_runtime(func, runs=100):
    total_time = 0
    for _ in range(runs):
        start_time = time.time()
        func()
        end_time = time.time()
        total_time += end_time - start_time
    return total_time / runs


average_time = average_runtime(run_dataframe_repr_long)
print(f"Average runtime over {100} runs: {average_time:.6f} seconds")

and found no significant difference:

image

pr_1041 - is the branch with one collect
amended_pr_1041 - is the branch with two collect

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's weird. Maybe some artifact of my system settings? If there is no performance issues than I'll use your approach. But then why was the _repr_html_ using batch manipulation at the first place? I took the idea from that function!

Copy link
Contributor

@kosiew kosiew Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @Spaarsh

Sorry, in my previous test, I overlooked to maturin develop for the Rust changes.

In my retests, two collects does take about 53% (1935/1265) longer.

image

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh no issues. Thanks for corroborating my findings btw!

match batches_as_string {
Ok(batch) => Ok(format!("DataFrame()\n{batch}")),
Ok(batch) => {
if num_rows > 10 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using has_more_rows from above

+                if has_more_rows {

Ok(format!("DataFrame()\n{batch}\n... and additional rows"))
} else {
Ok(format!("DataFrame()\n{batch}"))
}
}
Err(err) => Ok(format!("Error: {:?}", err.to_string())),
}
}



fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> {
let mut html_str = "<table border='1'>\n".to_string();

let df = self.df.as_ref().clone().limit(0, Some(10))?;

// Limit to the first 11 rows
let df = self.df.as_ref().clone().limit(0, Some(11))?;
let batches = wait_for_future(py, df.collect())?;


// If there are no rows, close the table and return
if batches.is_empty() {
html_str.push_str("</table>\n");
return Ok(html_str);
}


// Get schema for headers
let schema = batches[0].schema();

let mut header = Vec::new();
for field in schema.fields() {
header.push(format!("<th>{}</td>", field.name()));
header.push(format!("<th>{}</th>", field.name()));
}
let header_str = header.join("");
html_str.push_str(&format!("<tr>{}</tr>\n", header_str));

for batch in batches {

// Flatten rows and format them as HTML
let mut total_rows = 0;
for batch in &batches {
total_rows += batch.num_rows();
let formatters = batch
.columns()
.iter()
.map(|c| ArrayFormatter::try_new(c.as_ref(), &FormatOptions::default()))
.map(|c| {
c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string())))
})
.map(|c| c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string()))))
.collect::<Result<Vec<_>, _>>()?;

for row in 0..batch.num_rows() {

let num_rows_to_render = if total_rows > 10 { 10 } else { batch.num_rows() };

for row in 0..num_rows_to_render {
let mut cells = Vec::new();
for formatter in &formatters {
cells.push(format!("<td>{}</td>", formatter.value(row)));
}
let row_str = cells.join("");
html_str.push_str(&format!("<tr>{}</tr>\n", row_str));
}
}

if total_rows >= 10 {
break;
}
Comment on lines +162 to +183
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about simplifying to:

            let rows_remaining = 10 - total_rows;
            let rows_in_batch = batch.num_rows().min(rows_remaining);

            for row in 0..rows_in_batch {
                html_str.push_str("<tr>");
                for col in batch.columns() {
                    let formatter =
                        ArrayFormatter::try_new(col.as_ref(), &FormatOptions::default())?;
                    html_str.push_str("<td>");
                    html_str.push_str(&formatter.value(row).to_string());
                    html_str.push_str("</td>");
                }
                html_str.push_str("</tr>\n");
            }

            total_rows += rows_in_batch;

Reasons:

  1. More Accurate Row Limiting:

Before: total_rows was updated before checking the row limit, which could result in processing extra rows unnecessarily.
After: rows_remaining = 10 - total_rows ensures that we never exceed the row limit.

  1. Avoids Redundant Vec Allocation:

Before: Each row was constructed using a Vec, and format!() was used for each cell.
After: Directly appends elements to html_str, eliminating unnecessary heap allocations.

  1. Simplified and More Efficient Row Processing:

Before:
Used .map() and .collect() to create a list of ArrayFormatters before processing rows.
After:
Retrieves and formats values inside the loop, reducing redundant processing.

  1. Avoids Unnecessary break Condition:

Before: Explicit if total_rows >= 10 { break; } was used to stop processing.
After: The min(rows_remaining, batch.num_rows()) logic naturally prevents extra iterations

}

if total_rows > 10 {
html_str.push_str("<tr><td colspan=\"100%\">... and additional rows</td></tr>\n");
}

html_str.push_str("</table>\n");

Ok(html_str)
}


/// Calculate summary statistics for a DataFrame
fn describe(&self, py: Python) -> PyDataFusionResult<Self> {
Expand Down Expand Up @@ -436,6 +485,16 @@ impl PyDataFrame {
Ok(Self::new(df))
}

// Add column name handling that removes "?table?" prefix
fn format_column_name(&self, name: &str) -> String {
// Strip ?table? prefix if present
if name.starts_with("?table?.") {
name.trim_start_matches("?table?.").to_string()
} else {
name.to_string()
}
}

/// Calculate the intersection of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema
fn intersect(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
let new_df = self
Expand Down