diff --git a/pymc/smc/kernels.py b/pymc/smc/kernels.py index 6576b425dfb..dec7e52685d 100644 --- a/pymc/smc/kernels.py +++ b/pymc/smc/kernels.py @@ -33,7 +33,6 @@ compile_pymc, floatX, join_nonshared_inputs, - make_shared_replacements, ) from pymc.sampling.forward import draw from pymc.step_methods.metropolis import MultivariateNormalProposal @@ -168,7 +167,7 @@ def __init__( raise ValueError(f"Threshold value {threshold} must be between 0 and 1") self.threshold = threshold self.model = model - self.rng = np.random.default_rng(seed=random_seed) + self.initialize_rng(random_seed=random_seed) self.model = modelcontext(model) self.variables = self.model.value_vars @@ -186,6 +185,21 @@ def __init__( self.resampling_indexes = None self.weights = np.ones(self.draws) / self.draws + self.varlogp = self.model.varlogp + self.datalogp = self.model.datalogp + + def initialize_rng(self, random_seed=None): + """ + Initialize random number generator. + + Parameters + ---------- + random_seed : int, array_like of int, RandomState or Generator, optional + Value used to initialize the random number generator. + """ + + self.rng = np.random.default_rng(seed=random_seed) + def initialize_population(self) -> dict[str, np.ndarray]: """Create an initial population from the prior distribution""" sys.stdout.write(" ") # see issue #5828 @@ -212,15 +226,22 @@ def initialize_population(self) -> dict[str, np.ndarray]: return cast(dict[str, np.ndarray], dict_prior) - def _initialize_kernel(self): - """Create variables and logp function necessary to run SMC kernel + def _initialize_kernel(self, initial_point=None): + """ + Create variables and logp function necessary to run SMC kernel This method should not be overwritten. If needed, use `setup_kernel` instead. + Parameters + ---------- + initial_point : dict, optional + Dictionary that contains initial values for model variables. """ - # Create dictionary that stores original variables shape and size - initial_point = self.model.initial_point(random_seed=self.rng.integers(2**30)) + + if initial_point is None: + # Create dictionary that stores original variables shape and size + initial_point = self.model.initial_point(random_seed=self.rng.integers(2**30)) for v in self.variables: self.var_info[v.name] = (initial_point[v.name].shape, initial_point[v.name].size) # Create particles bijection map @@ -237,14 +258,8 @@ def _initialize_kernel(self): self.tempered_posterior = np.array(floatX(population)) # Initialize prior and likelihood log probabilities - shared = make_shared_replacements(initial_point, self.variables, self.model) - - self.prior_logp_func = _logp_forw( - initial_point, [self.model.varlogp], self.variables, shared - ) - self.likelihood_logp_func = _logp_forw( - initial_point, [self.model.datalogp], self.variables, shared - ) + self.prior_logp_func = _logp_forw(initial_point, [self.varlogp], self.variables, {}) + self.likelihood_logp_func = _logp_forw(initial_point, [self.datalogp], self.variables, {}) priors = [self.prior_logp_func(sample) for sample in self.tempered_posterior] likelihoods = [self.likelihood_logp_func(sample) for sample in self.tempered_posterior] diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index 03e64f94c10..dbc1e8faead 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -197,6 +197,16 @@ def sample_smc( random_seed = _get_seeds_per_chain(random_state=random_seed, chains=chains) model = modelcontext(model) + smc = kernel( + draws=draws, + start=start, + model=model, + **kernel_kwargs, + ) + initial_points = [ + model.initial_point(random_seed=np.random.default_rng(seed=seed).integers(2**30)) + for seed in random_seed + ] _log = logging.getLogger(__name__) _log.info("Initializing SMC sampler...") @@ -205,16 +215,9 @@ def sample_smc( f"in {cores} job{'s' if cores > 1 else ''}" ) - params = ( - draws, - kernel, - start, - model, - ) - t1 = time.time() - results = run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores) + results = run_chains(chains, progressbar, smc, random_seed, initial_points, cores) ( traces, @@ -303,41 +306,21 @@ def _save_sample_stats( def _sample_smc_int( - draws, - kernel, - start, - model, + smc, random_seed, + initial_point, chain, progress_dict, task_id, - **kernel_kwargs, ): """Run one SMC instance.""" - in_out_pickled = isinstance(model, bytes) + in_out_pickled = isinstance(smc, bytes) if in_out_pickled: # function was called in multiprocessing context, deserialize first - (draws, kernel, start, model) = map( - cloudpickle.loads, - ( - draws, - kernel, - start, - model, - ), - ) - - kernel_kwargs = {key: cloudpickle.loads(value) for key, value in kernel_kwargs.items()} - - smc = kernel( - draws=draws, - start=start, - model=model, - random_seed=random_seed, - **kernel_kwargs, - ) + smc = cloudpickle.loads(smc) - smc._initialize_kernel() + smc.initialize_rng(random_seed) + smc._initialize_kernel(initial_point) smc.setup_kernel() stage = 0 @@ -367,7 +350,7 @@ def _sample_smc_int( return results -def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores): +def run_chains(chains, progressbar, smc, random_seed, initial_points, cores): with CustomProgress( TextColumn("{task.description}"), SpinnerColumn(), @@ -383,9 +366,8 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores): # main process and our worker functions _progress = manager.dict() - # "manually" (de)serialize params before/after multiprocessing - params = tuple(cloudpickle.dumps(p) for p in params) - kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()} + # "manually" (de)serialize kernel before/after multiprocessing + smc = cloudpickle.dumps(smc) with ProcessPoolExecutor(max_workers=cores) as executor: for c in range(chains): # iterate over the jobs we need to run @@ -394,12 +376,12 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores): futures.append( executor.submit( _sample_smc_int, - *params, + smc, random_seed[c], + initial_points[c], c, _progress, task_id, - **kernel_kwargs, ) ) diff --git a/tests/smc/test_smc.py b/tests/smc/test_smc.py index 84a53695581..fcb0a534159 100644 --- a/tests/smc/test_smc.py +++ b/tests/smc/test_smc.py @@ -134,6 +134,21 @@ def test_unobserved_categorical(self): assert np.all(np.median(trace["mu"], axis=0) == [1, 2]) + def test_parallel_custom(self): + def _logp(value, mu): + return -((value - mu) ** 2) + + def _random(mu, rng=None, size=None): + return rng.normal(loc=mu, scale=1, size=size) + + def _dist(mu, size=None): + return pm.Normal.dist(mu, 1, size=size) + + with pm.Model(): + mu = pm.CustomDist("mu", 0, logp=_logp, dist=_dist) + pm.CustomDist("y", mu, logp=_logp, class_name="", random=_random, observed=[1, 2]) + pm.sample_smc(draws=6, cores=2) + def test_marginal_likelihood(self): """ Verifies that the log marginal likelihood function