-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgpt-czech-poet-userfriendly.py
executable file
·72 lines (57 loc) · 1.76 KB
/
gpt-czech-poet-userfriendly.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/usr/bin/env python3
#coding: utf-8
import sys
import logging
logging.basicConfig(
format='%(asctime)s %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO)
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("jinymusim/gpt-czech-poet")
model = AutoModelForCausalLM.from_pretrained("jinymusim/gpt-czech-poet")
metra = {
'J': 'jamb',
'T': 'trochej',
}
while True:
# Input Poet Start
poet_start = input('Zadej rýmové schéma, např. AABB nebo ABCABC: ')
poet_start = poet_start.strip().upper()
if not poet_start:
poet_start = 'AABB'
# tokenize input
tokenized_poet_start = tokenizer.encode(poet_start, return_tensors='pt')
# generated a continuation to it
out = model.generate(
tokenized_poet_start,
max_length=256,
num_beams=8,
no_repeat_ngram_size=2,
early_stopping=True,
pad_token_id= tokenizer.pad_token_id,
eos_token_id = tokenizer.eos_token_id
)
# Decode Poet
decoded_cont = tokenizer.decode(out[0], skip_special_tokens=True)
result = decoded_cont.split('\n')
header = result[0]
try:
schema, year, meter = header.split(' # ')
if meter in metra:
meter = metra[meter]
except:
schema = poet_start
year = '?'
meter = '?'
poem = result[1:]
print()
print(f'''Zde je vygenerovaná báseň s rýmovým schématem {schema}, používající metrum typu {meter}, ve stylu roku {year}:''')
print()
for line in poem:
try:
cleanline = line.split('#')[1].strip()
except:
cleanline = line
print(cleanline)
print()