rsortino commited on
Commit
28d8474
1 Parent(s): de6ffcd

Upload 6 files

Browse files
Files changed (6) hide show
  1. colorize.py +61 -0
  2. gradio_colorization.py +118 -0
  3. utils/data.py +47 -0
  4. utils/ddim.py +317 -0
  5. utils/diffusion.py +259 -0
  6. utils/model.py +43 -0
colorize.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import cv2
4
+ import einops
5
+ import numpy as np
6
+ import torch
7
+ from pytorch_lightning import seed_everything
8
+
9
+ from utils.data import HWC3, apply_color, resize_image
10
+ from utils.ddim import DDIMSampler
11
+ from utils.model import create_model, load_state_dict
12
+
13
+ model = create_model('./models/cldm_v21.yaml').cpu()
14
+ model.load_state_dict(load_state_dict(
15
+ 'lightning_logs/version_6/checkpoints/colorizenet-sd21.ckpt', location='cuda'))
16
+ model = model.cuda()
17
+ ddim_sampler = DDIMSampler(model)
18
+
19
+
20
+ input_image = cv2.imread("sample_data/sample1_bw.jpg")
21
+ input_image = HWC3(input_image)
22
+ img = resize_image(input_image, resolution=512)
23
+ H, W, C = img.shape
24
+
25
+ num_samples = 1
26
+ control = torch.from_numpy(img.copy()).float().cuda() / 255.0
27
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
28
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
29
+
30
+
31
+ # seed = random.randint(0, 65535)
32
+ seed = 1294574436
33
+ seed_everything(seed)
34
+ prompt = "Colorize this image"
35
+ n_prompt = ""
36
+ guess_mode = False
37
+ strength = 1.0
38
+ eta = 0.0
39
+ ddim_steps = 20
40
+ scale = 9.0
41
+
42
+ cond = {"c_concat": [control], "c_crossattn": [
43
+ model.get_learned_conditioning([prompt] * num_samples)]}
44
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [
45
+ model.get_learned_conditioning([n_prompt] * num_samples)]}
46
+ shape = (4, H // 8, W // 8)
47
+
48
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else (
49
+ [strength] * 13)
50
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
51
+ shape, cond, verbose=False, eta=eta,
52
+ unconditional_guidance_scale=scale,
53
+ unconditional_conditioning=un_cond)
54
+
55
+ x_samples = model.decode_first_stage(samples)
56
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c')
57
+ * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
58
+
59
+ results = [x_samples[i] for i in range(num_samples)]
60
+ colored_results = [apply_color(img, result) for result in results]
61
+ [cv2.imwrite(f"colorized_{i}.jpg", cv2.cvtColor(result, cv2.COLOR_RGB2BGR)) for i, result in enumerate(colored_results)]
gradio_colorization.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from share import *
2
+ import config
3
+
4
+ import cv2
5
+ import einops
6
+ import gradio as gr
7
+ import numpy as np
8
+ import torch
9
+ import random
10
+
11
+ from pytorch_lightning import seed_everything
12
+ from annotator.util import resize_image, HWC3
13
+ from cldm.model import create_model, load_state_dict
14
+ from cldm.ddim_hacked import DDIMSampler
15
+
16
+
17
+ model = create_model('./models/cldm_v21.yaml').cpu()
18
+ model.load_state_dict(load_state_dict(
19
+ 'lightning_logs/version_6/checkpoints/colorizenet-sd21.ckpt', location='cuda'))
20
+ model = model.cuda()
21
+ ddim_sampler = DDIMSampler(model)
22
+
23
+
24
+ def apply_color(image, color_map):
25
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
26
+ color_map = cv2.cvtColor(color_map, cv2.COLOR_RGB2LAB)
27
+
28
+ l, _, _ = cv2.split(image)
29
+ _, a, b = cv2.split(color_map)
30
+
31
+ merged = cv2.merge([l, a, b])
32
+
33
+ result = cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)
34
+ return result
35
+
36
+
37
+ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta):
38
+ with torch.no_grad():
39
+ input_image = HWC3(input_image)
40
+ img = resize_image(input_image, image_resolution)
41
+ H, W, C = img.shape
42
+
43
+ control = torch.from_numpy(img.copy()).float().cuda() / 255.0
44
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
45
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
46
+
47
+ if seed == -1:
48
+ seed = random.randint(0, 65535)
49
+ seed_everything(seed)
50
+
51
+ if config.save_memory:
52
+ model.low_vram_shift(is_diffusing=False)
53
+
54
+ cond = {"c_concat": [control], "c_crossattn": [
55
+ model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
56
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [
57
+ model.get_learned_conditioning([n_prompt] * num_samples)]}
58
+ shape = (4, H // 8, W // 8)
59
+
60
+ if config.save_memory:
61
+ model.low_vram_shift(is_diffusing=True)
62
+
63
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else (
64
+ [strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
65
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
66
+ shape, cond, verbose=False, eta=eta,
67
+ unconditional_guidance_scale=scale,
68
+ unconditional_conditioning=un_cond)
69
+
70
+ if config.save_memory:
71
+ model.low_vram_shift(is_diffusing=False)
72
+
73
+ x_samples = model.decode_first_stage(samples)
74
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c')
75
+ * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
76
+
77
+ results = [x_samples[i] for i in range(num_samples)]
78
+ colored_results = [apply_color(img, result) for result in results]
79
+ return [img] + results + colored_results
80
+
81
+
82
+ block = gr.Blocks().queue()
83
+ with block:
84
+ with gr.Row():
85
+ gr.Markdown("## Colorize images with Stable Diffusion")
86
+ with gr.Row():
87
+ with gr.Column():
88
+ input_image = gr.Image(source='upload', type="numpy")
89
+ prompt = gr.Textbox(label="Prompt")
90
+ run_button = gr.Button(label="Run")
91
+ with gr.Accordion("Advanced options", open=False):
92
+ num_samples = gr.Slider(
93
+ label="Images", minimum=1, maximum=12, value=1, step=1)
94
+ image_resolution = gr.Slider(
95
+ label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
96
+ strength = gr.Slider(
97
+ label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
98
+ guess_mode = gr.Checkbox(label='Guess Mode', value=False)
99
+ ddim_steps = gr.Slider(
100
+ label="Steps", minimum=1, maximum=100, value=20, step=1)
101
+ scale = gr.Slider(label="Guidance Scale",
102
+ minimum=0.1, maximum=30.0, value=9.0, step=0.1)
103
+ seed = gr.Slider(label="Seed", minimum=-1,
104
+ maximum=2147483647, step=1, randomize=True)
105
+ eta = gr.Number(label="eta (DDIM)", value=0.0)
106
+ a_prompt = gr.Textbox(
107
+ label="Added Prompt", value='best quality, natural colors')
108
+ n_prompt = gr.Textbox(label="Negative Prompt",
109
+ value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
110
+ with gr.Column():
111
+ result_gallery = gr.Gallery(
112
+ label='Output', show_label=False, elem_id="gallery").style(grid=3, height='auto')
113
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution,
114
+ ddim_steps, guess_mode, strength, scale, seed, eta]
115
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
116
+
117
+
118
+ block.launch(server_name='0.0.0.0')
utils/data.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+
5
+ # Data utils
6
+ def HWC3(x):
7
+ assert x.dtype == np.uint8
8
+ if x.ndim == 2:
9
+ x = x[:, :, None]
10
+ assert x.ndim == 3
11
+ H, W, C = x.shape
12
+ assert C == 1 or C == 3 or C == 4
13
+ if C == 3:
14
+ return x
15
+ if C == 1:
16
+ return np.concatenate([x, x, x], axis=2)
17
+ if C == 4:
18
+ color = x[:, :, 0:3].astype(np.float32)
19
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
20
+ y = color * alpha + 255.0 * (1.0 - alpha)
21
+ y = y.clip(0, 255).astype(np.uint8)
22
+ return y
23
+
24
+
25
+ def resize_image(input_image, resolution):
26
+ H, W, C = input_image.shape
27
+ H = float(H)
28
+ W = float(W)
29
+ k = float(resolution) / min(H, W)
30
+ H *= k
31
+ W *= k
32
+ H = int(np.round(H / 64.0)) * 64
33
+ W = int(np.round(W / 64.0)) * 64
34
+ img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
35
+ return img
36
+
37
+ def apply_color(image, color_map):
38
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
39
+ color_map = cv2.cvtColor(color_map, cv2.COLOR_RGB2LAB)
40
+
41
+ l, _, _ = cv2.split(image)
42
+ _, a, b = cv2.split(color_map)
43
+
44
+ merged = cv2.merge([l, a, b])
45
+
46
+ result = cv2.cvtColor(merged, cv2.COLOR_LAB2RGB)
47
+ return result
utils/ddim.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+
7
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
8
+
9
+
10
+ class DDIMSampler(object):
11
+ def __init__(self, model, schedule="linear", **kwargs):
12
+ super().__init__()
13
+ self.model = model
14
+ self.ddpm_num_timesteps = model.num_timesteps
15
+ self.schedule = schedule
16
+
17
+ def register_buffer(self, name, attr):
18
+ if type(attr) == torch.Tensor:
19
+ if attr.device != torch.device("cuda"):
20
+ attr = attr.to(torch.device("cuda"))
21
+ setattr(self, name, attr)
22
+
23
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
24
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
25
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
26
+ alphas_cumprod = self.model.alphas_cumprod
27
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
28
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
29
+
30
+ self.register_buffer('betas', to_torch(self.model.betas))
31
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
32
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
33
+
34
+ # calculations for diffusion q(x_t | x_{t-1}) and others
35
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
36
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
37
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
38
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
39
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
40
+
41
+ # ddim sampling parameters
42
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
43
+ ddim_timesteps=self.ddim_timesteps,
44
+ eta=ddim_eta,verbose=verbose)
45
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
46
+ self.register_buffer('ddim_alphas', ddim_alphas)
47
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
48
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
49
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
50
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
51
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
52
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
53
+
54
+ @torch.no_grad()
55
+ def sample(self,
56
+ S,
57
+ batch_size,
58
+ shape,
59
+ conditioning=None,
60
+ callback=None,
61
+ normals_sequence=None,
62
+ img_callback=None,
63
+ quantize_x0=False,
64
+ eta=0.,
65
+ mask=None,
66
+ x0=None,
67
+ temperature=1.,
68
+ noise_dropout=0.,
69
+ score_corrector=None,
70
+ corrector_kwargs=None,
71
+ verbose=True,
72
+ x_T=None,
73
+ log_every_t=100,
74
+ unconditional_guidance_scale=1.,
75
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
76
+ dynamic_threshold=None,
77
+ ucg_schedule=None,
78
+ **kwargs
79
+ ):
80
+ if conditioning is not None:
81
+ if isinstance(conditioning, dict):
82
+ ctmp = conditioning[list(conditioning.keys())[0]]
83
+ while isinstance(ctmp, list): ctmp = ctmp[0]
84
+ cbs = ctmp.shape[0]
85
+ if cbs != batch_size:
86
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
87
+
88
+ elif isinstance(conditioning, list):
89
+ for ctmp in conditioning:
90
+ if ctmp.shape[0] != batch_size:
91
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
92
+
93
+ else:
94
+ if conditioning.shape[0] != batch_size:
95
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
96
+
97
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
98
+ # sampling
99
+ C, H, W = shape
100
+ size = (batch_size, C, H, W)
101
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
102
+
103
+ samples, intermediates = self.ddim_sampling(conditioning, size,
104
+ callback=callback,
105
+ img_callback=img_callback,
106
+ quantize_denoised=quantize_x0,
107
+ mask=mask, x0=x0,
108
+ ddim_use_original_steps=False,
109
+ noise_dropout=noise_dropout,
110
+ temperature=temperature,
111
+ score_corrector=score_corrector,
112
+ corrector_kwargs=corrector_kwargs,
113
+ x_T=x_T,
114
+ log_every_t=log_every_t,
115
+ unconditional_guidance_scale=unconditional_guidance_scale,
116
+ unconditional_conditioning=unconditional_conditioning,
117
+ dynamic_threshold=dynamic_threshold,
118
+ ucg_schedule=ucg_schedule
119
+ )
120
+ return samples, intermediates
121
+
122
+ @torch.no_grad()
123
+ def ddim_sampling(self, cond, shape,
124
+ x_T=None, ddim_use_original_steps=False,
125
+ callback=None, timesteps=None, quantize_denoised=False,
126
+ mask=None, x0=None, img_callback=None, log_every_t=100,
127
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
128
+ unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
129
+ ucg_schedule=None):
130
+ device = self.model.betas.device
131
+ b = shape[0]
132
+ if x_T is None:
133
+ img = torch.randn(shape, device=device)
134
+ else:
135
+ img = x_T
136
+
137
+ if timesteps is None:
138
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
139
+ elif timesteps is not None and not ddim_use_original_steps:
140
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
141
+ timesteps = self.ddim_timesteps[:subset_end]
142
+
143
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
144
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
145
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
146
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
147
+
148
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
149
+
150
+ for i, step in enumerate(iterator):
151
+ index = total_steps - i - 1
152
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
153
+
154
+ if mask is not None:
155
+ assert x0 is not None
156
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
157
+ img = img_orig * mask + (1. - mask) * img
158
+
159
+ if ucg_schedule is not None:
160
+ assert len(ucg_schedule) == len(time_range)
161
+ unconditional_guidance_scale = ucg_schedule[i]
162
+
163
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
164
+ quantize_denoised=quantize_denoised, temperature=temperature,
165
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
166
+ corrector_kwargs=corrector_kwargs,
167
+ unconditional_guidance_scale=unconditional_guidance_scale,
168
+ unconditional_conditioning=unconditional_conditioning,
169
+ dynamic_threshold=dynamic_threshold)
170
+ img, pred_x0 = outs
171
+ if callback: callback(i)
172
+ if img_callback: img_callback(pred_x0, i)
173
+
174
+ if index % log_every_t == 0 or index == total_steps - 1:
175
+ intermediates['x_inter'].append(img)
176
+ intermediates['pred_x0'].append(pred_x0)
177
+
178
+ return img, intermediates
179
+
180
+ @torch.no_grad()
181
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
182
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
183
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
184
+ dynamic_threshold=None):
185
+ b, *_, device = *x.shape, x.device
186
+
187
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
188
+ model_output = self.model.apply_model(x, t, c)
189
+ else:
190
+ model_t = self.model.apply_model(x, t, c)
191
+ model_uncond = self.model.apply_model(x, t, unconditional_conditioning)
192
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
193
+
194
+ if self.model.parameterization == "v":
195
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
196
+ else:
197
+ e_t = model_output
198
+
199
+ if score_corrector is not None:
200
+ assert self.model.parameterization == "eps", 'not implemented'
201
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
202
+
203
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
204
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
205
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
206
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
207
+ # select parameters corresponding to the currently considered timestep
208
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
209
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
210
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
211
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
212
+
213
+ # current prediction for x_0
214
+ if self.model.parameterization != "v":
215
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
216
+ else:
217
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
218
+
219
+ if quantize_denoised:
220
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
221
+
222
+ if dynamic_threshold is not None:
223
+ raise NotImplementedError()
224
+
225
+ # direction pointing to x_t
226
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
227
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
228
+ if noise_dropout > 0.:
229
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
230
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
231
+ return x_prev, pred_x0
232
+
233
+ @torch.no_grad()
234
+ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
235
+ unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
236
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
237
+ num_reference_steps = timesteps.shape[0]
238
+
239
+ assert t_enc <= num_reference_steps
240
+ num_steps = t_enc
241
+
242
+ if use_original_steps:
243
+ alphas_next = self.alphas_cumprod[:num_steps]
244
+ alphas = self.alphas_cumprod_prev[:num_steps]
245
+ else:
246
+ alphas_next = self.ddim_alphas[:num_steps]
247
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
248
+
249
+ x_next = x0
250
+ intermediates = []
251
+ inter_steps = []
252
+ for i in tqdm(range(num_steps), desc='Encoding Image'):
253
+ t = torch.full((x0.shape[0],), timesteps[i], device=self.model.device, dtype=torch.long)
254
+ if unconditional_guidance_scale == 1.:
255
+ noise_pred = self.model.apply_model(x_next, t, c)
256
+ else:
257
+ assert unconditional_conditioning is not None
258
+ e_t_uncond, noise_pred = torch.chunk(
259
+ self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
260
+ torch.cat((unconditional_conditioning, c))), 2)
261
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
262
+
263
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
264
+ weighted_noise_pred = alphas_next[i].sqrt() * (
265
+ (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
266
+ x_next = xt_weighted + weighted_noise_pred
267
+ if return_intermediates and i % (
268
+ num_steps // return_intermediates) == 0 and i < num_steps - 1:
269
+ intermediates.append(x_next)
270
+ inter_steps.append(i)
271
+ elif return_intermediates and i >= num_steps - 2:
272
+ intermediates.append(x_next)
273
+ inter_steps.append(i)
274
+ if callback: callback(i)
275
+
276
+ out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
277
+ if return_intermediates:
278
+ out.update({'intermediates': intermediates})
279
+ return x_next, out
280
+
281
+ @torch.no_grad()
282
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
283
+ # fast, but does not allow for exact reconstruction
284
+ # t serves as an index to gather the correct alphas
285
+ if use_original_steps:
286
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
287
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
288
+ else:
289
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
290
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
291
+
292
+ if noise is None:
293
+ noise = torch.randn_like(x0)
294
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
295
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
296
+
297
+ @torch.no_grad()
298
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
299
+ use_original_steps=False, callback=None):
300
+
301
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
302
+ timesteps = timesteps[:t_start]
303
+
304
+ time_range = np.flip(timesteps)
305
+ total_steps = timesteps.shape[0]
306
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
307
+
308
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
309
+ x_dec = x_latent
310
+ for i, step in enumerate(iterator):
311
+ index = total_steps - i - 1
312
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
313
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
314
+ unconditional_guidance_scale=unconditional_guidance_scale,
315
+ unconditional_conditioning=unconditional_conditioning)
316
+ if callback: callback(i)
317
+ return x_dec
utils/diffusion.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from einops import repeat
6
+
7
+ from .model import instantiate_from_config
8
+
9
+
10
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
11
+ if schedule == "linear":
12
+ betas = (
13
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
14
+ )
15
+
16
+ elif schedule == "cosine":
17
+ timesteps = (
18
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
19
+ )
20
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
21
+ alphas = torch.cos(alphas).pow(2)
22
+ alphas = alphas / alphas[0]
23
+ betas = 1 - alphas[1:] / alphas[:-1]
24
+ betas = np.clip(betas, a_min=0, a_max=0.999)
25
+
26
+ elif schedule == "sqrt_linear":
27
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
28
+ elif schedule == "sqrt":
29
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
30
+ else:
31
+ raise ValueError(f"schedule '{schedule}' unknown.")
32
+ return betas.numpy()
33
+
34
+
35
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
36
+ if ddim_discr_method == 'uniform':
37
+ c = num_ddpm_timesteps // num_ddim_timesteps
38
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
39
+ elif ddim_discr_method == 'quad':
40
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
41
+ else:
42
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
43
+
44
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
45
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
46
+ steps_out = ddim_timesteps + 1
47
+ if verbose:
48
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
49
+ return steps_out
50
+
51
+
52
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
53
+ # select alphas for computing the variance schedule
54
+ alphas = alphacums[ddim_timesteps]
55
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
56
+
57
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
58
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
59
+ if verbose:
60
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
61
+ print(f'For the chosen value of eta, which is {eta}, '
62
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
63
+ return sigmas, alphas, alphas_prev
64
+
65
+
66
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
67
+ """
68
+ Create a beta schedule that discretizes the given alpha_t_bar function,
69
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
70
+ :param num_diffusion_timesteps: the number of betas to produce.
71
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
72
+ produces the cumulative product of (1-beta) up to that
73
+ part of the diffusion process.
74
+ :param max_beta: the maximum beta to use; use values lower than 1 to
75
+ prevent singularities.
76
+ """
77
+ betas = []
78
+ for i in range(num_diffusion_timesteps):
79
+ t1 = i / num_diffusion_timesteps
80
+ t2 = (i + 1) / num_diffusion_timesteps
81
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
82
+ return np.array(betas)
83
+
84
+
85
+ def extract_into_tensor(a, t, x_shape):
86
+ b, *_ = t.shape
87
+ out = a.gather(-1, t)
88
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
89
+
90
+
91
+ def checkpoint(func, inputs, params, flag):
92
+ """
93
+ Evaluate a function without caching intermediate activations, allowing for
94
+ reduced memory at the expense of extra compute in the backward pass.
95
+ :param func: the function to evaluate.
96
+ :param inputs: the argument sequence to pass to `func`.
97
+ :param params: a sequence of parameters `func` depends on but does not
98
+ explicitly take as arguments.
99
+ :param flag: if False, disable gradient checkpointing.
100
+ """
101
+ if flag:
102
+ args = tuple(inputs) + tuple(params)
103
+ return CheckpointFunction.apply(func, len(inputs), *args)
104
+ else:
105
+ return func(*inputs)
106
+
107
+
108
+ class CheckpointFunction(torch.autograd.Function):
109
+ @staticmethod
110
+ def forward(ctx, run_function, length, *args):
111
+ ctx.run_function = run_function
112
+ ctx.input_tensors = list(args[:length])
113
+ ctx.input_params = list(args[length:])
114
+ ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
115
+ "dtype": torch.get_autocast_gpu_dtype(),
116
+ "cache_enabled": torch.is_autocast_cache_enabled()}
117
+ with torch.no_grad():
118
+ output_tensors = ctx.run_function(*ctx.input_tensors)
119
+ return output_tensors
120
+
121
+ @staticmethod
122
+ def backward(ctx, *output_grads):
123
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
124
+ with torch.enable_grad(), \
125
+ torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
126
+ # Fixes a bug where the first op in run_function modifies the
127
+ # Tensor storage in place, which is not allowed for detach()'d
128
+ # Tensors.
129
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
130
+ output_tensors = ctx.run_function(*shallow_copies)
131
+ input_grads = torch.autograd.grad(
132
+ output_tensors,
133
+ ctx.input_tensors + ctx.input_params,
134
+ output_grads,
135
+ allow_unused=True,
136
+ )
137
+ del ctx.input_tensors
138
+ del ctx.input_params
139
+ del output_tensors
140
+ return (None, None) + input_grads
141
+
142
+
143
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
144
+ """
145
+ Create sinusoidal timestep embeddings.
146
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
147
+ These may be fractional.
148
+ :param dim: the dimension of the output.
149
+ :param max_period: controls the minimum frequency of the embeddings.
150
+ :return: an [N x dim] Tensor of positional embeddings.
151
+ """
152
+ if not repeat_only:
153
+ half = dim // 2
154
+ freqs = torch.exp(
155
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
156
+ ).to(device=timesteps.device)
157
+ args = timesteps[:, None].float() * freqs[None]
158
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
159
+ if dim % 2:
160
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
161
+ else:
162
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
163
+ return embedding
164
+
165
+
166
+ def zero_module(module):
167
+ """
168
+ Zero out the parameters of a module and return it.
169
+ """
170
+ for p in module.parameters():
171
+ p.detach().zero_()
172
+ return module
173
+
174
+
175
+ def scale_module(module, scale):
176
+ """
177
+ Scale the parameters of a module and return it.
178
+ """
179
+ for p in module.parameters():
180
+ p.detach().mul_(scale)
181
+ return module
182
+
183
+
184
+ def mean_flat(tensor):
185
+ """
186
+ Take the mean over all non-batch dimensions.
187
+ """
188
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
189
+
190
+
191
+ def normalization(channels):
192
+ """
193
+ Make a standard normalization layer.
194
+ :param channels: number of input channels.
195
+ :return: an nn.Module for normalization.
196
+ """
197
+ return GroupNorm32(32, channels)
198
+
199
+
200
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
201
+ class SiLU(nn.Module):
202
+ def forward(self, x):
203
+ return x * torch.sigmoid(x)
204
+
205
+
206
+ class GroupNorm32(nn.GroupNorm):
207
+ def forward(self, x):
208
+ return super().forward(x.float()).type(x.dtype)
209
+
210
+ def conv_nd(dims, *args, **kwargs):
211
+ """
212
+ Create a 1D, 2D, or 3D convolution module.
213
+ """
214
+ if dims == 1:
215
+ return nn.Conv1d(*args, **kwargs)
216
+ elif dims == 2:
217
+ return nn.Conv2d(*args, **kwargs)
218
+ elif dims == 3:
219
+ return nn.Conv3d(*args, **kwargs)
220
+ raise ValueError(f"unsupported dimensions: {dims}")
221
+
222
+
223
+ def linear(*args, **kwargs):
224
+ """
225
+ Create a linear module.
226
+ """
227
+ return nn.Linear(*args, **kwargs)
228
+
229
+
230
+ def avg_pool_nd(dims, *args, **kwargs):
231
+ """
232
+ Create a 1D, 2D, or 3D average pooling module.
233
+ """
234
+ if dims == 1:
235
+ return nn.AvgPool1d(*args, **kwargs)
236
+ elif dims == 2:
237
+ return nn.AvgPool2d(*args, **kwargs)
238
+ elif dims == 3:
239
+ return nn.AvgPool3d(*args, **kwargs)
240
+ raise ValueError(f"unsupported dimensions: {dims}")
241
+
242
+
243
+ class HybridConditioner(nn.Module):
244
+
245
+ def __init__(self, c_concat_config, c_crossattn_config):
246
+ super().__init__()
247
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
248
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
249
+
250
+ def forward(self, c_concat, c_crossattn):
251
+ c_concat = self.concat_conditioner(c_concat)
252
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
253
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
254
+
255
+
256
+ def noise_like(shape, device, repeat=False):
257
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
258
+ noise = lambda: torch.randn(shape, device=device)
259
+ return repeat_noise() if repeat else noise()
utils/model.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import OmegaConf
2
+ from ldm.util import instantiate_from_config
3
+ import importlib
4
+ import os
5
+ import torch
6
+
7
+
8
+ def create_model(config_path):
9
+ config = OmegaConf.load(config_path)
10
+ model = instantiate_from_config(config.model).cpu()
11
+ print(f'Loaded model config from [{config_path}]')
12
+ return model
13
+
14
+ def instantiate_from_config(config):
15
+ if not "target" in config:
16
+ if config == '__is_first_stage__':
17
+ return None
18
+ elif config == "__is_unconditional__":
19
+ return None
20
+ raise KeyError("Expected key `target` to instantiate.")
21
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
22
+
23
+
24
+ def get_obj_from_str(string, reload=False):
25
+ module, cls = string.rsplit(".", 1)
26
+ if reload:
27
+ module_imp = importlib.import_module(module)
28
+ importlib.reload(module_imp)
29
+ return getattr(importlib.import_module(module, package=None), cls)
30
+
31
+ def get_state_dict(d):
32
+ return d.get('state_dict', d)
33
+
34
+ def load_state_dict(ckpt_path, location='cpu'):
35
+ _, extension = os.path.splitext(ckpt_path)
36
+ if extension.lower() == ".safetensors":
37
+ import safetensors.torch
38
+ state_dict = safetensors.torch.load_file(ckpt_path, device=location)
39
+ else:
40
+ state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
41
+ state_dict = get_state_dict(state_dict)
42
+ print(f'Loaded state_dict from [{ckpt_path}]')
43
+ return state_dict