diff --git a/src/fairchem/core/common/distutils.py b/src/fairchem/core/common/distutils.py index 54eb48347d..a79faa5a4e 100644 --- a/src/fairchem/core/common/distutils.py +++ b/src/fairchem/core/common/distutils.py @@ -69,13 +69,12 @@ def setup(config) -> None: f"Init: {config['init_method']}, {config['world_size']}, {config['rank']}" ) - # ensures GPU0 does not have extra context/higher peak memory + assign_device_for_local_rank(config["cpu"], config["local_rank"]) + logging.info( - f"local rank: {config['local_rank']}, visible devices: {os.environ['CUDA_VISIBLE_DEVICES']}" + f"local rank: {config['local_rank']}, rank: {config['rank']}, visible devices: {os.environ['CUDA_VISIBLE_DEVICES']}" ) - assign_device_for_local_rank(config["cpu"], config["local_rank"]) - dist.init_process_group( backend="nccl", init_method=config["init_method"],