-
Notifications
You must be signed in to change notification settings - Fork 101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
_repr_ and _html_repr_ show '... and additional rows' message #1041
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -90,59 +90,108 @@ impl PyDataFrame { | |
} | ||
|
||
fn __repr__(&self, py: Python) -> PyDataFusionResult<String> { | ||
let df = self.df.as_ref().clone().limit(0, Some(10))?; | ||
// Get 11 rows to check if there are more than 10 | ||
let df = self.df.as_ref().clone().limit(0, Some(11))?; | ||
let batches = wait_for_future(py, df.collect())?; | ||
let batches_as_string = pretty::pretty_format_batches(&batches); | ||
let num_rows = batches.iter().map(|batch| batch.num_rows()).sum::<usize>(); | ||
|
||
// Flatten batches into a single batch for the first 10 rows | ||
let mut all_rows = Vec::new(); | ||
let mut total_rows = 0; | ||
|
||
for batch in &batches { | ||
let num_rows_to_take = if total_rows + batch.num_rows() > 10 { | ||
10 - total_rows | ||
} else { | ||
batch.num_rows() | ||
}; | ||
|
||
if num_rows_to_take > 0 { | ||
let sliced_batch = batch.slice(0, num_rows_to_take); | ||
all_rows.push(sliced_batch); | ||
total_rows += num_rows_to_take; | ||
} | ||
|
||
if total_rows >= 10 { | ||
break; | ||
} | ||
} | ||
|
||
let batches_as_string = pretty::pretty_format_batches(&all_rows); | ||
|
||
match batches_as_string { | ||
Ok(batch) => Ok(format!("DataFrame()\n{batch}")), | ||
Ok(batch) => { | ||
if num_rows > 10 { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. using + if has_more_rows { |
||
Ok(format!("DataFrame()\n{batch}\n... and additional rows")) | ||
} else { | ||
Ok(format!("DataFrame()\n{batch}")) | ||
} | ||
} | ||
Err(err) => Ok(format!("Error: {:?}", err.to_string())), | ||
} | ||
} | ||
|
||
|
||
|
||
fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> { | ||
let mut html_str = "<table border='1'>\n".to_string(); | ||
|
||
let df = self.df.as_ref().clone().limit(0, Some(10))?; | ||
|
||
// Limit to the first 11 rows | ||
let df = self.df.as_ref().clone().limit(0, Some(11))?; | ||
let batches = wait_for_future(py, df.collect())?; | ||
|
||
|
||
// If there are no rows, close the table and return | ||
if batches.is_empty() { | ||
html_str.push_str("</table>\n"); | ||
return Ok(html_str); | ||
} | ||
|
||
|
||
// Get schema for headers | ||
let schema = batches[0].schema(); | ||
|
||
let mut header = Vec::new(); | ||
for field in schema.fields() { | ||
header.push(format!("<th>{}</td>", field.name())); | ||
header.push(format!("<th>{}</th>", field.name())); | ||
} | ||
let header_str = header.join(""); | ||
html_str.push_str(&format!("<tr>{}</tr>\n", header_str)); | ||
|
||
for batch in batches { | ||
|
||
// Flatten rows and format them as HTML | ||
let mut total_rows = 0; | ||
for batch in &batches { | ||
total_rows += batch.num_rows(); | ||
let formatters = batch | ||
.columns() | ||
.iter() | ||
.map(|c| ArrayFormatter::try_new(c.as_ref(), &FormatOptions::default())) | ||
.map(|c| { | ||
c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string()))) | ||
}) | ||
.map(|c| c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string())))) | ||
.collect::<Result<Vec<_>, _>>()?; | ||
|
||
for row in 0..batch.num_rows() { | ||
|
||
let num_rows_to_render = if total_rows > 10 { 10 } else { batch.num_rows() }; | ||
|
||
for row in 0..num_rows_to_render { | ||
let mut cells = Vec::new(); | ||
for formatter in &formatters { | ||
cells.push(format!("<td>{}</td>", formatter.value(row))); | ||
} | ||
let row_str = cells.join(""); | ||
html_str.push_str(&format!("<tr>{}</tr>\n", row_str)); | ||
} | ||
} | ||
|
||
if total_rows >= 10 { | ||
break; | ||
} | ||
Comment on lines
+162
to
+183
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about simplifying to: let rows_remaining = 10 - total_rows;
let rows_in_batch = batch.num_rows().min(rows_remaining);
for row in 0..rows_in_batch {
html_str.push_str("<tr>");
for col in batch.columns() {
let formatter =
ArrayFormatter::try_new(col.as_ref(), &FormatOptions::default())?;
html_str.push_str("<td>");
html_str.push_str(&formatter.value(row).to_string());
html_str.push_str("</td>");
}
html_str.push_str("</tr>\n");
}
total_rows += rows_in_batch; Reasons:
Before: total_rows was updated before checking the row limit, which could result in processing extra rows unnecessarily.
Before: Each row was constructed using a Vec, and format!() was used for each cell.
Before:
Before: Explicit if total_rows >= 10 { break; } was used to stop processing. |
||
} | ||
|
||
if total_rows > 10 { | ||
html_str.push_str("<tr><td colspan=\"100%\">... and additional rows</td></tr>\n"); | ||
} | ||
|
||
html_str.push_str("</table>\n"); | ||
|
||
Ok(html_str) | ||
} | ||
|
||
|
||
/// Calculate summary statistics for a DataFrame | ||
fn describe(&self, py: Python) -> PyDataFusionResult<Self> { | ||
|
@@ -436,6 +485,16 @@ impl PyDataFrame { | |
Ok(Self::new(df)) | ||
} | ||
|
||
// Add column name handling that removes "?table?" prefix | ||
fn format_column_name(&self, name: &str) -> String { | ||
// Strip ?table? prefix if present | ||
if name.starts_with("?table?.") { | ||
name.trim_start_matches("?table?.").to_string() | ||
} else { | ||
name.to_string() | ||
} | ||
} | ||
|
||
/// Calculate the intersection of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema | ||
fn intersect(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> { | ||
let new_df = self | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can simplify
batches_as_string
to:This directly retrieves just the first 10 rows, eliminating the need for manual row tracking and slicing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did try this initiatially but calling
collect
twice led to a severe performance degradation. It used to take50ms
. With the manual slicing, it dropped to5ms
.You can check my initial suggestion for the same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran this test to compare the performance:
and found no significant difference:
pr_1041 - is the branch with one
collect
amended_pr_1041 - is the branch with two
collect
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's weird. Maybe some artifact of my system settings? If there is no performance issues than I'll use your approach. But then why was the
_repr_html_
using batch manipulation at the first place? I took the idea from that function!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi @Spaarsh
Sorry, in my previous test, I overlooked to
maturin develop
for the Rust changes.In my retests, two collects does take about 53% (1935/1265) longer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh no issues. Thanks for corroborating my findings btw!