diff --git a/src/tidypandas/__init__.py b/src/tidypandas/__init__.py index 6f46417..721de17 100644 --- a/src/tidypandas/__init__.py +++ b/src/tidypandas/__init__.py @@ -2,3 +2,8 @@ from tidypandas import series_utils from tidypandas import tidy_utils from tidypandas import tidy_accessor +from tidypandas.tidyselect import ( + starts_with, + ends_with, + contains + ) diff --git a/src/tidypandas/_unexported_utils.py b/src/tidypandas/_unexported_utils.py index 5e7566c..c181f14 100644 --- a/src/tidypandas/_unexported_utils.py +++ b/src/tidypandas/_unexported_utils.py @@ -8,6 +8,7 @@ import pandas as pd import inspect import pandas.api.types as dtypes +from tidypandas import tidyselect import warnings def _is_kwargable(func): @@ -58,6 +59,43 @@ def _is_string_or_string_list(x): res = False return res + +def _is_tidyselect_compatible(x): + ''' + _is_tidyselect_compatible(x) + + Check whether the input is a string or a tidyselect closure or a list of strings and/or tidyselect closures + + Parameters + ---------- + x : object + Any python object + + Returns + ------- + bool + True if input is a string or a list of strings + + Examples + -------- + >>> _is_tidyselect_compatible("bar") # True + >>> _is_tidyselect_compatible(["bar"]) # True + >>> _is_tidyselect_compatible(("bar",)) # False + >>> _is_tidyselect_compatible(["bar", 1]) # False + >>> _is_tidyselect_compatible(["bar", ends_with("_mean")]) # True + ''' + res = False + if isinstance(x, str): + res = True + elif callable(x) and hasattr(tidyselect, x.__name__.lstrip("_").rstrip("_")): + res = True + elif isinstance(x, list) and len(x) >= 1: + if all([isinstance(i, str) or (callable(i) and hasattr(tidyselect, i.__name__.lstrip("_").rstrip("_"))) for i in x]): + res = True + else: + res = False + + return res def _enlist(x): ''' @@ -243,6 +281,17 @@ def _flatten_strings(x): res = res.union(set(ele)) return list(res) +def _flatten_list(x): + res = [] + for ele in list(x): + if isinstance(ele, list): + res += ele + elif _is_nested(ele): + res += _flatten_list(ele) + else: + res.append(ele) + return res + def _nested_is_unique(x): res = list() x = list(x) diff --git a/src/tidypandas/tidyframe_class.py b/src/tidypandas/tidyframe_class.py index f22e644..c962040 100644 --- a/src/tidypandas/tidyframe_class.py +++ b/src/tidypandas/tidyframe_class.py @@ -23,6 +23,7 @@ _is_kwargable, _is_valid_colname, _is_string_or_string_list, + _is_tidyselect_compatible, _enlist, _get_unique_names, _is_unique_list, @@ -32,8 +33,14 @@ _coerce_pdf, _is_nested, _flatten_strings, + _flatten_list, _nested_is_unique ) +from tidypandas.tidyselect import ( + starts_with, + ends_with, + contains + ) import tidypandas.format as tidy_fmt @@ -679,6 +686,16 @@ def _validate_column_names(self, column_names): ) return None + + def _get_simplified_column_names(self, column_names): + assert _is_tidyselect_compatible(column_names),\ + "arg 'column_names' should be tidyselect compatible" + column_names = _enlist(column_names) + column_names = [x(self) if callable(x) else x for x in column_names] + column_names = _flatten_list(column_names) + column_names = list(dict.fromkeys(column_names)) + self._validate_column_names(column_names) + return column_names def _clean_order_by(self, order_by): @@ -839,7 +856,7 @@ def add_row_number(self, by = None, name = 'row_number'): raise Exception("'name' should not start with an underscore") if name in cn: - raise Expection("'name' should not be an existing column name.") + raise Exception("'name' should not be an existing column name.") if by is None: po = self.__data.assign(**{name : np.arange(nr)}) @@ -1128,8 +1145,8 @@ def select(self, column_names = None, predicate = None, include = True): ) column_names = list(np.array(cn)[col_bool_list]) else: - self._validate_column_names(column_names) - column_names = _enlist(column_names) + column_names = self._get_simplified_column_names(column_names) + if not include: column_names = list(setlist(cn).difference(column_names)) @@ -1183,8 +1200,7 @@ def relocate(self, column_names, before = None, after = None): >>> penguins_tidy.relocate(["island", "species"], after = "year") ''' - self._validate_column_names(column_names) - column_names = _enlist(column_names) + column_names = self._get_simplified_column_names(column_names) cn = self.colnames assert not ((before is not None) and (after is not None)),\ @@ -2063,18 +2079,13 @@ def _mutate_across(self # use column_names if column_names is not None: - assert isinstance(column_names, list) - assert all([isinstance(acol, str) for acol in column_names]) + # assert isinstance(column_names, list) + # assert all([isinstance(acol, str) for acol in column_names]) + column_names = self._get_simplified_column_names(column_names) # use predicate to assign appropriate column_names else: mask = list(self.__data.apply(predicate, axis = 0)) - assert all([isinstance(x, bool) for x in mask])(self.group_modify(lambda chunk: chunk.query(query) - , by = by - , is_pandas_udf = True - , preserve_row_order = True - , row_order_column_name = ro_name - ) - ) + assert all([isinstance(x, bool) for x in mask]) column_names = self.__data.columns[mask] # make a copy of the dataframe and apply mutate in order @@ -2644,8 +2655,9 @@ def summarise(self # use column_names if column_names is not None: - self._validate_column_names(column_names) - column_names = _enlist(column_names) + # self._validate_column_names(column_names) + # column_names = _enlist(column_names) + column_names = self._get_simplified_column_names(column_names) # use predicate to assign appropriate column_names else: mask = list(self.__data.apply(predicate, axis = 0)) @@ -3792,19 +3804,21 @@ def pivot_wider(self cn = self.colnames - assert _is_string_or_string_list(names_from),\ - "arg 'names_from' should be string or a list of strings" - names_from = _enlist(names_from) - assert _is_unique_list(names_from),\ - "arg 'names_from' should have unique strings" - assert set(names_from).issubset(cn),\ - "arg 'names_from' should be a subset of existing column names" - - assert _is_string_or_string_list(values_from),\ - "arg 'values_from' should be string or a list of strings" - values_from = _enlist(values_from) - assert set(values_from).issubset(cn),\ - "arg 'names_from' should have unique strings" + # assert _is_string_or_string_list(names_from),\ + # "arg 'names_from' should be string or a list of strings" + # names_from = _enlist(names_from) + # assert _is_unique_list(names_from),\ + # "arg 'names_from' should have unique strings" + # assert set(names_from).issubset(cn),\ + # "arg 'names_from' should be a subset of existing column names" + names_from = self._get_simplified_column_names(names_from) + + # assert _is_string_or_string_list(values_from),\ + # "arg 'values_from' should be string or a list of strings" + # values_from = _enlist(values_from) + # assert set(values_from).issubset(cn),\ + # "arg 'names_from' should have unique strings" + values_from = self._get_simplified_column_names(values_from) assert len(set(values_from).intersection(names_from)) == 0,\ ("arg 'names_from' and 'values_from' should not " "have common column names" @@ -3822,11 +3836,12 @@ def pivot_wider(self else: print("'id_cols' chosen: " + str(id_cols)) else: - assert _is_string_or_string_list(id_cols),\ - "arg 'id_cols' should be string or a list of strings" - id_cols = _enlist(id_cols) - assert _is_unique_list(id_cols),\ - "arg 'id_cols' should have unique strings" + # assert _is_string_or_string_list(id_cols),\ + # "arg 'id_cols' should be string or a list of strings" + # id_cols = _enlist(id_cols) + # assert _is_unique_list(id_cols),\ + # "arg 'id_cols' should have unique strings" + id_cols = self._get_simplified_column_names(id_cols) assert len(set(id_cols).intersection(names_values_from)) == 0,\ ("arg 'id_cols' should not have common names with either " "'names_from' or 'values_from'" @@ -3968,11 +3983,12 @@ def pivot_longer(self ''' # assertions cn = self.colnames - assert _is_string_or_string_list(cols),\ - "arg 'cols' should be a string or a list of strings" - cols = _enlist(cols) - assert set(cols).issubset(cn),\ - "arg 'cols' should be a subset of existing column names" + # assert _is_string_or_string_list(cols),\ + # "arg 'cols' should be a string or a list of strings" + # cols = _enlist(cols) + # assert set(cols).issubset(cn),\ + # "arg 'cols' should be a subset of existing column names" + cols = self._get_simplified_column_names(cols) assert isinstance(include, bool),\ "arg 'include' should be a bool" if not include: @@ -5001,8 +5017,9 @@ def drop_na(self, column_names = None): ''' cn = self.colnames if column_names is not None: - self._validate_column_names(column_names) - column_names = _enlist(column_names) + # self._validate_column_names(column_names) + # column_names = _enlist(column_names) + column_names = self._get_simplified_column_names(column_names) else: column_names = cn @@ -5272,13 +5289,13 @@ def unite(self, column_names, into, sep = "_", keep = False): >>> df.separate('col', into = ["col_1", "col_2"], sep = "_", strict = False) ''' - + column_names = self._get_simplified_column_names(column_names) def reduce_join(df, columns, sep): assert len(columns) > 1 slist = [df[x].astype(str) for x in columns] red_series = functools.reduce(lambda x, y: x + sep + y, slist[1:], slist[0]) return red_series.to_frame(name = into) - + joined = reduce_join(self.__data, column_names, sep) if not keep: @@ -5476,8 +5493,9 @@ def nest(self ''' cn = self.colnames - self._validate_column_names(column_names) - column_names = _enlist(column_names) + # self._validate_column_names(column_names) + # column_names = _enlist(column_names) + column_names = self._get_simplified_column_names(column_names) if not include: column_names = list(setlist(cn).difference(column_names)) by = list(setlist(cn).difference(column_names)) diff --git a/src/tidypandas/tidyselect.py b/src/tidypandas/tidyselect.py new file mode 100644 index 0000000..69b27a6 --- /dev/null +++ b/src/tidypandas/tidyselect.py @@ -0,0 +1,28 @@ +def starts_with(prefix): + assert(isinstance(prefix, str)) + def _starts_with_(tidy_df, prefix=prefix): + cn = tidy_df.colnames + sel_cn = list(filter(lambda x: x[0:len(prefix)] == prefix, cn)) + return sel_cn + return _starts_with_ + + +def ends_with(suffix): + assert(isinstance(suffix, str)) + def _ends_with_(tidy_df, suffix=suffix): + cn = tidy_df.colnames + sel_cn = list(filter(lambda x: x[-len(suffix):] == suffix, cn)) + return sel_cn + return _ends_with_ + +def contains(pattern): + assert(isinstance(pattern, str)) + def _contains_(tidy_df, pattern=pattern): + cn = tidy_df.colnames + sel_cn = list(filter(lambda x: x.find(pattern) > -1, cn)) + return sel_cn + return _contains_ + + + + diff --git a/tests/test_tidypandas.py b/tests/test_tidypandas.py index 941768d..fa687a4 100644 --- a/tests/test_tidypandas.py +++ b/tests/test_tidypandas.py @@ -7,6 +7,11 @@ from tidypandas import tidyframe from tidypandas.tidy_utils import simplify +from tidypandas.tidyselect import ( + starts_with, + ends_with, + contains +) @pytest.fixture def penguins_data(): @@ -212,6 +217,34 @@ def test_select(penguins_data): result = penguins_tidy.select(['sex', 'species'], include = False).to_pandas() assert_frame_equal_v2(expected, result) +def test_tidyselect(penguins_data): + penguins_tidy = tidyframe(penguins_data, copy=False) + + expected = penguins_data[["bill_length_mm", "bill_depth_mm"]] + result = penguins_tidy.select(starts_with("bill")).to_pandas() + assert_frame_equal_v2(expected, result) + + expected = penguins_data[["bill_length_mm", "bill_depth_mm", "flipper_length_mm"]] + result = penguins_tidy.select(ends_with("mm")).to_pandas() + assert_frame_equal_v2(expected, result) + + expected = penguins_data[["bill_length_mm", "flipper_length_mm"]] + result = penguins_tidy.select(contains("length")).to_pandas() + assert_frame_equal_v2(expected, result) + + expected = (penguins_data + .melt(id_vars = list(set(penguins_data.columns).difference(["bill_length_mm", "flipper_length_mm"])) + , value_vars = ["bill_length_mm", "flipper_length_mm"] + , var_name = "name" + , value_name = "value" + , ignore_index = True + ) + ) + result = penguins_tidy.pivot_longer(cols=contains("length")).to_pandas() + assert_frame_equal_v2(expected, result) + + + def test_filter(penguins_data): penguins_tidy = tidyframe(penguins_data, copy=False) exp = penguins_tidy.filter(lambda x: x['bill_length_mm'] >= x['bill_length_mm'].mean(), by = 'species')