-
Notifications
You must be signed in to change notification settings - Fork 245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move the warning of using the sequential chain method to the constructor #893
Conversation
states, last_state = lax.map(partial_map_fn, map_args) | ||
elif chain_method == 'parallel': | ||
if self.chain_method == 'sequential': | ||
states, last_state = _laxmap(partial_map_fn, map_args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using lax.map
here is a bit faster but expensive in term of memory requirement so I switched to for loop (loop over chains - for each chain, we still use lax.scan if progress_bar=False and jit(sample_fn) if progress_bar=True to draw samples). See the benchmark on GPU (on CPU, the performance is similar)
%%time
%%memit
import jax
import numpyro; numpyro.set_platform("gpu")
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
def model():
numpyro.sample("x", dist.Normal(0, 1).expand([100]))
mcmc = MCMC(NUTS(model), 1000, num_samples=10000, progress_bar=False, num_chains=8)
mcmc.run(jax.random.PRNGKey(0))
samples = mcmc.get_samples(group_by_chain=True)["x"]
print(samples.shape)
using lax.map
(8, 10000, 100)
peak memory: 5462.81 MiB, increment: 5415.39 MiB
CPU times: user 9min 23s, sys: 6.47 s, total: 9min 29s
Wall time: 9min 25s
while using for loop _laxmap
:
(8, 10000, 100)
peak memory: 1795.53 MiB, increment: 1748.04 MiB
CPU times: user 9min 41s, sys: 5.9 s, total: 9min 47s
Wall time: 9min 44s
cc @PaoloRanzi81 I'm not sure if this will be enough to solve OOM in your model but I guess it could help a bit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems more than acceptable for the memory savings!
numpyro/infer/mcmc.py
Outdated
chain_method = 'sequential' | ||
warnings.warn('There are not enough devices to run parallel chains: expected {} but got {}.' | ||
' Chains will be drawn sequentially. If you are running MCMC in CPU,' | ||
' consider to use `numpyro.set_host_device_count({})` at the beginning' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: consider using
...devices are available
...
Tried to address memory leak as reported by @PaoloRanzi81 in #539 for
chain_method='sequential'
in 1 GPU but haven't found the root problem yet. But I think it would be nice to raise the warning about not enough devices to run parallel method as early as possible.Also use
jax.local_device_count
instead ofjax.lib.xla_bridge.device_count
to calculate the number of available devices, as mentioned in docs ofpmap
.