Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move the warning of using the sequential chain method to the constructor #893

Merged
merged 7 commits into from
Jan 26, 2021

Conversation

fehiepsi
Copy link
Member

Tried to address memory leak as reported by @PaoloRanzi81 in #539 for chain_method='sequential' in 1 GPU but haven't found the root problem yet. But I think it would be nice to raise the warning about not enough devices to run parallel method as early as possible.

Also use jax.local_device_count instead of jax.lib.xla_bridge.device_count to calculate the number of available devices, as mentioned in docs of pmap.

@fehiepsi fehiepsi mentioned this pull request Jan 26, 2021
@fehiepsi fehiepsi removed the easy label Jan 26, 2021
states, last_state = lax.map(partial_map_fn, map_args)
elif chain_method == 'parallel':
if self.chain_method == 'sequential':
states, last_state = _laxmap(partial_map_fn, map_args)
Copy link
Member Author

@fehiepsi fehiepsi Jan 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using lax.map here is a bit faster but expensive in term of memory requirement so I switched to for loop (loop over chains - for each chain, we still use lax.scan if progress_bar=False and jit(sample_fn) if progress_bar=True to draw samples). See the benchmark on GPU (on CPU, the performance is similar)

%%time
%%memit

import jax

import numpyro; numpyro.set_platform("gpu")
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

def model():
    numpyro.sample("x", dist.Normal(0, 1).expand([100]))

mcmc = MCMC(NUTS(model), 1000, num_samples=10000, progress_bar=False, num_chains=8)
mcmc.run(jax.random.PRNGKey(0))
samples = mcmc.get_samples(group_by_chain=True)["x"]
print(samples.shape)

using lax.map

(8, 10000, 100)
peak memory: 5462.81 MiB, increment: 5415.39 MiB
CPU times: user 9min 23s, sys: 6.47 s, total: 9min 29s
Wall time: 9min 25s

while using for loop _laxmap:

(8, 10000, 100)
peak memory: 1795.53 MiB, increment: 1748.04 MiB
CPU times: user 9min 41s, sys: 5.9 s, total: 9min 47s
Wall time: 9min 44s

cc @PaoloRanzi81 I'm not sure if this will be enough to solve OOM in your model but I guess it could help a bit.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems more than acceptable for the memory savings!

chain_method = 'sequential'
warnings.warn('There are not enough devices to run parallel chains: expected {} but got {}.'
' Chains will be drawn sequentially. If you are running MCMC in CPU,'
' consider to use `numpyro.set_host_device_count({})` at the beginning'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: consider using...devices are available...

@neerajprad neerajprad merged commit 16edc9f into pyro-ppl:master Jan 26, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants