Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

eval memory issue #5

Open
ramdhan1989 opened this issue Dec 21, 2024 · 6 comments
Open

eval memory issue #5

ramdhan1989 opened this issue Dec 21, 2024 · 6 comments

Comments

@ramdhan1989
Copy link

ramdhan1989 commented Dec 21, 2024

Hi,
I can train the model using my dataset but when using eval mode I always got memory error. is there any solution?

bytes.
Traceback (most recent call last):
File "/scratch1/rwibawa/cvit/ns/main_nowandb.py", line 31, in
app.run(main)
File "/spack/2206/apps/linux-centos7-x86_64_v3/gcc-11.3.0/python-3.11.3-gl2q3yz/lib/python3.11/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/spack/2206/apps/linux-centos7-x86_64_v3/gcc-11.3.0/python-3.11.3-gl2q3yz/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
^^^^^^^^^^
File "/scratch1/rwibawa/cvit/ns/main_nowandb.py", line 26, in main
eval.evaluate(FLAGS.config)
File "/scratch1/rwibawa/cvit/ns/eval.py", line 81, in evaluate
pred = model.apply(state.params, x, coords)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch1/rwibawa/cvit/ns/src/model.py", line 364, in call
x = CrossAttnBlock(
^^^^^^^^^^^^^^^
File "/scratch1/rwibawa/cvit/ns/src/model.py", line 133, in call
x = nn.MultiHeadDotProductAttention(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home1/rwibawa/.local/lib/python3.11/site-packages/flax/linen/attention.py", line 674, in call
x = self.attention_fn(*attn_args, **attn_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home1/rwibawa/.local/lib/python3.11/site-packages/flax/linen/attention.py", line 266, in dot_product_attention
attn_weights = dot_product_attention_weights(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home1/rwibawa/.local/lib/python3.11/site-packages/flax/linen/attention.py", line 132, in dot_product_attention_weights
attn_weights = einsum('...qhd,...khd->...hqk', query, key)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home1/rwibawa/.local/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 9747, in einsum
return einsum(operands, contractions, precision,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/spack/2206/apps/linux-centos7-x86_64_v3/gcc-11.3.0/python-3.11.3-gl2q3yz/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 34359738368 bytes.
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

@sifanexisted
Copy link
Collaborator

Apologies for the delayed response. The OOM error occurs because CVIT is trained with subsampled grid coordinates, but evaluation requires the full grid coordinates, significantly increasing GPU memory usage. We recommend reducing the batch size until the OOM error is resolved.

Hope this helps!

@ramdhan1989
Copy link
Author

Hi,
I have reduced batch size and also increase memory but still get an error. I wonder if there is somthing I missed from this log error?

`A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.2 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last): File "/scratch1/rwibawa/cvit/ns/main_nowandb.py", line 5, in
import train_nowandb
File "/scratch1/rwibawa/cvit/ns/train_nowandb.py", line 8, in
from src.model import CVit
File "/scratch1/rwibawa/cvit/ns/src/init.py", line 2, in
from . import data_pipeline
File "/scratch1/rwibawa/cvit/ns/src/data_pipeline.py", line 13, in
import tensorflow as tf
File "/home1/rwibawa/.local/lib/python3.11/site-packages/tensorflow/init.py", line 467, in
importlib.import_module("keras.src.optimizers")
File "/spack/2206/apps/linux-centos7-x86_64_v3/gcc-11.3.0/python-3.11.3-gl2q3yz/lib/python3.11/importlib/init.py", line 126, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
File "/home1/rwibawa/.local/lib/python3.11/site-packages/keras/init.py", line 2, in
from keras.api import DTypePolicy
File "/home1/rwibawa/.local/lib/python3.11/site-packages/keras/api/init.py", line 8, in
from keras.api import activations
File "/home1/rwibawa/.local/lib/python3.11/site-packages/keras/api/activations/init.py", line 7, in
from keras.src.activations import deserialize
File "/home1/rwibawa/.local/lib/python3.11/site-packages/keras/src/init.py", line 13, in
from keras.src import visualization
File "/home1/rwibawa/.local/lib/python3.11/site-packages/keras/src/visualization/init.py", line 2, in
from keras.src.visualization import plot_image_gallery
File "/home1/rwibawa/.local/lib/python3.11/site-packages/keras/src/visualization/plot_image_gallery.py", line 13, in
import matplotlib.pyplot as plt
File "/spack/2206/apps/linux-centos7-x86_64_v3/gcc-11.3.0/python-3.11.3-gl2q3yz/lib/python3.11/site-packages/matplotlib/init.py", line 129, in
from . import _api, _version, cbook, _docstring, rcsetup
File "/spack/2206/apps/linux-centos7-x86_64_v3/gcc-11.3.0/python-3.11.3-gl2q3yz/lib/python3.11/site-packages/matplotlib/rcsetup.py", line 27, in
from matplotlib.colors import Colormap, is_color_like
File "/spack/2206/apps/linux-centos7-x86_64_v3/gcc-11.3.0/python-3.11.3-gl2q3yz/lib/python3.11/site-packages/matplotlib/colors.py", line 56, in
from matplotlib import _api, _cm, cbook, scale
File "/spack/2206/apps/linux-centos7-x86_64_v3/gcc-11.3.0/python-3.11.3-gl2q3yz/lib/python3.11/site-packages/matplotlib/scale.py", line 22, in
from matplotlib.ticker import (
File "/spack/2206/apps/linux-centos7-x86_64_v3/gcc-11.3.0/python-3.11.3-gl2q3yz/lib/python3.11/site-packages/matplotlib/ticker.py", line 138, in
from matplotlib import transforms as mtransforms
File "/spack/2206/apps/linux-centos7-x86_64_v3/gcc-11.3.0/python-3.11.3-gl2q3yz/lib/python3.11/site-packages/matplotlib/transforms.py", line 49, in
from matplotlib._path import (
AttributeError: _ARRAY_API not found

A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.2 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last): File "/scratch1/rwibawa/cvit/ns/main_nowandb.py", line 5, in
import train_nowandb
File "/scratch1/rwibawa/cvit/ns/train_nowandb.py", line 8, in
from src.model import CVit
File "/scratch1/rwibawa/cvit/ns/src/init.py", line 2, in
from . import data_pipeline
File "/scratch1/rwibawa/cvit/ns/src/data_pipeline.py", line 13, in
import tensorflow as tf
File "/home1/rwibawa/.local/lib/python3.11/site-packages/tensorflow/init.py", line 467, in
importlib.import_module("keras.src.optimizers")
File "/spack/2206/apps/linux-centos7-x86_64_v3/gcc-11.3.0/python-3.11.3-gl2q3yz/lib/python3.11/importlib/init.py", line 126, in import_module
return bootstrap.gcd_import(name[level:], package, level)
File "/home1/rwibawa/.local/lib/python3.11/site-packages/keras/init.py", line 2, in
from keras.api import DTypePolicy
File "/home1/rwibawa/.local/lib/python3.11/site-packages/keras/api/init.py", line 34, in
from keras.api import visualization
File "/home1/rwibawa/.local/lib/python3.11/site-packages/keras/api/visualization/init.py", line 11, in
from keras.src.visualization.plot_bounding_box_gallery import (
File "/home1/rwibawa/.local/lib/python3.11/site-packages/keras/src/visualization/plot_bounding_box_gallery.py", line 12, in
from matplotlib import patches # For legend patches
File "/spack/2206/apps/linux-centos7-x86_64_v3/gcc-11.3.0/python-3.11.3-gl2q3yz/lib/python3.11/site-packages/matplotlib/init.py", line 129, in
from . import api, version, cbook, docstring, rcsetup
File "/spack/2206/apps/linux-centos7-x86_64_v3/gcc-11.3.0/python-3.11.3-gl2q3yz/lib/python3.11/site-packages/matplotlib/rcsetup.py", line 27, in
from matplotlib.colors import Colormap, is_color_like
File "/spack/2206/apps/linux-centos7-x86_64_v3/gcc-11.3.0/python-3.11.3-gl2q3yz/lib/python3.11/site-packages/matplotlib/colors.py", line 56, in
from matplotlib import api, cm, cbook, scale
File "/spack/2206/apps/linux-centos7-x86_64_v3/gcc-11.3.0/python-3.11.3-gl2q3yz/lib/python3.11/site-packages/matplotlib/scale.py", line 22, in
from matplotlib.ticker import (
File "/spack/2206/apps/linux-centos7-x86_64_v3/gcc-11.3.0/python-3.11.3-gl2q3yz/lib/python3.11/site-packages/matplotlib/ticker.py", line 138, in
from matplotlib import transforms as mtransforms
File "/spack/2206/apps/linux-centos7-x86_64_v3/gcc-11.3.0/python-3.11.3-gl2q3yz/lib/python3.11/site-packages/matplotlib/transforms.py", line 49, in
from matplotlib.path import (
AttributeError: ARRAY_API not found
/spack/2206/apps/linux-centos7-x86_64_v3/gcc-11.3.0/python-3.11.3-gl2q3yz/lib/python3.11/site-packages/absl/flags/validators.py:254: UserWarning: Flag --config has a non-None default value; therefore, mark_flag_as_required will pass even if flag is not specified in the command line!
mark_flag_as_required(flag_name, flag_values)
I1227 09:47:57.819218 139859223803712 checkpoint_manager.py:566] [process=0][thread=MainThread] CheckpointManager init: checkpointers=None, item_names=None, item_handlers=None, handler_registry=None
I1227 09:47:57.819641 139859223803712 composite_checkpoint_handler.py:499] Initialized registry DefaultCheckpointHandlerRegistry({('metrics', <class 'orbax.checkpoint.src.handlers.json_checkpoint_handler.JsonSaveArgs'>): <orbax.checkpoint.src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7f3094115e50>, ('metrics', <class 'orbax.checkpoint.src.handlers.json_checkpoint_handler.JsonRestoreArgs'>): <orbax.checkpoint.src.handlers.json_checkpoint_handler.JsonCheckpointHandler object at 0x7f3094115e50>}).
I1227 09:47:57.820122 139859223803712 abstract_checkpointer.py:35] orbax-checkpoint version: 0.10.3
I1227 09:47:57.820212 139859223803712 async_checkpointer.py:76] [process=0][thread=MainThread] Using barrier_sync_fn: <function get_barrier_sync_fn.. at 0x7f30745c23e0> timeout: 600 secs and primary_host=0 for async checkpoint writes
I1227 09:47:57.820417 139859223803712 multihost.py:293] [process=0][thread=MainThread] Skipping global process sync, barrier name: CheckpointManager:create_directory
I1227 09:47:57.825712 139842155443968 checkpoint.py:225] Read Metadata={'init_timestamp_nsecs': 1733541977503152670, 'commit_timestamp_nsecs': 1733541985220808221} from /scratch1/rwibawa/cvit/ns/checkpoints2/303400/CHECKPOINT_METADATA
I1227 09:47:57.845154 139859223803712 checkpoint_manager.py:1469] Found 14 checkpoint steps in /scratch1/rwibawa/cvit/ns/checkpoints2
I1227 09:47:57.845750 139859223803712 checkpoint_manager.py:1510] Saving root metadata
I1227 09:47:57.845834 139859223803712 multihost.py:293] [process=0][thread=MainThread] Skipping global process sync, barrier name: CheckpointManager:save_metadata
I1227 09:47:57.845890 139859223803712 checkpoint_manager.py:733] [process=0][thread=MainThread] CheckpointManager created, primary_host=0, CheckpointManagerOptions=CheckpointManagerOptions(save_interval_steps=1, max_to_keep=10, keep_time_interval=None, keep_period=None, should_keep_fn=None, best_fn=None, best_mode='max', keep_checkpoints_without_metrics=True, step_prefix=None, step_format_fixed_length=None, step_name_format=None, create=True, cleanup_tmp_directories=False, save_on_steps=frozenset(), single_host_load_and_broadcast=False, todelete_subdir=None, enable_background_delete=False, read_only=False, enable_async_checkpointing=True, async_options=None, multiprocessing_options=MultiprocessingOptions(primary_host=0, active_processes=None, barrier_sync_key_prefix=None), should_save_fn=None, file_options=FileOptions(path_permission_mode=None), save_root_metadata=True, temporary_path_class=None), root_directory=/scratch1/rwibawa/cvit/ns/checkpoints2: <orbax.checkpoint.checkpoint_manager.CheckpointManager object at 0x7f3074768090>
I1227 09:47:57.848666 139859223803712 checkpointer.py:236] Restoring checkpoint from /scratch1/rwibawa/cvit/ns/checkpoints2/334600.
I1227 09:47:57.849260 139859223803712 composite_checkpoint_handler.py:554] No entry found in handler registry for item: default and args with type: <class 'orbax.checkpoint.src.handlers.standard_checkpoint_handler.StandardRestoreArgs'>. Falling back to global handler registry.
I1227 09:47:57.849389 139859223803712 base_pytree_checkpoint_handler.py:322] Created BasePyTreeCheckpointHandler: pytree_metadata_options=PyTreeMetadataOptions(support_rich_types=False)
I1227 09:47:57.849548 139859223803712 composite_checkpoint_handler.py:233] Deferred registration for item: "default". Adding handler <orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardCheckpointHandler object at 0x7f30942bfdd0> for item "default" and save args <class 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardSaveArgs'> and restore args <class 'orbax.checkpoint._src.handlers.standard_checkpoint_handler.StandardRestoreArgs'> to _handler_registry.
I1227 09:48:02.480482 139859223803712 base_pytree_checkpoint_handler.py:111] [process=0] /jax/checkpoint/read/bytes_per_sec: 229.4 MiB/s (total bytes: 1.0 GiB) (time elapsed: 4 seconds) (per-host)
I1227 09:48:02.482426 139859223803712 checkpointer.py:239] Finished restoring checkpoint from /scratch1/rwibawa/cvit/ns/checkpoints2/334600.
I1227 09:48:02.482512 139859223803712 multihost.py:293] [process=0][thread=MainThread] Skipping global process sync, barrier name: Checkpointer:restore
I1227 09:48:02.482597 139859223803712 standard_logger.py:34] {'step': 334600, 'event_type': 'restore', 'directory': '/scratch1/rwibawa/cvit/ns/checkpoints2', 'checkpointer_start_time': 1735321677.8485606, 'checkpointer_duration_secs': 4.633980989456177, 'checkpoint_manager_start_time': 1735321677.8460107, 'checkpoint_manager_duration_secs': 4.636533737182617}
/scratch1/rwibawa/cvit/ns/ns_pipeline.py:76: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:278.)
ux = torch.Tensor(f[key+'/ux'])
/scratch1/rwibawa/cvit/ns/ns_pipeline.py:88: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad
(True), rather than torch.tensor(sourceTensor).
mu = torch.tensor([torch.mean(data[:,:,:,:,0]), torch.mean(torch.tensor(visc))])
/scratch1/rwibawa/cvit/ns/ns_pipeline.py:89: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad
(True), rather than torch.tensor(sourceTensor).
std = torch.tensor([torch.mean(data[:,:,:,:,0]), torch.std(torch.tensor(visc))])
device gpu
device gpu
Total number of parameters: 92,601,089
/spack/2206/apps/linux-centos7-x86_64_v3/gcc-11.3.0/python-3.11.3-gl2q3yz/lib/python3.11/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = os.fork()
2024-12-27 09:49:33.465355: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3021] Can't reduce memory use below -3.34GiB (-3587639705 bytes) by rematerialization; only reduced to 33.07GiB (35504783376 bytes), down from 33.07GiB (35504783376 bytes) originally
2024-12-27 09:49:43.791355: W external/xla/xla/tsl/framework/bfc_allocator.cc:501] Allocator (GPU_0_bfc) ran out of memory trying to allocate 32.00GiB (rounded to 34359738368)requested by op
2024-12-27 09:49:43.792525: W external/xla/xla/tsl/framework/bfc_allocator.cc:512] ****************************
_________________________________________________________
E1227 09:49:43.792548 8670 pjrt_stream_executor_client.cc:3086] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 34359738368 bytes.
Traceback (most recent call last):
File "/scratch1/rwibawa/cvit/ns/main_nowandb.py", line 31, in
app.run(main)
File "/spack/2206/apps/linux-centos7-x86_64_v3/gcc-11.3.0/python-3.11.3-gl2q3yz/lib/python3.11/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/spack/2206/apps/linux-centos7-x86_64_v3/gcc-11.3.0/python-3.11.3-gl2q3yz/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
^^^^^^^^^^
File "/scratch1/rwibawa/cvit/ns/main_nowandb.py", line 26, in main
eval.evaluate(FLAGS.config)
File "/scratch1/rwibawa/cvit/ns/eval.py", line 81, in evaluate
pred = model.apply(state.params, x, coords)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/scratch1/rwibawa/cvit/ns/src/model.py", line 364, in call
x = CrossAttnBlock(
^^^^^^^^^^^^^^^
File "/scratch1/rwibawa/cvit/ns/src/model.py", line 133, in call
x = nn.MultiHeadDotProductAttention(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home1/rwibawa/.local/lib/python3.11/site-packages/flax/linen/attention.py", line 674, in call
x = self.attention_fn(*attn_args, **attn_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home1/rwibawa/.local/lib/python3.11/site-packages/flax/linen/attention.py", line 266, in dot_product_attention
attn_weights = dot_product_attention_weights(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home1/rwibawa/.local/lib/python3.11/site-packages/flax/linen/attention.py", line 132, in dot_product_attention_weights
attn_weights = einsum('...qhd,...khd->...hqk', query, key)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home1/rwibawa/.local/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 9747, in einsum
return einsum(operands, contractions, precision,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/spack/2206/apps/linux-centos7-x86_64_v3/gcc-11.3.0/python-3.11.3-gl2q3yz/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 34359738368 bytes.
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.`

@sifanexisted
Copy link
Collaborator

It seems the conflict might be due to using a high NumPy version (> 2.x). I checked my environment, and my current NumPy version is 1.26.4. Could you please try downgrading your NumPy version and see if that resolves the issue?

Also, could you provide the batch size and the number of query points used for training, as well as the batch size used for testing?

@ramdhan1989
Copy link
Author

ramdhan1989 commented Dec 29, 2024

Hi, I still got the same error by using numpy 1.26.4. batch size =1 as shown below. And coords size is (16.384, 2). I think 16384 = 128*128

import ml_collections
import jax.numpy as jnp
def get_config():
import ml_collections
config = ml_collections.ConfigDict()
config.mode = "eval"
config.x_dim = [10, 10, 128, 128, 2]
config.coords_dim = [1024, 2]
config.seed = 42
config.model = model = ml_collections.ConfigDict()
model.patch_size = (1, 4, 4)
model.grid_size = (128, 128)
model.latent_dim = 512
model.emb_dim = 768
model.depth = 15
model.num_heads = 12
model.dec_emb_dim = 512
model.dec_num_heads = 16
model.dec_depth = 1
model.num_mlp_layers = 1
model.mlp_ratio = 2
model.out_dim = 1
model.embedding_type = "grid"
config.dataset = dataset = ml_collections.ConfigDict()
dataset.path = "/scratch1/rwibawa/cvit/NavierStokes-2D/"
dataset.components = ["u", "vx", "vy"]
dataset.prev_steps = 10
dataset.pred_steps = 1
dataset.downsample = 1

dataset.train_samples = 6500
dataset.test_samples = 10
dataset.batch_size = 1 #10
dataset.num_query_points = 1024
dataset.num_workers = 8
config.lr = lr = ml_collections.ConfigDict()
lr.init_value = 0.0
lr.end_value = 1e-6
lr.peak_value = 1e-3
lr.decay_rate = 0.9
lr.transition_steps = 5000
lr.warmup_steps = 5000
config.optim = optim = ml_collections.ConfigDict()
optim.beta1 = 0.9
optim.beta2 = 0.999
optim.eps = 1e-8
optim.weight_decay = 1e-5
optim.clip_norm = 10.0

config.training = training = ml_collections.ConfigDict()
training.num_steps = 4 * 10**5

config.logging = logging = ml_collections.ConfigDict()
logging.log_interval = 200
logging.eval_steps = 10

config.saving = saving = ml_collections.ConfigDict()
saving.save_interval = 200
saving.num_keep_ckpts = 10

config.eval = eval = ml_collections.ConfigDict()
eval.rollout_steps = 4

return config

Thank you

@sifanexisted
Copy link
Collaborator

I noticed that in eval.py, the batch size is hardcoded as follows:

test_loader = DataLoader(
    test_dataset, batch_size=32, shuffle=False, drop_last=True, num_workers=8
)

Did you reduce the batch size as well?

@ramdhan1989
Copy link
Author

Oh you are right, I can run it now. thanks for your help

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants