diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 5016b0897d..c9a9bef500 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -658,25 +658,46 @@ def sample_numpyro_nuts( tic2 = datetime.now() print("Compilation time = ", tic2 - tic1, file=sys.stdout) - print("Sampling...", file=sys.stdout) - map_seed = jax.random.PRNGKey(random_seed) if chains > 1: map_seed = jax.random.split(map_seed, chains) - pmap_numpyro.run( - map_seed, - init_params=init_params, - extra_fields=( - "num_steps", - "potential_energy", - "energy", - "adapt_state.step_size", - "accept_prob", - "diverging", - ), + extra_fields=( + "num_steps", + "potential_energy", + "energy", + "adapt_state.step_size", + "accept_prob", + "diverging", ) + if tune > 0: + print("Warmup...", file=sys.stdout) + pmap_numpyro.warmup( + map_seed, + collect_warmup=True, + init_params=init_params, + extra_fields=extra_fields, + ) + + raw_mcmc_warmup_samples = pmap_numpyro.get_samples(group_by_chain=True) + warmup_sample_stats = _sample_stats_to_xarray(pmap_numpyro) + + print("Sampling...", file=sys.stdout) + pmap_numpyro.post_warmup_state = pmap_numpyro.last_state + pmap_numpyro.run( + pmap_numpyro.post_warmup_state.rng_key, + extra_fields=extra_fields, + ) + + else: + print("Sampling...", file=sys.stdout) + pmap_numpyro.run( + map_seed, + init_params=init_params, + extra_fields=extra_fields, + ) + raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True) tic3 = datetime.now() @@ -730,5 +751,13 @@ def sample_numpyro_nuts( dims=dims, attrs=make_attrs(attrs, library=numpyro), ) - az_trace = to_trace(posterior=posterior, **idata_kwargs) + + if tune > 0: + az_trace = to_trace( + posterior=posterior, + warmup_sample_stats=warmup_sample_stats, + **idata_kwargs, + ) + else: + az_trace = to_trace(posterior=posterior, **idata_kwargs) return az_trace