-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathadaptive_diffusion.py
executable file
·115 lines (99 loc) · 4.02 KB
/
adaptive_diffusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import numpy as np
import json
import os
import pandas as pd
def _get_idx_from_list(list_, idx):
if not isinstance(list_, list):
list_ = [list_]
list_arr = np.array(list_)
list_arr_idx = list_arr[idx]
return list(list_arr_idx)
def generate_batch(lst, batch_size):
""" Yields batch of specified size """
for i in range(0, len(lst), batch_size):
yield lst[i: i + batch_size]
class AdaptiveDiffusionPipeline:
def __init__(self, estimator, student, teacher):
self.estimator = estimator
self.score_percentiles = None
self.student = student
self.teacher = teacher
def calc_score_percentiles(
self,
file_path,
n_samples=500,
b_size=1,
prompts_path=None,
**kwargs,
):
if os.path.exists(file_path):
print(f'Loading score percentiles from {file_path}')
with open(f'{file_path}') as f:
data = json.load(f)
self.score_percentiles = {}
for key in data:
self.score_percentiles[int(key)] = data[key]
else:
print(f'Calculating score percentiles on {n_samples} samples from {prompts_path} and saving as {file_path}')
prompts = list(pd.read_csv(prompts_path)['caption'])[:n_samples]
prompts = generate_batch(prompts, b_size)
scores = []
for prompt in prompts:
student_out = self.student(prompt=prompt,
**kwargs).images
for j, p in enumerate(prompt):
score = self.estimator.score(p, student_out[j])
scores.append(score)
score_percentiles = {}
k_list = [10, 20, 30, 40, 50, 60, 70, 80, 90]
for k in k_list:
score_percentiles[k] = np.percentile(scores, k)
self.score_percentiles = score_percentiles
with open(f"{file_path}", "w") as fp:
json.dump(self.score_percentiles, fp)
def __call__(
self,
prompt,
num_inference_steps_student=2,
student_guidance=0.0,
num_inference_steps_teacher=4,
teacher_guidance=8.0,
sigma=0.4,
k=50,
seed=0,
**kwargs
):
# Step 0. Configuration
generator = torch.Generator(device="cuda").manual_seed(seed)
num_all_steps = int(num_inference_steps_teacher / sigma + 1)
chosen_threshold = self.score_percentiles[k]
# Step 1. Student prediction
student_out = self.student(prompt=prompt,
num_inference_steps=num_inference_steps_student,
generator=generator,
guidance_scale=student_guidance,
**kwargs)['images']
# Step 2. Score estimation
reward = []
for p, student_img in zip(prompt, student_out):
reward.append(self.estimator.score(p, student_img))
idx_to_improve = np.array(reward) < chosen_threshold
idx_to_remain = np.array(reward) >= chosen_threshold
# Step 3. Adaptive selection and improvement
if sum(idx_to_improve) > 0:
improved_out = self.teacher(
prompt=_get_idx_from_list(prompt, idx_to_improve),
image=_get_idx_from_list(student_out, idx_to_improve),
num_inference_steps=num_all_steps,
guidance_scale=teacher_guidance,
strength=sigma,
**kwargs
)['images']
final_out = improved_out + _get_idx_from_list(student_out, idx_to_remain)
else:
final_out = student_out
total_number_generation_steps = num_inference_steps_student + sum(idx_to_improve) / len(
prompt) * num_inference_steps_teacher
print(f'Total number of generation steps: {int(total_number_generation_steps)}')
return final_out