Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Suggested change to std/var preprocessing to improve precision #422

Open
jemmajeffree opened this issue Feb 17, 2025 · 3 comments
Open

Suggested change to std/var preprocessing to improve precision #422

jemmajeffree opened this issue Feb 17, 2025 · 3 comments

Comments

@jemmajeffree
Copy link

Hi,
I've noticed that in a few rare situations, groupby and flox can return quite noisy standard deviations. In situations where the mean of an array is much larger than the standard deviation (such as deep ocean salinity, raised here), flox returns noisier values on dask arrays than on loaded numpy arrays. In extreme situations, the standard deviation of a dask array can contain NaNs from square-rooting negative variances.

I'm guessing it's the same idea as #386, in which case @dcherian has thought about this for much longer than I have. I've done a little bit of looking through the code, and could easily have missed something about how this works with neighbouring functions, but my thoughts on the potential problem and how it might be addressed are below.

Minimal complete verifiable example:

import numpy as np
import xarray as xr

l =12000
np.random.seed(1)
test_data = xr.DataArray(np.random.uniform(0,1,l)/100+1000000,dims=('time',)  # huge mean with relatively small variability
                        ).assign_coords({'month':xr.DataArray(np.arange(l)%12,dims=('time',))})

# with numpy arrays returns reasonable and consistent values
test_data.groupby('month').std('time')
# array([0.00283648, 0.00281895, 0.00287791, 0.00287652, 0.00287337,
#        0.00287037, 0.00289802, 0.00289441, 0.00285839, 0.00296478,
#        0.00284787, 0.00292089])

# using lazy computation/dask
dask_test_data = test_data.chunk({'time':100})
dask_test_data.groupby('month').std('time').load()
# array([0.01118034, 0.01118034, 0.01118034, 0.01581139, 0.        ,
#        0.01581139, 0.01118034, 0.01118034, 0.01118034,        nan,
#               nan, 0.        ])

A functional workaround is to subtract the mean before calculating standard deviation:

(dask_test_data.groupby('month')-dask_test_data.groupby('month').mean('time')).groupby('month').std('time').load()

My understanding is that the distinction comes from aggregate_npg.py improving precision by subtracting the first non-nan element of the array, a preprocessing step skipped by aggregations.py. This solution is probably not quite as stable as subtracting the mean, but the first element should be really close to the mean if the standard deviation is small, and it might be faster.

I’d suggest that to improve precision and match the numpy engine behaviour in aggregations_npg.py, the flox engine implementation for dask arrays of nanstd,nanvar,std,var could have a preprocessor that looks something like this:

def var_std_preprocess(array, axis): # Not sure of naming conventions, sorry
    """Subtracts first value of array from whole array, 
    to improve numerical precision of nanstd, nanvar, std, var

    Adapted from from argreduce_preprocess and _var_std_wrapper in aggregate_npg.py
    """
    import dask.array  # Copied from argreduce_preprocess, but maybe these shouldn’t be within the function? 
    import numpy as np # For either this function or argreduce_preprocess?
    
    # NEXT LINE IS PSEUDOCODE; I’m not entirely sure how to apply it lazily
    # If it doesn’t cost anything speed wise, then probably better to use mean. Happy to run some time tests on either
    first_elements = nanfirst(array,axis) 

    def subtract_first(array_, first_elements_):
        return array_-first_elements_

    return dask.array.map_blocks(
        subtract_first,
        array,
        first,
        dtype=array.dtype,
        meta=array._meta,
        name="groupby-var_std-preprocess",
    )

and is included in the Aggregations definition like so:

nanstd = Aggregation(
    "nanstd",
    preprocess=var_std_preprocess, #UPDATED LINE
    chunk=("nansum_of_squares", "nansum", "nanlen"),
    combine=("sum", "sum", "sum"),
    finalize=_std_finalize,
    fill_value=0,
    final_fill_value=np.nan,
    dtypes=(None, None, np.intp),
    final_dtype=np.floating,
)

It seems to work if first_elements is naively array[0] in the one-dimensional, no-nans case, but I’m not sure how to generalise it and apply nanfirst without the usual layers/wrappers. (aggregate_npg.py uses first = _get_aggregate(engine).aggregate(group_idx, array, func="nanfirst", axis=axis), but I don't think this syntax translates to the flox/dask implementation) . If you can give me a few tips or examples to work from, then I’m happy to try implement this behaviour.

Happy also to discuss alternatives, or to provide a pull request if that's easier to work with.

This is also my first time reading through flox code in detail (it's really nicely written and documented, by the way, was lovely to read), and one of the first times I’ve interacted with public github repos, so I’d appreciate any feedback or corrections on what's useful to provide when describing issues.

@dcherian
Copy link
Collaborator

dcherian commented Feb 17, 2025

First, thanks for the really well thought out post and deep-dive in to the code. That's not easy to do. I really appreciate it. (I also hope you're having fun, which it seems like you are :) )

My first thought is that nanfirst is one full pass through the data which we should try hard to avoid (if we can). To that end, the idea in #386 is to use Welford's method (though the second post in that thread links to a paper that does something different). I would first understand what intermediates we need to track during the compute and then modify

flox/flox/aggregations.py

Lines 349 to 358 in ca57681

var = Aggregation(
"var",
chunk=("sum_of_squares", "sum", "nanlen"),
combine=("sum", "sum", "sum"),
finalize=_var_finalize,
fill_value=0,
final_fill_value=np.nan,
dtypes=(None, None, np.intp),
final_dtype=np.floating,
)
to match. Is this something you are up for exploring?

For reference, Cubed implements the Welford algo, I think: https://github.com/cubed-dev/cubed/blob/745f564964933d6178666c62e71d92fbc0b3fd2b/cubed/array_api/statistical_functions.py#L209. This would be a useful reference

@jemmajeffree
Copy link
Author

Thanks @dcherian, I'm definitely having fun.

I see your point about avoiding a first full pass through the data. I guess I was hoping that there was a way to pick any datapoint that's not a NaN in O(1) time, but I can see how that would be difficult to implement

I'm not sure how well Welford's method (at least, that described by the blog post you linked) would adapt to parallelization, because calculations for each additional point rely on all those that have come before. However, the variant described in the Schubert and Gertz (2018) paper you linked would work for combining variance calculated on individual chunks. Cubed seems to be doing something along these lines, merging variance from subsets of data with more than one datapoint.

How much is touching data within a chunk twice a problem? I'm wondering if applying the simpler two-pass solution (subtract mean, then sum squares) within each chunk, and then combining these as described in Schubert and Gertz (2018) increases precision with minimal loss of speed for the typical xarray use case. I'll run some tests for speed, but I think this is what the "proposed minibatch" in Schubert and Gertz is describing.

I'm keen to explore :) It's looking to me like implementing anything along these lines would require defining something new that can be applied blockwise to dask chunks — would you agree or have I missed something here?

@dcherian
Copy link
Collaborator

dcherian commented Feb 17, 2025

Thanks @dcherian, I'm definitely having fun.

🥳 !

because calculations for each additional point rely on all those that have come before.

In general these parallel algorithms are quite similar, apply something blockwise ("chunk"), combine intermediates ("combine"), and potentially add a final transformation ("aggregate"). The phrase to look for is "incremental update".

For example, the snippet at the end of https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm updates the estimate for the new value, and scrolling down "Chan's algorithm" has formulae for combining two intermediate estimates.

It's looking to me like implementing anything along these lines would require defining something new that can be applied blockwise to dask chunks

We already have this in the chunk parameter to Aggregation. For simplicity in many cases, it is simply a string, but it can be a function. Or at least it should accept any function though I do see that there are no examples of that sort.

The principle behind welford's algorithm seems to be to track the mean and variance instead of sum & sum-of-squares which we do currently. I haven't read Schubert and Gertz but I'd encourage you to experiment. If you open a PR i'm happy to help with comments or tiny code changes to unblock you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants