1aurent commited on
Commit
13531f3
1 Parent(s): 3ea49df

add enhancer app

Browse files
.gitignore ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio_cached_examples/
2
+
3
+ # https://github.com/github/gitignore/blob/main/Python.gitignore
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # poetry
101
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105
+ #poetry.lock
106
+
107
+ # pdm
108
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109
+ #pdm.lock
110
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111
+ # in version control.
112
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
113
+ .pdm.toml
114
+ .pdm-python
115
+ .pdm-build/
116
+
117
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118
+ __pypackages__/
119
+
120
+ # Celery stuff
121
+ celerybeat-schedule
122
+ celerybeat.pid
123
+
124
+ # SageMath parsed files
125
+ *.sage.py
126
+
127
+ # Environments
128
+ .env
129
+ .venv
130
+ env/
131
+ venv/
132
+ ENV/
133
+ env.bak/
134
+ venv.bak/
135
+
136
+ # Spyder project settings
137
+ .spyderproject
138
+ .spyproject
139
+
140
+ # Rope project settings
141
+ .ropeproject
142
+
143
+ # mkdocs documentation
144
+ /site
145
+
146
+ # mypy
147
+ .mypy_cache/
148
+ .dmypy.json
149
+ dmypy.json
150
+
151
+ # Pyre type checker
152
+ .pyre/
153
+
154
+ # pytype static type analyzer
155
+ .pytype/
156
+
157
+ # Cython debug symbols
158
+ cython_debug/
159
+
160
+ # PyCharm
161
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
164
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165
+ #.idea/
README.md CHANGED
@@ -4,9 +4,31 @@ emoji: 🖼️🪄
4
  colorFrom: pink
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 4.37.2
8
- app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  colorFrom: pink
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 4.38.1
8
+ app_file: src/app.py
9
  pinned: false
10
  ---
11
 
12
+ # Enhancer
13
+
14
+ ## Links
15
+
16
+ - https://blog.finegrain.ai/posts/reproducing-clarity-upscaler/
17
+ - https://github.com/finegrain-ai/refiners
18
+ - https://github.com/philz1337x/clarity-upscaler
19
+ - https://finegrain.ai/
20
+
21
+ ## Example image credits
22
+
23
+ - https://r2.clarityai.co/inputs/13_before.webp by [Clarity AI](https://clarityai.co/)
24
+ - https://unsplash.com/photos/L7EwHkq1B2s by [Kara Eads](https://unsplash.com/@karaeads)
25
+ - https://unsplash.com/photos/gtDYwUIr9Vg by [Melissa Walker Horn](https://unsplash.com/@eilivsonas)
26
+ - https://unsplash.com/photos/rW-I87aPY5Y by [Karina Vorozheeva](https://unsplash.com/@_k_arinn)
27
+ - https://unsplash.com/photos/jggQZkITXng by [Tadeusz Lakota](https://unsplash.com/@tadekl)
28
+ - https://unsplash.com/photos/hIaOPjYCEj4 by [KaroGraphix Photography](https://unsplash.com/@karographix)
29
+ - https://unsplash.com/photos/X53e51WfjlE by [Ryoji Iwata](https://unsplash.com/@ryoji__iwata)
30
+ - https://unsplash.com/photos/gJH8AqpiSEU by [Edgar.infocus](https://unsplash.com/@edgar_infocus)
31
+ - https://unsplash.com/photos/_XjW3oN8UOE by [Jeremy Wallace](https://unsplash.com/@jdanielw)
32
+
33
+ All unsplash images are under the [Unplash License](https://unsplash.com/license). \
34
+ All unsplash images were downloaded in the "small" format (width=640px).
examples/clarity_bird.webp ADDED

Git LFS Details

  • SHA256: a1bf18d88b928ba178dda6c773e4e3327c08f17c743c8f124ab03f3ef7100a65
  • Pointer size: 130 Bytes
  • Size of remote file: 42.5 kB
examples/edgar-infocus-gJH8AqpiSEU-unsplash.jpg ADDED

Git LFS Details

  • SHA256: d458cb591d83eaed54f406f0ff625a640ed3c72c02fc4728f66ceb5cd354cc0b
  • Pointer size: 130 Bytes
  • Size of remote file: 49.4 kB
examples/jeremy-wallace-_XjW3oN8UOE-unsplash.jpg ADDED

Git LFS Details

  • SHA256: 04f17db7d49fd915237c7721f2723b43b4a0c53acfd737c3742b864993aac71e
  • Pointer size: 131 Bytes
  • Size of remote file: 118 kB
examples/kara-eads-L7EwHkq1B2s-unsplash.jpg ADDED

Git LFS Details

  • SHA256: f3c6c772b5ef805f9b317d3b9547940a990ca8c983689ef553dd76a8be49a398
  • Pointer size: 130 Bytes
  • Size of remote file: 66.5 kB
examples/karina-vorozheeva-rW-I87aPY5Y-unsplash.jpg ADDED

Git LFS Details

  • SHA256: 5e16ee5b0aae24133be45e7685dbabebbfafc17dfe08a7ceba3fc5b0f42a66dc
  • Pointer size: 130 Bytes
  • Size of remote file: 90 kB
examples/karographix-photography-hIaOPjYCEj4-unsplash.jpg ADDED

Git LFS Details

  • SHA256: 83860b586d31cf981df74b7579bbcef9eede87d294b47d3a88211bbc9af25501
  • Pointer size: 130 Bytes
  • Size of remote file: 66 kB
examples/melissa-walker-horn-gtDYwUIr9Vg-unsplash.jpg ADDED

Git LFS Details

  • SHA256: 1d1a7d7f5c5fff1ed335e0e1998f86fe2e4c20565f5a95dac72dc279cb60dae4
  • Pointer size: 130 Bytes
  • Size of remote file: 81.1 kB
examples/ryoji-iwata-X53e51WfjlE-unsplash.jpg ADDED

Git LFS Details

  • SHA256: 4ad5c2789d05d1b7e4727d820e5bc8c9707e76e2f10408966970752acf398e36
  • Pointer size: 131 Bytes
  • Size of remote file: 194 kB
examples/tadeusz-lakota-jggQZkITXng-unsplash.jpg ADDED

Git LFS Details

  • SHA256: 128df6ad73db9eaa6fa9befe9eb43c91d500da4cedbe6ec4188f3ba336beaf7c
  • Pointer size: 130 Bytes
  • Size of remote file: 64.7 kB
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ git+https://github.com/finegrain-ai/refiners@299217f45ab788bb7e670bcafb37a789a054461f
2
+ gradio_imageslider==0.0.20
3
+ spaces==0.28.3
4
+ numpy<2.0.0
src/app.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import torch
3
+
4
+
5
+ # see https://huggingface.co/spaces/zero-gpu-explorers/README/discussions/85
6
+ def my_arange(*args, **kwargs):
7
+ return torch.arange(*args, **kwargs)
8
+
9
+
10
+ torch.arange = my_arange
11
+
12
+ from pathlib import Path
13
+
14
+ import gradio as gr
15
+ from gradio_imageslider import ImageSlider
16
+ from huggingface_hub import hf_hub_download
17
+ from PIL import Image
18
+ from refiners.fluxion.utils import manual_seed
19
+ from refiners.foundationals.latent_diffusion import Solver, solvers
20
+
21
+ from enhancer import ESRGANUpscaler, ESRGANUpscalerCheckpoints
22
+
23
+ TITLE = """
24
+ <h1 align="center">Image Enhancer, implemented using refiners</h1>
25
+
26
+ <p>
27
+ <center>
28
+ <a style="font-size: 1.25rem;" href="https://blog.finegrain.ai/posts/reproducing-clarity-upscaler/" target="_blank">[blog post]</a>
29
+ <a style="font-size: 1.25rem;" href="https://github.com/finegrain-ai/refiners" target="_blank">[refiners]</a>
30
+ <a style="font-size: 1.25rem;" href="https://github.com/philz1337x/clarity-upscaler" target="_blank">[clarity-upscaler]</a>
31
+ <a style="font-size: 1.25rem;" href="https://finegrain.ai/" target="_blank">[finegrain]</a>
32
+ </center>
33
+ </p>
34
+ """
35
+
36
+ CHECKPOINTS = ESRGANUpscalerCheckpoints(
37
+ unet=Path(
38
+ hf_hub_download(
39
+ repo_id="refiners/juggernaut.reborn",
40
+ filename="unet.safetensors",
41
+ revision="948510aaf4c8e8e9b32b5a7c25736422253f7b93",
42
+ )
43
+ ),
44
+ clip_text_encoder=Path(
45
+ hf_hub_download(
46
+ repo_id="refiners/juggernaut.reborn",
47
+ filename="text_encoder.safetensors",
48
+ revision="948510aaf4c8e8e9b32b5a7c25736422253f7b93",
49
+ )
50
+ ),
51
+ lda=Path(
52
+ hf_hub_download(
53
+ repo_id="refiners/juggernaut.reborn",
54
+ filename="autoencoder.safetensors",
55
+ revision="948510aaf4c8e8e9b32b5a7c25736422253f7b93",
56
+ )
57
+ ),
58
+ controlnet_tile=Path(
59
+ hf_hub_download(
60
+ repo_id="refiners/controlnet.sd15.tile",
61
+ filename="model.safetensors",
62
+ revision="48ced6ff8bfa873a8976fa467c3629a240643387",
63
+ )
64
+ ),
65
+ esrgan=Path(
66
+ hf_hub_download(
67
+ repo_id="philz1337x/upscaler",
68
+ filename="4x-UltraSharp.pth",
69
+ revision="011deacac8270114eb7d2eeff4fe6fa9a837be70",
70
+ )
71
+ ),
72
+ negative_embedding=Path(
73
+ hf_hub_download(
74
+ repo_id="philz1337x/embeddings",
75
+ filename="JuggernautNegative-neg.pt",
76
+ revision="203caa7e9cc2bc225031a4021f6ab1ded283454a",
77
+ )
78
+ ),
79
+ negative_embedding_key="string_to_param.*",
80
+ loras={
81
+ "more_details": Path(
82
+ hf_hub_download(
83
+ repo_id="philz1337x/loras",
84
+ filename="more_details.safetensors",
85
+ revision="a3802c0280c0d00c2ab18d37454a8744c44e474e",
86
+ )
87
+ ),
88
+ "sdxl_render": Path(
89
+ hf_hub_download(
90
+ repo_id="philz1337x/loras",
91
+ filename="SDXLrender_v2.0.safetensors",
92
+ revision="a3802c0280c0d00c2ab18d37454a8744c44e474e",
93
+ )
94
+ ),
95
+ },
96
+ )
97
+
98
+ LORA_SCALES = {
99
+ "more_details": 0.5,
100
+ "sdxl_render": 1.0,
101
+ }
102
+
103
+ # initialize the enhancer, on the cpu
104
+ DEVICE_CPU = torch.device("cpu")
105
+ DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
106
+ enhancer = ESRGANUpscaler(checkpoints=CHECKPOINTS, device=DEVICE_CPU, dtype=DTYPE)
107
+
108
+ # "move" the enhancer to the gpu, this is handled by Zero GPU
109
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
110
+ enhancer.to(device=DEVICE, dtype=DTYPE)
111
+
112
+
113
+ @spaces.GPU
114
+ def process(
115
+ input_image: Image.Image,
116
+ prompt: str = "masterpiece, best quality, highres",
117
+ negative_prompt: str = "worst quality, low quality, normal quality",
118
+ seed: int = 42,
119
+ upscale_factor: int = 2,
120
+ controlnet_scale: float = 0.6,
121
+ controlnet_decay: float = 1.0,
122
+ condition_scale: int = 6,
123
+ tile_width: int = 112,
124
+ tile_height: int = 144,
125
+ denoise_strength: float = 0.35,
126
+ num_inference_steps: int = 18,
127
+ solver: str = "DDIM",
128
+ ) -> tuple[Image.Image, Image.Image]:
129
+ manual_seed(seed)
130
+
131
+ solver_type: type[Solver] = getattr(solvers, solver)
132
+
133
+ enhanced_image = enhancer.upscale(
134
+ image=input_image,
135
+ prompt=prompt,
136
+ negative_prompt=negative_prompt,
137
+ upscale_factor=upscale_factor,
138
+ controlnet_scale=controlnet_scale,
139
+ controlnet_scale_decay=controlnet_decay,
140
+ condition_scale=condition_scale,
141
+ tile_size=(tile_height, tile_width),
142
+ denoise_strength=denoise_strength,
143
+ num_inference_steps=num_inference_steps,
144
+ loras_scale=LORA_SCALES,
145
+ solver_type=solver_type,
146
+ )
147
+
148
+ return (input_image, enhanced_image)
149
+
150
+
151
+ with gr.Blocks() as demo:
152
+ gr.HTML(TITLE)
153
+
154
+ with gr.Row():
155
+ with gr.Column():
156
+ input_image = gr.Image(type="pil", label="Input Image")
157
+ run_button = gr.ClearButton(components=None, value="Enhance Image")
158
+ with gr.Column():
159
+ output_slider = ImageSlider(label="Before / After")
160
+ run_button.add(output_slider)
161
+
162
+ with gr.Accordion("Advanced Options", open=False):
163
+ prompt = gr.Textbox(
164
+ label="Prompt",
165
+ placeholder="masterpiece, best quality, highres",
166
+ )
167
+ negative_prompt = gr.Textbox(
168
+ label="Negative Prompt",
169
+ placeholder="worst quality, low quality, normal quality",
170
+ )
171
+ seed = gr.Slider(
172
+ minimum=0,
173
+ maximum=10_000,
174
+ value=42,
175
+ step=1,
176
+ label="Seed",
177
+ )
178
+ upscale_factor = gr.Slider(
179
+ minimum=1,
180
+ maximum=4,
181
+ value=2,
182
+ step=0.2,
183
+ label="Upscale Factor",
184
+ )
185
+ controlnet_scale = gr.Slider(
186
+ minimum=0,
187
+ maximum=1.5,
188
+ value=0.6,
189
+ step=0.1,
190
+ label="ControlNet Scale",
191
+ )
192
+ controlnet_decay = gr.Slider(
193
+ minimum=0.5,
194
+ maximum=1,
195
+ value=1.0,
196
+ step=0.025,
197
+ label="ControlNet Scale Decay",
198
+ )
199
+ condition_scale = gr.Slider(
200
+ minimum=2,
201
+ maximum=20,
202
+ value=6,
203
+ step=1,
204
+ label="Condition Scale",
205
+ )
206
+ tile_width = gr.Slider(
207
+ minimum=64,
208
+ maximum=200,
209
+ value=112,
210
+ step=1,
211
+ label="Latent Tile Width",
212
+ )
213
+ tile_height = gr.Slider(
214
+ minimum=64,
215
+ maximum=200,
216
+ value=144,
217
+ step=1,
218
+ label="Latent Tile Height",
219
+ )
220
+ denoise_strength = gr.Slider(
221
+ minimum=0,
222
+ maximum=1,
223
+ value=0.35,
224
+ step=0.1,
225
+ label="Denoise Strength",
226
+ )
227
+ num_inference_steps = gr.Slider(
228
+ minimum=1,
229
+ maximum=30,
230
+ value=18,
231
+ step=1,
232
+ label="Number of Inference Steps",
233
+ )
234
+ solver = gr.Radio(
235
+ choices=["DDIM", "DPMSolver"],
236
+ value="DDIM",
237
+ label="Solver",
238
+ )
239
+
240
+ run_button.click(
241
+ fn=process,
242
+ inputs=[
243
+ input_image,
244
+ prompt,
245
+ negative_prompt,
246
+ seed,
247
+ upscale_factor,
248
+ controlnet_scale,
249
+ controlnet_decay,
250
+ condition_scale,
251
+ tile_width,
252
+ tile_height,
253
+ denoise_strength,
254
+ num_inference_steps,
255
+ solver,
256
+ ],
257
+ outputs=output_slider,
258
+ )
259
+
260
+ gr.Examples(
261
+ examples=[
262
+ "examples/kara-eads-L7EwHkq1B2s-unsplash.jpg",
263
+ "examples/clarity_bird.webp",
264
+ "examples/edgar-infocus-gJH8AqpiSEU-unsplash.jpg",
265
+ "examples/jeremy-wallace-_XjW3oN8UOE-unsplash.jpg",
266
+ "examples/karina-vorozheeva-rW-I87aPY5Y-unsplash.jpg",
267
+ "examples/karographix-photography-hIaOPjYCEj4-unsplash.jpg",
268
+ "examples/melissa-walker-horn-gtDYwUIr9Vg-unsplash.jpg",
269
+ "examples/ryoji-iwata-X53e51WfjlE-unsplash.jpg",
270
+ "examples/tadeusz-lakota-jggQZkITXng-unsplash.jpg",
271
+ ],
272
+ inputs=[input_image],
273
+ outputs=output_slider,
274
+ fn=process,
275
+ cache_examples="lazy",
276
+ run_on_click=False,
277
+ )
278
+
279
+ demo.launch(share=False)
src/enhancer.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ from typing import Any
4
+
5
+ import torch
6
+ from PIL import Image
7
+ from refiners.foundationals.clip.concepts import ConceptExtender
8
+ from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import (
9
+ MultiUpscaler,
10
+ UpscalerCheckpoints,
11
+ )
12
+
13
+ from esrgan_model import UpscalerESRGAN
14
+
15
+
16
+ @dataclass(kw_only=True)
17
+ class ESRGANUpscalerCheckpoints(UpscalerCheckpoints):
18
+ esrgan: Path | None = None
19
+
20
+
21
+ class ESRGANUpscaler(MultiUpscaler):
22
+ def __init__(
23
+ self,
24
+ checkpoints: ESRGANUpscalerCheckpoints,
25
+ device: torch.device,
26
+ dtype: torch.dtype,
27
+ ) -> None:
28
+ super().__init__(checkpoints=checkpoints, device=device, dtype=dtype)
29
+ self.esrgan = self.load_esrgan(checkpoints.esrgan)
30
+
31
+ def to(self, device: torch.device, dtype: torch.dtype):
32
+ self.esrgan.to(device=device, dtype=dtype)
33
+ self.sd = self.sd.to(device=device, dtype=dtype)
34
+ self.device = device
35
+ self.dtype = dtype
36
+
37
+ def load_esrgan(self, path: Path | None) -> UpscalerESRGAN | None:
38
+ if path is None:
39
+ return None
40
+ return UpscalerESRGAN(path, device=self.device, dtype=self.dtype)
41
+
42
+ def load_negative_embedding(self, path: Path | None, key: str | None) -> str:
43
+ if path is None:
44
+ return ""
45
+
46
+ embeddings: torch.Tensor | dict[str, Any] = torch.load( # type: ignore
47
+ path, weights_only=True, map_location=self.device
48
+ )
49
+
50
+ if isinstance(embeddings, dict):
51
+ assert (
52
+ key is not None
53
+ ), "Key must be provided to access the negative embedding."
54
+ key_sequence = key.split(".")
55
+ for key in key_sequence:
56
+ assert (
57
+ key in embeddings
58
+ ), f"Key {key} not found in the negative embedding dictionary. Available keys: {list(embeddings.keys())}"
59
+ embeddings = embeddings[key]
60
+
61
+ assert isinstance(
62
+ embeddings, torch.Tensor
63
+ ), f"The negative embedding must be a tensor, found {type(embeddings)}."
64
+ assert (
65
+ embeddings.ndim == 2
66
+ ), f"The negative embedding must be a 2D tensor, found {embeddings.ndim}D tensor."
67
+
68
+ extender = ConceptExtender(self.sd.clip_text_encoder)
69
+ negative_embedding_token = ", "
70
+ for i, embedding in enumerate(embeddings):
71
+ embedding = embedding.to(device=self.device, dtype=self.dtype)
72
+ extender.add_concept(token=f"<{i}>", embedding=embedding)
73
+ negative_embedding_token += f"<{i}> "
74
+ extender.inject()
75
+
76
+ return negative_embedding_token
77
+
78
+ def pre_upscale(
79
+ self,
80
+ image: Image.Image,
81
+ upscale_factor: float,
82
+ use_esrgan: bool = True,
83
+ use_esrgan_tiling: bool = True,
84
+ **_: Any,
85
+ ) -> Image.Image:
86
+ if self.esrgan is None or not use_esrgan:
87
+ return super().pre_upscale(image=image, upscale_factor=upscale_factor)
88
+
89
+ width, height = image.size
90
+
91
+ if use_esrgan_tiling:
92
+ image = self.esrgan.upscale_with_tiling(image)
93
+ else:
94
+ image = self.esrgan.upscale_without_tiling(image)
95
+
96
+ return image.resize(
97
+ size=(
98
+ int(width * upscale_factor),
99
+ int(height * upscale_factor),
100
+ ),
101
+ resample=Image.LANCZOS,
102
+ )
src/esrgan_model.py ADDED
@@ -0,0 +1,1068 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # type: ignore
2
+ """
3
+ Modified from https://github.com/philz1337x/clarity-upscaler
4
+ which is a copy of https://github.com/AUTOMATIC1111/stable-diffusion-webui
5
+ which is a copy of https://github.com/victorca25/iNNfer
6
+ which is a copy of https://github.com/xinntao/ESRGAN
7
+ """
8
+
9
+ import math
10
+ import os
11
+ from collections import OrderedDict, namedtuple
12
+ from pathlib import Path
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from PIL import Image
19
+
20
+ ####################
21
+ # RRDBNet Generator
22
+ ####################
23
+
24
+
25
+ class RRDBNet(nn.Module):
26
+ def __init__(
27
+ self,
28
+ in_nc,
29
+ out_nc,
30
+ nf,
31
+ nb,
32
+ nr=3,
33
+ gc=32,
34
+ upscale=4,
35
+ norm_type=None,
36
+ act_type="leakyrelu",
37
+ mode="CNA",
38
+ upsample_mode="upconv",
39
+ convtype="Conv2D",
40
+ finalact=None,
41
+ gaussian_noise=False,
42
+ plus=False,
43
+ ):
44
+ super(RRDBNet, self).__init__()
45
+ n_upscale = int(math.log(upscale, 2))
46
+ if upscale == 3:
47
+ n_upscale = 1
48
+
49
+ self.resrgan_scale = 0
50
+ if in_nc % 16 == 0:
51
+ self.resrgan_scale = 1
52
+ elif in_nc != 4 and in_nc % 4 == 0:
53
+ self.resrgan_scale = 2
54
+
55
+ fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
56
+ rb_blocks = [
57
+ RRDB(
58
+ nf,
59
+ nr,
60
+ kernel_size=3,
61
+ gc=32,
62
+ stride=1,
63
+ bias=1,
64
+ pad_type="zero",
65
+ norm_type=norm_type,
66
+ act_type=act_type,
67
+ mode="CNA",
68
+ convtype=convtype,
69
+ gaussian_noise=gaussian_noise,
70
+ plus=plus,
71
+ )
72
+ for _ in range(nb)
73
+ ]
74
+ LR_conv = conv_block(
75
+ nf,
76
+ nf,
77
+ kernel_size=3,
78
+ norm_type=norm_type,
79
+ act_type=None,
80
+ mode=mode,
81
+ convtype=convtype,
82
+ )
83
+
84
+ if upsample_mode == "upconv":
85
+ upsample_block = upconv_block
86
+ elif upsample_mode == "pixelshuffle":
87
+ upsample_block = pixelshuffle_block
88
+ else:
89
+ raise NotImplementedError(f"upsample mode [{upsample_mode}] is not found")
90
+ if upscale == 3:
91
+ upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
92
+ else:
93
+ upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
94
+ HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
95
+ HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
96
+
97
+ outact = act(finalact) if finalact else None
98
+
99
+ self.model = sequential(
100
+ fea_conv,
101
+ ShortcutBlock(sequential(*rb_blocks, LR_conv)),
102
+ *upsampler,
103
+ HR_conv0,
104
+ HR_conv1,
105
+ outact,
106
+ )
107
+
108
+ def forward(self, x, outm=None):
109
+ if self.resrgan_scale == 1:
110
+ feat = pixel_unshuffle(x, scale=4)
111
+ elif self.resrgan_scale == 2:
112
+ feat = pixel_unshuffle(x, scale=2)
113
+ else:
114
+ feat = x
115
+
116
+ return self.model(feat)
117
+
118
+
119
+ class RRDB(nn.Module):
120
+ """
121
+ Residual in Residual Dense Block
122
+ (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ nf,
128
+ nr=3,
129
+ kernel_size=3,
130
+ gc=32,
131
+ stride=1,
132
+ bias=1,
133
+ pad_type="zero",
134
+ norm_type=None,
135
+ act_type="leakyrelu",
136
+ mode="CNA",
137
+ convtype="Conv2D",
138
+ spectral_norm=False,
139
+ gaussian_noise=False,
140
+ plus=False,
141
+ ):
142
+ super(RRDB, self).__init__()
143
+ # This is for backwards compatibility with existing models
144
+ if nr == 3:
145
+ self.RDB1 = ResidualDenseBlock_5C(
146
+ nf,
147
+ kernel_size,
148
+ gc,
149
+ stride,
150
+ bias,
151
+ pad_type,
152
+ norm_type,
153
+ act_type,
154
+ mode,
155
+ convtype,
156
+ spectral_norm=spectral_norm,
157
+ gaussian_noise=gaussian_noise,
158
+ plus=plus,
159
+ )
160
+ self.RDB2 = ResidualDenseBlock_5C(
161
+ nf,
162
+ kernel_size,
163
+ gc,
164
+ stride,
165
+ bias,
166
+ pad_type,
167
+ norm_type,
168
+ act_type,
169
+ mode,
170
+ convtype,
171
+ spectral_norm=spectral_norm,
172
+ gaussian_noise=gaussian_noise,
173
+ plus=plus,
174
+ )
175
+ self.RDB3 = ResidualDenseBlock_5C(
176
+ nf,
177
+ kernel_size,
178
+ gc,
179
+ stride,
180
+ bias,
181
+ pad_type,
182
+ norm_type,
183
+ act_type,
184
+ mode,
185
+ convtype,
186
+ spectral_norm=spectral_norm,
187
+ gaussian_noise=gaussian_noise,
188
+ plus=plus,
189
+ )
190
+ else:
191
+ RDB_list = [
192
+ ResidualDenseBlock_5C(
193
+ nf,
194
+ kernel_size,
195
+ gc,
196
+ stride,
197
+ bias,
198
+ pad_type,
199
+ norm_type,
200
+ act_type,
201
+ mode,
202
+ convtype,
203
+ spectral_norm=spectral_norm,
204
+ gaussian_noise=gaussian_noise,
205
+ plus=plus,
206
+ )
207
+ for _ in range(nr)
208
+ ]
209
+ self.RDBs = nn.Sequential(*RDB_list)
210
+
211
+ def forward(self, x):
212
+ if hasattr(self, "RDB1"):
213
+ out = self.RDB1(x)
214
+ out = self.RDB2(out)
215
+ out = self.RDB3(out)
216
+ else:
217
+ out = self.RDBs(x)
218
+ return out * 0.2 + x
219
+
220
+
221
+ class ResidualDenseBlock_5C(nn.Module):
222
+ """
223
+ Residual Dense Block
224
+ The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
225
+ Modified options that can be used:
226
+ - "Partial Convolution based Padding" arXiv:1811.11718
227
+ - "Spectral normalization" arXiv:1802.05957
228
+ - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
229
+ {Rakotonirina} and A. {Rasoanaivo}
230
+ """
231
+
232
+ def __init__(
233
+ self,
234
+ nf=64,
235
+ kernel_size=3,
236
+ gc=32,
237
+ stride=1,
238
+ bias=1,
239
+ pad_type="zero",
240
+ norm_type=None,
241
+ act_type="leakyrelu",
242
+ mode="CNA",
243
+ convtype="Conv2D",
244
+ spectral_norm=False,
245
+ gaussian_noise=False,
246
+ plus=False,
247
+ ):
248
+ super(ResidualDenseBlock_5C, self).__init__()
249
+
250
+ self.noise = GaussianNoise() if gaussian_noise else None
251
+ self.conv1x1 = conv1x1(nf, gc) if plus else None
252
+
253
+ self.conv1 = conv_block(
254
+ nf,
255
+ gc,
256
+ kernel_size,
257
+ stride,
258
+ bias=bias,
259
+ pad_type=pad_type,
260
+ norm_type=norm_type,
261
+ act_type=act_type,
262
+ mode=mode,
263
+ convtype=convtype,
264
+ spectral_norm=spectral_norm,
265
+ )
266
+ self.conv2 = conv_block(
267
+ nf + gc,
268
+ gc,
269
+ kernel_size,
270
+ stride,
271
+ bias=bias,
272
+ pad_type=pad_type,
273
+ norm_type=norm_type,
274
+ act_type=act_type,
275
+ mode=mode,
276
+ convtype=convtype,
277
+ spectral_norm=spectral_norm,
278
+ )
279
+ self.conv3 = conv_block(
280
+ nf + 2 * gc,
281
+ gc,
282
+ kernel_size,
283
+ stride,
284
+ bias=bias,
285
+ pad_type=pad_type,
286
+ norm_type=norm_type,
287
+ act_type=act_type,
288
+ mode=mode,
289
+ convtype=convtype,
290
+ spectral_norm=spectral_norm,
291
+ )
292
+ self.conv4 = conv_block(
293
+ nf + 3 * gc,
294
+ gc,
295
+ kernel_size,
296
+ stride,
297
+ bias=bias,
298
+ pad_type=pad_type,
299
+ norm_type=norm_type,
300
+ act_type=act_type,
301
+ mode=mode,
302
+ convtype=convtype,
303
+ spectral_norm=spectral_norm,
304
+ )
305
+ if mode == "CNA":
306
+ last_act = None
307
+ else:
308
+ last_act = act_type
309
+ self.conv5 = conv_block(
310
+ nf + 4 * gc,
311
+ nf,
312
+ 3,
313
+ stride,
314
+ bias=bias,
315
+ pad_type=pad_type,
316
+ norm_type=norm_type,
317
+ act_type=last_act,
318
+ mode=mode,
319
+ convtype=convtype,
320
+ spectral_norm=spectral_norm,
321
+ )
322
+
323
+ def forward(self, x):
324
+ x1 = self.conv1(x)
325
+ x2 = self.conv2(torch.cat((x, x1), 1))
326
+ if self.conv1x1:
327
+ x2 = x2 + self.conv1x1(x)
328
+ x3 = self.conv3(torch.cat((x, x1, x2), 1))
329
+ x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
330
+ if self.conv1x1:
331
+ x4 = x4 + x2
332
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
333
+ if self.noise:
334
+ return self.noise(x5.mul(0.2) + x)
335
+ else:
336
+ return x5 * 0.2 + x
337
+
338
+
339
+ ####################
340
+ # ESRGANplus
341
+ ####################
342
+
343
+
344
+ class GaussianNoise(nn.Module):
345
+ def __init__(self, sigma=0.1, is_relative_detach=False):
346
+ super().__init__()
347
+ self.sigma = sigma
348
+ self.is_relative_detach = is_relative_detach
349
+ self.noise = torch.tensor(0, dtype=torch.float)
350
+
351
+ def forward(self, x):
352
+ if self.training and self.sigma != 0:
353
+ self.noise = self.noise.to(device=x.device, dtype=x.device)
354
+ scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
355
+ sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
356
+ x = x + sampled_noise
357
+ return x
358
+
359
+
360
+ def conv1x1(in_planes, out_planes, stride=1):
361
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
362
+
363
+
364
+ ####################
365
+ # SRVGGNetCompact
366
+ ####################
367
+
368
+
369
+ class SRVGGNetCompact(nn.Module):
370
+ """A compact VGG-style network structure for super-resolution.
371
+ This class is copied from https://github.com/xinntao/Real-ESRGAN
372
+ """
373
+
374
+ def __init__(
375
+ self,
376
+ num_in_ch=3,
377
+ num_out_ch=3,
378
+ num_feat=64,
379
+ num_conv=16,
380
+ upscale=4,
381
+ act_type="prelu",
382
+ ):
383
+ super(SRVGGNetCompact, self).__init__()
384
+ self.num_in_ch = num_in_ch
385
+ self.num_out_ch = num_out_ch
386
+ self.num_feat = num_feat
387
+ self.num_conv = num_conv
388
+ self.upscale = upscale
389
+ self.act_type = act_type
390
+
391
+ self.body = nn.ModuleList()
392
+ # the first conv
393
+ self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
394
+ # the first activation
395
+ if act_type == "relu":
396
+ activation = nn.ReLU(inplace=True)
397
+ elif act_type == "prelu":
398
+ activation = nn.PReLU(num_parameters=num_feat)
399
+ elif act_type == "leakyrelu":
400
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
401
+ self.body.append(activation)
402
+
403
+ # the body structure
404
+ for _ in range(num_conv):
405
+ self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
406
+ # activation
407
+ if act_type == "relu":
408
+ activation = nn.ReLU(inplace=True)
409
+ elif act_type == "prelu":
410
+ activation = nn.PReLU(num_parameters=num_feat)
411
+ elif act_type == "leakyrelu":
412
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
413
+ self.body.append(activation)
414
+
415
+ # the last conv
416
+ self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
417
+ # upsample
418
+ self.upsampler = nn.PixelShuffle(upscale)
419
+
420
+ def forward(self, x):
421
+ out = x
422
+ for i in range(0, len(self.body)):
423
+ out = self.body[i](out)
424
+
425
+ out = self.upsampler(out)
426
+ # add the nearest upsampled image, so that the network learns the residual
427
+ base = F.interpolate(x, scale_factor=self.upscale, mode="nearest")
428
+ out += base
429
+ return out
430
+
431
+
432
+ ####################
433
+ # Upsampler
434
+ ####################
435
+
436
+
437
+ class Upsample(nn.Module):
438
+ r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
439
+ The input data is assumed to be of the form
440
+ `minibatch x channels x [optional depth] x [optional height] x width`.
441
+ """
442
+
443
+ def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
444
+ super(Upsample, self).__init__()
445
+ if isinstance(scale_factor, tuple):
446
+ self.scale_factor = tuple(float(factor) for factor in scale_factor)
447
+ else:
448
+ self.scale_factor = float(scale_factor) if scale_factor else None
449
+ self.mode = mode
450
+ self.size = size
451
+ self.align_corners = align_corners
452
+
453
+ def forward(self, x):
454
+ return nn.functional.interpolate(
455
+ x,
456
+ size=self.size,
457
+ scale_factor=self.scale_factor,
458
+ mode=self.mode,
459
+ align_corners=self.align_corners,
460
+ )
461
+
462
+ def extra_repr(self):
463
+ if self.scale_factor is not None:
464
+ info = f"scale_factor={self.scale_factor}"
465
+ else:
466
+ info = f"size={self.size}"
467
+ info += f", mode={self.mode}"
468
+ return info
469
+
470
+
471
+ def pixel_unshuffle(x, scale):
472
+ """Pixel unshuffle.
473
+ Args:
474
+ x (Tensor): Input feature with shape (b, c, hh, hw).
475
+ scale (int): Downsample ratio.
476
+ Returns:
477
+ Tensor: the pixel unshuffled feature.
478
+ """
479
+ b, c, hh, hw = x.size()
480
+ out_channel = c * (scale**2)
481
+ assert hh % scale == 0 and hw % scale == 0
482
+ h = hh // scale
483
+ w = hw // scale
484
+ x_view = x.view(b, c, h, scale, w, scale)
485
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
486
+
487
+
488
+ def pixelshuffle_block(
489
+ in_nc,
490
+ out_nc,
491
+ upscale_factor=2,
492
+ kernel_size=3,
493
+ stride=1,
494
+ bias=True,
495
+ pad_type="zero",
496
+ norm_type=None,
497
+ act_type="relu",
498
+ convtype="Conv2D",
499
+ ):
500
+ """
501
+ Pixel shuffle layer
502
+ (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
503
+ Neural Network, CVPR17)
504
+ """
505
+ conv = conv_block(
506
+ in_nc,
507
+ out_nc * (upscale_factor**2),
508
+ kernel_size,
509
+ stride,
510
+ bias=bias,
511
+ pad_type=pad_type,
512
+ norm_type=None,
513
+ act_type=None,
514
+ convtype=convtype,
515
+ )
516
+ pixel_shuffle = nn.PixelShuffle(upscale_factor)
517
+
518
+ n = norm(norm_type, out_nc) if norm_type else None
519
+ a = act(act_type) if act_type else None
520
+ return sequential(conv, pixel_shuffle, n, a)
521
+
522
+
523
+ def upconv_block(
524
+ in_nc,
525
+ out_nc,
526
+ upscale_factor=2,
527
+ kernel_size=3,
528
+ stride=1,
529
+ bias=True,
530
+ pad_type="zero",
531
+ norm_type=None,
532
+ act_type="relu",
533
+ mode="nearest",
534
+ convtype="Conv2D",
535
+ ):
536
+ """Upconv layer"""
537
+ upscale_factor = (1, upscale_factor, upscale_factor) if convtype == "Conv3D" else upscale_factor
538
+ upsample = Upsample(scale_factor=upscale_factor, mode=mode)
539
+ conv = conv_block(
540
+ in_nc,
541
+ out_nc,
542
+ kernel_size,
543
+ stride,
544
+ bias=bias,
545
+ pad_type=pad_type,
546
+ norm_type=norm_type,
547
+ act_type=act_type,
548
+ convtype=convtype,
549
+ )
550
+ return sequential(upsample, conv)
551
+
552
+
553
+ ####################
554
+ # Basic blocks
555
+ ####################
556
+
557
+
558
+ def make_layer(basic_block, num_basic_block, **kwarg):
559
+ """Make layers by stacking the same blocks.
560
+ Args:
561
+ basic_block (nn.module): nn.module class for basic block. (block)
562
+ num_basic_block (int): number of blocks. (n_layers)
563
+ Returns:
564
+ nn.Sequential: Stacked blocks in nn.Sequential.
565
+ """
566
+ layers = []
567
+ for _ in range(num_basic_block):
568
+ layers.append(basic_block(**kwarg))
569
+ return nn.Sequential(*layers)
570
+
571
+
572
+ def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
573
+ """activation helper"""
574
+ act_type = act_type.lower()
575
+ if act_type == "relu":
576
+ layer = nn.ReLU(inplace)
577
+ elif act_type in ("leakyrelu", "lrelu"):
578
+ layer = nn.LeakyReLU(neg_slope, inplace)
579
+ elif act_type == "prelu":
580
+ layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
581
+ elif act_type == "tanh": # [-1, 1] range output
582
+ layer = nn.Tanh()
583
+ elif act_type == "sigmoid": # [0, 1] range output
584
+ layer = nn.Sigmoid()
585
+ else:
586
+ raise NotImplementedError(f"activation layer [{act_type}] is not found")
587
+ return layer
588
+
589
+
590
+ class Identity(nn.Module):
591
+ def __init__(self, *kwargs):
592
+ super(Identity, self).__init__()
593
+
594
+ def forward(self, x, *kwargs):
595
+ return x
596
+
597
+
598
+ def norm(norm_type, nc):
599
+ """Return a normalization layer"""
600
+ norm_type = norm_type.lower()
601
+ if norm_type == "batch":
602
+ layer = nn.BatchNorm2d(nc, affine=True)
603
+ elif norm_type == "instance":
604
+ layer = nn.InstanceNorm2d(nc, affine=False)
605
+ elif norm_type == "none":
606
+
607
+ def norm_layer(x):
608
+ return Identity()
609
+ else:
610
+ raise NotImplementedError(f"normalization layer [{norm_type}] is not found")
611
+ return layer
612
+
613
+
614
+ def pad(pad_type, padding):
615
+ """padding layer helper"""
616
+ pad_type = pad_type.lower()
617
+ if padding == 0:
618
+ return None
619
+ if pad_type == "reflect":
620
+ layer = nn.ReflectionPad2d(padding)
621
+ elif pad_type == "replicate":
622
+ layer = nn.ReplicationPad2d(padding)
623
+ elif pad_type == "zero":
624
+ layer = nn.ZeroPad2d(padding)
625
+ else:
626
+ raise NotImplementedError(f"padding layer [{pad_type}] is not implemented")
627
+ return layer
628
+
629
+
630
+ def get_valid_padding(kernel_size, dilation):
631
+ kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
632
+ padding = (kernel_size - 1) // 2
633
+ return padding
634
+
635
+
636
+ class ShortcutBlock(nn.Module):
637
+ """Elementwise sum the output of a submodule to its input"""
638
+
639
+ def __init__(self, submodule):
640
+ super(ShortcutBlock, self).__init__()
641
+ self.sub = submodule
642
+
643
+ def forward(self, x):
644
+ output = x + self.sub(x)
645
+ return output
646
+
647
+ def __repr__(self):
648
+ return "Identity + \n|" + self.sub.__repr__().replace("\n", "\n|")
649
+
650
+
651
+ def sequential(*args):
652
+ """Flatten Sequential. It unwraps nn.Sequential."""
653
+ if len(args) == 1:
654
+ if isinstance(args[0], OrderedDict):
655
+ raise NotImplementedError("sequential does not support OrderedDict input.")
656
+ return args[0] # No sequential is needed.
657
+ modules = []
658
+ for module in args:
659
+ if isinstance(module, nn.Sequential):
660
+ for submodule in module.children():
661
+ modules.append(submodule)
662
+ elif isinstance(module, nn.Module):
663
+ modules.append(module)
664
+ return nn.Sequential(*modules)
665
+
666
+
667
+ def conv_block(
668
+ in_nc,
669
+ out_nc,
670
+ kernel_size,
671
+ stride=1,
672
+ dilation=1,
673
+ groups=1,
674
+ bias=True,
675
+ pad_type="zero",
676
+ norm_type=None,
677
+ act_type="relu",
678
+ mode="CNA",
679
+ convtype="Conv2D",
680
+ spectral_norm=False,
681
+ ):
682
+ """Conv layer with padding, normalization, activation"""
683
+ assert mode in ["CNA", "NAC", "CNAC"], f"Wrong conv mode [{mode}]"
684
+ padding = get_valid_padding(kernel_size, dilation)
685
+ p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None
686
+ padding = padding if pad_type == "zero" else 0
687
+
688
+ if convtype == "PartialConv2D":
689
+ # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer
690
+ from torchvision.ops import PartialConv2d
691
+
692
+ c = PartialConv2d(
693
+ in_nc,
694
+ out_nc,
695
+ kernel_size=kernel_size,
696
+ stride=stride,
697
+ padding=padding,
698
+ dilation=dilation,
699
+ bias=bias,
700
+ groups=groups,
701
+ )
702
+ elif convtype == "DeformConv2D":
703
+ from torchvision.ops import DeformConv2d # not tested
704
+
705
+ c = DeformConv2d(
706
+ in_nc,
707
+ out_nc,
708
+ kernel_size=kernel_size,
709
+ stride=stride,
710
+ padding=padding,
711
+ dilation=dilation,
712
+ bias=bias,
713
+ groups=groups,
714
+ )
715
+ elif convtype == "Conv3D":
716
+ c = nn.Conv3d(
717
+ in_nc,
718
+ out_nc,
719
+ kernel_size=kernel_size,
720
+ stride=stride,
721
+ padding=padding,
722
+ dilation=dilation,
723
+ bias=bias,
724
+ groups=groups,
725
+ )
726
+ else:
727
+ c = nn.Conv2d(
728
+ in_nc,
729
+ out_nc,
730
+ kernel_size=kernel_size,
731
+ stride=stride,
732
+ padding=padding,
733
+ dilation=dilation,
734
+ bias=bias,
735
+ groups=groups,
736
+ )
737
+
738
+ if spectral_norm:
739
+ c = nn.utils.spectral_norm(c)
740
+
741
+ a = act(act_type) if act_type else None
742
+ if "CNA" in mode:
743
+ n = norm(norm_type, out_nc) if norm_type else None
744
+ return sequential(p, c, n, a)
745
+ elif mode == "NAC":
746
+ if norm_type is None and act_type is not None:
747
+ a = act(act_type, inplace=False)
748
+ n = norm(norm_type, in_nc) if norm_type else None
749
+ return sequential(n, a, p, c)
750
+
751
+
752
+ def load_models(
753
+ model_path: Path,
754
+ command_path: str = None,
755
+ ) -> list:
756
+ """
757
+ A one-and done loader to try finding the desired models in specified directories.
758
+
759
+ @param download_name: Specify to download from model_url immediately.
760
+ @param model_url: If no other models are found, this will be downloaded on upscale.
761
+ @param model_path: The location to store/find models in.
762
+ @param command_path: A command-line argument to search for models in first.
763
+ @param ext_filter: An optional list of filename extensions to filter by
764
+ @return: A list of paths containing the desired model(s)
765
+ """
766
+ output = []
767
+
768
+ try:
769
+ places = []
770
+ if command_path is not None and command_path != model_path:
771
+ pretrained_path = os.path.join(command_path, "experiments/pretrained_models")
772
+ if os.path.exists(pretrained_path):
773
+ print(f"Appending path: {pretrained_path}")
774
+ places.append(pretrained_path)
775
+ elif os.path.exists(command_path):
776
+ places.append(command_path)
777
+
778
+ places.append(model_path)
779
+
780
+ except Exception:
781
+ pass
782
+
783
+ return output
784
+
785
+
786
+ def mod2normal(state_dict):
787
+ # this code is copied from https://github.com/victorca25/iNNfer
788
+ if "conv_first.weight" in state_dict:
789
+ crt_net = {}
790
+ items = list(state_dict)
791
+
792
+ crt_net["model.0.weight"] = state_dict["conv_first.weight"]
793
+ crt_net["model.0.bias"] = state_dict["conv_first.bias"]
794
+
795
+ for k in items.copy():
796
+ if "RDB" in k:
797
+ ori_k = k.replace("RRDB_trunk.", "model.1.sub.")
798
+ if ".weight" in k:
799
+ ori_k = ori_k.replace(".weight", ".0.weight")
800
+ elif ".bias" in k:
801
+ ori_k = ori_k.replace(".bias", ".0.bias")
802
+ crt_net[ori_k] = state_dict[k]
803
+ items.remove(k)
804
+
805
+ crt_net["model.1.sub.23.weight"] = state_dict["trunk_conv.weight"]
806
+ crt_net["model.1.sub.23.bias"] = state_dict["trunk_conv.bias"]
807
+ crt_net["model.3.weight"] = state_dict["upconv1.weight"]
808
+ crt_net["model.3.bias"] = state_dict["upconv1.bias"]
809
+ crt_net["model.6.weight"] = state_dict["upconv2.weight"]
810
+ crt_net["model.6.bias"] = state_dict["upconv2.bias"]
811
+ crt_net["model.8.weight"] = state_dict["HRconv.weight"]
812
+ crt_net["model.8.bias"] = state_dict["HRconv.bias"]
813
+ crt_net["model.10.weight"] = state_dict["conv_last.weight"]
814
+ crt_net["model.10.bias"] = state_dict["conv_last.bias"]
815
+ state_dict = crt_net
816
+ return state_dict
817
+
818
+
819
+ def resrgan2normal(state_dict, nb=23):
820
+ # this code is copied from https://github.com/victorca25/iNNfer
821
+ if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
822
+ re8x = 0
823
+ crt_net = {}
824
+ items = list(state_dict)
825
+
826
+ crt_net["model.0.weight"] = state_dict["conv_first.weight"]
827
+ crt_net["model.0.bias"] = state_dict["conv_first.bias"]
828
+
829
+ for k in items.copy():
830
+ if "rdb" in k:
831
+ ori_k = k.replace("body.", "model.1.sub.")
832
+ ori_k = ori_k.replace(".rdb", ".RDB")
833
+ if ".weight" in k:
834
+ ori_k = ori_k.replace(".weight", ".0.weight")
835
+ elif ".bias" in k:
836
+ ori_k = ori_k.replace(".bias", ".0.bias")
837
+ crt_net[ori_k] = state_dict[k]
838
+ items.remove(k)
839
+
840
+ crt_net[f"model.1.sub.{nb}.weight"] = state_dict["conv_body.weight"]
841
+ crt_net[f"model.1.sub.{nb}.bias"] = state_dict["conv_body.bias"]
842
+ crt_net["model.3.weight"] = state_dict["conv_up1.weight"]
843
+ crt_net["model.3.bias"] = state_dict["conv_up1.bias"]
844
+ crt_net["model.6.weight"] = state_dict["conv_up2.weight"]
845
+ crt_net["model.6.bias"] = state_dict["conv_up2.bias"]
846
+
847
+ if "conv_up3.weight" in state_dict:
848
+ # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
849
+ re8x = 3
850
+ crt_net["model.9.weight"] = state_dict["conv_up3.weight"]
851
+ crt_net["model.9.bias"] = state_dict["conv_up3.bias"]
852
+
853
+ crt_net[f"model.{8+re8x}.weight"] = state_dict["conv_hr.weight"]
854
+ crt_net[f"model.{8+re8x}.bias"] = state_dict["conv_hr.bias"]
855
+ crt_net[f"model.{10+re8x}.weight"] = state_dict["conv_last.weight"]
856
+ crt_net[f"model.{10+re8x}.bias"] = state_dict["conv_last.bias"]
857
+
858
+ state_dict = crt_net
859
+ return state_dict
860
+
861
+
862
+ def infer_params(state_dict):
863
+ # this code is copied from https://github.com/victorca25/iNNfer
864
+ scale2x = 0
865
+ scalemin = 6
866
+ n_uplayer = 0
867
+ plus = False
868
+
869
+ for block in list(state_dict):
870
+ parts = block.split(".")
871
+ n_parts = len(parts)
872
+ if n_parts == 5 and parts[2] == "sub":
873
+ nb = int(parts[3])
874
+ elif n_parts == 3:
875
+ part_num = int(parts[1])
876
+ if part_num > scalemin and parts[0] == "model" and parts[2] == "weight":
877
+ scale2x += 1
878
+ if part_num > n_uplayer:
879
+ n_uplayer = part_num
880
+ out_nc = state_dict[block].shape[0]
881
+ if not plus and "conv1x1" in block:
882
+ plus = True
883
+
884
+ nf = state_dict["model.0.weight"].shape[0]
885
+ in_nc = state_dict["model.0.weight"].shape[1]
886
+ out_nc = out_nc
887
+ scale = 2**scale2x
888
+
889
+ return in_nc, out_nc, nf, nb, plus, scale
890
+
891
+
892
+ # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L64
893
+ Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
894
+
895
+
896
+ # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L67
897
+ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
898
+ w = image.width
899
+ h = image.height
900
+
901
+ non_overlap_width = tile_w - overlap
902
+ non_overlap_height = tile_h - overlap
903
+
904
+ cols = math.ceil((w - overlap) / non_overlap_width)
905
+ rows = math.ceil((h - overlap) / non_overlap_height)
906
+
907
+ dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
908
+ dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
909
+
910
+ grid = Grid([], tile_w, tile_h, w, h, overlap)
911
+ for row in range(rows):
912
+ row_images = []
913
+
914
+ y = int(row * dy)
915
+
916
+ if y + tile_h >= h:
917
+ y = h - tile_h
918
+
919
+ for col in range(cols):
920
+ x = int(col * dx)
921
+
922
+ if x + tile_w >= w:
923
+ x = w - tile_w
924
+
925
+ tile = image.crop((x, y, x + tile_w, y + tile_h))
926
+
927
+ row_images.append([x, tile_w, tile])
928
+
929
+ grid.tiles.append([y, tile_h, row_images])
930
+
931
+ return grid
932
+
933
+
934
+ # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L104
935
+ def combine_grid(grid):
936
+ def make_mask_image(r):
937
+ r = r * 255 / grid.overlap
938
+ r = r.astype(np.uint8)
939
+ return Image.fromarray(r, "L")
940
+
941
+ mask_w = make_mask_image(
942
+ np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0)
943
+ )
944
+ mask_h = make_mask_image(
945
+ np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1)
946
+ )
947
+
948
+ combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
949
+ for y, h, row in grid.tiles:
950
+ combined_row = Image.new("RGB", (grid.image_w, h))
951
+ for x, w, tile in row:
952
+ if x == 0:
953
+ combined_row.paste(tile, (0, 0))
954
+ continue
955
+
956
+ combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
957
+ combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
958
+
959
+ if y == 0:
960
+ combined_image.paste(combined_row, (0, 0))
961
+ continue
962
+
963
+ combined_image.paste(
964
+ combined_row.crop((0, 0, combined_row.width, grid.overlap)),
965
+ (0, y),
966
+ mask=mask_h,
967
+ )
968
+ combined_image.paste(
969
+ combined_row.crop((0, grid.overlap, combined_row.width, h)),
970
+ (0, y + grid.overlap),
971
+ )
972
+
973
+ return combined_image
974
+
975
+
976
+ class UpscalerESRGAN:
977
+ def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
978
+ self.device = device
979
+ self.dtype = dtype
980
+ self.model_path = model_path
981
+ self.model = self.load_model(model_path)
982
+
983
+ def __call__(self, img: Image.Image) -> Image.Image:
984
+ return self.upscale_without_tiling(img)
985
+
986
+ def to(self, device: torch.device, dtype: torch.dtype):
987
+ self.device = device
988
+ self.dtype = dtype
989
+ self.model.to(device=device, dtype=dtype)
990
+
991
+ def load_model(self, path: Path) -> SRVGGNetCompact | RRDBNet:
992
+ filename = path
993
+ state_dict = torch.load(filename, weights_only=True, map_location=self.device)
994
+
995
+ if "params_ema" in state_dict:
996
+ state_dict = state_dict["params_ema"]
997
+ elif "params" in state_dict:
998
+ state_dict = state_dict["params"]
999
+ num_conv = 16 if "realesr-animevideov3" in filename else 32
1000
+ model = SRVGGNetCompact(
1001
+ num_in_ch=3,
1002
+ num_out_ch=3,
1003
+ num_feat=64,
1004
+ num_conv=num_conv,
1005
+ upscale=4,
1006
+ act_type="prelu",
1007
+ )
1008
+ model.load_state_dict(state_dict)
1009
+ model.eval()
1010
+ return model
1011
+
1012
+ if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
1013
+ nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
1014
+ state_dict = resrgan2normal(state_dict, nb)
1015
+ elif "conv_first.weight" in state_dict:
1016
+ state_dict = mod2normal(state_dict)
1017
+ elif "model.0.weight" not in state_dict:
1018
+ raise Exception("The file is not a recognized ESRGAN model.")
1019
+
1020
+ in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
1021
+
1022
+ model = RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
1023
+ model.load_state_dict(state_dict)
1024
+ model.eval()
1025
+
1026
+ return model
1027
+
1028
+ def upscale_without_tiling(self, img: Image.Image) -> Image.Image:
1029
+ img = np.array(img)
1030
+ img = img[:, :, ::-1]
1031
+ img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
1032
+ img = torch.from_numpy(img).float()
1033
+ img = img.unsqueeze(0).to(device=self.device, dtype=self.dtype)
1034
+ with torch.no_grad():
1035
+ output = self.model(img)
1036
+ output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
1037
+ output = 255.0 * np.moveaxis(output, 0, 2)
1038
+ output = output.astype(np.uint8)
1039
+ output = output[:, :, ::-1]
1040
+ return Image.fromarray(output, "RGB")
1041
+
1042
+ # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/esrgan_model.py#L208
1043
+ def upscale_with_tiling(self, img: Image.Image) -> Image.Image:
1044
+ grid = split_grid(img)
1045
+ newtiles = []
1046
+ scale_factor = 1
1047
+
1048
+ for y, h, row in grid.tiles:
1049
+ newrow = []
1050
+ for tiledata in row:
1051
+ x, w, tile = tiledata
1052
+
1053
+ output = self.upscale_without_tiling(tile)
1054
+ scale_factor = output.width // tile.width
1055
+
1056
+ newrow.append([x * scale_factor, w * scale_factor, output])
1057
+ newtiles.append([y * scale_factor, h * scale_factor, newrow])
1058
+
1059
+ newgrid = Grid(
1060
+ newtiles,
1061
+ grid.tile_w * scale_factor,
1062
+ grid.tile_h * scale_factor,
1063
+ grid.image_w * scale_factor,
1064
+ grid.image_h * scale_factor,
1065
+ grid.overlap * scale_factor,
1066
+ )
1067
+ output = combine_grid(newgrid)
1068
+ return output