-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcustom_diffusion.py
407 lines (351 loc) · 16.8 KB
/
custom_diffusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
# Copyright 2022 Adobe Research. All rights reserved.
# To view a copy of the license, visit LICENSE.md.
import os
import tqdm
from pathlib import Path
import requests
from PIL import Image
from io import BytesIO
from clip_retrieval.clip_client import ClipClient
from typing import Callable, Optional
import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from accelerate.logging import get_logger
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.models.cross_attention import CrossAttention
from diffusers.utils.import_utils import is_xformers_available
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
logger = get_logger(__name__)
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
if use_memory_efficient_attention_xformers:
if self.added_kv_proj_dim is not None:
# TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
# which uses this type of cross attention ONLY because the attention mask of format
# [0, ..., -10.000, ..., 0, ...,] is not supported
raise NotImplementedError(
"Memory efficient attention with `xformers` is currently not supported when"
" `self.added_kv_proj_dim` is defined."
)
elif not is_xformers_available():
raise ModuleNotFoundError(
(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers"
),
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
" only available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
processor = CustomDiffusionXFormersAttnProcessor(attention_op=attention_op)
else:
processor = CustomDiffusionAttnProcessor()
self.set_processor(processor)
class CustomDiffusionAttnProcessor:
def __call__(
self,
attn: CrossAttention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states)
crossattn = False
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
crossattn = True
if attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if crossattn:
detach = torch.ones_like(key)
detach[:, :1, :] = detach[:, :1, :]*0.
key = detach*key + (1-detach)*key.detach()
value = detach*value + (1-detach)*value.detach()
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class CustomDiffusionXFormersAttnProcessor:
def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states)
crossattn = False
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
crossattn = True
if attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if crossattn:
detach = torch.ones_like(key)
detach[:, :1, :] = detach[:, :1, :]*0.
key = detach*key + (1-detach)*key.detach()
value = detach*value + (1-detach)*value.detach()
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class CustomDiffusionPipeline(StableDiffusionPipeline):
r"""
Pipeline for custom diffusion model.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.).
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
modifier_token: list of new modifier tokens added or to be added to text_encoder
modifier_token_id: list of id of new modifier tokens added or to be added to text_encoder
"""
_optional_components = ["safety_checker", "feature_extractor", "modifier_token"]
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
modifier_token: list = [],
modifier_token_id: list = [],
):
super().__init__(vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker,
feature_extractor,
requires_safety_checker)
# change attn class
def change_attn(unet):
for layer in unet.children():
if type(layer) == CrossAttention:
bound_method = set_use_memory_efficient_attention_xformers.__get__(layer, layer.__class__)
setattr(layer, 'set_use_memory_efficient_attention_xformers', bound_method)
else:
change_attn(layer)
change_attn(self.unet)
self.unet.set_attn_processor(CustomDiffusionAttnProcessor())
self.modifier_token = modifier_token
self.modifier_token_id = modifier_token_id
def add_token(self, initializer_token):
initializer_token_id = []
for modifier_token_, initializer_token_ in zip(self.modifier_token, initializer_token):
# Add the placeholder token in tokenizer
num_added_tokens = self.tokenizer.add_tokens(modifier_token_)
if num_added_tokens == 0:
raise ValueError(
f"The tokenizer already contains the token {modifier_token_}. Please pass a different"
" `modifier_token` that is not already in the tokenizer."
)
# Convert the initializer_token, placeholder_token to ids
token_ids = self.tokenizer.encode([initializer_token_], add_special_tokens=False)
# Check if initializer_token is a single token or a sequence of tokens
if len(token_ids) > 1:
raise ValueError("The initializer token must be a single token.")
self.modifier_token_id.append(self.tokenizer.convert_tokens_to_ids(modifier_token_))
initializer_token_id.append(token_ids[0])
# Resize the token embeddings as we are adding new special tokens to the tokenizer
self.text_encoder.resize_token_embeddings(len(self.tokenizer))
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = self.text_encoder.get_input_embeddings().weight.data
for (x, y) in zip(self.modifier_token_id, initializer_token_id):
token_embeds[x] = token_embeds[y]
def save_pretrained(self, save_path, freeze_model="crossattn_kv", save_text_encoder=False, all=False):
if all:
super().save_pretrained(save_path)
else:
delta_dict = {'unet': {}, 'modifier_token': {}}
if self.modifier_token is not None:
for i in range(len(self.modifier_token_id)):
learned_embeds = self.text_encoder.get_input_embeddings().concept_token[i]
delta_dict['modifier_token'][self.modifier_token[i]] = learned_embeds.detach().cpu()
if save_text_encoder:
delta_dict['text_encoder'] = self.text_encoder.state_dict()
for name, params in self.unet.named_parameters():
if freeze_model == "crossattn":
if 'attn2' in name:
delta_dict['unet'][name] = params.cpu().clone()
elif freeze_model == "crossattn_kv":
if 'attn2.to_k' in name or 'attn2.to_v' in name:
delta_dict['unet'][name] = params.cpu().clone()
else:
raise ValueError(
"freeze_model argument only supports crossattn_kv or crossattn"
)
torch.save(delta_dict, save_path)
def load_model(self, save_path, compress=False):
st = torch.load(save_path)
if 'text_encoder' in st:
self.text_encoder.load_state_dict(st['text_encoder'])
if 'modifier_token' in st:
modifier_tokens = list(st['modifier_token'].keys())
modifier_token_id = []
for modifier_token in modifier_tokens:
num_added_tokens = self.tokenizer.add_tokens(modifier_token)
if num_added_tokens == 0:
raise ValueError(
f"The tokenizer already contains the token {modifier_token}. Please pass a different"
" `modifier_token` that is not already in the tokenizer."
)
modifier_token_id.append(self.tokenizer.convert_tokens_to_ids(modifier_token))
# Resize the token embeddings as we are adding new special tokens to the tokenizer
self.text_encoder.resize_token_embeddings(len(self.tokenizer))
token_embeds = self.text_encoder.get_input_embeddings().weight.data
for i, id_ in enumerate(modifier_token_id):
token_embeds[id_] = st['modifier_token'][modifier_tokens[i]]
for name, params in self.unet.named_parameters():
if 'attn2' in name:
if compress and ('to_k' in name or 'to_v' in name):
params.data += st['unet'][name]['u']@st['unet'][name]['v']
elif name in st['unet']:
params.data.copy_(st['unet'][f'{name}'])
def create_custom_diffusion(unet, freeze_model):
for name, params in unet.named_parameters():
if freeze_model == 'crossattn':
if 'attn2' in name:
params.requires_grad = True
print(name)
else:
params.requires_grad = False
elif freeze_model == "crossattn_kv":
if 'attn2.to_k' in name or 'attn2.to_v' in name:
params.requires_grad = True
print(name)
else:
params.requires_grad = False
else:
raise ValueError(
"freeze_model argument only supports crossattn_kv or crossattn"
)
# change attn class
def change_attn(unet):
for layer in unet.children():
if type(layer) == CrossAttention:
bound_method = set_use_memory_efficient_attention_xformers.__get__(layer, layer.__class__)
setattr(layer, 'set_use_memory_efficient_attention_xformers', bound_method)
else:
change_attn(layer)
change_attn(unet)
unet.set_attn_processor(CustomDiffusionAttnProcessor())
return unet
def retrieve(target_name, outpath, num_class_images):
num_images = 2*num_class_images
client = ClipClient(url="https://knn.laion.ai/knn-service", indice_name="laion5B-L-14", num_images=num_images,
aesthetic_weight=0.1)
if len(target_name.split()):
target = '_'.join(target_name.split())
else:
target = target_name
os.makedirs(f'{outpath}/{target}', exist_ok=True)
if len(list(Path(f'{outpath}/{target}').iterdir())) >= num_class_images:
return
while True:
print(target_name)
results = client.query(text=target_name)
if len(results) >= num_class_images or num_images > 1e4:
break
else:
num_images = int(1.5*num_images)
client = ClipClient(url="https://knn.laion.ai/knn-service", indice_name="laion5B-L-14", num_images=num_images, aesthetic_weight=0.1)
count = 0
urls = []
captions = []
pbar = tqdm.tqdm(desc='downloading real regularization images', total=num_class_images)
for each in results:
name = f'{outpath}/{target}/{count}.jpg'
success = True
while True:
try:
img = requests.get(each['url'])
success = True
break
except:
success = False
break
if success and img.status_code == 200:
try:
_ = Image.open(BytesIO(img.content))
with open(name, 'wb') as f:
f.write(img.content)
urls.append(each['url'])
captions.append(each['caption'])
count += 1
pbar.update(1)
except:
pass
if count > num_class_images:
break
with open(f'{outpath}/caption.txt', 'w') as f:
for each in captions:
f.write(each.strip() + '\n')
with open(f'{outpath}/urls.txt', 'w') as f:
for each in urls:
f.write(each.strip() + '\n')
with open(f'{outpath}/images.txt', 'w') as f:
for p in range(count):
f.write(f'{outpath}/{target}/{p}.jpg' + '\n')