-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdisplay_weights.py
executable file
·51 lines (40 loc) · 2.01 KB
/
display_weights.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import print_function
import sys
import argparse
import torch
import numpy as np
np.set_printoptions(threshold=sys.maxsize)
def main():
# Training settings
parser = argparse.ArgumentParser(description='Displays quantized weights after conversion')
parser.add_argument('--model-path', type=str, default="./models/mnist_quantized_converted.pth",
help='MNIST trained model after quantization and conversion')
parser.add_argument('--chip-format', action='store_true', default=False,
help='Plots weights in chip format')
args = parser.parse_args()
checkpoint = torch.load(args.model_path)
print('Found {} parameters: {}'.format(len(checkpoint), list(checkpoint)))
for p_name in list(checkpoint):
p_val = checkpoint[p_name].numpy()
print('============================== {} =============================='.format(p_name))
print('=========== {} | (min={:.2f}, mean={:.2f}, max={:.2f}) ==========='.format(p_val.shape, p_val.min(), p_val.mean(), p_val.max()))
if len(p_val.shape) > 1 and args.chip_format:
out_chan, in_chan, kernel_height, kernel_width = p_val.shape
print('// Start of file {}'.format(p_name))
print('#RC {} {}'.format(in_chan, out_chan))
for o_c in range(out_chan):
print('// Column number {}'.format(o_c))
for i_c in range (in_chan):
kernel_val = p_val[o_c, i_c]
kernel_val = kernel_val.flatten()
kernel_val = kernel_val.tolist()
kernel_val = ' '.join([str(v) for v in kernel_val])
print(kernel_val)
print('// End of file')
else:
print(p_val)
print('==========================================================================')
if __name__ == '__main__':
main()