-
Notifications
You must be signed in to change notification settings - Fork 245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Address memory leak of postprocess_fn and lax.map #860
Conversation
Can confirm that a toy example that quickly caused an out of memory error no longer does and that reported GPU memory usage is approximately constant and certainly no longer monotonically increasing. 0.2.7
0.4.1
local
gpu
Used GPU Memory MB 1,243.22
sample: 100%|██████████| 200/200 [09:13<00:00, 2.77s/it, 1023 steps of size 7.46e-07. acc. prob=0.81]
0%| | 0/100 [00:00<?, ?it/s]
Used GPU Memory MB 3,727.16
sample: 100%|██████████| 100/100 [04:52<00:00, 2.92s/it, 1023 steps of size 7.46e-07. acc. prob=0.81]
0%| | 0/100 [00:00<?, ?it/s]
Used GPU Memory MB 4,029.68
sample: 100%|██████████| 100/100 [04:54<00:00, 2.95s/it, 1023 steps of size 7.46e-07. acc. prob=0.82]
0%| | 0/100 [00:00<?, ?it/s]
Used GPU Memory MB 3,994.62
sample: 100%|██████████| 100/100 [04:53<00:00, 2.94s/it, 1023 steps of size 7.46e-07. acc. prob=0.81]
0%| | 0/100 [00:00<?, ?it/s]
Used GPU Memory MB 3,952.41
sample: 100%|██████████| 100/100 [04:59<00:00, 2.99s/it, 1023 steps of size 7.46e-07. acc. prob=0.81]
0%| | 0/100 [00:00<?, ?it/s]
Used GPU Memory MB 3,887.33
sample: 100%|██████████| 100/100 [04:56<00:00, 2.96s/it, 1023 steps of size 7.46e-07. acc. prob=0.81]
0%| | 0/100 [00:00<?, ?it/s]
Used GPU Memory MB 3,876.19
sample: 100%|██████████| 100/100 [04:51<00:00, 2.91s/it, 1023 steps of size 7.46e-07. acc. prob=0.81]
0%| | 0/100 [00:00<?, ?it/s]
Used GPU Memory MB 3,801.28
sample: 100%|██████████| 100/100 [04:47<00:00, 2.88s/it, 1023 steps of size 7.46e-07. acc. prob=0.80]
0%| | 0/100 [00:00<?, ?it/s]
Used GPU Memory MB 3,813.93
sample: 100%|██████████| 100/100 [04:47<00:00, 2.87s/it, 1023 steps of size 7.46e-07. acc. prob=0.80]
0%| | 0/100 [00:00<?, ?it/s]
Used GPU Memory MB 3,770.61
sample: 100%|██████████| 100/100 [04:45<00:00, 2.85s/it, 1023 steps of size 7.46e-07. acc. prob=0.80] import os
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform" #memory management at go time #https://github.com/pyro-ppl/numpyro/issues/539
#https://github.com/pyro-ppl/numpyro/issues/735
import numpyro
import jax
numpyro.set_platform("gpu")
print(jax.__version__) #0.2.3
print(numpyro.__version__) #0.4.1
print(jax.config.FLAGS.jax_backend_target) #local
print(jax.lib.xla_bridge.get_backend().platform) #gpu
import numpy as np
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS
from numpyro.infer.hmc import hmc
from numpyro.infer.util import initialize_model
from numpyro.util import fori_collect
import numpy as onp
import numpyro
import jax
from pynvml import *
GPU_mem_state=None
try:
nvmlInit()
GPUhandle=nvmlDeviceGetHandleByIndex(0)
numpyro.set_platform("gpu")
def GPU_mem_state():
info = nvmlDeviceGetMemoryInfo(GPUhandle)
return "Used GPU Memory MB {:,}".format(onp.round(info.used/1000000,2))
except:
print ("Cant initialise GPU, Using CPU")
def test1():
a = numpyro.sample("a", dist.Normal(0., 0.2), sample_shape=(365,3300))
b = numpyro.sample("b", dist.Normal(0., 0.2), sample_shape=(365,3300))
c = numpyro.sample("c", dist.Normal(0., 0.2), sample_shape=(365,3300))
d = numpyro.sample("d", dist.Normal(0., 0.2), sample_shape=(365,3300))
e = numpyro.sample("e", dist.Normal(0., 0.2), sample_shape=(365,3300))
import gc
from numpyro.infer import NUTS, MCMC
mcmc = MCMC(NUTS(test1), 100, 100)
for i in range(10):
print("\n"+GPU_mem_state())
mcmc.run(random.PRNGKey(i))
samples = mcmc.get_samples()
trace = [onp.atleast_1d(onp.asarray(f)) for f in samples]
del samples
mcmc._warmup_state = mcmc._last_state
gc.collect() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kudos on debugging and fixing this tricky issue! I hope this is the last that we see of memory leaks due to caching, but if not, we should work with the JAX devs and request for implementing some cache eviction policy that doesn't hog all the available memory in these kinds of situations.
Thanks for reviewing, @rexdouglass and @neerajprad ! |
Resolves #448 and #539.
returns
consistently.
The reason for the leak is: a new
postprocess_fn
is created each time.run
is called, solax.map(postprocess_fn,...)
involves a new JAX cache each time it is run.@rexdouglass Could you help me verify this change in a large model that you have? I'm not sure if we leave some things behind the change.