-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
Copy pathfull_finetune_example.py
122 lines (95 loc) · 4.14 KB
/
full_finetune_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
This script is meant to be the simplest possible starting point for full finetuning a GPT model using lightning fabric with code (not CLI).
- no checkpoints
- no out dir
- no precision
- no resume
- no train/eval args (or any args in general)
- no logger (only to terminal)
- no grad accumulation
and no other fancy stuff.
To add all the above stuff, you can slowly add them in yourself by looking at the code in litgpt/finetune/full.py or the docs for litgpt/fabric.
"""
import os
import lightning as L
import torch
import torch.nn as nn
from litgpt.data import Alpaca
from litgpt.model import GPT, Config
from litgpt.tokenizer import Tokenizer
from litgpt.utils import num_parameters
# training params/args
SEED = 1337
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" # try also "stabilityai/stablelm-base-alpha-3b"!
BATCH_SIZE = 4
LR_WARMUP_STEPS = 100
MAX_STEPS = 601
def validate(model, val_dataloader):
model.eval()
loss = 0
with torch.no_grad():
for batch in val_dataloader:
input_ids, targets = batch["input_ids"], batch["labels"]
logits = model(input_ids)
logits = logits.reshape(-1, logits.size(-1))
targets = targets.reshape(-1)
loss += nn.functional.cross_entropy(logits[..., :-1, :], targets[..., 1:])
fabric.print(f"Validation loss: {loss / len(val_dataloader)}")
def train(fabric, model, optimizer, scheduler, train_dataloader, val_dataloader):
for iter_num, batch in enumerate(train_dataloader):
input_ids, targets = batch["input_ids"], batch["labels"]
# get model preds (logits)
logits = model(input_ids)
logits = logits.reshape(-1, logits.size(-1))
# get loss
targets = targets.reshape(-1)
loss = nn.functional.cross_entropy(logits[..., :-1, :], targets[..., 1:])
# update weights
fabric.backward(loss)
optimizer.step()
optimizer.zero_grad()
scheduler.step()
# print train loss every 100 steps
if iter_num % 100 == 0 or iter_num == 0:
fabric.print(f"Train iter {iter_num} - loss {loss}")
# validate every 300 steps
if iter_num % 300 == 0 or iter_num == 0:
validate(model, val_dataloader)
model.train()
iter_num += 1
if iter_num >= MAX_STEPS:
break
def main(fabric):
fabric.seed_everything(SEED)
# setup data, make tokenizer and make dataloaders
data = Alpaca()
tokenizer = Tokenizer(checkpoint_dir=f"checkpoints/{MODEL_NAME}")
data.connect(tokenizer=tokenizer, batch_size=BATCH_SIZE, max_seq_length=1024)
data.setup()
train_dataloader = data.train_dataloader()
val_dataloader = data.val_dataloader()
train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)
# print how many steps in an epoch
fabric.print(f"Steps in an epoch: {len(train_dataloader)}")
# setup model
config = Config.from_file(f"checkpoints/{MODEL_NAME}/model_config.yaml")
model = GPT(config)
fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}")
model = fabric.setup(model)
# setup optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3, weight_decay=0.02, betas=(0.9, 0.95))
optimizer = fabric.setup_optimizers(optimizer)
# setup lr scheduler
scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / LR_WARMUP_STEPS)
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(MAX_STEPS - LR_WARMUP_STEPS))
scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[LR_WARMUP_STEPS])
# Start training!!!
train(fabric, model, optimizer, scheduler, train_dataloader, val_dataloader)
if __name__ == "__main__":
# check that the model exists (downloaded to ./checkpoints/)
if not os.path.exists(f"checkpoints/{MODEL_NAME}"):
print(f"Model {MODEL_NAME} not found. Please download it using `litgpt download --repo {MODEL_NAME}`")
exit()
### Setup and launch
fabric = L.Fabric(devices="auto", strategy="auto")
fabric.launch(main)