-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
46 lines (37 loc) · 1.54 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import numpy as np
def inferece(y_pred, y):
""" Performs inference """
return (y_pred == y).sum().item()
def train_epoch(epochs, train_loader,
model, optimizer, loss_fn, threshold):
# Not implemented validation for each epoch
loss_history = []
for i in range(epochs):
temp_loss = 0
correct_total = 0
batch_size = 0
for batch_x, batch_y, y_ in train_loader:
batch_size = batch_x.shape[0]
# print(f"Iteration [{i+1}/{epochs}] Training")
q1, q2, y = batch_x.cuda(), batch_y.cuda(), y_.cuda()
# print(q1.shape, q2.shape, y.shape)
# Reset the gardients
optimizer.zero_grad()
# Model forward and predictions
similarity = model(q1, q2)
y_pred = (similarity > threshold).float() * 1
# print(y_pred.shape, y.shape)
correct = inferece(y_pred, y)
correct_total += correct
# Calculate the loss
loss = loss_fn(similarity, y)
temp_loss += torch.abs(loss).item()
# Calculate gradients by performign the backward pass
loss.backward()
# Update weights
optimizer.step()
loss_history.append(temp_loss)
print(f"Epoch: {i}, train_loss MAE: {temp_loss/batch_size}")
# Not enabled learning rate scheduler
return loss_history