Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DistributedDataParallel #26

Open
LanceGe opened this issue Jul 17, 2020 · 8 comments
Open

DistributedDataParallel #26

LanceGe opened this issue Jul 17, 2020 · 8 comments

Comments

@LanceGe
Copy link

LanceGe commented Jul 17, 2020

Is there a way to make SupContrast work for DistributedDataParallel? By default each worker can only see its own sub-batch so the inter-sub-batch relationship of the samples will be utilized.

@HobbitLong
Copy link
Owner

You can use all_gather to gather features together. The caveat is that you need to manually propagate gradients through all_gather op, as it doesn't auto-bp.

@LanceGe
Copy link
Author

LanceGe commented Jul 18, 2020

You can use all_gather to gather features together. The caveat is that you need to manually propagate gradients through all_gather op, as it doesn't auto-bp.

I finally make it work with the help of diffdist, which provides a differentiable all_gather wrapper.

@ShijianXu
Copy link

Hi, can you share your code about how to implement this? I am not familiar with all_gather .etc operations. Thanks a lot.

@LanceGe
Copy link
Author

LanceGe commented Jul 21, 2020

Hi, can you share your code about how to implement this? I am not familiar with all_gather .etc operations. Thanks a lot.

First, install diffdist.
Then put the following snippet before calling the criterion:

    import diffdist.functional as distops

    features = distops.all_gather(
        gather_list=[torch.zeros_like(features) for _ in range(torch.distributed.get_world_size())],
        tensor=features,
        next_backprop=None,
        inplace=True,
    )
    features = torch.cat(features)

    labels = distops.all_gather(
        gather_list=[torch.zeros_like(labels) for _ in range(torch.distributed.get_world_size())],
        tensor=labels,
        next_backprop=None,
        inplace=True,
    )
    labels = torch.cat(labels)

@ShijianXu
Copy link

Thank you for your quick reply.

So, then I can simply compute the loss as usual and then backward the gradient?

@LanceGe
Copy link
Author

LanceGe commented Jul 21, 2020 via email

@ShijianXu
Copy link

OK. Anyway, thanks a lot.

@ShijianXu
Copy link

Just for reference, this seems to be a reliable solution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants