forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 3
/
test_bundled_inputs.py
137 lines (112 loc) · 5.04 KB
/
test_bundled_inputs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#!/usr/bin/env python3
import io
import torch
import torch.utils.bundled_inputs
from torch.testing._internal.common_utils import TestCase, run_tests
def model_size(sm):
buffer = io.BytesIO()
torch.jit.save(sm, buffer)
return len(buffer.getvalue())
def save_and_load(sm):
buffer = io.BytesIO()
torch.jit.save(sm, buffer)
buffer.seek(0)
return torch.jit.load(buffer)
class TestBundledInputs(TestCase):
def test_single_tensors(self):
class SingleTensorModel(torch.nn.Module):
def forward(self, arg):
return arg
sm = torch.jit.script(SingleTensorModel())
original_size = model_size(sm)
get_expr = []
samples = [
# Tensor with small numel and small storage.
(torch.tensor([1]),),
# Tensor with large numel and small storage.
(torch.tensor([[2, 3, 4]]).expand(1 << 16, -1)[:, ::2],),
# Tensor with small numel and large storage.
(torch.tensor(range(1 << 16))[-8:],),
# Large zero tensor.
(torch.zeros(1 << 16),),
# Large channels-last ones tensor.
(torch.ones(4, 8, 32, 32).contiguous(memory_format=torch.channels_last),),
# Special encoding of random tensor.
(torch.utils.bundled_inputs.bundle_randn(1 << 16),),
]
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
sm, samples, get_expr)
# print(get_expr[0])
# print(sm._generate_bundled_inputs.code)
# Make sure the model only grew a little bit,
# despite having nominally large bundled inputs.
augmented_size = model_size(sm)
self.assertLess(augmented_size, original_size + (1 << 12))
loaded = save_and_load(sm)
inflated = loaded.get_all_bundled_inputs()
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
self.assertEqual(len(inflated), len(samples))
self.assertTrue(loaded.run_on_bundled_input(0) is inflated[0][0])
for idx, inp in enumerate(inflated):
self.assertIsInstance(inp, tuple)
self.assertEqual(len(inp), 1)
self.assertIsInstance(inp[0], torch.Tensor)
if idx != 5:
# Strides might be important for benchmarking.
self.assertEqual(inp[0].stride(), samples[idx][0].stride())
self.assertEqual(inp[0], samples[idx][0], exact_dtype=True)
# This tensor is random, but with 100,000 trials,
# mean and std had ranges of (-0.0154, 0.0144) and (0.9907, 1.0105).
self.assertEqual(inflated[5][0].shape, (1 << 16,))
self.assertAlmostEqual(inflated[5][0].mean().item(), 0, delta=0.025)
self.assertAlmostEqual(inflated[5][0].std().item(), 1, delta=0.02)
def test_large_tensor_with_inflation(self):
class SingleTensorModel(torch.nn.Module):
def forward(self, arg):
return arg
sm = torch.jit.script(SingleTensorModel())
sample_tensor = torch.randn(1 << 16)
# We can store tensors with custom inflation functions regardless
# of size, even if inflation is just the identity.
sample = torch.utils.bundled_inputs.bundle_large_tensor(sample_tensor)
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
sm, [(sample,)])
loaded = save_and_load(sm)
inflated = loaded.get_all_bundled_inputs()
self.assertEqual(len(inflated), 1)
self.assertEqual(inflated[0][0], sample_tensor)
def test_rejected_tensors(self):
def check_tensor(sample):
# Need to define the class in this scope to get a fresh type for each run.
class SingleTensorModel(torch.nn.Module):
def forward(self, arg):
return arg
sm = torch.jit.script(SingleTensorModel())
with self.assertRaisesRegex(Exception, "Bundled input argument"):
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
sm, [(sample,)])
# Plain old big tensor.
check_tensor(torch.randn(1 << 16))
# This tensor has two elements, but they're far apart in memory.
# We currently cannot represent this compactly while preserving
# the strides.
small_sparse = torch.randn(2, 1 << 16)[:, 0:1]
self.assertEqual(small_sparse.numel(), 2)
check_tensor(small_sparse)
def test_non_tensors(self):
class StringAndIntModel(torch.nn.Module):
def forward(self, fmt: str, num: int):
return fmt.format(num)
sm = torch.jit.script(StringAndIntModel())
samples = [
("first {}", 1),
("second {}", 2),
]
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
sm, samples)
loaded = save_and_load(sm)
inflated = loaded.get_all_bundled_inputs()
self.assertEqual(inflated, samples)
self.assertTrue(loaded.run_on_bundled_input(0) == "first 1")
if __name__ == '__main__':
run_tests()