Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OOM error when using "sequential" chain method on a GPU #899

Closed
fehiepsi opened this issue Jan 28, 2021 · 6 comments · Fixed by #904
Closed

OOM error when using "sequential" chain method on a GPU #899

fehiepsi opened this issue Jan 28, 2021 · 6 comments · Fixed by #904

Comments

@fehiepsi
Copy link
Member

fehiepsi commented Jan 28, 2021

As reported by @PaoloRanzi81 in #539, RAM might not free its resources after a chain finished its run, which leads to OOM for complicated models trained in hours.

Some observations so far:

  • This seems to not happen on simple models
  • This seems to not happen in CPU and TPU
  • This seems to not happen when sampling in short period of time (says less than 1 hour)
  • This happens when progress_bar=False

In progress of testing

  • The default behavior progress_bar=True
  • Using only one chain with large number of samples
  • Simplifying the model to isolate which operator, handler,... causes the issue
@fehiepsi fehiepsi added the question Further information is requested label Jan 28, 2021
@fehiepsi fehiepsi mentioned this issue Jan 28, 2021
@PaoloRanzi81
Copy link

Hi @fehiepsi

I did some testing with the new master branch.

  1. Bad news:
    Unfortunately, chain_method=“sequential” still stealing lots of RAM. OOM still present... I have noticed:
    1.1. that the new Numpyro master version the memory in the GPU decreased. Is it possible that _laxmap frees up only the memory of the GPU but not the system RAM?
    1.2. Attached below the table that shows the results of the testing.
    1.3. Attached picture below taken from GCP shows the monotonic increase of the RAM.
    1.4. During my testing, I was still using progress_bar= False. Not tested progress_bar= True yet.
    1.5. Funny enough, with the same model + same data-set but new master branch, a new JAX’s error pops up about complaining for the slow compilation (see picture attached).

  2. Good news:
    I asked my employer and he kindly agreed to share data-set and model for research purpose. I can prepare a self-contained module for the end of the next week if you are still interested. I will send the Dropbox link privately.

  3. Good news:
    I think the problem with your toy example was it was too simple. Based on your example I have created a self-contained example where the monotonic increase is clearly visible. The script is below in separate comment. I have built it by trial-and-error, trying to match it as close as possible with my actual model. Forgive me if I have used numpyro.plate() in the wrong way. I am still learning.
    The good thing of the new toy example it the fact it reproduces the error. It is custom-made in order to:
    3.1. I have set number_of_distribution=10000 because it matches perfectly my GPU usage of the actual model (70 %).
    3.2. I have put 5 variables/features as for my model.
    3.3. I have put hyper-priors to each variable as for my model.
    3.4. Lastly, my toy example has all the MCMC specs I am using for its CPU version (which is perfectly working).

01_numpyro_performance_master_version
02_monotonic_increase

@PaoloRanzi81
Copy link

Forgot the message! Ops...
03_new_message

@PaoloRanzi81
Copy link

Self-contained script for reproducing the OOM error

# import libraries
import jax
import jax.numpy as jnp
from jax.lib import xla_bridge
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import numpy as np
import time


# start clocking time
start_time = time.time()

# force GPU usage
numpyro.set_platform(platform='gpu')

# print whether the default device is either CPU or GPU or TPU
print('Device used is {}'.format(str(xla_bridge.get_backend().platform)))
# print(xla_bridge.get_backend().platform) 
jnp.ones(())
print(jax.local_devices())
print(jax.local_device_count())

# set number of distributions
number_of_distribution = 10000


# define model
def model():
    
    # intercept
    a = numpyro.sample('a', dist.Normal(0., 5.).expand([number_of_distribution]))
    
    # hyper-priors 
    x_01_mu = numpyro.sample('x_01_mu', dist.Normal(0., 5.).expand([number_of_distribution]))
    x_01_sigma = numpyro.sample('x_01_sigma', dist.HalfNormal(5.).expand([number_of_distribution]))
    beta_01_mu = numpyro.sample('beta_01_mu', dist.Normal(0., 5.).expand([number_of_distribution]))
    beta_01_sigma = numpyro.sample('beta_01_sigma', dist.HalfNormal(5.).expand([number_of_distribution]))
    x_02_mu = numpyro.sample('x_02_mu', dist.Normal(0., 5.).expand([number_of_distribution]))
    x_02_sigma = numpyro.sample('x_02_sigma', dist.HalfNormal(5.).expand([number_of_distribution]))
    beta_02_mu = numpyro.sample('beta_02_mu', dist.Normal(0., 5.).expand([number_of_distribution]))
    beta_02_sigma = numpyro.sample('beta_02_sigma', dist.HalfNormal(5.).expand([number_of_distribution]))
    x_03_mu = numpyro.sample('x_03_mu', dist.Normal(0., 5.).expand([number_of_distribution]))
    x_03_sigma = numpyro.sample('x_03_sigma', dist.HalfNormal(5.).expand([number_of_distribution]))
    beta_03_mu = numpyro.sample('beta_03_mu', dist.Normal(0., 5.).expand([number_of_distribution]))
    beta_03_sigma = numpyro.sample('beta_03_sigma', dist.HalfNormal(5.).expand([number_of_distribution]))
    x_04_mu = numpyro.sample('x_04_mu', dist.Normal(0., 5.).expand([number_of_distribution]))
    x_04_sigma = numpyro.sample('x_04_sigma', dist.HalfNormal(5.).expand([number_of_distribution]))
    beta_04_mu = numpyro.sample('beta_04_mu', dist.Normal(0., 5.).expand([number_of_distribution]))
    beta_04_sigma = numpyro.sample('beta_04_sigma', dist.HalfNormal(5.).expand([number_of_distribution]))
    x_05_mu = numpyro.sample('x_05_mu', dist.Normal(0., 5.).expand([number_of_distribution]))
    x_05_sigma = numpyro.sample('x_05_sigma', dist.HalfNormal(5.).expand([number_of_distribution]))
    beta_05_mu = numpyro.sample('beta_05_mu', dist.Normal(0., 5.).expand([number_of_distribution]))
    beta_05_sigma = numpyro.sample('beta_05_sigma', dist.HalfNormal(5.).expand([number_of_distribution]))
    
    # generating distributions from hyper-priors 
    with numpyro.plate('plate', number_of_distribution): # how it works? For using hyper-priors defined above?
        x_01 = numpyro.sample('x_01', dist.Normal(x_01_mu, x_01_sigma).expand([number_of_distribution]))
        beta_01 = numpyro.sample('beta_01', dist.Normal(beta_01_mu, beta_01_sigma).expand([number_of_distribution]))
        x_02 = numpyro.sample('x_02', dist.Normal(x_02_mu, x_02_sigma).expand([number_of_distribution]))
        beta_02 = numpyro.sample('beta_02', dist.Normal(beta_02_mu, beta_02_sigma).expand([number_of_distribution]))
        x_03 = numpyro.sample('x_03', dist.Normal(x_03_mu, x_03_sigma).expand([number_of_distribution]))
        beta_03 = numpyro.sample('beta_03', dist.Normal(beta_03_mu, beta_03_sigma).expand([number_of_distribution]))
        x_04 = numpyro.sample('x_04', dist.Normal(x_04_mu, x_04_sigma).expand([number_of_distribution]))
        beta_04 = numpyro.sample('beta_04', dist.Normal(beta_04_mu, beta_04_sigma).expand([number_of_distribution]))
        x_05 = numpyro.sample('x_05', dist.Normal(x_05_mu, x_05_sigma).expand([number_of_distribution]))
        beta_05 = numpyro.sample('beta_05', dist.Normal(beta_05_mu, beta_05_sigma).expand([number_of_distribution]))
           
    # epsilon
    epsilon = numpyro.sample('epsilon', dist.HalfNormal(5.).expand([number_of_distribution]))
    
    # likelihood
    mu = (a 
         + x_01*beta_01 
         + x_02*beta_02
         + x_03*beta_03
         + x_04*beta_04
         + x_05*beta_05
         + epsilon)
    
    # output
    numpyro.sample('y', dist.Normal(mu, 5.))
    
        

# set and run MCMC
mcmc = MCMC(NUTS(model), 
            num_warmup=55250, 
            num_samples=1000, 
            num_chains=8, 
            # chain_method="vectorized",
            chain_method="sequential",
            # chain_method="parallel",
            progress_bar=False)
            # progress_bar=True)             
mcmc.run(jax.random.PRNGKey(0))
samples = mcmc.get_samples()['y']

# print MCMC's output
mcmc.print_summary()
print(samples.shape)
print(samples.device_buffer.device())



# end time according to computer clock
end_time = time.time()

# calculate total execution time
total_execution_time = np.round((end_time - start_time), 2)


# shows run-time's timestamps + total execution time
print('start time (unix timestamp):{}'.format(start_time))
print('end time (unix timestamp):{}'.format(end_time))
print('total execution time (seconds):{}'.format(total_execution_time))

@fehiepsi
Copy link
Member Author

Thanks, @PaoloRanzi81! I did see memory is increasing with progress_bar=False, but not with progress_bar=True. Probably lax.scan (similar to lax.map), which is used to collect samples under progress_bar=False does not free RAM during its run. There is no reason to use progress_bar=False in your case. We will improve the docs the reflect this observation. Thanks!

@PaoloRanzi81
Copy link

Congratulations @fehiepsi : you have caught the bug with 1 GPU!

I have tested my custom toy model and the OOM error of the RAM does disappear setting
progress_bar = True !

Still not clear to me why a small change did such a major improvement in the RAM memory consumption... I was thinking that setting the progress_bar was just aesthetics... Please remember to change code/documentation etc. in order to make other people aware of it. This way we avoid them wasting their time troubleshooting it.

I do still have the OOM error with my actual model. Using progress_bar = True did not help, unfortunately. I think my actual model is way more complex than the custom toy model I have built in Numpyro. Anyway, I am increasingly thinking that the OOM error I am still experiencing, it could be more likely due to PyMC3 rather then Numpyro or JAX. I will report an issue to PyMC3 asking help for debugging it.

You can close the issue. Thanks again for your help so far!

@fehiepsi
Copy link
Member Author

Thanks, @PaoloRanzi81, for your effort to isolate the issue! We'll improve documentation to reflect this issue better.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants