Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wjsi committed Apr 28, 2022
1 parent 22ca9e4 commit 5069080
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 54 deletions.
56 changes: 36 additions & 20 deletions mars/dataframe/groupby/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,34 +701,36 @@ def _do_custom_agg(op, custom_reduction, *input_objs):
out = op.outputs[0]
for group_key in input_objs[0].groups.keys():
group_objs = [o.get_group(group_key) for o in input_objs]

agg_done = False
if op.stage == OperandStage.map:
result = custom_reduction.pre(group_objs[0])
res_tuple = custom_reduction.pre(group_objs[0])
agg_done = custom_reduction.pre_with_agg
if not isinstance(result, tuple):
result = (result,)
if not isinstance(res_tuple, tuple):
res_tuple = (res_tuple,)
else:
result = group_objs
res_tuple = group_objs

if not agg_done:
result = custom_reduction.agg(*result)
if not isinstance(result, tuple):
result = (result,)
res_tuple = custom_reduction.agg(*res_tuple)
if not isinstance(res_tuple, tuple):
res_tuple = (res_tuple,)

if op.stage == OperandStage.agg:
result = custom_reduction.post(*result)
if not isinstance(result, tuple):
result = (result,)

if out.ndim == 2:
if result[0].ndim == 1:
result = tuple(r.to_frame().T for r in result)
if op.stage == OperandStage.agg:
result = tuple(r.astype(out.dtypes) for r in result)
else:
result = tuple(xdf.Series(r) for r in result)
res_tuple = custom_reduction.post(*res_tuple)
if not isinstance(res_tuple, tuple):
res_tuple = (res_tuple,)

new_res_list = []
for r in res_tuple:
if out.ndim == 2 and r.ndim == 1:
r = r.to_frame().T
elif out.ndim < 2:
if getattr(r, "ndim", 0) == 2:
r = r.iloc[0, :]
else:
r = xdf.Series(r)

for r in result:
if len(input_objs[0].grouper.names) == 1:
r.index = xdf.Index(
[group_key], name=input_objs[0].grouper.names[0]
Expand All @@ -737,7 +739,21 @@ def _do_custom_agg(op, custom_reduction, *input_objs):
r.index = xdf.MultiIndex.from_tuples(
[group_key], names=input_objs[0].grouper.names
)
results.append(result)

if op.groupby_params.get("selection"):
# correct columns for groupby-selection-agg paradigms
selection = op.groupby_params["selection"]
r.columns = [selection] if input_objs[0].ndim == 1 else selection

if out.ndim == 2 and op.stage == OperandStage.agg:
dtype_cols = set(out.dtypes.index) & set(r.columns)
conv_dtypes = {
k: v for k, v in out.dtypes.items() if k in dtype_cols
}
r = r.astype(conv_dtypes)
new_res_list.append(r)

results.append(tuple(new_res_list))
if not results and op.stage == OperandStage.agg:
empty_df = pd.DataFrame(
[], columns=out.dtypes.index, index=out.index_value.to_pandas()[:0]
Expand Down
21 changes: 21 additions & 0 deletions mars/dataframe/groupby/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,3 +476,24 @@ def test_groupby_fill():
assert len(r.chunks) == 4
assert r.shape == (len(s1),)
assert r.chunks[0].shape == (np.nan,)


def test_groupby_nunique():
df1 = pd.DataFrame(
[
[1, 1, 10],
[1, 1, np.nan],
[1, 1, np.nan],
[1, 2, np.nan],
[1, 2, 20],
[1, 2, np.nan],
[1, 3, np.nan],
[1, 3, np.nan],
],
columns=["one", "two", "three"],
)
mdf = md.DataFrame(df1, chunk_size=3)

r = tile(mdf.groupby(["one", "two"]).nunique())
assert len(r.chunks) == 1
assert isinstance(r.chunks[0].op, DataFrameGroupByAgg)
19 changes: 11 additions & 8 deletions mars/dataframe/groupby/tests/test_groupby_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,13 +1241,16 @@ def test_groupby_nunique(setup):
# test with as_index=False
mdf = md.DataFrame(df1, chunk_size=13)
if _agg_size_as_frame:
res = mdf.groupby("b", as_index=False)["a"].nunique().execute().fetch()
expected = df1.groupby("b", as_index=False)["a"].nunique()
pd.testing.assert_frame_equal(
mdf.groupby("b", as_index=False)["a"]
.nunique()
.execute()
.fetch()
.sort_values(by="b", ignore_index=True),
df1.groupby("b", as_index=False)["a"]
.nunique()
.sort_values(by="b", ignore_index=True),
res.sort_values(by="b", ignore_index=True),
expected.sort_values(by="b", ignore_index=True),
)

res = mdf.groupby("b", as_index=False)[["a", "c"]].nunique().execute().fetch()
expected = df1.groupby("b", as_index=False)[["a", "c"]].nunique()
pd.testing.assert_frame_equal(
res.sort_values(by="b", ignore_index=True),
expected.sort_values(by="b", ignore_index=True),
)
44 changes: 18 additions & 26 deletions mars/dataframe/reduction/nunique.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,61 +41,53 @@ def __init__(
self._dropna = dropna
self._use_arrow_dtype = use_arrow_dtype

@staticmethod
def _drop_duplicates_to_arrow(v, explode=False):
def _drop_duplicates(self, xdf, value, explode=False):
if explode:
v = v.explode()
try:
return ArrowListArray([v.drop_duplicates().to_numpy()])
except pa.ArrowInvalid:
# fallback due to diverse dtypes
return [v.drop_duplicates().to_list()]
value = value.explode()

if not self._use_arrow_dtype or xdf is cudf:
return [value.drop_duplicates().to_numpy()]
else:
try:
return ArrowListArray([value.drop_duplicates().to_numpy()])
except pa.ArrowInvalid:
# fallback due to diverse dtypes
return [value.drop_duplicates().to_numpy()]

def pre(self, in_data): # noqa: W0221 # pylint: disable=arguments-differ
xdf = cudf if self.is_gpu() else pd
if isinstance(in_data, xdf.Series):
unique_values = in_data.drop_duplicates()
unique_values = self._drop_duplicates(xdf, in_data)
return xdf.Series(unique_values, name=in_data.name)
else:
if self._axis == 0:
data = dict()
for d, v in in_data.iteritems():
if not self._use_arrow_dtype or xdf is cudf:
data[d] = [v.drop_duplicates().to_list()]
else:
data[d] = self._drop_duplicates_to_arrow(v)
data[d] = self._drop_duplicates(xdf, v)
df = xdf.DataFrame(data)
else:
df = xdf.DataFrame(columns=[0])
for d, v in in_data.iterrows():
if not self._use_arrow_dtype or xdf is cudf:
df.loc[d] = [v.drop_duplicates().to_list()]
else:
df.loc[d] = self._drop_duplicates_to_arrow(v)
df.loc[d] = self._drop_duplicates(xdf, v)
return df

def agg(self, in_data): # noqa: W0221 # pylint: disable=arguments-differ
xdf = cudf if self.is_gpu() else pd
if isinstance(in_data, xdf.Series):
unique_values = in_data.explode().drop_duplicates()
unique_values = self._drop_duplicates(xdf, in_data, explode=True)
return xdf.Series(unique_values, name=in_data.name)
else:
if self._axis == 0:
data = dict()
for d, v in in_data.iteritems():
if not self._use_arrow_dtype or xdf is cudf:
data[d] = [v.explode().drop_duplicates().to_list()]
else:
if self._use_arrow_dtype and xdf is not cudf:
v = pd.Series(v.to_numpy())
data[d] = self._drop_duplicates_to_arrow(v, explode=True)
data[d] = self._drop_duplicates(xdf, v, explode=True)
df = xdf.DataFrame(data)
else:
df = xdf.DataFrame(columns=[0])
for d, v in in_data.iterrows():
if not self._use_arrow_dtype or xdf is cudf:
df.loc[d] = [v.explode().drop_duplicates().to_list()]
else:
df.loc[d] = self._drop_duplicates_to_arrow(v, explode=True)
df.loc[d] = self._drop_duplicates(xdf, v, explode=True)
return df

def post(self, in_data): # noqa: W0221 # pylint: disable=arguments-differ
Expand Down

0 comments on commit 5069080

Please sign in to comment.