-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathtrain_classifier.py
79 lines (59 loc) · 2.64 KB
/
train_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim
from torch.autograd import Variable
from scipy import misc
from resnet import resnet
def save_checkpoint(state, filename='black_box_func.pth'):
torch.save(state, filename)
#it seems wrong tar ,so convert to pth
#def load_checkpoint(net,optimizer,filename='black_box_func.pth'):
# checkpoint = torch.load(filename)
# net.load_state_dict(checkpoint['state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer'])
#return net,optimizer
#it has not been used for the code
def cifar10():
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
return trainloader,testloader,classes
trainloader,testloader,classes = cifar10()
black_box_func = resnet()
black_box_func = black_box_func.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(black_box_func.parameters())
for epoch in range(6): # loop over the dataset multiple times
running_loss = 0.0
running_corrects = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs
inputs, labels = data
# wrap them in Variable
inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
out = black_box_func(inputs)
_, preds = torch.max(out.data, 1)
loss = criterion(out,labels)
running_corrects += torch.sum(preds == labels.data)
running_corrects=running_corrects.float()#avoid the acc=0
running_loss += loss.data
if(i%100 == 0):
print('Epoch = %f , Accuracy = %f, Loss = %f '%(epoch+1 , running_corrects/(4*(i+1)), running_loss/(4*(i+1))) )
loss.backward()
optimizer.step()
save_checkpoint(black_box_func, filename='black_box_func.pth')