Skip to content

Commit

Permalink
Merge pull request #587 from Tjev/table-refactor
Browse files Browse the repository at this point in the history
refactor(gooddata-sdk) - Switch up function definitions ordering

Reviewed-by: Lubo Slivka
             https://github.com/lupko
  • Loading branch information
gdgate authored Mar 8, 2024
2 parents 10a861b + 57171d0 commit 36155e7
Showing 1 changed file with 140 additions and 140 deletions.
280 changes: 140 additions & 140 deletions gooddata-sdk/gooddata_sdk/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,25 +272,6 @@ def create(cls, total: TotalDefinition, measures: list[VisualizationMetric]) ->
)


def _create_dimension(bucket: VisualizationBucket, measures_item_identifier: Optional[str] = None) -> TableDimension:
item_ids = [a.local_id for a in bucket.attributes]
if measures_item_identifier is not None:
item_ids.append(measures_item_identifier)
return TableDimension(
item_ids=item_ids,
idx=_GET_DIM_INDEX_OF_BUCKET_TYPE[bucket.type],
)


def _get_dim_idx_for_predicate(
dims: list[TableDimension], predicate: Callable[[TableDimension], bool]
) -> Optional[int]:
for dim_idx, dim in enumerate(dims):
if predicate(dim):
return dim_idx
return None


@frozen
class AttributeLocator:
attribute_identifier: str
Expand Down Expand Up @@ -367,6 +348,73 @@ def _create_data_col_locators(locators: list[VisualizationSortLocator]) -> list[
return converted_locators


def _get_dim_idx_for_predicate(
dims: list[TableDimension], predicate: Callable[[TableDimension], bool]
) -> Optional[int]:
for dim_idx, dim in enumerate(dims):
if predicate(dim):
return dim_idx
return None


def _append_attribute_sort_key(
dims: list[TableDimension], sort_item: VisualizationSort, sorting: list[list[SortKey]]
) -> None:
dim_idx = _get_dim_idx_for_predicate(dims, lambda x: sort_item.attribute_identifier in x.item_ids)
if dim_idx is None:
log_msg = (
f'attempting to sort by attribute with localId "{sort_item.attribute_identifier}" '
"but this attribute is not in any dimension."
)
logger.warning(log_msg)
return

sorting[dim_idx].append(
SortKeyAttribute(
sort_type=sort_item.type,
direction=sort_item.direction,
attribute_identifier=sort_item.attribute_identifier,
attribute_sort_type=sort_item.attribute_sort_type,
)
)


def _append_measure_sort_key(
measure_dim: Optional[TableDimension],
non_measure_dim_idx: Optional[int],
sort_item: VisualizationSort,
sorting: list[list[SortKey]],
) -> None:
if non_measure_dim_idx is None:
logger.warning(
"Trying to use measure sort in an execution that only contains dimension with MeasureGroup. "
"This is not valid sort. Measure sort is used to sort the non-measure dimension by values "
"from measure dimension. Skipping."
)
return

if not measure_dim:
logger.warning("Trying to use measure sort in an execution that does not contain MeasureGroup. Skipping.")
return

sorting[non_measure_dim_idx].append(
SortKeyValue(
sort_type=sort_item.type,
direction=sort_item.direction,
measure_dim_identifier=f"dim_{measure_dim.idx}",
data_column_locators=_create_data_col_locators(sort_item.locators),
)
)


def _merge_dims_with_sorting(dims: list[TableDimension], sorting: list[list[SortKey]]) -> list[TableDimension]:
for dim in dims:
dim_sorting = sorting[dim.idx]
if dim_sorting:
dim.sorting = [ds.to_dict() for ds in dim_sorting]
return dims


def _create_dims_with_sorts(dims: list[TableDimension], sorts: list[VisualizationSort]) -> list[TableDimension]:
"""
Places sorting into dimensions. Returns the same dimensions objects but with modified sorting.
Expand Down Expand Up @@ -406,71 +454,23 @@ def _create_dims_with_sorts(dims: list[TableDimension], sorts: list[Visualizatio
return _merge_dims_with_sorting(dims, sorting)


def _merge_dims_with_sorting(dims: list[TableDimension], sorting: list[list[SortKey]]) -> list[TableDimension]:
for dim in dims:
dim_sorting = sorting[dim.idx]
if dim_sorting:
dim.sorting = [ds.to_dict() for ds in dim_sorting]
return dims


def _append_measure_sort_key(
measure_dim: Optional[TableDimension],
non_measure_dim_idx: Optional[int],
sort_item: VisualizationSort,
sorting: list[list[SortKey]],
) -> None:
if non_measure_dim_idx is None:
logger.warning(
"Trying to use measure sort in an execution that only contains dimension with MeasureGroup. "
"This is not valid sort. Measure sort is used to sort the non-measure dimension by values "
"from measure dimension. Skipping."
)
return

if not measure_dim:
logger.warning("Trying to use measure sort in an execution that does not contain MeasureGroup. Skipping.")
return

sorting[non_measure_dim_idx].append(
SortKeyValue(
sort_type=sort_item.type,
direction=sort_item.direction,
measure_dim_identifier=f"dim_{measure_dim.idx}",
data_column_locators=_create_data_col_locators(sort_item.locators),
)
)


def _append_attribute_sort_key(
dims: list[TableDimension], sort_item: VisualizationSort, sorting: list[list[SortKey]]
) -> None:
dim_idx = _get_dim_idx_for_predicate(dims, lambda x: sort_item.attribute_identifier in x.item_ids)
if dim_idx is None:
log_msg = (
f'attempting to sort by attribute with localId "{sort_item.attribute_identifier}" '
"but this attribute is not in any dimension."
)
logger.warning(log_msg)
return

sorting[dim_idx].append(
SortKeyAttribute(
sort_type=sort_item.type,
direction=sort_item.direction,
attribute_identifier=sort_item.attribute_identifier,
attribute_sort_type=sort_item.attribute_sort_type,
)
)


def _vis_is_transposed(visualization: Visualization) -> bool:
controls = visualization.properties.get("controls")
if not controls:
return False
return controls.get("measureGroupDimension") == "rows"


def _create_dimension(bucket: VisualizationBucket, measures_item_identifier: Optional[str] = None) -> TableDimension:
item_ids = [a.local_id for a in bucket.attributes]
if measures_item_identifier is not None:
item_ids.append(measures_item_identifier)
return TableDimension(
item_ids=item_ids,
idx=_GET_DIM_INDEX_OF_BUCKET_TYPE[bucket.type],
)


def _create_dimensions(visualization: Visualization) -> list[TableDimension]:
measures_item_identifier = _MEASURE_GROUP_IDENTIFIER if visualization.metrics else None
row_bucket = visualization.get_bucket_of_type(BucketType.ROWS)
Expand Down Expand Up @@ -567,59 +567,22 @@ def update_compute_info(self, col_total_attr_id: str, row_total_attr_id: str) ->
self.column_subtotal_dimension_index = self.col_attr_ids.index(col_total_attr_id)


def _get_additional_totals(visualization: Visualization, dimensions: list[TableDimension]) -> list[TotalDefinition]:
"""Construct special cases of pivot table totals.
These special cases are -
1. Grand totals - is the value obtained from row and column totals.
2. Marginal totals - is the value obtained within the subgroups from row and column subtotals.
For `Grand Totals`, in Tiger AFM, you specify that you want total of totals with total that have only
"measureGroup" present.
For `Marginal Total`, in Tiger AFM, would need to iterate through both dimensions and obtain the missing
totalDimensions items based on the attribute and column identifiers order in buckets.
"""
totals: list[TotalDefinition] = []
row_bucket = visualization.get_bucket_of_type(BucketType.ROWS)
col_bucket = visualization.get_bucket_of_type(BucketType.COLS)

tci = TotalsComputeInfo(
row_attr_ids=[a.local_id for a in row_bucket.attributes],
col_attr_ids=[a.local_id for a in col_bucket.attributes],
measure_group_rows=[_MEASURE_GROUP_IDENTIFIER] if _MEASURE_GROUP_IDENTIFIER in dimensions[0].item_ids else [],
measure_group_cols=[_MEASURE_GROUP_IDENTIFIER] if _MEASURE_GROUP_IDENTIFIER in dimensions[1].item_ids else [],
)
for row_index, row_total in enumerate(row_bucket.totals):
tci.reset_to_defaults()
for col_index, col_total in enumerate(col_bucket.totals):
# Check for totals from same measure and type
if row_total.measure_id == col_total.measure_id and row_total.type == col_total.type:
tci.update_compute_info(col_total.attribute_id, row_total.attribute_id)

if tci.has_row_and_column_sub_totals:
totals.append(_extend_marginal_totals(col_index, row_total, tci))

if tci.has_row_subtotal_and_column_grand_total:
totals.append(_extend_marginal_totals_of_rows(row_index, row_total, tci))

if tci.has_column_subtotal_and_row_grand_total:
totals.append(_extend_marginal_totals_of_cols(row_index, row_total, tci))

if tci.has_row_and_column_grand_totals:
totals.append(_extend_grand_totals(row_index, row_total, tci))

return totals


def _extend_grand_totals(row_index: int, row_total: VisualizationTotal, tci: TotalsComputeInfo) -> TotalDefinition:
# Extend grand totals payload
row_dim = [TotalDimension(idx=0, items=tci.measure_group_rows)] if tci.measure_group_rows else []
col_dim = [TotalDimension(idx=1, items=tci.measure_group_cols)] if tci.measure_group_cols else []
def _extend_marginal_totals(col_index: int, row_total: VisualizationTotal, tci: TotalsComputeInfo) -> TotalDefinition:
# Extend marginal totals payload
return TotalDefinition(
local_id=_grand_total_local_identifier(row_total, row_index),
local_id=_marginal_total_local_identifier(row_total, col_index),
aggregation=row_total.type,
metric_local_id=row_total.measure_id,
total_dims=row_dim + col_dim,
total_dims=[
TotalDimension(
idx=0,
items=tci.row_attr_ids[: tci.row_dimension_index] + tci.measure_group_rows,
),
TotalDimension(
idx=1,
items=tci.col_attr_ids[: tci.column_dimension_index] + tci.measure_group_cols,
),
],
)


Expand Down Expand Up @@ -663,23 +626,60 @@ def _extend_marginal_totals_of_rows(
)


def _extend_marginal_totals(col_index: int, row_total: VisualizationTotal, tci: TotalsComputeInfo) -> TotalDefinition:
# Extend marginal totals payload
def _extend_grand_totals(row_index: int, row_total: VisualizationTotal, tci: TotalsComputeInfo) -> TotalDefinition:
# Extend grand totals payload
row_dim = [TotalDimension(idx=0, items=tci.measure_group_rows)] if tci.measure_group_rows else []
col_dim = [TotalDimension(idx=1, items=tci.measure_group_cols)] if tci.measure_group_cols else []
return TotalDefinition(
local_id=_marginal_total_local_identifier(row_total, col_index),
local_id=_grand_total_local_identifier(row_total, row_index),
aggregation=row_total.type,
metric_local_id=row_total.measure_id,
total_dims=[
TotalDimension(
idx=0,
items=tci.row_attr_ids[: tci.row_dimension_index] + tci.measure_group_rows,
),
TotalDimension(
idx=1,
items=tci.col_attr_ids[: tci.column_dimension_index] + tci.measure_group_cols,
),
],
total_dims=row_dim + col_dim,
)


def _get_additional_totals(visualization: Visualization, dimensions: list[TableDimension]) -> list[TotalDefinition]:
"""Construct special cases of pivot table totals.
These special cases are -
1. Grand totals - is the value obtained from row and column totals.
2. Marginal totals - is the value obtained within the subgroups from row and column subtotals.
For `Grand Totals`, in Tiger AFM, you specify that you want total of totals with total that have only
"measureGroup" present.
For `Marginal Total`, in Tiger AFM, would need to iterate through both dimensions and obtain the missing
totalDimensions items based on the attribute and column identifiers order in buckets.
"""
totals: list[TotalDefinition] = []
row_bucket = visualization.get_bucket_of_type(BucketType.ROWS)
col_bucket = visualization.get_bucket_of_type(BucketType.COLS)

tci = TotalsComputeInfo(
row_attr_ids=[a.local_id for a in row_bucket.attributes],
col_attr_ids=[a.local_id for a in col_bucket.attributes],
measure_group_rows=[_MEASURE_GROUP_IDENTIFIER] if _MEASURE_GROUP_IDENTIFIER in dimensions[0].item_ids else [],
measure_group_cols=[_MEASURE_GROUP_IDENTIFIER] if _MEASURE_GROUP_IDENTIFIER in dimensions[1].item_ids else [],
)
for row_index, row_total in enumerate(row_bucket.totals):
tci.reset_to_defaults()
for col_index, col_total in enumerate(col_bucket.totals):
# Check for totals from same measure and type
if row_total.measure_id == col_total.measure_id and row_total.type == col_total.type:
tci.update_compute_info(col_total.attribute_id, row_total.attribute_id)

if tci.has_row_and_column_sub_totals:
totals.append(_extend_marginal_totals(col_index, row_total, tci))

if tci.has_row_subtotal_and_column_grand_total:
totals.append(_extend_marginal_totals_of_rows(row_index, row_total, tci))

if tci.has_column_subtotal_and_row_grand_total:
totals.append(_extend_marginal_totals_of_cols(row_index, row_total, tci))

if tci.has_row_and_column_grand_totals:
totals.append(_extend_grand_totals(row_index, row_total, tci))

return totals


def _get_computable_totals(visualization: Visualization, dimensions: list[TableDimension]) -> list[TotalDefinition]:
Expand Down

0 comments on commit 36155e7

Please sign in to comment.