Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compile the functions needed by SMC before the worker processes are started #7472

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

EliasRas
Copy link
Contributor

@EliasRas EliasRas commented Aug 22, 2024

Description

Currently sample_smc can fail due to a NotImplementedError if it's used with a model defined using CustomDist. If a CustomDist is used, the overloads for e.g. _logprob are registered only in the main process. The issue exists only on Windows because the worker processes are spawned. In other systems where the default option is forking, everything works.

#7241 fixed the issue by registering the overloads manually. Although that would fix #7224, the approach might not be the best in the long run.

This PR moves some of the SMC kernel initialization (calculating initial_point and compiling *logp functions) from worker processes to the main process. This way the overloads are not needed in the worker processes.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7472.org.readthedocs.build/en/7472/

self.varlogp = self.model.varlogp
self.datalogp = self.model.datalogp

def initialize_rng(self, random_seed=None):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was it unnecessary to add this method? Didn't want to directly access SMC_KERNEL.rng since it's not initialized by just direct assignment.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't follow, can you explain again?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a new method SMC_KERNEL.initialize_rng which creates a new SMC_KERNEL.rng with given seed. It's just a convenience method for seeding the rng with a different seed in each worker process. Previously it wasn't necessary because the kernels were created in each process separately and rng is seeded during that. This PR creates the kernel before creating the worker processes and seeding has to be done separately.

I was wondering if adding a new method was unnecessary. The method doesn't do much after all. I could just do smc.rng = np.random.default_rng(seed=random_seed) in _sample_smc_int instead but I didn't want to interact with SMC_KERNEL.rng since I didn't see it used anywhere else outside of SMC_KERNEL.

Nitpicky? I agree.

@EliasRas
Copy link
Contributor Author

The test from #7241 is missing but I intend to add it later.

@ricardoV94
Copy link
Member

One thing we'll have to be careful that didn't matter before is to update the shared variables of the logp functions that define RNGs. Model with minibatch/Simulator have a stochastic logp. For those we will probably have to copy the pytensor function using the swap kwarg to provide new shared RNGs: https://pytensor.readthedocs.io/en/latest/tutorial/examples.html#copying-functions

@EliasRas
Copy link
Contributor Author

EliasRas commented Sep 2, 2024

Which shared variables are you referring to? I ran tests/distributions/test_simulator.py locally and there are some failures. I tried to look into test_custom_dist_sum_stat which fails but I couldn't find any SharedVariable.

Could you kick off the tests? I don't know if I'm doing something wrong on my computer which leads to failures in the tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants