Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ZarrTrace #7540

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open

Add ZarrTrace #7540

wants to merge 9 commits into from

Conversation

lucianopaz
Copy link
Contributor

@lucianopaz lucianopaz commented Oct 16, 2024

Description

This PR is related to #7503. It specifically focuses on having a way to store intermediate trace results and the step methods sampling state somewhere (See task 2 of #7508).

To be honest, the current situation of the MultiTrace and NDArray backends is terrible. These backend classes have inconsistent signatures across subclasses, and it's very awkward to write new backends that adhere to them. McBackend was an attempt to make things sane again. As far as I understand, McBackend does support ways to dump samples to disk instead of holding them in memory using the ClickHouse database. However, I found the backend a bit detached from arviz and xarray, and it seemed to be tightly linked to protocol buffers, which made it harder for me to see how I could customize stuff.

These considerations brought me to the approach I'm pursuing in this PR: add a backend that uses zarr. Using zarr has the following benefits:

  1. xarray can read zarr stores directly making it possible to write InferenceData objects to disk directly almost without even having to call a converter.
  2. zarr works with hierarchically structured data. It's possible to store arrays for each variable inside of a group (e.g. posterior, observed_data) directly.
  3. zarr arrays handle numpy arrays nicely. Fixed sized binary data can be fit into zarr arrays seemlessly.
  4. zarr arrays also have the possibility of storing object dtyped arrays using the numcodec package. This makes it possible use the same store to hold sample stats warning objects and step methods sampling_state in the same place as the actual samples from the posterior.
  5. zarr hierarchies can use many different kinds of storage: they can be held in memory, saved as a directory structure, inside of a zip file, or even remotely on s3 buckets.
  6. It's also possible to write to the same zarr object concurrently from different processes or threads, as long as a synchronization object is provided.
  7. zarr also stores the data using a compressed binary representation. The actual compressor can be customized.
  8. zarr arrays are chunked. This means that they don't need to be loaded entirely onto memory, making it possible to leave a smaller memory footprint while sampling. Another benefit of chunking is that write operations on different chunks should be completely independent from each other.

Having stated all of these considerations I intend to:

  • Build a zarr trace backend
  • Have ZarrTrace integrate well with pymc.sample
  • WONT DO NOW Replace the MultiTrace and NDArray backend defaults with their Zarr counterparts
  • Handle sampling state information in the zarr backend
  • Document ZarrTrace
  • Buffer write operations in ZarrChain.record
  • Record sampling state information periodically during sampling
  • Make it possible to load the zarr trace backend and resume sampling from it.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7540.org.readthedocs.build/en/7540/

@lucianopaz lucianopaz added enhancements trace-backend Traces and ArviZ stuff major Include in major changes release notes section labels Oct 16, 2024
Copy link

codecov bot commented Oct 16, 2024

Codecov Report

Attention: Patch coverage is 92.87356% with 31 lines in your changes missing coverage. Please review.

Project coverage is 92.84%. Comparing base (0082409) to head (a087568).

Files with missing lines Patch % Lines
pymc/sampling/population.py 33.33% 10 Missing ⚠️
pymc/backends/zarr.py 96.91% 9 Missing ⚠️
pymc/sampling/parallel.py 78.78% 7 Missing ⚠️
pymc/sampling/mcmc.py 92.00% 4 Missing ⚠️
pymc/step_methods/state.py 94.11% 1 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7540      +/-   ##
==========================================
- Coverage   92.84%   92.84%   -0.01%     
==========================================
  Files         106      107       +1     
  Lines       17685    18105     +420     
==========================================
+ Hits        16420    16809     +389     
- Misses       1265     1296      +31     
Files with missing lines Coverage Δ
pymc/backends/__init__.py 92.50% <100.00%> (+0.83%) ⬆️
pymc/step_methods/compound.py 97.58% <100.00%> (+0.02%) ⬆️
pymc/step_methods/hmc/quadpotential.py 84.63% <100.00%> (ø)
pymc/util.py 82.86% <100.00%> (+1.01%) ⬆️
pymc/step_methods/state.py 95.52% <94.11%> (-2.60%) ⬇️
pymc/sampling/mcmc.py 87.80% <92.00%> (+0.57%) ⬆️
pymc/sampling/parallel.py 87.73% <78.78%> (-1.12%) ⬇️
pymc/backends/zarr.py 96.91% <96.91%> (ø)
pymc/sampling/population.py 70.83% <33.33%> (-3.85%) ⬇️
---- 🚨 Try these New Features:

@lucianopaz
Copy link
Contributor Author

This is an important issue to keep track of when we'll eventually want to read the zarr store and create an InferenceData object using xarray and arviz

Comment on lines 95 to 109
_dtype = np.dtype(dtype)
if np.issubdtype(_dtype, np.floating):
return (np.nan, _dtype, None)
elif np.issubdtype(_dtype, np.integer):
return (-1_000_000, _dtype, None)
elif np.issubdtype(_dtype, "bool"):
return (False, _dtype, None)
elif np.issubdtype(_dtype, "str"):
return ("", _dtype, None)
elif np.issubdtype(_dtype, "datetime64"):
return (np.datetime64(0, "Y"), _dtype, None)
elif np.issubdtype(_dtype, "timedelta64"):
return (np.timedelta64(0, "Y"), _dtype, None)
else:
return (None, _dtype, numcodecs.Pickle())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question from my own ignorance, since I don't understand so much how fill values are implemented. Are we just hoping that these fill values don't actually occur in the data?

If so, this seems especially perilous for bool 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, they are supposed to be the initialisation values for the entries. When the sampler completes its run, all entries will be filled with the correct value. Zarr just needs you to tell it what value to give to unwritten places. In the storage, these entries are never actually written, they are produced when you ask for the concrete values in the array.
The dangerous part is that xarray is interpreting fill_value as an indicator of whether the actual value should be masked to nan. This seems to be because of the netcdf standard treats fill_value as something completely different.
To keep things as clean as possible, I’ll store the draw_idx of each chain in a separate group that should never be converted to xarray.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that makes a lot more sense now, thanks for the explanation!

In case it's non-obvious to more than me, maybe it would be helpful to try to make this more self-evident. Perhaps by calling the function get_initial_fill_value_and_codec, or make some little comment that the fill value is used for initialization?

@michaelosthege
Copy link
Member

the current situation of the MultiTrace and NDArray backends is terrible. These backend classes have inconsistent signatures across subclasses, and it's very awkward to write new backends that adhere to them.

Yes, therefore I would recommend not to use them for any new implementation.

McBackend was an attempt to make things sane again. As far as I understand, McBackend does support ways to dump samples to disk instead of holding them in memory using the ClickHouse database.

Just to clarify:

  • NumPyBackend is the go-to for in memory situations
  • ClickHouseBackend is for storing on disk (in a database that may even sit on a different machine!)

It should be quite simple to implement a ZarrBackend with McBackend!
I would recommend to do that first, because McBackend's test suite already covers all (?) of the nasty edge cases.

However, I found the backend a bit detached from arviz and xarray, and it seemed to be tightly linked to protocol buffers, which made it harder for me to see how I could customize stuff.

Yes and No. I would say ArviZ is a first-class citizen, because Run.to_inferencedata() is in the base class.
There are two things which McBackend does not integrate tightly:

  1. xarray because it doesn't/didn't support sparse arrays (needed for sparse stats or variables with varying shape)
  2. InferenceData groups other than .posterior

I consider 2. the primary weakness, and it's the only reason why I don't use McBackend by default.
I see that you added "the other" groups as properties to the ZarrTrace. Maybe this is something we should do on a more abstract level? Have McBackend define the signature of InferenceData without requiring a specific implementation for it?

First we must find answers to:

  • How do prior/posterior/log_likelihood data points arrive?
  • Do they arrive in some kind of "sampling" process that may get parallelized or doesn't fit into memory?
  • If yes, should the storage backend even make a difference between MCMC and forward sampling?

Protocol buffers

I used them because they are convenient for specifying a data structure and not having to write serialization/deserialization code for it.

And they are only for the constant metadata {constant_data, observed_data, coords, names dtypes, ...} because this this needs to be serializable from a semi-clean data structure supporting all the weird data types that users may put into their coords (timestamps anybody?).

From the Python perspective this could also be done with zarr or xarray (they can serialize to binary), but can you serialize/deserialize that in another language?
The protobufs can be compiled to C++ or Rust to easily read/write run metadata from those languages too!

Is that tight integration?

The important design decision is not which implementation is used to serialize/deserialize metadata, but rather to freeze and detach these (meta)data from chains, draws and stats:

  1. Determine it before starting the costly MCMC
  2. Serialize it to its own "blob" of data. Think of {constant_data, observed_data, coords, names dtypes, ...} as the "header" section of a trace.

@lucianopaz
Copy link
Contributor Author

Thanks @michaelosthege for the feedback!

McBackend was an attempt to make things sane again. As far as I understand, McBackend does support ways to dump samples to disk instead of holding them in memory using the ClickHouse database.

Just to clarify:

* `NumPyBackend` is the go-to for _in memory_ situations

* `ClickHouseBackend` is for storing on disk (in a database that may even sit on a different machine!)

It should be quite simple to implement a ZarrBackend with McBackend! I would recommend to do that first, because McBackend's test suite already covers all (?) of the nasty edge cases.

I understand what the two backends for McBackend offer and that McBackend already has a test suite. Despite this, I'll try to argue in favor of writing something that's detached from McBackend.

However, I found the backend a bit detached from arviz and xarray, and it seemed to be tightly linked to protocol buffers, which made it harder for me to see how I could customize stuff.

Yes and No. I would say ArviZ is a first-class citizen, because Run.to_inferencedata() is in the base class.

The way I see this is that McBackend offers a signature to convert from a kind of storage (like MultiTrace) into another one (arviz.InferenceData). I understand that with this method, you guarantee that there should always be a method to go from an McBackend Run to arviz.InferenceData, but you have to handle a lot of transformation logic in this conversion (just like the extra conversion logic that's already in pymc.backends.arviz). In my opinion, this isn't tight integration. Having the data stored in native zarr makes it possible to generate xarray.Dataset objects with a simple xr.open_zarr(store, group) calls, and then these can be wrapped into an InferenceData object with a simple InferenceData(posterior=zarr_posterior, ...) (and potentially even into the future DataTree objects, since zarr hierarchies are already very much tree-like).

There are two things which McBackend does not integrate tightly:

1. `xarray` because it doesn't/didn't support sparse arrays (needed for sparse stats or variables with varying shape)

xarray does not support sparse arrays. At the moment, the posterior samples are initialized as "empty" zarr arrays (in practice, filled arrays with a fill_value). The nice thing about zarr arrays is that these filled, uninitialized places, don't take up almost any space because they aren't actually stored. If queried, their value gets set from the fill_value attribute. xarray still needs to figure out pydata/xarray#5475 though.

2. `InferenceData` groups other than `.posterior`

I consider 2. the primary weakness, and it's the only reason why I don't use McBackend by default. I see that you added "the other" groups as properties to the ZarrTrace.

The key thing is that I added these groups to the zarr hierarchy, having them as ZarrTrace properties is not necessary. By having them in a single shared zarr entity, they are stored almost like an InferenceData from zarr. I need to actually check if arviz has a from_zarr method, because that would be the direct conversion method from a ZarrTrace to an InferenceData object without having to add any extra conversion code.

Maybe this is something we should do on a more abstract level? Have McBackend define the signature of InferenceData without requiring a specific implementation for it?

First we must find answers to:

* How do prior/posterior/log_likelihood data points arrive?

* Do they arrive in some kind of "sampling" process that may get parallelized or doesn't fit into memory?

* If yes, should the storage backend even make a difference between MCMC and forward sampling?

I decided to only focus on MCMC for now, and I'm trying to make ZarrTrace handle concurrent writes to the zarr store from multiple processes during sampling. Having said that, it's almost effortless to add other groups to a zarr hierarchy, and the same store could house prior, prior_predictive, posterior_predictive and predictions as well without having to handle almost any extra logic.

Protocol buffers

I used them because they are convenient for specifying a data structure and not having to write serialization/deserialization code for it.

And they are only for the constant metadata {constant_data, observed_data, coords, names dtypes, ...} because this this needs to be serializable from a semi-clean data structure supporting all the weird data types that users may put into their coords (timestamps anybody?).

From the Python perspective this could also be done with zarr or xarray (they can serialize to binary), but can you serialize/deserialize that in another language? The protobufs can be compiled to C++ or Rust to easily read/write run metadata from those languages too!

Yes, you can deserialize almost all of the contents into C++ or Rust. zarr can be readable from python, Julia, C++, rust, javascript and Java. The only content that would not be readable in other languages would come from arrays with object dtype. At the moment, this is limited to two things:

  1. SamplerWarning
  2. StepMethodState

The latter isn't a problem in my opinion because it is related exclusively to the python pymc step methods, and I detached it to its own private group in the zarr hierarchy. The former might be more problematic, but since SamplerWarning is a dataclass, it could potentially be converted into a dictionary and then represented as a json object, which zarr can serialize without problems.

Having said that, there are other benefits that we would get if we were to rely on zarr directly, such as:

  • Offloading the maintenance cost of the storage backend code
  • Growing set of features that will become available to us as time goes by
  • Seamless compression of the arrays to save storage space
  • Integration with multiple on disk storage options that range from directory structure, zipfiles and multiple SQL and no-SQL databases
  • Integration with distributed or cloud storage like S3, Hadoop, Google Cloud Storage and Azure storage blob, and also fsspec.

I think that these added benefits plus the drop in maintenance costs in the long run warrant using zarr directly and not through a new backend for McBackend.

@maresb
Copy link
Contributor

maresb commented Oct 17, 2024

@lucianopaz, have you done some benchmarks with this yet (in particular with S3)? I'm a bit concerned that with (1, 1, ...) chunk size that I/O will be a bottleneck.

@lucianopaz lucianopaz force-pushed the zarr branch 2 times, most recently from 3206597 to 69bb2ac Compare October 17, 2024 12:54
@lucianopaz
Copy link
Contributor Author

@lucianopaz, have you done some benchmarks with this yet (in particular with S3)? I'm a bit concerned that with (1, 1, ...) chunk size that I/O will be a bottleneck.

No, I haven't. But I've made the chunksize customizable now via the draws_per_chunk parameter. @aseyboldt said that we could try to use a different chunk size depending on the dimensionality of the RV.

Anyway, my long term goal is to add something like checkpoints during sampling where the trace gets dumped into a file along with the sampling state of the step methods. I think that I'll eventually make the chunks align with that, so that we don't lose samples that were drawn before the checkpoint if sampling gets terminated afterwards (before having finished).

@lucianopaz
Copy link
Contributor Author

By the way, I've added a to_inferencedata method to the zarr trace. I had to do it because I wanted to ensure that the zarr store had consolidated metadata (if it didn't, xarray would complain) and because I needed to pass mask_and_scale=False to xarray.open_zarr (which arviz doesn't allow in from_zarr). Anyway, you can see for yourselves that the conversion code is extremely short because the stored data is already aligned with what arviz wants.

Copy link
Member

@OriolAbril OriolAbril left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love this direction. I left a comment on ArviZ integration.

I also have more ideas of things that can be done to integrate better the sampling outputs with inferencedata but it might be better to address them in follow up PRs. Not having to go into the current ArviZ converter might help get this things off the ground. Many of these are around since #5160

Also, anything on ArviZ side that can help with this let me know


Better sample_stats. sample_stats doesn't necessarily need to restrict itself to having chain, draw dimensions in all its variables. the mass matrix could also go in there and a divergence_id even (with extra coordinate values or a multiindex to store the start and end points of divergences) which would complement the boolean diverging variable with chain, draw dimension.

samples in the unconstrained space. related to #6721 and to a lesser extent arviz-devs/arviz-base#8

pymc/backends/zarr.py Outdated Show resolved Hide resolved
Copy link
Member

@aseyboldt aseyboldt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I absolutely love this! :D

def setup(self, draws: int, chain: int, sampler_vars: Sequence[dict] | None): # type: ignore[override]
self.chain = chain

def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not check the source, but I think zarr will write the whole chunk each time we set a draw here, even if that chunk is not full yet. If that is indeed the case, we should be able to speed this up a lot if draws_per_chunk is >1 if we buffer draws_per_chunk draws, and the set values in one go.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, it would be a shame that zarr itself doesn't buffer the write until the chunk is filled.

@lucianopaz
Copy link
Contributor Author

I also have more ideas of things that can be done to integrate better the sampling outputs with inferencedata but it might be better to address them in follow up PRs. Not having to go into the current ArviZ converter might help get this things off the ground. Many of these are around since #5160

Great! Yes, let's try to address those in other PRs.

Also, anything on ArviZ side that can help with this let me know

I don't know if from_zarr might have to be updated a bit? How did you handle the fill_value from zarr? I ran into problems on my side and had to add mask_and_scale=False when I opened the group.

Better sample_stats. sample_stats doesn't necessarily need to restrict itself to having chain, draw dimensions in all its variables. the mass matrix could also go in there and a divergence_id even (with extra coordinate values or a multiindex to store the start and end points of divergences) which would complement the boolean diverging variable with chain, draw dimension.

I tried to stay close to how pymc is returning things right now. I agree that this could be greatly improved, but maybe we can do so in follow up iterations.

samples in the unconstrained space. related to #6721 and to a lesser extent arviz-devs/arviz-base#8

I'll try to add this. It doesn't seem to be difficult.

@lucianopaz lucianopaz force-pushed the zarr branch 2 times, most recently from 0c5e73b to ee0a36d Compare October 23, 2024 09:25
@aseyboldt
Copy link
Member

I just did a little experiment with strace, to see what zarr does when we write a value to a chunk.

I wrote a little script that creates the array, writes two chunks and then closes the file. With strace python the-script.py we can see all the syscalls it uses in between

import zarr
import numpy as np
import sys

store = zarr.DirectoryStore("zarr-test.zarr")
#store = zarr.LRUStoreCache(store, max_size=2**28)
data = zarr.open(store)
foo = data.array("foo", np.zeros((0, 0, 0)), chunks=(10, 10, 1))
foo.resize((1000, 10, 1))

# Mark the position in the code to make it easier to find the correct part
print("start--", flush=True, file=sys.stderr)

foo[0, 0, 0] = 1.0
print("mid--", flush=True, file=sys.stderr)
foo[1, 0, 0] = 2.0

print("done--", flush=True, file=sys.stderr)

The first write triggers this:

write(2, "start--\n", 8start--
)                = 8
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0", 0x7ffd56fa6d50) = -1 ENOENT (No such file or directory)
getpid()                                = 414823
futex(0x7b79166b0a88, FUTEX_WAKE_PRIVATE, 2147483647) = 0
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0", 0x7ffd56fa6cd0) = -1 ENOENT (No such file or directory)
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo", {st_mode=S_IFDIR|0755, st_size=14, ...}) = 0
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo", {st_mode=S_IFDIR|0755, st_size=14, ...}) = 0
fstat(3, {st_mode=S_IFCHR|0666, st_rdev=makedev(0x1, 0x9), ...}) = 0
read(3, "\231;\177Y\rp\267\323W?\16\212n\372\346\372", 16) = 16
openat(AT_FDCWD, "/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0.993b7f590d7047d3973f0e8a6efae6fa.partial", O_WRONLY|O_CREAT|O_TRUNC|O_CLOEXEC, 0666) = 4
fstat(4, {st_mode=S_IFREG|0644, st_size=0, ...}) = 0
ioctl(4, TCGETS, 0x7ffd56fa6c10)        = -1 ENOTTY (Inappropriate ioctl for device)
lseek(4, 0, SEEK_CUR)                   = 0
write(4, "\2\0011\10 \3\0\0 \3\0\0/\0\0\0\24\0\0\0\27\0\0\0\37\0\1\0\377\377F\37"..., 47) = 47
close(4)                                = 0
rename("/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0.993b7f590d7047d3973f0e8a6efae6fa.partial", "/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0") = 0
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0.993b7f590d7047d3973f0e8a6efae6fa.partial", 0x7ffd56fa6d30) = -1 ENOENT (No such file or directory)

The second write triggers this:

write(2, "mid--\n", 6mid--
)                  = 6
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0", {st_mode=S_IFREG|0644, st_size=47, ...}) = 0
openat(AT_FDCWD, "/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0", O_RDONLY|O_CLOEXEC) = 4
fstat(4, {st_mode=S_IFREG|0644, st_size=47, ...}) = 0
ioctl(4, TCGETS, 0x7ffd56fa6c30)        = -1 ENOTTY (Inappropriate ioctl for device)
lseek(4, 0, SEEK_CUR)                   = 0
lseek(4, 0, SEEK_CUR)                   = 0
fstat(4, {st_mode=S_IFREG|0644, st_size=47, ...}) = 0
read(4, "\2\0011\10 \3\0\0 \3\0\0/\0\0\0\24\0\0\0\27\0\0\0\37\0\1\0\377\377F\37"..., 48) = 47
read(4, "", 1)                          = 0
close(4)                                = 0
getpid()                                = 414823
getpid()                                = 414823
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0", {st_mode=S_IFREG|0644, st_size=47, ...}) = 0
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo", {st_mode=S_IFDIR|0755, st_size=24, ...}) = 0
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo", {st_mode=S_IFDIR|0755, st_size=24, ...}) = 0
fstat(3, {st_mode=S_IFCHR|0666, st_rdev=makedev(0x1, 0x9), ...}) = 0
read(3, "q\313\267\231t\364\276E\235\223J4\354}\372\240", 16) = 16
openat(AT_FDCWD, "/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0.71cbb79974f44e459d934a34ec7dfaa0.partial", O_WRONLY|O_CREAT|O_TRUNC|O_CLOEXEC, 0666) = 4
fstat(4, {st_mode=S_IFREG|0644, st_size=0, ...}) = 0
ioctl(4, TCGETS, 0x7ffd56fa6c10)        = -1 ENOTTY (Inappropriate ioctl for device)
lseek(4, 0, SEEK_CUR)                   = 0
write(4, "\2\0011\10 \3\0\0 \3\0\0006\0\0\0\24\0\0\0\36\0\0\0\37\0\1\0\377\377F\37"..., 54) = 54
close(4)                                = 0
rename("/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0.71cbb79974f44e459d934a34ec7dfaa0.partial", "/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0") = 0
stat("/home/adr/git/flow-experiment/zarr-test.zarr/foo/0.0.0.71cbb79974f44e459d934a34ec7dfaa0.partial", 0x7ffd56fa6d30) = -1 ENOENT (No such file or directory)
write(2, "done--\n", 7done--
)                 = 7

So there's a lot going on for each write to the array. If I read this correctly, for the second store it actually reads the chunk from the disc, then modifies the chunk with the indexing update, writes the new chunk to a temporary file and then replaces that with the original chunk file.

For the first write it skips reading in the chunk data, because there still is nothing there to read.

So I think if we want to get good performance from this, we should try to combine writes, by buffering draws_per_chunk items, and then writing them in one go.

@ricardoV94
Copy link
Member

ricardoV94 commented Nov 10, 2024

Can we keep this PR about adding Zarr trace only?

And not sampling state/ resuming

@ricardoV94
Copy link
Member

ricardoV94 commented Nov 10, 2024

Re flaky test, @bwengals was surprised by the magnitude of the error and there is an open issue for it.

If you're confident about there being no issue please close the associated issue. Perhaps as a separate PR to unblock other contributions, depending on how long this takes to get merged

@bwengals
Copy link
Contributor

Put a fix in for the flaky GP test here #7567

@lucianopaz
Copy link
Contributor Author

Can we keep this PR about adding Zarr trace only?

And not sampling state/ resuming

Yes, the only thing that I still need to figure out is what needs to be changed in the current ZarrTrace and ZarrChain classes to make sure that the sampling state/resuming will be workable without a big refactor of these two classes.

@lucianopaz
Copy link
Contributor Author

Put a fix in for the flaky GP test here #7567

Your fix looks great. I was just trying out stuff to see if the CI would build. I'll drop the commit with the test patch.

@lucianopaz lucianopaz force-pushed the zarr branch 2 times, most recently from ddf0ada to 002f890 Compare November 20, 2024 09:18
@lucianopaz lucianopaz marked this pull request as ready for review November 20, 2024 09:19
@lucianopaz
Copy link
Contributor Author

@ricardoV94, @OriolAbril, @aseyboldt, @maresb and @michaelosthege. I think that this is ready for a proper review. I'll try to handle the mambaforge sunset error in a different PR

@lucianopaz
Copy link
Contributor Author

As I said before, I think that this PR is no longer a draft, but I think that there are some choices that I made might be better to discuss here:

  1. Add zarr as a mandatory requirement. In all honesty, I don't think that this is a bad thing, and I chose to leave it like this for simplicity and because I would love to eventually deprecate MultiTrace et al. But as @ricardoV94 pointed out, I could make zarr an optional requirement and then add a bunch of conditional imports.
  2. Record the warmup draws and posterior draws in the same array only to split them at the end of sampling. I chose to do it like this because the other backends operated this way, but I think that it might be a waste. I think that it's possible and maybe even preferable to create the warmup and normal groups and have the ZarrChain.record determine where to put the draw. It would involve a bit more logic, but I think it's doable.
  3. How to handle report objects. I decided to ignore the report object that gets attached to the MultiTrace, but I could actually store that information inside a private group of the ZarrTrace. Something like ZarrTrace.root._sampling_state. Do you guys think that it's worth the effort?
  4. Add sampling_time and tuning_steps as arrays in _sampling_state instead of adding them as attrs of the groups. The latter option is how arviz stores this info, but I decided to make them full fledged arrays in the _sampling_state group. I don't think that it's worth the effort to duplicate the same timing and step information across multiple groups, and having it only in the _sampling_state makes it easier to keep track of it.

Comment on lines 108 to 109
# For some strange reason, spawn multiprocessing doesn't copy the rng
# seed sequence, so we have to rebuild it from scratch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment can be removed/updated? It's the bug you found for numpy<2.0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you’re right. I’ll update the comment

)
self.vars = [var for var in vars if var.name in self.varnames]

self.fn = model.compile_fn(self.vars, inputs=model.value_vars, on_unused_input="ignore")
Copy link
Member

@ricardoV94 ricardoV94 Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allow passing the compiled function already and use pytensor.In/pytensor.Out and trust_input for the default as was done in b589ce8

May be worth to refactor into a helper

draws_per_chain = total_draws_per_chain - tuning_steps_per_chain

total_n_tune = tuning_steps_per_chain.sum()
total_draws = draws_per_chain.sum()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this dealing with uneven draws, when sampling is Interrupted? MultiTrace was discarding draws from longer chains so everything is square

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the ZarrTrace groups are created, they all get a fill value and the final shape for each array. The fill value is not actually stored in the zarr.store, so only the actual draws take up any memory. That being said, when you open and load an unfinished run, the arrays in the trace will have the shape as if they had been completely sampled, but the unsampled draws will be set to the fill value.

Copy link
Member

@ricardoV94 ricardoV94 Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the default value? Would it make sense to trim it like MultiTrace does?

Comment on lines +1237 to +1238
if isinstance(trace, ZarrChain):
trace.link_stepper(step)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could also be added to NDarray?

Comment on lines +1263 to +1264
if isinstance(trace, ZarrChain):
trace.record_sampling_state(step=step)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make this part of trace.close?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to change the signature of close and I don't want to make trace necessarily have a reference to the step. We could add a record_sampling_state method to NDArray though

Copy link
Member

@ricardoV94 ricardoV94 Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, and doesn't have to be done in this PR. Just please open an issue?

@ricardoV94
Copy link
Member

@lucianopaz PR looks good. Left some small comments

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancements major Include in major changes release notes section trace-backend Traces and ArviZ stuff
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants