From 865d94393a326990aca4fafb76438178855bfd72 Mon Sep 17 00:00:00 2001 From: Theodore Vasiloudis Date: Thu, 16 Jan 2025 22:23:30 +0000 Subject: [PATCH] [GSProcessing] Enforce ordering for node label processing during classification --- .../config/label_config_base.py | 24 ++ .../graphstorm_processing/constants.py | 10 + .../data_transformations/dist_label_loader.py | 44 ++- .../dist_category_transformation.py | 18 +- .../dist_label_transformation.py | 11 +- .../distributed_executor.py | 10 +- .../dist_heterogeneous_loader.py | 281 +++++++++++---- .../graph_loaders/row_count_utils.py | 7 +- .../graph_loaders/schema_utils.py | 6 +- .../repartition_files.py | 5 +- .../tests/test_dist_heterogenous_loader.py | 320 +++++++++++++++--- .../tests/test_dist_label_loader.py | 34 +- 12 files changed, 621 insertions(+), 149 deletions(-) diff --git a/graphstorm-processing/graphstorm_processing/config/label_config_base.py b/graphstorm-processing/graphstorm_processing/config/label_config_base.py index 000f80c9ca..0e3c878633 100644 --- a/graphstorm-processing/graphstorm_processing/config/label_config_base.py +++ b/graphstorm-processing/graphstorm_processing/config/label_config_base.py @@ -18,6 +18,8 @@ import logging from typing import Any, Dict, Optional +from graphstorm_processing.constants import VALID_TASK_TYPES + class LabelConfig(abc.ABC): """Basic class for label config""" @@ -55,6 +57,9 @@ def __init__(self, config_dict: Dict[str, Any]): self._mask_field_names = None def _sanity_check(self): + assert ( + self._task_type in VALID_TASK_TYPES + ), f"Invalid task type {self._task_type}, must be one of {VALID_TASK_TYPES}" if self._label_column == "": assert self._task_type == "link_prediction", ( "When no label column is specified, the task type must be link_prediction, " @@ -83,6 +88,25 @@ def _sanity_check(self): assert all(isinstance(x, str) for x in self._mask_field_names) assert len(self._mask_field_names) == 3 + def __repr__(self) -> str: + """Formal object representation for debugging""" + return ( + f"{self.__class__.__name__}(label_column={self._label_column!r}, " + f"task_type={self._task_type!r}, separator={self._separator!r}, " + f"multilabel={self._multilabel!r}, split={self._split!r}, " + f"custom_split_filenames={self._custom_split_filenames!r}, " + f"mask_field_names={self._mask_field_names!r})" + ) + + def __str__(self) -> str: + """Informal object representation for readability""" + task_desc = f"{self._task_type} task" + if self._label_column: + task_desc += f" on column '{self._label_column}'" + if self._multilabel: + task_desc += f" (multilabel with separator '{self._separator}')" + return task_desc + @property def label_column(self) -> str: """The name of the column storing the target label property value.""" diff --git a/graphstorm-processing/graphstorm_processing/constants.py b/graphstorm-processing/graphstorm_processing/constants.py index a732306ab8..405b4b5547 100644 --- a/graphstorm-processing/graphstorm_processing/constants.py +++ b/graphstorm-processing/graphstorm_processing/constants.py @@ -66,6 +66,16 @@ NODE_MAPPING_STR = "orig" NODE_MAPPING_INT = "new" +################# Reserved columns ################ +DATA_SPLIT_SET_MASK_COL = "GSP-SAMPLE-SET-MASK" + +################# Supported task types ############## +VALID_TASK_TYPES = { + "classification", + "regression", + "link_prediction", +} + ################# Supported execution envs ############## class ExecutionEnv(Enum): diff --git a/graphstorm-processing/graphstorm_processing/data_transformations/dist_label_loader.py b/graphstorm-processing/graphstorm_processing/data_transformations/dist_label_loader.py index a85299947d..53dda8de27 100644 --- a/graphstorm-processing/graphstorm_processing/data_transformations/dist_label_loader.py +++ b/graphstorm-processing/graphstorm_processing/data_transformations/dist_label_loader.py @@ -16,9 +16,10 @@ from dataclasses import dataclass from math import fsum +from typing import Optional from pyspark.sql import DataFrame, SparkSession -from pyspark.sql.types import FloatType +from pyspark.sql.types import NumericType from graphstorm_processing.config.label_config_base import LabelConfig from graphstorm_processing.data_transformations.dist_transformations import ( @@ -100,11 +101,33 @@ class DistLabelLoader: The SparkSession to use for processing. """ - def __init__(self, label_config: LabelConfig, spark: SparkSession) -> None: + def __init__( + self, label_config: LabelConfig, spark: SparkSession, order_col: Optional[str] = None + ) -> None: self.label_config = label_config self.label_column = label_config.label_column self.spark = spark self.label_map: dict[str, int] = {} + self.order_col = order_col + + def __str__(self) -> str: + """Informal object representation for readability""" + return ( + f"DistLabelLoader(label_column='{self.label_column}', " + f"task_type='{self.label_config.task_type}', " + f"multilabel={self.label_config.multilabel}, " + f"order_col={self.order_col!r})" + ) + + def __repr__(self) -> str: + """Formal object representation for debugging""" + return ( + f"DistLabelLoader(" + f"label_config={self.label_config!r}, " + f"spark={self.spark!r}, " + f"order_col={self.order_col!r}, " + f"label_map={self.label_map!r})" + ) def process_label(self, input_df: DataFrame) -> DataFrame: """Transforms the label column in the input DataFrame to conform to GraphStorm expectations. @@ -134,6 +157,7 @@ def process_label(self, input_df: DataFrame) -> DataFrame: label_type = input_df.schema[self.label_column].dataType if self.label_config.task_type == "classification": + assert self.order_col, f"{self.order_col} must be provided for classification tasks" if self.label_config.multilabel: assert self.label_config.separator label_transformer = DistMultiLabelTransformation( @@ -141,16 +165,24 @@ def process_label(self, input_df: DataFrame) -> DataFrame: ) else: label_transformer = DistSingleLabelTransformation( - [self.label_config.label_column], self.spark + [self.label_config.label_column], + self.spark, ) - transformed_label = label_transformer.apply(input_df).select(self.label_column) + transformed_label = label_transformer.apply(input_df) + if self.order_col: + assert self.order_col in transformed_label.columns, ( + f"{self.order_col=} needs to be part of transformed " + f"label DF, got {transformed_label.columns=}" + ) + transformed_label = transformed_label.sort(self.order_col).cache() + self.label_map = label_transformer.value_map return transformed_label elif self.label_config.task_type == "regression": - if not isinstance(label_type, FloatType): + if not isinstance(label_type, NumericType): raise RuntimeError( - "Data type for regression should be FloatType, " + "Data type for regression should be a NumericType, " f"got {label_type} for {self.label_column}" ) return input_df.select(self.label_column) diff --git a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_category_transformation.py b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_category_transformation.py index 761189aacb..3d8cb06202 100644 --- a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_category_transformation.py +++ b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_category_transformation.py @@ -313,8 +313,12 @@ def __init__(self, cols: Sequence[str], separator: str) -> None: def get_transformation_name() -> str: return "DistMultiCategoryTransformation" - def apply(self, input_df: DataFrame) -> DataFrame: + def apply(self, input_df: DataFrame, return_all_cols: bool = False) -> DataFrame: col_datatype = input_df.schema[self.multi_column].dataType + if return_all_cols: + original_cols = {*input_df.columns} - {self.multi_column} + else: + original_cols = {} is_array_col = False if col_datatype.typeName() == "array": assert isinstance(col_datatype, ArrayType) @@ -326,13 +330,19 @@ def apply(self, input_df: DataFrame) -> DataFrame: is_array_col = True + # Parquet input might come with arrays already, CSV will need splitting if is_array_col: - list_df = input_df.select(self.multi_column).alias(self.multi_column) + multi_column = F.col(self.multi_column) else: - list_df = input_df.select( - F.split(F.col(self.multi_column), self.separator).alias(self.multi_column) + multi_column = F.split(F.col(self.multi_column), self.separator).alias( + self.multi_column ) + list_df = input_df.select( + multi_column, + *original_cols, + ) + distinct_category_counts = ( list_df.withColumn(SINGLE_CATEGORY_COL, F.explode(F.col(self.multi_column))) .groupBy(SINGLE_CATEGORY_COL) diff --git a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_label_transformation.py b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_label_transformation.py index 840654dcd8..66974f6bb4 100644 --- a/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_label_transformation.py +++ b/graphstorm-processing/graphstorm_processing/data_transformations/dist_transformations/dist_label_transformation.py @@ -45,6 +45,7 @@ def __init__(self, cols: Sequence[str], spark: SparkSession) -> None: def apply(self, input_df: DataFrame) -> DataFrame: assert self.spark + original_cols = {*input_df.columns} - {self.label_column} processed_col_name = self.label_column + "_processed" str_indexer = StringIndexer( @@ -63,13 +64,15 @@ def apply(self, input_df: DataFrame) -> DataFrame: # Labels that were missing and were assigned the value numLabels by the StringIndexer # are converted to None - long_class_label = indexed_df.select(F.col(self.label_column).cast("long")).select( + long_class_label = indexed_df.select( F.when( F.col(self.label_column) == len(str_indexer_model.labelsArray[0]), # type: ignore F.lit(None), ) .otherwise(F.col(self.label_column)) - .alias(self.label_column) + .cast("long") + .alias(self.label_column), + *original_cols, ) # Get a mapping from original label to encoded value @@ -112,7 +115,7 @@ def __init__(self, cols: Sequence[str], separator: str) -> None: super().__init__(cols, separator) self.label_column = cols[0] - def apply(self, input_df: DataFrame) -> DataFrame: - multi_cat_df = super().apply(input_df) + def apply(self, input_df: DataFrame, return_all_cols=True) -> DataFrame: + multi_cat_df = super().apply(input_df, return_all_cols=return_all_cols) return multi_cat_df diff --git a/graphstorm-processing/graphstorm_processing/distributed_executor.py b/graphstorm-processing/graphstorm_processing/distributed_executor.py index ee6858b544..5d3e2e9144 100644 --- a/graphstorm-processing/graphstorm_processing/distributed_executor.py +++ b/graphstorm-processing/graphstorm_processing/distributed_executor.py @@ -264,6 +264,7 @@ def __init__( output_prefix=self.output_prefix, precomputed_transformations=self.precomputed_transformations, ) + self.loader = DistHeterogeneousGraphLoader( self.spark, loader_config, @@ -287,17 +288,18 @@ def _upload_output_files(self, loader: DistHeterogeneousGraphLoader, force=False bucket, s3_prefix = s3_utils.extract_bucket_and_key(self.output_prefix) s3 = boto3.resource("s3") - output_files = os.listdir(loader.output_path) + output_files = os.listdir(loader.local_meta_output_path) for output_file in output_files: s3.meta.client.upload_file( - f"{os.path.join(loader.output_path, output_file)}", + f"{os.path.join(loader.local_meta_output_path, output_file)}", bucket, f"{s3_prefix}/{output_file}", ) def run(self) -> None: """ - Executes the Spark processing job. + Executes the Spark processing job, optional repartition job, and uploads any metadata files + if needed. """ logging.info("Performing data processing with PySpark...") @@ -355,7 +357,7 @@ def run(self) -> None: # If any of the metadata modification took place, write an updated metadata file if updated_metadata: updated_meta_path = os.path.join( - self.loader.output_path, "updated_row_counts_metadata.json" + self.loader.local_meta_output_path, "updated_row_counts_metadata.json" ) with open( updated_meta_path, diff --git a/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py b/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py index 5cea423e56..5da6462a28 100644 --- a/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py +++ b/graphstorm-processing/graphstorm_processing/graph_loaders/dist_heterogeneous_loader.py @@ -24,7 +24,11 @@ from dataclasses import dataclass from time import perf_counter from typing import Any, Dict, Optional, Set, Tuple +from uuid import uuid4 +import numpy as np +import pyarrow as pa +from pyarrow import parquet as pq from pyspark import RDD from pyspark.sql import Row, SparkSession, DataFrame, functions as F from pyspark.sql.types import ( @@ -40,6 +44,8 @@ from numpy.random import default_rng from graphstorm_processing.constants import ( + DATA_SPLIT_SET_MASK_COL, + FilesystemType, MIN_VALUE, MAX_VALUE, VALUE_COUNTS, @@ -166,21 +172,21 @@ def __init__( spark: SparkSession, loader_config: HeterogeneousLoaderConfig, ): - self.output_path = loader_config.local_metadata_output_path + self.local_meta_output_path = loader_config.local_metadata_output_path self._data_configs = loader_config.data_configs self.feature_configs: list[FeatureConfig] = [] # TODO: Pass as an argument? if loader_config.input_prefix.startswith("s3://"): - self.filesystem_type = "s3" + self.filesystem_type = FilesystemType.S3 else: assert os.path.isabs(loader_config.input_prefix), "We expect an absolute path" - self.filesystem_type = "local" + self.filesystem_type = FilesystemType.LOCAL self.spark: SparkSession = spark self.add_reverse_edges = loader_config.add_reverse_edges # Remove trailing slash in s3 paths - if self.filesystem_type == "s3": + if self.filesystem_type == FilesystemType.S3: self.input_prefix = s3_utils.s3_path_remove_trailing(loader_config.input_prefix) self.output_prefix = s3_utils.s3_path_remove_trailing(loader_config.output_prefix) else: @@ -254,6 +260,8 @@ def __init__( self.skip_train_masks = False self.pre_computed_transformations = loader_config.precomputed_transformations + print(f"Output prefix once inside DGHL: {self.output_prefix}") + def process_and_write_graph_data( self, data_configs: Mapping[str, Sequence[StructureConfig]] ) -> ProcessedGraphRepresentation: @@ -343,12 +351,14 @@ def process_and_write_graph_data( metadata_dict["graph_info"] = self._finalize_graphinfo_dict(metadata_dict) # The metadata dict is written to disk as a JSON file - with open(os.path.join(self.output_path, "metadata.json"), "w", encoding="utf-8") as f: + with open( + os.path.join(self.local_meta_output_path, "metadata.json"), "w", encoding="utf-8" + ) as f: json.dump(metadata_dict, f, indent=4) # Write the transformations file with open( - os.path.join(self.output_path, TRANSFORMATIONS_FILENAME), + os.path.join(self.local_meta_output_path, TRANSFORMATIONS_FILENAME), "w", encoding="utf-8", ) as f: @@ -358,7 +368,7 @@ def process_and_write_graph_data( # name did not fit Parquet requirements if len(self.column_substitutions) > 0: with open( - os.path.join(self.output_path, "column_substitutions.json"), + os.path.join(self.local_meta_output_path, "column_substitutions.json"), "w", encoding="utf-8", ) as f: @@ -545,7 +555,7 @@ def _write_df( NotImplementedError If an output format other than "csv" or "parquet" is requested. """ - if self.filesystem_type == "s3": + if self.filesystem_type == FilesystemType.S3: output_bucket, output_prefix = s3_utils.extract_bucket_and_key(full_output_path) else: output_bucket = "" @@ -559,6 +569,7 @@ def _write_df( if out_format == "parquet": # Write to parquet input_df = self._replace_special_chars_in_cols(input_df) + # TODO: Remove overwrite mode input_df.write.mode("overwrite").parquet(os.path.join(full_output_path, "parquet")) prefix_with_format = os.path.join(output_prefix, "parquet") elif out_format == "csv": @@ -589,30 +600,26 @@ def csv_row(data: Row): # So we first get the full paths, then the common prefix, # then strip the common prefix from the full paths, # to leave paths relative to where the metadata will be written. - if self.filesystem_type == "s3": + if self.filesystem_type == FilesystemType.S3: object_key_list = s3_utils.list_s3_objects(output_bucket, prefix_with_format) else: object_key_list = [ os.path.join(prefix_with_format, f) for f in os.listdir(prefix_with_format) ] + # Ensure key list is sorted, to maintain any order in DF + object_key_list.sort() + assert ( object_key_list ), f"No files found written under: {output_bucket}/{prefix_with_format}" - # Only include data files and strip the common output path prefix from the key filtered_key_list = [] - if self.filesystem_type == "s3": - # Get the S3 key prefix without the bucket - common_prefix = self.output_prefix.split("/", maxsplit=3)[3] - else: - common_prefix = self.output_prefix for key in object_key_list: if key.endswith(".csv") or key.endswith(".parquet"): - chars_to_skip = len(common_prefix) - key_without_prefix = key[chars_to_skip:].lstrip("/") - filtered_key_list.append(key_without_prefix) + # Only include data files and strip the common output path prefix from the key + filtered_key_list.append(self._strip_common_prefix(key)) logging.info( "Wrote %d files to %s, (%d requested)", @@ -622,6 +629,68 @@ def csv_row(data: Row): ) return filtered_key_list + @staticmethod + def _create_metadata_entry(path_list: Sequence[str]) -> dict: + return { + "format": {"name": FORMAT_NAME, "delimiter": DELIMITER}, + "data": path_list, + } + + # TODO: Enable writing table to multiple Parquet files + def _write_pyarrow_table(self, pa_table: pa.Table, out_path: str) -> list[str]: + """Writes a single PyArrow Table to Parquet. + + Parameters + ---------- + pa_table : pa.Table + PyArrow Table to write to storage. + out_path : str + Full filepath to write. Can be an S3 URI or local path. + + Returns + ------- + list[str] + The list of files written, paths being relative to ``self.output_prefix``. + """ + + # TODO: Remove once we support writing multiple files + assert out_path.endswith(".parquet") + if self.filesystem_type == FilesystemType.LOCAL: + os.makedirs(os.path.dirname(out_path), exist_ok=True) + + pq.write_table( + pa_table, + where=out_path, + ) + print(f"{out_path=}") + + if self.filesystem_type == FilesystemType.S3: + _, out_path = s3_utils.extract_bucket_and_key(out_path) + + return [self._strip_common_prefix(out_path)] + + def _strip_common_prefix(self, full_path: str) -> str: + """Strips the common prefix from the full path. + + Parameters + ---------- + full_path : str + Full path to the file, including the common prefix. + + Returns + ------- + str + The path without the common prefix. + """ + if self.filesystem_type == FilesystemType.S3: + # Get the S3 key prefix without the bucket + common_prefix = self.output_prefix.split("/", maxsplit=3)[3] + else: + common_prefix = self.output_prefix + + key_without_prefix = full_path.replace(common_prefix, "").lstrip("/") + return key_without_prefix + def _add_node_mappings_to_metadata(self, metadata_dict: Dict) -> Dict: """ Adds node mappings to the metadata dict that is eventually written to disk. @@ -1212,8 +1281,17 @@ def _process_node_labels( self.graph_info["task_type"] = ( "node_class" if label_conf.task_type == "classification" else "node_regression" ) + # We only should re-order for classification + if label_conf.task_type == "classification": + order_col = NODE_MAPPING_INT + assert ( + order_col in nodes_df.columns + ), f"Order column '{order_col}' not found in node dataframe, {nodes_df.columns=}" + else: + order_col = None + self.graph_info["is_multilabel"] = label_conf.multilabel - node_label_loader = DistLabelLoader(label_conf, self.spark) + node_label_loader = DistLabelLoader(label_conf, self.spark, order_col) logging.info( "Processing label data for node type %s, label col: %s...", node_type, @@ -1224,10 +1302,36 @@ def _process_node_labels( self.graph_info["label_map"] = node_label_loader.label_map label_output_path = ( - f"{self.output_prefix}/node_data/" f"{node_type}-label-{label_conf.label_column}" + f"{self.output_prefix}/node_data/{node_type}-label-{label_conf.label_column}" ) - path_list = self._write_df(transformed_label, label_output_path) + if label_conf.task_type == "classification": + assert order_col, "An order column is needed to process classification labels." + # For classification we need to order the DF, collect to Pandas + # and write to storage directly + logging.info( + "Collecting and sorting label data for node type '%s', label col: '%s'...", + node_type, + label_conf.label_column, + ) + # The presence of order_col ensures transformed_label DF comes in ordered + # but do we want to double-check before writing? + transformed_label_pd = transformed_label.select( + label_conf.label_column, order_col + ).toPandas() + + # Write to parquet using zero-copy column values from Pandas DF + path_list = self._write_pyarrow_table( + pa.Table.from_arrays( + [transformed_label_pd[label_conf.label_column].values], + names=[label_conf.label_column], + ), + f"{label_output_path}/part-00000-{uuid4()}.parquet", + ) + else: + path_list = self._write_df( + transformed_label.select(label_conf.label_column), label_output_path + ) label_metadata_dict = { "format": {"name": FORMAT_NAME, "delimiter": DELIMITER}, @@ -1261,6 +1365,7 @@ def _process_node_labels( ) else: custom_split_filenames = None + label_split_dicts = self._create_split_files( nodes_df, label_conf.label_column, @@ -1268,6 +1373,7 @@ def _process_node_labels( split_masks_output_prefix, custom_split_filenames, mask_field_names=label_conf.mask_field_names, + order_col=order_col, ) node_type_label_metadata.update(label_split_dicts) @@ -1915,6 +2021,7 @@ def _create_split_files( custom_split_file: Optional[CustomSplit] = None, seed: Optional[int] = None, mask_field_names: Optional[tuple[str, str, str]] = None, + order_col: Optional[str] = None, ) -> Dict: """ Given an input dataframe and a list of split rates or a list of custom split files @@ -1949,35 +2056,83 @@ def _create_split_files( The metadata dict elements for the train/test/val masks, to be added to the caller's edge/node type metadata. """ - # If the user did not provide a split rate we use a default + + # Use custom column names if requested + if mask_field_names: + mask_names = mask_field_names + else: + mask_names = ("train_mask", "val_mask", "test_mask") + + # TODO: Make this an argument to the write functions to not rely on function scope? split_metadata = {} + + def write_masks_numpy(np_mask_arrays: Sequence[np.ndarray]): + """Write the 3 mask files to storage from numpy arrays.""" + for mask_name, mask_vals in zip(mask_names, np_mask_arrays): + mask_full_outpath = f"{output_path}-{mask_name}.parquet" + path_list = self._write_pyarrow_table( + pa.Table.from_arrays([mask_vals], names=[mask_name]), + mask_full_outpath, + ) + + split_metadata[mask_name] = self._create_metadata_entry(path_list) + + def write_masks_spark(mask_dfs: Sequence[DataFrame]): + """Write the 3 mask files to storage from Spark DataFrames.""" + for mask_name, mask_df in zip(mask_names, mask_dfs): + out_path_list = self._write_df( + mask_df.select(F.col(mask_name).cast(ByteType()).alias(mask_name)), + f"{output_path}-{mask_name}", + ) + split_metadata[mask_name] = self._create_metadata_entry(out_path_list) + if not custom_split_file: - mask_dfs = self._create_split_files_split_rates( - input_df, label_column, split_rates, seed, mask_field_names + # No custom split file, we create masks according to split rates + masks_single_df = self._create_split_files_split_rates( + input_df, + label_column, + split_rates, + seed, + order_col, ) + + if order_col: + # Classification case, masks had to be re-ordered to maintain same order as labels + combined_masks_pandas = masks_single_df.select( + [DATA_SPLIT_SET_MASK_COL, label_column] + ).toPandas() + # Convert mask column of [x, x, x] 0/1 lists to 3 numpy 0-dim arrays + # NOTE: We request zero-copy numpy conversion but it's not guaranteed, see + # https://pandas.pydata.org/docs/reference/api/pandas.Series.to_numpy.html#pandas.Series.to_numpy + mask_array: np.ndarray = np.stack( + combined_masks_pandas[DATA_SPLIT_SET_MASK_COL].to_numpy(copy=False) + ) + train_mask_np = mask_array[:, 0].astype(np.int8) + val_mask_np = mask_array[:, 1].astype(np.int8) + test_mask_np = mask_array[:, 2].astype(np.int8) + write_masks_numpy([train_mask_np, val_mask_np, test_mask_np]) + else: + # Regression/LP case, no requirement to maintain order + # TODO: Ensure order is maintained for regression labels + train_mask_df = masks_single_df.select( + F.col(DATA_SPLIT_SET_MASK_COL)[0].alias(mask_names[0]) + ) + val_mask_df = masks_single_df.select( + F.col(DATA_SPLIT_SET_MASK_COL)[1].alias(mask_names[1]) + ) + test_mask_df = masks_single_df.select( + F.col(DATA_SPLIT_SET_MASK_COL)[2].alias(mask_names[2]) + ) + write_masks_spark([train_mask_df, val_mask_df, test_mask_df]) else: mask_dfs = self._create_split_files_custom_split( input_df, custom_split_file, mask_field_names ) + write_masks_spark(mask_dfs) - def create_metadata_entry(path_list): - return { - "format": {"name": FORMAT_NAME, "delimiter": DELIMITER}, - "data": path_list, - } - - if mask_field_names is not None: - mask_names = mask_field_names - else: - mask_names = ("train_mask", "val_mask", "test_mask") - - # Write each mask DF to disk with appropriate name - for mask_name, mask_df in zip(mask_names, mask_dfs): - out_path_list = self._write_df( - mask_df.select(F.col(mask_name).cast(ByteType()).alias(mask_name)), - f"{output_path}-{mask_name}", - ) - split_metadata[mask_name] = create_metadata_entry(out_path_list) + assert split_metadata.keys() == { + *mask_names + }, "We expect the produced metadata to contain all mask entries" return split_metadata @@ -1987,8 +2142,8 @@ def _create_split_files_split_rates( label_column: str, split_rates: Optional[SplitRates], seed: Optional[int], - mask_field_names: Optional[tuple[str, str, str]] = None, - ) -> tuple[DataFrame, DataFrame, DataFrame]: + order_col: Optional[str] = None, + ) -> DataFrame: """ Creates the train/val/test mask dataframe based on split rates. @@ -2005,15 +2160,14 @@ def _create_split_files_split_rates( If None, a default split rate of 0.8:0.1:0.1 is used. seed: Optional[int] An optional random seed for reproducibility. - mask_field_names: Optional[tuple[str, str, str]] - An optional tuple of field names to use for the split masks. - If not provided, the default field names "train_mask", - "val_mask", and "test_mask" are used. + order_col: Optional[str] + A column to order the output by. Required for classification tasks. Returns ------- - tuple[DataFrame, DataFrame, DataFrame] - Train/val/test mask DataFrames. + DataFrame + DataFrame containing train/val/test masks as single column named as + `constants.DATA_SPLIT_SET_MASK_COL` """ if split_rates is None: split_rates = SplitRates(train_rate=0.8, val_rate=0.1, test_rate=0.1) @@ -2043,28 +2197,25 @@ def multinomial_sample(label_col: str) -> Sequence[int]: return [0, 0, 0] return rng.multinomial(1, split_list).tolist() - group_col_name = "sample_boolean_mask" # TODO: Ensure uniqueness of column? - - # TODO: Use PandasUDF and check if it is faster than UDF + # Note: Using PandasUDF here only led to much worse performance split_group = F.udf(multinomial_sample, ArrayType(IntegerType())) # Convert label col to string and apply UDF # to create one-hot vector indicating train/test/val membership input_col = F.col(label_column).astype("string") if label_column else F.lit("dummy") - int_group_df = input_df.select(split_group(input_col).alias(group_col_name)) + int_group_df = input_df.select( + split_group(input_col).alias(DATA_SPLIT_SET_MASK_COL), *input_df.columns + ) - # We cache because we re-use this DF 3 times - int_group_df.cache() - # Use custom column names if requested - if mask_field_names: - mask_names = mask_field_names - else: - mask_names = ("train_mask", "val_mask", "test_mask") + if order_col: + assert ( + order_col in input_df.columns + ), f"Order column {order_col} not found in {int_group_df.columns}" + int_group_df = int_group_df.orderBy(order_col) - train_mask_df = int_group_df.select(F.col(group_col_name)[0].alias(mask_names[0])) - val_mask_df = int_group_df.select(F.col(group_col_name)[1].alias(mask_names[1])) - test_mask_df = int_group_df.select(F.col(group_col_name)[2].alias(mask_names[2])) + # We cache because we re-use this DF + int_group_df.cache() - return train_mask_df, val_mask_df, test_mask_df + return int_group_df def _create_split_files_custom_split( self, diff --git a/graphstorm-processing/graphstorm_processing/graph_loaders/row_count_utils.py b/graphstorm-processing/graphstorm_processing/graph_loaders/row_count_utils.py index 63c1c2e3d1..b2b72c0c75 100644 --- a/graphstorm-processing/graphstorm_processing/graph_loaders/row_count_utils.py +++ b/graphstorm-processing/graphstorm_processing/graph_loaders/row_count_utils.py @@ -24,6 +24,7 @@ import pyarrow.parquet as pq from pyarrow import fs +from graphstorm_processing.constants import FilesystemType from ..data_transformations import s3_utils # pylint: disable=relative-beyond-top-level @@ -43,11 +44,11 @@ class ParquetRowCounter: The filesystem type. Can be 'local' or 's3'. """ - def __init__(self, metadata_dict: dict, output_prefix: str, filesystem_type: str): + def __init__(self, metadata_dict: dict, output_prefix: str, filesystem_type: FilesystemType): self.output_prefix = output_prefix self.filesystem_type = filesystem_type self.metadata_dict = metadata_dict - if self.filesystem_type == "s3": + if self.filesystem_type == FilesystemType.S3: output_bucket, _ = s3_utils.extract_bucket_and_key(output_prefix) bucket_region = s3_utils.get_bucket_region(output_bucket) # Increase default retries because we are likely to run into @@ -106,7 +107,7 @@ def get_row_count_for_parquet_file(self, relative_parquet_file_path: str) -> int int The number of rows in the Parquet file. """ - if self.filesystem_type == "s3": + if self.filesystem_type == FilesystemType.S3: file_bucket, file_key = s3_utils.extract_bucket_and_key( self.output_prefix, relative_parquet_file_path ) diff --git a/graphstorm-processing/graphstorm_processing/graph_loaders/schema_utils.py b/graphstorm-processing/graphstorm_processing/graph_loaders/schema_utils.py index 19ed03869d..9f11c46adc 100644 --- a/graphstorm-processing/graphstorm_processing/graph_loaders/schema_utils.py +++ b/graphstorm-processing/graphstorm_processing/graph_loaders/schema_utils.py @@ -22,9 +22,9 @@ from pyspark.sql.types import StructType, StructField, StringType, DataType, DoubleType -from ..config.config_parser import EdgeConfig, NodeConfig -from ..config.label_config_base import LabelConfig -from ..config.feature_config_base import FeatureConfig +from graphstorm_processing.config.config_parser import EdgeConfig, NodeConfig +from graphstorm_processing.config.label_config_base import LabelConfig +from graphstorm_processing.config.feature_config_base import FeatureConfig def parse_edge_file_schema(edge_config: EdgeConfig) -> StructType: diff --git a/graphstorm-processing/graphstorm_processing/repartition_files.py b/graphstorm-processing/graphstorm_processing/repartition_files.py index 2f48a61d99..f9d4d17d67 100644 --- a/graphstorm-processing/graphstorm_processing/repartition_files.py +++ b/graphstorm-processing/graphstorm_processing/repartition_files.py @@ -89,7 +89,10 @@ def __init__( self.input_prefix = input_prefix[5:] if input_prefix.startswith("s3://") else input_prefix self.filesystem_type = filesystem_type if self.filesystem_type == FilesystemType.S3: - self.bucket = self.input_prefix.split("/")[1] + # Expected input is bucket/path/to/file, no s3:// prefix + self.bucket = self.input_prefix.split("/")[0] + if not region: + region = s3_utils.get_bucket_region(self.bucket) self.pyarrow_fs = fs.S3FileSystem( region=region, retry_strategy=fs.AwsDefaultS3RetryStrategy(max_attempts=10), diff --git a/graphstorm-processing/tests/test_dist_heterogenous_loader.py b/graphstorm-processing/tests/test_dist_heterogenous_loader.py index f9a8521f75..871a8196ac 100644 --- a/graphstorm-processing/tests/test_dist_heterogenous_loader.py +++ b/graphstorm-processing/tests/test_dist_heterogenous_loader.py @@ -20,8 +20,12 @@ import os import shutil import tempfile +from uuid import uuid4 +import numpy as np +import pandas as pd from numpy.testing import assert_allclose +from pandas.testing import assert_frame_equal from pyspark.sql import SparkSession, DataFrame import pyspark.sql.functions as F import pyarrow.parquet as pq @@ -45,8 +49,8 @@ from graphstorm_processing.config.config_conversion import GConstructConfigConverter from graphstorm_processing.constants import ( COLUMN_NAME, - MIN_VALUE, MAX_VALUE, + MIN_VALUE, VALUE_COUNTS, TRANSFORMATIONS_FILENAME, ) @@ -234,7 +238,7 @@ def verify_integ_test_output( for node_type in metadata["node_type"]: nrows = pq.read_table( os.path.join( - loader.output_path, + loader.local_meta_output_path, os.path.dirname(metadata["raw_id_mappings"][node_type]["data"][0]), ) ).num_rows @@ -252,7 +256,7 @@ def verify_integ_test_output( for edge_type in metadata["edge_type"]: nrows = pq.read_table( os.path.join( - loader.output_path, + loader.local_meta_output_path, os.path.dirname(metadata["edges"][edge_type]["data"][0]), ) ).num_rows @@ -288,7 +292,7 @@ def test_load_dist_heterogen_node_class(dghl_loader: DistHeterogeneousGraphLoade dghl_loader.load() with open( - os.path.join(dghl_loader.output_path, "metadata.json"), "r", encoding="utf-8" + os.path.join(dghl_loader.local_meta_output_path, "metadata.json"), "r", encoding="utf-8" ) as mfile: metadata = json.load(mfile) @@ -314,7 +318,7 @@ def test_load_dist_heterogen_node_class(dghl_loader: DistHeterogeneousGraphLoade assert metadata["node_data"][node_type].keys() == expected_node_data[node_type] with open( - os.path.join(dghl_loader.output_path, TRANSFORMATIONS_FILENAME), + os.path.join(dghl_loader.local_meta_output_path, TRANSFORMATIONS_FILENAME), "r", encoding="utf-8", ) as transformation_file: @@ -331,7 +335,7 @@ def test_load_dist_hgl_without_labels( dghl_loader_no_label.load() with open( - os.path.join(dghl_loader_no_label.output_path, "metadata.json"), + os.path.join(dghl_loader_no_label.local_meta_output_path, "metadata.json"), "r", encoding="utf-8", ) as mfile: @@ -410,7 +414,9 @@ def test_create_all_mapppings_from_edges( assert len(dghl_loader.node_mapping_paths) == 4 for node_type, mapping_files in dghl_loader.node_mapping_paths.items(): - files_with_prefix = [os.path.join(dghl_loader.output_path, x) for x in mapping_files] + files_with_prefix = [ + os.path.join(dghl_loader.local_meta_output_path, x) for x in mapping_files + ] mapping_count = spark.read.parquet(*files_with_prefix).count() assert mapping_count == expected_node_counts[node_type] @@ -515,6 +521,46 @@ def create_edges_df_num_label(spark: SparkSession, missing_data_points: int) -> return df +def create_nodes_df_num_labels( + spark: SparkSession, total_data_points=NUM_DATAPOINTS, missing_data_points=0 +) -> DataFrame: + """Create a nodes DF with a numeric and a string label for testing. + + Returned DF schema: + + NODE_MAPPING_STR: node_str_ids, unique per node (uuid4), + NODE_MAPPING_INT: node_int_ids, unique per node, + NUM_LABEL_COL: numeric labels, 0-9 range, + STR_LABEL_COL: string labels, 10 random string labels, + """ + node_str_ids = [str(uuid4()) for _ in range(total_data_points)] + node_int_ids = np.arange(total_data_points) + + rng = np.random.default_rng(42) + # Create random numerical labels with values 0-9 + num_labels = rng.integers(10, size=total_data_points) + # Create random string labels from a pool of 10 uuid labels + str_label_vals = [str(uuid4()) for _ in range(10)] + str_labels = [str_label_vals[rng.integers(10)] for _ in range(total_data_points)] + + # Set certain number of labels to be missing + for _ in range(missing_data_points): + str_labels[rng.integers(total_data_points)] = None + num_labels[rng.integers(total_data_points)] = None + + pandas_df = pd.DataFrame.from_dict( + { + NODE_MAPPING_STR: node_str_ids, + NODE_MAPPING_INT: node_int_ids, + NUM_LABEL_COL: num_labels, + STR_LABEL_COL: str_labels, + } + ) + df = spark.createDataFrame(pandas_df) + + return df + + def ensure_masks_are_correct( train_mask_df: DataFrame, val_mask_df: DataFrame, @@ -782,7 +828,7 @@ def test_process_node_labels_multitask( spark: SparkSession, dghl_loader: DistHeterogeneousGraphLoader ): """Test processing multi-task link prediction and regression edge labels""" - nodes_df = create_edges_df_num_label(spark, 0) + nodes_df = create_nodes_df_num_labels(spark) class_split_rates = {"train": 0.6, "val": 0.3, "test": 0.1} reg_split_rates = {"train": 0.7, "val": 0.2, "test": 0.1} @@ -797,6 +843,7 @@ def test_process_node_labels_multitask( "split_rate": class_split_rates, "mask_field_names": class_mask_names, } + reg_mask_names = [ f"train_mask_{NUM_LABEL_COL}", f"val_mask_{NUM_LABEL_COL}", @@ -938,38 +985,45 @@ def test_update_label_properties_multilabel( def test_node_custom_label(spark, dghl_loader: DistHeterogeneousGraphLoader, tmp_path): """Test using custom label splits for nodes""" - data = [(i,) for i in range(1, 11)] - # Create DataFrame - nodes_df = spark.createDataFrame(data, ["orig"]) - - train_df = spark.createDataFrame([(i,) for i in range(1, 6)], ["mask_id"]) - val_df = spark.createDataFrame([(i,) for i in range(6, 9)], ["mask_id"]) - test_df = spark.createDataFrame([(i,) for i in range(9, 11)], ["mask_id"]) + nodes_df = create_nodes_df_num_labels(spark, total_data_points=11) + # create_split_files_custom_split expects the NODE_MAPPING_STR values to + # match those provided in the custom mask files, in this case numbers 1-11. + # create_nodes_df_num_labels creates NODE_MAPPING_INT from 0-10, so we increment by one + # and assign to the NODE_MAPPING_STR column instead + nodes_df = nodes_df.withColumn(NODE_MAPPING_STR, F.col(NODE_MAPPING_INT) + 1) + + mask_col_name = "mask_id" + # Create test membership maps, which correspond to values of NODE_MAPPING_STR + train_df = spark.createDataFrame([(i,) for i in range(1, 6)], [mask_col_name]) + val_df = spark.createDataFrame([(i,) for i in range(6, 9)], [mask_col_name]) + test_df = spark.createDataFrame([(i,) for i in range(9, 11)], [mask_col_name]) train_df.repartition(1).write.parquet(f"{tmp_path}/train.parquet") val_df.repartition(1).write.parquet(f"{tmp_path}/val.parquet") test_df.repartition(1).write.parquet(f"{tmp_path}/test.parquet") - config_dict = { - "column": "orig", + label_config_dict = { + "column": NUM_LABEL_COL, "type": "classification", "split_rate": {"train": 0.8, "val": 0.1, "test": 0.1}, "custom_split_filenames": { "train": [f"{tmp_path}/train.parquet"], "valid": [f"{tmp_path}/val.parquet"], "test": [f"{tmp_path}/test.parquet"], - "column": ["mask_id"], + "column": [mask_col_name], }, } dghl_loader.input_prefix = "" - label_configs = [NodeLabelConfig(config_dict)] - label_metadata_dicts = dghl_loader._process_node_labels(label_configs, nodes_df, "orig") + label_configs = [NodeLabelConfig(label_config_dict)] + label_metadata_dicts = dghl_loader._process_node_labels( + label_configs, nodes_df, "dummy-node-type" + ) assert label_metadata_dicts.keys() == { "train_mask", "test_mask", "val_mask", - "orig", + NUM_LABEL_COL, } train_mask_df, val_mask_df, test_mask_df = read_masks_from_disk( @@ -1038,10 +1092,13 @@ def test_edge_custom_label(spark, dghl_loader: DistHeterogeneousGraphLoader, tmp def test_node_custom_label_multitask(spark, dghl_loader: DistHeterogeneousGraphLoader, tmp_path): """Test using custom label splits for nodes""" - data = [(i,) for i in range(0, 1200)] - # Create DataFrame - nodes_df = spark.createDataFrame(data, ["orig"]) + nodes_df = create_nodes_df_num_labels(spark, total_data_points=1200) + # create_split_files_custom_split expects the NODE_MAPPING_STR values to + # match those provided in the custom mask files, in this case numbers 1-11. + # create_nodes_df_num_labels creates NODE_MAPPING_INT from 0-10, so we increment by one + # and assign to the NODE_MAPPING_STR column instead + nodes_df = nodes_df.withColumn(NODE_MAPPING_STR, F.col(NODE_MAPPING_INT) + 1) train_df = spark.createDataFrame([(i,) for i in range(1, 1000)], ["mask_id"]) val_df = spark.createDataFrame([(i,) for i in range(1001, 1100)], ["mask_id"]) @@ -1057,7 +1114,7 @@ def test_node_custom_label_multitask(spark, dghl_loader: DistHeterogeneousGraphL ] # Will only do custom data split although provided split rate config_dict = { - "column": "orig", + "column": NUM_LABEL_COL, "type": "classification", "split_rate": {"train": 0.8, "val": 0.1, "test": 0.1}, "custom_split_filenames": { @@ -1074,7 +1131,7 @@ def test_node_custom_label_multitask(spark, dghl_loader: DistHeterogeneousGraphL "reg_test_mask", ] class_config_dict_split_rate = { - "column": "orig", + "column": STR_LABEL_COL, "type": "classification", "split_rate": {"train": 0.8, "val": 0.1, "test": 0.1}, "mask_field_names": class_mask_names_split_rate, @@ -1084,18 +1141,22 @@ def test_node_custom_label_multitask(spark, dghl_loader: DistHeterogeneousGraphL NodeLabelConfig(config_dict), NodeLabelConfig(class_config_dict_split_rate), ] - label_metadata_dicts = dghl_loader._process_node_labels(label_configs, nodes_df, "orig") + label_metadata_dicts = dghl_loader._process_node_labels( + label_configs, nodes_df, "dummy-node-type" + ) assert label_metadata_dicts.keys() == { *class_mask_names, *class_mask_names_split_rate, - "orig", + STR_LABEL_COL, + NUM_LABEL_COL, } train_mask_df, val_mask_df, test_mask_df = read_masks_from_disk( spark, dghl_loader, label_metadata_dicts, class_mask_names ) + # Check totals for masks train_total_ones = train_mask_df.agg(F.sum("custom_split_train_mask")).collect()[0][0] val_total_ones = val_mask_df.agg(F.sum("custom_split_val_mask")).collect()[0][0] test_total_ones = test_mask_df.agg(F.sum("custom_split_test_mask")).collect()[0][0] @@ -1103,26 +1164,22 @@ def test_node_custom_label_multitask(spark, dghl_loader: DistHeterogeneousGraphL assert val_total_ones == 99 assert test_total_ones == 99 - # Check the order of the train_mask_df - train_mask_df = train_mask_df.withColumn("order_check_id", F.monotonically_increasing_id()) - val_mask_df = val_mask_df.withColumn("order_check_id", F.monotonically_increasing_id()) - test_mask_df = test_mask_df.withColumn("order_check_id", F.monotonically_increasing_id()) - train_mask_df = train_mask_df.filter( - (F.col("order_check_id") > 0) & (F.col("order_check_id") < 1000) - ).drop("order_check_id") - val_mask_df = val_mask_df.filter( - (F.col("order_check_id") > 1000) & (F.col("order_check_id") < 1100) - ).drop("order_check_id") - test_mask_df = test_mask_df.filter( - (F.col("order_check_id") > 1100) & (F.col("order_check_id") < 1300) - ).drop("order_check_id") + # Check order of masks + train_mask_pd = train_mask_df.toPandas() + val_mask_pd = val_mask_df.toPandas() + test_mask_pd = test_mask_df.toPandas() - train_unique_rows = train_mask_df.distinct().collect() - assert len(train_unique_rows) == 1 and all(value == 1 for value in train_unique_rows[0]) - val_unique_rows = val_mask_df.distinct().collect() - assert len(val_unique_rows) == 1 and all(value == 1 for value in val_unique_rows[0]) - test_unique_rows = test_mask_df.distinct().collect() - assert len(test_unique_rows) == 1 and all(value == 1 for value in test_unique_rows[0]) + # We already know there's 999 ones in the train mask, so we can check the first 999 + # are 1 and the rest are guaranteed to be 0. NOTE: iloc ranges are zero-indexed + assert [1] * 999 == train_mask_pd["custom_split_train_mask"].iloc[:999].tolist() + + # We already know there's 99 ones in the val mask, so we can check the 99 + # after the train masks 1's are 1 and the rest will be 0 + assert [1] * 99 == val_mask_pd["custom_split_val_mask"].iloc[1000:1099].tolist() + + # Similarly, we know there's 99 ones in the test mask, so we can check the 99 values + # after the val masks 1's are 1 and the rest will be 0 + assert [1] * 99 == test_mask_pd["custom_split_test_mask"].iloc[1100:1199].tolist() # Check classification mask correctness train_mask_df, val_mask_df, test_mask_df = read_masks_from_disk( @@ -1248,3 +1305,172 @@ def test_edge_custom_label_multitask(spark, dghl_loader: DistHeterogeneousGraphL lp_mask_names, 3000, ) + + +def check_mask_nan_distribution( + df: pd.DataFrame, label_column: str, mask_name: str +) -> tuple[bool, str]: + """ + Check distribution of mask values (0/1) for NaN labels in a specific mask. + + Args: + df: DataFrame containing labels and masks + label_column: Name of the label column + mask_name: Name of the mask column to check + + Returns: + tuple: (has_error: bool, error_message: str) + has_error will be true if at least one NaN label has a mask value of 1. + error_message will contain information about the NaN label count and + the mask value distribution. + If no NaN labels correspond with 1 in the mask, error_message will be an empty string. + """ + nan_labels = df[label_column].isna() + mask_values = df[mask_name] + + # Count distribution of mask values for NaN labels + nan_dist = mask_values[nan_labels].value_counts().sort_index() + + if 1 in nan_dist: + return ( + True, + f"Found {nan_dist[1]} NaN labels with {mask_name}=1 " + f"(distribution of {mask_name} values for NaN labels: {nan_dist.to_dict()})", + ) + return (False, "") + + +def test_strip_common_prefix(dghl_loader: DistHeterogeneousGraphLoader): + """Test stripping common prefix from file paths.""" + stripped_path = dghl_loader._strip_common_prefix(f"{dghl_loader.output_prefix}/path/to/file") + + assert stripped_path == "path/to/file" + + stripped_path = dghl_loader._strip_common_prefix("/path/to/file") + assert stripped_path == "path/to/file" + + +def test_node_dist_label_order_partitioned( + spark: SparkSession, + dghl_loader: DistHeterogeneousGraphLoader, +): + """Test that label and mask order is maintained after label processing. + + NOTE: Tests DGHL code together with DistLabelLoader code because their results are coupled. + """ + label_col = STR_LABEL_COL + + # Create a Pandas DF with a label column with 10k "zero", 10k "one", 10k None rows + num_datapoints = 10**4 + ids = list(range(3 * num_datapoints)) + data_zeros = ["zero" for _ in range(num_datapoints)] + data_ones = ["one" for _ in range(num_datapoints)] + data_nan = [None for _ in range(num_datapoints)] + data = data_zeros + data_ones + data_nan + # Create DF with label data that contains "zero", "one", None values + # and a set of unique IDs that we treat as strings + pandas_input = pd.DataFrame.from_dict({label_col: data, NODE_MAPPING_STR: ids}) + # We shuffle the rows so that "zero", "one" and None values are mixed and not continuous + pandas_shuffled = pandas_input.sample(frac=1, random_state=42).reset_index(drop=True) + # Then we assign a sequential numerical ID that we use as an order identifier + # DGHL by default uses `NODE_MAPPING_INT` for the name of this column, so we + # use it here as well. + order_col = NODE_MAPPING_INT + pandas_shuffled[order_col] = ids + names_df = spark.createDataFrame(pandas_shuffled) + + # Consistently shuffle the DF to multiple partitions + names_df_repart = names_df.repartition(64) + + assert names_df_repart.rdd.getNumPartitions() == 64 + # Now we re-order by the order column, this way we have multiple partitions, + # but the incoming data are ordered by node id. + # This emulates the input DF that would result from the str-to-int node id + # mapping + names_df_repart = names_df_repart.sort(order_col) + + # Convert the partitioned/shuffled DF to pandas for test verification + names_df_repart_pd = names_df_repart.toPandas() + + classification_config_dict = { + "column": STR_LABEL_COL, + "type": "classification", + "split_rate": {"train": 0.8, "val": 0.1, "test": 0.1}, + } + label_configs = [ + NodeLabelConfig(classification_config_dict), + ] + + # Process the label, create masks and write the output DFs, + # we will read it from disk to emulate real downstream scenario + label_metadata_dicts = dghl_loader._process_node_labels( + label_configs, names_df_repart, "dummy-node-type" + ) + + assert label_metadata_dicts.keys() == { + STR_LABEL_COL, + "train_mask", + "val_mask", + "test_mask", + } + + # Apply transformation in Pandas to check against Spark, + # ensuring we use the same replacements as DGHL loader applied + label_map = dghl_loader.graph_info["label_map"] + expected_transformed_pd = names_df_repart_pd.replace( + { + "zero": label_map["zero"], + "one": label_map["one"], + } + ) + + def read_pandas_from_relative_list(file_list) -> pd.DataFrame: + full_paths = [f"{dghl_loader.output_prefix}/{single_file}" for single_file in file_list] + return pq.read_table(*full_paths).to_pandas() + + label_files = label_metadata_dicts[STR_LABEL_COL]["data"] + # These are the transformed label value, read in as a Pandas DF + actual_transformed_label_pd: pd.DataFrame = read_pandas_from_relative_list(label_files) + + # Expect the label values to match those in our local Pandas conversion, in-order + assert_frame_equal( + actual_transformed_label_pd.loc[:, [label_col]], + expected_transformed_pd.loc[:, [label_col]], + check_dtype=False, + ) + + # Now let's test the mask transformation, read the produced mask values as pandas DFs + train_mask = read_pandas_from_relative_list(label_metadata_dicts["train_mask"]["data"]) + val_mask = read_pandas_from_relative_list(label_metadata_dicts["val_mask"]["data"]) + test_mask = read_pandas_from_relative_list(label_metadata_dicts["test_mask"]["data"]) + + mask_names = ["train_mask", "val_mask", "test_mask"] + masks_and_names = zip([train_mask, val_mask, test_mask], mask_names) + + # Add mask values to the label pandas DF as individual columns + # This allows us to easily check label values against mask values + for mask, mask_name in masks_and_names: + actual_transformed_label_pd[mask_name] = mask + + # Check every mask value against the label values, and report errors + # if there exists a mask that has value 1 in a location where the label is NaN + errors = [] + for mask_name in mask_names: + has_error, error_msg = check_mask_nan_distribution( + actual_transformed_label_pd, + label_col, + mask_name, + ) + if has_error: + errors.append(error_msg) + + # TODO: Check the approximated numbers of 1s we expected in the masks + + # If any issues were found, raise error with grouped values as info + if errors: + # Perform the groupby operation for a human-friendly printout + grouped_values = actual_transformed_label_pd.groupby([label_col], dropna=False).agg( + {i: "value_counts" for i in ["train_mask", "val_mask", "test_mask"]} + ) + print(grouped_values) + raise ValueError("\n".join(errors)) diff --git a/graphstorm-processing/tests/test_dist_label_loader.py b/graphstorm-processing/tests/test_dist_label_loader.py index ce5e6b0fe7..e700d65773 100644 --- a/graphstorm-processing/tests/test_dist_label_loader.py +++ b/graphstorm-processing/tests/test_dist_label_loader.py @@ -17,11 +17,12 @@ import numpy as np from pyspark.sql import DataFrame, SparkSession -from pyspark.sql.types import StructField, StructType, StringType +from pyspark.sql.types import IntegerType, StructField, StructType, StringType from graphstorm_processing.data_transformations.dist_label_loader import DistLabelLoader from graphstorm_processing.config.label_config_base import LabelConfig +from graphstorm_processing.constants import NODE_MAPPING_INT def test_dist_classification_label(spark: SparkSession, check_df_schema): @@ -34,15 +35,17 @@ def test_dist_classification_label(spark: SparkSession, check_df_schema): } data = [ - ("mark",), - ("john",), - ("tara",), - ("jen",), - (None,), + ("mark", 0), + ("john", 1), + ("tara", 2), + ("jen", 3), + (None, 4), ] - names_df = spark.createDataFrame(data, schema=[label_col]) + names_df = spark.createDataFrame(data, schema=[label_col, NODE_MAPPING_INT]) - label_transformer = DistLabelLoader(LabelConfig(classification_config), spark) + label_transformer = DistLabelLoader( + LabelConfig(classification_config), spark, order_col=NODE_MAPPING_INT + ) transformed_labels = label_transformer.process_label(names_df) @@ -104,12 +107,19 @@ def test_dist_multilabel_classification(spark: SparkSession, check_df_schema): "separator": "|", } - data = [("1|2",), ("3|4",), ("5|6",), ("7|8",), ("NaN",)] + data = [("1|2", 0), ("3|4", 1), ("5|6", 2), ("7|8", 3), ("NaN", 4)] - schema = StructType([StructField("ratings", StringType(), True)]) + schema = StructType( + [ + StructField(label_col, StringType(), True), + StructField(NODE_MAPPING_INT, IntegerType(), True), + ] + ) label_df = spark.createDataFrame(data, schema=schema) - label_transformer = DistLabelLoader(LabelConfig(multilabel_config), spark) + label_transformer = DistLabelLoader( + LabelConfig(multilabel_config), spark, order_col=NODE_MAPPING_INT + ) transformed_labels = label_transformer.process_label(label_df) @@ -117,7 +127,7 @@ def test_dist_multilabel_classification(spark: SparkSession, check_df_schema): assert set(label_map.keys()) == {"1", "2", "3", "4", "5", "6", "7", "8", "NaN"} - check_df_schema(transformed_labels) + check_df_schema(transformed_labels.select(label_col)) transformed_rows = transformed_labels.collect()