Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add strict parameter to pl.concat(how='horizontal') #20019

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
6038130
added 'strict' as a keyword argument (default: False) to pl.concat an…
nimit Nov 27, 2024
d14049c
bug fix with keyword argument change in pl.concat
nimit Nov 27, 2024
ae2a335
Merge branch 'pola-rs:main' into strict-concat-19133
nimit Dec 1, 2024
6a0df66
rust changes
nimit Dec 1, 2024
ea847bf
fixed python errors with concat_df_horizontal, concat_lf_horizontal
nimit Dec 1, 2024
8f2883d
fixed exception type in python unit test for strict concatenation of …
nimit Dec 1, 2024
21d110a
build: Bump `chrono-tz` to `0.10` (#20094)
stinodego Dec 1, 2024
d85fc9c
chore(rust): Update AWS doc dependencies (#20095)
stinodego Dec 2, 2024
f00e6fd
docs(rust): Fix inconsistency between code and comment (#20070)
YichiZhang0613 Dec 2, 2024
a413c4a
fix: Only slice after sort when slice is smaller than frame length (#…
mcrumiller Dec 2, 2024
5a8cd16
fix: Return null instead of 0. for rolling_std when window contains a…
MarcoGorelli Dec 2, 2024
755ee47
build: Bump `atoi_simd` to version `0.16` (#20098)
stinodego Dec 2, 2024
af4d5a5
build: Bump `thiserror` to version `2` (#20097)
stinodego Dec 2, 2024
06b2ab6
build(rust): Fix path to `polars-dylib` crate in workspace (#20103)
stinodego Dec 2, 2024
511c219
build: Bump `fs4` to version `0.12` (#20101)
stinodego Dec 2, 2024
b276485
build: Bump `object_store` to version `0.11` (#20102)
stinodego Dec 2, 2024
bc3d320
build: Bump `memmap2` to version `0.9` (#20105)
stinodego Dec 2, 2024
8f72e20
refactor(rust): Replace custom `PushNode` trait with `Extend` (#20107)
nameexhaustion Dec 2, 2024
527af9c
build: Upgrade `sqlparser-rs` from version 0.49 to 0.52 (#20110)
alexander-beedie Dec 2, 2024
0d4029a
fixed LazyFrame test, Python & Rust documentation changes regarding a…
nimit Dec 2, 2024
90e65e8
Merge branch 'pola-rs:main' into strict-concat-19133
nimit Dec 2, 2024
6078964
doc fix, removed unnecessary generic
nimit Dec 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion crates/polars-core/src/frame/horizontal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,16 @@ impl DataFrame {
}
/// Concat [`DataFrame`]s horizontally.
/// Concat horizontally and extend with null values if lengths don't match
nimit marked this conversation as resolved.
Show resolved Hide resolved
pub fn concat_df_horizontal(dfs: &[DataFrame], check_duplicates: bool) -> PolarsResult<DataFrame> {
pub fn concat_df_horizontal<T>(
dfs: &[DataFrame],
check_duplicates: bool,
strict: T,
nimit marked this conversation as resolved.
Show resolved Hide resolved
) -> PolarsResult<DataFrame>
where
T: Into<Option<bool>>,
{
// Set this to true for v2.0.0 milestone
let strict = strict.into().unwrap_or(false);
let output_height = dfs
.iter()
.map(|df| df.height())
Expand All @@ -87,6 +96,11 @@ pub fn concat_df_horizontal(dfs: &[DataFrame], check_duplicates: bool) -> Polars

// if not all equal length, extend the DataFrame with nulls
nimit marked this conversation as resolved.
Show resolved Hide resolved
let dfs = if !dfs.iter().all(|df| df.height() == output_height) {
if strict {
return Err(
polars_err!(ShapeMismatch: "cannot concat dataframes with different heights in 'strict' mode"),
);
}
owned_df = dfs
.iter()
.cloned()
Expand Down
1 change: 1 addition & 0 deletions crates/polars-lazy/src/dsl/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ pub fn concat_lf_horizontal<L: AsRef<[LazyFrame]>>(

let options = HConcatOptions {
parallel: args.parallel,
strict: args.strict,
};
let lp = DslPlan::HConcat {
inputs: lfs.iter().map(|lf| lf.logical_plan.clone()).collect(),
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-mem-engine/src/executors/hconcat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,6 @@ impl Executor for HConcatExec {
};

// Invariant of IR. Schema is already checked to contain no duplicates.
concat_df_horizontal(&dfs, false)
concat_df_horizontal(&dfs, false, self.options.strict)
}
}
4 changes: 4 additions & 0 deletions crates/polars-plan/src/plans/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ pub struct UnionOptions {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct HConcatOptions {
pub parallel: bool,
pub strict: bool,
}

#[derive(Clone, Debug, PartialEq, Eq, Default, Hash)]
Expand Down Expand Up @@ -363,6 +364,7 @@ pub struct UnionArgs {
pub rechunk: bool,
pub to_supertypes: bool,
pub diagonal: bool,
pub strict: bool,
// If it is a union from a scan over multiple files.
pub from_partitioned_ds: bool,
}
Expand All @@ -375,6 +377,8 @@ impl Default for UnionArgs {
to_supertypes: false,
diagonal: false,
from_partitioned_ds: false,
// By default, strict should be true in v2.0.0
strict: false,
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-python/src/functions/eager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ pub fn concat_df_diagonal(dfs: &Bound<'_, PyAny>) -> PyResult<PyDataFrame> {
}

#[pyfunction]
pub fn concat_df_horizontal(dfs: &Bound<'_, PyAny>) -> PyResult<PyDataFrame> {
pub fn concat_df_horizontal(dfs: &Bound<'_, PyAny>, strict: bool) -> PyResult<PyDataFrame> {
let iter = dfs.iter()?;

let dfs = iter
Expand All @@ -88,6 +88,6 @@ pub fn concat_df_horizontal(dfs: &Bound<'_, PyAny>) -> PyResult<PyDataFrame> {
})
.collect::<PyResult<Vec<_>>>()?;

let df = functions::concat_df_horizontal(&dfs, true).map_err(PyPolarsErr::from)?;
let df = functions::concat_df_horizontal(&dfs, true, strict).map_err(PyPolarsErr::from)?;
Ok(df.into())
}
7 changes: 6 additions & 1 deletion crates/polars-python/src/functions/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,11 @@ pub fn concat_lf_diagonal(
}

#[pyfunction]
pub fn concat_lf_horizontal(lfs: &Bound<'_, PyAny>, parallel: bool) -> PyResult<PyLazyFrame> {
pub fn concat_lf_horizontal(
lfs: &Bound<'_, PyAny>,
parallel: bool,
strict: bool,
) -> PyResult<PyLazyFrame> {
let iter = lfs.iter()?;

let lfs = iter
Expand All @@ -339,6 +343,7 @@ pub fn concat_lf_horizontal(lfs: &Bound<'_, PyAny>, parallel: bool) -> PyResult<
rechunk: false, // No need to rechunk with horizontal concatenation
parallel,
to_supertypes: false,
strict,
..Default::default()
};
let lf = dsl::functions::concat_lf_horizontal(lfs, args).map_err(PyPolarsErr::from)?;
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-stream/src/nodes/zip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ impl ComputeNode for ZipNode {
for input_head in &mut self.input_heads {
out.push(input_head.take(common_size));
}
let out_df = concat_df_horizontal(&out, false)?;
let out_df = concat_df_horizontal(&out, false, None)?;
out.clear();

let morsel = Morsel::new(out_df, self.out_seq, source_token.clone());
Expand Down Expand Up @@ -320,7 +320,7 @@ impl ComputeNode for ZipNode {
for input_head in &mut self.input_heads {
out.push(input_head.consume_broadcast());
}
let out_df = concat_df_horizontal(&out, false)?;
let out_df = concat_df_horizontal(&out, false, None)?;
out.clear();

let morsel = Morsel::new(out_df, self.out_seq, source_token.clone());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
"r2"=> &[7, 8],
"r3"=> &[9, 10],
)?;
let df_horizontal_concat = polars::functions::concat_df_horizontal(&[df_h1, df_h2], true)?;
let df_horizontal_concat =
polars::functions::concat_df_horizontal(&[df_h1, df_h2], true, None)?;
println!("{}", &df_horizontal_concat);
// --8<-- [end:horizontal]
//
Expand All @@ -43,7 +44,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
"r1"=> &[5, 6, 7],
"r2"=> &[8, 9, 10],
)?;
let df_horizontal_concat = polars::functions::concat_df_horizontal(&[df_h1, df_h2], true)?;
let df_horizontal_concat =
polars::functions::concat_df_horizontal(&[df_h1, df_h2], true, None)?;
println!("{}", &df_horizontal_concat);
// --8<-- [end:horizontal_different_lengths]

Expand Down
6 changes: 5 additions & 1 deletion py-polars/polars/functions/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def concat(
how: ConcatMethod = "vertical",
rechunk: bool = False,
parallel: bool = True,
strict: bool = False,
) -> PolarsType:
"""
Combine multiple DataFrames, LazyFrames, or Series into a single object.
Expand Down Expand Up @@ -58,6 +59,8 @@ def concat(
parallel
Only relevant for LazyFrames. This determines if the concatenated
lazy computations may be executed in parallel.
strict
If True, reject concatenating DataFrames that are not the same height when how=`horizontal`
nimit marked this conversation as resolved.
Show resolved Hide resolved

Examples
--------
Expand Down Expand Up @@ -205,7 +208,7 @@ def concat(
)
).collect(no_optimization=True)
elif how == "horizontal":
out = wrap_df(plr.concat_df_horizontal(elems))
out = wrap_df(plr.concat_df_horizontal(elems, strict=strict))
else:
allowed = ", ".join(repr(m) for m in get_args(ConcatMethod))
msg = f"DataFrame `how` must be one of {{{allowed}}}, got {how!r}"
Expand Down Expand Up @@ -235,6 +238,7 @@ def concat(
plr.concat_lf_horizontal(
elems,
parallel=parallel,
strict=strict,
)
)
else:
Expand Down
19 changes: 19 additions & 0 deletions py-polars/tests/unit/functions/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,25 @@ def test_concat_lf_stack_overflow() -> None:
assert bar.collect().shape == (1001, 1)


def test_concat_horizontally_strict() -> None:
a = pl.DataFrame({"a": [0, 1], "b": [1, 2]})
b = pl.DataFrame({"c": [11], "d": [42]})

with pytest.raises(pl.exceptions.ShapeError):
pl.concat([a, b], how="horizontal", strict=True)

with pytest.raises(pl.exceptions.ShapeError):
pl.concat([a.lazy(), b.lazy()], how="horizontal", strict=True)
nimit marked this conversation as resolved.
Show resolved Hide resolved

out = pl.concat([a, b], how="horizontal", strict=False)
assert out.to_dict(as_series=False) == {
"a": [0, 1],
"b": [1, 2],
"c": [11, None],
"d": [42, None],
}


def test_concat_vertically_relaxed() -> None:
a = pl.DataFrame(
data={"a": [1, 2, 3], "b": [True, False, None]},
Expand Down
Loading