-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathaa_online_tree_pre.py
70 lines (54 loc) · 2.81 KB
/
aa_online_tree_pre.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import VitMach
import math
import torch
from torch import nn
import torchvision
import numpy as np
import pickle
import utilities
from torch.utils.data import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.data import Subset
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
softmax = nn.Softmax(dim=1)
NLL = nn.NLLLoss(reduction='sum')
mse = torch.nn.MSELoss(reduction='sum')
bce = torch.nn.BCELoss(reduction='sum')
cos = torch.nn.CosineSimilarity(dim=0)
relu = nn.ReLU()
#Reduces data to specified number of examples per category
def get_data(shuf=False, data=2):
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
if data == 0:
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=False, transform=transform)
elif data == 1:
trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=False, transform=transform)
elif data == 2:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
elif data == 3:
trainset = torchvision.datasets.SVHN(root='./data', split='train', download=False, transform=transform)
else:
dir = 'C:/Users/nalon/Documents/PythonScripts/tiny-imagenet-200/train'
trainset = torchvision.datasets.ImageFolder(dir, transform=transform)
dir = 'C:/Users/nalon/Documents/PythonScripts/tiny-imagenet-200/test'
testset = torchvision.datasets.ImageFolder(dir, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=shuf)
test_loader = torch.utils.data.DataLoader(testset, batch_size=10000, shuffle=False)
return train_loader, test_loader
def pretrn_online(arch=8, data=0, dev='cuda', max_iter=200, num_seeds=1, alpha=200, in_dim=64, in_chn=3):
with torch.no_grad():
# Memorize
for s in range(num_seeds):
model = Tree.Tree(in_dim=in_dim, in_chnls=in_chn, wt_up=wtupType, alpha=alpha, arch=arch, chnls=chnls).to(dev)
train_loader, test_loader = get_data(shuf=shuf, data=data, max_iter=max_iter)
mem_images = torch.zeros(0, in_chn, in_dim, in_dim).to(dev)
for batch_idx, (images, y) in enumerate(train_loader):
images = images.to(dev)
mem_images = torch.cat((mem_images, images), dim=0)
model.update_wts(images)
torch.save(model, f'models/preTree_arch{arch}_data{data}_seed{s}')
pretrn_online(arch=6, data=4, dev='cuda', max_iter=2000, num_seeds=1, in_dim=64, in_chn=4, )