Skip to content

Commit

Permalink
[GSProcessing] Fix ParquetRowCounter bug when different types had sam…
Browse files Browse the repository at this point in the history
…e-name features
  • Loading branch information
thvasilo committed Jan 18, 2025
1 parent 3491340 commit f8d3f10
Show file tree
Hide file tree
Showing 2 changed files with 322 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,13 @@ def _add_counts_for_graph_structure(
all_entries_counts = [] # type: List[Sequence[int]]
for type_value in self.metadata_dict[edge_or_node_type_key]:
logging.info("Getting counts for %s, %s", top_level_key, type_value)
relative_file_list = self.metadata_dict[top_level_key][type_value]["data"]
# Get the data dictionary for this type and create a copy
# to avoid modifying shared references
type_data_dict = self.metadata_dict[top_level_key][type_value].copy()
relative_file_list = type_data_dict["data"]
type_row_counts = self.get_row_counts_for_parquet_files(relative_file_list)

# Store the row counts directly in the metadata dictionary
self.metadata_dict[top_level_key][type_value]["row_counts"] = type_row_counts
all_entries_counts.append(type_row_counts)

Expand Down Expand Up @@ -229,10 +234,15 @@ def _add_counts_for_features(self, top_level_key: str, edge_or_node_type_key: st
self.metadata_dict[top_level_key].keys(),
)
continue
for feature_name, feature_data_dict in self.metadata_dict[top_level_key][
type_name
].items():
relative_file_list = feature_data_dict["data"] # type: Sequence[str]

# Create a new dictionary for this type's features
type_features = self.metadata_dict[top_level_key][type_name]

for feature_name, feature_data in type_features.items():
# Create a copy of the feature data to avoid modifying shared references
feature_data_dict = feature_data.copy()
relative_file_list = feature_data_dict["data"]

logging.info(
"Getting counts for %s, type: %s, feature: %s",
top_level_key,
Expand All @@ -247,7 +257,12 @@ def _add_counts_for_features(self, top_level_key: str, edge_or_node_type_key: st
feature_name,
feature_row_counts,
)
# Store the row counts in both places
feature_data_dict["row_counts"] = feature_row_counts
self.metadata_dict[top_level_key][type_name][feature_name][
"row_counts"
] = feature_row_counts

features_per_type_counts.append(feature_row_counts)
all_feature_counts.append(features_per_type_counts)

Expand Down Expand Up @@ -405,6 +420,14 @@ def verify_metadata_match(graph_meta: Dict[str, Dict]) -> bool:
True if all row counts match for each type, False otherwise.
"""
logging.info("Verifying features and structure row counts match...")

# Check if required keys exist
required_keys = ["edge_data", "edges", "node_data"]
missing_keys = [key for key in required_keys if key not in graph_meta]
if missing_keys:
logging.error("Missing required keys in metadata: %s", missing_keys)
return False

all_edge_counts_match = ParquetRowCounter.verify_features_and_graph_structure_match(
graph_meta["edge_data"], graph_meta["edges"]
)
Expand All @@ -422,8 +445,6 @@ def verify_metadata_match(graph_meta: Dict[str, Dict]) -> bool:
or not all_edge_data_counts_match
):
all_match = False
# TODO: Should we create a file as indication
# downstream that repartitioning is necessary?
logging.info(
"Some edge/node row counts do not match, "
"will need to re-partition before creating distributed graph."
Expand Down
294 changes: 294 additions & 0 deletions graphstorm-processing/tests/test_row_count_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
"""
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License").
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
import json
import shutil
import pytest

import pyarrow as pa
import pyarrow.parquet as pq

from graphstorm_processing.graph_loaders.row_count_utils import (
ParquetRowCounter,
verify_metadata_match,
)
from graphstorm_processing.constants import FilesystemType

# pylint: disable=redefined-outer-name

_ROOT = os.path.abspath(os.path.dirname(__file__))
TEMP_DATA_PREFIX = os.path.join(_ROOT, "resources/row_counting/generated_parquet/")


def create_feature_table(col_name: str, num_rows: int, feature_start_val: int) -> pa.Table:
"""Creates a test PyArrow table."""
feature: pa.Array = pa.array(range(feature_start_val, feature_start_val + num_rows))
return pa.table([feature], names=[col_name])


@pytest.fixture(scope="module")
def test_metadata():
"""Create and return the test metadata structure."""
return {
"node_type": ["type1", "type2"],
"edge_type": ["type1:edge:type2"],
"node_data": {
"type1": {
"feature1": {
"format": {"name": "parquet"},
"data": [
"node_data/type1-feature1/part-00000.parquet",
"node_data/type1-feature1/part-00001.parquet",
],
},
"feature2": {
"format": {"name": "parquet"},
"data": [
"node_data/type1-feature2/part-00000.parquet",
"node_data/type1-feature2/part-00001.parquet",
],
},
},
"type2": {
"feature1": {
"format": {"name": "parquet"},
"data": [
"node_data/type2-feature1/part-00000.parquet",
"node_data/type2-feature1/part-00001.parquet",
],
}
},
},
"edge_data": {
"type1:edge:type2": {
"weight": {
"format": {"name": "parquet"},
"data": [
"edge_data/type1_edge_type2-weight/part-00000.parquet",
"edge_data/type1_edge_type2-weight/part-00001.parquet",
],
}
}
},
"raw_id_mappings": {
"type1": {
"format": {"name": "parquet"},
"data": [
"raw_id_mappings/type1/part-00000.parquet",
"raw_id_mappings/type1/part-00001.parquet",
],
},
"type2": {
"format": {"name": "parquet"},
"data": [
"raw_id_mappings/type2/part-00000.parquet",
"raw_id_mappings/type2/part-00001.parquet",
],
},
},
"edges": {
"type1:edge:type2": {
"format": {"name": "parquet"},
"data": [
"edges/type1_edge_type2/part-00000.parquet",
"edges/type1_edge_type2/part-00001.parquet",
],
}
},
}


@pytest.fixture(scope="module", autouse=True)
def create_test_files_fixture(test_metadata):
"""Creates test files with known row counts."""
if os.path.exists(TEMP_DATA_PREFIX):
shutil.rmtree(TEMP_DATA_PREFIX)
os.makedirs(TEMP_DATA_PREFIX)

# Write metadata
with open(os.path.join(TEMP_DATA_PREFIX, "metadata.json"), "w", encoding="utf-8") as f:
json.dump(test_metadata, f)

# Create directory structure and files
for path_type in [
"node_data",
"edge_data",
"edges",
"raw_id_mappings",
]: # Added raw_id_mappings
os.makedirs(os.path.join(TEMP_DATA_PREFIX, path_type))

# Create node data files
for type_name in ["type1", "type2"]:
for feature in ["feature1", "feature2"]:
feature_path = os.path.join(TEMP_DATA_PREFIX, "node_data", f"{type_name}-{feature}")
os.makedirs(feature_path, exist_ok=True)

# Create files with different row counts
pq.write_table(
create_feature_table(feature, 10, 0),
os.path.join(feature_path, "part-00000.parquet"),
)
pq.write_table(
create_feature_table(feature, 15, 10),
os.path.join(feature_path, "part-00001.parquet"),
)

# Create raw ID mapping files
for type_name in ["type1", "type2"]:
mapping_path = os.path.join(TEMP_DATA_PREFIX, "raw_id_mappings", type_name)
os.makedirs(mapping_path)
# Create mapping files with the same row counts as other files
pq.write_table(
create_feature_table("id", 10, 0), os.path.join(mapping_path, "part-00000.parquet")
)
pq.write_table(
create_feature_table("id", 15, 10), os.path.join(mapping_path, "part-00001.parquet")
)

# Create edge data files
edge_feat_path = os.path.join(TEMP_DATA_PREFIX, "edge_data", "type1_edge_type2-weight")
os.makedirs(edge_feat_path)
pq.write_table(
create_feature_table("weight", 10, 0), os.path.join(edge_feat_path, "part-00000.parquet")
)
pq.write_table(
create_feature_table("weight", 15, 10), os.path.join(edge_feat_path, "part-00001.parquet")
)

# Create edge structure files
edge_path = os.path.join(TEMP_DATA_PREFIX, "edges", "type1_edge_type2")
os.makedirs(edge_path)
pq.write_table(
create_feature_table("edge", 10, 0), os.path.join(edge_path, "part-00000.parquet")
)
pq.write_table(
create_feature_table("edge", 15, 10), os.path.join(edge_path, "part-00001.parquet")
)

yield TEMP_DATA_PREFIX

# Cleanup
shutil.rmtree(TEMP_DATA_PREFIX)


@pytest.fixture(scope="module")
def row_counter(test_metadata):
"""Create a ParquetRowCounter instance."""
return ParquetRowCounter(test_metadata, TEMP_DATA_PREFIX, FilesystemType.LOCAL)


def test_row_counter_initialization(row_counter, test_metadata):
"""Test counter initialization."""
assert row_counter.metadata_dict == test_metadata
assert row_counter.output_prefix == TEMP_DATA_PREFIX
assert row_counter.filesystem_type == FilesystemType.LOCAL


def test_get_row_count_for_single_file(row_counter):
"""Test counting rows in a single file."""
count = row_counter.get_row_count_for_parquet_file(
"node_data/type1-feature1/part-00000.parquet"
)
assert count == 10


def test_get_row_counts_for_multiple_files(row_counter):
"""Test counting rows across multiple files."""
counts = row_counter.get_row_counts_for_parquet_files(
[
"node_data/type1-feature1/part-00000.parquet",
"node_data/type1-feature1/part-00001.parquet",
]
)
assert counts == [10, 15]


def test_add_counts_to_metadata(row_counter, test_metadata):
"""Test adding row counts to metadata."""
updated_metadata = row_counter.add_row_counts_to_metadata(test_metadata)

# Check edge counts
assert "row_counts" in updated_metadata["edges"]["type1:edge:type2"]
assert updated_metadata["edges"]["type1:edge:type2"]["row_counts"] == [10, 15]

# Check node feature counts for both types
assert "row_counts" in updated_metadata["node_data"]["type1"]["feature1"]
assert updated_metadata["node_data"]["type1"]["feature1"]["row_counts"] == [10, 15]
assert "row_counts" in updated_metadata["node_data"]["type1"]["feature2"]
assert "row_counts" in updated_metadata["node_data"]["type2"]["feature1"]


def test_verify_features_and_structure_match():
"""Test verification of feature and structure row counts."""
structure_meta = {"type1": {"row_counts": [10, 15], "data": ["file1.parquet", "file2.parquet"]}}

# Test matching counts
feature_meta = {
"type1": {"feature1": {"row_counts": [10, 15], "data": ["feat1.parquet", "feat2.parquet"]}}
}
assert ParquetRowCounter.verify_features_and_graph_structure_match(feature_meta, structure_meta)

# Test mismatched counts
feature_meta["type1"]["feature1"]["row_counts"] = [10, 16]
assert not ParquetRowCounter.verify_features_and_graph_structure_match(
feature_meta, structure_meta
)


def test_verify_all_features_match():
"""Test verification that all features for a type have matching counts."""
feature_meta = {
"type1": {
"feature1": {"row_counts": [10, 15], "data": ["feat1.parquet", "feat2.parquet"]},
"feature2": {"row_counts": [10, 15], "data": ["feat3.parquet", "feat4.parquet"]},
}
}

# Test matching counts
assert ParquetRowCounter.verify_all_features_match(feature_meta)

# Test mismatched counts
feature_meta["type1"]["feature2"]["row_counts"] = [10, 16]
assert not ParquetRowCounter.verify_all_features_match(feature_meta)


def test_shared_feature_names(row_counter, test_metadata):
"""Test handling of shared feature names across different types."""
updated_metadata = row_counter.add_row_counts_to_metadata(test_metadata)

# Verify both types have row counts for feature1
assert "row_counts" in updated_metadata["node_data"]["type1"]["feature1"]
assert "row_counts" in updated_metadata["node_data"]["type2"]["feature1"]

# Verify the counts are independent
type1_counts = updated_metadata["node_data"]["type1"]["feature1"]["row_counts"]
type2_counts = updated_metadata["node_data"]["type2"]["feature1"]["row_counts"]
assert type1_counts == type2_counts # Should be [10, 15] for both


def test_verify_metadata_match(row_counter, test_metadata):
"""Test the full metadata verification function."""
updated_metadata = row_counter.add_row_counts_to_metadata(test_metadata)

# Test with correct metadata
assert verify_metadata_match(updated_metadata)

# Test with corrupted metadata
corrupted_metadata = updated_metadata.copy()
corrupted_metadata["node_data"]["type1"]["feature1"]["row_counts"] = [10, 16]
assert not verify_metadata_match(corrupted_metadata)

0 comments on commit f8d3f10

Please sign in to comment.