diff --git a/pyproject.toml b/pyproject.toml index 876db6e..bde5c86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,10 +39,10 @@ module-name = "tsdownsample._rust._tsdownsample_rs" # The path to place the comp # Linting [tool.ruff] -select = ["E", "F", "I"] line-length = 88 -extend-select = ["Q"] -ignore = ["E402", "F403"] +lint.select = ["E", "F", "I"] +lint.extend-select = ["Q"] +lint.ignore = ["E402", "F403"] # Formatting [tool.black] diff --git a/tests/test_tsdownsample.py b/tests/test_tsdownsample.py index 993faa6..3f5b244 100644 --- a/tests/test_tsdownsample.py +++ b/tests/test_tsdownsample.py @@ -44,11 +44,6 @@ def generate_rust_downsamplers() -> Iterable[AbstractDownsampler]: yield downsampler -def generate_rust_nan_downsamplers() -> Iterable[AbstractDownsampler]: - for downsampler in RUST_NAN_DOWNSAMPLERS: - yield downsampler - - def generate_all_downsamplers() -> Iterable[AbstractDownsampler]: for downsampler in RUST_DOWNSAMPLERS + RUST_NAN_DOWNSAMPLERS + OTHER_DOWNSAMPLERS: yield downsampler @@ -106,7 +101,7 @@ def test_rust_downsampler(downsampler: AbstractDownsampler): assert s_downsampled[-1] == len(arr) - 1 -@pytest.mark.parametrize("downsampler", generate_rust_nan_downsamplers()) +@pytest.mark.parametrize("downsampler", RUST_NAN_DOWNSAMPLERS) def test_rust_nan_downsampler(downsampler: AbstractRustNaNDownsampler): """Test the Rust NaN downsamplers.""" datapoints = generate_nan_datapoints() @@ -360,3 +355,41 @@ def test_nan_minmaxlttb_downsampler(): s_downsampled = NaNMinMaxLTTBDownsampler().downsample(arr, n_out=100) arr_downsampled = arr[s_downsampled] assert np.all(np.isnan(arr_downsampled[1:-1])) # first and last are not NaN + + +@pytest.mark.parametrize("downsampler", RUST_DOWNSAMPLERS) +def test_no_nans_omitted(downsampler: AbstractDownsampler): + n = 10_000 + y = np.arange(n, dtype=np.float64) + for i in range(1, 100): + y[i + 100] = np.nan + + s_downsampled = downsampler.downsample(y, n_out=1000) + assert np.all(~np.isnan(y[s_downsampled])) + s_downsampled = downsampler.downsample(y, n_out=1000, parallel=True) + assert np.all(~np.isnan(y[s_downsampled])) + + x = np.arange(n) + s_downsampled = downsampler.downsample(x, y, n_out=1000) + assert np.all(~np.isnan(y[s_downsampled])) + s_downsampled = downsampler.downsample(x, y, n_out=1000, parallel=True) + assert np.all(~np.isnan(y[s_downsampled])) + + +@pytest.mark.parametrize("downsampler", RUST_NAN_DOWNSAMPLERS) +def tests_nans_returned(downsampler: AbstractDownsampler): + n = 10_000 + y = np.arange(n, dtype=np.float64) + for i in range(1, 100): + y[i + 100] = np.nan + + s_downsampled = downsampler.downsample(y, n_out=1000) + assert np.any(np.isnan(y[s_downsampled])) + s_downsampled = downsampler.downsample(y, n_out=1000, parallel=True) + assert np.any(np.isnan(y[s_downsampled])) + + x = np.arange(n) + s_downsampled = downsampler.downsample(x, y, n_out=1000) + assert np.any(np.isnan(y[s_downsampled])) + s_downsampled = downsampler.downsample(x, y, n_out=1000, parallel=True) + assert np.any(np.isnan(y[s_downsampled])) diff --git a/tsdownsample/downsampling_interface.py b/tsdownsample/downsampling_interface.py index 9c05c54..e327976 100644 --- a/tsdownsample/downsampling_interface.py +++ b/tsdownsample/downsampling_interface.py @@ -335,6 +335,10 @@ def _switch_mod_with_x_and_y( # TIMEDELTA -> i64 (timedelta64 is viewed as int64) raise ValueError(f"Unsupported data type (for x): {x_dtype}") + def _prune_nans(self, sampled_idxs: np.ndarray, y: np.ndarray) -> np.ndarray: + """Remove all nan indices.""" + return sampled_idxs[~np.isnan(y[sampled_idxs])] + def _downsample( self, x: Union[np.ndarray, None], @@ -359,11 +363,11 @@ def _downsample( ## Viewing the x-data as different dtype (if necessary) if x is None: downsample_f = self._switch_mod_with_y(y.dtype, mod) - return downsample_f(y, n_out, **kwargs) + return self._prune_nans(downsample_f(y, n_out, **kwargs), y) x = self._view_x(x) ## Getting the appropriate downsample function downsample_f = self._switch_mod_with_x_and_y(x.dtype, y.dtype, mod) - return downsample_f(x, y, n_out, **kwargs) + return self._prune_nans(downsample_f(x, y, n_out, **kwargs), y) def downsample(self, *args, n_out: int, parallel: bool = False, **kwargs): """Downsample the data in x and y. @@ -400,6 +404,11 @@ def _downsample_func_prefix(self) -> str: """The prefix of the downsample functions in the rust module.""" return NAN_DOWNSAMPLE_F + ## Overriding the _prune_nans method to return the sampled indices without pruning + def _prune_nans(self, sampled_idxs: np.ndarray, y: np.ndarray) -> np.ndarray: + """Remove all nan indices.""" + return sampled_idxs + def _switch_mod_with_y( self, y_dtype: np.dtype, mod: ModuleType, downsample_func: Optional[str] = None ) -> Callable: