-
Notifications
You must be signed in to change notification settings - Fork 0
/
Generate_script.py
141 lines (123 loc) · 4.97 KB
/
Generate_script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch
import torch.nn as nn
from torch.nn import functional as F
# Hyper Parameters
block_size = 256
device = 'cuda' if torch.cuda.is_available() else 'cpu'
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.2
# Define your model classes here
class Head(nn.Module):
""" Implements one head of self attention """
def __init__(self, head_size):
super().__init__()
self.key = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B, T, C = x.shape
k = self.key(x)
q = self.query(x)
wei = q @ k.transpose(-1, -2) * C**(-0.5)
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
wei = self.dropout(wei)
v = self.value(x)
out = wei @ v
return out
class MultiHeadAttention(nn.Module):
""" Multiple heads of self attention in parallel """
def __init__(self, num_heads, head_size):
super().__init__()
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
self.proj = nn.Linear(num_heads * head_size, n_embd)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1)
out = self.proj(out)
out = self.dropout(out)
return out
class FeedForward(nn.Module):
""" Feed forward layer. Note -> Feed forward network works at a token level (independent) """
def __init__(self, n_embd):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.ReLU(),
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class Block(nn.Module):
""" Transformer block (decoder sec): Communication followed by computation """
def __init__(self, n_embd, n_head):
super().__init__()
head_size = n_embd // n_head
self.sa = MultiHeadAttention(n_head, head_size=head_size)
self.ffwd = FeedForward(n_embd)
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
def forward(self, x):
x = self.ln1(x)
x = x + self.sa(x)
x = self.ln2(x)
x = x + self.ffwd(x)
return x
class BigramLanguageModel(nn.Module):
def __init__(self):
super().__init__()
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(block_size, n_embd)
self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
self.lm_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size)
def forward(self, idx, targets=None):
B, T = idx.shape
token_emb = self.token_embedding_table(idx) # (B, T, C)
pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T, C)
x = token_emb + pos_emb # (B, T, C)
x = self.blocks(x)
logits = self.lm_head(x) # (B, T, vocab_size)
if targets is None:
loss = None
else:
B, T, C = logits.shape
logits = logits.view(B * T, C)
targets = targets.view(B * T)
loss = F.cross_entropy(logits, targets)
return logits, loss
def generate(self, idx, max_new_tokens):
for _ in range(max_new_tokens):
idx_cond = idx[:, -block_size:]
logits, loss = self(idx_cond)
logits = logits[:, -1, :] # (B, C)
probs = F.softmax(logits, dim=-1) # (B, C)
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
return idx
# Load the character mappings (assuming you saved them or can recreate them)
chars = sorted(list(set(open('Friends_script.txt', 'r', encoding='utf-8').read())))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
decode = lambda l: ''.join([itos[i] for i in l])
# Instantiate the model
model = BigramLanguageModel()
model.to(device)
# Load the model state
model_load_path = "C:/Users/bgpra/OneDrive/Documents/GitHub/Baby-GPT/friends_language_model.pth"
checkpoint = torch.load(model_load_path, map_location=torch.device(device))
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"Model loaded from {model_load_path}")
# Generate text
context = torch.zeros((1, 1), dtype=torch.long, device=device)
generated_text = decode(model.generate(context, max_new_tokens=50000)[0].tolist())
with open('C:/Users/bgpra/OneDrive/Documents/GitHub/Baby-GPT/Generated_script.txt','w',encoding='utf-8') as file:
file.write(generated_text)
print("Script written into Generated_script.txt")