forked from openai/whisper
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathnotes.txt
107 lines (74 loc) · 5.14 KB
/
notes.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
-------------------
codepath
transcribe.py -> transcribe
model.decode
decoding.py -> decode
decoding.py -> DecodingTask.run
DecodingTask._get_audio_features (uses encoder only)
--------------- Dichotomy here!
DecodingTask._main_loop (uses decoder only)
PyTorchInference.logits (uses decoder only)
Whisper.install_kv_cache_hooks (uses decoder only)
-----------------------
model loading info
checkpoint = torch.load(fp, map_location=device)
...
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])
-----------------------
dims, try setting to 0 for encoder / decoder separately:
self.encoder = AudioEncoder(
self.dims.n_mels,
self.dims.n_audio_ctx,
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
)
self.decoder = TextDecoder(
self.dims.n_vocab,
self.dims.n_text_ctx,
self.dims.n_text_state,
self.dims.n_text_head,
self.dims.n_text_layer,
)
state dict keys: ['decoder.positional_embedding', 'encoder.positional_embedding', 'decoder.token_embedding.weight', 'decoder.blocks.0.mlp_ln.weight',
'decoder.blocks.0.mlp_ln.bias', 'decoder.blocks.0.mlp.0.weight', 'decoder.blocks.0.mlp.0.bias',
'decoder.blocks.0.mlp.2.weight', 'decoder.blocks.0.mlp.2.bias', 'decoder.blocks.0.attn_ln.weight',
'decoder.blocks.0.attn_ln.bias', 'decoder.blocks.0.attn.query.weight', 'decoder.blocks.0.attn.query.bias',
'decoder.blocks.0.attn.key.weight', 'decoder.blocks.0.attn.value.weight', 'decoder.blocks.0.attn.value.bias',
'decoder.blocks.0.attn.out.weight', 'decoder.blocks.0.attn.out.bias', 'decoder.blocks.0.cross_attn_ln.weight',
'decoder.blocks.0.cross_attn_ln.bias', 'decoder.blocks.0.cross_attn.query.weight', 'decoder.blocks.0.cross_attn.query.bias',
'decoder.blocks.0.cross_attn.key.weight', 'decoder.blocks.0.cross_attn.value.weight', 'decoder.blocks.0.cross_attn.value.bias',
'decoder.blocks.0.cross_attn.out.weight', 'decoder.blocks.0.cross_attn.out.bias', 'decoder.blocks.1.mlp_ln.weight', 'decoder.blocks.1.mlp_ln.bias',
'decoder.blocks.1.mlp.0.weight', 'decoder.blocks.1.mlp.0.bias', 'decoder.blocks.1.mlp.2.weight', 'decoder.blocks.1.mlp.2.bias',
'decoder.blocks.1.attn_ln.weight', 'decoder.blocks.1.attn_ln.bias', 'decoder.blocks.1.attn.query.weight',
'decoder.blocks.1.attn.query.bias', 'decoder.blocks.1.attn.key.weight', 'decoder.blocks.1.attn.value.weight',
'decoder.blocks.1.attn.value.bias', 'decoder.blocks.1.attn.out.weight', 'decoder.blocks.1.attn.out.bias', 'decoder.blocks.1.cross_attn_ln.weight',
'decoder.blocks.1.cross_attn_ln.bias', 'decoder.blocks.1.cross_attn.query.weight', 'decoder.blocks.1.cross_attn.query.bias',
'decoder.blocks.1.cross_attn.key.weight', 'decoder.blocks.1.cross_attn.value.weight', 'decoder.blocks.1.cross_attn.value.bias',
'decoder.blocks.1.cross_attn.out.weight', 'decoder.blocks.1.cross_attn.out.bias', 'decoder.blocks.2.mlp_ln.weight',
'decoder.blocks.2.mlp_ln.bias', 'decoder.blocks.2.mlp.0.weight', 'decoder.blocks.2.mlp.0.bias', 'decoder.blocks.2.mlp.2.weight',
'decoder.blocks.2.mlp.2.bias', 'decoder.blocks.2.attn_ln.weight', 'decoder.blocks.2.attn_ln.bias', 'decoder.blocks.2.attn.query.weight',
'decoder.blocks.2.attn.query.bias', 'decoder.blocks.2.attn.key.weight', 'decoder.blocks.2.attn.value.weight', 'decoder.blocks.2.attn.value.bias',
'decoder.blocks.2.attn.out.weight', 'decoder.blocks.2.attn.out.bias', 'decoder.blocks.2.cross_attn_ln.weight', 'decoder.blocks.2.cross_attn_ln.bias',
'decoder.blocks.2.cross_attn.query.weight', 'decoder.blocks.2.cross_attn.query.bias', 'decoder.blocks.2.cross_attn.key.weight', 'decoder.blocks.2.cross_attn.value.weight',
'decoder.blocks.2.cross_attn.value.bias', 'decoder.blocks.2.cross_attn.out.weight', 'decoder.blocks.2.cross_attn.out.bias', 'decoder.blocks.3.mlp_ln.weight',
'decoder.blocks.3.mlp_ln.bias', 'decoder.blocks.3.mlp.0.weight', 'decoder.blocks.3.mlp.0.bias', 'decoder.blocks.3.mlp.2.weight', 'decoder.blocks.3.mlp.2.bias',
'decoder.blocks.3.attn_ln.weight', 'decoder.blocks.3.attn_ln.bias', 'decoder.blocks.3.attn.query.weight', 'decoder.blocks.3.attn.query.bias',
'decoder.blocks.3.attn.key.weight', 'decoder.blocks.3.attn.value.weight', 'decoder.blocks.3.attn.value.bias', 'decoder.blocks.3.attn.out.weight',
'decoder.blocks.3.attn.out.bias', 'decoder.blocks.3.cross_attn_ln.weight', 'decoder.blocks.3.cross_attn_ln.bias', 'decoder.blocks.3.cross_attn.query.weight',
'decoder.blocks.3.cross_attn.query.bias', 'decoder.blocks.3.cross_attn.key.weight', 'decoder.blocks.3.cross_attn.value.weight', 'decoder.blocks.3.cross_attn.value.bias',
'decoder.blocks.3.cross_attn.out.weight', 'decoder.blocks.3.cross_attn.out.bias', 'decoder.blocks.4.mlp_ln.weight', 'decoder.blocks.4.mlp_ln.bias', 'decoder.blocks.4.mlp.0.weight',
----------------------- archived info
Decoder functions:
inference: PytorchInference
inference.logits
decoding.py -> decode (= model.decode in transcribe.py)
model.logits
Encoder functions:
model.embed_audio
model._get_audio_features
decoding.py -> detect_language
-------------------
model.forward uses both encode and decode (but is unused.)