Skip to content

Commit

Permalink
take back some of the changes
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jan 18, 2024
1 parent 18d51ce commit fcf16d1
Showing 1 changed file with 34 additions and 19 deletions.
53 changes: 34 additions & 19 deletions examples/multigpu/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,31 +187,46 @@ def train(

model.train()
total_loss = torch.tensor(0, dtype=torch.float, device=device)
for step, data in (
tqdm.tqdm(enumerate(train_dataloader))
if rank == 0
else enumerate(train_dataloader)
):
# The input features are from the source nodes in the first
# layer's computation graph.
x = data.node_features["feat"]
########################################################################
# (HIGHLIGHT) Use Join Context Manager to solve uneven input problem.
#
# The mechanics of Distributed Data Parallel (DDP) training in PyTorch
# requires the number of inputs are the same for all ranks, otherwise
# the program may error or hang. To solve it, PyTorch provides Join
# Context Manager. Please refer to
# https://pytorch.org/tutorials/advanced/generic_join.html for detailed
# information.
#
# Another method is to set `drop_uneven_inputs` as True in GraphBolt's
# DistributedItemSampler, which will solve this problem by dropping
# uneven inputs.
########################################################################
with Join([model]):
for step, data in (
tqdm.tqdm(enumerate(train_dataloader))
if rank == 0
else enumerate(train_dataloader)
):
# The input features are from the source nodes in the first
# layer's computation graph.
x = data.node_features["feat"]

# The ground truth labels are from the destination nodes
# in the last layer's computation graph.
y = data.labels
# The ground truth labels are from the destination nodes
# in the last layer's computation graph.
y = data.labels

blocks = data.blocks
blocks = data.blocks

y_hat = model(blocks, x)
y_hat = model(blocks, x)

# Compute loss.
loss = F.cross_entropy(y_hat, y)
# Compute loss.
loss = F.cross_entropy(y_hat, y)

optimizer.zero_grad()
loss.backward()
optimizer.step()
optimizer.zero_grad()
loss.backward()
optimizer.step()

total_loss += loss.detach()
total_loss += loss.detach()

# Evaluate the model.
if rank == 0:
Expand Down

0 comments on commit fcf16d1

Please sign in to comment.