Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address memory leak of postprocess_fn and lax.map #860

Merged
merged 2 commits into from
Jan 4, 2021

Conversation

fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented Jan 4, 2021

Resolves #448 and #539.

import gc
from collections import Counter

import numpy as np

from numpyro import sample
import numpyro.distributions as dist
from jax import random, vmap
from numpyro.infer import MCMC, NUTS


def model(y_obs):
    mu = sample('mu', dist.Normal(0., 1.))
    sigma = sample("sigma", dist.HalfCauchy(3.))
    y = sample("y", dist.Normal(mu, sigma), obs=y_obs)

kernel = NUTS(model)
mcmc = MCMC(kernel, 1, 2)
for i in range(10):
    mcmc.run(random.PRNGKey(i), np.zeros((1,)))
    print("\nGC OBJECTS:")
    cnt = Counter()
    # force collection; it is expected that count of different types
    # should not increase per iteration
    gc.collect()
    for x in gc.get_objects():
        if isinstance(x, list):
            if len(x) > 1:
                cnt[type(x[0])] += 1
    print(cnt.most_common(10))

returns

GC OBJECTS:
[(<class 'str'>, 1555), (<class 'jax.core.Var'>, 1502), (<class 'tuple'>, 263), (<class 'jax.core.JaxprEqn'>, 252), (<class 'int'>, 85), (<class 'jax.core.ShapedArray'>, 64), (<class 'jax.core.Literal'>, 52), (<class 'set'>, 29), (<class 'argparse._ArgumentGroup'>, 19), (<class 'IPython.core.magic_arguments.argument'>, 18)]

GC OBJECTS:
[(<class 'str'>, 1555), (<class 'jax.core.Var'>, 1502), (<class 'tuple'>, 263), (<class 'jax.core.JaxprEqn'>, 252), (<class 'int'>, 85), (<class 'jax.core.ShapedArray'>, 64), (<class 'jax.core.Literal'>, 52), (<class 'set'>, 29), (<class 'argparse._ArgumentGroup'>, 19), (<class 'IPython.core.magic_arguments.argument'>, 18)]

consistently.

The reason for the leak is: a new postprocess_fn is created each time .run is called, so lax.map(postprocess_fn,...) involves a new JAX cache each time it is run.

@rexdouglass Could you help me verify this change in a large model that you have? I'm not sure if we leave some things behind the change.

@rexdouglass
Copy link

Can confirm that a toy example that quickly caused an out of memory error no longer does and that reported GPU memory usage is approximately constant and certainly no longer monotonically increasing.

0.2.7
0.4.1
local
gpu

Used GPU Memory MB 1,243.22
sample: 100%|██████████| 200/200 [09:13<00:00,  2.77s/it, 1023 steps of size 7.46e-07. acc. prob=0.81]
  0%|          | 0/100 [00:00<?, ?it/s]
Used GPU Memory MB 3,727.16
sample: 100%|██████████| 100/100 [04:52<00:00,  2.92s/it, 1023 steps of size 7.46e-07. acc. prob=0.81]
  0%|          | 0/100 [00:00<?, ?it/s]
Used GPU Memory MB 4,029.68
sample: 100%|██████████| 100/100 [04:54<00:00,  2.95s/it, 1023 steps of size 7.46e-07. acc. prob=0.82]
  0%|          | 0/100 [00:00<?, ?it/s]
Used GPU Memory MB 3,994.62
sample: 100%|██████████| 100/100 [04:53<00:00,  2.94s/it, 1023 steps of size 7.46e-07. acc. prob=0.81]
  0%|          | 0/100 [00:00<?, ?it/s]
Used GPU Memory MB 3,952.41
sample: 100%|██████████| 100/100 [04:59<00:00,  2.99s/it, 1023 steps of size 7.46e-07. acc. prob=0.81]
  0%|          | 0/100 [00:00<?, ?it/s]
Used GPU Memory MB 3,887.33
sample: 100%|██████████| 100/100 [04:56<00:00,  2.96s/it, 1023 steps of size 7.46e-07. acc. prob=0.81]
  0%|          | 0/100 [00:00<?, ?it/s]
Used GPU Memory MB 3,876.19
sample: 100%|██████████| 100/100 [04:51<00:00,  2.91s/it, 1023 steps of size 7.46e-07. acc. prob=0.81]
  0%|          | 0/100 [00:00<?, ?it/s]
Used GPU Memory MB 3,801.28
sample: 100%|██████████| 100/100 [04:47<00:00,  2.88s/it, 1023 steps of size 7.46e-07. acc. prob=0.80]
  0%|          | 0/100 [00:00<?, ?it/s]
Used GPU Memory MB 3,813.93
sample: 100%|██████████| 100/100 [04:47<00:00,  2.87s/it, 1023 steps of size 7.46e-07. acc. prob=0.80]
  0%|          | 0/100 [00:00<?, ?it/s]
Used GPU Memory MB 3,770.61
sample: 100%|██████████| 100/100 [04:45<00:00,  2.85s/it, 1023 steps of size 7.46e-07. acc. prob=0.80]
import os
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform" #memory management at go time #https://github.com/pyro-ppl/numpyro/issues/539

#https://github.com/pyro-ppl/numpyro/issues/735
import numpyro
import jax
numpyro.set_platform("gpu") 

print(jax.__version__) #0.2.3
print(numpyro.__version__) #0.4.1
print(jax.config.FLAGS.jax_backend_target) #local
print(jax.lib.xla_bridge.get_backend().platform) #gpu

import numpy as np
import numpyro.distributions as dist
from jax import random
from numpyro.infer import MCMC, NUTS
from numpyro.infer.hmc import hmc
from numpyro.infer.util import initialize_model
from numpyro.util import fori_collect
import numpy as onp
import numpyro
import jax
from pynvml import *

GPU_mem_state=None
try:
    nvmlInit()
    GPUhandle=nvmlDeviceGetHandleByIndex(0)
    numpyro.set_platform("gpu")
    
    def GPU_mem_state():
        info = nvmlDeviceGetMemoryInfo(GPUhandle)
        return "Used GPU Memory MB {:,}".format(onp.round(info.used/1000000,2))
except:
    print ("Cant initialise GPU, Using CPU")
    
def test1():
  a = numpyro.sample("a", dist.Normal(0., 0.2), sample_shape=(365,3300)) 
  b = numpyro.sample("b", dist.Normal(0., 0.2), sample_shape=(365,3300)) 
  c = numpyro.sample("c", dist.Normal(0., 0.2), sample_shape=(365,3300)) 
  d = numpyro.sample("d", dist.Normal(0., 0.2), sample_shape=(365,3300)) 
  e = numpyro.sample("e", dist.Normal(0., 0.2), sample_shape=(365,3300)) 

import gc
from numpyro.infer import NUTS, MCMC
mcmc = MCMC(NUTS(test1), 100, 100)
for i in range(10):
    print("\n"+GPU_mem_state())
    mcmc.run(random.PRNGKey(i))
    samples = mcmc.get_samples()
    trace = [onp.atleast_1d(onp.asarray(f)) for f in samples]
    del samples
    mcmc._warmup_state = mcmc._last_state
    gc.collect()

Copy link
Member

@neerajprad neerajprad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kudos on debugging and fixing this tricky issue! I hope this is the last that we see of memory leaks due to caching, but if not, we should work with the JAX devs and request for implementing some cache eviction policy that doesn't hog all the available memory in these kinds of situations.

@neerajprad neerajprad merged commit 956789b into pyro-ppl:master Jan 4, 2021
@fehiepsi
Copy link
Member Author

fehiepsi commented Jan 4, 2021

Thanks for reviewing, @rexdouglass and @neerajprad !

@fehiepsi fehiepsi mentioned this pull request Jan 26, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Memory leak when running MCMC multiple times
3 participants