-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Suggested change to std/var preprocessing to improve precision #422
Comments
First, thanks for the really well thought out post and deep-dive in to the code. That's not easy to do. I really appreciate it. (I also hope you're having fun, which it seems like you are :) ) My first thought is that Lines 349 to 358 in ca57681
For reference, Cubed implements the Welford algo, I think: https://github.com/cubed-dev/cubed/blob/745f564964933d6178666c62e71d92fbc0b3fd2b/cubed/array_api/statistical_functions.py#L209. This would be a useful reference |
Thanks @dcherian, I'm definitely having fun. I see your point about avoiding a first full pass through the data. I guess I was hoping that there was a way to pick any datapoint that's not a NaN in O(1) time, but I can see how that would be difficult to implement I'm not sure how well Welford's method (at least, that described by the blog post you linked) would adapt to parallelization, because calculations for each additional point rely on all those that have come before. However, the variant described in the Schubert and Gertz (2018) paper you linked would work for combining variance calculated on individual chunks. Cubed seems to be doing something along these lines, merging variance from subsets of data with more than one datapoint. How much is touching data within a chunk twice a problem? I'm wondering if applying the simpler two-pass solution (subtract mean, then sum squares) within each chunk, and then combining these as described in Schubert and Gertz (2018) increases precision with minimal loss of speed for the typical xarray use case. I'll run some tests for speed, but I think this is what the "proposed minibatch" in Schubert and Gertz is describing. I'm keen to explore :) It's looking to me like implementing anything along these lines would require defining something new that can be applied blockwise to dask chunks — would you agree or have I missed something here? |
🥳 !
In general these parallel algorithms are quite similar, apply something blockwise ("chunk"), combine intermediates ("combine"), and potentially add a final transformation ("aggregate"). The phrase to look for is "incremental update". For example, the snippet at the end of https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm updates the estimate for the new value, and scrolling down "Chan's algorithm" has formulae for combining two intermediate estimates.
We already have this in the The principle behind welford's algorithm seems to be to track the mean and variance instead of sum & sum-of-squares which we do currently. I haven't read Schubert and Gertz but I'd encourage you to experiment. If you open a PR i'm happy to help with comments or tiny code changes to unblock you. |
Hi,
I've noticed that in a few rare situations, groupby and flox can return quite noisy standard deviations. In situations where the mean of an array is much larger than the standard deviation (such as deep ocean salinity, raised here), flox returns noisier values on dask arrays than on loaded numpy arrays. In extreme situations, the standard deviation of a dask array can contain NaNs from square-rooting negative variances.
I'm guessing it's the same idea as #386, in which case @dcherian has thought about this for much longer than I have. I've done a little bit of looking through the code, and could easily have missed something about how this works with neighbouring functions, but my thoughts on the potential problem and how it might be addressed are below.
Minimal complete verifiable example:
A functional workaround is to subtract the mean before calculating standard deviation:
My understanding is that the distinction comes from aggregate_npg.py improving precision by subtracting the first non-nan element of the array, a preprocessing step skipped by aggregations.py. This solution is probably not quite as stable as subtracting the mean, but the first element should be really close to the mean if the standard deviation is small, and it might be faster.
I’d suggest that to improve precision and match the numpy engine behaviour in aggregations_npg.py, the flox engine implementation for dask arrays of nanstd,nanvar,std,var could have a preprocessor that looks something like this:
and is included in the Aggregations definition like so:
It seems to work if
first_elements
is naivelyarray[0]
in the one-dimensional, no-nans case, but I’m not sure how to generalise it and apply nanfirst without the usual layers/wrappers. (aggregate_npg.py usesfirst = _get_aggregate(engine).aggregate(group_idx, array, func="nanfirst", axis=axis)
, but I don't think this syntax translates to the flox/dask implementation) . If you can give me a few tips or examples to work from, then I’m happy to try implement this behaviour.Happy also to discuss alternatives, or to provide a pull request if that's easier to work with.
This is also my first time reading through flox code in detail (it's really nicely written and documented, by the way, was lovely to read), and one of the first times I’ve interacted with public github repos, so I’d appreciate any feedback or corrections on what's useful to provide when describing issues.
The text was updated successfully, but these errors were encountered: