Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding tidyselect i.e. select using a mix of string colnames and predefined predicates #40

Merged
merged 1 commit into from
Aug 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/tidypandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,8 @@
from tidypandas import series_utils
from tidypandas import tidy_utils
from tidypandas import tidy_accessor
from tidypandas.tidyselect import (
starts_with,
ends_with,
contains
)
49 changes: 49 additions & 0 deletions src/tidypandas/_unexported_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pandas as pd
import inspect
import pandas.api.types as dtypes
from tidypandas import tidyselect
import warnings

def _is_kwargable(func):
Expand Down Expand Up @@ -58,6 +59,43 @@ def _is_string_or_string_list(x):
res = False

return res

def _is_tidyselect_compatible(x):
'''
_is_tidyselect_compatible(x)

Check whether the input is a string or a tidyselect closure or a list of strings and/or tidyselect closures

Parameters
----------
x : object
Any python object

Returns
-------
bool
True if input is a string or a list of strings

Examples
--------
>>> _is_tidyselect_compatible("bar") # True
>>> _is_tidyselect_compatible(["bar"]) # True
>>> _is_tidyselect_compatible(("bar",)) # False
>>> _is_tidyselect_compatible(["bar", 1]) # False
>>> _is_tidyselect_compatible(["bar", ends_with("_mean")]) # True
'''
res = False
if isinstance(x, str):
res = True
elif callable(x) and hasattr(tidyselect, x.__name__.lstrip("_").rstrip("_")):
res = True
elif isinstance(x, list) and len(x) >= 1:
if all([isinstance(i, str) or (callable(i) and hasattr(tidyselect, i.__name__.lstrip("_").rstrip("_"))) for i in x]):
res = True
else:
res = False

return res

def _enlist(x):
'''
Expand Down Expand Up @@ -243,6 +281,17 @@ def _flatten_strings(x):
res = res.union(set(ele))
return list(res)

def _flatten_list(x):
res = []
for ele in list(x):
if isinstance(ele, list):
res += ele
elif _is_nested(ele):
res += _flatten_list(ele)
else:
res.append(ele)
return res

def _nested_is_unique(x):
res = list()
x = list(x)
Expand Down
108 changes: 63 additions & 45 deletions src/tidypandas/tidyframe_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_is_kwargable,
_is_valid_colname,
_is_string_or_string_list,
_is_tidyselect_compatible,
_enlist,
_get_unique_names,
_is_unique_list,
Expand All @@ -32,8 +33,14 @@
_coerce_pdf,
_is_nested,
_flatten_strings,
_flatten_list,
_nested_is_unique
)
from tidypandas.tidyselect import (
starts_with,
ends_with,
contains
)
import tidypandas.format as tidy_fmt


Expand Down Expand Up @@ -679,6 +686,16 @@ def _validate_column_names(self, column_names):
)

return None

def _get_simplified_column_names(self, column_names):
assert _is_tidyselect_compatible(column_names),\
"arg 'column_names' should be tidyselect compatible"
column_names = _enlist(column_names)
column_names = [x(self) if callable(x) else x for x in column_names]
column_names = _flatten_list(column_names)
column_names = list(dict.fromkeys(column_names))
self._validate_column_names(column_names)
return column_names

def _clean_order_by(self, order_by):

Expand Down Expand Up @@ -839,7 +856,7 @@ def add_row_number(self, by = None, name = 'row_number'):
raise Exception("'name' should not start with an underscore")

if name in cn:
raise Expection("'name' should not be an existing column name.")
raise Exception("'name' should not be an existing column name.")

if by is None:
po = self.__data.assign(**{name : np.arange(nr)})
Expand Down Expand Up @@ -1128,8 +1145,8 @@ def select(self, column_names = None, predicate = None, include = True):
)
column_names = list(np.array(cn)[col_bool_list])
else:
self._validate_column_names(column_names)
column_names = _enlist(column_names)
column_names = self._get_simplified_column_names(column_names)


if not include:
column_names = list(setlist(cn).difference(column_names))
Expand Down Expand Up @@ -1183,8 +1200,7 @@ def relocate(self, column_names, before = None, after = None):
>>> penguins_tidy.relocate(["island", "species"], after = "year")
'''

self._validate_column_names(column_names)
column_names = _enlist(column_names)
column_names = self._get_simplified_column_names(column_names)
cn = self.colnames

assert not ((before is not None) and (after is not None)),\
Expand Down Expand Up @@ -2063,18 +2079,13 @@ def _mutate_across(self

# use column_names
if column_names is not None:
assert isinstance(column_names, list)
assert all([isinstance(acol, str) for acol in column_names])
# assert isinstance(column_names, list)
# assert all([isinstance(acol, str) for acol in column_names])
column_names = self._get_simplified_column_names(column_names)
# use predicate to assign appropriate column_names
else:
mask = list(self.__data.apply(predicate, axis = 0))
assert all([isinstance(x, bool) for x in mask])(self.group_modify(lambda chunk: chunk.query(query)
, by = by
, is_pandas_udf = True
, preserve_row_order = True
, row_order_column_name = ro_name
)
)
assert all([isinstance(x, bool) for x in mask])
column_names = self.__data.columns[mask]

# make a copy of the dataframe and apply mutate in order
Expand Down Expand Up @@ -2644,8 +2655,9 @@ def summarise(self

# use column_names
if column_names is not None:
self._validate_column_names(column_names)
column_names = _enlist(column_names)
# self._validate_column_names(column_names)
# column_names = _enlist(column_names)
column_names = self._get_simplified_column_names(column_names)
# use predicate to assign appropriate column_names
else:
mask = list(self.__data.apply(predicate, axis = 0))
Expand Down Expand Up @@ -3792,19 +3804,21 @@ def pivot_wider(self

cn = self.colnames

assert _is_string_or_string_list(names_from),\
"arg 'names_from' should be string or a list of strings"
names_from = _enlist(names_from)
assert _is_unique_list(names_from),\
"arg 'names_from' should have unique strings"
assert set(names_from).issubset(cn),\
"arg 'names_from' should be a subset of existing column names"

assert _is_string_or_string_list(values_from),\
"arg 'values_from' should be string or a list of strings"
values_from = _enlist(values_from)
assert set(values_from).issubset(cn),\
"arg 'names_from' should have unique strings"
# assert _is_string_or_string_list(names_from),\
# "arg 'names_from' should be string or a list of strings"
# names_from = _enlist(names_from)
# assert _is_unique_list(names_from),\
# "arg 'names_from' should have unique strings"
# assert set(names_from).issubset(cn),\
# "arg 'names_from' should be a subset of existing column names"
names_from = self._get_simplified_column_names(names_from)

# assert _is_string_or_string_list(values_from),\
# "arg 'values_from' should be string or a list of strings"
# values_from = _enlist(values_from)
# assert set(values_from).issubset(cn),\
# "arg 'names_from' should have unique strings"
values_from = self._get_simplified_column_names(values_from)
assert len(set(values_from).intersection(names_from)) == 0,\
("arg 'names_from' and 'values_from' should not "
"have common column names"
Expand All @@ -3822,11 +3836,12 @@ def pivot_wider(self
else:
print("'id_cols' chosen: " + str(id_cols))
else:
assert _is_string_or_string_list(id_cols),\
"arg 'id_cols' should be string or a list of strings"
id_cols = _enlist(id_cols)
assert _is_unique_list(id_cols),\
"arg 'id_cols' should have unique strings"
# assert _is_string_or_string_list(id_cols),\
# "arg 'id_cols' should be string or a list of strings"
# id_cols = _enlist(id_cols)
# assert _is_unique_list(id_cols),\
# "arg 'id_cols' should have unique strings"
id_cols = self._get_simplified_column_names(id_cols)
assert len(set(id_cols).intersection(names_values_from)) == 0,\
("arg 'id_cols' should not have common names with either "
"'names_from' or 'values_from'"
Expand Down Expand Up @@ -3968,11 +3983,12 @@ def pivot_longer(self
'''
# assertions
cn = self.colnames
assert _is_string_or_string_list(cols),\
"arg 'cols' should be a string or a list of strings"
cols = _enlist(cols)
assert set(cols).issubset(cn),\
"arg 'cols' should be a subset of existing column names"
# assert _is_string_or_string_list(cols),\
# "arg 'cols' should be a string or a list of strings"
# cols = _enlist(cols)
# assert set(cols).issubset(cn),\
# "arg 'cols' should be a subset of existing column names"
cols = self._get_simplified_column_names(cols)
assert isinstance(include, bool),\
"arg 'include' should be a bool"
if not include:
Expand Down Expand Up @@ -5001,8 +5017,9 @@ def drop_na(self, column_names = None):
'''
cn = self.colnames
if column_names is not None:
self._validate_column_names(column_names)
column_names = _enlist(column_names)
# self._validate_column_names(column_names)
# column_names = _enlist(column_names)
column_names = self._get_simplified_column_names(column_names)
else:
column_names = cn

Expand Down Expand Up @@ -5272,13 +5289,13 @@ def unite(self, column_names, into, sep = "_", keep = False):
>>> df.separate('col', into = ["col_1", "col_2"], sep = "_", strict = False)
'''


column_names = self._get_simplified_column_names(column_names)
def reduce_join(df, columns, sep):
assert len(columns) > 1
slist = [df[x].astype(str) for x in columns]
red_series = functools.reduce(lambda x, y: x + sep + y, slist[1:], slist[0])
return red_series.to_frame(name = into)

joined = reduce_join(self.__data, column_names, sep)

if not keep:
Expand Down Expand Up @@ -5476,8 +5493,9 @@ def nest(self
'''

cn = self.colnames
self._validate_column_names(column_names)
column_names = _enlist(column_names)
# self._validate_column_names(column_names)
# column_names = _enlist(column_names)
column_names = self._get_simplified_column_names(column_names)
if not include:
column_names = list(setlist(cn).difference(column_names))
by = list(setlist(cn).difference(column_names))
Expand Down
28 changes: 28 additions & 0 deletions src/tidypandas/tidyselect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
def starts_with(prefix):
assert(isinstance(prefix, str))
def _starts_with_(tidy_df, prefix=prefix):
cn = tidy_df.colnames
sel_cn = list(filter(lambda x: x[0:len(prefix)] == prefix, cn))
return sel_cn
return _starts_with_


def ends_with(suffix):
assert(isinstance(suffix, str))
def _ends_with_(tidy_df, suffix=suffix):
cn = tidy_df.colnames
sel_cn = list(filter(lambda x: x[-len(suffix):] == suffix, cn))
return sel_cn
return _ends_with_

def contains(pattern):
assert(isinstance(pattern, str))
def _contains_(tidy_df, pattern=pattern):
cn = tidy_df.colnames
sel_cn = list(filter(lambda x: x.find(pattern) > -1, cn))
return sel_cn
return _contains_




33 changes: 33 additions & 0 deletions tests/test_tidypandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@

from tidypandas import tidyframe
from tidypandas.tidy_utils import simplify
from tidypandas.tidyselect import (
starts_with,
ends_with,
contains
)

@pytest.fixture
def penguins_data():
Expand Down Expand Up @@ -212,6 +217,34 @@ def test_select(penguins_data):
result = penguins_tidy.select(['sex', 'species'], include = False).to_pandas()
assert_frame_equal_v2(expected, result)

def test_tidyselect(penguins_data):
penguins_tidy = tidyframe(penguins_data, copy=False)

expected = penguins_data[["bill_length_mm", "bill_depth_mm"]]
result = penguins_tidy.select(starts_with("bill")).to_pandas()
assert_frame_equal_v2(expected, result)

expected = penguins_data[["bill_length_mm", "bill_depth_mm", "flipper_length_mm"]]
result = penguins_tidy.select(ends_with("mm")).to_pandas()
assert_frame_equal_v2(expected, result)

expected = penguins_data[["bill_length_mm", "flipper_length_mm"]]
result = penguins_tidy.select(contains("length")).to_pandas()
assert_frame_equal_v2(expected, result)

expected = (penguins_data
.melt(id_vars = list(set(penguins_data.columns).difference(["bill_length_mm", "flipper_length_mm"]))
, value_vars = ["bill_length_mm", "flipper_length_mm"]
, var_name = "name"
, value_name = "value"
, ignore_index = True
)
)
result = penguins_tidy.pivot_longer(cols=contains("length")).to_pandas()
assert_frame_equal_v2(expected, result)



def test_filter(penguins_data):
penguins_tidy = tidyframe(penguins_data, copy=False)
exp = penguins_tidy.filter(lambda x: x['bill_length_mm'] >= x['bill_length_mm'].mean(), by = 'species')
Expand Down