-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ZarrTrace #7540
base: main
Are you sure you want to change the base?
Add ZarrTrace #7540
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #7540 +/- ##
==========================================
- Coverage 92.84% 92.84% -0.01%
==========================================
Files 106 107 +1
Lines 17685 18105 +420
==========================================
+ Hits 16420 16809 +389
- Misses 1265 1296 +31
|
This is an important issue to keep track of when we'll eventually want to read the zarr store and create an |
pymc/backends/zarr.py
Outdated
_dtype = np.dtype(dtype) | ||
if np.issubdtype(_dtype, np.floating): | ||
return (np.nan, _dtype, None) | ||
elif np.issubdtype(_dtype, np.integer): | ||
return (-1_000_000, _dtype, None) | ||
elif np.issubdtype(_dtype, "bool"): | ||
return (False, _dtype, None) | ||
elif np.issubdtype(_dtype, "str"): | ||
return ("", _dtype, None) | ||
elif np.issubdtype(_dtype, "datetime64"): | ||
return (np.datetime64(0, "Y"), _dtype, None) | ||
elif np.issubdtype(_dtype, "timedelta64"): | ||
return (np.timedelta64(0, "Y"), _dtype, None) | ||
else: | ||
return (None, _dtype, numcodecs.Pickle()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question from my own ignorance, since I don't understand so much how fill values are implemented. Are we just hoping that these fill values don't actually occur in the data?
If so, this seems especially perilous for bool
😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, they are supposed to be the initialisation values for the entries. When the sampler completes its run, all entries will be filled with the correct value. Zarr just needs you to tell it what value to give to unwritten places. In the storage, these entries are never actually written, they are produced when you ask for the concrete values in the array.
The dangerous part is that xarray is interpreting fill_value
as an indicator of whether the actual value should be masked to nan. This seems to be because of the netcdf standard treats fill_value as something completely different.
To keep things as clean as possible, I’ll store the draw_idx of each chain in a separate group that should never be converted to xarray.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, that makes a lot more sense now, thanks for the explanation!
In case it's non-obvious to more than me, maybe it would be helpful to try to make this more self-evident. Perhaps by calling the function get_initial_fill_value_and_codec
, or make some little comment that the fill value is used for initialization?
Yes, therefore I would recommend not to use them for any new implementation.
Just to clarify:
It should be quite simple to implement a
Yes and No. I would say ArviZ is a first-class citizen, because
I consider 2. the primary weakness, and it's the only reason why I don't use McBackend by default. First we must find answers to:
Protocol buffers I used them because they are convenient for specifying a data structure and not having to write serialization/deserialization code for it. And they are only for the constant metadata From the Python perspective this could also be done with Is that tight integration? The important design decision is not which implementation is used to serialize/deserialize metadata, but rather to freeze and detach these (meta)data from chains, draws and stats:
|
Thanks @michaelosthege for the feedback!
I understand what the two backends for McBackend offer and that McBackend already has a test suite. Despite this, I'll try to argue in favor of writing something that's detached from McBackend.
The way I see this is that McBackend offers a signature to convert from a kind of storage (like
The key thing is that I added these groups to the zarr hierarchy, having them as
I decided to only focus on MCMC for now, and I'm trying to make
Yes, you can deserialize almost all of the contents into C++ or Rust. zarr can be readable from python, Julia, C++, rust, javascript and Java. The only content that would not be readable in other languages would come from arrays with The latter isn't a problem in my opinion because it is related exclusively to the python pymc step methods, and I detached it to its own private group in the zarr hierarchy. The former might be more problematic, but since Having said that, there are other benefits that we would get if we were to rely on zarr directly, such as:
I think that these added benefits plus the drop in maintenance costs in the long run warrant using zarr directly and not through a new backend for McBackend. |
@lucianopaz, have you done some benchmarks with this yet (in particular with S3)? I'm a bit concerned that with (1, 1, ...) chunk size that I/O will be a bottleneck. |
3206597
to
69bb2ac
Compare
No, I haven't. But I've made the chunksize customizable now via the Anyway, my long term goal is to add something like checkpoints during sampling where the trace gets dumped into a file along with the sampling state of the step methods. I think that I'll eventually make the chunks align with that, so that we don't lose samples that were drawn before the checkpoint if sampling gets terminated afterwards (before having finished). |
By the way, I've added a |
1f6d646
to
413d724
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I love this direction. I left a comment on ArviZ integration.
I also have more ideas of things that can be done to integrate better the sampling outputs with inferencedata but it might be better to address them in follow up PRs. Not having to go into the current ArviZ converter might help get this things off the ground. Many of these are around since #5160
Also, anything on ArviZ side that can help with this let me know
Better sample_stats. sample_stats doesn't necessarily need to restrict itself to having chain, draw
dimensions in all its variables. the mass matrix could also go in there and a divergence_id
even (with extra coordinate values or a multiindex to store the start and end points of divergences) which would complement the boolean diverging
variable with chain, draw
dimension.
samples in the unconstrained space. related to #6721 and to a lesser extent arviz-devs/arviz-base#8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I absolutely love this! :D
pymc/backends/zarr.py
Outdated
def setup(self, draws: int, chain: int, sampler_vars: Sequence[dict] | None): # type: ignore[override] | ||
self.chain = chain | ||
|
||
def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not check the source, but I think zarr will write the whole chunk each time we set a draw here, even if that chunk is not full yet. If that is indeed the case, we should be able to speed this up a lot if draws_per_chunk is >1 if we buffer draws_per_chunk
draws, and the set values in one go.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, it would be a shame that zarr itself doesn't buffer the write until the chunk is filled.
Great! Yes, let's try to address those in other PRs.
I don't know if
I tried to stay close to how pymc is returning things right now. I agree that this could be greatly improved, but maybe we can do so in follow up iterations.
I'll try to add this. It doesn't seem to be difficult. |
0c5e73b
to
ee0a36d
Compare
I just did a little experiment with strace, to see what zarr does when we write a value to a chunk. I wrote a little script that creates the array, writes two chunks and then closes the file. With import zarr
import numpy as np
import sys
store = zarr.DirectoryStore("zarr-test.zarr")
#store = zarr.LRUStoreCache(store, max_size=2**28)
data = zarr.open(store)
foo = data.array("foo", np.zeros((0, 0, 0)), chunks=(10, 10, 1))
foo.resize((1000, 10, 1))
# Mark the position in the code to make it easier to find the correct part
print("start--", flush=True, file=sys.stderr)
foo[0, 0, 0] = 1.0
print("mid--", flush=True, file=sys.stderr)
foo[1, 0, 0] = 2.0
print("done--", flush=True, file=sys.stderr) The first write triggers this:
The second write triggers this:
So there's a lot going on for each write to the array. If I read this correctly, for the second store it actually reads the chunk from the disc, then modifies the chunk with the indexing update, writes the new chunk to a temporary file and then replaces that with the original chunk file. For the first write it skips reading in the chunk data, because there still is nothing there to read. So I think if we want to get good performance from this, we should try to combine writes, by buffering |
Can we keep this PR about adding Zarr trace only? And not sampling state/ resuming |
Re flaky test, @bwengals was surprised by the magnitude of the error and there is an open issue for it. If you're confident about there being no issue please close the associated issue. Perhaps as a separate PR to unblock other contributions, depending on how long this takes to get merged |
Put a fix in for the flaky GP test here #7567 |
Yes, the only thing that I still need to figure out is what needs to be changed in the current |
Your fix looks great. I was just trying out stuff to see if the CI would build. I'll drop the commit with the test patch. |
ddf0ada
to
002f890
Compare
@ricardoV94, @OriolAbril, @aseyboldt, @maresb and @michaelosthege. I think that this is ready for a proper review. I'll try to handle the mambaforge sunset error in a different PR |
As I said before, I think that this PR is no longer a draft, but I think that there are some choices that I made might be better to discuss here:
|
# For some strange reason, spawn multiprocessing doesn't copy the rng | ||
# seed sequence, so we have to rebuild it from scratch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment can be removed/updated? It's the bug you found for numpy<2.0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you’re right. I’ll update the comment
) | ||
self.vars = [var for var in vars if var.name in self.varnames] | ||
|
||
self.fn = model.compile_fn(self.vars, inputs=model.value_vars, on_unused_input="ignore") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Allow passing the compiled function already and use pytensor.In
/pytensor.Out
and trust_input for the default as was done in b589ce8
May be worth to refactor into a helper
draws_per_chain = total_draws_per_chain - tuning_steps_per_chain | ||
|
||
total_n_tune = tuning_steps_per_chain.sum() | ||
total_draws = draws_per_chain.sum() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is this dealing with uneven draws, when sampling is Interrupted? MultiTrace was discarding draws from longer chains so everything is square
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When the ZarrTrace
groups are created, they all get a fill value and the final shape for each array. The fill value is not actually stored in the zarr.store
, so only the actual draws take up any memory. That being said, when you open and load an unfinished run, the arrays in the trace will have the shape as if they had been completely sampled, but the unsampled draws will be set to the fill value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the default value? Would it make sense to trim it like MultiTrace does?
if isinstance(trace, ZarrChain): | ||
trace.link_stepper(step) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could also be added to NDarray
?
if isinstance(trace, ZarrChain): | ||
trace.record_sampling_state(step=step) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we make this part of trace.close
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't want to change the signature of close
and I don't want to make trace
necessarily have a reference to the step
. We could add a record_sampling_state
method to NDArray
though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, and doesn't have to be done in this PR. Just please open an issue?
@lucianopaz PR looks good. Left some small comments |
Description
This PR is related to #7503. It specifically focuses on having a way to store intermediate trace results and the step methods sampling state somewhere (See task 2 of #7508).
To be honest, the current situation of the
MultiTrace
andNDArray
backends is terrible. These backend classes have inconsistent signatures across subclasses, and it's very awkward to write new backends that adhere to them.McBackend
was an attempt to make things sane again. As far as I understand,McBackend
does support ways to dump samples to disk instead of holding them in memory using theClickHouse
database. However, I found the backend a bit detached from arviz and xarray, and it seemed to be tightly linked to protocol buffers, which made it harder for me to see how I could customize stuff.These considerations brought me to the approach I'm pursuing in this PR: add a backend that uses zarr. Using zarr has the following benefits:
xarray
can read zarr stores directly making it possible to writeInferenceData
objects to disk directly almost without even having to call a converter.object
dtyped arrays using thenumcodec
package. This makes it possible use the same store to hold sample stats warning objects and step methodssampling_state
in the same place as the actual samples from the posterior.Having stated all of these considerations I intend to:
ZarrTrace
integrate well withpymc.sample
Replace theMultiTrace
andNDArray
backend defaults with their Zarr counterpartsZarrTrace
ZarrChain.record
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7540.org.readthedocs.build/en/7540/