Skip to content

Commit

Permalink
feat: add cast to DataFrame (#916)
Browse files Browse the repository at this point in the history
* feat: add with_columns

* feat: add top level cast

* chore: improve docstring

---------

Co-authored-by: Tim Saucer <[email protected]>
  • Loading branch information
ion-elgreco and timsaucer authored Oct 21, 2024
1 parent 7cca028 commit 70c099a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
13 changes: 13 additions & 0 deletions python/datafusion/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from __future__ import annotations


from typing import Any, Iterable, List, Literal, TYPE_CHECKING
from datafusion.record_batch import RecordBatchStream
from typing_extensions import deprecated
Expand Down Expand Up @@ -267,6 +268,18 @@ def sort(self, *exprs: Expr | SortExpr) -> DataFrame:
exprs_raw = [sort_or_default(expr) for expr in exprs]
return DataFrame(self.df.sort(*exprs_raw))

def cast(self, mapping: dict[str, pa.DataType[Any]]) -> DataFrame:
"""Cast one or more columns to a different data type.
Args:
mapping: Mapped with column as key and column dtype as value.
Returns:
DataFrame after casting columns
"""
exprs = [Expr.column(col).cast(dtype) for col, dtype in mapping.items()]
return self.with_columns(exprs)

def limit(self, count: int, offset: int = 0) -> DataFrame:
"""Return a new :py:class:`DataFrame` with a limited number of rows.
Expand Down
9 changes: 9 additions & 0 deletions python/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,15 @@ def test_with_columns(df):
assert result.column(6) == pa.array([5, 7, 9])


def test_cast(df):
df = df.cast({"a": pa.float16(), "b": pa.list_(pa.uint32())})
expected = pa.schema(
[("a", pa.float16()), ("b", pa.list_(pa.uint32())), ("c", pa.int64())]
)

assert df.schema() == expected


def test_with_column_renamed(df):
df = df.with_column("c", column("a") + column("b")).with_column_renamed("c", "sum")

Expand Down

0 comments on commit 70c099a

Please sign in to comment.