Skip to content

Commit

Permalink
Merge pull request #134 from amosproj/feature/#106-Gaussian-Smoothing
Browse files Browse the repository at this point in the history
#106 Gaussian Smoothing
  • Loading branch information
kristen149 authored Jan 28, 2025
2 parents 06b9977 + 21a6c6a commit ad015b4
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.spark.gaussian_smoothing
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ nav:
- Duplicate Detetection: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/duplicate_detection.md
- Out of Range Value Filter: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/out_of_range_value_filter.md
- Flatline Filter: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/flatline_filter.md
- Gaussian Smoothing: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/gaussian_smoothing.md
- Dimensionality Reduction: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/dimensionality_reduction.md
- Interval Filtering: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/interval_filtering.md
- K-Sigma Anomaly Detection: sdk/code-reference/pipelines/data_quality/data_manipulation/spark/k_sigma_anomaly_detection.md
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import numpy as np
from pyspark.sql.types import FloatType
from scipy.ndimage import gaussian_filter1d
from pyspark.sql import DataFrame as PySparkDataFrame, Window
from pyspark.sql import functions as F

from src.sdk.python.rtdip_sdk.pipelines._pipeline_utils.models import (
Libraries,
SystemType,
)
from ..interfaces import DataManipulationBaseInterface


class GaussianSmoothing(DataManipulationBaseInterface):
def __init__(
self,
df: PySparkDataFrame,
sigma: float,
mode: str = "temporal",
id_col: str = "id",
timestamp_col: str = "timestamp",
value_col: str = "value",
) -> None:
if not isinstance(df, PySparkDataFrame):
raise TypeError("df must be a PySpark DataFrame")
if not isinstance(sigma, (int, float)) or sigma <= 0:
raise ValueError("sigma must be a positive number")
if mode not in ["temporal", "spatial"]:
raise ValueError("mode must be either 'temporal' or 'spatial'")

if id_col not in df.columns:
raise ValueError(f"Column {id_col} not found in DataFrame")
if timestamp_col not in df.columns:
raise ValueError(f"Column {timestamp_col} not found in DataFrame")
if value_col not in df.columns:
raise ValueError(f"Column {value_col} not found in DataFrame")

self.df = df
self.sigma = sigma
self.mode = mode
self.id_col = id_col
self.timestamp_col = timestamp_col
self.value_col = value_col

@staticmethod
def system_type():
return SystemType.PYSPARK

@staticmethod
def libraries():
libraries = Libraries()
return libraries

@staticmethod
def settings() -> dict:
return {}

@staticmethod
def create_gaussian_smoother(sigma_value):
def apply_gaussian(values):
if not values:
return None
values_array = np.array([float(v) for v in values])
smoothed = gaussian_filter1d(values_array, sigma=sigma_value)
return float(smoothed[-1])

return apply_gaussian

def filter(self) -> PySparkDataFrame:

smooth_udf = F.udf(self.create_gaussian_smoother(self.sigma), FloatType())

if self.mode == "temporal":
window = (
Window.partitionBy(self.id_col)
.orderBy(self.timestamp_col)
.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing)
)
else: # spatial mode
window = (
Window.partitionBy(self.timestamp_col)
.orderBy(self.id_col)
.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing)
)

collect_list_expr = F.collect_list(F.col(self.value_col)).over(window)

return self.df.withColumn(self.value_col, smooth_udf(collect_list_expr))
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright 2025 RTDIP
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
from pyspark.sql import SparkSession

from src.sdk.python.rtdip_sdk.pipelines.data_quality.data_manipulation.spark.gaussian_smoothing import (
GaussianSmoothing,
)


@pytest.fixture(scope="session")
def spark_session():
spark = (
SparkSession.builder.master("local[2]")
.appName("GaussianSmoothingTest")
.getOrCreate()
)
yield spark
spark.stop()


def test_gaussian_smoothing_temporal(spark_session: SparkSession):
df = spark_session.createDataFrame(
[
("A2PS64V0J.:ZUX09R", "2024-01-02 03:49:45.000", "Good", "0.129999995"),
("A2PS64V0J.:ZUX09R", "2024-01-02 07:53:11.000", "Good", "0.119999997"),
("A2PS64V0J.:ZUX09R", "2024-01-02 11:56:42.000", "Good", "0.129999995"),
("A2PS64V0J.:ZUX09R", "2024-01-02 16:00:12.000", "Good", "0.150000006"),
("A2PS64V0J.:ZUX09R", "2024-01-02 20:03:46.000", "Good", "0.340000004"),
],
["TagName", "EventTime", "Status", "Value"],
)

smoother = GaussianSmoothing(
df=df,
sigma=2.0,
id_col="TagName",
mode="temporal",
timestamp_col="EventTime",
value_col="Value",
)
result_df = smoother.filter()

original_values = df.select("Value").collect()
smoothed_values = result_df.select("Value").collect()

assert (
original_values != smoothed_values
), "Values should be smoothed and not identical"

assert result_df.count() == df.count(), "Result should have same number of rows"


def test_gaussian_smoothing_spatial(spark_session: SparkSession):
df = spark_session.createDataFrame(
[
("A2PS64V0J.:ZUX09R", "2024-01-02 03:49:45.000", "Good", "0.129999995"),
("A2PS64V0J.:ZUX09R", "2024-01-02 07:53:11.000", "Good", "0.119999997"),
("A2PS64V0J.:ZUX09R", "2024-01-02 11:56:42.000", "Good", "0.129999995"),
("A2PS64V0J.:ZUX09R", "2024-01-02 16:00:12.000", "Good", "0.150000006"),
("A2PS64V0J.:ZUX09R", "2024-01-02 20:03:46.000", "Good", "0.340000004"),
],
["TagName", "EventTime", "Status", "Value"],
)

# Apply smoothing
smoother = GaussianSmoothing(
df=df,
sigma=3.0,
id_col="TagName",
mode="spatial",
timestamp_col="EventTime",
value_col="Value",
)
result_df = smoother.filter()

original_values = df.select("Value").collect()
smoothed_values = result_df.select("Value").collect()

assert (
original_values != smoothed_values
), "Values should be smoothed and not identical"
assert result_df.count() == df.count(), "Result should have same number of rows"


def test_interval_detection_large_data_set(spark_session: SparkSession):
# Should not timeout
base_path = os.path.dirname(__file__)
file_path = os.path.join(base_path, "../../test_data.csv")

df = spark_session.read.option("header", "true").csv(file_path)

smoother = GaussianSmoothing(
df=df,
sigma=1,
id_col="TagName",
mode="temporal",
timestamp_col="EventTime",
value_col="Value",
)

actual_df = smoother.filter()
assert (
actual_df.count() == df.count()
), "Output should have same number of rows as input"


def test_gaussian_smoothing_invalid_mode(spark_session: SparkSession):
# Create test data
df = spark_session.createDataFrame(
[
("A2PS64V0J.:ZUX09R", "2024-01-02 03:49:45.000", "Good", "0.129999995"),
("A2PS64V0J.:ZUX09R", "2024-01-02 07:53:11.000", "Good", "0.119999997"),
("A2PS64V0J.:ZUX09R", "2024-01-02 11:56:42.000", "Good", "0.129999995"),
("A2PS64V0J.:ZUX09R", "2024-01-02 16:00:12.000", "Good", "0.150000006"),
("A2PS64V0J.:ZUX09R", "2024-01-02 20:03:46.000", "Good", "0.340000004"),
],
["TagName", "EventTime", "Status", "Value"],
)

# Attempt to initialize with an invalid mode
with pytest.raises(ValueError, match="mode must be either 'temporal' or 'spatial'"):
GaussianSmoothing(
df=df,
sigma=2.0,
id_col="TagName",
mode="invalid_mode", # Invalid mode
timestamp_col="EventTime",
value_col="Value",
)

0 comments on commit ad015b4

Please sign in to comment.