kevinwang676 commited on
Commit
fb4fac3
1 Parent(s): 4d4f2d3

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. DiffSynth_Studio.py +15 -0
  3. LICENSE +201 -0
  4. README.md +117 -13
  5. diffsynth/__init__.py +6 -0
  6. diffsynth/controlnets/__init__.py +2 -0
  7. diffsynth/controlnets/controlnet_unit.py +53 -0
  8. diffsynth/controlnets/processors.py +51 -0
  9. diffsynth/data/__init__.py +1 -0
  10. diffsynth/data/video.py +148 -0
  11. diffsynth/extensions/ESRGAN/__init__.py +118 -0
  12. diffsynth/extensions/FastBlend/__init__.py +63 -0
  13. diffsynth/extensions/FastBlend/api.py +397 -0
  14. diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
  15. diffsynth/extensions/FastBlend/data.py +146 -0
  16. diffsynth/extensions/FastBlend/patch_match.py +298 -0
  17. diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
  18. diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
  19. diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
  20. diffsynth/extensions/FastBlend/runners/fast.py +141 -0
  21. diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
  22. diffsynth/extensions/RIFE/__init__.py +241 -0
  23. diffsynth/models/__init__.py +814 -0
  24. diffsynth/models/attention.py +89 -0
  25. diffsynth/models/downloader.py +28 -0
  26. diffsynth/models/hunyuan_dit.py +451 -0
  27. diffsynth/models/hunyuan_dit_text_encoder.py +161 -0
  28. diffsynth/models/kolors_text_encoder.py +1363 -0
  29. diffsynth/models/sd3_dit.py +797 -0
  30. diffsynth/models/sd3_text_encoder.py +0 -0
  31. diffsynth/models/sd3_vae_decoder.py +80 -0
  32. diffsynth/models/sd3_vae_encoder.py +94 -0
  33. diffsynth/models/sd_controlnet.py +587 -0
  34. diffsynth/models/sd_ipadapter.py +56 -0
  35. diffsynth/models/sd_lora.py +60 -0
  36. diffsynth/models/sd_motion.py +198 -0
  37. diffsynth/models/sd_text_encoder.py +320 -0
  38. diffsynth/models/sd_unet.py +0 -0
  39. diffsynth/models/sd_vae_decoder.py +332 -0
  40. diffsynth/models/sd_vae_encoder.py +278 -0
  41. diffsynth/models/sdxl_ipadapter.py +121 -0
  42. diffsynth/models/sdxl_motion.py +103 -0
  43. diffsynth/models/sdxl_text_encoder.py +757 -0
  44. diffsynth/models/sdxl_unet.py +0 -0
  45. diffsynth/models/sdxl_vae_decoder.py +15 -0
  46. diffsynth/models/sdxl_vae_encoder.py +15 -0
  47. diffsynth/models/svd_image_encoder.py +504 -0
  48. diffsynth/models/svd_unet.py +0 -0
  49. diffsynth/models/svd_vae_decoder.py +577 -0
  50. diffsynth/models/svd_vae_encoder.py +138 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt filter=lfs diff=lfs merge=lfs -text
DiffSynth_Studio.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Set web page format
2
+ import streamlit as st
3
+ st.set_page_config(layout="wide")
4
+ # Diasble virtual VRAM on windows system
5
+ import torch
6
+ torch.cuda.set_per_process_memory_fraction(0.999, 0)
7
+
8
+
9
+ st.markdown("""
10
+ # DiffSynth Studio
11
+
12
+ [Source Code](https://github.com/Artiprocher/DiffSynth-Studio)
13
+
14
+ Welcome to DiffSynth Studio.
15
+ """)
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [2023] [Zhongjie Duan]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,13 +1,117 @@
1
- ---
2
- title: Diffutoon
3
- emoji: 🏃
4
- colorFrom: green
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 4.39.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DiffSynth Studio
2
+
3
+
4
+ ## Introduction
5
+
6
+ DiffSynth Studio is a Diffusion engine. We have restructured architectures including Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. We provide many interesting features. Enjoy the magic of Diffusion models!
7
+
8
+ Until now, DiffSynth Studio has supported the following models:
9
+
10
+ * [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
11
+ * [Kolors](https://huggingface.co/Kwai-Kolors/Kolors)
12
+ * [Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
13
+ * [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt)
14
+ * [Hunyuan-DiT](https://github.com/Tencent/HunyuanDiT)
15
+ * [RIFE](https://github.com/hzwer/ECCV2022-RIFE)
16
+ * [ESRGAN](https://github.com/xinntao/ESRGAN)
17
+ * [Ip-Adapter](https://github.com/tencent-ailab/IP-Adapter)
18
+ * [AnimateDiff](https://github.com/guoyww/animatediff/)
19
+ * [ControlNet](https://github.com/lllyasviel/ControlNet)
20
+ * [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
21
+ * [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
22
+
23
+ ## News
24
+
25
+
26
+ - **June 21, 2024.** 🔥🔥🔥 We propose ExVideo, a post-tuning technique aimed at enhancing the capability of video generation models. We have extended Stable Video Diffusion to achieve the generation of long videos up to 128 frames.
27
+ - [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
28
+ - Source code is released in this repo. See [`examples/ExVideo`](./examples/ExVideo/).
29
+ - Models are released on [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
30
+ - Technical report is released on [arXiv](https://arxiv.org/abs/2406.14130).
31
+ - You can try ExVideo in this [Demo](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1)!
32
+
33
+ - **June 13, 2024.** DiffSynth Studio is transferred to ModelScope. The developers have transitioned from "I" to "we". Of course, I will still participate in development and maintenance.
34
+
35
+ - **Jan 29, 2024.** We propose Diffutoon, a fantastic solution for toon shading.
36
+ - [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
37
+ - The source codes are released in this project.
38
+ - The technical report (IJCAI 2024) is released on [arXiv](https://arxiv.org/abs/2401.16224).
39
+
40
+ - **Dec 8, 2023.** We decide to develop a new Project, aiming to release the potential of diffusion models, especially in video synthesis. The development of this project is started.
41
+
42
+ - **Nov 15, 2023.** We propose FastBlend, a powerful video deflickering algorithm.
43
+ - The sd-webui extension is released on [GitHub](https://github.com/Artiprocher/sd-webui-fastblend).
44
+ - Demo videos are shown on Bilibili, including three tasks.
45
+ - [Video deflickering](https://www.bilibili.com/video/BV1d94y1W7PE)
46
+ - [Video interpolation](https://www.bilibili.com/video/BV1Lw411m71p)
47
+ - [Image-driven video rendering](https://www.bilibili.com/video/BV1RB4y1Z7LF)
48
+ - The technical report is released on [arXiv](https://arxiv.org/abs/2311.09265).
49
+ - An unofficial ComfyUI extension developed by other users is released on [GitHub](https://github.com/AInseven/ComfyUI-fastblend).
50
+
51
+ - **Oct 1, 2023.** We release an early version of this project, namely FastSDXL. A try for building a diffusion engine.
52
+ - The source codes are released on [GitHub](https://github.com/Artiprocher/FastSDXL).
53
+ - FastSDXL includes a trainable OLSS scheduler for efficiency improvement.
54
+ - The original repo of OLSS is [here](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler).
55
+ - The technical report (CIKM 2023) is released on [arXiv](https://arxiv.org/abs/2305.14677).
56
+ - A demo video is shown on [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj).
57
+ - Since OLSS requires additional training, we don't implement it in this project.
58
+
59
+ - **Aug 29, 2023.** We propose DiffSynth, a video synthesis framework.
60
+ - [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/).
61
+ - The source codes are released in [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth).
62
+ - The technical report (ECML PKDD 2024) is released on [arXiv](https://arxiv.org/abs/2308.03463).
63
+
64
+
65
+ ## Installation
66
+
67
+ ```
68
+ git clone https://github.com/modelscope/DiffSynth-Studio.git
69
+ cd DiffSynth-Studio
70
+ pip install -e .
71
+ ```
72
+
73
+ ## Usage (in Python code)
74
+
75
+ The Python examples are in [`examples`](./examples/). We provide an overview here.
76
+
77
+ ### Long Video Synthesis
78
+
79
+ We trained an extended video synthesis model, which can generate 128 frames. [`examples/ExVideo`](./examples/ExVideo/)
80
+
81
+ https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
82
+
83
+ ### Image Synthesis
84
+
85
+ Generate high-resolution images, by breaking the limitation of diffusion models! [`examples/image_synthesis`](./examples/image_synthesis/).
86
+
87
+ LoRA fine-tuning is supported in [`examples/train`](./examples/train/).
88
+
89
+ |Model|Example|
90
+ |-|-|
91
+ |Stable Diffusion|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|
92
+ |Stable Diffusion XL|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|
93
+ |Stable Diffusion 3|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/4df346db-6f91-420a-b4c1-26e205376098)|
94
+ |Kolors|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/53ef6f41-da11-4701-8665-9f64392607bf)|
95
+ |Hunyuan-DiT|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/60b022c8-df3f-4541-95ab-bf39f2fa8bb5)|
96
+
97
+ ### Toon Shading
98
+
99
+ Render realistic videos in a flatten style and enable video editing features. [`examples/Diffutoon`](./examples/Diffutoon/)
100
+
101
+ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd
102
+
103
+ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/20528af5-5100-474a-8cdc-440b9efdd86c
104
+
105
+ ### Video Stylization
106
+
107
+ Video stylization without video models. [`examples/diffsynth`](./examples/diffsynth/)
108
+
109
+ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
110
+
111
+ ## Usage (in WebUI)
112
+
113
+ ```
114
+ python -m streamlit run DiffSynth_Studio.py
115
+ ```
116
+
117
+ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/93085557-73f3-4eee-a205-9829591ef954
diffsynth/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .data import *
2
+ from .models import *
3
+ from .prompts import *
4
+ from .schedulers import *
5
+ from .pipelines import *
6
+ from .controlnets import *
diffsynth/controlnets/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager
2
+ from .processors import Annotator
diffsynth/controlnets/controlnet_unit.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from .processors import Processor_id
4
+
5
+
6
+ class ControlNetConfigUnit:
7
+ def __init__(self, processor_id: Processor_id, model_path, scale=1.0):
8
+ self.processor_id = processor_id
9
+ self.model_path = model_path
10
+ self.scale = scale
11
+
12
+
13
+ class ControlNetUnit:
14
+ def __init__(self, processor, model, scale=1.0):
15
+ self.processor = processor
16
+ self.model = model
17
+ self.scale = scale
18
+
19
+
20
+ class MultiControlNetManager:
21
+ def __init__(self, controlnet_units=[]):
22
+ self.processors = [unit.processor for unit in controlnet_units]
23
+ self.models = [unit.model for unit in controlnet_units]
24
+ self.scales = [unit.scale for unit in controlnet_units]
25
+
26
+ def process_image(self, image, processor_id=None):
27
+ if processor_id is None:
28
+ processed_image = [processor(image) for processor in self.processors]
29
+ else:
30
+ processed_image = [self.processors[processor_id](image)]
31
+ processed_image = torch.concat([
32
+ torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
33
+ for image_ in processed_image
34
+ ], dim=0)
35
+ return processed_image
36
+
37
+ def __call__(
38
+ self,
39
+ sample, timestep, encoder_hidden_states, conditionings,
40
+ tiled=False, tile_size=64, tile_stride=32
41
+ ):
42
+ res_stack = None
43
+ for conditioning, model, scale in zip(conditionings, self.models, self.scales):
44
+ res_stack_ = model(
45
+ sample, timestep, encoder_hidden_states, conditioning,
46
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
47
+ )
48
+ res_stack_ = [res * scale for res in res_stack_]
49
+ if res_stack is None:
50
+ res_stack = res_stack_
51
+ else:
52
+ res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
53
+ return res_stack
diffsynth/controlnets/processors.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import Literal, TypeAlias
2
+ import warnings
3
+ with warnings.catch_warnings():
4
+ warnings.simplefilter("ignore")
5
+ from controlnet_aux.processor import (
6
+ CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector
7
+ )
8
+
9
+
10
+ Processor_id: TypeAlias = Literal[
11
+ "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"
12
+ ]
13
+
14
+ class Annotator:
15
+ def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda'):
16
+ if processor_id == "canny":
17
+ self.processor = CannyDetector()
18
+ elif processor_id == "depth":
19
+ self.processor = MidasDetector.from_pretrained(model_path).to(device)
20
+ elif processor_id == "softedge":
21
+ self.processor = HEDdetector.from_pretrained(model_path).to(device)
22
+ elif processor_id == "lineart":
23
+ self.processor = LineartDetector.from_pretrained(model_path).to(device)
24
+ elif processor_id == "lineart_anime":
25
+ self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
26
+ elif processor_id == "openpose":
27
+ self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
28
+ elif processor_id == "tile":
29
+ self.processor = None
30
+ else:
31
+ raise ValueError(f"Unsupported processor_id: {processor_id}")
32
+
33
+ self.processor_id = processor_id
34
+ self.detect_resolution = detect_resolution
35
+
36
+ def __call__(self, image):
37
+ width, height = image.size
38
+ if self.processor_id == "openpose":
39
+ kwargs = {
40
+ "include_body": True,
41
+ "include_hand": True,
42
+ "include_face": True
43
+ }
44
+ else:
45
+ kwargs = {}
46
+ if self.processor is not None:
47
+ detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)
48
+ image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)
49
+ image = image.resize((width, height))
50
+ return image
51
+
diffsynth/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .video import VideoData, save_video, save_frames
diffsynth/data/video.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imageio, os
2
+ import numpy as np
3
+ from PIL import Image
4
+ from tqdm import tqdm
5
+
6
+
7
+ class LowMemoryVideo:
8
+ def __init__(self, file_name):
9
+ self.reader = imageio.get_reader(file_name)
10
+
11
+ def __len__(self):
12
+ return self.reader.count_frames()
13
+
14
+ def __getitem__(self, item):
15
+ return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
16
+
17
+ def __del__(self):
18
+ self.reader.close()
19
+
20
+
21
+ def split_file_name(file_name):
22
+ result = []
23
+ number = -1
24
+ for i in file_name:
25
+ if ord(i)>=ord("0") and ord(i)<=ord("9"):
26
+ if number == -1:
27
+ number = 0
28
+ number = number*10 + ord(i) - ord("0")
29
+ else:
30
+ if number != -1:
31
+ result.append(number)
32
+ number = -1
33
+ result.append(i)
34
+ if number != -1:
35
+ result.append(number)
36
+ result = tuple(result)
37
+ return result
38
+
39
+
40
+ def search_for_images(folder):
41
+ file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
42
+ file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
43
+ file_list = [i[1] for i in sorted(file_list)]
44
+ file_list = [os.path.join(folder, i) for i in file_list]
45
+ return file_list
46
+
47
+
48
+ class LowMemoryImageFolder:
49
+ def __init__(self, folder, file_list=None):
50
+ if file_list is None:
51
+ self.file_list = search_for_images(folder)
52
+ else:
53
+ self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
54
+
55
+ def __len__(self):
56
+ return len(self.file_list)
57
+
58
+ def __getitem__(self, item):
59
+ return Image.open(self.file_list[item]).convert("RGB")
60
+
61
+ def __del__(self):
62
+ pass
63
+
64
+
65
+ def crop_and_resize(image, height, width):
66
+ image = np.array(image)
67
+ image_height, image_width, _ = image.shape
68
+ if image_height / image_width < height / width:
69
+ croped_width = int(image_height / height * width)
70
+ left = (image_width - croped_width) // 2
71
+ image = image[:, left: left+croped_width]
72
+ image = Image.fromarray(image).resize((width, height))
73
+ else:
74
+ croped_height = int(image_width / width * height)
75
+ left = (image_height - croped_height) // 2
76
+ image = image[left: left+croped_height, :]
77
+ image = Image.fromarray(image).resize((width, height))
78
+ return image
79
+
80
+
81
+ class VideoData:
82
+ def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs):
83
+ if video_file is not None:
84
+ self.data_type = "video"
85
+ self.data = LowMemoryVideo(video_file, **kwargs)
86
+ elif image_folder is not None:
87
+ self.data_type = "images"
88
+ self.data = LowMemoryImageFolder(image_folder, **kwargs)
89
+ else:
90
+ raise ValueError("Cannot open video or image folder")
91
+ self.length = None
92
+ self.set_shape(height, width)
93
+
94
+ def raw_data(self):
95
+ frames = []
96
+ for i in range(self.__len__()):
97
+ frames.append(self.__getitem__(i))
98
+ return frames
99
+
100
+ def set_length(self, length):
101
+ self.length = length
102
+
103
+ def set_shape(self, height, width):
104
+ self.height = height
105
+ self.width = width
106
+
107
+ def __len__(self):
108
+ if self.length is None:
109
+ return len(self.data)
110
+ else:
111
+ return self.length
112
+
113
+ def shape(self):
114
+ if self.height is not None and self.width is not None:
115
+ return self.height, self.width
116
+ else:
117
+ height, width, _ = self.__getitem__(0).shape
118
+ return height, width
119
+
120
+ def __getitem__(self, item):
121
+ frame = self.data.__getitem__(item)
122
+ width, height = frame.size
123
+ if self.height is not None and self.width is not None:
124
+ if self.height != height or self.width != width:
125
+ frame = crop_and_resize(frame, self.height, self.width)
126
+ return frame
127
+
128
+ def __del__(self):
129
+ pass
130
+
131
+ def save_images(self, folder):
132
+ os.makedirs(folder, exist_ok=True)
133
+ for i in tqdm(range(self.__len__()), desc="Saving images"):
134
+ frame = self.__getitem__(i)
135
+ frame.save(os.path.join(folder, f"{i}.png"))
136
+
137
+
138
+ def save_video(frames, save_path, fps, quality=9):
139
+ writer = imageio.get_writer(save_path, fps=fps, quality=quality)
140
+ for frame in tqdm(frames, desc="Saving video"):
141
+ frame = np.array(frame)
142
+ writer.append_data(frame)
143
+ writer.close()
144
+
145
+ def save_frames(frames, save_path):
146
+ os.makedirs(save_path, exist_ok=True)
147
+ for i, frame in enumerate(tqdm(frames, desc="Saving images")):
148
+ frame.save(os.path.join(save_path, f"{i}.png"))
diffsynth/extensions/ESRGAN/__init__.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import repeat
3
+ from PIL import Image
4
+ import numpy as np
5
+
6
+
7
+ class ResidualDenseBlock(torch.nn.Module):
8
+
9
+ def __init__(self, num_feat=64, num_grow_ch=32):
10
+ super(ResidualDenseBlock, self).__init__()
11
+ self.conv1 = torch.nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
12
+ self.conv2 = torch.nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
13
+ self.conv3 = torch.nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
14
+ self.conv4 = torch.nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
15
+ self.conv5 = torch.nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
16
+ self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
17
+
18
+ def forward(self, x):
19
+ x1 = self.lrelu(self.conv1(x))
20
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
21
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
22
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
23
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
24
+ return x5 * 0.2 + x
25
+
26
+
27
+ class RRDB(torch.nn.Module):
28
+
29
+ def __init__(self, num_feat, num_grow_ch=32):
30
+ super(RRDB, self).__init__()
31
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
32
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
33
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
34
+
35
+ def forward(self, x):
36
+ out = self.rdb1(x)
37
+ out = self.rdb2(out)
38
+ out = self.rdb3(out)
39
+ return out * 0.2 + x
40
+
41
+
42
+ class RRDBNet(torch.nn.Module):
43
+
44
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32):
45
+ super(RRDBNet, self).__init__()
46
+ self.conv_first = torch.nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
47
+ self.body = torch.torch.nn.Sequential(*[RRDB(num_feat=num_feat, num_grow_ch=num_grow_ch) for _ in range(num_block)])
48
+ self.conv_body = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
49
+ # upsample
50
+ self.conv_up1 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
51
+ self.conv_up2 = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
52
+ self.conv_hr = torch.nn.Conv2d(num_feat, num_feat, 3, 1, 1)
53
+ self.conv_last = torch.nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
54
+ self.lrelu = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True)
55
+
56
+ def forward(self, x):
57
+ feat = x
58
+ feat = self.conv_first(feat)
59
+ body_feat = self.conv_body(self.body(feat))
60
+ feat = feat + body_feat
61
+ # upsample
62
+ feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
63
+ feat = self.lrelu(self.conv_up1(feat))
64
+ feat = repeat(feat, "B C H W -> B C (H 2) (W 2)")
65
+ feat = self.lrelu(self.conv_up2(feat))
66
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
67
+ return out
68
+
69
+
70
+ class ESRGAN(torch.nn.Module):
71
+ def __init__(self, model):
72
+ super().__init__()
73
+ self.model = model
74
+
75
+ @staticmethod
76
+ def from_pretrained(model_path):
77
+ model = RRDBNet()
78
+ state_dict = torch.load(model_path, map_location="cpu")["params_ema"]
79
+ model.load_state_dict(state_dict)
80
+ model.eval()
81
+ return ESRGAN(model)
82
+
83
+ def process_image(self, image):
84
+ image = torch.Tensor(np.array(image, dtype=np.float32) / 255).permute(2, 0, 1)
85
+ return image
86
+
87
+ def process_images(self, images):
88
+ images = [self.process_image(image) for image in images]
89
+ images = torch.stack(images)
90
+ return images
91
+
92
+ def decode_images(self, images):
93
+ images = (images.permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
94
+ images = [Image.fromarray(image) for image in images]
95
+ return images
96
+
97
+ @torch.no_grad()
98
+ def upscale(self, images, batch_size=4, progress_bar=lambda x:x):
99
+ # Preprocess
100
+ input_tensor = self.process_images(images)
101
+
102
+ # Interpolate
103
+ output_tensor = []
104
+ for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
105
+ batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
106
+ batch_input_tensor = input_tensor[batch_id: batch_id_]
107
+ batch_input_tensor = batch_input_tensor.to(
108
+ device=self.model.conv_first.weight.device,
109
+ dtype=self.model.conv_first.weight.dtype)
110
+ batch_output_tensor = self.model(batch_input_tensor)
111
+ output_tensor.append(batch_output_tensor.cpu())
112
+
113
+ # Output
114
+ output_tensor = torch.concat(output_tensor, dim=0)
115
+
116
+ # To images
117
+ output_images = self.decode_images(output_tensor)
118
+ return output_images
diffsynth/extensions/FastBlend/__init__.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .runners.fast import TableManager, PyramidPatchMatcher
2
+ from PIL import Image
3
+ import numpy as np
4
+ import cupy as cp
5
+
6
+
7
+ class FastBlendSmoother:
8
+ def __init__(self):
9
+ self.batch_size = 8
10
+ self.window_size = 64
11
+ self.ebsynth_config = {
12
+ "minimum_patch_size": 5,
13
+ "threads_per_block": 8,
14
+ "num_iter": 5,
15
+ "gpu_id": 0,
16
+ "guide_weight": 10.0,
17
+ "initialize": "identity",
18
+ "tracking_window_size": 0,
19
+ }
20
+
21
+ @staticmethod
22
+ def from_model_manager(model_manager):
23
+ # TODO: fetch GPU ID from model_manager
24
+ return FastBlendSmoother()
25
+
26
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config):
27
+ frames_guide = [np.array(frame) for frame in frames_guide]
28
+ frames_style = [np.array(frame) for frame in frames_style]
29
+ table_manager = TableManager()
30
+ patch_match_engine = PyramidPatchMatcher(
31
+ image_height=frames_style[0].shape[0],
32
+ image_width=frames_style[0].shape[1],
33
+ channel=3,
34
+ **ebsynth_config
35
+ )
36
+ # left part
37
+ table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="FastBlend Step 1/4")
38
+ table_l = table_manager.remapping_table_to_blending_table(table_l)
39
+ table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="FastBlend Step 2/4")
40
+ # right part
41
+ table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="FastBlend Step 3/4")
42
+ table_r = table_manager.remapping_table_to_blending_table(table_r)
43
+ table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="FastBlend Step 4/4")[::-1]
44
+ # merge
45
+ frames = []
46
+ for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
47
+ weight_m = -1
48
+ weight = weight_l + weight_m + weight_r
49
+ frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
50
+ frames.append(frame)
51
+ frames = [Image.fromarray(frame.clip(0, 255).astype("uint8")) for frame in frames]
52
+ return frames
53
+
54
+ def __call__(self, rendered_frames, original_frames=None, **kwargs):
55
+ frames = self.run(
56
+ original_frames, rendered_frames,
57
+ self.batch_size, self.window_size, self.ebsynth_config
58
+ )
59
+ mempool = cp.get_default_memory_pool()
60
+ pinned_mempool = cp.get_default_pinned_memory_pool()
61
+ mempool.free_all_blocks()
62
+ pinned_mempool.free_all_blocks()
63
+ return frames
diffsynth/extensions/FastBlend/api.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .runners import AccurateModeRunner, FastModeRunner, BalancedModeRunner, InterpolationModeRunner, InterpolationModeSingleFrameRunner
2
+ from .data import VideoData, get_video_fps, save_video, search_for_images
3
+ import os
4
+ import gradio as gr
5
+
6
+
7
+ def check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder):
8
+ frames_guide = VideoData(video_guide, video_guide_folder)
9
+ frames_style = VideoData(video_style, video_style_folder)
10
+ message = ""
11
+ if len(frames_guide) < len(frames_style):
12
+ message += f"The number of frames mismatches. Only the first {len(frames_guide)} frames of style video will be used.\n"
13
+ frames_style.set_length(len(frames_guide))
14
+ elif len(frames_guide) > len(frames_style):
15
+ message += f"The number of frames mismatches. Only the first {len(frames_style)} frames of guide video will be used.\n"
16
+ frames_guide.set_length(len(frames_style))
17
+ height_guide, width_guide = frames_guide.shape()
18
+ height_style, width_style = frames_style.shape()
19
+ if height_guide != height_style or width_guide != width_style:
20
+ message += f"The shape of frames mismatches. The frames in style video will be resized to (height: {height_guide}, width: {width_guide})\n"
21
+ frames_style.set_shape(height_guide, width_guide)
22
+ return frames_guide, frames_style, message
23
+
24
+
25
+ def smooth_video(
26
+ video_guide,
27
+ video_guide_folder,
28
+ video_style,
29
+ video_style_folder,
30
+ mode,
31
+ window_size,
32
+ batch_size,
33
+ tracking_window_size,
34
+ output_path,
35
+ fps,
36
+ minimum_patch_size,
37
+ num_iter,
38
+ guide_weight,
39
+ initialize,
40
+ progress = None,
41
+ ):
42
+ # input
43
+ frames_guide, frames_style, message = check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder)
44
+ if len(message) > 0:
45
+ print(message)
46
+ # output
47
+ if output_path == "":
48
+ if video_style is None:
49
+ output_path = os.path.join(video_style_folder, "output")
50
+ else:
51
+ output_path = os.path.join(os.path.split(video_style)[0], "output")
52
+ os.makedirs(output_path, exist_ok=True)
53
+ print("No valid output_path. Your video will be saved here:", output_path)
54
+ elif not os.path.exists(output_path):
55
+ os.makedirs(output_path, exist_ok=True)
56
+ print("Your video will be saved here:", output_path)
57
+ frames_path = os.path.join(output_path, "frames")
58
+ video_path = os.path.join(output_path, "video.mp4")
59
+ os.makedirs(frames_path, exist_ok=True)
60
+ # process
61
+ if mode == "Fast" or mode == "Balanced":
62
+ tracking_window_size = 0
63
+ ebsynth_config = {
64
+ "minimum_patch_size": minimum_patch_size,
65
+ "threads_per_block": 8,
66
+ "num_iter": num_iter,
67
+ "gpu_id": 0,
68
+ "guide_weight": guide_weight,
69
+ "initialize": initialize,
70
+ "tracking_window_size": tracking_window_size,
71
+ }
72
+ if mode == "Fast":
73
+ FastModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
74
+ elif mode == "Balanced":
75
+ BalancedModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
76
+ elif mode == "Accurate":
77
+ AccurateModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
78
+ # output
79
+ try:
80
+ fps = int(fps)
81
+ except:
82
+ fps = get_video_fps(video_style) if video_style is not None else 30
83
+ print("Fps:", fps)
84
+ print("Saving video...")
85
+ video_path = save_video(frames_path, video_path, num_frames=len(frames_style), fps=fps)
86
+ print("Success!")
87
+ print("Your frames are here:", frames_path)
88
+ print("Your video is here:", video_path)
89
+ return output_path, fps, video_path
90
+
91
+
92
+ class KeyFrameMatcher:
93
+ def __init__(self):
94
+ pass
95
+
96
+ def extract_number_from_filename(self, file_name):
97
+ result = []
98
+ number = -1
99
+ for i in file_name:
100
+ if ord(i)>=ord("0") and ord(i)<=ord("9"):
101
+ if number == -1:
102
+ number = 0
103
+ number = number*10 + ord(i) - ord("0")
104
+ else:
105
+ if number != -1:
106
+ result.append(number)
107
+ number = -1
108
+ if number != -1:
109
+ result.append(number)
110
+ result = tuple(result)
111
+ return result
112
+
113
+ def extract_number_from_filenames(self, file_names):
114
+ numbers = [self.extract_number_from_filename(file_name) for file_name in file_names]
115
+ min_length = min(len(i) for i in numbers)
116
+ for i in range(min_length-1, -1, -1):
117
+ if len(set(number[i] for number in numbers))==len(file_names):
118
+ return [number[i] for number in numbers]
119
+ return list(range(len(file_names)))
120
+
121
+ def match_using_filename(self, file_names_a, file_names_b):
122
+ file_names_b_set = set(file_names_b)
123
+ matched_file_name = []
124
+ for file_name in file_names_a:
125
+ if file_name not in file_names_b_set:
126
+ matched_file_name.append(None)
127
+ else:
128
+ matched_file_name.append(file_name)
129
+ return matched_file_name
130
+
131
+ def match_using_numbers(self, file_names_a, file_names_b):
132
+ numbers_a = self.extract_number_from_filenames(file_names_a)
133
+ numbers_b = self.extract_number_from_filenames(file_names_b)
134
+ numbers_b_dict = {number: file_name for number, file_name in zip(numbers_b, file_names_b)}
135
+ matched_file_name = []
136
+ for number in numbers_a:
137
+ if number in numbers_b_dict:
138
+ matched_file_name.append(numbers_b_dict[number])
139
+ else:
140
+ matched_file_name.append(None)
141
+ return matched_file_name
142
+
143
+ def match_filenames(self, file_names_a, file_names_b):
144
+ matched_file_name = self.match_using_filename(file_names_a, file_names_b)
145
+ if sum([i is not None for i in matched_file_name]) > 0:
146
+ return matched_file_name
147
+ matched_file_name = self.match_using_numbers(file_names_a, file_names_b)
148
+ return matched_file_name
149
+
150
+
151
+ def detect_frames(frames_path, keyframes_path):
152
+ if not os.path.exists(frames_path) and not os.path.exists(keyframes_path):
153
+ return "Please input the directory of guide video and rendered frames"
154
+ elif not os.path.exists(frames_path):
155
+ return "Please input the directory of guide video"
156
+ elif not os.path.exists(keyframes_path):
157
+ return "Please input the directory of rendered frames"
158
+ frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
159
+ keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
160
+ if len(frames)==0:
161
+ return f"No images detected in {frames_path}"
162
+ if len(keyframes)==0:
163
+ return f"No images detected in {keyframes_path}"
164
+ matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
165
+ max_filename_length = max([len(i) for i in frames])
166
+ if sum([i is not None for i in matched_keyframes])==0:
167
+ message = ""
168
+ for frame, matched_keyframe in zip(frames, matched_keyframes):
169
+ message += frame + " " * (max_filename_length - len(frame) + 1)
170
+ message += "--> No matched keyframes\n"
171
+ else:
172
+ message = ""
173
+ for frame, matched_keyframe in zip(frames, matched_keyframes):
174
+ message += frame + " " * (max_filename_length - len(frame) + 1)
175
+ if matched_keyframe is None:
176
+ message += "--> [to be rendered]\n"
177
+ else:
178
+ message += f"--> {matched_keyframe}\n"
179
+ return message
180
+
181
+
182
+ def check_input_for_interpolating(frames_path, keyframes_path):
183
+ # search for images
184
+ frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
185
+ keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
186
+ # match frames
187
+ matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
188
+ file_list = [file_name for file_name in matched_keyframes if file_name is not None]
189
+ index_style = [i for i, file_name in enumerate(matched_keyframes) if file_name is not None]
190
+ frames_guide = VideoData(None, frames_path)
191
+ frames_style = VideoData(None, keyframes_path, file_list=file_list)
192
+ # match shape
193
+ message = ""
194
+ height_guide, width_guide = frames_guide.shape()
195
+ height_style, width_style = frames_style.shape()
196
+ if height_guide != height_style or width_guide != width_style:
197
+ message += f"The shape of frames mismatches. The rendered keyframes will be resized to (height: {height_guide}, width: {width_guide})\n"
198
+ frames_style.set_shape(height_guide, width_guide)
199
+ return frames_guide, frames_style, index_style, message
200
+
201
+
202
+ def interpolate_video(
203
+ frames_path,
204
+ keyframes_path,
205
+ output_path,
206
+ fps,
207
+ batch_size,
208
+ tracking_window_size,
209
+ minimum_patch_size,
210
+ num_iter,
211
+ guide_weight,
212
+ initialize,
213
+ progress = None,
214
+ ):
215
+ # input
216
+ frames_guide, frames_style, index_style, message = check_input_for_interpolating(frames_path, keyframes_path)
217
+ if len(message) > 0:
218
+ print(message)
219
+ # output
220
+ if output_path == "":
221
+ output_path = os.path.join(keyframes_path, "output")
222
+ os.makedirs(output_path, exist_ok=True)
223
+ print("No valid output_path. Your video will be saved here:", output_path)
224
+ elif not os.path.exists(output_path):
225
+ os.makedirs(output_path, exist_ok=True)
226
+ print("Your video will be saved here:", output_path)
227
+ output_frames_path = os.path.join(output_path, "frames")
228
+ output_video_path = os.path.join(output_path, "video.mp4")
229
+ os.makedirs(output_frames_path, exist_ok=True)
230
+ # process
231
+ ebsynth_config = {
232
+ "minimum_patch_size": minimum_patch_size,
233
+ "threads_per_block": 8,
234
+ "num_iter": num_iter,
235
+ "gpu_id": 0,
236
+ "guide_weight": guide_weight,
237
+ "initialize": initialize,
238
+ "tracking_window_size": tracking_window_size
239
+ }
240
+ if len(index_style)==1:
241
+ InterpolationModeSingleFrameRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
242
+ else:
243
+ InterpolationModeRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
244
+ try:
245
+ fps = int(fps)
246
+ except:
247
+ fps = 30
248
+ print("Fps:", fps)
249
+ print("Saving video...")
250
+ video_path = save_video(output_frames_path, output_video_path, num_frames=len(frames_guide), fps=fps)
251
+ print("Success!")
252
+ print("Your frames are here:", output_frames_path)
253
+ print("Your video is here:", video_path)
254
+ return output_path, fps, video_path
255
+
256
+
257
+ def on_ui_tabs():
258
+ with gr.Blocks(analytics_enabled=False) as ui_component:
259
+ with gr.Tab("Blend"):
260
+ gr.Markdown("""
261
+ # Blend
262
+
263
+ Given a guide video and a style video, this algorithm will make the style video fluent according to the motion features of the guide video. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/208d902d-6aba-48d7-b7d5-cd120ebd306d) to see the example. Note that this extension doesn't support long videos. Please use short videos (e.g., several seconds). The algorithm is mainly designed for 512*512 resolution. Please use a larger `Minimum patch size` for higher resolution.
264
+ """)
265
+ with gr.Row():
266
+ with gr.Column():
267
+ with gr.Tab("Guide video"):
268
+ video_guide = gr.Video(label="Guide video")
269
+ with gr.Tab("Guide video (images format)"):
270
+ video_guide_folder = gr.Textbox(label="Guide video (images format)", value="")
271
+ with gr.Column():
272
+ with gr.Tab("Style video"):
273
+ video_style = gr.Video(label="Style video")
274
+ with gr.Tab("Style video (images format)"):
275
+ video_style_folder = gr.Textbox(label="Style video (images format)", value="")
276
+ with gr.Column():
277
+ output_path = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of style video")
278
+ fps = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
279
+ video_output = gr.Video(label="Output video", interactive=False, show_share_button=True)
280
+ btn = gr.Button(value="Blend")
281
+ with gr.Row():
282
+ with gr.Column():
283
+ gr.Markdown("# Settings")
284
+ mode = gr.Radio(["Fast", "Balanced", "Accurate"], label="Inference mode", value="Fast", interactive=True)
285
+ window_size = gr.Slider(label="Sliding window size", value=15, minimum=1, maximum=1000, step=1, interactive=True)
286
+ batch_size = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
287
+ tracking_window_size = gr.Slider(label="Tracking window size (only for accurate mode)", value=0, minimum=0, maximum=10, step=1, interactive=True)
288
+ gr.Markdown("## Advanced Settings")
289
+ minimum_patch_size = gr.Slider(label="Minimum patch size (odd number)", value=5, minimum=5, maximum=99, step=2, interactive=True)
290
+ num_iter = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
291
+ guide_weight = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
292
+ initialize = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
293
+ with gr.Column():
294
+ gr.Markdown("""
295
+ # Reference
296
+
297
+ * Output directory: the directory to save the video.
298
+ * Inference mode
299
+
300
+ |Mode|Time|Memory|Quality|Frame by frame output|Description|
301
+ |-|-|-|-|-|-|
302
+ |Fast|■|■■■|■■|No|Blend the frames using a tree-like data structure, which requires much RAM but is fast.|
303
+ |Balanced|■■|■|■■|Yes|Blend the frames naively.|
304
+ |Accurate|■■■|■|■■■|Yes|Blend the frames and align them together for higher video quality. When [batch size] >= [sliding window size] * 2 + 1, the performance is the best.|
305
+
306
+ * Sliding window size: our algorithm will blend the frames in a sliding windows. If the size is n, each frame will be blended with the last n frames and the next n frames. A large sliding window can make the video fluent but sometimes smoggy.
307
+ * Batch size: a larger batch size makes the program faster but requires more VRAM.
308
+ * Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
309
+ * Advanced settings
310
+ * Minimum patch size (odd number): the minimum patch size used for patch matching. (Default: 5)
311
+ * Number of iterations: the number of iterations of patch matching. (Default: 5)
312
+ * Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
313
+ * NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
314
+ """)
315
+ btn.click(
316
+ smooth_video,
317
+ inputs=[
318
+ video_guide,
319
+ video_guide_folder,
320
+ video_style,
321
+ video_style_folder,
322
+ mode,
323
+ window_size,
324
+ batch_size,
325
+ tracking_window_size,
326
+ output_path,
327
+ fps,
328
+ minimum_patch_size,
329
+ num_iter,
330
+ guide_weight,
331
+ initialize
332
+ ],
333
+ outputs=[output_path, fps, video_output]
334
+ )
335
+ with gr.Tab("Interpolate"):
336
+ gr.Markdown("""
337
+ # Interpolate
338
+
339
+ Given a guide video and some rendered keyframes, this algorithm will render the remaining frames. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/3490c5b4-8f67-478f-86de-f9adc2ace16a) to see the example. The algorithm is experimental and is only tested for 512*512 resolution.
340
+ """)
341
+ with gr.Row():
342
+ with gr.Column():
343
+ with gr.Row():
344
+ with gr.Column():
345
+ video_guide_folder_ = gr.Textbox(label="Guide video (images format)", value="")
346
+ with gr.Column():
347
+ rendered_keyframes_ = gr.Textbox(label="Rendered keyframes (images format)", value="")
348
+ with gr.Row():
349
+ detected_frames = gr.Textbox(label="Detected frames", value="Please input the directory of guide video and rendered frames", lines=9, max_lines=9, interactive=False)
350
+ video_guide_folder_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
351
+ rendered_keyframes_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
352
+ with gr.Column():
353
+ output_path_ = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of rendered keyframes")
354
+ fps_ = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
355
+ video_output_ = gr.Video(label="Output video", interactive=False, show_share_button=True)
356
+ btn_ = gr.Button(value="Interpolate")
357
+ with gr.Row():
358
+ with gr.Column():
359
+ gr.Markdown("# Settings")
360
+ batch_size_ = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
361
+ tracking_window_size_ = gr.Slider(label="Tracking window size", value=0, minimum=0, maximum=10, step=1, interactive=True)
362
+ gr.Markdown("## Advanced Settings")
363
+ minimum_patch_size_ = gr.Slider(label="Minimum patch size (odd number, larger is better)", value=15, minimum=5, maximum=99, step=2, interactive=True)
364
+ num_iter_ = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
365
+ guide_weight_ = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
366
+ initialize_ = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
367
+ with gr.Column():
368
+ gr.Markdown("""
369
+ # Reference
370
+
371
+ * Output directory: the directory to save the video.
372
+ * Batch size: a larger batch size makes the program faster but requires more VRAM.
373
+ * Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
374
+ * Advanced settings
375
+ * Minimum patch size (odd number): the minimum patch size used for patch matching. **This parameter should be larger than that in blending. (Default: 15)**
376
+ * Number of iterations: the number of iterations of patch matching. (Default: 5)
377
+ * Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
378
+ * NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
379
+ """)
380
+ btn_.click(
381
+ interpolate_video,
382
+ inputs=[
383
+ video_guide_folder_,
384
+ rendered_keyframes_,
385
+ output_path_,
386
+ fps_,
387
+ batch_size_,
388
+ tracking_window_size_,
389
+ minimum_patch_size_,
390
+ num_iter_,
391
+ guide_weight_,
392
+ initialize_,
393
+ ],
394
+ outputs=[output_path_, fps_, video_output_]
395
+ )
396
+
397
+ return [(ui_component, "FastBlend", "FastBlend_ui")]
diffsynth/extensions/FastBlend/cupy_kernels.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cupy as cp
2
+
3
+ remapping_kernel = cp.RawKernel(r'''
4
+ extern "C" __global__
5
+ void remap(
6
+ const int height,
7
+ const int width,
8
+ const int channel,
9
+ const int patch_size,
10
+ const int pad_size,
11
+ const float* source_style,
12
+ const int* nnf,
13
+ float* target_style
14
+ ) {
15
+ const int r = (patch_size - 1) / 2;
16
+ const int x = blockDim.x * blockIdx.x + threadIdx.x;
17
+ const int y = blockDim.y * blockIdx.y + threadIdx.y;
18
+ if (x >= height or y >= width) return;
19
+ const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
20
+ const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size);
21
+ const int min_px = x < r ? -x : -r;
22
+ const int max_px = x + r > height - 1 ? height - 1 - x : r;
23
+ const int min_py = y < r ? -y : -r;
24
+ const int max_py = y + r > width - 1 ? width - 1 - y : r;
25
+ int num = 0;
26
+ for (int px = min_px; px <= max_px; px++){
27
+ for (int py = min_py; py <= max_py; py++){
28
+ const int nid = (x + px) * width + y + py;
29
+ const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px;
30
+ const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py;
31
+ if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue;
32
+ const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size);
33
+ num++;
34
+ for (int c = 0; c < channel; c++){
35
+ target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c];
36
+ }
37
+ }
38
+ }
39
+ for (int c = 0; c < channel; c++){
40
+ target_style[z + pid * channel + c] /= num;
41
+ }
42
+ }
43
+ ''', 'remap')
44
+
45
+
46
+ patch_error_kernel = cp.RawKernel(r'''
47
+ extern "C" __global__
48
+ void patch_error(
49
+ const int height,
50
+ const int width,
51
+ const int channel,
52
+ const int patch_size,
53
+ const int pad_size,
54
+ const float* source,
55
+ const int* nnf,
56
+ const float* target,
57
+ float* error
58
+ ) {
59
+ const int r = (patch_size - 1) / 2;
60
+ const int x = blockDim.x * blockIdx.x + threadIdx.x;
61
+ const int y = blockDim.y * blockIdx.y + threadIdx.y;
62
+ const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
63
+ if (x >= height or y >= width) return;
64
+ const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0];
65
+ const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1];
66
+ float e = 0;
67
+ for (int px = -r; px <= r; px++){
68
+ for (int py = -r; py <= r; py++){
69
+ const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py;
70
+ const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py;
71
+ for (int c = 0; c < channel; c++){
72
+ const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c];
73
+ e += diff * diff;
74
+ }
75
+ }
76
+ }
77
+ error[blockIdx.z * height * width + x * width + y] = e;
78
+ }
79
+ ''', 'patch_error')
80
+
81
+
82
+ pairwise_patch_error_kernel = cp.RawKernel(r'''
83
+ extern "C" __global__
84
+ void pairwise_patch_error(
85
+ const int height,
86
+ const int width,
87
+ const int channel,
88
+ const int patch_size,
89
+ const int pad_size,
90
+ const float* source_a,
91
+ const int* nnf_a,
92
+ const float* source_b,
93
+ const int* nnf_b,
94
+ float* error
95
+ ) {
96
+ const int r = (patch_size - 1) / 2;
97
+ const int x = blockDim.x * blockIdx.x + threadIdx.x;
98
+ const int y = blockDim.y * blockIdx.y + threadIdx.y;
99
+ const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
100
+ if (x >= height or y >= width) return;
101
+ const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2;
102
+ const int x_a = nnf_a[z_nnf + 0];
103
+ const int y_a = nnf_a[z_nnf + 1];
104
+ const int x_b = nnf_b[z_nnf + 0];
105
+ const int y_b = nnf_b[z_nnf + 1];
106
+ float e = 0;
107
+ for (int px = -r; px <= r; px++){
108
+ for (int py = -r; py <= r; py++){
109
+ const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py;
110
+ const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py;
111
+ for (int c = 0; c < channel; c++){
112
+ const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c];
113
+ e += diff * diff;
114
+ }
115
+ }
116
+ }
117
+ error[blockIdx.z * height * width + x * width + y] = e;
118
+ }
119
+ ''', 'pairwise_patch_error')
diffsynth/extensions/FastBlend/data.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imageio, os
2
+ import numpy as np
3
+ from PIL import Image
4
+
5
+
6
+ def read_video(file_name):
7
+ reader = imageio.get_reader(file_name)
8
+ video = []
9
+ for frame in reader:
10
+ frame = np.array(frame)
11
+ video.append(frame)
12
+ reader.close()
13
+ return video
14
+
15
+
16
+ def get_video_fps(file_name):
17
+ reader = imageio.get_reader(file_name)
18
+ fps = reader.get_meta_data()["fps"]
19
+ reader.close()
20
+ return fps
21
+
22
+
23
+ def save_video(frames_path, video_path, num_frames, fps):
24
+ writer = imageio.get_writer(video_path, fps=fps, quality=9)
25
+ for i in range(num_frames):
26
+ frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i)))
27
+ writer.append_data(frame)
28
+ writer.close()
29
+ return video_path
30
+
31
+
32
+ class LowMemoryVideo:
33
+ def __init__(self, file_name):
34
+ self.reader = imageio.get_reader(file_name)
35
+
36
+ def __len__(self):
37
+ return self.reader.count_frames()
38
+
39
+ def __getitem__(self, item):
40
+ return np.array(self.reader.get_data(item))
41
+
42
+ def __del__(self):
43
+ self.reader.close()
44
+
45
+
46
+ def split_file_name(file_name):
47
+ result = []
48
+ number = -1
49
+ for i in file_name:
50
+ if ord(i)>=ord("0") and ord(i)<=ord("9"):
51
+ if number == -1:
52
+ number = 0
53
+ number = number*10 + ord(i) - ord("0")
54
+ else:
55
+ if number != -1:
56
+ result.append(number)
57
+ number = -1
58
+ result.append(i)
59
+ if number != -1:
60
+ result.append(number)
61
+ result = tuple(result)
62
+ return result
63
+
64
+
65
+ def search_for_images(folder):
66
+ file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
67
+ file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
68
+ file_list = [i[1] for i in sorted(file_list)]
69
+ file_list = [os.path.join(folder, i) for i in file_list]
70
+ return file_list
71
+
72
+
73
+ def read_images(folder):
74
+ file_list = search_for_images(folder)
75
+ frames = [np.array(Image.open(i)) for i in file_list]
76
+ return frames
77
+
78
+
79
+ class LowMemoryImageFolder:
80
+ def __init__(self, folder, file_list=None):
81
+ if file_list is None:
82
+ self.file_list = search_for_images(folder)
83
+ else:
84
+ self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
85
+
86
+ def __len__(self):
87
+ return len(self.file_list)
88
+
89
+ def __getitem__(self, item):
90
+ return np.array(Image.open(self.file_list[item]))
91
+
92
+ def __del__(self):
93
+ pass
94
+
95
+
96
+ class VideoData:
97
+ def __init__(self, video_file, image_folder, **kwargs):
98
+ if video_file is not None:
99
+ self.data_type = "video"
100
+ self.data = LowMemoryVideo(video_file, **kwargs)
101
+ elif image_folder is not None:
102
+ self.data_type = "images"
103
+ self.data = LowMemoryImageFolder(image_folder, **kwargs)
104
+ else:
105
+ raise ValueError("Cannot open video or image folder")
106
+ self.length = None
107
+ self.height = None
108
+ self.width = None
109
+
110
+ def raw_data(self):
111
+ frames = []
112
+ for i in range(self.__len__()):
113
+ frames.append(self.__getitem__(i))
114
+ return frames
115
+
116
+ def set_length(self, length):
117
+ self.length = length
118
+
119
+ def set_shape(self, height, width):
120
+ self.height = height
121
+ self.width = width
122
+
123
+ def __len__(self):
124
+ if self.length is None:
125
+ return len(self.data)
126
+ else:
127
+ return self.length
128
+
129
+ def shape(self):
130
+ if self.height is not None and self.width is not None:
131
+ return self.height, self.width
132
+ else:
133
+ height, width, _ = self.__getitem__(0).shape
134
+ return height, width
135
+
136
+ def __getitem__(self, item):
137
+ frame = self.data.__getitem__(item)
138
+ height, width, _ = frame.shape
139
+ if self.height is not None and self.width is not None:
140
+ if self.height != height or self.width != width:
141
+ frame = Image.fromarray(frame).resize((self.width, self.height))
142
+ frame = np.array(frame)
143
+ return frame
144
+
145
+ def __del__(self):
146
+ pass
diffsynth/extensions/FastBlend/patch_match.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_error_kernel
2
+ import numpy as np
3
+ import cupy as cp
4
+ import cv2
5
+
6
+
7
+ class PatchMatcher:
8
+ def __init__(
9
+ self, height, width, channel, minimum_patch_size,
10
+ threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
11
+ random_search_steps=3, random_search_range=4,
12
+ use_mean_target_style=False, use_pairwise_patch_error=False,
13
+ tracking_window_size=0
14
+ ):
15
+ self.height = height
16
+ self.width = width
17
+ self.channel = channel
18
+ self.minimum_patch_size = minimum_patch_size
19
+ self.threads_per_block = threads_per_block
20
+ self.num_iter = num_iter
21
+ self.gpu_id = gpu_id
22
+ self.guide_weight = guide_weight
23
+ self.random_search_steps = random_search_steps
24
+ self.random_search_range = random_search_range
25
+ self.use_mean_target_style = use_mean_target_style
26
+ self.use_pairwise_patch_error = use_pairwise_patch_error
27
+ self.tracking_window_size = tracking_window_size
28
+
29
+ self.patch_size_list = [minimum_patch_size + i*2 for i in range(num_iter)][::-1]
30
+ self.pad_size = self.patch_size_list[0] // 2
31
+ self.grid = (
32
+ (height + threads_per_block - 1) // threads_per_block,
33
+ (width + threads_per_block - 1) // threads_per_block
34
+ )
35
+ self.block = (threads_per_block, threads_per_block)
36
+
37
+ def pad_image(self, image):
38
+ return cp.pad(image, ((0, 0), (self.pad_size, self.pad_size), (self.pad_size, self.pad_size), (0, 0)))
39
+
40
+ def unpad_image(self, image):
41
+ return image[:, self.pad_size: -self.pad_size, self.pad_size: -self.pad_size, :]
42
+
43
+ def apply_nnf_to_image(self, nnf, source):
44
+ batch_size = source.shape[0]
45
+ target = cp.zeros((batch_size, self.height + self.pad_size * 2, self.width + self.pad_size * 2, self.channel), dtype=cp.float32)
46
+ remapping_kernel(
47
+ self.grid + (batch_size,),
48
+ self.block,
49
+ (self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target)
50
+ )
51
+ return target
52
+
53
+ def get_patch_error(self, source, nnf, target):
54
+ batch_size = source.shape[0]
55
+ error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
56
+ patch_error_kernel(
57
+ self.grid + (batch_size,),
58
+ self.block,
59
+ (self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target, error)
60
+ )
61
+ return error
62
+
63
+ def get_pairwise_patch_error(self, source, nnf):
64
+ batch_size = source.shape[0]//2
65
+ error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
66
+ source_a, nnf_a = source[0::2].copy(), nnf[0::2].copy()
67
+ source_b, nnf_b = source[1::2].copy(), nnf[1::2].copy()
68
+ pairwise_patch_error_kernel(
69
+ self.grid + (batch_size,),
70
+ self.block,
71
+ (self.height, self.width, self.channel, self.patch_size, self.pad_size, source_a, nnf_a, source_b, nnf_b, error)
72
+ )
73
+ error = error.repeat(2, axis=0)
74
+ return error
75
+
76
+ def get_error(self, source_guide, target_guide, source_style, target_style, nnf):
77
+ error_guide = self.get_patch_error(source_guide, nnf, target_guide)
78
+ if self.use_mean_target_style:
79
+ target_style = self.apply_nnf_to_image(nnf, source_style)
80
+ target_style = target_style.mean(axis=0, keepdims=True)
81
+ target_style = target_style.repeat(source_guide.shape[0], axis=0)
82
+ if self.use_pairwise_patch_error:
83
+ error_style = self.get_pairwise_patch_error(source_style, nnf)
84
+ else:
85
+ error_style = self.get_patch_error(source_style, nnf, target_style)
86
+ error = error_guide * self.guide_weight + error_style
87
+ return error
88
+
89
+ def clamp_bound(self, nnf):
90
+ nnf[:,:,:,0] = cp.clip(nnf[:,:,:,0], 0, self.height-1)
91
+ nnf[:,:,:,1] = cp.clip(nnf[:,:,:,1], 0, self.width-1)
92
+ return nnf
93
+
94
+ def random_step(self, nnf, r):
95
+ batch_size = nnf.shape[0]
96
+ step = cp.random.randint(-r, r+1, size=(batch_size, self.height, self.width, 2), dtype=cp.int32)
97
+ upd_nnf = self.clamp_bound(nnf + step)
98
+ return upd_nnf
99
+
100
+ def neighboor_step(self, nnf, d):
101
+ if d==0:
102
+ upd_nnf = cp.concatenate([nnf[:, :1, :], nnf[:, :-1, :]], axis=1)
103
+ upd_nnf[:, :, :, 0] += 1
104
+ elif d==1:
105
+ upd_nnf = cp.concatenate([nnf[:, :, :1], nnf[:, :, :-1]], axis=2)
106
+ upd_nnf[:, :, :, 1] += 1
107
+ elif d==2:
108
+ upd_nnf = cp.concatenate([nnf[:, 1:, :], nnf[:, -1:, :]], axis=1)
109
+ upd_nnf[:, :, :, 0] -= 1
110
+ elif d==3:
111
+ upd_nnf = cp.concatenate([nnf[:, :, 1:], nnf[:, :, -1:]], axis=2)
112
+ upd_nnf[:, :, :, 1] -= 1
113
+ upd_nnf = self.clamp_bound(upd_nnf)
114
+ return upd_nnf
115
+
116
+ def shift_nnf(self, nnf, d):
117
+ if d>0:
118
+ d = min(nnf.shape[0], d)
119
+ upd_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
120
+ else:
121
+ d = max(-nnf.shape[0], d)
122
+ upd_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
123
+ return upd_nnf
124
+
125
+ def track_step(self, nnf, d):
126
+ if self.use_pairwise_patch_error:
127
+ upd_nnf = cp.zeros_like(nnf)
128
+ upd_nnf[0::2] = self.shift_nnf(nnf[0::2], d)
129
+ upd_nnf[1::2] = self.shift_nnf(nnf[1::2], d)
130
+ else:
131
+ upd_nnf = self.shift_nnf(nnf, d)
132
+ return upd_nnf
133
+
134
+ def C(self, n, m):
135
+ # not used
136
+ c = 1
137
+ for i in range(1, n+1):
138
+ c *= i
139
+ for i in range(1, m+1):
140
+ c //= i
141
+ for i in range(1, n-m+1):
142
+ c //= i
143
+ return c
144
+
145
+ def bezier_step(self, nnf, r):
146
+ # not used
147
+ n = r * 2 - 1
148
+ upd_nnf = cp.zeros(shape=nnf.shape, dtype=cp.float32)
149
+ for i, d in enumerate(list(range(-r, 0)) + list(range(1, r+1))):
150
+ if d>0:
151
+ ctl_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
152
+ elif d<0:
153
+ ctl_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
154
+ upd_nnf += ctl_nnf * (self.C(n, i) / 2**n)
155
+ upd_nnf = self.clamp_bound(upd_nnf).astype(nnf.dtype)
156
+ return upd_nnf
157
+
158
+ def update(self, source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf):
159
+ upd_err = self.get_error(source_guide, target_guide, source_style, target_style, upd_nnf)
160
+ upd_idx = (upd_err < err)
161
+ nnf[upd_idx] = upd_nnf[upd_idx]
162
+ err[upd_idx] = upd_err[upd_idx]
163
+ return nnf, err
164
+
165
+ def propagation(self, source_guide, target_guide, source_style, target_style, nnf, err):
166
+ for d in cp.random.permutation(4):
167
+ upd_nnf = self.neighboor_step(nnf, d)
168
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
169
+ return nnf, err
170
+
171
+ def random_search(self, source_guide, target_guide, source_style, target_style, nnf, err):
172
+ for i in range(self.random_search_steps):
173
+ upd_nnf = self.random_step(nnf, self.random_search_range)
174
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
175
+ return nnf, err
176
+
177
+ def track(self, source_guide, target_guide, source_style, target_style, nnf, err):
178
+ for d in range(1, self.tracking_window_size + 1):
179
+ upd_nnf = self.track_step(nnf, d)
180
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
181
+ upd_nnf = self.track_step(nnf, -d)
182
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
183
+ return nnf, err
184
+
185
+ def iteration(self, source_guide, target_guide, source_style, target_style, nnf, err):
186
+ nnf, err = self.propagation(source_guide, target_guide, source_style, target_style, nnf, err)
187
+ nnf, err = self.random_search(source_guide, target_guide, source_style, target_style, nnf, err)
188
+ nnf, err = self.track(source_guide, target_guide, source_style, target_style, nnf, err)
189
+ return nnf, err
190
+
191
+ def estimate_nnf(self, source_guide, target_guide, source_style, nnf):
192
+ with cp.cuda.Device(self.gpu_id):
193
+ source_guide = self.pad_image(source_guide)
194
+ target_guide = self.pad_image(target_guide)
195
+ source_style = self.pad_image(source_style)
196
+ for it in range(self.num_iter):
197
+ self.patch_size = self.patch_size_list[it]
198
+ target_style = self.apply_nnf_to_image(nnf, source_style)
199
+ err = self.get_error(source_guide, target_guide, source_style, target_style, nnf)
200
+ nnf, err = self.iteration(source_guide, target_guide, source_style, target_style, nnf, err)
201
+ target_style = self.unpad_image(self.apply_nnf_to_image(nnf, source_style))
202
+ return nnf, target_style
203
+
204
+
205
+ class PyramidPatchMatcher:
206
+ def __init__(
207
+ self, image_height, image_width, channel, minimum_patch_size,
208
+ threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
209
+ use_mean_target_style=False, use_pairwise_patch_error=False,
210
+ tracking_window_size=0,
211
+ initialize="identity"
212
+ ):
213
+ maximum_patch_size = minimum_patch_size + (num_iter - 1) * 2
214
+ self.pyramid_level = int(np.log2(min(image_height, image_width) / maximum_patch_size))
215
+ self.pyramid_heights = []
216
+ self.pyramid_widths = []
217
+ self.patch_matchers = []
218
+ self.minimum_patch_size = minimum_patch_size
219
+ self.num_iter = num_iter
220
+ self.gpu_id = gpu_id
221
+ self.initialize = initialize
222
+ for level in range(self.pyramid_level):
223
+ height = image_height//(2**(self.pyramid_level - 1 - level))
224
+ width = image_width//(2**(self.pyramid_level - 1 - level))
225
+ self.pyramid_heights.append(height)
226
+ self.pyramid_widths.append(width)
227
+ self.patch_matchers.append(PatchMatcher(
228
+ height, width, channel, minimum_patch_size=minimum_patch_size,
229
+ threads_per_block=threads_per_block, num_iter=num_iter, gpu_id=gpu_id, guide_weight=guide_weight,
230
+ use_mean_target_style=use_mean_target_style, use_pairwise_patch_error=use_pairwise_patch_error,
231
+ tracking_window_size=tracking_window_size
232
+ ))
233
+
234
+ def resample_image(self, images, level):
235
+ height, width = self.pyramid_heights[level], self.pyramid_widths[level]
236
+ images = images.get()
237
+ images_resample = []
238
+ for image in images:
239
+ image_resample = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
240
+ images_resample.append(image_resample)
241
+ images_resample = cp.array(np.stack(images_resample), dtype=cp.float32)
242
+ return images_resample
243
+
244
+ def initialize_nnf(self, batch_size):
245
+ if self.initialize == "random":
246
+ height, width = self.pyramid_heights[0], self.pyramid_widths[0]
247
+ nnf = cp.stack([
248
+ cp.random.randint(0, height, (batch_size, height, width), dtype=cp.int32),
249
+ cp.random.randint(0, width, (batch_size, height, width), dtype=cp.int32)
250
+ ], axis=3)
251
+ elif self.initialize == "identity":
252
+ height, width = self.pyramid_heights[0], self.pyramid_widths[0]
253
+ nnf = cp.stack([
254
+ cp.repeat(cp.arange(height), width).reshape(height, width),
255
+ cp.tile(cp.arange(width), height).reshape(height, width)
256
+ ], axis=2)
257
+ nnf = cp.stack([nnf] * batch_size)
258
+ else:
259
+ raise NotImplementedError()
260
+ return nnf
261
+
262
+ def update_nnf(self, nnf, level):
263
+ # upscale
264
+ nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2
265
+ nnf[:,[i for i in range(nnf.shape[0]) if i&1],:,0] += 1
266
+ nnf[:,:,[i for i in range(nnf.shape[0]) if i&1],1] += 1
267
+ # check if scale is 2
268
+ height, width = self.pyramid_heights[level], self.pyramid_widths[level]
269
+ if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2:
270
+ nnf = nnf.get().astype(np.float32)
271
+ nnf = [cv2.resize(n, (width, height), interpolation=cv2.INTER_LINEAR) for n in nnf]
272
+ nnf = cp.array(np.stack(nnf), dtype=cp.int32)
273
+ nnf = self.patch_matchers[level].clamp_bound(nnf)
274
+ return nnf
275
+
276
+ def apply_nnf_to_image(self, nnf, image):
277
+ with cp.cuda.Device(self.gpu_id):
278
+ image = self.patch_matchers[-1].pad_image(image)
279
+ image = self.patch_matchers[-1].apply_nnf_to_image(nnf, image)
280
+ return image
281
+
282
+ def estimate_nnf(self, source_guide, target_guide, source_style):
283
+ with cp.cuda.Device(self.gpu_id):
284
+ if not isinstance(source_guide, cp.ndarray):
285
+ source_guide = cp.array(source_guide, dtype=cp.float32)
286
+ if not isinstance(target_guide, cp.ndarray):
287
+ target_guide = cp.array(target_guide, dtype=cp.float32)
288
+ if not isinstance(source_style, cp.ndarray):
289
+ source_style = cp.array(source_style, dtype=cp.float32)
290
+ for level in range(self.pyramid_level):
291
+ nnf = self.initialize_nnf(source_guide.shape[0]) if level==0 else self.update_nnf(nnf, level)
292
+ source_guide_ = self.resample_image(source_guide, level)
293
+ target_guide_ = self.resample_image(target_guide, level)
294
+ source_style_ = self.resample_image(source_style, level)
295
+ nnf, target_style = self.patch_matchers[level].estimate_nnf(
296
+ source_guide_, target_guide_, source_style_, nnf
297
+ )
298
+ return nnf.get(), target_style.get()
diffsynth/extensions/FastBlend/runners/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .accurate import AccurateModeRunner
2
+ from .fast import FastModeRunner
3
+ from .balanced import BalancedModeRunner
4
+ from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner
diffsynth/extensions/FastBlend/runners/accurate.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..patch_match import PyramidPatchMatcher
2
+ import os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class AccurateModeRunner:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None):
13
+ patch_match_engine = PyramidPatchMatcher(
14
+ image_height=frames_style[0].shape[0],
15
+ image_width=frames_style[0].shape[1],
16
+ channel=3,
17
+ use_mean_target_style=True,
18
+ **ebsynth_config
19
+ )
20
+ # run
21
+ n = len(frames_style)
22
+ for target in tqdm(range(n), desc=desc):
23
+ l, r = max(target - window_size, 0), min(target + window_size + 1, n)
24
+ remapped_frames = []
25
+ for i in range(l, r, batch_size):
26
+ j = min(i + batch_size, r)
27
+ source_guide = np.stack([frames_guide[source] for source in range(i, j)])
28
+ target_guide = np.stack([frames_guide[target]] * (j - i))
29
+ source_style = np.stack([frames_style[source] for source in range(i, j)])
30
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
31
+ remapped_frames.append(target_style)
32
+ frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
33
+ frame = frame.clip(0, 255).astype("uint8")
34
+ if save_path is not None:
35
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
diffsynth/extensions/FastBlend/runners/balanced.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..patch_match import PyramidPatchMatcher
2
+ import os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class BalancedModeRunner:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None):
13
+ patch_match_engine = PyramidPatchMatcher(
14
+ image_height=frames_style[0].shape[0],
15
+ image_width=frames_style[0].shape[1],
16
+ channel=3,
17
+ **ebsynth_config
18
+ )
19
+ # tasks
20
+ n = len(frames_style)
21
+ tasks = []
22
+ for target in range(n):
23
+ for source in range(target - window_size, target + window_size + 1):
24
+ if source >= 0 and source < n and source != target:
25
+ tasks.append((source, target))
26
+ # run
27
+ frames = [(None, 1) for i in range(n)]
28
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
29
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
30
+ source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
31
+ target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
32
+ source_style = np.stack([frames_style[source] for source, target in tasks_batch])
33
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
34
+ for (source, target), result in zip(tasks_batch, target_style):
35
+ frame, weight = frames[target]
36
+ if frame is None:
37
+ frame = frames_style[target]
38
+ frames[target] = (
39
+ frame * (weight / (weight + 1)) + result / (weight + 1),
40
+ weight + 1
41
+ )
42
+ if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size):
43
+ frame = frame.clip(0, 255).astype("uint8")
44
+ if save_path is not None:
45
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
46
+ frames[target] = (None, 1)
diffsynth/extensions/FastBlend/runners/fast.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..patch_match import PyramidPatchMatcher
2
+ import functools, os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class TableManager:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def task_list(self, n):
13
+ tasks = []
14
+ max_level = 1
15
+ while (1<<max_level)<=n:
16
+ max_level += 1
17
+ for i in range(n):
18
+ j = i
19
+ for level in range(max_level):
20
+ if i&(1<<level):
21
+ continue
22
+ j |= 1<<level
23
+ if j>=n:
24
+ break
25
+ meta_data = {
26
+ "source": i,
27
+ "target": j,
28
+ "level": level + 1
29
+ }
30
+ tasks.append(meta_data)
31
+ tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"]))
32
+ return tasks
33
+
34
+ def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""):
35
+ n = len(frames_guide)
36
+ tasks = self.task_list(n)
37
+ remapping_table = [[(frames_style[i], 1)] for i in range(n)]
38
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
39
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
40
+ source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
41
+ target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
42
+ source_style = np.stack([frames_style[task["source"]] for task in tasks_batch])
43
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
44
+ for task, result in zip(tasks_batch, target_style):
45
+ target, level = task["target"], task["level"]
46
+ if len(remapping_table[target])==level:
47
+ remapping_table[target].append((result, 1))
48
+ else:
49
+ frame, weight = remapping_table[target][level]
50
+ remapping_table[target][level] = (
51
+ frame * (weight / (weight + 1)) + result / (weight + 1),
52
+ weight + 1
53
+ )
54
+ return remapping_table
55
+
56
+ def remapping_table_to_blending_table(self, table):
57
+ for i in range(len(table)):
58
+ for j in range(1, len(table[i])):
59
+ frame_1, weight_1 = table[i][j-1]
60
+ frame_2, weight_2 = table[i][j]
61
+ frame = (frame_1 + frame_2) / 2
62
+ weight = weight_1 + weight_2
63
+ table[i][j] = (frame, weight)
64
+ return table
65
+
66
+ def tree_query(self, leftbound, rightbound):
67
+ node_list = []
68
+ node_index = rightbound
69
+ while node_index>=leftbound:
70
+ node_level = 0
71
+ while (1<<node_level)&node_index and node_index-(1<<node_level+1)+1>=leftbound:
72
+ node_level += 1
73
+ node_list.append((node_index, node_level))
74
+ node_index -= 1<<node_level
75
+ return node_list
76
+
77
+ def process_window_sum(self, frames_guide, blending_table, patch_match_engine, window_size, batch_size, desc=""):
78
+ n = len(blending_table)
79
+ tasks = []
80
+ frames_result = []
81
+ for target in range(n):
82
+ node_list = self.tree_query(max(target-window_size, 0), target)
83
+ for source, level in node_list:
84
+ if source!=target:
85
+ meta_data = {
86
+ "source": source,
87
+ "target": target,
88
+ "level": level
89
+ }
90
+ tasks.append(meta_data)
91
+ else:
92
+ frames_result.append(blending_table[target][level])
93
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
94
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
95
+ source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
96
+ target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
97
+ source_style = np.stack([blending_table[task["source"]][task["level"]][0] for task in tasks_batch])
98
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
99
+ for task, frame_2 in zip(tasks_batch, target_style):
100
+ source, target, level = task["source"], task["target"], task["level"]
101
+ frame_1, weight_1 = frames_result[target]
102
+ weight_2 = blending_table[source][level][1]
103
+ weight = weight_1 + weight_2
104
+ frame = frame_1 * (weight_1 / weight) + frame_2 * (weight_2 / weight)
105
+ frames_result[target] = (frame, weight)
106
+ return frames_result
107
+
108
+
109
+ class FastModeRunner:
110
+ def __init__(self):
111
+ pass
112
+
113
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, save_path=None):
114
+ frames_guide = frames_guide.raw_data()
115
+ frames_style = frames_style.raw_data()
116
+ table_manager = TableManager()
117
+ patch_match_engine = PyramidPatchMatcher(
118
+ image_height=frames_style[0].shape[0],
119
+ image_width=frames_style[0].shape[1],
120
+ channel=3,
121
+ **ebsynth_config
122
+ )
123
+ # left part
124
+ table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="Fast Mode Step 1/4")
125
+ table_l = table_manager.remapping_table_to_blending_table(table_l)
126
+ table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 2/4")
127
+ # right part
128
+ table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="Fast Mode Step 3/4")
129
+ table_r = table_manager.remapping_table_to_blending_table(table_r)
130
+ table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 4/4")[::-1]
131
+ # merge
132
+ frames = []
133
+ for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
134
+ weight_m = -1
135
+ weight = weight_l + weight_m + weight_r
136
+ frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
137
+ frames.append(frame)
138
+ frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
139
+ if save_path is not None:
140
+ for target, frame in enumerate(frames):
141
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
diffsynth/extensions/FastBlend/runners/interpolation.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..patch_match import PyramidPatchMatcher
2
+ import os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class InterpolationModeRunner:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def get_index_dict(self, index_style):
13
+ index_dict = {}
14
+ for i, index in enumerate(index_style):
15
+ index_dict[index] = i
16
+ return index_dict
17
+
18
+ def get_weight(self, l, m, r):
19
+ weight_l, weight_r = abs(m - r), abs(m - l)
20
+ if weight_l + weight_r == 0:
21
+ weight_l, weight_r = 0.5, 0.5
22
+ else:
23
+ weight_l, weight_r = weight_l / (weight_l + weight_r), weight_r / (weight_l + weight_r)
24
+ return weight_l, weight_r
25
+
26
+ def get_task_group(self, index_style, n):
27
+ task_group = []
28
+ index_style = sorted(index_style)
29
+ # first frame
30
+ if index_style[0]>0:
31
+ tasks = []
32
+ for m in range(index_style[0]):
33
+ tasks.append((index_style[0], m, index_style[0]))
34
+ task_group.append(tasks)
35
+ # middle frames
36
+ for l, r in zip(index_style[:-1], index_style[1:]):
37
+ tasks = []
38
+ for m in range(l, r):
39
+ tasks.append((l, m, r))
40
+ task_group.append(tasks)
41
+ # last frame
42
+ tasks = []
43
+ for m in range(index_style[-1], n):
44
+ tasks.append((index_style[-1], m, index_style[-1]))
45
+ task_group.append(tasks)
46
+ return task_group
47
+
48
+ def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
49
+ patch_match_engine = PyramidPatchMatcher(
50
+ image_height=frames_style[0].shape[0],
51
+ image_width=frames_style[0].shape[1],
52
+ channel=3,
53
+ use_mean_target_style=False,
54
+ use_pairwise_patch_error=True,
55
+ **ebsynth_config
56
+ )
57
+ # task
58
+ index_dict = self.get_index_dict(index_style)
59
+ task_group = self.get_task_group(index_style, len(frames_guide))
60
+ # run
61
+ for tasks in task_group:
62
+ index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks])
63
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"):
64
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
65
+ source_guide, target_guide, source_style = [], [], []
66
+ for l, m, r in tasks_batch:
67
+ # l -> m
68
+ source_guide.append(frames_guide[l])
69
+ target_guide.append(frames_guide[m])
70
+ source_style.append(frames_style[index_dict[l]])
71
+ # r -> m
72
+ source_guide.append(frames_guide[r])
73
+ target_guide.append(frames_guide[m])
74
+ source_style.append(frames_style[index_dict[r]])
75
+ source_guide = np.stack(source_guide)
76
+ target_guide = np.stack(target_guide)
77
+ source_style = np.stack(source_style)
78
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
79
+ if save_path is not None:
80
+ for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch):
81
+ weight_l, weight_r = self.get_weight(l, m, r)
82
+ frame = frame_l * weight_l + frame_r * weight_r
83
+ frame = frame.clip(0, 255).astype("uint8")
84
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m))
85
+
86
+
87
+ class InterpolationModeSingleFrameRunner:
88
+ def __init__(self):
89
+ pass
90
+
91
+ def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
92
+ # check input
93
+ tracking_window_size = ebsynth_config["tracking_window_size"]
94
+ if tracking_window_size * 2 >= batch_size:
95
+ raise ValueError("batch_size should be larger than track_window_size * 2")
96
+ frame_style = frames_style[0]
97
+ frame_guide = frames_guide[index_style[0]]
98
+ patch_match_engine = PyramidPatchMatcher(
99
+ image_height=frame_style.shape[0],
100
+ image_width=frame_style.shape[1],
101
+ channel=3,
102
+ **ebsynth_config
103
+ )
104
+ # run
105
+ frame_id, n = 0, len(frames_guide)
106
+ for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"):
107
+ if i + batch_size > n:
108
+ l, r = max(n - batch_size, 0), n
109
+ else:
110
+ l, r = i, i + batch_size
111
+ source_guide = np.stack([frame_guide] * (r-l))
112
+ target_guide = np.stack([frames_guide[i] for i in range(l, r)])
113
+ source_style = np.stack([frame_style] * (r-l))
114
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
115
+ for i, frame in zip(range(l, r), target_style):
116
+ if i==frame_id:
117
+ frame = frame.clip(0, 255).astype("uint8")
118
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id))
119
+ frame_id += 1
120
+ if r < n and r-frame_id <= tracking_window_size:
121
+ break
diffsynth/extensions/RIFE/__init__.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+
8
+ def warp(tenInput, tenFlow, device):
9
+ backwarp_tenGrid = {}
10
+ k = (str(tenFlow.device), str(tenFlow.size()))
11
+ if k not in backwarp_tenGrid:
12
+ tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
13
+ 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
14
+ tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
15
+ 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
16
+ backwarp_tenGrid[k] = torch.cat(
17
+ [tenHorizontal, tenVertical], 1).to(device)
18
+
19
+ tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
20
+ tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
21
+
22
+ g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
23
+ return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
24
+
25
+
26
+ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
27
+ return nn.Sequential(
28
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
29
+ padding=padding, dilation=dilation, bias=True),
30
+ nn.PReLU(out_planes)
31
+ )
32
+
33
+
34
+ class IFBlock(nn.Module):
35
+ def __init__(self, in_planes, c=64):
36
+ super(IFBlock, self).__init__()
37
+ self.conv0 = nn.Sequential(conv(in_planes, c//2, 3, 2, 1), conv(c//2, c, 3, 2, 1),)
38
+ self.convblock0 = nn.Sequential(conv(c, c), conv(c, c))
39
+ self.convblock1 = nn.Sequential(conv(c, c), conv(c, c))
40
+ self.convblock2 = nn.Sequential(conv(c, c), conv(c, c))
41
+ self.convblock3 = nn.Sequential(conv(c, c), conv(c, c))
42
+ self.conv1 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 4, 4, 2, 1))
43
+ self.conv2 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 1, 4, 2, 1))
44
+
45
+ def forward(self, x, flow, scale=1):
46
+ x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
47
+ flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale
48
+ feat = self.conv0(torch.cat((x, flow), 1))
49
+ feat = self.convblock0(feat) + feat
50
+ feat = self.convblock1(feat) + feat
51
+ feat = self.convblock2(feat) + feat
52
+ feat = self.convblock3(feat) + feat
53
+ flow = self.conv1(feat)
54
+ mask = self.conv2(feat)
55
+ flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale
56
+ mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
57
+ return flow, mask
58
+
59
+
60
+ class IFNet(nn.Module):
61
+ def __init__(self):
62
+ super(IFNet, self).__init__()
63
+ self.block0 = IFBlock(7+4, c=90)
64
+ self.block1 = IFBlock(7+4, c=90)
65
+ self.block2 = IFBlock(7+4, c=90)
66
+ self.block_tea = IFBlock(10+4, c=90)
67
+
68
+ def forward(self, x, scale_list=[4, 2, 1], training=False):
69
+ if training == False:
70
+ channel = x.shape[1] // 2
71
+ img0 = x[:, :channel]
72
+ img1 = x[:, channel:]
73
+ flow_list = []
74
+ merged = []
75
+ mask_list = []
76
+ warped_img0 = img0
77
+ warped_img1 = img1
78
+ flow = (x[:, :4]).detach() * 0
79
+ mask = (x[:, :1]).detach() * 0
80
+ block = [self.block0, self.block1, self.block2]
81
+ for i in range(3):
82
+ f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
83
+ f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i])
84
+ flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
85
+ mask = mask + (m0 + (-m1)) / 2
86
+ mask_list.append(mask)
87
+ flow_list.append(flow)
88
+ warped_img0 = warp(img0, flow[:, :2], device=x.device)
89
+ warped_img1 = warp(img1, flow[:, 2:4], device=x.device)
90
+ merged.append((warped_img0, warped_img1))
91
+ '''
92
+ c0 = self.contextnet(img0, flow[:, :2])
93
+ c1 = self.contextnet(img1, flow[:, 2:4])
94
+ tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
95
+ res = tmp[:, 1:4] * 2 - 1
96
+ '''
97
+ for i in range(3):
98
+ mask_list[i] = torch.sigmoid(mask_list[i])
99
+ merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
100
+ return flow_list, mask_list[2], merged
101
+
102
+ def state_dict_converter(self):
103
+ return IFNetStateDictConverter()
104
+
105
+
106
+ class IFNetStateDictConverter:
107
+ def __init__(self):
108
+ pass
109
+
110
+ def from_diffusers(self, state_dict):
111
+ state_dict_ = {k.replace("module.", ""): v for k, v in state_dict.items()}
112
+ return state_dict_
113
+
114
+ def from_civitai(self, state_dict):
115
+ return self.from_diffusers(state_dict)
116
+
117
+
118
+ class RIFEInterpolater:
119
+ def __init__(self, model, device="cuda"):
120
+ self.model = model
121
+ self.device = device
122
+ # IFNet only does not support float16
123
+ self.torch_dtype = torch.float32
124
+
125
+ @staticmethod
126
+ def from_model_manager(model_manager):
127
+ return RIFEInterpolater(model_manager.RIFE, device=model_manager.device)
128
+
129
+ def process_image(self, image):
130
+ width, height = image.size
131
+ if width % 32 != 0 or height % 32 != 0:
132
+ width = (width + 31) // 32
133
+ height = (height + 31) // 32
134
+ image = image.resize((width, height))
135
+ image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1)
136
+ return image
137
+
138
+ def process_images(self, images):
139
+ images = [self.process_image(image) for image in images]
140
+ images = torch.stack(images)
141
+ return images
142
+
143
+ def decode_images(self, images):
144
+ images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
145
+ images = [Image.fromarray(image) for image in images]
146
+ return images
147
+
148
+ def add_interpolated_images(self, images, interpolated_images):
149
+ output_images = []
150
+ for image, interpolated_image in zip(images, interpolated_images):
151
+ output_images.append(image)
152
+ output_images.append(interpolated_image)
153
+ output_images.append(images[-1])
154
+ return output_images
155
+
156
+
157
+ @torch.no_grad()
158
+ def interpolate_(self, images, scale=1.0):
159
+ input_tensor = self.process_images(images)
160
+ input_tensor = torch.cat((input_tensor[:-1], input_tensor[1:]), dim=1)
161
+ input_tensor = input_tensor.to(device=self.device, dtype=self.torch_dtype)
162
+ flow, mask, merged = self.model(input_tensor, [4/scale, 2/scale, 1/scale])
163
+ output_images = self.decode_images(merged[2].cpu())
164
+ if output_images[0].size != images[0].size:
165
+ output_images = [image.resize(images[0].size) for image in output_images]
166
+ return output_images
167
+
168
+
169
+ @torch.no_grad()
170
+ def interpolate(self, images, scale=1.0, batch_size=4, num_iter=1, progress_bar=lambda x:x):
171
+ # Preprocess
172
+ processed_images = self.process_images(images)
173
+
174
+ for iter in range(num_iter):
175
+ # Input
176
+ input_tensor = torch.cat((processed_images[:-1], processed_images[1:]), dim=1)
177
+
178
+ # Interpolate
179
+ output_tensor = []
180
+ for batch_id in progress_bar(range(0, input_tensor.shape[0], batch_size)):
181
+ batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
182
+ batch_input_tensor = input_tensor[batch_id: batch_id_]
183
+ batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
184
+ flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
185
+ output_tensor.append(merged[2].cpu())
186
+
187
+ # Output
188
+ output_tensor = torch.concat(output_tensor, dim=0).clip(0, 1)
189
+ processed_images = self.add_interpolated_images(processed_images, output_tensor)
190
+ processed_images = torch.stack(processed_images)
191
+
192
+ # To images
193
+ output_images = self.decode_images(processed_images)
194
+ if output_images[0].size != images[0].size:
195
+ output_images = [image.resize(images[0].size) for image in output_images]
196
+ return output_images
197
+
198
+
199
+ class RIFESmoother(RIFEInterpolater):
200
+ def __init__(self, model, device="cuda"):
201
+ super(RIFESmoother, self).__init__(model, device=device)
202
+
203
+ @staticmethod
204
+ def from_model_manager(model_manager):
205
+ return RIFESmoother(model_manager.RIFE, device=model_manager.device)
206
+
207
+ def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
208
+ output_tensor = []
209
+ for batch_id in range(0, input_tensor.shape[0], batch_size):
210
+ batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
211
+ batch_input_tensor = input_tensor[batch_id: batch_id_]
212
+ batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
213
+ flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
214
+ output_tensor.append(merged[2].cpu())
215
+ output_tensor = torch.concat(output_tensor, dim=0)
216
+ return output_tensor
217
+
218
+ @torch.no_grad()
219
+ def __call__(self, rendered_frames, scale=1.0, batch_size=4, num_iter=1, **kwargs):
220
+ # Preprocess
221
+ processed_images = self.process_images(rendered_frames)
222
+
223
+ for iter in range(num_iter):
224
+ # Input
225
+ input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1)
226
+
227
+ # Interpolate
228
+ output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
229
+
230
+ # Blend
231
+ input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1)
232
+ output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
233
+
234
+ # Add to frames
235
+ processed_images[1:-1] = output_tensor
236
+
237
+ # To images
238
+ output_images = self.decode_images(processed_images)
239
+ if output_images[0].size != rendered_frames[0].size:
240
+ output_images = [image.resize(rendered_frames[0].size) for image in output_images]
241
+ return output_images
diffsynth/models/__init__.py ADDED
@@ -0,0 +1,814 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os, json
2
+ from safetensors import safe_open
3
+ from typing_extensions import Literal, TypeAlias
4
+ from typing import List
5
+
6
+ from .downloader import download_from_huggingface, download_from_modelscope
7
+
8
+ from .sd_text_encoder import SDTextEncoder
9
+ from .sd_unet import SDUNet
10
+ from .sd_vae_encoder import SDVAEEncoder
11
+ from .sd_vae_decoder import SDVAEDecoder
12
+ from .sd_lora import SDLoRA
13
+
14
+ from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
15
+ from .sdxl_unet import SDXLUNet
16
+ from .sdxl_vae_decoder import SDXLVAEDecoder
17
+ from .sdxl_vae_encoder import SDXLVAEEncoder
18
+
19
+ from .sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
20
+ from .sd3_dit import SD3DiT
21
+ from .sd3_vae_decoder import SD3VAEDecoder
22
+ from .sd3_vae_encoder import SD3VAEEncoder
23
+
24
+ from .sd_controlnet import SDControlNet
25
+
26
+ from .sd_motion import SDMotionModel
27
+ from .sdxl_motion import SDXLMotionModel
28
+
29
+ from .svd_image_encoder import SVDImageEncoder
30
+ from .svd_unet import SVDUNet
31
+ from .svd_vae_decoder import SVDVAEDecoder
32
+ from .svd_vae_encoder import SVDVAEEncoder
33
+
34
+ from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
35
+ from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
36
+
37
+ from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
38
+ from .hunyuan_dit import HunyuanDiT
39
+ from .kolors_text_encoder import ChatGLMModel
40
+
41
+
42
+ preset_models_on_huggingface = {
43
+ "HunyuanDiT": [
44
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
45
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
46
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
47
+ ("Tencent-Hunyuan/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
48
+ ],
49
+ "stable-video-diffusion-img2vid-xt": [
50
+ ("stabilityai/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
51
+ ],
52
+ "ExVideo-SVD-128f-v1": [
53
+ ("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
54
+ ],
55
+ }
56
+ preset_models_on_modelscope = {
57
+ # Hunyuan DiT
58
+ "HunyuanDiT": [
59
+ ("modelscope/HunyuanDiT", "t2i/clip_text_encoder/pytorch_model.bin", "models/HunyuanDiT/t2i/clip_text_encoder"),
60
+ ("modelscope/HunyuanDiT", "t2i/mt5/pytorch_model.bin", "models/HunyuanDiT/t2i/mt5"),
61
+ ("modelscope/HunyuanDiT", "t2i/model/pytorch_model_ema.pt", "models/HunyuanDiT/t2i/model"),
62
+ ("modelscope/HunyuanDiT", "t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin", "models/HunyuanDiT/t2i/sdxl-vae-fp16-fix"),
63
+ ],
64
+ # Stable Video Diffusion
65
+ "stable-video-diffusion-img2vid-xt": [
66
+ ("AI-ModelScope/stable-video-diffusion-img2vid-xt", "svd_xt.safetensors", "models/stable_video_diffusion"),
67
+ ],
68
+ # ExVideo
69
+ "ExVideo-SVD-128f-v1": [
70
+ ("ECNU-CILab/ExVideo-SVD-128f-v1", "model.fp16.safetensors", "models/stable_video_diffusion"),
71
+ ],
72
+ # Stable Diffusion
73
+ "StableDiffusion_v15": [
74
+ ("AI-ModelScope/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors", "models/stable_diffusion"),
75
+ ],
76
+ "DreamShaper_8": [
77
+ ("sd_lora/dreamshaper_8", "dreamshaper_8.safetensors", "models/stable_diffusion"),
78
+ ],
79
+ "AingDiffusion_v12": [
80
+ ("sd_lora/aingdiffusion_v12", "aingdiffusion_v12.safetensors", "models/stable_diffusion"),
81
+ ],
82
+ "Flat2DAnimerge_v45Sharp": [
83
+ ("sd_lora/Flat-2D-Animerge", "flat2DAnimerge_v45Sharp.safetensors", "models/stable_diffusion"),
84
+ ],
85
+ # Textual Inversion
86
+ "TextualInversion_VeryBadImageNegative_v1.3": [
87
+ ("sd_lora/verybadimagenegative_v1.3", "verybadimagenegative_v1.3.pt", "models/textual_inversion"),
88
+ ],
89
+ # Stable Diffusion XL
90
+ "StableDiffusionXL_v1": [
91
+ ("AI-ModelScope/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors", "models/stable_diffusion_xl"),
92
+ ],
93
+ "BluePencilXL_v200": [
94
+ ("sd_lora/bluePencilXL_v200", "bluePencilXL_v200.safetensors", "models/stable_diffusion_xl"),
95
+ ],
96
+ "StableDiffusionXL_Turbo": [
97
+ ("AI-ModelScope/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors", "models/stable_diffusion_xl_turbo"),
98
+ ],
99
+ # Stable Diffusion 3
100
+ "StableDiffusion3": [
101
+ ("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips_t5xxlfp16.safetensors", "models/stable_diffusion_3"),
102
+ ],
103
+ "StableDiffusion3_without_T5": [
104
+ ("AI-ModelScope/stable-diffusion-3-medium", "sd3_medium_incl_clips.safetensors", "models/stable_diffusion_3"),
105
+ ],
106
+ # ControlNet
107
+ "ControlNet_v11f1p_sd15_depth": [
108
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11f1p_sd15_depth.pth", "models/ControlNet"),
109
+ ("sd_lora/Annotators", "dpt_hybrid-midas-501f0c75.pt", "models/Annotators")
110
+ ],
111
+ "ControlNet_v11p_sd15_softedge": [
112
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_softedge.pth", "models/ControlNet"),
113
+ ("sd_lora/Annotators", "ControlNetHED.pth", "models/Annotators")
114
+ ],
115
+ "ControlNet_v11f1e_sd15_tile": [
116
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11f1e_sd15_tile.pth", "models/ControlNet")
117
+ ],
118
+ "ControlNet_v11p_sd15_lineart": [
119
+ ("AI-ModelScope/ControlNet-v1-1", "control_v11p_sd15_lineart.pth", "models/ControlNet"),
120
+ ("sd_lora/Annotators", "sk_model.pth", "models/Annotators"),
121
+ ("sd_lora/Annotators", "sk_model2.pth", "models/Annotators")
122
+ ],
123
+ # AnimateDiff
124
+ "AnimateDiff_v2": [
125
+ ("Shanghai_AI_Laboratory/animatediff", "mm_sd_v15_v2.ckpt", "models/AnimateDiff"),
126
+ ],
127
+ "AnimateDiff_xl_beta": [
128
+ ("Shanghai_AI_Laboratory/animatediff", "mm_sdxl_v10_beta.ckpt", "models/AnimateDiff"),
129
+ ],
130
+ # RIFE
131
+ "RIFE": [
132
+ ("Damo_XR_Lab/cv_rife_video-frame-interpolation", "flownet.pkl", "models/RIFE"),
133
+ ],
134
+ # Beautiful Prompt
135
+ "BeautifulPrompt": [
136
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
137
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "generation_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
138
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "model.safetensors", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
139
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "special_tokens_map.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
140
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
141
+ ("AI-ModelScope/pai-bloom-1b1-text2prompt-sd", "tokenizer_config.json", "models/BeautifulPrompt/pai-bloom-1b1-text2prompt-sd"),
142
+ ],
143
+ # Translator
144
+ "opus-mt-zh-en": [
145
+ ("moxying/opus-mt-zh-en", "config.json", "models/translator/opus-mt-zh-en"),
146
+ ("moxying/opus-mt-zh-en", "generation_config.json", "models/translator/opus-mt-zh-en"),
147
+ ("moxying/opus-mt-zh-en", "metadata.json", "models/translator/opus-mt-zh-en"),
148
+ ("moxying/opus-mt-zh-en", "pytorch_model.bin", "models/translator/opus-mt-zh-en"),
149
+ ("moxying/opus-mt-zh-en", "source.spm", "models/translator/opus-mt-zh-en"),
150
+ ("moxying/opus-mt-zh-en", "target.spm", "models/translator/opus-mt-zh-en"),
151
+ ("moxying/opus-mt-zh-en", "tokenizer_config.json", "models/translator/opus-mt-zh-en"),
152
+ ("moxying/opus-mt-zh-en", "vocab.json", "models/translator/opus-mt-zh-en"),
153
+ ],
154
+ # IP-Adapter
155
+ "IP-Adapter-SD": [
156
+ ("AI-ModelScope/IP-Adapter", "models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion/image_encoder"),
157
+ ("AI-ModelScope/IP-Adapter", "models/ip-adapter_sd15.bin", "models/IpAdapter/stable_diffusion"),
158
+ ],
159
+ "IP-Adapter-SDXL": [
160
+ ("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
161
+ ("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
162
+ ],
163
+ # Kolors
164
+ "Kolors": [
165
+ ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
166
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
167
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
168
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
169
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
170
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
171
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
172
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
173
+ ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
174
+ ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
175
+ ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
176
+ ],
177
+ "SDXL-vae-fp16-fix": [
178
+ ("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix")
179
+ ],
180
+ }
181
+ Preset_model_id: TypeAlias = Literal[
182
+ "HunyuanDiT",
183
+ "stable-video-diffusion-img2vid-xt",
184
+ "ExVideo-SVD-128f-v1",
185
+ "StableDiffusion_v15",
186
+ "DreamShaper_8",
187
+ "AingDiffusion_v12",
188
+ "Flat2DAnimerge_v45Sharp",
189
+ "TextualInversion_VeryBadImageNegative_v1.3",
190
+ "StableDiffusionXL_v1",
191
+ "BluePencilXL_v200",
192
+ "StableDiffusionXL_Turbo",
193
+ "ControlNet_v11f1p_sd15_depth",
194
+ "ControlNet_v11p_sd15_softedge",
195
+ "ControlNet_v11f1e_sd15_tile",
196
+ "ControlNet_v11p_sd15_lineart",
197
+ "AnimateDiff_v2",
198
+ "AnimateDiff_xl_beta",
199
+ "RIFE",
200
+ "BeautifulPrompt",
201
+ "opus-mt-zh-en",
202
+ "IP-Adapter-SD",
203
+ "IP-Adapter-SDXL",
204
+ "StableDiffusion3",
205
+ "StableDiffusion3_without_T5",
206
+ "Kolors",
207
+ "SDXL-vae-fp16-fix",
208
+ ]
209
+ Preset_model_website: TypeAlias = Literal[
210
+ "HuggingFace",
211
+ "ModelScope",
212
+ ]
213
+ website_to_preset_models = {
214
+ "HuggingFace": preset_models_on_huggingface,
215
+ "ModelScope": preset_models_on_modelscope,
216
+ }
217
+ website_to_download_fn = {
218
+ "HuggingFace": download_from_huggingface,
219
+ "ModelScope": download_from_modelscope,
220
+ }
221
+
222
+
223
+ def download_models(
224
+ model_id_list: List[Preset_model_id] = [],
225
+ downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
226
+ ):
227
+ downloaded_files = []
228
+ for model_id in model_id_list:
229
+ for website in downloading_priority:
230
+ if model_id in website_to_preset_models[website]:
231
+ for model_id, origin_file_path, local_dir in website_to_preset_models[website][model_id]:
232
+ # Check if the file is downloaded.
233
+ file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
234
+ if file_to_download in downloaded_files:
235
+ continue
236
+ # Download
237
+ website_to_download_fn[website](model_id, origin_file_path, local_dir)
238
+ if os.path.basename(origin_file_path) in os.listdir(local_dir):
239
+ downloaded_files.append(file_to_download)
240
+ return downloaded_files
241
+
242
+
243
+ class ModelManager:
244
+ def __init__(
245
+ self,
246
+ torch_dtype=torch.float16,
247
+ device="cuda",
248
+ model_id_list: List[Preset_model_id] = [],
249
+ downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
250
+ file_path_list: List[str] = [],
251
+ ):
252
+ self.torch_dtype = torch_dtype
253
+ self.device = device
254
+ self.model = {}
255
+ self.model_path = {}
256
+ self.textual_inversion_dict = {}
257
+ downloaded_files = download_models(model_id_list, downloading_priority)
258
+ self.load_models(downloaded_files + file_path_list)
259
+
260
+ def load_model_from_origin(
261
+ self,
262
+ download_from: Preset_model_website = "ModelScope",
263
+ model_id = "",
264
+ origin_file_path = "",
265
+ local_dir = ""
266
+ ):
267
+ website_to_download_fn[download_from](model_id, origin_file_path, local_dir)
268
+ file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
269
+ self.load_model(file_to_download)
270
+
271
+ def is_stable_video_diffusion(self, state_dict):
272
+ param_name = "model.diffusion_model.output_blocks.9.1.time_stack.0.norm_in.weight"
273
+ return param_name in state_dict
274
+
275
+ def is_RIFE(self, state_dict):
276
+ param_name = "block_tea.convblock3.0.1.weight"
277
+ return param_name in state_dict or ("module." + param_name) in state_dict
278
+
279
+ def is_beautiful_prompt(self, state_dict):
280
+ param_name = "transformer.h.9.self_attention.query_key_value.weight"
281
+ return param_name in state_dict
282
+
283
+ def is_stabe_diffusion_xl(self, state_dict):
284
+ param_name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight"
285
+ return param_name in state_dict
286
+
287
+ def is_stable_diffusion(self, state_dict):
288
+ if self.is_stabe_diffusion_xl(state_dict):
289
+ return False
290
+ param_name = "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight"
291
+ return param_name in state_dict
292
+
293
+ def is_controlnet(self, state_dict):
294
+ param_name = "control_model.time_embed.0.weight"
295
+ return param_name in state_dict
296
+
297
+ def is_animatediff(self, state_dict):
298
+ param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
299
+ return param_name in state_dict
300
+
301
+ def is_animatediff_xl(self, state_dict):
302
+ param_name = "up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.ff_norm.weight"
303
+ return param_name in state_dict
304
+
305
+ def is_sd_lora(self, state_dict):
306
+ param_name = "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_ff_net_2.lora_up.weight"
307
+ return param_name in state_dict
308
+
309
+ def is_translator(self, state_dict):
310
+ param_name = "model.encoder.layers.5.self_attn_layer_norm.weight"
311
+ return param_name in state_dict and len(state_dict) == 258
312
+
313
+ def is_ipadapter(self, state_dict):
314
+ return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([3072, 1024])
315
+
316
+ def is_ipadapter_image_encoder(self, state_dict):
317
+ param_name = "vision_model.encoder.layers.31.self_attn.v_proj.weight"
318
+ return param_name in state_dict and len(state_dict) == 521
319
+
320
+ def is_ipadapter_xl(self, state_dict):
321
+ return "image_proj" in state_dict and "ip_adapter" in state_dict and state_dict["image_proj"]["proj.weight"].shape == torch.Size([8192, 1280])
322
+
323
+ def is_ipadapter_xl_image_encoder(self, state_dict):
324
+ param_name = "vision_model.encoder.layers.47.self_attn.v_proj.weight"
325
+ return param_name in state_dict and len(state_dict) == 777
326
+
327
+ def is_hunyuan_dit_clip_text_encoder(self, state_dict):
328
+ param_name = "bert.encoder.layer.23.attention.output.dense.weight"
329
+ return param_name in state_dict
330
+
331
+ def is_hunyuan_dit_t5_text_encoder(self, state_dict):
332
+ param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
333
+ param_name_ = "decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
334
+ return param_name in state_dict and param_name_ in state_dict
335
+
336
+ def is_hunyuan_dit(self, state_dict):
337
+ param_name = "final_layer.adaLN_modulation.1.weight"
338
+ return param_name in state_dict
339
+
340
+ def is_diffusers_vae(self, state_dict):
341
+ param_name = "quant_conv.weight"
342
+ return param_name in state_dict
343
+
344
+ def is_ExVideo_StableVideoDiffusion(self, state_dict):
345
+ param_name = "blocks.185.positional_embedding.embeddings"
346
+ return param_name in state_dict
347
+
348
+ def is_stable_diffusion_3(self, state_dict):
349
+ param_names = [
350
+ "text_encoders.clip_l.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight",
351
+ "text_encoders.clip_g.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight",
352
+ "model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.weight",
353
+ "first_stage_model.encoder.mid.block_2.norm2.weight",
354
+ "first_stage_model.decoder.mid.block_2.norm2.weight",
355
+ ]
356
+ for param_name in param_names:
357
+ if param_name not in state_dict:
358
+ return False
359
+ return True
360
+
361
+ def is_stable_diffusion_3_t5(self, state_dict):
362
+ param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
363
+ return param_name in state_dict
364
+
365
+ def is_kolors_text_encoder(self, file_path):
366
+ file_list = os.listdir(file_path)
367
+ if "config.json" in file_list:
368
+ try:
369
+ with open(os.path.join(file_path, "config.json"), "r") as f:
370
+ config = json.load(f)
371
+ if config.get("model_type") == "chatglm":
372
+ return True
373
+ except:
374
+ pass
375
+ return False
376
+
377
+ def is_kolors_unet(self, state_dict):
378
+ return "up_blocks.2.resnets.2.time_emb_proj.weight" in state_dict and "encoder_hid_proj.weight" in state_dict
379
+
380
+ def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None):
381
+ component_dict = {
382
+ "image_encoder": SVDImageEncoder,
383
+ "unet": SVDUNet,
384
+ "vae_decoder": SVDVAEDecoder,
385
+ "vae_encoder": SVDVAEEncoder,
386
+ }
387
+ if components is None:
388
+ components = ["image_encoder", "unet", "vae_decoder", "vae_encoder"]
389
+ for component in components:
390
+ if component == "unet":
391
+ self.model[component] = component_dict[component](add_positional_conv=add_positional_conv)
392
+ self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv), strict=False)
393
+ else:
394
+ self.model[component] = component_dict[component]()
395
+ self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
396
+ self.model[component].to(self.torch_dtype).to(self.device)
397
+ self.model_path[component] = file_path
398
+
399
+ def load_stable_diffusion(self, state_dict, components=None, file_path=""):
400
+ component_dict = {
401
+ "text_encoder": SDTextEncoder,
402
+ "unet": SDUNet,
403
+ "vae_decoder": SDVAEDecoder,
404
+ "vae_encoder": SDVAEEncoder,
405
+ "refiner": SDXLUNet,
406
+ }
407
+ if components is None:
408
+ components = ["text_encoder", "unet", "vae_decoder", "vae_encoder"]
409
+ for component in components:
410
+ if component == "text_encoder":
411
+ # Add additional token embeddings to text encoder
412
+ token_embeddings = [state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"]]
413
+ for keyword in self.textual_inversion_dict:
414
+ _, embeddings = self.textual_inversion_dict[keyword]
415
+ token_embeddings.append(embeddings.to(dtype=token_embeddings[0].dtype))
416
+ token_embeddings = torch.concat(token_embeddings, dim=0)
417
+ state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"] = token_embeddings
418
+ self.model[component] = component_dict[component](vocab_size=token_embeddings.shape[0])
419
+ self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
420
+ self.model[component].to(self.torch_dtype).to(self.device)
421
+ else:
422
+ self.model[component] = component_dict[component]()
423
+ self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
424
+ self.model[component].to(self.torch_dtype).to(self.device)
425
+ self.model_path[component] = file_path
426
+
427
+ def load_stable_diffusion_xl(self, state_dict, components=None, file_path=""):
428
+ component_dict = {
429
+ "text_encoder": SDXLTextEncoder,
430
+ "text_encoder_2": SDXLTextEncoder2,
431
+ "unet": SDXLUNet,
432
+ "vae_decoder": SDXLVAEDecoder,
433
+ "vae_encoder": SDXLVAEEncoder,
434
+ }
435
+ if components is None:
436
+ components = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"]
437
+ for component in components:
438
+ self.model[component] = component_dict[component]()
439
+ self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
440
+ if component in ["vae_decoder", "vae_encoder"]:
441
+ # These two model will output nan when float16 is enabled.
442
+ # The precision problem happens in the last three resnet blocks.
443
+ # I do not know how to solve this problem.
444
+ self.model[component].to(torch.float32).to(self.device)
445
+ else:
446
+ self.model[component].to(self.torch_dtype).to(self.device)
447
+ self.model_path[component] = file_path
448
+
449
+ def load_controlnet(self, state_dict, file_path=""):
450
+ component = "controlnet"
451
+ if component not in self.model:
452
+ self.model[component] = []
453
+ self.model_path[component] = []
454
+ model = SDControlNet()
455
+ model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
456
+ model.to(self.torch_dtype).to(self.device)
457
+ self.model[component].append(model)
458
+ self.model_path[component].append(file_path)
459
+
460
+ def load_animatediff(self, state_dict, file_path=""):
461
+ component = "motion_modules"
462
+ model = SDMotionModel()
463
+ model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
464
+ model.to(self.torch_dtype).to(self.device)
465
+ self.model[component] = model
466
+ self.model_path[component] = file_path
467
+
468
+ def load_animatediff_xl(self, state_dict, file_path=""):
469
+ component = "motion_modules_xl"
470
+ model = SDXLMotionModel()
471
+ model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
472
+ model.to(self.torch_dtype).to(self.device)
473
+ self.model[component] = model
474
+ self.model_path[component] = file_path
475
+
476
+ def load_beautiful_prompt(self, state_dict, file_path=""):
477
+ component = "beautiful_prompt"
478
+ from transformers import AutoModelForCausalLM
479
+ model_folder = os.path.dirname(file_path)
480
+ model = AutoModelForCausalLM.from_pretrained(
481
+ model_folder, state_dict=state_dict, local_files_only=True, torch_dtype=self.torch_dtype
482
+ ).to(self.device).eval()
483
+ self.model[component] = model
484
+ self.model_path[component] = file_path
485
+
486
+ def load_RIFE(self, state_dict, file_path=""):
487
+ component = "RIFE"
488
+ from ..extensions.RIFE import IFNet
489
+ model = IFNet().eval()
490
+ model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
491
+ model.to(torch.float32).to(self.device)
492
+ self.model[component] = model
493
+ self.model_path[component] = file_path
494
+
495
+ def load_sd_lora(self, state_dict, alpha):
496
+ SDLoRA().add_lora_to_text_encoder(self.model["text_encoder"], state_dict, alpha=alpha, device=self.device)
497
+ SDLoRA().add_lora_to_unet(self.model["unet"], state_dict, alpha=alpha, device=self.device)
498
+
499
+ def load_translator(self, state_dict, file_path=""):
500
+ # This model is lightweight, we do not place it on GPU.
501
+ component = "translator"
502
+ from transformers import AutoModelForSeq2SeqLM
503
+ model_folder = os.path.dirname(file_path)
504
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_folder).eval()
505
+ self.model[component] = model
506
+ self.model_path[component] = file_path
507
+
508
+ def load_ipadapter(self, state_dict, file_path=""):
509
+ component = "ipadapter"
510
+ model = SDIpAdapter()
511
+ model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
512
+ model.to(self.torch_dtype).to(self.device)
513
+ self.model[component] = model
514
+ self.model_path[component] = file_path
515
+
516
+ def load_ipadapter_image_encoder(self, state_dict, file_path=""):
517
+ component = "ipadapter_image_encoder"
518
+ model = IpAdapterCLIPImageEmbedder()
519
+ model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
520
+ model.to(self.torch_dtype).to(self.device)
521
+ self.model[component] = model
522
+ self.model_path[component] = file_path
523
+
524
+ def load_ipadapter_xl(self, state_dict, file_path=""):
525
+ component = "ipadapter_xl"
526
+ model = SDXLIpAdapter()
527
+ model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
528
+ model.to(self.torch_dtype).to(self.device)
529
+ self.model[component] = model
530
+ self.model_path[component] = file_path
531
+
532
+ def load_ipadapter_xl_image_encoder(self, state_dict, file_path=""):
533
+ component = "ipadapter_xl_image_encoder"
534
+ model = IpAdapterXLCLIPImageEmbedder()
535
+ model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
536
+ model.to(self.torch_dtype).to(self.device)
537
+ self.model[component] = model
538
+ self.model_path[component] = file_path
539
+
540
+ def load_hunyuan_dit_clip_text_encoder(self, state_dict, file_path=""):
541
+ component = "hunyuan_dit_clip_text_encoder"
542
+ model = HunyuanDiTCLIPTextEncoder()
543
+ model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
544
+ model.to(self.torch_dtype).to(self.device)
545
+ self.model[component] = model
546
+ self.model_path[component] = file_path
547
+
548
+ def load_hunyuan_dit_t5_text_encoder(self, state_dict, file_path=""):
549
+ component = "hunyuan_dit_t5_text_encoder"
550
+ model = HunyuanDiTT5TextEncoder()
551
+ model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
552
+ model.to(self.torch_dtype).to(self.device)
553
+ self.model[component] = model
554
+ self.model_path[component] = file_path
555
+
556
+ def load_hunyuan_dit(self, state_dict, file_path=""):
557
+ component = "hunyuan_dit"
558
+ model = HunyuanDiT()
559
+ model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
560
+ model.to(self.torch_dtype).to(self.device)
561
+ self.model[component] = model
562
+ self.model_path[component] = file_path
563
+
564
+ def load_diffusers_vae(self, state_dict, file_path=""):
565
+ # TODO: detect SD and SDXL
566
+ component = "vae_encoder"
567
+ model = SDXLVAEEncoder()
568
+ model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
569
+ model.to(torch.float32).to(self.device)
570
+ self.model[component] = model
571
+ self.model_path[component] = file_path
572
+ component = "vae_decoder"
573
+ model = SDXLVAEDecoder()
574
+ model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
575
+ model.to(torch.float32).to(self.device)
576
+ self.model[component] = model
577
+ self.model_path[component] = file_path
578
+
579
+ def load_ExVideo_StableVideoDiffusion(self, state_dict, file_path=""):
580
+ unet_state_dict = self.model["unet"].state_dict()
581
+ self.model["unet"].to("cpu")
582
+ del self.model["unet"]
583
+ add_positional_conv = state_dict["blocks.185.positional_embedding.embeddings"].shape[0]
584
+ self.model["unet"] = SVDUNet(add_positional_conv=add_positional_conv)
585
+ self.model["unet"].load_state_dict(unet_state_dict, strict=False)
586
+ self.model["unet"].load_state_dict(state_dict, strict=False)
587
+ self.model["unet"].to(self.torch_dtype).to(self.device)
588
+
589
+ def load_stable_diffusion_3(self, state_dict, components=None, file_path=""):
590
+ component_dict = {
591
+ "sd3_text_encoder_1": SD3TextEncoder1,
592
+ "sd3_text_encoder_2": SD3TextEncoder2,
593
+ "sd3_text_encoder_3": SD3TextEncoder3,
594
+ "sd3_dit": SD3DiT,
595
+ "sd3_vae_decoder": SD3VAEDecoder,
596
+ "sd3_vae_encoder": SD3VAEEncoder,
597
+ }
598
+ if components is None:
599
+ components = ["sd3_text_encoder_1", "sd3_text_encoder_2", "sd3_text_encoder_3", "sd3_dit", "sd3_vae_decoder", "sd3_vae_encoder"]
600
+ for component in components:
601
+ if component == "sd3_text_encoder_3":
602
+ if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight" not in state_dict:
603
+ continue
604
+ if component == "sd3_text_encoder_1":
605
+ # Add additional token embeddings to text encoder
606
+ token_embeddings = [state_dict["text_encoders.clip_l.transformer.text_model.embeddings.token_embedding.weight"]]
607
+ for keyword in self.textual_inversion_dict:
608
+ _, embeddings = self.textual_inversion_dict[keyword]
609
+ token_embeddings.append(embeddings.to(dtype=token_embeddings[0].dtype))
610
+ token_embeddings = torch.concat(token_embeddings, dim=0)
611
+ state_dict["text_encoders.clip_l.transformer.text_model.embeddings.token_embedding.weight"] = token_embeddings
612
+ self.model[component] = component_dict[component](vocab_size=token_embeddings.shape[0])
613
+ self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
614
+ self.model[component].to(self.torch_dtype).to(self.device)
615
+ else:
616
+ self.model[component] = component_dict[component]()
617
+ self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
618
+ self.model[component].to(self.torch_dtype).to(self.device)
619
+ self.model_path[component] = file_path
620
+
621
+ def load_stable_diffusion_3_t5(self, state_dict, file_path=""):
622
+ component = "sd3_text_encoder_3"
623
+ model = SD3TextEncoder3()
624
+ model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
625
+ model.to(self.torch_dtype).to(self.device)
626
+ self.model[component] = model
627
+ self.model_path[component] = file_path
628
+
629
+ def load_kolors_text_encoder(self, state_dict=None, file_path=""):
630
+ component = "kolors_text_encoder"
631
+ model = ChatGLMModel.from_pretrained(file_path, torch_dtype=self.torch_dtype)
632
+ model = model.to(dtype=self.torch_dtype, device=self.device)
633
+ self.model[component] = model
634
+ self.model_path[component] = file_path
635
+
636
+ def load_kolors_unet(self, state_dict, file_path=""):
637
+ component = "kolors_unet"
638
+ model = SDXLUNet(is_kolors=True)
639
+ model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
640
+ model.to(self.torch_dtype).to(self.device)
641
+ self.model[component] = model
642
+ self.model_path[component] = file_path
643
+
644
+ def search_for_embeddings(self, state_dict):
645
+ embeddings = []
646
+ for k in state_dict:
647
+ if isinstance(state_dict[k], torch.Tensor):
648
+ embeddings.append(state_dict[k])
649
+ elif isinstance(state_dict[k], dict):
650
+ embeddings += self.search_for_embeddings(state_dict[k])
651
+ return embeddings
652
+
653
+ def load_textual_inversions(self, folder):
654
+ # Store additional tokens here
655
+ self.textual_inversion_dict = {}
656
+
657
+ # Load every textual inversion file
658
+ for file_name in os.listdir(folder):
659
+ if os.path.isdir(os.path.join(folder, file_name)) or \
660
+ not (file_name.endswith(".bin") or \
661
+ file_name.endswith(".safetensors") or \
662
+ file_name.endswith(".pth") or \
663
+ file_name.endswith(".pt")):
664
+ continue
665
+ keyword = os.path.splitext(file_name)[0]
666
+ state_dict = load_state_dict(os.path.join(folder, file_name))
667
+
668
+ # Search for embeddings
669
+ for embeddings in self.search_for_embeddings(state_dict):
670
+ if len(embeddings.shape) == 2 and embeddings.shape[1] == 768:
671
+ tokens = [f"{keyword}_{i}" for i in range(embeddings.shape[0])]
672
+ self.textual_inversion_dict[keyword] = (tokens, embeddings)
673
+ break
674
+
675
+ def load_model(self, file_path, components=None, lora_alphas=[]):
676
+ if os.path.isdir(file_path):
677
+ if self.is_kolors_text_encoder(file_path):
678
+ self.load_kolors_text_encoder(file_path=file_path)
679
+ return
680
+ state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
681
+ if self.is_stable_video_diffusion(state_dict):
682
+ self.load_stable_video_diffusion(state_dict, file_path=file_path)
683
+ elif self.is_animatediff(state_dict):
684
+ self.load_animatediff(state_dict, file_path=file_path)
685
+ elif self.is_animatediff_xl(state_dict):
686
+ self.load_animatediff_xl(state_dict, file_path=file_path)
687
+ elif self.is_controlnet(state_dict):
688
+ self.load_controlnet(state_dict, file_path=file_path)
689
+ elif self.is_stabe_diffusion_xl(state_dict):
690
+ self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path)
691
+ elif self.is_stable_diffusion(state_dict):
692
+ self.load_stable_diffusion(state_dict, components=components, file_path=file_path)
693
+ elif self.is_sd_lora(state_dict):
694
+ self.load_sd_lora(state_dict, alpha=lora_alphas.pop(0))
695
+ elif self.is_beautiful_prompt(state_dict):
696
+ self.load_beautiful_prompt(state_dict, file_path=file_path)
697
+ elif self.is_RIFE(state_dict):
698
+ self.load_RIFE(state_dict, file_path=file_path)
699
+ elif self.is_translator(state_dict):
700
+ self.load_translator(state_dict, file_path=file_path)
701
+ elif self.is_ipadapter(state_dict):
702
+ self.load_ipadapter(state_dict, file_path=file_path)
703
+ elif self.is_ipadapter_image_encoder(state_dict):
704
+ self.load_ipadapter_image_encoder(state_dict, file_path=file_path)
705
+ elif self.is_ipadapter_xl(state_dict):
706
+ self.load_ipadapter_xl(state_dict, file_path=file_path)
707
+ elif self.is_ipadapter_xl_image_encoder(state_dict):
708
+ self.load_ipadapter_xl_image_encoder(state_dict, file_path=file_path)
709
+ elif self.is_hunyuan_dit_clip_text_encoder(state_dict):
710
+ self.load_hunyuan_dit_clip_text_encoder(state_dict, file_path=file_path)
711
+ elif self.is_hunyuan_dit_t5_text_encoder(state_dict):
712
+ self.load_hunyuan_dit_t5_text_encoder(state_dict, file_path=file_path)
713
+ elif self.is_hunyuan_dit(state_dict):
714
+ self.load_hunyuan_dit(state_dict, file_path=file_path)
715
+ elif self.is_diffusers_vae(state_dict):
716
+ self.load_diffusers_vae(state_dict, file_path=file_path)
717
+ elif self.is_ExVideo_StableVideoDiffusion(state_dict):
718
+ self.load_ExVideo_StableVideoDiffusion(state_dict, file_path=file_path)
719
+ elif self.is_stable_diffusion_3(state_dict):
720
+ self.load_stable_diffusion_3(state_dict, components=components, file_path=file_path)
721
+ elif self.is_stable_diffusion_3_t5(state_dict):
722
+ self.load_stable_diffusion_3_t5(state_dict, file_path=file_path)
723
+ elif self.is_kolors_unet(state_dict):
724
+ self.load_kolors_unet(state_dict, file_path=file_path)
725
+
726
+ def load_models(self, file_path_list, lora_alphas=[]):
727
+ for file_path in file_path_list:
728
+ self.load_model(file_path, lora_alphas=lora_alphas)
729
+
730
+ def to(self, device):
731
+ for component in self.model:
732
+ if isinstance(self.model[component], list):
733
+ for model in self.model[component]:
734
+ model.to(device)
735
+ else:
736
+ self.model[component].to(device)
737
+ torch.cuda.empty_cache()
738
+
739
+ def get_model_with_model_path(self, model_path):
740
+ for component in self.model_path:
741
+ if isinstance(self.model_path[component], str):
742
+ if os.path.samefile(self.model_path[component], model_path):
743
+ return self.model[component]
744
+ elif isinstance(self.model_path[component], list):
745
+ for i, model_path_ in enumerate(self.model_path[component]):
746
+ if os.path.samefile(model_path_, model_path):
747
+ return self.model[component][i]
748
+ raise ValueError(f"Please load model {model_path} before you use it.")
749
+
750
+ def __getattr__(self, __name):
751
+ if __name in self.model:
752
+ return self.model[__name]
753
+ else:
754
+ return super.__getattribute__(__name)
755
+
756
+
757
+ def load_state_dict(file_path, torch_dtype=None):
758
+ if file_path.endswith(".safetensors"):
759
+ return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
760
+ else:
761
+ return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
762
+
763
+
764
+ def load_state_dict_from_safetensors(file_path, torch_dtype=None):
765
+ state_dict = {}
766
+ with safe_open(file_path, framework="pt", device="cpu") as f:
767
+ for k in f.keys():
768
+ state_dict[k] = f.get_tensor(k)
769
+ if torch_dtype is not None:
770
+ state_dict[k] = state_dict[k].to(torch_dtype)
771
+ return state_dict
772
+
773
+
774
+ def load_state_dict_from_bin(file_path, torch_dtype=None):
775
+ state_dict = torch.load(file_path, map_location="cpu")
776
+ if torch_dtype is not None:
777
+ for i in state_dict:
778
+ if isinstance(state_dict[i], torch.Tensor):
779
+ state_dict[i] = state_dict[i].to(torch_dtype)
780
+ return state_dict
781
+
782
+
783
+ def search_parameter(param, state_dict):
784
+ for name, param_ in state_dict.items():
785
+ if param.numel() == param_.numel():
786
+ if param.shape == param_.shape:
787
+ if torch.dist(param, param_) < 1e-6:
788
+ return name
789
+ else:
790
+ if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
791
+ return name
792
+ return None
793
+
794
+
795
+ def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
796
+ matched_keys = set()
797
+ with torch.no_grad():
798
+ for name in source_state_dict:
799
+ rename = search_parameter(source_state_dict[name], target_state_dict)
800
+ if rename is not None:
801
+ print(f'"{name}": "{rename}",')
802
+ matched_keys.add(rename)
803
+ elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
804
+ length = source_state_dict[name].shape[0] // 3
805
+ rename = []
806
+ for i in range(3):
807
+ rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
808
+ if None not in rename:
809
+ print(f'"{name}": {rename},')
810
+ for rename_ in rename:
811
+ matched_keys.add(rename_)
812
+ for name in target_state_dict:
813
+ if name not in matched_keys:
814
+ print("Cannot find", name, target_state_dict[name].shape)
diffsynth/models/attention.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+
4
+
5
+ def low_version_attention(query, key, value, attn_bias=None):
6
+ scale = 1 / query.shape[-1] ** 0.5
7
+ query = query * scale
8
+ attn = torch.matmul(query, key.transpose(-2, -1))
9
+ if attn_bias is not None:
10
+ attn = attn + attn_bias
11
+ attn = attn.softmax(-1)
12
+ return attn @ value
13
+
14
+
15
+ class Attention(torch.nn.Module):
16
+
17
+ def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
18
+ super().__init__()
19
+ dim_inner = head_dim * num_heads
20
+ kv_dim = kv_dim if kv_dim is not None else q_dim
21
+ self.num_heads = num_heads
22
+ self.head_dim = head_dim
23
+
24
+ self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
25
+ self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
26
+ self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
27
+ self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
28
+
29
+ def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
30
+ batch_size = q.shape[0]
31
+ ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
32
+ ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
33
+ ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
34
+ hidden_states = hidden_states + scale * ip_hidden_states
35
+ return hidden_states
36
+
37
+ def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
38
+ if encoder_hidden_states is None:
39
+ encoder_hidden_states = hidden_states
40
+
41
+ batch_size = encoder_hidden_states.shape[0]
42
+
43
+ q = self.to_q(hidden_states)
44
+ k = self.to_k(encoder_hidden_states)
45
+ v = self.to_v(encoder_hidden_states)
46
+
47
+ q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
48
+ k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
49
+ v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
50
+
51
+ if qkv_preprocessor is not None:
52
+ q, k, v = qkv_preprocessor(q, k, v)
53
+
54
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
55
+ if ipadapter_kwargs is not None:
56
+ hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
57
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
58
+ hidden_states = hidden_states.to(q.dtype)
59
+
60
+ hidden_states = self.to_out(hidden_states)
61
+
62
+ return hidden_states
63
+
64
+ def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
65
+ if encoder_hidden_states is None:
66
+ encoder_hidden_states = hidden_states
67
+
68
+ q = self.to_q(hidden_states)
69
+ k = self.to_k(encoder_hidden_states)
70
+ v = self.to_v(encoder_hidden_states)
71
+
72
+ q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
73
+ k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
74
+ v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
75
+
76
+ if attn_mask is not None:
77
+ hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
78
+ else:
79
+ import xformers.ops as xops
80
+ hidden_states = xops.memory_efficient_attention(q, k, v)
81
+ hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
82
+
83
+ hidden_states = hidden_states.to(q.dtype)
84
+ hidden_states = self.to_out(hidden_states)
85
+
86
+ return hidden_states
87
+
88
+ def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
89
+ return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
diffsynth/models/downloader.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download
2
+ from modelscope import snapshot_download
3
+ import os, shutil
4
+
5
+
6
+ def download_from_modelscope(model_id, origin_file_path, local_dir):
7
+ os.makedirs(local_dir, exist_ok=True)
8
+ if os.path.basename(origin_file_path) in os.listdir(local_dir):
9
+ print(f"{os.path.basename(origin_file_path)} has been already in {local_dir}.")
10
+ return
11
+ else:
12
+ print(f"Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
13
+ snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
14
+ downloaded_file_path = os.path.join(local_dir, origin_file_path)
15
+ target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
16
+ if downloaded_file_path != target_file_path:
17
+ shutil.move(downloaded_file_path, target_file_path)
18
+ shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
19
+
20
+
21
+ def download_from_huggingface(model_id, origin_file_path, local_dir):
22
+ os.makedirs(local_dir, exist_ok=True)
23
+ if os.path.basename(origin_file_path) in os.listdir(local_dir):
24
+ print(f"{os.path.basename(origin_file_path)} has been already in {local_dir}.")
25
+ return
26
+ else:
27
+ print(f"Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
28
+ hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
diffsynth/models/hunyuan_dit.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .attention import Attention
2
+ from .tiler import TileWorker
3
+ from einops import repeat, rearrange
4
+ import math
5
+ import torch
6
+
7
+
8
+ class HunyuanDiTRotaryEmbedding(torch.nn.Module):
9
+
10
+ def __init__(self, q_norm_shape=88, k_norm_shape=88, rotary_emb_on_k=True):
11
+ super().__init__()
12
+ self.q_norm = torch.nn.LayerNorm((q_norm_shape,), elementwise_affine=True, eps=1e-06)
13
+ self.k_norm = torch.nn.LayerNorm((k_norm_shape,), elementwise_affine=True, eps=1e-06)
14
+ self.rotary_emb_on_k = rotary_emb_on_k
15
+ self.k_cache, self.v_cache = [], []
16
+
17
+ def reshape_for_broadcast(self, freqs_cis, x):
18
+ ndim = x.ndim
19
+ shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
20
+ return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
21
+
22
+ def rotate_half(self, x):
23
+ x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
24
+ return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
25
+
26
+ def apply_rotary_emb(self, xq, xk, freqs_cis):
27
+ xk_out = None
28
+ cos, sin = self.reshape_for_broadcast(freqs_cis, xq)
29
+ cos, sin = cos.to(xq.device), sin.to(xq.device)
30
+ xq_out = (xq.float() * cos + self.rotate_half(xq.float()) * sin).type_as(xq)
31
+ if xk is not None:
32
+ xk_out = (xk.float() * cos + self.rotate_half(xk.float()) * sin).type_as(xk)
33
+ return xq_out, xk_out
34
+
35
+ def forward(self, q, k, v, freqs_cis_img, to_cache=False):
36
+ # norm
37
+ q = self.q_norm(q)
38
+ k = self.k_norm(k)
39
+
40
+ # RoPE
41
+ if self.rotary_emb_on_k:
42
+ q, k = self.apply_rotary_emb(q, k, freqs_cis_img)
43
+ else:
44
+ q, _ = self.apply_rotary_emb(q, None, freqs_cis_img)
45
+
46
+ if to_cache:
47
+ self.k_cache.append(k)
48
+ self.v_cache.append(v)
49
+ elif len(self.k_cache) > 0 and len(self.v_cache) > 0:
50
+ k = torch.concat([k] + self.k_cache, dim=2)
51
+ v = torch.concat([v] + self.v_cache, dim=2)
52
+ self.k_cache, self.v_cache = [], []
53
+ return q, k, v
54
+
55
+
56
+ class FP32_Layernorm(torch.nn.LayerNorm):
57
+ def forward(self, inputs):
58
+ origin_dtype = inputs.dtype
59
+ return torch.nn.functional.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).to(origin_dtype)
60
+
61
+
62
+ class FP32_SiLU(torch.nn.SiLU):
63
+ def forward(self, inputs):
64
+ origin_dtype = inputs.dtype
65
+ return torch.nn.functional.silu(inputs.float(), inplace=False).to(origin_dtype)
66
+
67
+
68
+ class HunyuanDiTFinalLayer(torch.nn.Module):
69
+ def __init__(self, final_hidden_size=1408, condition_dim=1408, patch_size=2, out_channels=8):
70
+ super().__init__()
71
+ self.norm_final = torch.nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
72
+ self.linear = torch.nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
73
+ self.adaLN_modulation = torch.nn.Sequential(
74
+ FP32_SiLU(),
75
+ torch.nn.Linear(condition_dim, 2 * final_hidden_size, bias=True)
76
+ )
77
+
78
+ def modulate(self, x, shift, scale):
79
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
80
+
81
+ def forward(self, hidden_states, condition_emb):
82
+ shift, scale = self.adaLN_modulation(condition_emb).chunk(2, dim=1)
83
+ hidden_states = self.modulate(self.norm_final(hidden_states), shift, scale)
84
+ hidden_states = self.linear(hidden_states)
85
+ return hidden_states
86
+
87
+
88
+ class HunyuanDiTBlock(torch.nn.Module):
89
+
90
+ def __init__(
91
+ self,
92
+ hidden_dim=1408,
93
+ condition_dim=1408,
94
+ num_heads=16,
95
+ mlp_ratio=4.3637,
96
+ text_dim=1024,
97
+ skip_connection=False
98
+ ):
99
+ super().__init__()
100
+ self.norm1 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
101
+ self.rota1 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads)
102
+ self.attn1 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
103
+ self.norm2 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
104
+ self.rota2 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads, rotary_emb_on_k=False)
105
+ self.attn2 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, kv_dim=text_dim, bias_q=True, bias_kv=True, bias_out=True)
106
+ self.norm3 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
107
+ self.modulation = torch.nn.Sequential(FP32_SiLU(), torch.nn.Linear(condition_dim, hidden_dim, bias=True))
108
+ self.mlp = torch.nn.Sequential(
109
+ torch.nn.Linear(hidden_dim, int(hidden_dim*mlp_ratio), bias=True),
110
+ torch.nn.GELU(approximate="tanh"),
111
+ torch.nn.Linear(int(hidden_dim*mlp_ratio), hidden_dim, bias=True)
112
+ )
113
+ if skip_connection:
114
+ self.skip_norm = FP32_Layernorm((hidden_dim * 2,), eps=1e-6, elementwise_affine=True)
115
+ self.skip_linear = torch.nn.Linear(hidden_dim * 2, hidden_dim, bias=True)
116
+ else:
117
+ self.skip_norm, self.skip_linear = None, None
118
+
119
+ def forward(self, hidden_states, condition_emb, text_emb, freq_cis_img, residual=None, to_cache=False):
120
+ # Long Skip Connection
121
+ if self.skip_norm is not None and self.skip_linear is not None:
122
+ hidden_states = torch.cat([hidden_states, residual], dim=-1)
123
+ hidden_states = self.skip_norm(hidden_states)
124
+ hidden_states = self.skip_linear(hidden_states)
125
+
126
+ # Self-Attention
127
+ shift_msa = self.modulation(condition_emb).unsqueeze(dim=1)
128
+ attn_input = self.norm1(hidden_states) + shift_msa
129
+ hidden_states = hidden_states + self.attn1(attn_input, qkv_preprocessor=lambda q, k, v: self.rota1(q, k, v, freq_cis_img, to_cache=to_cache))
130
+
131
+ # Cross-Attention
132
+ attn_input = self.norm3(hidden_states)
133
+ hidden_states = hidden_states + self.attn2(attn_input, text_emb, qkv_preprocessor=lambda q, k, v: self.rota2(q, k, v, freq_cis_img))
134
+
135
+ # FFN Layer
136
+ mlp_input = self.norm2(hidden_states)
137
+ hidden_states = hidden_states + self.mlp(mlp_input)
138
+ return hidden_states
139
+
140
+
141
+ class AttentionPool(torch.nn.Module):
142
+ def __init__(self, spacial_dim, embed_dim, num_heads, output_dim = None):
143
+ super().__init__()
144
+ self.positional_embedding = torch.nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
145
+ self.k_proj = torch.nn.Linear(embed_dim, embed_dim)
146
+ self.q_proj = torch.nn.Linear(embed_dim, embed_dim)
147
+ self.v_proj = torch.nn.Linear(embed_dim, embed_dim)
148
+ self.c_proj = torch.nn.Linear(embed_dim, output_dim or embed_dim)
149
+ self.num_heads = num_heads
150
+
151
+ def forward(self, x):
152
+ x = x.permute(1, 0, 2) # NLC -> LNC
153
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
154
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
155
+ x, _ = torch.nn.functional.multi_head_attention_forward(
156
+ query=x[:1], key=x, value=x,
157
+ embed_dim_to_check=x.shape[-1],
158
+ num_heads=self.num_heads,
159
+ q_proj_weight=self.q_proj.weight,
160
+ k_proj_weight=self.k_proj.weight,
161
+ v_proj_weight=self.v_proj.weight,
162
+ in_proj_weight=None,
163
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
164
+ bias_k=None,
165
+ bias_v=None,
166
+ add_zero_attn=False,
167
+ dropout_p=0,
168
+ out_proj_weight=self.c_proj.weight,
169
+ out_proj_bias=self.c_proj.bias,
170
+ use_separate_proj_weight=True,
171
+ training=self.training,
172
+ need_weights=False
173
+ )
174
+ return x.squeeze(0)
175
+
176
+
177
+ class PatchEmbed(torch.nn.Module):
178
+ def __init__(
179
+ self,
180
+ patch_size=(2, 2),
181
+ in_chans=4,
182
+ embed_dim=1408,
183
+ bias=True,
184
+ ):
185
+ super().__init__()
186
+ self.proj = torch.nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
187
+
188
+ def forward(self, x):
189
+ x = self.proj(x)
190
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
191
+ return x
192
+
193
+
194
+ def timestep_embedding(t, dim, max_period=10000, repeat_only=False):
195
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
196
+ if not repeat_only:
197
+ half = dim // 2
198
+ freqs = torch.exp(
199
+ -math.log(max_period)
200
+ * torch.arange(start=0, end=half, dtype=torch.float32)
201
+ / half
202
+ ).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线
203
+ args = t[:, None].float() * freqs[None]
204
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
205
+ if dim % 2:
206
+ embedding = torch.cat(
207
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
208
+ )
209
+ else:
210
+ embedding = repeat(t, "b -> b d", d=dim)
211
+ return embedding
212
+
213
+
214
+ class TimestepEmbedder(torch.nn.Module):
215
+ def __init__(self, hidden_size=1408, frequency_embedding_size=256):
216
+ super().__init__()
217
+ self.mlp = torch.nn.Sequential(
218
+ torch.nn.Linear(frequency_embedding_size, hidden_size, bias=True),
219
+ torch.nn.SiLU(),
220
+ torch.nn.Linear(hidden_size, hidden_size, bias=True),
221
+ )
222
+ self.frequency_embedding_size = frequency_embedding_size
223
+
224
+ def forward(self, t):
225
+ t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
226
+ t_emb = self.mlp(t_freq)
227
+ return t_emb
228
+
229
+
230
+ class HunyuanDiT(torch.nn.Module):
231
+ def __init__(self, num_layers_down=21, num_layers_up=19, in_channels=4, out_channels=8, hidden_dim=1408, text_dim=1024, t5_dim=2048, text_length=77, t5_length=256):
232
+ super().__init__()
233
+
234
+ # Embedders
235
+ self.text_emb_padding = torch.nn.Parameter(torch.randn(text_length + t5_length, text_dim, dtype=torch.float32))
236
+ self.t5_embedder = torch.nn.Sequential(
237
+ torch.nn.Linear(t5_dim, t5_dim * 4, bias=True),
238
+ FP32_SiLU(),
239
+ torch.nn.Linear(t5_dim * 4, text_dim, bias=True),
240
+ )
241
+ self.t5_pooler = AttentionPool(t5_length, t5_dim, num_heads=8, output_dim=1024)
242
+ self.style_embedder = torch.nn.Parameter(torch.randn(hidden_dim))
243
+ self.patch_embedder = PatchEmbed(in_chans=in_channels)
244
+ self.timestep_embedder = TimestepEmbedder()
245
+ self.extra_embedder = torch.nn.Sequential(
246
+ torch.nn.Linear(256 * 6 + 1024 + hidden_dim, hidden_dim * 4),
247
+ FP32_SiLU(),
248
+ torch.nn.Linear(hidden_dim * 4, hidden_dim),
249
+ )
250
+
251
+ # Transformer blocks
252
+ self.num_layers_down = num_layers_down
253
+ self.num_layers_up = num_layers_up
254
+ self.blocks = torch.nn.ModuleList(
255
+ [HunyuanDiTBlock(skip_connection=False) for _ in range(num_layers_down)] + \
256
+ [HunyuanDiTBlock(skip_connection=True) for _ in range(num_layers_up)]
257
+ )
258
+
259
+ # Output layers
260
+ self.final_layer = HunyuanDiTFinalLayer()
261
+ self.out_channels = out_channels
262
+
263
+ def prepare_text_emb(self, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5):
264
+ text_emb_mask = text_emb_mask.bool()
265
+ text_emb_mask_t5 = text_emb_mask_t5.bool()
266
+ text_emb_t5 = self.t5_embedder(text_emb_t5)
267
+ text_emb = torch.cat([text_emb, text_emb_t5], dim=1)
268
+ text_emb_mask = torch.cat([text_emb_mask, text_emb_mask_t5], dim=-1)
269
+ text_emb = torch.where(text_emb_mask.unsqueeze(2), text_emb, self.text_emb_padding.to(text_emb))
270
+ return text_emb
271
+
272
+ def prepare_extra_emb(self, text_emb_t5, timestep, size_emb, dtype, batch_size):
273
+ # Text embedding
274
+ pooled_text_emb_t5 = self.t5_pooler(text_emb_t5)
275
+
276
+ # Timestep embedding
277
+ timestep_emb = self.timestep_embedder(timestep)
278
+
279
+ # Size embedding
280
+ size_emb = timestep_embedding(size_emb.view(-1), 256).to(dtype)
281
+ size_emb = size_emb.view(-1, 6 * 256)
282
+
283
+ # Style embedding
284
+ style_emb = repeat(self.style_embedder, "D -> B D", B=batch_size)
285
+
286
+ # Concatenate all extra vectors
287
+ extra_emb = torch.cat([pooled_text_emb_t5, size_emb, style_emb], dim=1)
288
+ condition_emb = timestep_emb + self.extra_embedder(extra_emb)
289
+
290
+ return condition_emb
291
+
292
+ def unpatchify(self, x, h, w):
293
+ return rearrange(x, "B (H W) (P Q C) -> B C (H P) (W Q)", H=h, W=w, P=2, Q=2)
294
+
295
+ def build_mask(self, data, is_bound):
296
+ _, _, H, W = data.shape
297
+ h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
298
+ w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
299
+ border_width = (H + W) // 4
300
+ pad = torch.ones_like(h) * border_width
301
+ mask = torch.stack([
302
+ pad if is_bound[0] else h + 1,
303
+ pad if is_bound[1] else H - h,
304
+ pad if is_bound[2] else w + 1,
305
+ pad if is_bound[3] else W - w
306
+ ]).min(dim=0).values
307
+ mask = mask.clip(1, border_width)
308
+ mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
309
+ mask = rearrange(mask, "H W -> 1 H W")
310
+ return mask
311
+
312
+ def tiled_block_forward(self, block, hidden_states, condition_emb, text_emb, freq_cis_img, residual, torch_dtype, data_device, computation_device, tile_size, tile_stride):
313
+ B, C, H, W = hidden_states.shape
314
+
315
+ weight = torch.zeros((1, 1, H, W), dtype=torch_dtype, device=data_device)
316
+ values = torch.zeros((B, C, H, W), dtype=torch_dtype, device=data_device)
317
+
318
+ # Split tasks
319
+ tasks = []
320
+ for h in range(0, H, tile_stride):
321
+ for w in range(0, W, tile_stride):
322
+ if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
323
+ continue
324
+ h_, w_ = h + tile_size, w + tile_size
325
+ if h_ > H: h, h_ = H - tile_size, H
326
+ if w_ > W: w, w_ = W - tile_size, W
327
+ tasks.append((h, h_, w, w_))
328
+
329
+ # Run
330
+ for hl, hr, wl, wr in tasks:
331
+ hidden_states_batch = hidden_states[:, :, hl:hr, wl:wr].to(computation_device)
332
+ hidden_states_batch = rearrange(hidden_states_batch, "B C H W -> B (H W) C")
333
+ if residual is not None:
334
+ residual_batch = residual[:, :, hl:hr, wl:wr].to(computation_device)
335
+ residual_batch = rearrange(residual_batch, "B C H W -> B (H W) C")
336
+ else:
337
+ residual_batch = None
338
+
339
+ # Forward
340
+ hidden_states_batch = block(hidden_states_batch, condition_emb, text_emb, freq_cis_img, residual_batch).to(data_device)
341
+ hidden_states_batch = rearrange(hidden_states_batch, "B (H W) C -> B C H W", H=hr-hl)
342
+
343
+ mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))
344
+ values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
345
+ weight[:, :, hl:hr, wl:wr] += mask
346
+ values /= weight
347
+ return values
348
+
349
+ def forward(
350
+ self, hidden_states, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5, timestep, size_emb, freq_cis_img,
351
+ tiled=False, tile_size=64, tile_stride=32,
352
+ to_cache=False,
353
+ use_gradient_checkpointing=False,
354
+ ):
355
+ # Embeddings
356
+ text_emb = self.prepare_text_emb(text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5)
357
+ condition_emb = self.prepare_extra_emb(text_emb_t5, timestep, size_emb, hidden_states.dtype, hidden_states.shape[0])
358
+
359
+ # Input
360
+ height, width = hidden_states.shape[-2], hidden_states.shape[-1]
361
+ hidden_states = self.patch_embedder(hidden_states)
362
+
363
+ # Blocks
364
+ def create_custom_forward(module):
365
+ def custom_forward(*inputs):
366
+ return module(*inputs)
367
+ return custom_forward
368
+ if tiled:
369
+ hidden_states = rearrange(hidden_states, "B (H W) C -> B C H W", H=height//2)
370
+ residuals = []
371
+ for block_id, block in enumerate(self.blocks):
372
+ residual = residuals.pop() if block_id >= self.num_layers_down else None
373
+ hidden_states = self.tiled_block_forward(
374
+ block, hidden_states, condition_emb, text_emb, freq_cis_img, residual,
375
+ torch_dtype=hidden_states.dtype, data_device=hidden_states.device, computation_device=hidden_states.device,
376
+ tile_size=tile_size, tile_stride=tile_stride
377
+ )
378
+ if block_id < self.num_layers_down - 2:
379
+ residuals.append(hidden_states)
380
+ hidden_states = rearrange(hidden_states, "B C H W -> B (H W) C")
381
+ else:
382
+ residuals = []
383
+ for block_id, block in enumerate(self.blocks):
384
+ residual = residuals.pop() if block_id >= self.num_layers_down else None
385
+ if self.training and use_gradient_checkpointing:
386
+ hidden_states = torch.utils.checkpoint.checkpoint(
387
+ create_custom_forward(block),
388
+ hidden_states, condition_emb, text_emb, freq_cis_img, residual,
389
+ use_reentrant=False,
390
+ )
391
+ else:
392
+ hidden_states = block(hidden_states, condition_emb, text_emb, freq_cis_img, residual, to_cache=to_cache)
393
+ if block_id < self.num_layers_down - 2:
394
+ residuals.append(hidden_states)
395
+
396
+ # Output
397
+ hidden_states = self.final_layer(hidden_states, condition_emb)
398
+ hidden_states = self.unpatchify(hidden_states, height//2, width//2)
399
+ hidden_states, _ = hidden_states.chunk(2, dim=1)
400
+ return hidden_states
401
+
402
+ def state_dict_converter(self):
403
+ return HunyuanDiTStateDictConverter()
404
+
405
+
406
+
407
+ class HunyuanDiTStateDictConverter():
408
+ def __init__(self):
409
+ pass
410
+
411
+ def from_diffusers(self, state_dict):
412
+ state_dict_ = {}
413
+ for name, param in state_dict.items():
414
+ name_ = name
415
+ name_ = name_.replace(".default_modulation.", ".modulation.")
416
+ name_ = name_.replace(".mlp.fc1.", ".mlp.0.")
417
+ name_ = name_.replace(".mlp.fc2.", ".mlp.2.")
418
+ name_ = name_.replace(".attn1.q_norm.", ".rota1.q_norm.")
419
+ name_ = name_.replace(".attn2.q_norm.", ".rota2.q_norm.")
420
+ name_ = name_.replace(".attn1.k_norm.", ".rota1.k_norm.")
421
+ name_ = name_.replace(".attn2.k_norm.", ".rota2.k_norm.")
422
+ name_ = name_.replace(".q_proj.", ".to_q.")
423
+ name_ = name_.replace(".out_proj.", ".to_out.")
424
+ name_ = name_.replace("text_embedding_padding", "text_emb_padding")
425
+ name_ = name_.replace("mlp_t5.0.", "t5_embedder.0.")
426
+ name_ = name_.replace("mlp_t5.2.", "t5_embedder.2.")
427
+ name_ = name_.replace("pooler.", "t5_pooler.")
428
+ name_ = name_.replace("x_embedder.", "patch_embedder.")
429
+ name_ = name_.replace("t_embedder.", "timestep_embedder.")
430
+ name_ = name_.replace("t5_pooler.to_q.", "t5_pooler.q_proj.")
431
+ name_ = name_.replace("style_embedder.weight", "style_embedder")
432
+ if ".kv_proj." in name_:
433
+ param_k = param[:param.shape[0]//2]
434
+ param_v = param[param.shape[0]//2:]
435
+ state_dict_[name_.replace(".kv_proj.", ".to_k.")] = param_k
436
+ state_dict_[name_.replace(".kv_proj.", ".to_v.")] = param_v
437
+ elif ".Wqkv." in name_:
438
+ param_q = param[:param.shape[0]//3]
439
+ param_k = param[param.shape[0]//3:param.shape[0]//3*2]
440
+ param_v = param[param.shape[0]//3*2:]
441
+ state_dict_[name_.replace(".Wqkv.", ".to_q.")] = param_q
442
+ state_dict_[name_.replace(".Wqkv.", ".to_k.")] = param_k
443
+ state_dict_[name_.replace(".Wqkv.", ".to_v.")] = param_v
444
+ elif "style_embedder" in name_:
445
+ state_dict_[name_] = param.squeeze()
446
+ else:
447
+ state_dict_[name_] = param
448
+ return state_dict_
449
+
450
+ def from_civitai(self, state_dict):
451
+ return self.from_diffusers(state_dict)
diffsynth/models/hunyuan_dit_text_encoder.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertModel, BertConfig, T5EncoderModel, T5Config
2
+ import torch
3
+
4
+
5
+
6
+ class HunyuanDiTCLIPTextEncoder(BertModel):
7
+ def __init__(self):
8
+ config = BertConfig(
9
+ _name_or_path = "",
10
+ architectures = ["BertModel"],
11
+ attention_probs_dropout_prob = 0.1,
12
+ bos_token_id = 0,
13
+ classifier_dropout = None,
14
+ directionality = "bidi",
15
+ eos_token_id = 2,
16
+ hidden_act = "gelu",
17
+ hidden_dropout_prob = 0.1,
18
+ hidden_size = 1024,
19
+ initializer_range = 0.02,
20
+ intermediate_size = 4096,
21
+ layer_norm_eps = 1e-12,
22
+ max_position_embeddings = 512,
23
+ model_type = "bert",
24
+ num_attention_heads = 16,
25
+ num_hidden_layers = 24,
26
+ output_past = True,
27
+ pad_token_id = 0,
28
+ pooler_fc_size = 768,
29
+ pooler_num_attention_heads = 12,
30
+ pooler_num_fc_layers = 3,
31
+ pooler_size_per_head = 128,
32
+ pooler_type = "first_token_transform",
33
+ position_embedding_type = "absolute",
34
+ torch_dtype = "float32",
35
+ transformers_version = "4.37.2",
36
+ type_vocab_size = 2,
37
+ use_cache = True,
38
+ vocab_size = 47020
39
+ )
40
+ super().__init__(config, add_pooling_layer=False)
41
+ self.eval()
42
+
43
+ def forward(self, input_ids, attention_mask, clip_skip=1):
44
+ input_shape = input_ids.size()
45
+
46
+ batch_size, seq_length = input_shape
47
+ device = input_ids.device
48
+
49
+ past_key_values_length = 0
50
+
51
+ if attention_mask is None:
52
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
53
+
54
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
55
+
56
+ embedding_output = self.embeddings(
57
+ input_ids=input_ids,
58
+ position_ids=None,
59
+ token_type_ids=None,
60
+ inputs_embeds=None,
61
+ past_key_values_length=0,
62
+ )
63
+ encoder_outputs = self.encoder(
64
+ embedding_output,
65
+ attention_mask=extended_attention_mask,
66
+ head_mask=None,
67
+ encoder_hidden_states=None,
68
+ encoder_attention_mask=None,
69
+ past_key_values=None,
70
+ use_cache=False,
71
+ output_attentions=False,
72
+ output_hidden_states=True,
73
+ return_dict=True,
74
+ )
75
+ all_hidden_states = encoder_outputs.hidden_states
76
+ prompt_emb = all_hidden_states[-clip_skip]
77
+ if clip_skip > 1:
78
+ mean, std = all_hidden_states[-1].mean(), all_hidden_states[-1].std()
79
+ prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
80
+ return prompt_emb
81
+
82
+ def state_dict_converter(self):
83
+ return HunyuanDiTCLIPTextEncoderStateDictConverter()
84
+
85
+
86
+
87
+ class HunyuanDiTT5TextEncoder(T5EncoderModel):
88
+ def __init__(self):
89
+ config = T5Config(
90
+ _name_or_path = "../HunyuanDiT/t2i/mt5",
91
+ architectures = ["MT5ForConditionalGeneration"],
92
+ classifier_dropout = 0.0,
93
+ d_ff = 5120,
94
+ d_kv = 64,
95
+ d_model = 2048,
96
+ decoder_start_token_id = 0,
97
+ dense_act_fn = "gelu_new",
98
+ dropout_rate = 0.1,
99
+ eos_token_id = 1,
100
+ feed_forward_proj = "gated-gelu",
101
+ initializer_factor = 1.0,
102
+ is_encoder_decoder = True,
103
+ is_gated_act = True,
104
+ layer_norm_epsilon = 1e-06,
105
+ model_type = "t5",
106
+ num_decoder_layers = 24,
107
+ num_heads = 32,
108
+ num_layers = 24,
109
+ output_past = True,
110
+ pad_token_id = 0,
111
+ relative_attention_max_distance = 128,
112
+ relative_attention_num_buckets = 32,
113
+ tie_word_embeddings = False,
114
+ tokenizer_class = "T5Tokenizer",
115
+ transformers_version = "4.37.2",
116
+ use_cache = True,
117
+ vocab_size = 250112
118
+ )
119
+ super().__init__(config)
120
+ self.eval()
121
+
122
+ def forward(self, input_ids, attention_mask, clip_skip=1):
123
+ outputs = super().forward(
124
+ input_ids=input_ids,
125
+ attention_mask=attention_mask,
126
+ output_hidden_states=True,
127
+ )
128
+ prompt_emb = outputs.hidden_states[-clip_skip]
129
+ if clip_skip > 1:
130
+ mean, std = outputs.hidden_states[-1].mean(), outputs.hidden_states[-1].std()
131
+ prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
132
+ return prompt_emb
133
+
134
+ def state_dict_converter(self):
135
+ return HunyuanDiTT5TextEncoderStateDictConverter()
136
+
137
+
138
+
139
+ class HunyuanDiTCLIPTextEncoderStateDictConverter():
140
+ def __init__(self):
141
+ pass
142
+
143
+ def from_diffusers(self, state_dict):
144
+ state_dict_ = {name[5:]: param for name, param in state_dict.items() if name.startswith("bert.")}
145
+ return state_dict_
146
+
147
+ def from_civitai(self, state_dict):
148
+ return self.from_diffusers(state_dict)
149
+
150
+
151
+ class HunyuanDiTT5TextEncoderStateDictConverter():
152
+ def __init__(self):
153
+ pass
154
+
155
+ def from_diffusers(self, state_dict):
156
+ state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("encoder.")}
157
+ state_dict_["shared.weight"] = state_dict["shared.weight"]
158
+ return state_dict_
159
+
160
+ def from_civitai(self, state_dict):
161
+ return self.from_diffusers(state_dict)
diffsynth/models/kolors_text_encoder.py ADDED
@@ -0,0 +1,1363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This model is copied from https://github.com/Kwai-Kolors/Kolors/tree/master/kolors/models.
3
+ We didn't modify this model.
4
+ The tensor operation is performed in the prompter.
5
+ """
6
+
7
+
8
+ """ PyTorch ChatGLM model. """
9
+
10
+ import math
11
+ import copy
12
+ import warnings
13
+ import re
14
+ import sys
15
+
16
+ import torch
17
+ import torch.utils.checkpoint
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+ from torch.nn import CrossEntropyLoss, LayerNorm
21
+ from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
22
+ from torch.nn.utils import skip_init
23
+ from typing import Optional, Tuple, Union, List, Callable, Dict, Any
24
+ from copy import deepcopy
25
+
26
+ from transformers.modeling_outputs import (
27
+ BaseModelOutputWithPast,
28
+ CausalLMOutputWithPast,
29
+ SequenceClassifierOutputWithPast,
30
+ )
31
+ from transformers.modeling_utils import PreTrainedModel
32
+ from transformers.utils import logging
33
+ from transformers.generation.logits_process import LogitsProcessor
34
+ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
35
+ from transformers import PretrainedConfig
36
+
37
+
38
+
39
+ class ChatGLMConfig(PretrainedConfig):
40
+ model_type = "chatglm"
41
+ def __init__(
42
+ self,
43
+ num_layers=28,
44
+ padded_vocab_size=65024,
45
+ hidden_size=4096,
46
+ ffn_hidden_size=13696,
47
+ kv_channels=128,
48
+ num_attention_heads=32,
49
+ seq_length=2048,
50
+ hidden_dropout=0.0,
51
+ classifier_dropout=None,
52
+ attention_dropout=0.0,
53
+ layernorm_epsilon=1e-5,
54
+ rmsnorm=True,
55
+ apply_residual_connection_post_layernorm=False,
56
+ post_layer_norm=True,
57
+ add_bias_linear=False,
58
+ add_qkv_bias=False,
59
+ bias_dropout_fusion=True,
60
+ multi_query_attention=False,
61
+ multi_query_group_num=1,
62
+ apply_query_key_layer_scaling=True,
63
+ attention_softmax_in_fp32=True,
64
+ fp32_residual_connection=False,
65
+ quantization_bit=0,
66
+ pre_seq_len=None,
67
+ prefix_projection=False,
68
+ **kwargs
69
+ ):
70
+ self.num_layers = num_layers
71
+ self.vocab_size = padded_vocab_size
72
+ self.padded_vocab_size = padded_vocab_size
73
+ self.hidden_size = hidden_size
74
+ self.ffn_hidden_size = ffn_hidden_size
75
+ self.kv_channels = kv_channels
76
+ self.num_attention_heads = num_attention_heads
77
+ self.seq_length = seq_length
78
+ self.hidden_dropout = hidden_dropout
79
+ self.classifier_dropout = classifier_dropout
80
+ self.attention_dropout = attention_dropout
81
+ self.layernorm_epsilon = layernorm_epsilon
82
+ self.rmsnorm = rmsnorm
83
+ self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
84
+ self.post_layer_norm = post_layer_norm
85
+ self.add_bias_linear = add_bias_linear
86
+ self.add_qkv_bias = add_qkv_bias
87
+ self.bias_dropout_fusion = bias_dropout_fusion
88
+ self.multi_query_attention = multi_query_attention
89
+ self.multi_query_group_num = multi_query_group_num
90
+ self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
91
+ self.attention_softmax_in_fp32 = attention_softmax_in_fp32
92
+ self.fp32_residual_connection = fp32_residual_connection
93
+ self.quantization_bit = quantization_bit
94
+ self.pre_seq_len = pre_seq_len
95
+ self.prefix_projection = prefix_projection
96
+ super().__init__(**kwargs)
97
+
98
+
99
+
100
+ # flags required to enable jit fusion kernels
101
+
102
+ if sys.platform != 'darwin':
103
+ torch._C._jit_set_profiling_mode(False)
104
+ torch._C._jit_set_profiling_executor(False)
105
+ torch._C._jit_override_can_fuse_on_cpu(True)
106
+ torch._C._jit_override_can_fuse_on_gpu(True)
107
+
108
+ logger = logging.get_logger(__name__)
109
+
110
+ _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
111
+ _CONFIG_FOR_DOC = "ChatGLM6BConfig"
112
+
113
+ CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
114
+ "THUDM/chatglm3-6b-base",
115
+ # See all ChatGLM models at https://huggingface.co/models?filter=chatglm
116
+ ]
117
+
118
+
119
+ def default_init(cls, *args, **kwargs):
120
+ return cls(*args, **kwargs)
121
+
122
+
123
+ class InvalidScoreLogitsProcessor(LogitsProcessor):
124
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
125
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
126
+ scores.zero_()
127
+ scores[..., 5] = 5e4
128
+ return scores
129
+
130
+
131
+ class PrefixEncoder(torch.nn.Module):
132
+ """
133
+ The torch.nn model to encode the prefix
134
+ Input shape: (batch-size, prefix-length)
135
+ Output shape: (batch-size, prefix-length, 2*layers*hidden)
136
+ """
137
+
138
+ def __init__(self, config: ChatGLMConfig):
139
+ super().__init__()
140
+ self.prefix_projection = config.prefix_projection
141
+ if self.prefix_projection:
142
+ # Use a two-layer MLP to encode the prefix
143
+ kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2
144
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
145
+ self.trans = torch.nn.Sequential(
146
+ torch.nn.Linear(kv_size, config.hidden_size),
147
+ torch.nn.Tanh(),
148
+ torch.nn.Linear(config.hidden_size, kv_size)
149
+ )
150
+ else:
151
+ self.embedding = torch.nn.Embedding(config.pre_seq_len,
152
+ config.num_layers * config.kv_channels * config.multi_query_group_num * 2)
153
+
154
+ def forward(self, prefix: torch.Tensor):
155
+ if self.prefix_projection:
156
+ prefix_tokens = self.embedding(prefix)
157
+ past_key_values = self.trans(prefix_tokens)
158
+ else:
159
+ past_key_values = self.embedding(prefix)
160
+ return past_key_values
161
+
162
+
163
+ def split_tensor_along_last_dim(
164
+ tensor: torch.Tensor,
165
+ num_partitions: int,
166
+ contiguous_split_chunks: bool = False,
167
+ ) -> List[torch.Tensor]:
168
+ """Split a tensor along its last dimension.
169
+
170
+ Arguments:
171
+ tensor: input tensor.
172
+ num_partitions: number of partitions to split the tensor
173
+ contiguous_split_chunks: If True, make each chunk contiguous
174
+ in memory.
175
+
176
+ Returns:
177
+ A list of Tensors
178
+ """
179
+ # Get the size and dimension.
180
+ last_dim = tensor.dim() - 1
181
+ last_dim_size = tensor.size()[last_dim] // num_partitions
182
+ # Split.
183
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
184
+ # Note: torch.split does not create contiguous tensors by default.
185
+ if contiguous_split_chunks:
186
+ return tuple(chunk.contiguous() for chunk in tensor_list)
187
+
188
+ return tensor_list
189
+
190
+
191
+ class RotaryEmbedding(nn.Module):
192
+ def __init__(self, dim, original_impl=False, device=None, dtype=None):
193
+ super().__init__()
194
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
195
+ self.register_buffer("inv_freq", inv_freq)
196
+ self.dim = dim
197
+ self.original_impl = original_impl
198
+
199
+ def forward_impl(
200
+ self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
201
+ ):
202
+ """Enhanced Transformer with Rotary Position Embedding.
203
+
204
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
205
+ transformers/rope/__init__.py. MIT License:
206
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
207
+ """
208
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
209
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))
210
+
211
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
212
+ seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)
213
+
214
+ # Calculate the product of position index and $\theta_i$
215
+ idx_theta = torch.outer(seq_idx, theta).float()
216
+
217
+ cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
218
+
219
+ # this is to mimic the behaviour of complex32, else we will get different results
220
+ if dtype in (torch.float16, torch.bfloat16, torch.int8):
221
+ cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
222
+ return cache
223
+
224
+ def forward(self, max_seq_len, offset=0):
225
+ return self.forward_impl(
226
+ max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
227
+ )
228
+
229
+
230
+ @torch.jit.script
231
+ def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
232
+ # x: [sq, b, np, hn]
233
+ sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
234
+ rot_dim = rope_cache.shape[-2] * 2
235
+ x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
236
+ # truncate to support variable sizes
237
+ rope_cache = rope_cache[:sq]
238
+ xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
239
+ rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
240
+ x_out2 = torch.stack(
241
+ [
242
+ xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
243
+ xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
244
+ ],
245
+ -1,
246
+ )
247
+ x_out2 = x_out2.flatten(3)
248
+ return torch.cat((x_out2, x_pass), dim=-1)
249
+
250
+
251
+ class RMSNorm(torch.nn.Module):
252
+ def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
253
+ super().__init__()
254
+ self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
255
+ self.eps = eps
256
+
257
+ def forward(self, hidden_states: torch.Tensor):
258
+ input_dtype = hidden_states.dtype
259
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
260
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
261
+
262
+ return (self.weight * hidden_states).to(input_dtype)
263
+
264
+
265
+ class CoreAttention(torch.nn.Module):
266
+ def __init__(self, config: ChatGLMConfig, layer_number):
267
+ super(CoreAttention, self).__init__()
268
+
269
+ self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
270
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
271
+ if self.apply_query_key_layer_scaling:
272
+ self.attention_softmax_in_fp32 = True
273
+ self.layer_number = max(1, layer_number)
274
+
275
+ projection_size = config.kv_channels * config.num_attention_heads
276
+
277
+ # Per attention head and per partition values.
278
+ self.hidden_size_per_partition = projection_size
279
+ self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
280
+ self.num_attention_heads_per_partition = config.num_attention_heads
281
+
282
+ coeff = None
283
+ self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
284
+ if self.apply_query_key_layer_scaling:
285
+ coeff = self.layer_number
286
+ self.norm_factor *= coeff
287
+ self.coeff = coeff
288
+
289
+ self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
290
+
291
+ def forward(self, query_layer, key_layer, value_layer, attention_mask):
292
+ pytorch_major_version = int(torch.__version__.split('.')[0])
293
+ if pytorch_major_version >= 2:
294
+ query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
295
+ if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
296
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
297
+ is_causal=True)
298
+ else:
299
+ if attention_mask is not None:
300
+ attention_mask = ~attention_mask
301
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
302
+ attention_mask)
303
+ context_layer = context_layer.permute(2, 0, 1, 3)
304
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
305
+ context_layer = context_layer.reshape(*new_context_layer_shape)
306
+ else:
307
+ # Raw attention scores
308
+
309
+ # [b, np, sq, sk]
310
+ output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
311
+
312
+ # [sq, b, np, hn] -> [sq, b * np, hn]
313
+ query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
314
+ # [sk, b, np, hn] -> [sk, b * np, hn]
315
+ key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
316
+
317
+ # preallocting input tensor: [b * np, sq, sk]
318
+ matmul_input_buffer = torch.empty(
319
+ output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
320
+ device=query_layer.device
321
+ )
322
+
323
+ # Raw attention scores. [b * np, sq, sk]
324
+ matmul_result = torch.baddbmm(
325
+ matmul_input_buffer,
326
+ query_layer.transpose(0, 1), # [b * np, sq, hn]
327
+ key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
328
+ beta=0.0,
329
+ alpha=(1.0 / self.norm_factor),
330
+ )
331
+
332
+ # change view to [b, np, sq, sk]
333
+ attention_scores = matmul_result.view(*output_size)
334
+
335
+ # ===========================
336
+ # Attention probs and dropout
337
+ # ===========================
338
+
339
+ # attention scores and attention mask [b, np, sq, sk]
340
+ if self.attention_softmax_in_fp32:
341
+ attention_scores = attention_scores.float()
342
+ if self.coeff is not None:
343
+ attention_scores = attention_scores * self.coeff
344
+ if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
345
+ attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
346
+ device=attention_scores.device, dtype=torch.bool)
347
+ attention_mask.tril_()
348
+ attention_mask = ~attention_mask
349
+ if attention_mask is not None:
350
+ attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
351
+ attention_probs = F.softmax(attention_scores, dim=-1)
352
+ attention_probs = attention_probs.type_as(value_layer)
353
+
354
+ # This is actually dropping out entire tokens to attend to, which might
355
+ # seem a bit unusual, but is taken from the original Transformer paper.
356
+ attention_probs = self.attention_dropout(attention_probs)
357
+ # =========================
358
+ # Context layer. [sq, b, hp]
359
+ # =========================
360
+
361
+ # value_layer -> context layer.
362
+ # [sk, b, np, hn] --> [b, np, sq, hn]
363
+
364
+ # context layer shape: [b, np, sq, hn]
365
+ output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
366
+ # change view [sk, b * np, hn]
367
+ value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
368
+ # change view [b * np, sq, sk]
369
+ attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
370
+ # matmul: [b * np, sq, hn]
371
+ context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
372
+ # change view [b, np, sq, hn]
373
+ context_layer = context_layer.view(*output_size)
374
+ # [b, np, sq, hn] --> [sq, b, np, hn]
375
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
376
+ # [sq, b, np, hn] --> [sq, b, hp]
377
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
378
+ context_layer = context_layer.view(*new_context_layer_shape)
379
+
380
+ return context_layer
381
+
382
+
383
+ class SelfAttention(torch.nn.Module):
384
+ """Parallel self-attention layer abstract class.
385
+
386
+ Self-attention layer takes input with size [s, b, h]
387
+ and returns output of the same size.
388
+ """
389
+
390
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
391
+ super(SelfAttention, self).__init__()
392
+ self.layer_number = max(1, layer_number)
393
+
394
+ self.projection_size = config.kv_channels * config.num_attention_heads
395
+
396
+ # Per attention head and per partition values.
397
+ self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
398
+ self.num_attention_heads_per_partition = config.num_attention_heads
399
+
400
+ self.multi_query_attention = config.multi_query_attention
401
+ self.qkv_hidden_size = 3 * self.projection_size
402
+ if self.multi_query_attention:
403
+ self.num_multi_query_groups_per_partition = config.multi_query_group_num
404
+ self.qkv_hidden_size = (
405
+ self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
406
+ )
407
+ self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
408
+ bias=config.add_bias_linear or config.add_qkv_bias,
409
+ device=device, **_config_to_kwargs(config)
410
+ )
411
+
412
+ self.core_attention = CoreAttention(config, self.layer_number)
413
+
414
+ # Output.
415
+ self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
416
+ device=device, **_config_to_kwargs(config)
417
+ )
418
+
419
+ def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
420
+ if self.multi_query_attention:
421
+ num_attention_heads = self.num_multi_query_groups_per_partition
422
+ else:
423
+ num_attention_heads = self.num_attention_heads_per_partition
424
+ return torch.empty(
425
+ inference_max_sequence_len,
426
+ batch_size,
427
+ num_attention_heads,
428
+ self.hidden_size_per_attention_head,
429
+ dtype=dtype,
430
+ device=device,
431
+ )
432
+
433
+ def forward(
434
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True
435
+ ):
436
+ # hidden_states: [sq, b, h]
437
+
438
+ # =================================================
439
+ # Pre-allocate memory for key-values for inference.
440
+ # =================================================
441
+ # =====================
442
+ # Query, Key, and Value
443
+ # =====================
444
+
445
+ # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
446
+ mixed_x_layer = self.query_key_value(hidden_states)
447
+
448
+ if self.multi_query_attention:
449
+ (query_layer, key_layer, value_layer) = mixed_x_layer.split(
450
+ [
451
+ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
452
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
453
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
454
+ ],
455
+ dim=-1,
456
+ )
457
+ query_layer = query_layer.view(
458
+ query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
459
+ )
460
+ key_layer = key_layer.view(
461
+ key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
462
+ )
463
+ value_layer = value_layer.view(
464
+ value_layer.size()[:-1]
465
+ + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
466
+ )
467
+ else:
468
+ new_tensor_shape = mixed_x_layer.size()[:-1] + \
469
+ (self.num_attention_heads_per_partition,
470
+ 3 * self.hidden_size_per_attention_head)
471
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
472
+
473
+ # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
474
+ (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
475
+
476
+ # apply relative positional encoding (rotary embedding)
477
+ if rotary_pos_emb is not None:
478
+ query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
479
+ key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
480
+
481
+ # adjust key and value for inference
482
+ if kv_cache is not None:
483
+ cache_k, cache_v = kv_cache
484
+ key_layer = torch.cat((cache_k, key_layer), dim=0)
485
+ value_layer = torch.cat((cache_v, value_layer), dim=0)
486
+ if use_cache:
487
+ kv_cache = (key_layer, value_layer)
488
+ else:
489
+ kv_cache = None
490
+
491
+ if self.multi_query_attention:
492
+ key_layer = key_layer.unsqueeze(-2)
493
+ key_layer = key_layer.expand(
494
+ -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
495
+ )
496
+ key_layer = key_layer.contiguous().view(
497
+ key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
498
+ )
499
+ value_layer = value_layer.unsqueeze(-2)
500
+ value_layer = value_layer.expand(
501
+ -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
502
+ )
503
+ value_layer = value_layer.contiguous().view(
504
+ value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
505
+ )
506
+
507
+ # ==================================
508
+ # core attention computation
509
+ # ==================================
510
+
511
+ context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
512
+
513
+ # =================
514
+ # Output. [sq, b, h]
515
+ # =================
516
+
517
+ output = self.dense(context_layer)
518
+
519
+ return output, kv_cache
520
+
521
+
522
+ def _config_to_kwargs(args):
523
+ common_kwargs = {
524
+ "dtype": args.torch_dtype,
525
+ }
526
+ return common_kwargs
527
+
528
+
529
+ class MLP(torch.nn.Module):
530
+ """MLP.
531
+
532
+ MLP will take the input with h hidden state, project it to 4*h
533
+ hidden dimension, perform nonlinear transformation, and project the
534
+ state back into h hidden dimension.
535
+ """
536
+
537
+ def __init__(self, config: ChatGLMConfig, device=None):
538
+ super(MLP, self).__init__()
539
+
540
+ self.add_bias = config.add_bias_linear
541
+
542
+ # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
543
+ self.dense_h_to_4h = nn.Linear(
544
+ config.hidden_size,
545
+ config.ffn_hidden_size * 2,
546
+ bias=self.add_bias,
547
+ device=device,
548
+ **_config_to_kwargs(config)
549
+ )
550
+
551
+ def swiglu(x):
552
+ x = torch.chunk(x, 2, dim=-1)
553
+ return F.silu(x[0]) * x[1]
554
+
555
+ self.activation_func = swiglu
556
+
557
+ # Project back to h.
558
+ self.dense_4h_to_h = nn.Linear(
559
+ config.ffn_hidden_size,
560
+ config.hidden_size,
561
+ bias=self.add_bias,
562
+ device=device,
563
+ **_config_to_kwargs(config)
564
+ )
565
+
566
+ def forward(self, hidden_states):
567
+ # [s, b, 4hp]
568
+ intermediate_parallel = self.dense_h_to_4h(hidden_states)
569
+ intermediate_parallel = self.activation_func(intermediate_parallel)
570
+ # [s, b, h]
571
+ output = self.dense_4h_to_h(intermediate_parallel)
572
+ return output
573
+
574
+
575
+ class GLMBlock(torch.nn.Module):
576
+ """A single transformer layer.
577
+
578
+ Transformer layer takes input with size [s, b, h] and returns an
579
+ output of the same size.
580
+ """
581
+
582
+ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
583
+ super(GLMBlock, self).__init__()
584
+ self.layer_number = layer_number
585
+
586
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
587
+
588
+ self.fp32_residual_connection = config.fp32_residual_connection
589
+
590
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
591
+ # Layernorm on the input data.
592
+ self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
593
+ dtype=config.torch_dtype)
594
+
595
+ # Self attention.
596
+ self.self_attention = SelfAttention(config, layer_number, device=device)
597
+ self.hidden_dropout = config.hidden_dropout
598
+
599
+ # Layernorm on the attention output
600
+ self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
601
+ dtype=config.torch_dtype)
602
+
603
+ # MLP
604
+ self.mlp = MLP(config, device=device)
605
+
606
+ def forward(
607
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
608
+ ):
609
+ # hidden_states: [s, b, h]
610
+
611
+ # Layer norm at the beginning of the transformer layer.
612
+ layernorm_output = self.input_layernorm(hidden_states)
613
+ # Self attention.
614
+ attention_output, kv_cache = self.self_attention(
615
+ layernorm_output,
616
+ attention_mask,
617
+ rotary_pos_emb,
618
+ kv_cache=kv_cache,
619
+ use_cache=use_cache
620
+ )
621
+
622
+ # Residual connection.
623
+ if self.apply_residual_connection_post_layernorm:
624
+ residual = layernorm_output
625
+ else:
626
+ residual = hidden_states
627
+
628
+ layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
629
+ layernorm_input = residual + layernorm_input
630
+
631
+ # Layer norm post the self attention.
632
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
633
+
634
+ # MLP.
635
+ mlp_output = self.mlp(layernorm_output)
636
+
637
+ # Second residual connection.
638
+ if self.apply_residual_connection_post_layernorm:
639
+ residual = layernorm_output
640
+ else:
641
+ residual = layernorm_input
642
+
643
+ output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
644
+ output = residual + output
645
+
646
+ return output, kv_cache
647
+
648
+
649
+ class GLMTransformer(torch.nn.Module):
650
+ """Transformer class."""
651
+
652
+ def __init__(self, config: ChatGLMConfig, device=None):
653
+ super(GLMTransformer, self).__init__()
654
+
655
+ self.fp32_residual_connection = config.fp32_residual_connection
656
+ self.post_layer_norm = config.post_layer_norm
657
+
658
+ # Number of layers.
659
+ self.num_layers = config.num_layers
660
+
661
+ # Transformer layers.
662
+ def build_layer(layer_number):
663
+ return GLMBlock(config, layer_number, device=device)
664
+
665
+ self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
666
+
667
+ if self.post_layer_norm:
668
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
669
+ # Final layer norm before output.
670
+ self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
671
+ dtype=config.torch_dtype)
672
+
673
+ self.gradient_checkpointing = False
674
+
675
+ def _get_layer(self, layer_number):
676
+ return self.layers[layer_number]
677
+
678
+ def forward(
679
+ self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
680
+ use_cache: Optional[bool] = True,
681
+ output_hidden_states: Optional[bool] = False,
682
+ ):
683
+ if not kv_caches:
684
+ kv_caches = [None for _ in range(self.num_layers)]
685
+ presents = () if use_cache else None
686
+ if self.gradient_checkpointing and self.training:
687
+ if use_cache:
688
+ logger.warning_once(
689
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
690
+ )
691
+ use_cache = False
692
+
693
+ all_self_attentions = None
694
+ all_hidden_states = () if output_hidden_states else None
695
+ for index in range(self.num_layers):
696
+ if output_hidden_states:
697
+ all_hidden_states = all_hidden_states + (hidden_states,)
698
+
699
+ layer = self._get_layer(index)
700
+ if self.gradient_checkpointing and self.training:
701
+ layer_ret = torch.utils.checkpoint.checkpoint(
702
+ layer,
703
+ hidden_states,
704
+ attention_mask,
705
+ rotary_pos_emb,
706
+ kv_caches[index],
707
+ use_cache
708
+ )
709
+ else:
710
+ layer_ret = layer(
711
+ hidden_states,
712
+ attention_mask,
713
+ rotary_pos_emb,
714
+ kv_cache=kv_caches[index],
715
+ use_cache=use_cache
716
+ )
717
+ hidden_states, kv_cache = layer_ret
718
+ if use_cache:
719
+ presents = presents + (kv_cache,)
720
+
721
+ if output_hidden_states:
722
+ all_hidden_states = all_hidden_states + (hidden_states,)
723
+
724
+ # Final layer norm.
725
+ if self.post_layer_norm:
726
+ hidden_states = self.final_layernorm(hidden_states)
727
+
728
+ return hidden_states, presents, all_hidden_states, all_self_attentions
729
+
730
+
731
+ class ChatGLMPreTrainedModel(PreTrainedModel):
732
+ """
733
+ An abstract class to handle weights initialization and
734
+ a simple interface for downloading and loading pretrained models.
735
+ """
736
+
737
+ is_parallelizable = False
738
+ supports_gradient_checkpointing = True
739
+ config_class = ChatGLMConfig
740
+ base_model_prefix = "transformer"
741
+ _no_split_modules = ["GLMBlock"]
742
+
743
+ def _init_weights(self, module: nn.Module):
744
+ """Initialize the weights."""
745
+ return
746
+
747
+ def get_masks(self, input_ids, past_key_values, padding_mask=None):
748
+ batch_size, seq_length = input_ids.shape
749
+ full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
750
+ full_attention_mask.tril_()
751
+ past_length = 0
752
+ if past_key_values:
753
+ past_length = past_key_values[0][0].shape[0]
754
+ if past_length:
755
+ full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
756
+ device=input_ids.device), full_attention_mask), dim=-1)
757
+ if padding_mask is not None:
758
+ full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
759
+ if not past_length and padding_mask is not None:
760
+ full_attention_mask -= padding_mask.unsqueeze(-1) - 1
761
+ full_attention_mask = (full_attention_mask < 0.5).bool()
762
+ full_attention_mask.unsqueeze_(1)
763
+ return full_attention_mask
764
+
765
+ def get_position_ids(self, input_ids, device):
766
+ batch_size, seq_length = input_ids.shape
767
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
768
+ return position_ids
769
+
770
+ def _set_gradient_checkpointing(self, module, value=False):
771
+ if isinstance(module, GLMTransformer):
772
+ module.gradient_checkpointing = value
773
+
774
+
775
+ class Embedding(torch.nn.Module):
776
+ """Language model embeddings."""
777
+
778
+ def __init__(self, config: ChatGLMConfig, device=None):
779
+ super(Embedding, self).__init__()
780
+
781
+ self.hidden_size = config.hidden_size
782
+ # Word embeddings (parallel).
783
+ self.word_embeddings = nn.Embedding(
784
+ config.padded_vocab_size,
785
+ self.hidden_size,
786
+ dtype=config.torch_dtype,
787
+ device=device
788
+ )
789
+ self.fp32_residual_connection = config.fp32_residual_connection
790
+
791
+ def forward(self, input_ids):
792
+ # Embeddings.
793
+ words_embeddings = self.word_embeddings(input_ids)
794
+ embeddings = words_embeddings
795
+ # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
796
+ embeddings = embeddings.transpose(0, 1).contiguous()
797
+ # If the input flag for fp32 residual connection is set, convert for float.
798
+ if self.fp32_residual_connection:
799
+ embeddings = embeddings.float()
800
+ return embeddings
801
+
802
+
803
+ class ChatGLMModel(ChatGLMPreTrainedModel):
804
+ def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
805
+ super().__init__(config)
806
+ if empty_init:
807
+ init_method = skip_init
808
+ else:
809
+ init_method = default_init
810
+ init_kwargs = {}
811
+ if device is not None:
812
+ init_kwargs["device"] = device
813
+ self.embedding = init_method(Embedding, config, **init_kwargs)
814
+ self.num_layers = config.num_layers
815
+ self.multi_query_group_num = config.multi_query_group_num
816
+ self.kv_channels = config.kv_channels
817
+
818
+ # Rotary positional embeddings
819
+ self.seq_length = config.seq_length
820
+ rotary_dim = (
821
+ config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
822
+ )
823
+
824
+ self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
825
+ dtype=config.torch_dtype)
826
+ self.encoder = init_method(GLMTransformer, config, **init_kwargs)
827
+ self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
828
+ dtype=config.torch_dtype, **init_kwargs)
829
+ self.pre_seq_len = config.pre_seq_len
830
+ self.prefix_projection = config.prefix_projection
831
+ if self.pre_seq_len is not None:
832
+ for param in self.parameters():
833
+ param.requires_grad = False
834
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
835
+ self.prefix_encoder = PrefixEncoder(config)
836
+ self.dropout = torch.nn.Dropout(0.1)
837
+
838
+ def get_input_embeddings(self):
839
+ return self.embedding.word_embeddings
840
+
841
+ def get_prompt(self, batch_size, device, dtype=torch.half):
842
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
843
+ past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
844
+ past_key_values = past_key_values.view(
845
+ batch_size,
846
+ self.pre_seq_len,
847
+ self.num_layers * 2,
848
+ self.multi_query_group_num,
849
+ self.kv_channels
850
+ )
851
+ # seq_len, b, nh, hidden_size
852
+ past_key_values = self.dropout(past_key_values)
853
+ past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
854
+ return past_key_values
855
+
856
+ def forward(
857
+ self,
858
+ input_ids,
859
+ position_ids: Optional[torch.Tensor] = None,
860
+ attention_mask: Optional[torch.BoolTensor] = None,
861
+ full_attention_mask: Optional[torch.BoolTensor] = None,
862
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
863
+ inputs_embeds: Optional[torch.Tensor] = None,
864
+ use_cache: Optional[bool] = None,
865
+ output_hidden_states: Optional[bool] = None,
866
+ return_dict: Optional[bool] = None,
867
+ ):
868
+ output_hidden_states = (
869
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
870
+ )
871
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
872
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
873
+
874
+ batch_size, seq_length = input_ids.shape
875
+
876
+ if inputs_embeds is None:
877
+ inputs_embeds = self.embedding(input_ids)
878
+
879
+ if self.pre_seq_len is not None:
880
+ if past_key_values is None:
881
+ past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
882
+ dtype=inputs_embeds.dtype)
883
+ if attention_mask is not None:
884
+ attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)),
885
+ attention_mask], dim=-1)
886
+
887
+ if full_attention_mask is None:
888
+ if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
889
+ full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
890
+
891
+ # Rotary positional embeddings
892
+ rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
893
+ if position_ids is not None:
894
+ rotary_pos_emb = rotary_pos_emb[position_ids]
895
+ else:
896
+ rotary_pos_emb = rotary_pos_emb[None, :seq_length]
897
+ rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
898
+
899
+ # Run encoder.
900
+ hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
901
+ inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
902
+ kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
903
+ )
904
+
905
+ if not return_dict:
906
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
907
+
908
+ return BaseModelOutputWithPast(
909
+ last_hidden_state=hidden_states,
910
+ past_key_values=presents,
911
+ hidden_states=all_hidden_states,
912
+ attentions=all_self_attentions,
913
+ )
914
+
915
+ def quantize(self, weight_bit_width: int):
916
+ from .quantization import quantize
917
+ quantize(self.encoder, weight_bit_width)
918
+ return self
919
+
920
+
921
+ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
922
+ def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
923
+ super().__init__(config)
924
+
925
+ self.max_sequence_length = config.max_length
926
+ self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
927
+ self.config = config
928
+ self.quantized = False
929
+
930
+ if self.config.quantization_bit:
931
+ self.quantize(self.config.quantization_bit, empty_init=True)
932
+
933
+ def _update_model_kwargs_for_generation(
934
+ self,
935
+ outputs: ModelOutput,
936
+ model_kwargs: Dict[str, Any],
937
+ is_encoder_decoder: bool = False,
938
+ standardize_cache_format: bool = False,
939
+ ) -> Dict[str, Any]:
940
+ # update past_key_values
941
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
942
+ outputs, standardize_cache_format=standardize_cache_format
943
+ )
944
+
945
+ # update attention mask
946
+ if "attention_mask" in model_kwargs:
947
+ attention_mask = model_kwargs["attention_mask"]
948
+ model_kwargs["attention_mask"] = torch.cat(
949
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
950
+ )
951
+
952
+ # update position ids
953
+ if "position_ids" in model_kwargs:
954
+ position_ids = model_kwargs["position_ids"]
955
+ new_position_id = position_ids[..., -1:].clone()
956
+ new_position_id += 1
957
+ model_kwargs["position_ids"] = torch.cat(
958
+ [position_ids, new_position_id], dim=-1
959
+ )
960
+
961
+ model_kwargs["is_first_forward"] = False
962
+ return model_kwargs
963
+
964
+ def prepare_inputs_for_generation(
965
+ self,
966
+ input_ids: torch.LongTensor,
967
+ past_key_values: Optional[torch.Tensor] = None,
968
+ attention_mask: Optional[torch.Tensor] = None,
969
+ position_ids: Optional[torch.Tensor] = None,
970
+ use_cache: Optional[bool] = None,
971
+ is_first_forward: bool = True,
972
+ **kwargs
973
+ ) -> dict:
974
+ # only last token for input_ids if past is not None
975
+ if position_ids is None:
976
+ position_ids = self.get_position_ids(input_ids, device=input_ids.device)
977
+ if not is_first_forward:
978
+ if past_key_values is not None:
979
+ position_ids = position_ids[..., -1:]
980
+ input_ids = input_ids[:, -1:]
981
+ return {
982
+ "input_ids": input_ids,
983
+ "past_key_values": past_key_values,
984
+ "position_ids": position_ids,
985
+ "attention_mask": attention_mask,
986
+ "return_last_logit": True,
987
+ "use_cache": use_cache
988
+ }
989
+
990
+ def forward(
991
+ self,
992
+ input_ids: Optional[torch.Tensor] = None,
993
+ position_ids: Optional[torch.Tensor] = None,
994
+ attention_mask: Optional[torch.Tensor] = None,
995
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
996
+ inputs_embeds: Optional[torch.Tensor] = None,
997
+ labels: Optional[torch.Tensor] = None,
998
+ use_cache: Optional[bool] = None,
999
+ output_attentions: Optional[bool] = None,
1000
+ output_hidden_states: Optional[bool] = None,
1001
+ return_dict: Optional[bool] = None,
1002
+ return_last_logit: Optional[bool] = False,
1003
+ ):
1004
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1005
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1006
+
1007
+ transformer_outputs = self.transformer(
1008
+ input_ids=input_ids,
1009
+ position_ids=position_ids,
1010
+ attention_mask=attention_mask,
1011
+ past_key_values=past_key_values,
1012
+ inputs_embeds=inputs_embeds,
1013
+ use_cache=use_cache,
1014
+ output_hidden_states=output_hidden_states,
1015
+ return_dict=return_dict,
1016
+ )
1017
+
1018
+ hidden_states = transformer_outputs[0]
1019
+ if return_last_logit:
1020
+ hidden_states = hidden_states[-1:]
1021
+ lm_logits = self.transformer.output_layer(hidden_states)
1022
+ lm_logits = lm_logits.transpose(0, 1).contiguous()
1023
+
1024
+ loss = None
1025
+ if labels is not None:
1026
+ lm_logits = lm_logits.to(torch.float32)
1027
+
1028
+ # Shift so that tokens < n predict n
1029
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1030
+ shift_labels = labels[..., 1:].contiguous()
1031
+ # Flatten the tokens
1032
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
1033
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1034
+
1035
+ lm_logits = lm_logits.to(hidden_states.dtype)
1036
+ loss = loss.to(hidden_states.dtype)
1037
+
1038
+ if not return_dict:
1039
+ output = (lm_logits,) + transformer_outputs[1:]
1040
+ return ((loss,) + output) if loss is not None else output
1041
+
1042
+ return CausalLMOutputWithPast(
1043
+ loss=loss,
1044
+ logits=lm_logits,
1045
+ past_key_values=transformer_outputs.past_key_values,
1046
+ hidden_states=transformer_outputs.hidden_states,
1047
+ attentions=transformer_outputs.attentions,
1048
+ )
1049
+
1050
+ @staticmethod
1051
+ def _reorder_cache(
1052
+ past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
1053
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
1054
+ """
1055
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1056
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1057
+ beam_idx at every generation step.
1058
+
1059
+ Output shares the same memory storage as `past`.
1060
+ """
1061
+ return tuple(
1062
+ (
1063
+ layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
1064
+ layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
1065
+ )
1066
+ for layer_past in past
1067
+ )
1068
+
1069
+ def process_response(self, output, history):
1070
+ content = ""
1071
+ history = deepcopy(history)
1072
+ for response in output.split("<|assistant|>"):
1073
+ metadata, content = response.split("\n", maxsplit=1)
1074
+ if not metadata.strip():
1075
+ content = content.strip()
1076
+ history.append({"role": "assistant", "metadata": metadata, "content": content})
1077
+ content = content.replace("[[训练时间]]", "2023年")
1078
+ else:
1079
+ history.append({"role": "assistant", "metadata": metadata, "content": content})
1080
+ if history[0]["role"] == "system" and "tools" in history[0]:
1081
+ content = "\n".join(content.split("\n")[1:-1])
1082
+ def tool_call(**kwargs):
1083
+ return kwargs
1084
+ parameters = eval(content)
1085
+ content = {"name": metadata.strip(), "parameters": parameters}
1086
+ else:
1087
+ content = {"name": metadata.strip(), "content": content}
1088
+ return content, history
1089
+
1090
+ @torch.inference_mode()
1091
+ def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
1092
+ max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
1093
+ **kwargs):
1094
+ if history is None:
1095
+ history = []
1096
+ if logits_processor is None:
1097
+ logits_processor = LogitsProcessorList()
1098
+ logits_processor.append(InvalidScoreLogitsProcessor())
1099
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1100
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1101
+ inputs = tokenizer.build_chat_input(query, history=history, role=role)
1102
+ inputs = inputs.to(self.device)
1103
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
1104
+ tokenizer.get_command("<|observation|>")]
1105
+ outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
1106
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1107
+ response = tokenizer.decode(outputs)
1108
+ history.append({"role": role, "content": query})
1109
+ response, history = self.process_response(response, history)
1110
+ return response, history
1111
+
1112
+ @torch.inference_mode()
1113
+ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user",
1114
+ past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
1115
+ logits_processor=None, return_past_key_values=False, **kwargs):
1116
+ if history is None:
1117
+ history = []
1118
+ if logits_processor is None:
1119
+ logits_processor = LogitsProcessorList()
1120
+ logits_processor.append(InvalidScoreLogitsProcessor())
1121
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
1122
+ tokenizer.get_command("<|observation|>")]
1123
+ gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
1124
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1125
+ if past_key_values is None:
1126
+ inputs = tokenizer.build_chat_input(query, history=history, role=role)
1127
+ else:
1128
+ inputs = tokenizer.build_chat_input(query, role=role)
1129
+ inputs = inputs.to(self.device)
1130
+ if past_key_values is not None:
1131
+ past_length = past_key_values[0][0].shape[0]
1132
+ if self.transformer.pre_seq_len is not None:
1133
+ past_length -= self.transformer.pre_seq_len
1134
+ inputs.position_ids += past_length
1135
+ attention_mask = inputs.attention_mask
1136
+ attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
1137
+ inputs['attention_mask'] = attention_mask
1138
+ history.append({"role": role, "content": query})
1139
+ for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
1140
+ eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
1141
+ **gen_kwargs):
1142
+ if return_past_key_values:
1143
+ outputs, past_key_values = outputs
1144
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1145
+ response = tokenizer.decode(outputs)
1146
+ if response and response[-1] != "�":
1147
+ response, new_history = self.process_response(response, history)
1148
+ if return_past_key_values:
1149
+ yield response, new_history, past_key_values
1150
+ else:
1151
+ yield response, new_history
1152
+
1153
+ @torch.inference_mode()
1154
+ def stream_generate(
1155
+ self,
1156
+ input_ids,
1157
+ generation_config: Optional[GenerationConfig] = None,
1158
+ logits_processor: Optional[LogitsProcessorList] = None,
1159
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
1160
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
1161
+ return_past_key_values=False,
1162
+ **kwargs,
1163
+ ):
1164
+ batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
1165
+
1166
+ if generation_config is None:
1167
+ generation_config = self.generation_config
1168
+ generation_config = copy.deepcopy(generation_config)
1169
+ model_kwargs = generation_config.update(**kwargs)
1170
+ model_kwargs["use_cache"] = generation_config.use_cache
1171
+ bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
1172
+
1173
+ if isinstance(eos_token_id, int):
1174
+ eos_token_id = [eos_token_id]
1175
+ eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
1176
+
1177
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
1178
+ if has_default_max_length and generation_config.max_new_tokens is None:
1179
+ warnings.warn(
1180
+ f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
1181
+ "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
1182
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
1183
+ UserWarning,
1184
+ )
1185
+ elif generation_config.max_new_tokens is not None:
1186
+ generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
1187
+ if not has_default_max_length:
1188
+ logger.warn(
1189
+ f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
1190
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
1191
+ "Please refer to the documentation for more information. "
1192
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
1193
+ UserWarning,
1194
+ )
1195
+
1196
+ if input_ids_seq_length >= generation_config.max_length:
1197
+ input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
1198
+ logger.warning(
1199
+ f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
1200
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
1201
+ " increasing `max_new_tokens`."
1202
+ )
1203
+
1204
+ # 2. Set generation parameters if not already defined
1205
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1206
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1207
+
1208
+ logits_processor = self._get_logits_processor(
1209
+ generation_config=generation_config,
1210
+ input_ids_seq_length=input_ids_seq_length,
1211
+ encoder_input_ids=input_ids,
1212
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1213
+ logits_processor=logits_processor,
1214
+ )
1215
+
1216
+ stopping_criteria = self._get_stopping_criteria(
1217
+ generation_config=generation_config, stopping_criteria=stopping_criteria
1218
+ )
1219
+ logits_warper = self._get_logits_warper(generation_config)
1220
+
1221
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
1222
+ scores = None
1223
+ while True:
1224
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1225
+ # forward pass to get next token
1226
+ outputs = self(
1227
+ **model_inputs,
1228
+ return_dict=True,
1229
+ output_attentions=False,
1230
+ output_hidden_states=False,
1231
+ )
1232
+
1233
+ next_token_logits = outputs.logits[:, -1, :]
1234
+
1235
+ # pre-process distribution
1236
+ next_token_scores = logits_processor(input_ids, next_token_logits)
1237
+ next_token_scores = logits_warper(input_ids, next_token_scores)
1238
+
1239
+ # sample
1240
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
1241
+ if generation_config.do_sample:
1242
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1243
+ else:
1244
+ next_tokens = torch.argmax(probs, dim=-1)
1245
+ # update generated ids, model inputs, and length for next step
1246
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
1247
+ model_kwargs = self._update_model_kwargs_for_generation(
1248
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1249
+ )
1250
+ unfinished_sequences = unfinished_sequences.mul(
1251
+ next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
1252
+ )
1253
+ if return_past_key_values:
1254
+ yield input_ids, outputs.past_key_values
1255
+ else:
1256
+ yield input_ids
1257
+ # stop when each sentence is finished, or if we exceed the maximum length
1258
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
1259
+ break
1260
+
1261
+ def quantize(self, bits: int, empty_init=False, device=None, **kwargs):
1262
+ if bits == 0:
1263
+ return
1264
+
1265
+ from .quantization import quantize
1266
+
1267
+ if self.quantized:
1268
+ logger.info("Already quantized.")
1269
+ return self
1270
+
1271
+ self.quantized = True
1272
+
1273
+ self.config.quantization_bit = bits
1274
+
1275
+ self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device,
1276
+ **kwargs)
1277
+ return self
1278
+
1279
+
1280
+ class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1281
+ def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
1282
+ super().__init__(config)
1283
+
1284
+ self.num_labels = config.num_labels
1285
+ self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
1286
+
1287
+ self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
1288
+ if config.classifier_dropout is not None:
1289
+ self.dropout = nn.Dropout(config.classifier_dropout)
1290
+ else:
1291
+ self.dropout = None
1292
+ self.config = config
1293
+
1294
+ if self.config.quantization_bit:
1295
+ self.quantize(self.config.quantization_bit, empty_init=True)
1296
+
1297
+ def forward(
1298
+ self,
1299
+ input_ids: Optional[torch.LongTensor] = None,
1300
+ position_ids: Optional[torch.LongTensor] = None,
1301
+ attention_mask: Optional[torch.Tensor] = None,
1302
+ full_attention_mask: Optional[torch.Tensor] = None,
1303
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1304
+ inputs_embeds: Optional[torch.LongTensor] = None,
1305
+ labels: Optional[torch.LongTensor] = None,
1306
+ use_cache: Optional[bool] = None,
1307
+ output_hidden_states: Optional[bool] = None,
1308
+ return_dict: Optional[bool] = None,
1309
+ ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
1310
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1311
+
1312
+ transformer_outputs = self.transformer(
1313
+ input_ids=input_ids,
1314
+ position_ids=position_ids,
1315
+ attention_mask=attention_mask,
1316
+ full_attention_mask=full_attention_mask,
1317
+ past_key_values=past_key_values,
1318
+ inputs_embeds=inputs_embeds,
1319
+ use_cache=use_cache,
1320
+ output_hidden_states=output_hidden_states,
1321
+ return_dict=return_dict,
1322
+ )
1323
+
1324
+ hidden_states = transformer_outputs[0]
1325
+ pooled_hidden_states = hidden_states[-1]
1326
+ if self.dropout is not None:
1327
+ pooled_hidden_states = self.dropout(pooled_hidden_states)
1328
+ logits = self.classifier_head(pooled_hidden_states)
1329
+
1330
+ loss = None
1331
+ if labels is not None:
1332
+ if self.config.problem_type is None:
1333
+ if self.num_labels == 1:
1334
+ self.config.problem_type = "regression"
1335
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1336
+ self.config.problem_type = "single_label_classification"
1337
+ else:
1338
+ self.config.problem_type = "multi_label_classification"
1339
+
1340
+ if self.config.problem_type == "regression":
1341
+ loss_fct = MSELoss()
1342
+ if self.num_labels == 1:
1343
+ loss = loss_fct(logits.squeeze().float(), labels.squeeze())
1344
+ else:
1345
+ loss = loss_fct(logits.float(), labels)
1346
+ elif self.config.problem_type == "single_label_classification":
1347
+ loss_fct = CrossEntropyLoss()
1348
+ loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1))
1349
+ elif self.config.problem_type == "multi_label_classification":
1350
+ loss_fct = BCEWithLogitsLoss()
1351
+ loss = loss_fct(logits.float(), labels.view(-1, self.num_labels))
1352
+
1353
+ if not return_dict:
1354
+ output = (logits,) + transformer_outputs[1:]
1355
+ return ((loss,) + output) if loss is not None else output
1356
+
1357
+ return SequenceClassifierOutputWithPast(
1358
+ loss=loss,
1359
+ logits=logits,
1360
+ past_key_values=transformer_outputs.past_key_values,
1361
+ hidden_states=transformer_outputs.hidden_states,
1362
+ attentions=transformer_outputs.attentions,
1363
+ )
diffsynth/models/sd3_dit.py ADDED
@@ -0,0 +1,797 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from .svd_unet import TemporalTimesteps
4
+ from .tiler import TileWorker
5
+
6
+
7
+
8
+ class PatchEmbed(torch.nn.Module):
9
+ def __init__(self, patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192):
10
+ super().__init__()
11
+ self.pos_embed_max_size = pos_embed_max_size
12
+ self.patch_size = patch_size
13
+
14
+ self.proj = torch.nn.Conv2d(in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size)
15
+ self.pos_embed = torch.nn.Parameter(torch.zeros(1, self.pos_embed_max_size, self.pos_embed_max_size, 1536))
16
+
17
+ def cropped_pos_embed(self, height, width):
18
+ height = height // self.patch_size
19
+ width = width // self.patch_size
20
+ top = (self.pos_embed_max_size - height) // 2
21
+ left = (self.pos_embed_max_size - width) // 2
22
+ spatial_pos_embed = self.pos_embed[:, top : top + height, left : left + width, :].flatten(1, 2)
23
+ return spatial_pos_embed
24
+
25
+ def forward(self, latent):
26
+ height, width = latent.shape[-2:]
27
+ latent = self.proj(latent)
28
+ latent = latent.flatten(2).transpose(1, 2)
29
+ pos_embed = self.cropped_pos_embed(height, width)
30
+ return latent + pos_embed
31
+
32
+
33
+
34
+ class TimestepEmbeddings(torch.nn.Module):
35
+ def __init__(self, dim_in, dim_out):
36
+ super().__init__()
37
+ self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0)
38
+ self.timestep_embedder = torch.nn.Sequential(
39
+ torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
40
+ )
41
+
42
+ def forward(self, timestep, dtype):
43
+ time_emb = self.time_proj(timestep).to(dtype)
44
+ time_emb = self.timestep_embedder(time_emb)
45
+ return time_emb
46
+
47
+
48
+
49
+ class AdaLayerNorm(torch.nn.Module):
50
+ def __init__(self, dim, single=False):
51
+ super().__init__()
52
+ self.single = single
53
+ self.linear = torch.nn.Linear(dim, dim * (2 if single else 6))
54
+ self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
55
+
56
+ def forward(self, x, emb):
57
+ emb = self.linear(torch.nn.functional.silu(emb))
58
+ if self.single:
59
+ scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
60
+ x = self.norm(x) * (1 + scale) + shift
61
+ return x
62
+ else:
63
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
64
+ x = self.norm(x) * (1 + scale_msa) + shift_msa
65
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
66
+
67
+
68
+
69
+ class JointAttention(torch.nn.Module):
70
+ def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
71
+ super().__init__()
72
+ self.num_heads = num_heads
73
+ self.head_dim = head_dim
74
+ self.only_out_a = only_out_a
75
+
76
+ self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
77
+ self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
78
+
79
+ self.a_to_out = torch.nn.Linear(dim_a, dim_a)
80
+ if not only_out_a:
81
+ self.b_to_out = torch.nn.Linear(dim_b, dim_b)
82
+
83
+ def forward(self, hidden_states_a, hidden_states_b):
84
+ batch_size = hidden_states_a.shape[0]
85
+
86
+ qkv = torch.concat([self.a_to_qkv(hidden_states_a), self.b_to_qkv(hidden_states_b)], dim=1)
87
+ qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
88
+ q, k, v = qkv.chunk(3, dim=1)
89
+
90
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
91
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
92
+ hidden_states = hidden_states.to(q.dtype)
93
+ hidden_states_a, hidden_states_b = hidden_states[:, :hidden_states_a.shape[1]], hidden_states[:, hidden_states_a.shape[1]:]
94
+ hidden_states_a = self.a_to_out(hidden_states_a)
95
+ if self.only_out_a:
96
+ return hidden_states_a
97
+ else:
98
+ hidden_states_b = self.b_to_out(hidden_states_b)
99
+ return hidden_states_a, hidden_states_b
100
+
101
+
102
+
103
+ class JointTransformerBlock(torch.nn.Module):
104
+ def __init__(self, dim, num_attention_heads):
105
+ super().__init__()
106
+ self.norm1_a = AdaLayerNorm(dim)
107
+ self.norm1_b = AdaLayerNorm(dim)
108
+
109
+ self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
110
+
111
+ self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
112
+ self.ff_a = torch.nn.Sequential(
113
+ torch.nn.Linear(dim, dim*4),
114
+ torch.nn.GELU(approximate="tanh"),
115
+ torch.nn.Linear(dim*4, dim)
116
+ )
117
+
118
+ self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
119
+ self.ff_b = torch.nn.Sequential(
120
+ torch.nn.Linear(dim, dim*4),
121
+ torch.nn.GELU(approximate="tanh"),
122
+ torch.nn.Linear(dim*4, dim)
123
+ )
124
+
125
+
126
+ def forward(self, hidden_states_a, hidden_states_b, temb):
127
+ norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
128
+ norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
129
+
130
+ # Attention
131
+ attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b)
132
+
133
+ # Part A
134
+ hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
135
+ norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
136
+ hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
137
+
138
+ # Part B
139
+ hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
140
+ norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
141
+ hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
142
+
143
+ return hidden_states_a, hidden_states_b
144
+
145
+
146
+
147
+ class JointTransformerFinalBlock(torch.nn.Module):
148
+ def __init__(self, dim, num_attention_heads):
149
+ super().__init__()
150
+ self.norm1_a = AdaLayerNorm(dim)
151
+ self.norm1_b = AdaLayerNorm(dim, single=True)
152
+
153
+ self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, only_out_a=True)
154
+
155
+ self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
156
+ self.ff_a = torch.nn.Sequential(
157
+ torch.nn.Linear(dim, dim*4),
158
+ torch.nn.GELU(approximate="tanh"),
159
+ torch.nn.Linear(dim*4, dim)
160
+ )
161
+
162
+
163
+ def forward(self, hidden_states_a, hidden_states_b, temb):
164
+ norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
165
+ norm_hidden_states_b = self.norm1_b(hidden_states_b, emb=temb)
166
+
167
+ # Attention
168
+ attn_output_a = self.attn(norm_hidden_states_a, norm_hidden_states_b)
169
+
170
+ # Part A
171
+ hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
172
+ norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
173
+ hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
174
+
175
+ return hidden_states_a, hidden_states_b
176
+
177
+
178
+
179
+ class SD3DiT(torch.nn.Module):
180
+ def __init__(self):
181
+ super().__init__()
182
+ self.pos_embedder = PatchEmbed(patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192)
183
+ self.time_embedder = TimestepEmbeddings(256, 1536)
184
+ self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(2048, 1536), torch.nn.SiLU(), torch.nn.Linear(1536, 1536))
185
+ self.context_embedder = torch.nn.Linear(4096, 1536)
186
+ self.blocks = torch.nn.ModuleList([JointTransformerBlock(1536, 24) for _ in range(23)] + [JointTransformerFinalBlock(1536, 24)])
187
+ self.norm_out = AdaLayerNorm(1536, single=True)
188
+ self.proj_out = torch.nn.Linear(1536, 64)
189
+
190
+ def tiled_forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size=128, tile_stride=64):
191
+ # Due to the global positional embedding, we cannot implement layer-wise tiled forward.
192
+ hidden_states = TileWorker().tiled_forward(
193
+ lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb),
194
+ hidden_states,
195
+ tile_size,
196
+ tile_stride,
197
+ tile_device=hidden_states.device,
198
+ tile_dtype=hidden_states.dtype
199
+ )
200
+ return hidden_states
201
+
202
+ def forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tiled=False, tile_size=128, tile_stride=64, use_gradient_checkpointing=False):
203
+ if tiled:
204
+ return self.tiled_forward(hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size, tile_stride)
205
+ conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
206
+ prompt_emb = self.context_embedder(prompt_emb)
207
+
208
+ height, width = hidden_states.shape[-2:]
209
+ hidden_states = self.pos_embedder(hidden_states)
210
+
211
+ def create_custom_forward(module):
212
+ def custom_forward(*inputs):
213
+ return module(*inputs)
214
+ return custom_forward
215
+
216
+ for block in self.blocks:
217
+ if self.training and use_gradient_checkpointing:
218
+ hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
219
+ create_custom_forward(block),
220
+ hidden_states, prompt_emb, conditioning,
221
+ use_reentrant=False,
222
+ )
223
+ else:
224
+ hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning)
225
+
226
+ hidden_states = self.norm_out(hidden_states, conditioning)
227
+ hidden_states = self.proj_out(hidden_states)
228
+ hidden_states = rearrange(hidden_states, "B (H W) (P Q C) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
229
+ return hidden_states
230
+
231
+ def state_dict_converter(self):
232
+ return SD3DiTStateDictConverter()
233
+
234
+
235
+
236
+ class SD3DiTStateDictConverter:
237
+ def __init__(self):
238
+ pass
239
+
240
+ def from_diffusers(self, state_dict):
241
+ rename_dict = {
242
+ "context_embedder": "context_embedder",
243
+ "pos_embed.pos_embed": "pos_embedder.pos_embed",
244
+ "pos_embed.proj": "pos_embedder.proj",
245
+ "time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
246
+ "time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
247
+ "time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
248
+ "time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
249
+ "norm_out.linear": "norm_out.linear",
250
+ "proj_out": "proj_out",
251
+
252
+ "norm1.linear": "norm1_a.linear",
253
+ "norm1_context.linear": "norm1_b.linear",
254
+ "attn.to_q": "attn.a_to_q",
255
+ "attn.to_k": "attn.a_to_k",
256
+ "attn.to_v": "attn.a_to_v",
257
+ "attn.to_out.0": "attn.a_to_out",
258
+ "attn.add_q_proj": "attn.b_to_q",
259
+ "attn.add_k_proj": "attn.b_to_k",
260
+ "attn.add_v_proj": "attn.b_to_v",
261
+ "attn.to_add_out": "attn.b_to_out",
262
+ "ff.net.0.proj": "ff_a.0",
263
+ "ff.net.2": "ff_a.2",
264
+ "ff_context.net.0.proj": "ff_b.0",
265
+ "ff_context.net.2": "ff_b.2",
266
+ }
267
+ state_dict_ = {}
268
+ for name, param in state_dict.items():
269
+ if name in rename_dict:
270
+ if name == "pos_embed.pos_embed":
271
+ param = param.reshape((1, 192, 192, 1536))
272
+ state_dict_[rename_dict[name]] = param
273
+ elif name.endswith(".weight") or name.endswith(".bias"):
274
+ suffix = ".weight" if name.endswith(".weight") else ".bias"
275
+ prefix = name[:-len(suffix)]
276
+ if prefix in rename_dict:
277
+ state_dict_[rename_dict[prefix] + suffix] = param
278
+ elif prefix.startswith("transformer_blocks."):
279
+ names = prefix.split(".")
280
+ names[0] = "blocks"
281
+ middle = ".".join(names[2:])
282
+ if middle in rename_dict:
283
+ name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
284
+ state_dict_[name_] = param
285
+ return state_dict_
286
+
287
+ def from_civitai(self, state_dict):
288
+ rename_dict = {
289
+ "model.diffusion_model.context_embedder.bias": "context_embedder.bias",
290
+ "model.diffusion_model.context_embedder.weight": "context_embedder.weight",
291
+ "model.diffusion_model.final_layer.linear.bias": "proj_out.bias",
292
+ "model.diffusion_model.final_layer.linear.weight": "proj_out.weight",
293
+ "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias": "blocks.0.norm1_b.linear.bias",
294
+ "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.weight": "blocks.0.norm1_b.linear.weight",
295
+ "model.diffusion_model.joint_blocks.0.context_block.attn.proj.bias": "blocks.0.attn.b_to_out.bias",
296
+ "model.diffusion_model.joint_blocks.0.context_block.attn.proj.weight": "blocks.0.attn.b_to_out.weight",
297
+ "model.diffusion_model.joint_blocks.0.context_block.attn.qkv.bias": ['blocks.0.attn.b_to_q.bias', 'blocks.0.attn.b_to_k.bias', 'blocks.0.attn.b_to_v.bias'],
298
+ "model.diffusion_model.joint_blocks.0.context_block.attn.qkv.weight": ['blocks.0.attn.b_to_q.weight', 'blocks.0.attn.b_to_k.weight', 'blocks.0.attn.b_to_v.weight'],
299
+ "model.diffusion_model.joint_blocks.0.context_block.mlp.fc1.bias": "blocks.0.ff_b.0.bias",
300
+ "model.diffusion_model.joint_blocks.0.context_block.mlp.fc1.weight": "blocks.0.ff_b.0.weight",
301
+ "model.diffusion_model.joint_blocks.0.context_block.mlp.fc2.bias": "blocks.0.ff_b.2.bias",
302
+ "model.diffusion_model.joint_blocks.0.context_block.mlp.fc2.weight": "blocks.0.ff_b.2.weight",
303
+ "model.diffusion_model.joint_blocks.0.x_block.adaLN_modulation.1.bias": "blocks.0.norm1_a.linear.bias",
304
+ "model.diffusion_model.joint_blocks.0.x_block.adaLN_modulation.1.weight": "blocks.0.norm1_a.linear.weight",
305
+ "model.diffusion_model.joint_blocks.0.x_block.attn.proj.bias": "blocks.0.attn.a_to_out.bias",
306
+ "model.diffusion_model.joint_blocks.0.x_block.attn.proj.weight": "blocks.0.attn.a_to_out.weight",
307
+ "model.diffusion_model.joint_blocks.0.x_block.attn.qkv.bias": ['blocks.0.attn.a_to_q.bias', 'blocks.0.attn.a_to_k.bias', 'blocks.0.attn.a_to_v.bias'],
308
+ "model.diffusion_model.joint_blocks.0.x_block.attn.qkv.weight": ['blocks.0.attn.a_to_q.weight', 'blocks.0.attn.a_to_k.weight', 'blocks.0.attn.a_to_v.weight'],
309
+ "model.diffusion_model.joint_blocks.0.x_block.mlp.fc1.bias": "blocks.0.ff_a.0.bias",
310
+ "model.diffusion_model.joint_blocks.0.x_block.mlp.fc1.weight": "blocks.0.ff_a.0.weight",
311
+ "model.diffusion_model.joint_blocks.0.x_block.mlp.fc2.bias": "blocks.0.ff_a.2.bias",
312
+ "model.diffusion_model.joint_blocks.0.x_block.mlp.fc2.weight": "blocks.0.ff_a.2.weight",
313
+ "model.diffusion_model.joint_blocks.1.context_block.adaLN_modulation.1.bias": "blocks.1.norm1_b.linear.bias",
314
+ "model.diffusion_model.joint_blocks.1.context_block.adaLN_modulation.1.weight": "blocks.1.norm1_b.linear.weight",
315
+ "model.diffusion_model.joint_blocks.1.context_block.attn.proj.bias": "blocks.1.attn.b_to_out.bias",
316
+ "model.diffusion_model.joint_blocks.1.context_block.attn.proj.weight": "blocks.1.attn.b_to_out.weight",
317
+ "model.diffusion_model.joint_blocks.1.context_block.attn.qkv.bias": ['blocks.1.attn.b_to_q.bias', 'blocks.1.attn.b_to_k.bias', 'blocks.1.attn.b_to_v.bias'],
318
+ "model.diffusion_model.joint_blocks.1.context_block.attn.qkv.weight": ['blocks.1.attn.b_to_q.weight', 'blocks.1.attn.b_to_k.weight', 'blocks.1.attn.b_to_v.weight'],
319
+ "model.diffusion_model.joint_blocks.1.context_block.mlp.fc1.bias": "blocks.1.ff_b.0.bias",
320
+ "model.diffusion_model.joint_blocks.1.context_block.mlp.fc1.weight": "blocks.1.ff_b.0.weight",
321
+ "model.diffusion_model.joint_blocks.1.context_block.mlp.fc2.bias": "blocks.1.ff_b.2.bias",
322
+ "model.diffusion_model.joint_blocks.1.context_block.mlp.fc2.weight": "blocks.1.ff_b.2.weight",
323
+ "model.diffusion_model.joint_blocks.1.x_block.adaLN_modulation.1.bias": "blocks.1.norm1_a.linear.bias",
324
+ "model.diffusion_model.joint_blocks.1.x_block.adaLN_modulation.1.weight": "blocks.1.norm1_a.linear.weight",
325
+ "model.diffusion_model.joint_blocks.1.x_block.attn.proj.bias": "blocks.1.attn.a_to_out.bias",
326
+ "model.diffusion_model.joint_blocks.1.x_block.attn.proj.weight": "blocks.1.attn.a_to_out.weight",
327
+ "model.diffusion_model.joint_blocks.1.x_block.attn.qkv.bias": ['blocks.1.attn.a_to_q.bias', 'blocks.1.attn.a_to_k.bias', 'blocks.1.attn.a_to_v.bias'],
328
+ "model.diffusion_model.joint_blocks.1.x_block.attn.qkv.weight": ['blocks.1.attn.a_to_q.weight', 'blocks.1.attn.a_to_k.weight', 'blocks.1.attn.a_to_v.weight'],
329
+ "model.diffusion_model.joint_blocks.1.x_block.mlp.fc1.bias": "blocks.1.ff_a.0.bias",
330
+ "model.diffusion_model.joint_blocks.1.x_block.mlp.fc1.weight": "blocks.1.ff_a.0.weight",
331
+ "model.diffusion_model.joint_blocks.1.x_block.mlp.fc2.bias": "blocks.1.ff_a.2.bias",
332
+ "model.diffusion_model.joint_blocks.1.x_block.mlp.fc2.weight": "blocks.1.ff_a.2.weight",
333
+ "model.diffusion_model.joint_blocks.10.context_block.adaLN_modulation.1.bias": "blocks.10.norm1_b.linear.bias",
334
+ "model.diffusion_model.joint_blocks.10.context_block.adaLN_modulation.1.weight": "blocks.10.norm1_b.linear.weight",
335
+ "model.diffusion_model.joint_blocks.10.context_block.attn.proj.bias": "blocks.10.attn.b_to_out.bias",
336
+ "model.diffusion_model.joint_blocks.10.context_block.attn.proj.weight": "blocks.10.attn.b_to_out.weight",
337
+ "model.diffusion_model.joint_blocks.10.context_block.attn.qkv.bias": ['blocks.10.attn.b_to_q.bias', 'blocks.10.attn.b_to_k.bias', 'blocks.10.attn.b_to_v.bias'],
338
+ "model.diffusion_model.joint_blocks.10.context_block.attn.qkv.weight": ['blocks.10.attn.b_to_q.weight', 'blocks.10.attn.b_to_k.weight', 'blocks.10.attn.b_to_v.weight'],
339
+ "model.diffusion_model.joint_blocks.10.context_block.mlp.fc1.bias": "blocks.10.ff_b.0.bias",
340
+ "model.diffusion_model.joint_blocks.10.context_block.mlp.fc1.weight": "blocks.10.ff_b.0.weight",
341
+ "model.diffusion_model.joint_blocks.10.context_block.mlp.fc2.bias": "blocks.10.ff_b.2.bias",
342
+ "model.diffusion_model.joint_blocks.10.context_block.mlp.fc2.weight": "blocks.10.ff_b.2.weight",
343
+ "model.diffusion_model.joint_blocks.10.x_block.adaLN_modulation.1.bias": "blocks.10.norm1_a.linear.bias",
344
+ "model.diffusion_model.joint_blocks.10.x_block.adaLN_modulation.1.weight": "blocks.10.norm1_a.linear.weight",
345
+ "model.diffusion_model.joint_blocks.10.x_block.attn.proj.bias": "blocks.10.attn.a_to_out.bias",
346
+ "model.diffusion_model.joint_blocks.10.x_block.attn.proj.weight": "blocks.10.attn.a_to_out.weight",
347
+ "model.diffusion_model.joint_blocks.10.x_block.attn.qkv.bias": ['blocks.10.attn.a_to_q.bias', 'blocks.10.attn.a_to_k.bias', 'blocks.10.attn.a_to_v.bias'],
348
+ "model.diffusion_model.joint_blocks.10.x_block.attn.qkv.weight": ['blocks.10.attn.a_to_q.weight', 'blocks.10.attn.a_to_k.weight', 'blocks.10.attn.a_to_v.weight'],
349
+ "model.diffusion_model.joint_blocks.10.x_block.mlp.fc1.bias": "blocks.10.ff_a.0.bias",
350
+ "model.diffusion_model.joint_blocks.10.x_block.mlp.fc1.weight": "blocks.10.ff_a.0.weight",
351
+ "model.diffusion_model.joint_blocks.10.x_block.mlp.fc2.bias": "blocks.10.ff_a.2.bias",
352
+ "model.diffusion_model.joint_blocks.10.x_block.mlp.fc2.weight": "blocks.10.ff_a.2.weight",
353
+ "model.diffusion_model.joint_blocks.11.context_block.adaLN_modulation.1.bias": "blocks.11.norm1_b.linear.bias",
354
+ "model.diffusion_model.joint_blocks.11.context_block.adaLN_modulation.1.weight": "blocks.11.norm1_b.linear.weight",
355
+ "model.diffusion_model.joint_blocks.11.context_block.attn.proj.bias": "blocks.11.attn.b_to_out.bias",
356
+ "model.diffusion_model.joint_blocks.11.context_block.attn.proj.weight": "blocks.11.attn.b_to_out.weight",
357
+ "model.diffusion_model.joint_blocks.11.context_block.attn.qkv.bias": ['blocks.11.attn.b_to_q.bias', 'blocks.11.attn.b_to_k.bias', 'blocks.11.attn.b_to_v.bias'],
358
+ "model.diffusion_model.joint_blocks.11.context_block.attn.qkv.weight": ['blocks.11.attn.b_to_q.weight', 'blocks.11.attn.b_to_k.weight', 'blocks.11.attn.b_to_v.weight'],
359
+ "model.diffusion_model.joint_blocks.11.context_block.mlp.fc1.bias": "blocks.11.ff_b.0.bias",
360
+ "model.diffusion_model.joint_blocks.11.context_block.mlp.fc1.weight": "blocks.11.ff_b.0.weight",
361
+ "model.diffusion_model.joint_blocks.11.context_block.mlp.fc2.bias": "blocks.11.ff_b.2.bias",
362
+ "model.diffusion_model.joint_blocks.11.context_block.mlp.fc2.weight": "blocks.11.ff_b.2.weight",
363
+ "model.diffusion_model.joint_blocks.11.x_block.adaLN_modulation.1.bias": "blocks.11.norm1_a.linear.bias",
364
+ "model.diffusion_model.joint_blocks.11.x_block.adaLN_modulation.1.weight": "blocks.11.norm1_a.linear.weight",
365
+ "model.diffusion_model.joint_blocks.11.x_block.attn.proj.bias": "blocks.11.attn.a_to_out.bias",
366
+ "model.diffusion_model.joint_blocks.11.x_block.attn.proj.weight": "blocks.11.attn.a_to_out.weight",
367
+ "model.diffusion_model.joint_blocks.11.x_block.attn.qkv.bias": ['blocks.11.attn.a_to_q.bias', 'blocks.11.attn.a_to_k.bias', 'blocks.11.attn.a_to_v.bias'],
368
+ "model.diffusion_model.joint_blocks.11.x_block.attn.qkv.weight": ['blocks.11.attn.a_to_q.weight', 'blocks.11.attn.a_to_k.weight', 'blocks.11.attn.a_to_v.weight'],
369
+ "model.diffusion_model.joint_blocks.11.x_block.mlp.fc1.bias": "blocks.11.ff_a.0.bias",
370
+ "model.diffusion_model.joint_blocks.11.x_block.mlp.fc1.weight": "blocks.11.ff_a.0.weight",
371
+ "model.diffusion_model.joint_blocks.11.x_block.mlp.fc2.bias": "blocks.11.ff_a.2.bias",
372
+ "model.diffusion_model.joint_blocks.11.x_block.mlp.fc2.weight": "blocks.11.ff_a.2.weight",
373
+ "model.diffusion_model.joint_blocks.12.context_block.adaLN_modulation.1.bias": "blocks.12.norm1_b.linear.bias",
374
+ "model.diffusion_model.joint_blocks.12.context_block.adaLN_modulation.1.weight": "blocks.12.norm1_b.linear.weight",
375
+ "model.diffusion_model.joint_blocks.12.context_block.attn.proj.bias": "blocks.12.attn.b_to_out.bias",
376
+ "model.diffusion_model.joint_blocks.12.context_block.attn.proj.weight": "blocks.12.attn.b_to_out.weight",
377
+ "model.diffusion_model.joint_blocks.12.context_block.attn.qkv.bias": ['blocks.12.attn.b_to_q.bias', 'blocks.12.attn.b_to_k.bias', 'blocks.12.attn.b_to_v.bias'],
378
+ "model.diffusion_model.joint_blocks.12.context_block.attn.qkv.weight": ['blocks.12.attn.b_to_q.weight', 'blocks.12.attn.b_to_k.weight', 'blocks.12.attn.b_to_v.weight'],
379
+ "model.diffusion_model.joint_blocks.12.context_block.mlp.fc1.bias": "blocks.12.ff_b.0.bias",
380
+ "model.diffusion_model.joint_blocks.12.context_block.mlp.fc1.weight": "blocks.12.ff_b.0.weight",
381
+ "model.diffusion_model.joint_blocks.12.context_block.mlp.fc2.bias": "blocks.12.ff_b.2.bias",
382
+ "model.diffusion_model.joint_blocks.12.context_block.mlp.fc2.weight": "blocks.12.ff_b.2.weight",
383
+ "model.diffusion_model.joint_blocks.12.x_block.adaLN_modulation.1.bias": "blocks.12.norm1_a.linear.bias",
384
+ "model.diffusion_model.joint_blocks.12.x_block.adaLN_modulation.1.weight": "blocks.12.norm1_a.linear.weight",
385
+ "model.diffusion_model.joint_blocks.12.x_block.attn.proj.bias": "blocks.12.attn.a_to_out.bias",
386
+ "model.diffusion_model.joint_blocks.12.x_block.attn.proj.weight": "blocks.12.attn.a_to_out.weight",
387
+ "model.diffusion_model.joint_blocks.12.x_block.attn.qkv.bias": ['blocks.12.attn.a_to_q.bias', 'blocks.12.attn.a_to_k.bias', 'blocks.12.attn.a_to_v.bias'],
388
+ "model.diffusion_model.joint_blocks.12.x_block.attn.qkv.weight": ['blocks.12.attn.a_to_q.weight', 'blocks.12.attn.a_to_k.weight', 'blocks.12.attn.a_to_v.weight'],
389
+ "model.diffusion_model.joint_blocks.12.x_block.mlp.fc1.bias": "blocks.12.ff_a.0.bias",
390
+ "model.diffusion_model.joint_blocks.12.x_block.mlp.fc1.weight": "blocks.12.ff_a.0.weight",
391
+ "model.diffusion_model.joint_blocks.12.x_block.mlp.fc2.bias": "blocks.12.ff_a.2.bias",
392
+ "model.diffusion_model.joint_blocks.12.x_block.mlp.fc2.weight": "blocks.12.ff_a.2.weight",
393
+ "model.diffusion_model.joint_blocks.13.context_block.adaLN_modulation.1.bias": "blocks.13.norm1_b.linear.bias",
394
+ "model.diffusion_model.joint_blocks.13.context_block.adaLN_modulation.1.weight": "blocks.13.norm1_b.linear.weight",
395
+ "model.diffusion_model.joint_blocks.13.context_block.attn.proj.bias": "blocks.13.attn.b_to_out.bias",
396
+ "model.diffusion_model.joint_blocks.13.context_block.attn.proj.weight": "blocks.13.attn.b_to_out.weight",
397
+ "model.diffusion_model.joint_blocks.13.context_block.attn.qkv.bias": ['blocks.13.attn.b_to_q.bias', 'blocks.13.attn.b_to_k.bias', 'blocks.13.attn.b_to_v.bias'],
398
+ "model.diffusion_model.joint_blocks.13.context_block.attn.qkv.weight": ['blocks.13.attn.b_to_q.weight', 'blocks.13.attn.b_to_k.weight', 'blocks.13.attn.b_to_v.weight'],
399
+ "model.diffusion_model.joint_blocks.13.context_block.mlp.fc1.bias": "blocks.13.ff_b.0.bias",
400
+ "model.diffusion_model.joint_blocks.13.context_block.mlp.fc1.weight": "blocks.13.ff_b.0.weight",
401
+ "model.diffusion_model.joint_blocks.13.context_block.mlp.fc2.bias": "blocks.13.ff_b.2.bias",
402
+ "model.diffusion_model.joint_blocks.13.context_block.mlp.fc2.weight": "blocks.13.ff_b.2.weight",
403
+ "model.diffusion_model.joint_blocks.13.x_block.adaLN_modulation.1.bias": "blocks.13.norm1_a.linear.bias",
404
+ "model.diffusion_model.joint_blocks.13.x_block.adaLN_modulation.1.weight": "blocks.13.norm1_a.linear.weight",
405
+ "model.diffusion_model.joint_blocks.13.x_block.attn.proj.bias": "blocks.13.attn.a_to_out.bias",
406
+ "model.diffusion_model.joint_blocks.13.x_block.attn.proj.weight": "blocks.13.attn.a_to_out.weight",
407
+ "model.diffusion_model.joint_blocks.13.x_block.attn.qkv.bias": ['blocks.13.attn.a_to_q.bias', 'blocks.13.attn.a_to_k.bias', 'blocks.13.attn.a_to_v.bias'],
408
+ "model.diffusion_model.joint_blocks.13.x_block.attn.qkv.weight": ['blocks.13.attn.a_to_q.weight', 'blocks.13.attn.a_to_k.weight', 'blocks.13.attn.a_to_v.weight'],
409
+ "model.diffusion_model.joint_blocks.13.x_block.mlp.fc1.bias": "blocks.13.ff_a.0.bias",
410
+ "model.diffusion_model.joint_blocks.13.x_block.mlp.fc1.weight": "blocks.13.ff_a.0.weight",
411
+ "model.diffusion_model.joint_blocks.13.x_block.mlp.fc2.bias": "blocks.13.ff_a.2.bias",
412
+ "model.diffusion_model.joint_blocks.13.x_block.mlp.fc2.weight": "blocks.13.ff_a.2.weight",
413
+ "model.diffusion_model.joint_blocks.14.context_block.adaLN_modulation.1.bias": "blocks.14.norm1_b.linear.bias",
414
+ "model.diffusion_model.joint_blocks.14.context_block.adaLN_modulation.1.weight": "blocks.14.norm1_b.linear.weight",
415
+ "model.diffusion_model.joint_blocks.14.context_block.attn.proj.bias": "blocks.14.attn.b_to_out.bias",
416
+ "model.diffusion_model.joint_blocks.14.context_block.attn.proj.weight": "blocks.14.attn.b_to_out.weight",
417
+ "model.diffusion_model.joint_blocks.14.context_block.attn.qkv.bias": ['blocks.14.attn.b_to_q.bias', 'blocks.14.attn.b_to_k.bias', 'blocks.14.attn.b_to_v.bias'],
418
+ "model.diffusion_model.joint_blocks.14.context_block.attn.qkv.weight": ['blocks.14.attn.b_to_q.weight', 'blocks.14.attn.b_to_k.weight', 'blocks.14.attn.b_to_v.weight'],
419
+ "model.diffusion_model.joint_blocks.14.context_block.mlp.fc1.bias": "blocks.14.ff_b.0.bias",
420
+ "model.diffusion_model.joint_blocks.14.context_block.mlp.fc1.weight": "blocks.14.ff_b.0.weight",
421
+ "model.diffusion_model.joint_blocks.14.context_block.mlp.fc2.bias": "blocks.14.ff_b.2.bias",
422
+ "model.diffusion_model.joint_blocks.14.context_block.mlp.fc2.weight": "blocks.14.ff_b.2.weight",
423
+ "model.diffusion_model.joint_blocks.14.x_block.adaLN_modulation.1.bias": "blocks.14.norm1_a.linear.bias",
424
+ "model.diffusion_model.joint_blocks.14.x_block.adaLN_modulation.1.weight": "blocks.14.norm1_a.linear.weight",
425
+ "model.diffusion_model.joint_blocks.14.x_block.attn.proj.bias": "blocks.14.attn.a_to_out.bias",
426
+ "model.diffusion_model.joint_blocks.14.x_block.attn.proj.weight": "blocks.14.attn.a_to_out.weight",
427
+ "model.diffusion_model.joint_blocks.14.x_block.attn.qkv.bias": ['blocks.14.attn.a_to_q.bias', 'blocks.14.attn.a_to_k.bias', 'blocks.14.attn.a_to_v.bias'],
428
+ "model.diffusion_model.joint_blocks.14.x_block.attn.qkv.weight": ['blocks.14.attn.a_to_q.weight', 'blocks.14.attn.a_to_k.weight', 'blocks.14.attn.a_to_v.weight'],
429
+ "model.diffusion_model.joint_blocks.14.x_block.mlp.fc1.bias": "blocks.14.ff_a.0.bias",
430
+ "model.diffusion_model.joint_blocks.14.x_block.mlp.fc1.weight": "blocks.14.ff_a.0.weight",
431
+ "model.diffusion_model.joint_blocks.14.x_block.mlp.fc2.bias": "blocks.14.ff_a.2.bias",
432
+ "model.diffusion_model.joint_blocks.14.x_block.mlp.fc2.weight": "blocks.14.ff_a.2.weight",
433
+ "model.diffusion_model.joint_blocks.15.context_block.adaLN_modulation.1.bias": "blocks.15.norm1_b.linear.bias",
434
+ "model.diffusion_model.joint_blocks.15.context_block.adaLN_modulation.1.weight": "blocks.15.norm1_b.linear.weight",
435
+ "model.diffusion_model.joint_blocks.15.context_block.attn.proj.bias": "blocks.15.attn.b_to_out.bias",
436
+ "model.diffusion_model.joint_blocks.15.context_block.attn.proj.weight": "blocks.15.attn.b_to_out.weight",
437
+ "model.diffusion_model.joint_blocks.15.context_block.attn.qkv.bias": ['blocks.15.attn.b_to_q.bias', 'blocks.15.attn.b_to_k.bias', 'blocks.15.attn.b_to_v.bias'],
438
+ "model.diffusion_model.joint_blocks.15.context_block.attn.qkv.weight": ['blocks.15.attn.b_to_q.weight', 'blocks.15.attn.b_to_k.weight', 'blocks.15.attn.b_to_v.weight'],
439
+ "model.diffusion_model.joint_blocks.15.context_block.mlp.fc1.bias": "blocks.15.ff_b.0.bias",
440
+ "model.diffusion_model.joint_blocks.15.context_block.mlp.fc1.weight": "blocks.15.ff_b.0.weight",
441
+ "model.diffusion_model.joint_blocks.15.context_block.mlp.fc2.bias": "blocks.15.ff_b.2.bias",
442
+ "model.diffusion_model.joint_blocks.15.context_block.mlp.fc2.weight": "blocks.15.ff_b.2.weight",
443
+ "model.diffusion_model.joint_blocks.15.x_block.adaLN_modulation.1.bias": "blocks.15.norm1_a.linear.bias",
444
+ "model.diffusion_model.joint_blocks.15.x_block.adaLN_modulation.1.weight": "blocks.15.norm1_a.linear.weight",
445
+ "model.diffusion_model.joint_blocks.15.x_block.attn.proj.bias": "blocks.15.attn.a_to_out.bias",
446
+ "model.diffusion_model.joint_blocks.15.x_block.attn.proj.weight": "blocks.15.attn.a_to_out.weight",
447
+ "model.diffusion_model.joint_blocks.15.x_block.attn.qkv.bias": ['blocks.15.attn.a_to_q.bias', 'blocks.15.attn.a_to_k.bias', 'blocks.15.attn.a_to_v.bias'],
448
+ "model.diffusion_model.joint_blocks.15.x_block.attn.qkv.weight": ['blocks.15.attn.a_to_q.weight', 'blocks.15.attn.a_to_k.weight', 'blocks.15.attn.a_to_v.weight'],
449
+ "model.diffusion_model.joint_blocks.15.x_block.mlp.fc1.bias": "blocks.15.ff_a.0.bias",
450
+ "model.diffusion_model.joint_blocks.15.x_block.mlp.fc1.weight": "blocks.15.ff_a.0.weight",
451
+ "model.diffusion_model.joint_blocks.15.x_block.mlp.fc2.bias": "blocks.15.ff_a.2.bias",
452
+ "model.diffusion_model.joint_blocks.15.x_block.mlp.fc2.weight": "blocks.15.ff_a.2.weight",
453
+ "model.diffusion_model.joint_blocks.16.context_block.adaLN_modulation.1.bias": "blocks.16.norm1_b.linear.bias",
454
+ "model.diffusion_model.joint_blocks.16.context_block.adaLN_modulation.1.weight": "blocks.16.norm1_b.linear.weight",
455
+ "model.diffusion_model.joint_blocks.16.context_block.attn.proj.bias": "blocks.16.attn.b_to_out.bias",
456
+ "model.diffusion_model.joint_blocks.16.context_block.attn.proj.weight": "blocks.16.attn.b_to_out.weight",
457
+ "model.diffusion_model.joint_blocks.16.context_block.attn.qkv.bias": ['blocks.16.attn.b_to_q.bias', 'blocks.16.attn.b_to_k.bias', 'blocks.16.attn.b_to_v.bias'],
458
+ "model.diffusion_model.joint_blocks.16.context_block.attn.qkv.weight": ['blocks.16.attn.b_to_q.weight', 'blocks.16.attn.b_to_k.weight', 'blocks.16.attn.b_to_v.weight'],
459
+ "model.diffusion_model.joint_blocks.16.context_block.mlp.fc1.bias": "blocks.16.ff_b.0.bias",
460
+ "model.diffusion_model.joint_blocks.16.context_block.mlp.fc1.weight": "blocks.16.ff_b.0.weight",
461
+ "model.diffusion_model.joint_blocks.16.context_block.mlp.fc2.bias": "blocks.16.ff_b.2.bias",
462
+ "model.diffusion_model.joint_blocks.16.context_block.mlp.fc2.weight": "blocks.16.ff_b.2.weight",
463
+ "model.diffusion_model.joint_blocks.16.x_block.adaLN_modulation.1.bias": "blocks.16.norm1_a.linear.bias",
464
+ "model.diffusion_model.joint_blocks.16.x_block.adaLN_modulation.1.weight": "blocks.16.norm1_a.linear.weight",
465
+ "model.diffusion_model.joint_blocks.16.x_block.attn.proj.bias": "blocks.16.attn.a_to_out.bias",
466
+ "model.diffusion_model.joint_blocks.16.x_block.attn.proj.weight": "blocks.16.attn.a_to_out.weight",
467
+ "model.diffusion_model.joint_blocks.16.x_block.attn.qkv.bias": ['blocks.16.attn.a_to_q.bias', 'blocks.16.attn.a_to_k.bias', 'blocks.16.attn.a_to_v.bias'],
468
+ "model.diffusion_model.joint_blocks.16.x_block.attn.qkv.weight": ['blocks.16.attn.a_to_q.weight', 'blocks.16.attn.a_to_k.weight', 'blocks.16.attn.a_to_v.weight'],
469
+ "model.diffusion_model.joint_blocks.16.x_block.mlp.fc1.bias": "blocks.16.ff_a.0.bias",
470
+ "model.diffusion_model.joint_blocks.16.x_block.mlp.fc1.weight": "blocks.16.ff_a.0.weight",
471
+ "model.diffusion_model.joint_blocks.16.x_block.mlp.fc2.bias": "blocks.16.ff_a.2.bias",
472
+ "model.diffusion_model.joint_blocks.16.x_block.mlp.fc2.weight": "blocks.16.ff_a.2.weight",
473
+ "model.diffusion_model.joint_blocks.17.context_block.adaLN_modulation.1.bias": "blocks.17.norm1_b.linear.bias",
474
+ "model.diffusion_model.joint_blocks.17.context_block.adaLN_modulation.1.weight": "blocks.17.norm1_b.linear.weight",
475
+ "model.diffusion_model.joint_blocks.17.context_block.attn.proj.bias": "blocks.17.attn.b_to_out.bias",
476
+ "model.diffusion_model.joint_blocks.17.context_block.attn.proj.weight": "blocks.17.attn.b_to_out.weight",
477
+ "model.diffusion_model.joint_blocks.17.context_block.attn.qkv.bias": ['blocks.17.attn.b_to_q.bias', 'blocks.17.attn.b_to_k.bias', 'blocks.17.attn.b_to_v.bias'],
478
+ "model.diffusion_model.joint_blocks.17.context_block.attn.qkv.weight": ['blocks.17.attn.b_to_q.weight', 'blocks.17.attn.b_to_k.weight', 'blocks.17.attn.b_to_v.weight'],
479
+ "model.diffusion_model.joint_blocks.17.context_block.mlp.fc1.bias": "blocks.17.ff_b.0.bias",
480
+ "model.diffusion_model.joint_blocks.17.context_block.mlp.fc1.weight": "blocks.17.ff_b.0.weight",
481
+ "model.diffusion_model.joint_blocks.17.context_block.mlp.fc2.bias": "blocks.17.ff_b.2.bias",
482
+ "model.diffusion_model.joint_blocks.17.context_block.mlp.fc2.weight": "blocks.17.ff_b.2.weight",
483
+ "model.diffusion_model.joint_blocks.17.x_block.adaLN_modulation.1.bias": "blocks.17.norm1_a.linear.bias",
484
+ "model.diffusion_model.joint_blocks.17.x_block.adaLN_modulation.1.weight": "blocks.17.norm1_a.linear.weight",
485
+ "model.diffusion_model.joint_blocks.17.x_block.attn.proj.bias": "blocks.17.attn.a_to_out.bias",
486
+ "model.diffusion_model.joint_blocks.17.x_block.attn.proj.weight": "blocks.17.attn.a_to_out.weight",
487
+ "model.diffusion_model.joint_blocks.17.x_block.attn.qkv.bias": ['blocks.17.attn.a_to_q.bias', 'blocks.17.attn.a_to_k.bias', 'blocks.17.attn.a_to_v.bias'],
488
+ "model.diffusion_model.joint_blocks.17.x_block.attn.qkv.weight": ['blocks.17.attn.a_to_q.weight', 'blocks.17.attn.a_to_k.weight', 'blocks.17.attn.a_to_v.weight'],
489
+ "model.diffusion_model.joint_blocks.17.x_block.mlp.fc1.bias": "blocks.17.ff_a.0.bias",
490
+ "model.diffusion_model.joint_blocks.17.x_block.mlp.fc1.weight": "blocks.17.ff_a.0.weight",
491
+ "model.diffusion_model.joint_blocks.17.x_block.mlp.fc2.bias": "blocks.17.ff_a.2.bias",
492
+ "model.diffusion_model.joint_blocks.17.x_block.mlp.fc2.weight": "blocks.17.ff_a.2.weight",
493
+ "model.diffusion_model.joint_blocks.18.context_block.adaLN_modulation.1.bias": "blocks.18.norm1_b.linear.bias",
494
+ "model.diffusion_model.joint_blocks.18.context_block.adaLN_modulation.1.weight": "blocks.18.norm1_b.linear.weight",
495
+ "model.diffusion_model.joint_blocks.18.context_block.attn.proj.bias": "blocks.18.attn.b_to_out.bias",
496
+ "model.diffusion_model.joint_blocks.18.context_block.attn.proj.weight": "blocks.18.attn.b_to_out.weight",
497
+ "model.diffusion_model.joint_blocks.18.context_block.attn.qkv.bias": ['blocks.18.attn.b_to_q.bias', 'blocks.18.attn.b_to_k.bias', 'blocks.18.attn.b_to_v.bias'],
498
+ "model.diffusion_model.joint_blocks.18.context_block.attn.qkv.weight": ['blocks.18.attn.b_to_q.weight', 'blocks.18.attn.b_to_k.weight', 'blocks.18.attn.b_to_v.weight'],
499
+ "model.diffusion_model.joint_blocks.18.context_block.mlp.fc1.bias": "blocks.18.ff_b.0.bias",
500
+ "model.diffusion_model.joint_blocks.18.context_block.mlp.fc1.weight": "blocks.18.ff_b.0.weight",
501
+ "model.diffusion_model.joint_blocks.18.context_block.mlp.fc2.bias": "blocks.18.ff_b.2.bias",
502
+ "model.diffusion_model.joint_blocks.18.context_block.mlp.fc2.weight": "blocks.18.ff_b.2.weight",
503
+ "model.diffusion_model.joint_blocks.18.x_block.adaLN_modulation.1.bias": "blocks.18.norm1_a.linear.bias",
504
+ "model.diffusion_model.joint_blocks.18.x_block.adaLN_modulation.1.weight": "blocks.18.norm1_a.linear.weight",
505
+ "model.diffusion_model.joint_blocks.18.x_block.attn.proj.bias": "blocks.18.attn.a_to_out.bias",
506
+ "model.diffusion_model.joint_blocks.18.x_block.attn.proj.weight": "blocks.18.attn.a_to_out.weight",
507
+ "model.diffusion_model.joint_blocks.18.x_block.attn.qkv.bias": ['blocks.18.attn.a_to_q.bias', 'blocks.18.attn.a_to_k.bias', 'blocks.18.attn.a_to_v.bias'],
508
+ "model.diffusion_model.joint_blocks.18.x_block.attn.qkv.weight": ['blocks.18.attn.a_to_q.weight', 'blocks.18.attn.a_to_k.weight', 'blocks.18.attn.a_to_v.weight'],
509
+ "model.diffusion_model.joint_blocks.18.x_block.mlp.fc1.bias": "blocks.18.ff_a.0.bias",
510
+ "model.diffusion_model.joint_blocks.18.x_block.mlp.fc1.weight": "blocks.18.ff_a.0.weight",
511
+ "model.diffusion_model.joint_blocks.18.x_block.mlp.fc2.bias": "blocks.18.ff_a.2.bias",
512
+ "model.diffusion_model.joint_blocks.18.x_block.mlp.fc2.weight": "blocks.18.ff_a.2.weight",
513
+ "model.diffusion_model.joint_blocks.19.context_block.adaLN_modulation.1.bias": "blocks.19.norm1_b.linear.bias",
514
+ "model.diffusion_model.joint_blocks.19.context_block.adaLN_modulation.1.weight": "blocks.19.norm1_b.linear.weight",
515
+ "model.diffusion_model.joint_blocks.19.context_block.attn.proj.bias": "blocks.19.attn.b_to_out.bias",
516
+ "model.diffusion_model.joint_blocks.19.context_block.attn.proj.weight": "blocks.19.attn.b_to_out.weight",
517
+ "model.diffusion_model.joint_blocks.19.context_block.attn.qkv.bias": ['blocks.19.attn.b_to_q.bias', 'blocks.19.attn.b_to_k.bias', 'blocks.19.attn.b_to_v.bias'],
518
+ "model.diffusion_model.joint_blocks.19.context_block.attn.qkv.weight": ['blocks.19.attn.b_to_q.weight', 'blocks.19.attn.b_to_k.weight', 'blocks.19.attn.b_to_v.weight'],
519
+ "model.diffusion_model.joint_blocks.19.context_block.mlp.fc1.bias": "blocks.19.ff_b.0.bias",
520
+ "model.diffusion_model.joint_blocks.19.context_block.mlp.fc1.weight": "blocks.19.ff_b.0.weight",
521
+ "model.diffusion_model.joint_blocks.19.context_block.mlp.fc2.bias": "blocks.19.ff_b.2.bias",
522
+ "model.diffusion_model.joint_blocks.19.context_block.mlp.fc2.weight": "blocks.19.ff_b.2.weight",
523
+ "model.diffusion_model.joint_blocks.19.x_block.adaLN_modulation.1.bias": "blocks.19.norm1_a.linear.bias",
524
+ "model.diffusion_model.joint_blocks.19.x_block.adaLN_modulation.1.weight": "blocks.19.norm1_a.linear.weight",
525
+ "model.diffusion_model.joint_blocks.19.x_block.attn.proj.bias": "blocks.19.attn.a_to_out.bias",
526
+ "model.diffusion_model.joint_blocks.19.x_block.attn.proj.weight": "blocks.19.attn.a_to_out.weight",
527
+ "model.diffusion_model.joint_blocks.19.x_block.attn.qkv.bias": ['blocks.19.attn.a_to_q.bias', 'blocks.19.attn.a_to_k.bias', 'blocks.19.attn.a_to_v.bias'],
528
+ "model.diffusion_model.joint_blocks.19.x_block.attn.qkv.weight": ['blocks.19.attn.a_to_q.weight', 'blocks.19.attn.a_to_k.weight', 'blocks.19.attn.a_to_v.weight'],
529
+ "model.diffusion_model.joint_blocks.19.x_block.mlp.fc1.bias": "blocks.19.ff_a.0.bias",
530
+ "model.diffusion_model.joint_blocks.19.x_block.mlp.fc1.weight": "blocks.19.ff_a.0.weight",
531
+ "model.diffusion_model.joint_blocks.19.x_block.mlp.fc2.bias": "blocks.19.ff_a.2.bias",
532
+ "model.diffusion_model.joint_blocks.19.x_block.mlp.fc2.weight": "blocks.19.ff_a.2.weight",
533
+ "model.diffusion_model.joint_blocks.2.context_block.adaLN_modulation.1.bias": "blocks.2.norm1_b.linear.bias",
534
+ "model.diffusion_model.joint_blocks.2.context_block.adaLN_modulation.1.weight": "blocks.2.norm1_b.linear.weight",
535
+ "model.diffusion_model.joint_blocks.2.context_block.attn.proj.bias": "blocks.2.attn.b_to_out.bias",
536
+ "model.diffusion_model.joint_blocks.2.context_block.attn.proj.weight": "blocks.2.attn.b_to_out.weight",
537
+ "model.diffusion_model.joint_blocks.2.context_block.attn.qkv.bias": ['blocks.2.attn.b_to_q.bias', 'blocks.2.attn.b_to_k.bias', 'blocks.2.attn.b_to_v.bias'],
538
+ "model.diffusion_model.joint_blocks.2.context_block.attn.qkv.weight": ['blocks.2.attn.b_to_q.weight', 'blocks.2.attn.b_to_k.weight', 'blocks.2.attn.b_to_v.weight'],
539
+ "model.diffusion_model.joint_blocks.2.context_block.mlp.fc1.bias": "blocks.2.ff_b.0.bias",
540
+ "model.diffusion_model.joint_blocks.2.context_block.mlp.fc1.weight": "blocks.2.ff_b.0.weight",
541
+ "model.diffusion_model.joint_blocks.2.context_block.mlp.fc2.bias": "blocks.2.ff_b.2.bias",
542
+ "model.diffusion_model.joint_blocks.2.context_block.mlp.fc2.weight": "blocks.2.ff_b.2.weight",
543
+ "model.diffusion_model.joint_blocks.2.x_block.adaLN_modulation.1.bias": "blocks.2.norm1_a.linear.bias",
544
+ "model.diffusion_model.joint_blocks.2.x_block.adaLN_modulation.1.weight": "blocks.2.norm1_a.linear.weight",
545
+ "model.diffusion_model.joint_blocks.2.x_block.attn.proj.bias": "blocks.2.attn.a_to_out.bias",
546
+ "model.diffusion_model.joint_blocks.2.x_block.attn.proj.weight": "blocks.2.attn.a_to_out.weight",
547
+ "model.diffusion_model.joint_blocks.2.x_block.attn.qkv.bias": ['blocks.2.attn.a_to_q.bias', 'blocks.2.attn.a_to_k.bias', 'blocks.2.attn.a_to_v.bias'],
548
+ "model.diffusion_model.joint_blocks.2.x_block.attn.qkv.weight": ['blocks.2.attn.a_to_q.weight', 'blocks.2.attn.a_to_k.weight', 'blocks.2.attn.a_to_v.weight'],
549
+ "model.diffusion_model.joint_blocks.2.x_block.mlp.fc1.bias": "blocks.2.ff_a.0.bias",
550
+ "model.diffusion_model.joint_blocks.2.x_block.mlp.fc1.weight": "blocks.2.ff_a.0.weight",
551
+ "model.diffusion_model.joint_blocks.2.x_block.mlp.fc2.bias": "blocks.2.ff_a.2.bias",
552
+ "model.diffusion_model.joint_blocks.2.x_block.mlp.fc2.weight": "blocks.2.ff_a.2.weight",
553
+ "model.diffusion_model.joint_blocks.20.context_block.adaLN_modulation.1.bias": "blocks.20.norm1_b.linear.bias",
554
+ "model.diffusion_model.joint_blocks.20.context_block.adaLN_modulation.1.weight": "blocks.20.norm1_b.linear.weight",
555
+ "model.diffusion_model.joint_blocks.20.context_block.attn.proj.bias": "blocks.20.attn.b_to_out.bias",
556
+ "model.diffusion_model.joint_blocks.20.context_block.attn.proj.weight": "blocks.20.attn.b_to_out.weight",
557
+ "model.diffusion_model.joint_blocks.20.context_block.attn.qkv.bias": ['blocks.20.attn.b_to_q.bias', 'blocks.20.attn.b_to_k.bias', 'blocks.20.attn.b_to_v.bias'],
558
+ "model.diffusion_model.joint_blocks.20.context_block.attn.qkv.weight": ['blocks.20.attn.b_to_q.weight', 'blocks.20.attn.b_to_k.weight', 'blocks.20.attn.b_to_v.weight'],
559
+ "model.diffusion_model.joint_blocks.20.context_block.mlp.fc1.bias": "blocks.20.ff_b.0.bias",
560
+ "model.diffusion_model.joint_blocks.20.context_block.mlp.fc1.weight": "blocks.20.ff_b.0.weight",
561
+ "model.diffusion_model.joint_blocks.20.context_block.mlp.fc2.bias": "blocks.20.ff_b.2.bias",
562
+ "model.diffusion_model.joint_blocks.20.context_block.mlp.fc2.weight": "blocks.20.ff_b.2.weight",
563
+ "model.diffusion_model.joint_blocks.20.x_block.adaLN_modulation.1.bias": "blocks.20.norm1_a.linear.bias",
564
+ "model.diffusion_model.joint_blocks.20.x_block.adaLN_modulation.1.weight": "blocks.20.norm1_a.linear.weight",
565
+ "model.diffusion_model.joint_blocks.20.x_block.attn.proj.bias": "blocks.20.attn.a_to_out.bias",
566
+ "model.diffusion_model.joint_blocks.20.x_block.attn.proj.weight": "blocks.20.attn.a_to_out.weight",
567
+ "model.diffusion_model.joint_blocks.20.x_block.attn.qkv.bias": ['blocks.20.attn.a_to_q.bias', 'blocks.20.attn.a_to_k.bias', 'blocks.20.attn.a_to_v.bias'],
568
+ "model.diffusion_model.joint_blocks.20.x_block.attn.qkv.weight": ['blocks.20.attn.a_to_q.weight', 'blocks.20.attn.a_to_k.weight', 'blocks.20.attn.a_to_v.weight'],
569
+ "model.diffusion_model.joint_blocks.20.x_block.mlp.fc1.bias": "blocks.20.ff_a.0.bias",
570
+ "model.diffusion_model.joint_blocks.20.x_block.mlp.fc1.weight": "blocks.20.ff_a.0.weight",
571
+ "model.diffusion_model.joint_blocks.20.x_block.mlp.fc2.bias": "blocks.20.ff_a.2.bias",
572
+ "model.diffusion_model.joint_blocks.20.x_block.mlp.fc2.weight": "blocks.20.ff_a.2.weight",
573
+ "model.diffusion_model.joint_blocks.21.context_block.adaLN_modulation.1.bias": "blocks.21.norm1_b.linear.bias",
574
+ "model.diffusion_model.joint_blocks.21.context_block.adaLN_modulation.1.weight": "blocks.21.norm1_b.linear.weight",
575
+ "model.diffusion_model.joint_blocks.21.context_block.attn.proj.bias": "blocks.21.attn.b_to_out.bias",
576
+ "model.diffusion_model.joint_blocks.21.context_block.attn.proj.weight": "blocks.21.attn.b_to_out.weight",
577
+ "model.diffusion_model.joint_blocks.21.context_block.attn.qkv.bias": ['blocks.21.attn.b_to_q.bias', 'blocks.21.attn.b_to_k.bias', 'blocks.21.attn.b_to_v.bias'],
578
+ "model.diffusion_model.joint_blocks.21.context_block.attn.qkv.weight": ['blocks.21.attn.b_to_q.weight', 'blocks.21.attn.b_to_k.weight', 'blocks.21.attn.b_to_v.weight'],
579
+ "model.diffusion_model.joint_blocks.21.context_block.mlp.fc1.bias": "blocks.21.ff_b.0.bias",
580
+ "model.diffusion_model.joint_blocks.21.context_block.mlp.fc1.weight": "blocks.21.ff_b.0.weight",
581
+ "model.diffusion_model.joint_blocks.21.context_block.mlp.fc2.bias": "blocks.21.ff_b.2.bias",
582
+ "model.diffusion_model.joint_blocks.21.context_block.mlp.fc2.weight": "blocks.21.ff_b.2.weight",
583
+ "model.diffusion_model.joint_blocks.21.x_block.adaLN_modulation.1.bias": "blocks.21.norm1_a.linear.bias",
584
+ "model.diffusion_model.joint_blocks.21.x_block.adaLN_modulation.1.weight": "blocks.21.norm1_a.linear.weight",
585
+ "model.diffusion_model.joint_blocks.21.x_block.attn.proj.bias": "blocks.21.attn.a_to_out.bias",
586
+ "model.diffusion_model.joint_blocks.21.x_block.attn.proj.weight": "blocks.21.attn.a_to_out.weight",
587
+ "model.diffusion_model.joint_blocks.21.x_block.attn.qkv.bias": ['blocks.21.attn.a_to_q.bias', 'blocks.21.attn.a_to_k.bias', 'blocks.21.attn.a_to_v.bias'],
588
+ "model.diffusion_model.joint_blocks.21.x_block.attn.qkv.weight": ['blocks.21.attn.a_to_q.weight', 'blocks.21.attn.a_to_k.weight', 'blocks.21.attn.a_to_v.weight'],
589
+ "model.diffusion_model.joint_blocks.21.x_block.mlp.fc1.bias": "blocks.21.ff_a.0.bias",
590
+ "model.diffusion_model.joint_blocks.21.x_block.mlp.fc1.weight": "blocks.21.ff_a.0.weight",
591
+ "model.diffusion_model.joint_blocks.21.x_block.mlp.fc2.bias": "blocks.21.ff_a.2.bias",
592
+ "model.diffusion_model.joint_blocks.21.x_block.mlp.fc2.weight": "blocks.21.ff_a.2.weight",
593
+ "model.diffusion_model.joint_blocks.22.context_block.adaLN_modulation.1.bias": "blocks.22.norm1_b.linear.bias",
594
+ "model.diffusion_model.joint_blocks.22.context_block.adaLN_modulation.1.weight": "blocks.22.norm1_b.linear.weight",
595
+ "model.diffusion_model.joint_blocks.22.context_block.attn.proj.bias": "blocks.22.attn.b_to_out.bias",
596
+ "model.diffusion_model.joint_blocks.22.context_block.attn.proj.weight": "blocks.22.attn.b_to_out.weight",
597
+ "model.diffusion_model.joint_blocks.22.context_block.attn.qkv.bias": ['blocks.22.attn.b_to_q.bias', 'blocks.22.attn.b_to_k.bias', 'blocks.22.attn.b_to_v.bias'],
598
+ "model.diffusion_model.joint_blocks.22.context_block.attn.qkv.weight": ['blocks.22.attn.b_to_q.weight', 'blocks.22.attn.b_to_k.weight', 'blocks.22.attn.b_to_v.weight'],
599
+ "model.diffusion_model.joint_blocks.22.context_block.mlp.fc1.bias": "blocks.22.ff_b.0.bias",
600
+ "model.diffusion_model.joint_blocks.22.context_block.mlp.fc1.weight": "blocks.22.ff_b.0.weight",
601
+ "model.diffusion_model.joint_blocks.22.context_block.mlp.fc2.bias": "blocks.22.ff_b.2.bias",
602
+ "model.diffusion_model.joint_blocks.22.context_block.mlp.fc2.weight": "blocks.22.ff_b.2.weight",
603
+ "model.diffusion_model.joint_blocks.22.x_block.adaLN_modulation.1.bias": "blocks.22.norm1_a.linear.bias",
604
+ "model.diffusion_model.joint_blocks.22.x_block.adaLN_modulation.1.weight": "blocks.22.norm1_a.linear.weight",
605
+ "model.diffusion_model.joint_blocks.22.x_block.attn.proj.bias": "blocks.22.attn.a_to_out.bias",
606
+ "model.diffusion_model.joint_blocks.22.x_block.attn.proj.weight": "blocks.22.attn.a_to_out.weight",
607
+ "model.diffusion_model.joint_blocks.22.x_block.attn.qkv.bias": ['blocks.22.attn.a_to_q.bias', 'blocks.22.attn.a_to_k.bias', 'blocks.22.attn.a_to_v.bias'],
608
+ "model.diffusion_model.joint_blocks.22.x_block.attn.qkv.weight": ['blocks.22.attn.a_to_q.weight', 'blocks.22.attn.a_to_k.weight', 'blocks.22.attn.a_to_v.weight'],
609
+ "model.diffusion_model.joint_blocks.22.x_block.mlp.fc1.bias": "blocks.22.ff_a.0.bias",
610
+ "model.diffusion_model.joint_blocks.22.x_block.mlp.fc1.weight": "blocks.22.ff_a.0.weight",
611
+ "model.diffusion_model.joint_blocks.22.x_block.mlp.fc2.bias": "blocks.22.ff_a.2.bias",
612
+ "model.diffusion_model.joint_blocks.22.x_block.mlp.fc2.weight": "blocks.22.ff_a.2.weight",
613
+ "model.diffusion_model.joint_blocks.23.context_block.attn.qkv.bias": ['blocks.23.attn.b_to_q.bias', 'blocks.23.attn.b_to_k.bias', 'blocks.23.attn.b_to_v.bias'],
614
+ "model.diffusion_model.joint_blocks.23.context_block.attn.qkv.weight": ['blocks.23.attn.b_to_q.weight', 'blocks.23.attn.b_to_k.weight', 'blocks.23.attn.b_to_v.weight'],
615
+ "model.diffusion_model.joint_blocks.23.x_block.adaLN_modulation.1.bias": "blocks.23.norm1_a.linear.bias",
616
+ "model.diffusion_model.joint_blocks.23.x_block.adaLN_modulation.1.weight": "blocks.23.norm1_a.linear.weight",
617
+ "model.diffusion_model.joint_blocks.23.x_block.attn.proj.bias": "blocks.23.attn.a_to_out.bias",
618
+ "model.diffusion_model.joint_blocks.23.x_block.attn.proj.weight": "blocks.23.attn.a_to_out.weight",
619
+ "model.diffusion_model.joint_blocks.23.x_block.attn.qkv.bias": ['blocks.23.attn.a_to_q.bias', 'blocks.23.attn.a_to_k.bias', 'blocks.23.attn.a_to_v.bias'],
620
+ "model.diffusion_model.joint_blocks.23.x_block.attn.qkv.weight": ['blocks.23.attn.a_to_q.weight', 'blocks.23.attn.a_to_k.weight', 'blocks.23.attn.a_to_v.weight'],
621
+ "model.diffusion_model.joint_blocks.23.x_block.mlp.fc1.bias": "blocks.23.ff_a.0.bias",
622
+ "model.diffusion_model.joint_blocks.23.x_block.mlp.fc1.weight": "blocks.23.ff_a.0.weight",
623
+ "model.diffusion_model.joint_blocks.23.x_block.mlp.fc2.bias": "blocks.23.ff_a.2.bias",
624
+ "model.diffusion_model.joint_blocks.23.x_block.mlp.fc2.weight": "blocks.23.ff_a.2.weight",
625
+ "model.diffusion_model.joint_blocks.3.context_block.adaLN_modulation.1.bias": "blocks.3.norm1_b.linear.bias",
626
+ "model.diffusion_model.joint_blocks.3.context_block.adaLN_modulation.1.weight": "blocks.3.norm1_b.linear.weight",
627
+ "model.diffusion_model.joint_blocks.3.context_block.attn.proj.bias": "blocks.3.attn.b_to_out.bias",
628
+ "model.diffusion_model.joint_blocks.3.context_block.attn.proj.weight": "blocks.3.attn.b_to_out.weight",
629
+ "model.diffusion_model.joint_blocks.3.context_block.attn.qkv.bias": ['blocks.3.attn.b_to_q.bias', 'blocks.3.attn.b_to_k.bias', 'blocks.3.attn.b_to_v.bias'],
630
+ "model.diffusion_model.joint_blocks.3.context_block.attn.qkv.weight": ['blocks.3.attn.b_to_q.weight', 'blocks.3.attn.b_to_k.weight', 'blocks.3.attn.b_to_v.weight'],
631
+ "model.diffusion_model.joint_blocks.3.context_block.mlp.fc1.bias": "blocks.3.ff_b.0.bias",
632
+ "model.diffusion_model.joint_blocks.3.context_block.mlp.fc1.weight": "blocks.3.ff_b.0.weight",
633
+ "model.diffusion_model.joint_blocks.3.context_block.mlp.fc2.bias": "blocks.3.ff_b.2.bias",
634
+ "model.diffusion_model.joint_blocks.3.context_block.mlp.fc2.weight": "blocks.3.ff_b.2.weight",
635
+ "model.diffusion_model.joint_blocks.3.x_block.adaLN_modulation.1.bias": "blocks.3.norm1_a.linear.bias",
636
+ "model.diffusion_model.joint_blocks.3.x_block.adaLN_modulation.1.weight": "blocks.3.norm1_a.linear.weight",
637
+ "model.diffusion_model.joint_blocks.3.x_block.attn.proj.bias": "blocks.3.attn.a_to_out.bias",
638
+ "model.diffusion_model.joint_blocks.3.x_block.attn.proj.weight": "blocks.3.attn.a_to_out.weight",
639
+ "model.diffusion_model.joint_blocks.3.x_block.attn.qkv.bias": ['blocks.3.attn.a_to_q.bias', 'blocks.3.attn.a_to_k.bias', 'blocks.3.attn.a_to_v.bias'],
640
+ "model.diffusion_model.joint_blocks.3.x_block.attn.qkv.weight": ['blocks.3.attn.a_to_q.weight', 'blocks.3.attn.a_to_k.weight', 'blocks.3.attn.a_to_v.weight'],
641
+ "model.diffusion_model.joint_blocks.3.x_block.mlp.fc1.bias": "blocks.3.ff_a.0.bias",
642
+ "model.diffusion_model.joint_blocks.3.x_block.mlp.fc1.weight": "blocks.3.ff_a.0.weight",
643
+ "model.diffusion_model.joint_blocks.3.x_block.mlp.fc2.bias": "blocks.3.ff_a.2.bias",
644
+ "model.diffusion_model.joint_blocks.3.x_block.mlp.fc2.weight": "blocks.3.ff_a.2.weight",
645
+ "model.diffusion_model.joint_blocks.4.context_block.adaLN_modulation.1.bias": "blocks.4.norm1_b.linear.bias",
646
+ "model.diffusion_model.joint_blocks.4.context_block.adaLN_modulation.1.weight": "blocks.4.norm1_b.linear.weight",
647
+ "model.diffusion_model.joint_blocks.4.context_block.attn.proj.bias": "blocks.4.attn.b_to_out.bias",
648
+ "model.diffusion_model.joint_blocks.4.context_block.attn.proj.weight": "blocks.4.attn.b_to_out.weight",
649
+ "model.diffusion_model.joint_blocks.4.context_block.attn.qkv.bias": ['blocks.4.attn.b_to_q.bias', 'blocks.4.attn.b_to_k.bias', 'blocks.4.attn.b_to_v.bias'],
650
+ "model.diffusion_model.joint_blocks.4.context_block.attn.qkv.weight": ['blocks.4.attn.b_to_q.weight', 'blocks.4.attn.b_to_k.weight', 'blocks.4.attn.b_to_v.weight'],
651
+ "model.diffusion_model.joint_blocks.4.context_block.mlp.fc1.bias": "blocks.4.ff_b.0.bias",
652
+ "model.diffusion_model.joint_blocks.4.context_block.mlp.fc1.weight": "blocks.4.ff_b.0.weight",
653
+ "model.diffusion_model.joint_blocks.4.context_block.mlp.fc2.bias": "blocks.4.ff_b.2.bias",
654
+ "model.diffusion_model.joint_blocks.4.context_block.mlp.fc2.weight": "blocks.4.ff_b.2.weight",
655
+ "model.diffusion_model.joint_blocks.4.x_block.adaLN_modulation.1.bias": "blocks.4.norm1_a.linear.bias",
656
+ "model.diffusion_model.joint_blocks.4.x_block.adaLN_modulation.1.weight": "blocks.4.norm1_a.linear.weight",
657
+ "model.diffusion_model.joint_blocks.4.x_block.attn.proj.bias": "blocks.4.attn.a_to_out.bias",
658
+ "model.diffusion_model.joint_blocks.4.x_block.attn.proj.weight": "blocks.4.attn.a_to_out.weight",
659
+ "model.diffusion_model.joint_blocks.4.x_block.attn.qkv.bias": ['blocks.4.attn.a_to_q.bias', 'blocks.4.attn.a_to_k.bias', 'blocks.4.attn.a_to_v.bias'],
660
+ "model.diffusion_model.joint_blocks.4.x_block.attn.qkv.weight": ['blocks.4.attn.a_to_q.weight', 'blocks.4.attn.a_to_k.weight', 'blocks.4.attn.a_to_v.weight'],
661
+ "model.diffusion_model.joint_blocks.4.x_block.mlp.fc1.bias": "blocks.4.ff_a.0.bias",
662
+ "model.diffusion_model.joint_blocks.4.x_block.mlp.fc1.weight": "blocks.4.ff_a.0.weight",
663
+ "model.diffusion_model.joint_blocks.4.x_block.mlp.fc2.bias": "blocks.4.ff_a.2.bias",
664
+ "model.diffusion_model.joint_blocks.4.x_block.mlp.fc2.weight": "blocks.4.ff_a.2.weight",
665
+ "model.diffusion_model.joint_blocks.5.context_block.adaLN_modulation.1.bias": "blocks.5.norm1_b.linear.bias",
666
+ "model.diffusion_model.joint_blocks.5.context_block.adaLN_modulation.1.weight": "blocks.5.norm1_b.linear.weight",
667
+ "model.diffusion_model.joint_blocks.5.context_block.attn.proj.bias": "blocks.5.attn.b_to_out.bias",
668
+ "model.diffusion_model.joint_blocks.5.context_block.attn.proj.weight": "blocks.5.attn.b_to_out.weight",
669
+ "model.diffusion_model.joint_blocks.5.context_block.attn.qkv.bias": ['blocks.5.attn.b_to_q.bias', 'blocks.5.attn.b_to_k.bias', 'blocks.5.attn.b_to_v.bias'],
670
+ "model.diffusion_model.joint_blocks.5.context_block.attn.qkv.weight": ['blocks.5.attn.b_to_q.weight', 'blocks.5.attn.b_to_k.weight', 'blocks.5.attn.b_to_v.weight'],
671
+ "model.diffusion_model.joint_blocks.5.context_block.mlp.fc1.bias": "blocks.5.ff_b.0.bias",
672
+ "model.diffusion_model.joint_blocks.5.context_block.mlp.fc1.weight": "blocks.5.ff_b.0.weight",
673
+ "model.diffusion_model.joint_blocks.5.context_block.mlp.fc2.bias": "blocks.5.ff_b.2.bias",
674
+ "model.diffusion_model.joint_blocks.5.context_block.mlp.fc2.weight": "blocks.5.ff_b.2.weight",
675
+ "model.diffusion_model.joint_blocks.5.x_block.adaLN_modulation.1.bias": "blocks.5.norm1_a.linear.bias",
676
+ "model.diffusion_model.joint_blocks.5.x_block.adaLN_modulation.1.weight": "blocks.5.norm1_a.linear.weight",
677
+ "model.diffusion_model.joint_blocks.5.x_block.attn.proj.bias": "blocks.5.attn.a_to_out.bias",
678
+ "model.diffusion_model.joint_blocks.5.x_block.attn.proj.weight": "blocks.5.attn.a_to_out.weight",
679
+ "model.diffusion_model.joint_blocks.5.x_block.attn.qkv.bias": ['blocks.5.attn.a_to_q.bias', 'blocks.5.attn.a_to_k.bias', 'blocks.5.attn.a_to_v.bias'],
680
+ "model.diffusion_model.joint_blocks.5.x_block.attn.qkv.weight": ['blocks.5.attn.a_to_q.weight', 'blocks.5.attn.a_to_k.weight', 'blocks.5.attn.a_to_v.weight'],
681
+ "model.diffusion_model.joint_blocks.5.x_block.mlp.fc1.bias": "blocks.5.ff_a.0.bias",
682
+ "model.diffusion_model.joint_blocks.5.x_block.mlp.fc1.weight": "blocks.5.ff_a.0.weight",
683
+ "model.diffusion_model.joint_blocks.5.x_block.mlp.fc2.bias": "blocks.5.ff_a.2.bias",
684
+ "model.diffusion_model.joint_blocks.5.x_block.mlp.fc2.weight": "blocks.5.ff_a.2.weight",
685
+ "model.diffusion_model.joint_blocks.6.context_block.adaLN_modulation.1.bias": "blocks.6.norm1_b.linear.bias",
686
+ "model.diffusion_model.joint_blocks.6.context_block.adaLN_modulation.1.weight": "blocks.6.norm1_b.linear.weight",
687
+ "model.diffusion_model.joint_blocks.6.context_block.attn.proj.bias": "blocks.6.attn.b_to_out.bias",
688
+ "model.diffusion_model.joint_blocks.6.context_block.attn.proj.weight": "blocks.6.attn.b_to_out.weight",
689
+ "model.diffusion_model.joint_blocks.6.context_block.attn.qkv.bias": ['blocks.6.attn.b_to_q.bias', 'blocks.6.attn.b_to_k.bias', 'blocks.6.attn.b_to_v.bias'],
690
+ "model.diffusion_model.joint_blocks.6.context_block.attn.qkv.weight": ['blocks.6.attn.b_to_q.weight', 'blocks.6.attn.b_to_k.weight', 'blocks.6.attn.b_to_v.weight'],
691
+ "model.diffusion_model.joint_blocks.6.context_block.mlp.fc1.bias": "blocks.6.ff_b.0.bias",
692
+ "model.diffusion_model.joint_blocks.6.context_block.mlp.fc1.weight": "blocks.6.ff_b.0.weight",
693
+ "model.diffusion_model.joint_blocks.6.context_block.mlp.fc2.bias": "blocks.6.ff_b.2.bias",
694
+ "model.diffusion_model.joint_blocks.6.context_block.mlp.fc2.weight": "blocks.6.ff_b.2.weight",
695
+ "model.diffusion_model.joint_blocks.6.x_block.adaLN_modulation.1.bias": "blocks.6.norm1_a.linear.bias",
696
+ "model.diffusion_model.joint_blocks.6.x_block.adaLN_modulation.1.weight": "blocks.6.norm1_a.linear.weight",
697
+ "model.diffusion_model.joint_blocks.6.x_block.attn.proj.bias": "blocks.6.attn.a_to_out.bias",
698
+ "model.diffusion_model.joint_blocks.6.x_block.attn.proj.weight": "blocks.6.attn.a_to_out.weight",
699
+ "model.diffusion_model.joint_blocks.6.x_block.attn.qkv.bias": ['blocks.6.attn.a_to_q.bias', 'blocks.6.attn.a_to_k.bias', 'blocks.6.attn.a_to_v.bias'],
700
+ "model.diffusion_model.joint_blocks.6.x_block.attn.qkv.weight": ['blocks.6.attn.a_to_q.weight', 'blocks.6.attn.a_to_k.weight', 'blocks.6.attn.a_to_v.weight'],
701
+ "model.diffusion_model.joint_blocks.6.x_block.mlp.fc1.bias": "blocks.6.ff_a.0.bias",
702
+ "model.diffusion_model.joint_blocks.6.x_block.mlp.fc1.weight": "blocks.6.ff_a.0.weight",
703
+ "model.diffusion_model.joint_blocks.6.x_block.mlp.fc2.bias": "blocks.6.ff_a.2.bias",
704
+ "model.diffusion_model.joint_blocks.6.x_block.mlp.fc2.weight": "blocks.6.ff_a.2.weight",
705
+ "model.diffusion_model.joint_blocks.7.context_block.adaLN_modulation.1.bias": "blocks.7.norm1_b.linear.bias",
706
+ "model.diffusion_model.joint_blocks.7.context_block.adaLN_modulation.1.weight": "blocks.7.norm1_b.linear.weight",
707
+ "model.diffusion_model.joint_blocks.7.context_block.attn.proj.bias": "blocks.7.attn.b_to_out.bias",
708
+ "model.diffusion_model.joint_blocks.7.context_block.attn.proj.weight": "blocks.7.attn.b_to_out.weight",
709
+ "model.diffusion_model.joint_blocks.7.context_block.attn.qkv.bias": ['blocks.7.attn.b_to_q.bias', 'blocks.7.attn.b_to_k.bias', 'blocks.7.attn.b_to_v.bias'],
710
+ "model.diffusion_model.joint_blocks.7.context_block.attn.qkv.weight": ['blocks.7.attn.b_to_q.weight', 'blocks.7.attn.b_to_k.weight', 'blocks.7.attn.b_to_v.weight'],
711
+ "model.diffusion_model.joint_blocks.7.context_block.mlp.fc1.bias": "blocks.7.ff_b.0.bias",
712
+ "model.diffusion_model.joint_blocks.7.context_block.mlp.fc1.weight": "blocks.7.ff_b.0.weight",
713
+ "model.diffusion_model.joint_blocks.7.context_block.mlp.fc2.bias": "blocks.7.ff_b.2.bias",
714
+ "model.diffusion_model.joint_blocks.7.context_block.mlp.fc2.weight": "blocks.7.ff_b.2.weight",
715
+ "model.diffusion_model.joint_blocks.7.x_block.adaLN_modulation.1.bias": "blocks.7.norm1_a.linear.bias",
716
+ "model.diffusion_model.joint_blocks.7.x_block.adaLN_modulation.1.weight": "blocks.7.norm1_a.linear.weight",
717
+ "model.diffusion_model.joint_blocks.7.x_block.attn.proj.bias": "blocks.7.attn.a_to_out.bias",
718
+ "model.diffusion_model.joint_blocks.7.x_block.attn.proj.weight": "blocks.7.attn.a_to_out.weight",
719
+ "model.diffusion_model.joint_blocks.7.x_block.attn.qkv.bias": ['blocks.7.attn.a_to_q.bias', 'blocks.7.attn.a_to_k.bias', 'blocks.7.attn.a_to_v.bias'],
720
+ "model.diffusion_model.joint_blocks.7.x_block.attn.qkv.weight": ['blocks.7.attn.a_to_q.weight', 'blocks.7.attn.a_to_k.weight', 'blocks.7.attn.a_to_v.weight'],
721
+ "model.diffusion_model.joint_blocks.7.x_block.mlp.fc1.bias": "blocks.7.ff_a.0.bias",
722
+ "model.diffusion_model.joint_blocks.7.x_block.mlp.fc1.weight": "blocks.7.ff_a.0.weight",
723
+ "model.diffusion_model.joint_blocks.7.x_block.mlp.fc2.bias": "blocks.7.ff_a.2.bias",
724
+ "model.diffusion_model.joint_blocks.7.x_block.mlp.fc2.weight": "blocks.7.ff_a.2.weight",
725
+ "model.diffusion_model.joint_blocks.8.context_block.adaLN_modulation.1.bias": "blocks.8.norm1_b.linear.bias",
726
+ "model.diffusion_model.joint_blocks.8.context_block.adaLN_modulation.1.weight": "blocks.8.norm1_b.linear.weight",
727
+ "model.diffusion_model.joint_blocks.8.context_block.attn.proj.bias": "blocks.8.attn.b_to_out.bias",
728
+ "model.diffusion_model.joint_blocks.8.context_block.attn.proj.weight": "blocks.8.attn.b_to_out.weight",
729
+ "model.diffusion_model.joint_blocks.8.context_block.attn.qkv.bias": ['blocks.8.attn.b_to_q.bias', 'blocks.8.attn.b_to_k.bias', 'blocks.8.attn.b_to_v.bias'],
730
+ "model.diffusion_model.joint_blocks.8.context_block.attn.qkv.weight": ['blocks.8.attn.b_to_q.weight', 'blocks.8.attn.b_to_k.weight', 'blocks.8.attn.b_to_v.weight'],
731
+ "model.diffusion_model.joint_blocks.8.context_block.mlp.fc1.bias": "blocks.8.ff_b.0.bias",
732
+ "model.diffusion_model.joint_blocks.8.context_block.mlp.fc1.weight": "blocks.8.ff_b.0.weight",
733
+ "model.diffusion_model.joint_blocks.8.context_block.mlp.fc2.bias": "blocks.8.ff_b.2.bias",
734
+ "model.diffusion_model.joint_blocks.8.context_block.mlp.fc2.weight": "blocks.8.ff_b.2.weight",
735
+ "model.diffusion_model.joint_blocks.8.x_block.adaLN_modulation.1.bias": "blocks.8.norm1_a.linear.bias",
736
+ "model.diffusion_model.joint_blocks.8.x_block.adaLN_modulation.1.weight": "blocks.8.norm1_a.linear.weight",
737
+ "model.diffusion_model.joint_blocks.8.x_block.attn.proj.bias": "blocks.8.attn.a_to_out.bias",
738
+ "model.diffusion_model.joint_blocks.8.x_block.attn.proj.weight": "blocks.8.attn.a_to_out.weight",
739
+ "model.diffusion_model.joint_blocks.8.x_block.attn.qkv.bias": ['blocks.8.attn.a_to_q.bias', 'blocks.8.attn.a_to_k.bias', 'blocks.8.attn.a_to_v.bias'],
740
+ "model.diffusion_model.joint_blocks.8.x_block.attn.qkv.weight": ['blocks.8.attn.a_to_q.weight', 'blocks.8.attn.a_to_k.weight', 'blocks.8.attn.a_to_v.weight'],
741
+ "model.diffusion_model.joint_blocks.8.x_block.mlp.fc1.bias": "blocks.8.ff_a.0.bias",
742
+ "model.diffusion_model.joint_blocks.8.x_block.mlp.fc1.weight": "blocks.8.ff_a.0.weight",
743
+ "model.diffusion_model.joint_blocks.8.x_block.mlp.fc2.bias": "blocks.8.ff_a.2.bias",
744
+ "model.diffusion_model.joint_blocks.8.x_block.mlp.fc2.weight": "blocks.8.ff_a.2.weight",
745
+ "model.diffusion_model.joint_blocks.9.context_block.adaLN_modulation.1.bias": "blocks.9.norm1_b.linear.bias",
746
+ "model.diffusion_model.joint_blocks.9.context_block.adaLN_modulation.1.weight": "blocks.9.norm1_b.linear.weight",
747
+ "model.diffusion_model.joint_blocks.9.context_block.attn.proj.bias": "blocks.9.attn.b_to_out.bias",
748
+ "model.diffusion_model.joint_blocks.9.context_block.attn.proj.weight": "blocks.9.attn.b_to_out.weight",
749
+ "model.diffusion_model.joint_blocks.9.context_block.attn.qkv.bias": ['blocks.9.attn.b_to_q.bias', 'blocks.9.attn.b_to_k.bias', 'blocks.9.attn.b_to_v.bias'],
750
+ "model.diffusion_model.joint_blocks.9.context_block.attn.qkv.weight": ['blocks.9.attn.b_to_q.weight', 'blocks.9.attn.b_to_k.weight', 'blocks.9.attn.b_to_v.weight'],
751
+ "model.diffusion_model.joint_blocks.9.context_block.mlp.fc1.bias": "blocks.9.ff_b.0.bias",
752
+ "model.diffusion_model.joint_blocks.9.context_block.mlp.fc1.weight": "blocks.9.ff_b.0.weight",
753
+ "model.diffusion_model.joint_blocks.9.context_block.mlp.fc2.bias": "blocks.9.ff_b.2.bias",
754
+ "model.diffusion_model.joint_blocks.9.context_block.mlp.fc2.weight": "blocks.9.ff_b.2.weight",
755
+ "model.diffusion_model.joint_blocks.9.x_block.adaLN_modulation.1.bias": "blocks.9.norm1_a.linear.bias",
756
+ "model.diffusion_model.joint_blocks.9.x_block.adaLN_modulation.1.weight": "blocks.9.norm1_a.linear.weight",
757
+ "model.diffusion_model.joint_blocks.9.x_block.attn.proj.bias": "blocks.9.attn.a_to_out.bias",
758
+ "model.diffusion_model.joint_blocks.9.x_block.attn.proj.weight": "blocks.9.attn.a_to_out.weight",
759
+ "model.diffusion_model.joint_blocks.9.x_block.attn.qkv.bias": ['blocks.9.attn.a_to_q.bias', 'blocks.9.attn.a_to_k.bias', 'blocks.9.attn.a_to_v.bias'],
760
+ "model.diffusion_model.joint_blocks.9.x_block.attn.qkv.weight": ['blocks.9.attn.a_to_q.weight', 'blocks.9.attn.a_to_k.weight', 'blocks.9.attn.a_to_v.weight'],
761
+ "model.diffusion_model.joint_blocks.9.x_block.mlp.fc1.bias": "blocks.9.ff_a.0.bias",
762
+ "model.diffusion_model.joint_blocks.9.x_block.mlp.fc1.weight": "blocks.9.ff_a.0.weight",
763
+ "model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.bias": "blocks.9.ff_a.2.bias",
764
+ "model.diffusion_model.joint_blocks.9.x_block.mlp.fc2.weight": "blocks.9.ff_a.2.weight",
765
+ "model.diffusion_model.pos_embed": "pos_embedder.pos_embed",
766
+ "model.diffusion_model.t_embedder.mlp.0.bias": "time_embedder.timestep_embedder.0.bias",
767
+ "model.diffusion_model.t_embedder.mlp.0.weight": "time_embedder.timestep_embedder.0.weight",
768
+ "model.diffusion_model.t_embedder.mlp.2.bias": "time_embedder.timestep_embedder.2.bias",
769
+ "model.diffusion_model.t_embedder.mlp.2.weight": "time_embedder.timestep_embedder.2.weight",
770
+ "model.diffusion_model.x_embedder.proj.bias": "pos_embedder.proj.bias",
771
+ "model.diffusion_model.x_embedder.proj.weight": "pos_embedder.proj.weight",
772
+ "model.diffusion_model.y_embedder.mlp.0.bias": "pooled_text_embedder.0.bias",
773
+ "model.diffusion_model.y_embedder.mlp.0.weight": "pooled_text_embedder.0.weight",
774
+ "model.diffusion_model.y_embedder.mlp.2.bias": "pooled_text_embedder.2.bias",
775
+ "model.diffusion_model.y_embedder.mlp.2.weight": "pooled_text_embedder.2.weight",
776
+
777
+ "model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.weight": "blocks.23.norm1_b.linear.weight",
778
+ "model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.bias": "blocks.23.norm1_b.linear.bias",
779
+ "model.diffusion_model.final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
780
+ "model.diffusion_model.final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
781
+ }
782
+ state_dict_ = {}
783
+ for name in state_dict:
784
+ if name in rename_dict:
785
+ param = state_dict[name]
786
+ if name.startswith("model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1."):
787
+ param = torch.concat([param[1536:], param[:1536]], axis=0)
788
+ elif name.startswith("model.diffusion_model.final_layer.adaLN_modulation.1."):
789
+ param = torch.concat([param[1536:], param[:1536]], axis=0)
790
+ elif name == "model.diffusion_model.pos_embed":
791
+ param = param.reshape((1, 192, 192, 1536))
792
+ if isinstance(rename_dict[name], str):
793
+ state_dict_[rename_dict[name]] = param
794
+ else:
795
+ name_ = rename_dict[name][0].replace(".a_to_q.", ".a_to_qkv.").replace(".b_to_q.", ".b_to_qkv.")
796
+ state_dict_[name_] = param
797
+ return state_dict_
diffsynth/models/sd3_text_encoder.py ADDED
The diff for this file is too large to render. See raw diff
 
diffsynth/models/sd3_vae_decoder.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .sd_vae_decoder import VAEAttentionBlock, SDVAEDecoderStateDictConverter
3
+ from .sd_unet import ResnetBlock, UpSampler
4
+ from .tiler import TileWorker
5
+
6
+
7
+
8
+ class SD3VAEDecoder(torch.nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+ self.scaling_factor = 1.5305 # Different from SD 1.x
12
+ self.shift_factor = 0.0609 # Different from SD 1.x
13
+ self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x
14
+
15
+ self.blocks = torch.nn.ModuleList([
16
+ # UNetMidBlock2D
17
+ ResnetBlock(512, 512, eps=1e-6),
18
+ VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
19
+ ResnetBlock(512, 512, eps=1e-6),
20
+ # UpDecoderBlock2D
21
+ ResnetBlock(512, 512, eps=1e-6),
22
+ ResnetBlock(512, 512, eps=1e-6),
23
+ ResnetBlock(512, 512, eps=1e-6),
24
+ UpSampler(512),
25
+ # UpDecoderBlock2D
26
+ ResnetBlock(512, 512, eps=1e-6),
27
+ ResnetBlock(512, 512, eps=1e-6),
28
+ ResnetBlock(512, 512, eps=1e-6),
29
+ UpSampler(512),
30
+ # UpDecoderBlock2D
31
+ ResnetBlock(512, 256, eps=1e-6),
32
+ ResnetBlock(256, 256, eps=1e-6),
33
+ ResnetBlock(256, 256, eps=1e-6),
34
+ UpSampler(256),
35
+ # UpDecoderBlock2D
36
+ ResnetBlock(256, 128, eps=1e-6),
37
+ ResnetBlock(128, 128, eps=1e-6),
38
+ ResnetBlock(128, 128, eps=1e-6),
39
+ ])
40
+
41
+ self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6)
42
+ self.conv_act = torch.nn.SiLU()
43
+ self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
44
+
45
+ def tiled_forward(self, sample, tile_size=64, tile_stride=32):
46
+ hidden_states = TileWorker().tiled_forward(
47
+ lambda x: self.forward(x),
48
+ sample,
49
+ tile_size,
50
+ tile_stride,
51
+ tile_device=sample.device,
52
+ tile_dtype=sample.dtype
53
+ )
54
+ return hidden_states
55
+
56
+ def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
57
+ # For VAE Decoder, we do not need to apply the tiler on each layer.
58
+ if tiled:
59
+ return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
60
+
61
+ # 1. pre-process
62
+ hidden_states = sample / self.scaling_factor + self.shift_factor
63
+ hidden_states = self.conv_in(hidden_states)
64
+ time_emb = None
65
+ text_emb = None
66
+ res_stack = None
67
+
68
+ # 2. blocks
69
+ for i, block in enumerate(self.blocks):
70
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
71
+
72
+ # 3. output
73
+ hidden_states = self.conv_norm_out(hidden_states)
74
+ hidden_states = self.conv_act(hidden_states)
75
+ hidden_states = self.conv_out(hidden_states)
76
+
77
+ return hidden_states
78
+
79
+ def state_dict_converter(self):
80
+ return SDVAEDecoderStateDictConverter()
diffsynth/models/sd3_vae_encoder.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .sd_unet import ResnetBlock, DownSampler
3
+ from .sd_vae_encoder import VAEAttentionBlock, SDVAEEncoderStateDictConverter
4
+ from .tiler import TileWorker
5
+ from einops import rearrange
6
+
7
+
8
+ class SD3VAEEncoder(torch.nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+ self.scaling_factor = 1.5305 # Different from SD 1.x
12
+ self.shift_factor = 0.0609 # Different from SD 1.x
13
+ self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
14
+
15
+ self.blocks = torch.nn.ModuleList([
16
+ # DownEncoderBlock2D
17
+ ResnetBlock(128, 128, eps=1e-6),
18
+ ResnetBlock(128, 128, eps=1e-6),
19
+ DownSampler(128, padding=0, extra_padding=True),
20
+ # DownEncoderBlock2D
21
+ ResnetBlock(128, 256, eps=1e-6),
22
+ ResnetBlock(256, 256, eps=1e-6),
23
+ DownSampler(256, padding=0, extra_padding=True),
24
+ # DownEncoderBlock2D
25
+ ResnetBlock(256, 512, eps=1e-6),
26
+ ResnetBlock(512, 512, eps=1e-6),
27
+ DownSampler(512, padding=0, extra_padding=True),
28
+ # DownEncoderBlock2D
29
+ ResnetBlock(512, 512, eps=1e-6),
30
+ ResnetBlock(512, 512, eps=1e-6),
31
+ # UNetMidBlock2D
32
+ ResnetBlock(512, 512, eps=1e-6),
33
+ VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
34
+ ResnetBlock(512, 512, eps=1e-6),
35
+ ])
36
+
37
+ self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
38
+ self.conv_act = torch.nn.SiLU()
39
+ self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1)
40
+
41
+ def tiled_forward(self, sample, tile_size=64, tile_stride=32):
42
+ hidden_states = TileWorker().tiled_forward(
43
+ lambda x: self.forward(x),
44
+ sample,
45
+ tile_size,
46
+ tile_stride,
47
+ tile_device=sample.device,
48
+ tile_dtype=sample.dtype
49
+ )
50
+ return hidden_states
51
+
52
+ def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
53
+ # For VAE Decoder, we do not need to apply the tiler on each layer.
54
+ if tiled:
55
+ return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
56
+
57
+ # 1. pre-process
58
+ hidden_states = self.conv_in(sample)
59
+ time_emb = None
60
+ text_emb = None
61
+ res_stack = None
62
+
63
+ # 2. blocks
64
+ for i, block in enumerate(self.blocks):
65
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
66
+
67
+ # 3. output
68
+ hidden_states = self.conv_norm_out(hidden_states)
69
+ hidden_states = self.conv_act(hidden_states)
70
+ hidden_states = self.conv_out(hidden_states)
71
+ hidden_states = hidden_states[:, :16]
72
+ hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor
73
+
74
+ return hidden_states
75
+
76
+ def encode_video(self, sample, batch_size=8):
77
+ B = sample.shape[0]
78
+ hidden_states = []
79
+
80
+ for i in range(0, sample.shape[2], batch_size):
81
+
82
+ j = min(i + batch_size, sample.shape[2])
83
+ sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
84
+
85
+ hidden_states_batch = self(sample_batch)
86
+ hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
87
+
88
+ hidden_states.append(hidden_states_batch)
89
+
90
+ hidden_states = torch.concat(hidden_states, dim=2)
91
+ return hidden_states
92
+
93
+ def state_dict_converter(self):
94
+ return SDVAEEncoderStateDictConverter()
diffsynth/models/sd_controlnet.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
3
+ from .tiler import TileWorker
4
+
5
+
6
+ class ControlNetConditioningLayer(torch.nn.Module):
7
+ def __init__(self, channels = (3, 16, 32, 96, 256, 320)):
8
+ super().__init__()
9
+ self.blocks = torch.nn.ModuleList([])
10
+ self.blocks.append(torch.nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1))
11
+ self.blocks.append(torch.nn.SiLU())
12
+ for i in range(1, len(channels) - 2):
13
+ self.blocks.append(torch.nn.Conv2d(channels[i], channels[i], kernel_size=3, padding=1))
14
+ self.blocks.append(torch.nn.SiLU())
15
+ self.blocks.append(torch.nn.Conv2d(channels[i], channels[i+1], kernel_size=3, padding=1, stride=2))
16
+ self.blocks.append(torch.nn.SiLU())
17
+ self.blocks.append(torch.nn.Conv2d(channels[-2], channels[-1], kernel_size=3, padding=1))
18
+
19
+ def forward(self, conditioning):
20
+ for block in self.blocks:
21
+ conditioning = block(conditioning)
22
+ return conditioning
23
+
24
+
25
+ class SDControlNet(torch.nn.Module):
26
+ def __init__(self, global_pool=False):
27
+ super().__init__()
28
+ self.time_proj = Timesteps(320)
29
+ self.time_embedding = torch.nn.Sequential(
30
+ torch.nn.Linear(320, 1280),
31
+ torch.nn.SiLU(),
32
+ torch.nn.Linear(1280, 1280)
33
+ )
34
+ self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
35
+
36
+ self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320))
37
+
38
+ self.blocks = torch.nn.ModuleList([
39
+ # CrossAttnDownBlock2D
40
+ ResnetBlock(320, 320, 1280),
41
+ AttentionBlock(8, 40, 320, 1, 768),
42
+ PushBlock(),
43
+ ResnetBlock(320, 320, 1280),
44
+ AttentionBlock(8, 40, 320, 1, 768),
45
+ PushBlock(),
46
+ DownSampler(320),
47
+ PushBlock(),
48
+ # CrossAttnDownBlock2D
49
+ ResnetBlock(320, 640, 1280),
50
+ AttentionBlock(8, 80, 640, 1, 768),
51
+ PushBlock(),
52
+ ResnetBlock(640, 640, 1280),
53
+ AttentionBlock(8, 80, 640, 1, 768),
54
+ PushBlock(),
55
+ DownSampler(640),
56
+ PushBlock(),
57
+ # CrossAttnDownBlock2D
58
+ ResnetBlock(640, 1280, 1280),
59
+ AttentionBlock(8, 160, 1280, 1, 768),
60
+ PushBlock(),
61
+ ResnetBlock(1280, 1280, 1280),
62
+ AttentionBlock(8, 160, 1280, 1, 768),
63
+ PushBlock(),
64
+ DownSampler(1280),
65
+ PushBlock(),
66
+ # DownBlock2D
67
+ ResnetBlock(1280, 1280, 1280),
68
+ PushBlock(),
69
+ ResnetBlock(1280, 1280, 1280),
70
+ PushBlock(),
71
+ # UNetMidBlock2DCrossAttn
72
+ ResnetBlock(1280, 1280, 1280),
73
+ AttentionBlock(8, 160, 1280, 1, 768),
74
+ ResnetBlock(1280, 1280, 1280),
75
+ PushBlock()
76
+ ])
77
+
78
+ self.controlnet_blocks = torch.nn.ModuleList([
79
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
80
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
81
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
82
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
83
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
84
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
85
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
86
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
87
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
88
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
89
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
90
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
91
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
92
+ ])
93
+
94
+ self.global_pool = global_pool
95
+
96
+ def forward(
97
+ self,
98
+ sample, timestep, encoder_hidden_states, conditioning,
99
+ tiled=False, tile_size=64, tile_stride=32,
100
+ ):
101
+ # 1. time
102
+ time_emb = self.time_proj(timestep[None]).to(sample.dtype)
103
+ time_emb = self.time_embedding(time_emb)
104
+ time_emb = time_emb.repeat(sample.shape[0], 1)
105
+
106
+ # 2. pre-process
107
+ height, width = sample.shape[2], sample.shape[3]
108
+ hidden_states = self.conv_in(sample) + self.controlnet_conv_in(conditioning)
109
+ text_emb = encoder_hidden_states
110
+ res_stack = [hidden_states]
111
+
112
+ # 3. blocks
113
+ for i, block in enumerate(self.blocks):
114
+ if tiled and not isinstance(block, PushBlock):
115
+ _, _, inter_height, _ = hidden_states.shape
116
+ resize_scale = inter_height / height
117
+ hidden_states = TileWorker().tiled_forward(
118
+ lambda x: block(x, time_emb, text_emb, res_stack)[0],
119
+ hidden_states,
120
+ int(tile_size * resize_scale),
121
+ int(tile_stride * resize_scale),
122
+ tile_device=hidden_states.device,
123
+ tile_dtype=hidden_states.dtype
124
+ )
125
+ else:
126
+ hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
127
+
128
+ # 4. ControlNet blocks
129
+ controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
130
+
131
+ # pool
132
+ if self.global_pool:
133
+ controlnet_res_stack = [res.mean(dim=(2, 3), keepdim=True) for res in controlnet_res_stack]
134
+
135
+ return controlnet_res_stack
136
+
137
+ def state_dict_converter(self):
138
+ return SDControlNetStateDictConverter()
139
+
140
+
141
+ class SDControlNetStateDictConverter:
142
+ def __init__(self):
143
+ pass
144
+
145
+ def from_diffusers(self, state_dict):
146
+ # architecture
147
+ block_types = [
148
+ 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
149
+ 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
150
+ 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
151
+ 'ResnetBlock', 'PushBlock', 'ResnetBlock', 'PushBlock',
152
+ 'ResnetBlock', 'AttentionBlock', 'ResnetBlock',
153
+ 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'UpSampler',
154
+ 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
155
+ 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
156
+ 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock'
157
+ ]
158
+
159
+ # controlnet_rename_dict
160
+ controlnet_rename_dict = {
161
+ "controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight",
162
+ "controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias",
163
+ "controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight",
164
+ "controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias",
165
+ "controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight",
166
+ "controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias",
167
+ "controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight",
168
+ "controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias",
169
+ "controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight",
170
+ "controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias",
171
+ "controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight",
172
+ "controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias",
173
+ "controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
174
+ "controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
175
+ "controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
176
+ "controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
177
+ }
178
+
179
+ # Rename each parameter
180
+ name_list = sorted([name for name in state_dict])
181
+ rename_dict = {}
182
+ block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
183
+ last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
184
+ for name in name_list:
185
+ names = name.split(".")
186
+ if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
187
+ pass
188
+ elif name in controlnet_rename_dict:
189
+ names = controlnet_rename_dict[name].split(".")
190
+ elif names[0] == "controlnet_down_blocks":
191
+ names[0] = "controlnet_blocks"
192
+ elif names[0] == "controlnet_mid_block":
193
+ names = ["controlnet_blocks", "12", names[-1]]
194
+ elif names[0] in ["time_embedding", "add_embedding"]:
195
+ if names[0] == "add_embedding":
196
+ names[0] = "add_time_embedding"
197
+ names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
198
+ elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
199
+ if names[0] == "mid_block":
200
+ names.insert(1, "0")
201
+ block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
202
+ block_type_with_id = ".".join(names[:4])
203
+ if block_type_with_id != last_block_type_with_id[block_type]:
204
+ block_id[block_type] += 1
205
+ last_block_type_with_id[block_type] = block_type_with_id
206
+ while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
207
+ block_id[block_type] += 1
208
+ block_type_with_id = ".".join(names[:4])
209
+ names = ["blocks", str(block_id[block_type])] + names[4:]
210
+ if "ff" in names:
211
+ ff_index = names.index("ff")
212
+ component = ".".join(names[ff_index:ff_index+3])
213
+ component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
214
+ names = names[:ff_index] + [component] + names[ff_index+3:]
215
+ if "to_out" in names:
216
+ names.pop(names.index("to_out") + 1)
217
+ else:
218
+ raise ValueError(f"Unknown parameters: {name}")
219
+ rename_dict[name] = ".".join(names)
220
+
221
+ # Convert state_dict
222
+ state_dict_ = {}
223
+ for name, param in state_dict.items():
224
+ if ".proj_in." in name or ".proj_out." in name:
225
+ param = param.squeeze()
226
+ if rename_dict[name] in [
227
+ "controlnet_blocks.1.bias", "controlnet_blocks.2.bias", "controlnet_blocks.3.bias", "controlnet_blocks.5.bias", "controlnet_blocks.6.bias",
228
+ "controlnet_blocks.8.bias", "controlnet_blocks.9.bias", "controlnet_blocks.10.bias", "controlnet_blocks.11.bias", "controlnet_blocks.12.bias"
229
+ ]:
230
+ continue
231
+ state_dict_[rename_dict[name]] = param
232
+ return state_dict_
233
+
234
+ def from_civitai(self, state_dict):
235
+ if "mid_block.resnets.1.time_emb_proj.weight" in state_dict:
236
+ # For controlnets in diffusers format
237
+ return self.from_diffusers(state_dict)
238
+ rename_dict = {
239
+ "control_model.time_embed.0.weight": "time_embedding.0.weight",
240
+ "control_model.time_embed.0.bias": "time_embedding.0.bias",
241
+ "control_model.time_embed.2.weight": "time_embedding.2.weight",
242
+ "control_model.time_embed.2.bias": "time_embedding.2.bias",
243
+ "control_model.input_blocks.0.0.weight": "conv_in.weight",
244
+ "control_model.input_blocks.0.0.bias": "conv_in.bias",
245
+ "control_model.input_blocks.1.0.in_layers.0.weight": "blocks.0.norm1.weight",
246
+ "control_model.input_blocks.1.0.in_layers.0.bias": "blocks.0.norm1.bias",
247
+ "control_model.input_blocks.1.0.in_layers.2.weight": "blocks.0.conv1.weight",
248
+ "control_model.input_blocks.1.0.in_layers.2.bias": "blocks.0.conv1.bias",
249
+ "control_model.input_blocks.1.0.emb_layers.1.weight": "blocks.0.time_emb_proj.weight",
250
+ "control_model.input_blocks.1.0.emb_layers.1.bias": "blocks.0.time_emb_proj.bias",
251
+ "control_model.input_blocks.1.0.out_layers.0.weight": "blocks.0.norm2.weight",
252
+ "control_model.input_blocks.1.0.out_layers.0.bias": "blocks.0.norm2.bias",
253
+ "control_model.input_blocks.1.0.out_layers.3.weight": "blocks.0.conv2.weight",
254
+ "control_model.input_blocks.1.0.out_layers.3.bias": "blocks.0.conv2.bias",
255
+ "control_model.input_blocks.1.1.norm.weight": "blocks.1.norm.weight",
256
+ "control_model.input_blocks.1.1.norm.bias": "blocks.1.norm.bias",
257
+ "control_model.input_blocks.1.1.proj_in.weight": "blocks.1.proj_in.weight",
258
+ "control_model.input_blocks.1.1.proj_in.bias": "blocks.1.proj_in.bias",
259
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "blocks.1.transformer_blocks.0.attn1.to_q.weight",
260
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "blocks.1.transformer_blocks.0.attn1.to_k.weight",
261
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "blocks.1.transformer_blocks.0.attn1.to_v.weight",
262
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.1.transformer_blocks.0.attn1.to_out.weight",
263
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.1.transformer_blocks.0.attn1.to_out.bias",
264
+ "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.1.transformer_blocks.0.act_fn.proj.weight",
265
+ "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.1.transformer_blocks.0.act_fn.proj.bias",
266
+ "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "blocks.1.transformer_blocks.0.ff.weight",
267
+ "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "blocks.1.transformer_blocks.0.ff.bias",
268
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "blocks.1.transformer_blocks.0.attn2.to_q.weight",
269
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "blocks.1.transformer_blocks.0.attn2.to_k.weight",
270
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "blocks.1.transformer_blocks.0.attn2.to_v.weight",
271
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.1.transformer_blocks.0.attn2.to_out.weight",
272
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.1.transformer_blocks.0.attn2.to_out.bias",
273
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm1.weight": "blocks.1.transformer_blocks.0.norm1.weight",
274
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm1.bias": "blocks.1.transformer_blocks.0.norm1.bias",
275
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "blocks.1.transformer_blocks.0.norm2.weight",
276
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm2.bias": "blocks.1.transformer_blocks.0.norm2.bias",
277
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "blocks.1.transformer_blocks.0.norm3.weight",
278
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "blocks.1.transformer_blocks.0.norm3.bias",
279
+ "control_model.input_blocks.1.1.proj_out.weight": "blocks.1.proj_out.weight",
280
+ "control_model.input_blocks.1.1.proj_out.bias": "blocks.1.proj_out.bias",
281
+ "control_model.input_blocks.2.0.in_layers.0.weight": "blocks.3.norm1.weight",
282
+ "control_model.input_blocks.2.0.in_layers.0.bias": "blocks.3.norm1.bias",
283
+ "control_model.input_blocks.2.0.in_layers.2.weight": "blocks.3.conv1.weight",
284
+ "control_model.input_blocks.2.0.in_layers.2.bias": "blocks.3.conv1.bias",
285
+ "control_model.input_blocks.2.0.emb_layers.1.weight": "blocks.3.time_emb_proj.weight",
286
+ "control_model.input_blocks.2.0.emb_layers.1.bias": "blocks.3.time_emb_proj.bias",
287
+ "control_model.input_blocks.2.0.out_layers.0.weight": "blocks.3.norm2.weight",
288
+ "control_model.input_blocks.2.0.out_layers.0.bias": "blocks.3.norm2.bias",
289
+ "control_model.input_blocks.2.0.out_layers.3.weight": "blocks.3.conv2.weight",
290
+ "control_model.input_blocks.2.0.out_layers.3.bias": "blocks.3.conv2.bias",
291
+ "control_model.input_blocks.2.1.norm.weight": "blocks.4.norm.weight",
292
+ "control_model.input_blocks.2.1.norm.bias": "blocks.4.norm.bias",
293
+ "control_model.input_blocks.2.1.proj_in.weight": "blocks.4.proj_in.weight",
294
+ "control_model.input_blocks.2.1.proj_in.bias": "blocks.4.proj_in.bias",
295
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "blocks.4.transformer_blocks.0.attn1.to_q.weight",
296
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "blocks.4.transformer_blocks.0.attn1.to_k.weight",
297
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "blocks.4.transformer_blocks.0.attn1.to_v.weight",
298
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.4.transformer_blocks.0.attn1.to_out.weight",
299
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.4.transformer_blocks.0.attn1.to_out.bias",
300
+ "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.4.transformer_blocks.0.act_fn.proj.weight",
301
+ "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.4.transformer_blocks.0.act_fn.proj.bias",
302
+ "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "blocks.4.transformer_blocks.0.ff.weight",
303
+ "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "blocks.4.transformer_blocks.0.ff.bias",
304
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "blocks.4.transformer_blocks.0.attn2.to_q.weight",
305
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "blocks.4.transformer_blocks.0.attn2.to_k.weight",
306
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "blocks.4.transformer_blocks.0.attn2.to_v.weight",
307
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.4.transformer_blocks.0.attn2.to_out.weight",
308
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.4.transformer_blocks.0.attn2.to_out.bias",
309
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm1.weight": "blocks.4.transformer_blocks.0.norm1.weight",
310
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm1.bias": "blocks.4.transformer_blocks.0.norm1.bias",
311
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm2.weight": "blocks.4.transformer_blocks.0.norm2.weight",
312
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm2.bias": "blocks.4.transformer_blocks.0.norm2.bias",
313
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "blocks.4.transformer_blocks.0.norm3.weight",
314
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm3.bias": "blocks.4.transformer_blocks.0.norm3.bias",
315
+ "control_model.input_blocks.2.1.proj_out.weight": "blocks.4.proj_out.weight",
316
+ "control_model.input_blocks.2.1.proj_out.bias": "blocks.4.proj_out.bias",
317
+ "control_model.input_blocks.3.0.op.weight": "blocks.6.conv.weight",
318
+ "control_model.input_blocks.3.0.op.bias": "blocks.6.conv.bias",
319
+ "control_model.input_blocks.4.0.in_layers.0.weight": "blocks.8.norm1.weight",
320
+ "control_model.input_blocks.4.0.in_layers.0.bias": "blocks.8.norm1.bias",
321
+ "control_model.input_blocks.4.0.in_layers.2.weight": "blocks.8.conv1.weight",
322
+ "control_model.input_blocks.4.0.in_layers.2.bias": "blocks.8.conv1.bias",
323
+ "control_model.input_blocks.4.0.emb_layers.1.weight": "blocks.8.time_emb_proj.weight",
324
+ "control_model.input_blocks.4.0.emb_layers.1.bias": "blocks.8.time_emb_proj.bias",
325
+ "control_model.input_blocks.4.0.out_layers.0.weight": "blocks.8.norm2.weight",
326
+ "control_model.input_blocks.4.0.out_layers.0.bias": "blocks.8.norm2.bias",
327
+ "control_model.input_blocks.4.0.out_layers.3.weight": "blocks.8.conv2.weight",
328
+ "control_model.input_blocks.4.0.out_layers.3.bias": "blocks.8.conv2.bias",
329
+ "control_model.input_blocks.4.0.skip_connection.weight": "blocks.8.conv_shortcut.weight",
330
+ "control_model.input_blocks.4.0.skip_connection.bias": "blocks.8.conv_shortcut.bias",
331
+ "control_model.input_blocks.4.1.norm.weight": "blocks.9.norm.weight",
332
+ "control_model.input_blocks.4.1.norm.bias": "blocks.9.norm.bias",
333
+ "control_model.input_blocks.4.1.proj_in.weight": "blocks.9.proj_in.weight",
334
+ "control_model.input_blocks.4.1.proj_in.bias": "blocks.9.proj_in.bias",
335
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.9.transformer_blocks.0.attn1.to_q.weight",
336
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.9.transformer_blocks.0.attn1.to_k.weight",
337
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.9.transformer_blocks.0.attn1.to_v.weight",
338
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.9.transformer_blocks.0.attn1.to_out.weight",
339
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.9.transformer_blocks.0.attn1.to_out.bias",
340
+ "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.9.transformer_blocks.0.act_fn.proj.weight",
341
+ "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.9.transformer_blocks.0.act_fn.proj.bias",
342
+ "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.9.transformer_blocks.0.ff.weight",
343
+ "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.9.transformer_blocks.0.ff.bias",
344
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.9.transformer_blocks.0.attn2.to_q.weight",
345
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.9.transformer_blocks.0.attn2.to_k.weight",
346
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.9.transformer_blocks.0.attn2.to_v.weight",
347
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.9.transformer_blocks.0.attn2.to_out.weight",
348
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.9.transformer_blocks.0.attn2.to_out.bias",
349
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.9.transformer_blocks.0.norm1.weight",
350
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.9.transformer_blocks.0.norm1.bias",
351
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.9.transformer_blocks.0.norm2.weight",
352
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.9.transformer_blocks.0.norm2.bias",
353
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.9.transformer_blocks.0.norm3.weight",
354
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.9.transformer_blocks.0.norm3.bias",
355
+ "control_model.input_blocks.4.1.proj_out.weight": "blocks.9.proj_out.weight",
356
+ "control_model.input_blocks.4.1.proj_out.bias": "blocks.9.proj_out.bias",
357
+ "control_model.input_blocks.5.0.in_layers.0.weight": "blocks.11.norm1.weight",
358
+ "control_model.input_blocks.5.0.in_layers.0.bias": "blocks.11.norm1.bias",
359
+ "control_model.input_blocks.5.0.in_layers.2.weight": "blocks.11.conv1.weight",
360
+ "control_model.input_blocks.5.0.in_layers.2.bias": "blocks.11.conv1.bias",
361
+ "control_model.input_blocks.5.0.emb_layers.1.weight": "blocks.11.time_emb_proj.weight",
362
+ "control_model.input_blocks.5.0.emb_layers.1.bias": "blocks.11.time_emb_proj.bias",
363
+ "control_model.input_blocks.5.0.out_layers.0.weight": "blocks.11.norm2.weight",
364
+ "control_model.input_blocks.5.0.out_layers.0.bias": "blocks.11.norm2.bias",
365
+ "control_model.input_blocks.5.0.out_layers.3.weight": "blocks.11.conv2.weight",
366
+ "control_model.input_blocks.5.0.out_layers.3.bias": "blocks.11.conv2.bias",
367
+ "control_model.input_blocks.5.1.norm.weight": "blocks.12.norm.weight",
368
+ "control_model.input_blocks.5.1.norm.bias": "blocks.12.norm.bias",
369
+ "control_model.input_blocks.5.1.proj_in.weight": "blocks.12.proj_in.weight",
370
+ "control_model.input_blocks.5.1.proj_in.bias": "blocks.12.proj_in.bias",
371
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.12.transformer_blocks.0.attn1.to_q.weight",
372
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.12.transformer_blocks.0.attn1.to_k.weight",
373
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.12.transformer_blocks.0.attn1.to_v.weight",
374
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.12.transformer_blocks.0.attn1.to_out.weight",
375
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.12.transformer_blocks.0.attn1.to_out.bias",
376
+ "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.12.transformer_blocks.0.act_fn.proj.weight",
377
+ "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.12.transformer_blocks.0.act_fn.proj.bias",
378
+ "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.12.transformer_blocks.0.ff.weight",
379
+ "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.12.transformer_blocks.0.ff.bias",
380
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.12.transformer_blocks.0.attn2.to_q.weight",
381
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.12.transformer_blocks.0.attn2.to_k.weight",
382
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.12.transformer_blocks.0.attn2.to_v.weight",
383
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.12.transformer_blocks.0.attn2.to_out.weight",
384
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.12.transformer_blocks.0.attn2.to_out.bias",
385
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.12.transformer_blocks.0.norm1.weight",
386
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.12.transformer_blocks.0.norm1.bias",
387
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.12.transformer_blocks.0.norm2.weight",
388
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.12.transformer_blocks.0.norm2.bias",
389
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.12.transformer_blocks.0.norm3.weight",
390
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.12.transformer_blocks.0.norm3.bias",
391
+ "control_model.input_blocks.5.1.proj_out.weight": "blocks.12.proj_out.weight",
392
+ "control_model.input_blocks.5.1.proj_out.bias": "blocks.12.proj_out.bias",
393
+ "control_model.input_blocks.6.0.op.weight": "blocks.14.conv.weight",
394
+ "control_model.input_blocks.6.0.op.bias": "blocks.14.conv.bias",
395
+ "control_model.input_blocks.7.0.in_layers.0.weight": "blocks.16.norm1.weight",
396
+ "control_model.input_blocks.7.0.in_layers.0.bias": "blocks.16.norm1.bias",
397
+ "control_model.input_blocks.7.0.in_layers.2.weight": "blocks.16.conv1.weight",
398
+ "control_model.input_blocks.7.0.in_layers.2.bias": "blocks.16.conv1.bias",
399
+ "control_model.input_blocks.7.0.emb_layers.1.weight": "blocks.16.time_emb_proj.weight",
400
+ "control_model.input_blocks.7.0.emb_layers.1.bias": "blocks.16.time_emb_proj.bias",
401
+ "control_model.input_blocks.7.0.out_layers.0.weight": "blocks.16.norm2.weight",
402
+ "control_model.input_blocks.7.0.out_layers.0.bias": "blocks.16.norm2.bias",
403
+ "control_model.input_blocks.7.0.out_layers.3.weight": "blocks.16.conv2.weight",
404
+ "control_model.input_blocks.7.0.out_layers.3.bias": "blocks.16.conv2.bias",
405
+ "control_model.input_blocks.7.0.skip_connection.weight": "blocks.16.conv_shortcut.weight",
406
+ "control_model.input_blocks.7.0.skip_connection.bias": "blocks.16.conv_shortcut.bias",
407
+ "control_model.input_blocks.7.1.norm.weight": "blocks.17.norm.weight",
408
+ "control_model.input_blocks.7.1.norm.bias": "blocks.17.norm.bias",
409
+ "control_model.input_blocks.7.1.proj_in.weight": "blocks.17.proj_in.weight",
410
+ "control_model.input_blocks.7.1.proj_in.bias": "blocks.17.proj_in.bias",
411
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "blocks.17.transformer_blocks.0.attn1.to_q.weight",
412
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "blocks.17.transformer_blocks.0.attn1.to_k.weight",
413
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "blocks.17.transformer_blocks.0.attn1.to_v.weight",
414
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.17.transformer_blocks.0.attn1.to_out.weight",
415
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.17.transformer_blocks.0.attn1.to_out.bias",
416
+ "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.17.transformer_blocks.0.act_fn.proj.weight",
417
+ "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.17.transformer_blocks.0.act_fn.proj.bias",
418
+ "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "blocks.17.transformer_blocks.0.ff.weight",
419
+ "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "blocks.17.transformer_blocks.0.ff.bias",
420
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "blocks.17.transformer_blocks.0.attn2.to_q.weight",
421
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "blocks.17.transformer_blocks.0.attn2.to_k.weight",
422
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "blocks.17.transformer_blocks.0.attn2.to_v.weight",
423
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.17.transformer_blocks.0.attn2.to_out.weight",
424
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.17.transformer_blocks.0.attn2.to_out.bias",
425
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "blocks.17.transformer_blocks.0.norm1.weight",
426
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "blocks.17.transformer_blocks.0.norm1.bias",
427
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "blocks.17.transformer_blocks.0.norm2.weight",
428
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "blocks.17.transformer_blocks.0.norm2.bias",
429
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "blocks.17.transformer_blocks.0.norm3.weight",
430
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "blocks.17.transformer_blocks.0.norm3.bias",
431
+ "control_model.input_blocks.7.1.proj_out.weight": "blocks.17.proj_out.weight",
432
+ "control_model.input_blocks.7.1.proj_out.bias": "blocks.17.proj_out.bias",
433
+ "control_model.input_blocks.8.0.in_layers.0.weight": "blocks.19.norm1.weight",
434
+ "control_model.input_blocks.8.0.in_layers.0.bias": "blocks.19.norm1.bias",
435
+ "control_model.input_blocks.8.0.in_layers.2.weight": "blocks.19.conv1.weight",
436
+ "control_model.input_blocks.8.0.in_layers.2.bias": "blocks.19.conv1.bias",
437
+ "control_model.input_blocks.8.0.emb_layers.1.weight": "blocks.19.time_emb_proj.weight",
438
+ "control_model.input_blocks.8.0.emb_layers.1.bias": "blocks.19.time_emb_proj.bias",
439
+ "control_model.input_blocks.8.0.out_layers.0.weight": "blocks.19.norm2.weight",
440
+ "control_model.input_blocks.8.0.out_layers.0.bias": "blocks.19.norm2.bias",
441
+ "control_model.input_blocks.8.0.out_layers.3.weight": "blocks.19.conv2.weight",
442
+ "control_model.input_blocks.8.0.out_layers.3.bias": "blocks.19.conv2.bias",
443
+ "control_model.input_blocks.8.1.norm.weight": "blocks.20.norm.weight",
444
+ "control_model.input_blocks.8.1.norm.bias": "blocks.20.norm.bias",
445
+ "control_model.input_blocks.8.1.proj_in.weight": "blocks.20.proj_in.weight",
446
+ "control_model.input_blocks.8.1.proj_in.bias": "blocks.20.proj_in.bias",
447
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "blocks.20.transformer_blocks.0.attn1.to_q.weight",
448
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "blocks.20.transformer_blocks.0.attn1.to_k.weight",
449
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "blocks.20.transformer_blocks.0.attn1.to_v.weight",
450
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.20.transformer_blocks.0.attn1.to_out.weight",
451
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.20.transformer_blocks.0.attn1.to_out.bias",
452
+ "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.20.transformer_blocks.0.act_fn.proj.weight",
453
+ "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.20.transformer_blocks.0.act_fn.proj.bias",
454
+ "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "blocks.20.transformer_blocks.0.ff.weight",
455
+ "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "blocks.20.transformer_blocks.0.ff.bias",
456
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "blocks.20.transformer_blocks.0.attn2.to_q.weight",
457
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "blocks.20.transformer_blocks.0.attn2.to_k.weight",
458
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "blocks.20.transformer_blocks.0.attn2.to_v.weight",
459
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.20.transformer_blocks.0.attn2.to_out.weight",
460
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.20.transformer_blocks.0.attn2.to_out.bias",
461
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "blocks.20.transformer_blocks.0.norm1.weight",
462
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "blocks.20.transformer_blocks.0.norm1.bias",
463
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "blocks.20.transformer_blocks.0.norm2.weight",
464
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "blocks.20.transformer_blocks.0.norm2.bias",
465
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "blocks.20.transformer_blocks.0.norm3.weight",
466
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "blocks.20.transformer_blocks.0.norm3.bias",
467
+ "control_model.input_blocks.8.1.proj_out.weight": "blocks.20.proj_out.weight",
468
+ "control_model.input_blocks.8.1.proj_out.bias": "blocks.20.proj_out.bias",
469
+ "control_model.input_blocks.9.0.op.weight": "blocks.22.conv.weight",
470
+ "control_model.input_blocks.9.0.op.bias": "blocks.22.conv.bias",
471
+ "control_model.input_blocks.10.0.in_layers.0.weight": "blocks.24.norm1.weight",
472
+ "control_model.input_blocks.10.0.in_layers.0.bias": "blocks.24.norm1.bias",
473
+ "control_model.input_blocks.10.0.in_layers.2.weight": "blocks.24.conv1.weight",
474
+ "control_model.input_blocks.10.0.in_layers.2.bias": "blocks.24.conv1.bias",
475
+ "control_model.input_blocks.10.0.emb_layers.1.weight": "blocks.24.time_emb_proj.weight",
476
+ "control_model.input_blocks.10.0.emb_layers.1.bias": "blocks.24.time_emb_proj.bias",
477
+ "control_model.input_blocks.10.0.out_layers.0.weight": "blocks.24.norm2.weight",
478
+ "control_model.input_blocks.10.0.out_layers.0.bias": "blocks.24.norm2.bias",
479
+ "control_model.input_blocks.10.0.out_layers.3.weight": "blocks.24.conv2.weight",
480
+ "control_model.input_blocks.10.0.out_layers.3.bias": "blocks.24.conv2.bias",
481
+ "control_model.input_blocks.11.0.in_layers.0.weight": "blocks.26.norm1.weight",
482
+ "control_model.input_blocks.11.0.in_layers.0.bias": "blocks.26.norm1.bias",
483
+ "control_model.input_blocks.11.0.in_layers.2.weight": "blocks.26.conv1.weight",
484
+ "control_model.input_blocks.11.0.in_layers.2.bias": "blocks.26.conv1.bias",
485
+ "control_model.input_blocks.11.0.emb_layers.1.weight": "blocks.26.time_emb_proj.weight",
486
+ "control_model.input_blocks.11.0.emb_layers.1.bias": "blocks.26.time_emb_proj.bias",
487
+ "control_model.input_blocks.11.0.out_layers.0.weight": "blocks.26.norm2.weight",
488
+ "control_model.input_blocks.11.0.out_layers.0.bias": "blocks.26.norm2.bias",
489
+ "control_model.input_blocks.11.0.out_layers.3.weight": "blocks.26.conv2.weight",
490
+ "control_model.input_blocks.11.0.out_layers.3.bias": "blocks.26.conv2.bias",
491
+ "control_model.zero_convs.0.0.weight": "controlnet_blocks.0.weight",
492
+ "control_model.zero_convs.0.0.bias": "controlnet_blocks.0.bias",
493
+ "control_model.zero_convs.1.0.weight": "controlnet_blocks.1.weight",
494
+ "control_model.zero_convs.1.0.bias": "controlnet_blocks.0.bias",
495
+ "control_model.zero_convs.2.0.weight": "controlnet_blocks.2.weight",
496
+ "control_model.zero_convs.2.0.bias": "controlnet_blocks.0.bias",
497
+ "control_model.zero_convs.3.0.weight": "controlnet_blocks.3.weight",
498
+ "control_model.zero_convs.3.0.bias": "controlnet_blocks.0.bias",
499
+ "control_model.zero_convs.4.0.weight": "controlnet_blocks.4.weight",
500
+ "control_model.zero_convs.4.0.bias": "controlnet_blocks.4.bias",
501
+ "control_model.zero_convs.5.0.weight": "controlnet_blocks.5.weight",
502
+ "control_model.zero_convs.5.0.bias": "controlnet_blocks.4.bias",
503
+ "control_model.zero_convs.6.0.weight": "controlnet_blocks.6.weight",
504
+ "control_model.zero_convs.6.0.bias": "controlnet_blocks.4.bias",
505
+ "control_model.zero_convs.7.0.weight": "controlnet_blocks.7.weight",
506
+ "control_model.zero_convs.7.0.bias": "controlnet_blocks.7.bias",
507
+ "control_model.zero_convs.8.0.weight": "controlnet_blocks.8.weight",
508
+ "control_model.zero_convs.8.0.bias": "controlnet_blocks.7.bias",
509
+ "control_model.zero_convs.9.0.weight": "controlnet_blocks.9.weight",
510
+ "control_model.zero_convs.9.0.bias": "controlnet_blocks.7.bias",
511
+ "control_model.zero_convs.10.0.weight": "controlnet_blocks.10.weight",
512
+ "control_model.zero_convs.10.0.bias": "controlnet_blocks.7.bias",
513
+ "control_model.zero_convs.11.0.weight": "controlnet_blocks.11.weight",
514
+ "control_model.zero_convs.11.0.bias": "controlnet_blocks.7.bias",
515
+ "control_model.input_hint_block.0.weight": "controlnet_conv_in.blocks.0.weight",
516
+ "control_model.input_hint_block.0.bias": "controlnet_conv_in.blocks.0.bias",
517
+ "control_model.input_hint_block.2.weight": "controlnet_conv_in.blocks.2.weight",
518
+ "control_model.input_hint_block.2.bias": "controlnet_conv_in.blocks.2.bias",
519
+ "control_model.input_hint_block.4.weight": "controlnet_conv_in.blocks.4.weight",
520
+ "control_model.input_hint_block.4.bias": "controlnet_conv_in.blocks.4.bias",
521
+ "control_model.input_hint_block.6.weight": "controlnet_conv_in.blocks.6.weight",
522
+ "control_model.input_hint_block.6.bias": "controlnet_conv_in.blocks.6.bias",
523
+ "control_model.input_hint_block.8.weight": "controlnet_conv_in.blocks.8.weight",
524
+ "control_model.input_hint_block.8.bias": "controlnet_conv_in.blocks.8.bias",
525
+ "control_model.input_hint_block.10.weight": "controlnet_conv_in.blocks.10.weight",
526
+ "control_model.input_hint_block.10.bias": "controlnet_conv_in.blocks.10.bias",
527
+ "control_model.input_hint_block.12.weight": "controlnet_conv_in.blocks.12.weight",
528
+ "control_model.input_hint_block.12.bias": "controlnet_conv_in.blocks.12.bias",
529
+ "control_model.input_hint_block.14.weight": "controlnet_conv_in.blocks.14.weight",
530
+ "control_model.input_hint_block.14.bias": "controlnet_conv_in.blocks.14.bias",
531
+ "control_model.middle_block.0.in_layers.0.weight": "blocks.28.norm1.weight",
532
+ "control_model.middle_block.0.in_layers.0.bias": "blocks.28.norm1.bias",
533
+ "control_model.middle_block.0.in_layers.2.weight": "blocks.28.conv1.weight",
534
+ "control_model.middle_block.0.in_layers.2.bias": "blocks.28.conv1.bias",
535
+ "control_model.middle_block.0.emb_layers.1.weight": "blocks.28.time_emb_proj.weight",
536
+ "control_model.middle_block.0.emb_layers.1.bias": "blocks.28.time_emb_proj.bias",
537
+ "control_model.middle_block.0.out_layers.0.weight": "blocks.28.norm2.weight",
538
+ "control_model.middle_block.0.out_layers.0.bias": "blocks.28.norm2.bias",
539
+ "control_model.middle_block.0.out_layers.3.weight": "blocks.28.conv2.weight",
540
+ "control_model.middle_block.0.out_layers.3.bias": "blocks.28.conv2.bias",
541
+ "control_model.middle_block.1.norm.weight": "blocks.29.norm.weight",
542
+ "control_model.middle_block.1.norm.bias": "blocks.29.norm.bias",
543
+ "control_model.middle_block.1.proj_in.weight": "blocks.29.proj_in.weight",
544
+ "control_model.middle_block.1.proj_in.bias": "blocks.29.proj_in.bias",
545
+ "control_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "blocks.29.transformer_blocks.0.attn1.to_q.weight",
546
+ "control_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "blocks.29.transformer_blocks.0.attn1.to_k.weight",
547
+ "control_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "blocks.29.transformer_blocks.0.attn1.to_v.weight",
548
+ "control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.29.transformer_blocks.0.attn1.to_out.weight",
549
+ "control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.29.transformer_blocks.0.attn1.to_out.bias",
550
+ "control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.29.transformer_blocks.0.act_fn.proj.weight",
551
+ "control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.29.transformer_blocks.0.act_fn.proj.bias",
552
+ "control_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "blocks.29.transformer_blocks.0.ff.weight",
553
+ "control_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "blocks.29.transformer_blocks.0.ff.bias",
554
+ "control_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "blocks.29.transformer_blocks.0.attn2.to_q.weight",
555
+ "control_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "blocks.29.transformer_blocks.0.attn2.to_k.weight",
556
+ "control_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "blocks.29.transformer_blocks.0.attn2.to_v.weight",
557
+ "control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.29.transformer_blocks.0.attn2.to_out.weight",
558
+ "control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.29.transformer_blocks.0.attn2.to_out.bias",
559
+ "control_model.middle_block.1.transformer_blocks.0.norm1.weight": "blocks.29.transformer_blocks.0.norm1.weight",
560
+ "control_model.middle_block.1.transformer_blocks.0.norm1.bias": "blocks.29.transformer_blocks.0.norm1.bias",
561
+ "control_model.middle_block.1.transformer_blocks.0.norm2.weight": "blocks.29.transformer_blocks.0.norm2.weight",
562
+ "control_model.middle_block.1.transformer_blocks.0.norm2.bias": "blocks.29.transformer_blocks.0.norm2.bias",
563
+ "control_model.middle_block.1.transformer_blocks.0.norm3.weight": "blocks.29.transformer_blocks.0.norm3.weight",
564
+ "control_model.middle_block.1.transformer_blocks.0.norm3.bias": "blocks.29.transformer_blocks.0.norm3.bias",
565
+ "control_model.middle_block.1.proj_out.weight": "blocks.29.proj_out.weight",
566
+ "control_model.middle_block.1.proj_out.bias": "blocks.29.proj_out.bias",
567
+ "control_model.middle_block.2.in_layers.0.weight": "blocks.30.norm1.weight",
568
+ "control_model.middle_block.2.in_layers.0.bias": "blocks.30.norm1.bias",
569
+ "control_model.middle_block.2.in_layers.2.weight": "blocks.30.conv1.weight",
570
+ "control_model.middle_block.2.in_layers.2.bias": "blocks.30.conv1.bias",
571
+ "control_model.middle_block.2.emb_layers.1.weight": "blocks.30.time_emb_proj.weight",
572
+ "control_model.middle_block.2.emb_layers.1.bias": "blocks.30.time_emb_proj.bias",
573
+ "control_model.middle_block.2.out_layers.0.weight": "blocks.30.norm2.weight",
574
+ "control_model.middle_block.2.out_layers.0.bias": "blocks.30.norm2.bias",
575
+ "control_model.middle_block.2.out_layers.3.weight": "blocks.30.conv2.weight",
576
+ "control_model.middle_block.2.out_layers.3.bias": "blocks.30.conv2.bias",
577
+ "control_model.middle_block_out.0.weight": "controlnet_blocks.12.weight",
578
+ "control_model.middle_block_out.0.bias": "controlnet_blocks.7.bias",
579
+ }
580
+ state_dict_ = {}
581
+ for name in state_dict:
582
+ if name in rename_dict:
583
+ param = state_dict[name]
584
+ if ".proj_in." in name or ".proj_out." in name:
585
+ param = param.squeeze()
586
+ state_dict_[rename_dict[name]] = param
587
+ return state_dict_
diffsynth/models/sd_ipadapter.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .svd_image_encoder import SVDImageEncoder
2
+ from .sdxl_ipadapter import IpAdapterImageProjModel, IpAdapterModule, SDXLIpAdapterStateDictConverter
3
+ from transformers import CLIPImageProcessor
4
+ import torch
5
+
6
+
7
+ class IpAdapterCLIPImageEmbedder(SVDImageEncoder):
8
+ def __init__(self):
9
+ super().__init__()
10
+ self.image_processor = CLIPImageProcessor()
11
+
12
+ def forward(self, image):
13
+ pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values
14
+ pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype)
15
+ return super().forward(pixel_values)
16
+
17
+
18
+ class SDIpAdapter(torch.nn.Module):
19
+ def __init__(self):
20
+ super().__init__()
21
+ shape_list = [(768, 320)] * 2 + [(768, 640)] * 2 + [(768, 1280)] * 5 + [(768, 640)] * 3 + [(768, 320)] * 3 + [(768, 1280)] * 1
22
+ self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list])
23
+ self.image_proj = IpAdapterImageProjModel(cross_attention_dim=768, clip_embeddings_dim=1024, clip_extra_context_tokens=4)
24
+ self.set_full_adapter()
25
+
26
+ def set_full_adapter(self):
27
+ block_ids = [1, 4, 9, 12, 17, 20, 40, 43, 46, 50, 53, 56, 60, 63, 66, 29]
28
+ self.call_block_id = {(i, 0): j for j, i in enumerate(block_ids)}
29
+
30
+ def set_less_adapter(self):
31
+ # IP-Adapter for SD v1.5 doesn't support this feature.
32
+ self.set_full_adapter()
33
+
34
+ def forward(self, hidden_states, scale=1.0):
35
+ hidden_states = self.image_proj(hidden_states)
36
+ hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
37
+ ip_kv_dict = {}
38
+ for (block_id, transformer_id) in self.call_block_id:
39
+ ipadapter_id = self.call_block_id[(block_id, transformer_id)]
40
+ ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
41
+ if block_id not in ip_kv_dict:
42
+ ip_kv_dict[block_id] = {}
43
+ ip_kv_dict[block_id][transformer_id] = {
44
+ "ip_k": ip_k,
45
+ "ip_v": ip_v,
46
+ "scale": scale
47
+ }
48
+ return ip_kv_dict
49
+
50
+ def state_dict_converter(self):
51
+ return SDIpAdapterStateDictConverter()
52
+
53
+
54
+ class SDIpAdapterStateDictConverter(SDXLIpAdapterStateDictConverter):
55
+ def __init__(self):
56
+ pass
diffsynth/models/sd_lora.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .sd_unet import SDUNetStateDictConverter, SDUNet
3
+ from .sd_text_encoder import SDTextEncoderStateDictConverter, SDTextEncoder
4
+
5
+
6
+ class SDLoRA:
7
+ def __init__(self):
8
+ pass
9
+
10
+ def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0, device="cuda"):
11
+ special_keys = {
12
+ "down.blocks": "down_blocks",
13
+ "up.blocks": "up_blocks",
14
+ "mid.block": "mid_block",
15
+ "proj.in": "proj_in",
16
+ "proj.out": "proj_out",
17
+ "transformer.blocks": "transformer_blocks",
18
+ "to.q": "to_q",
19
+ "to.k": "to_k",
20
+ "to.v": "to_v",
21
+ "to.out": "to_out",
22
+ }
23
+ state_dict_ = {}
24
+ for key in state_dict:
25
+ if ".lora_up" not in key:
26
+ continue
27
+ if not key.startswith(lora_prefix):
28
+ continue
29
+ weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
30
+ weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
31
+ if len(weight_up.shape) == 4:
32
+ weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
33
+ weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
34
+ lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
35
+ else:
36
+ lora_weight = alpha * torch.mm(weight_up, weight_down)
37
+ target_name = key.split(".")[0].replace("_", ".")[len(lora_prefix):] + ".weight"
38
+ for special_key in special_keys:
39
+ target_name = target_name.replace(special_key, special_keys[special_key])
40
+ state_dict_[target_name] = lora_weight.cpu()
41
+ return state_dict_
42
+
43
+ def add_lora_to_unet(self, unet: SDUNet, state_dict_lora, alpha=1.0, device="cuda"):
44
+ state_dict_unet = unet.state_dict()
45
+ state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_unet_", alpha=alpha, device=device)
46
+ state_dict_lora = SDUNetStateDictConverter().from_diffusers(state_dict_lora)
47
+ if len(state_dict_lora) > 0:
48
+ for name in state_dict_lora:
49
+ state_dict_unet[name] += state_dict_lora[name].to(device=device)
50
+ unet.load_state_dict(state_dict_unet)
51
+
52
+ def add_lora_to_text_encoder(self, text_encoder: SDTextEncoder, state_dict_lora, alpha=1.0, device="cuda"):
53
+ state_dict_text_encoder = text_encoder.state_dict()
54
+ state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_te_", alpha=alpha, device=device)
55
+ state_dict_lora = SDTextEncoderStateDictConverter().from_diffusers(state_dict_lora)
56
+ if len(state_dict_lora) > 0:
57
+ for name in state_dict_lora:
58
+ state_dict_text_encoder[name] += state_dict_lora[name].to(device=device)
59
+ text_encoder.load_state_dict(state_dict_text_encoder)
60
+
diffsynth/models/sd_motion.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .sd_unet import SDUNet, Attention, GEGLU
2
+ import torch
3
+ from einops import rearrange, repeat
4
+
5
+
6
+ class TemporalTransformerBlock(torch.nn.Module):
7
+
8
+ def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32):
9
+ super().__init__()
10
+
11
+ # 1. Self-Attn
12
+ self.pe1 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
13
+ self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True)
14
+ self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
15
+
16
+ # 2. Cross-Attn
17
+ self.pe2 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
18
+ self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True)
19
+ self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
20
+
21
+ # 3. Feed-forward
22
+ self.norm3 = torch.nn.LayerNorm(dim, elementwise_affine=True)
23
+ self.act_fn = GEGLU(dim, dim * 4)
24
+ self.ff = torch.nn.Linear(dim * 4, dim)
25
+
26
+
27
+ def forward(self, hidden_states, batch_size=1):
28
+
29
+ # 1. Self-Attention
30
+ norm_hidden_states = self.norm1(hidden_states)
31
+ norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
32
+ attn_output = self.attn1(norm_hidden_states + self.pe1[:, :norm_hidden_states.shape[1]])
33
+ attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
34
+ hidden_states = attn_output + hidden_states
35
+
36
+ # 2. Cross-Attention
37
+ norm_hidden_states = self.norm2(hidden_states)
38
+ norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
39
+ attn_output = self.attn2(norm_hidden_states + self.pe2[:, :norm_hidden_states.shape[1]])
40
+ attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
41
+ hidden_states = attn_output + hidden_states
42
+
43
+ # 3. Feed-forward
44
+ norm_hidden_states = self.norm3(hidden_states)
45
+ ff_output = self.act_fn(norm_hidden_states)
46
+ ff_output = self.ff(ff_output)
47
+ hidden_states = ff_output + hidden_states
48
+
49
+ return hidden_states
50
+
51
+
52
+ class TemporalBlock(torch.nn.Module):
53
+
54
+ def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
55
+ super().__init__()
56
+ inner_dim = num_attention_heads * attention_head_dim
57
+
58
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
59
+ self.proj_in = torch.nn.Linear(in_channels, inner_dim)
60
+
61
+ self.transformer_blocks = torch.nn.ModuleList([
62
+ TemporalTransformerBlock(
63
+ inner_dim,
64
+ num_attention_heads,
65
+ attention_head_dim
66
+ )
67
+ for d in range(num_layers)
68
+ ])
69
+
70
+ self.proj_out = torch.nn.Linear(inner_dim, in_channels)
71
+
72
+ def forward(self, hidden_states, time_emb, text_emb, res_stack, batch_size=1):
73
+ batch, _, height, width = hidden_states.shape
74
+ residual = hidden_states
75
+
76
+ hidden_states = self.norm(hidden_states)
77
+ inner_dim = hidden_states.shape[1]
78
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
79
+ hidden_states = self.proj_in(hidden_states)
80
+
81
+ for block in self.transformer_blocks:
82
+ hidden_states = block(
83
+ hidden_states,
84
+ batch_size=batch_size
85
+ )
86
+
87
+ hidden_states = self.proj_out(hidden_states)
88
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
89
+ hidden_states = hidden_states + residual
90
+
91
+ return hidden_states, time_emb, text_emb, res_stack
92
+
93
+
94
+ class SDMotionModel(torch.nn.Module):
95
+ def __init__(self):
96
+ super().__init__()
97
+ self.motion_modules = torch.nn.ModuleList([
98
+ TemporalBlock(8, 40, 320, eps=1e-6),
99
+ TemporalBlock(8, 40, 320, eps=1e-6),
100
+ TemporalBlock(8, 80, 640, eps=1e-6),
101
+ TemporalBlock(8, 80, 640, eps=1e-6),
102
+ TemporalBlock(8, 160, 1280, eps=1e-6),
103
+ TemporalBlock(8, 160, 1280, eps=1e-6),
104
+ TemporalBlock(8, 160, 1280, eps=1e-6),
105
+ TemporalBlock(8, 160, 1280, eps=1e-6),
106
+ TemporalBlock(8, 160, 1280, eps=1e-6),
107
+ TemporalBlock(8, 160, 1280, eps=1e-6),
108
+ TemporalBlock(8, 160, 1280, eps=1e-6),
109
+ TemporalBlock(8, 160, 1280, eps=1e-6),
110
+ TemporalBlock(8, 160, 1280, eps=1e-6),
111
+ TemporalBlock(8, 160, 1280, eps=1e-6),
112
+ TemporalBlock(8, 160, 1280, eps=1e-6),
113
+ TemporalBlock(8, 80, 640, eps=1e-6),
114
+ TemporalBlock(8, 80, 640, eps=1e-6),
115
+ TemporalBlock(8, 80, 640, eps=1e-6),
116
+ TemporalBlock(8, 40, 320, eps=1e-6),
117
+ TemporalBlock(8, 40, 320, eps=1e-6),
118
+ TemporalBlock(8, 40, 320, eps=1e-6),
119
+ ])
120
+ self.call_block_id = {
121
+ 1: 0,
122
+ 4: 1,
123
+ 9: 2,
124
+ 12: 3,
125
+ 17: 4,
126
+ 20: 5,
127
+ 24: 6,
128
+ 26: 7,
129
+ 29: 8,
130
+ 32: 9,
131
+ 34: 10,
132
+ 36: 11,
133
+ 40: 12,
134
+ 43: 13,
135
+ 46: 14,
136
+ 50: 15,
137
+ 53: 16,
138
+ 56: 17,
139
+ 60: 18,
140
+ 63: 19,
141
+ 66: 20
142
+ }
143
+
144
+ def forward(self):
145
+ pass
146
+
147
+ def state_dict_converter(self):
148
+ return SDMotionModelStateDictConverter()
149
+
150
+
151
+ class SDMotionModelStateDictConverter:
152
+ def __init__(self):
153
+ pass
154
+
155
+ def from_diffusers(self, state_dict):
156
+ rename_dict = {
157
+ "norm": "norm",
158
+ "proj_in": "proj_in",
159
+ "transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q",
160
+ "transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k",
161
+ "transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v",
162
+ "transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out",
163
+ "transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1",
164
+ "transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q",
165
+ "transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k",
166
+ "transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v",
167
+ "transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out",
168
+ "transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2",
169
+ "transformer_blocks.0.norms.0": "transformer_blocks.0.norm1",
170
+ "transformer_blocks.0.norms.1": "transformer_blocks.0.norm2",
171
+ "transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj",
172
+ "transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff",
173
+ "transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3",
174
+ "proj_out": "proj_out",
175
+ }
176
+ name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")])
177
+ name_list += sorted([i for i in state_dict if i.startswith("mid_block.")])
178
+ name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")])
179
+ state_dict_ = {}
180
+ last_prefix, module_id = "", -1
181
+ for name in name_list:
182
+ names = name.split(".")
183
+ prefix_index = names.index("temporal_transformer") + 1
184
+ prefix = ".".join(names[:prefix_index])
185
+ if prefix != last_prefix:
186
+ last_prefix = prefix
187
+ module_id += 1
188
+ middle_name = ".".join(names[prefix_index:-1])
189
+ suffix = names[-1]
190
+ if "pos_encoder" in names:
191
+ rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]])
192
+ else:
193
+ rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
194
+ state_dict_[rename] = state_dict[name]
195
+ return state_dict_
196
+
197
+ def from_civitai(self, state_dict):
198
+ return self.from_diffusers(state_dict)
diffsynth/models/sd_text_encoder.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .attention import Attention
3
+
4
+
5
+ class CLIPEncoderLayer(torch.nn.Module):
6
+ def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
7
+ super().__init__()
8
+ self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
9
+ self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
10
+ self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
11
+ self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
12
+ self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
13
+
14
+ self.use_quick_gelu = use_quick_gelu
15
+
16
+ def quickGELU(self, x):
17
+ return x * torch.sigmoid(1.702 * x)
18
+
19
+ def forward(self, hidden_states, attn_mask=None):
20
+ residual = hidden_states
21
+
22
+ hidden_states = self.layer_norm1(hidden_states)
23
+ hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
24
+ hidden_states = residual + hidden_states
25
+
26
+ residual = hidden_states
27
+ hidden_states = self.layer_norm2(hidden_states)
28
+ hidden_states = self.fc1(hidden_states)
29
+ if self.use_quick_gelu:
30
+ hidden_states = self.quickGELU(hidden_states)
31
+ else:
32
+ hidden_states = torch.nn.functional.gelu(hidden_states)
33
+ hidden_states = self.fc2(hidden_states)
34
+ hidden_states = residual + hidden_states
35
+
36
+ return hidden_states
37
+
38
+
39
+ class SDTextEncoder(torch.nn.Module):
40
+ def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
41
+ super().__init__()
42
+
43
+ # token_embedding
44
+ self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
45
+
46
+ # position_embeds (This is a fixed tensor)
47
+ self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
48
+
49
+ # encoders
50
+ self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
51
+
52
+ # attn_mask
53
+ self.attn_mask = self.attention_mask(max_position_embeddings)
54
+
55
+ # final_layer_norm
56
+ self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
57
+
58
+ def attention_mask(self, length):
59
+ mask = torch.empty(length, length)
60
+ mask.fill_(float("-inf"))
61
+ mask.triu_(1)
62
+ return mask
63
+
64
+ def forward(self, input_ids, clip_skip=1):
65
+ embeds = self.token_embedding(input_ids) + self.position_embeds
66
+ attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
67
+ for encoder_id, encoder in enumerate(self.encoders):
68
+ embeds = encoder(embeds, attn_mask=attn_mask)
69
+ if encoder_id + clip_skip == len(self.encoders):
70
+ break
71
+ embeds = self.final_layer_norm(embeds)
72
+ return embeds
73
+
74
+ def state_dict_converter(self):
75
+ return SDTextEncoderStateDictConverter()
76
+
77
+
78
+ class SDTextEncoderStateDictConverter:
79
+ def __init__(self):
80
+ pass
81
+
82
+ def from_diffusers(self, state_dict):
83
+ rename_dict = {
84
+ "text_model.embeddings.token_embedding.weight": "token_embedding.weight",
85
+ "text_model.embeddings.position_embedding.weight": "position_embeds",
86
+ "text_model.final_layer_norm.weight": "final_layer_norm.weight",
87
+ "text_model.final_layer_norm.bias": "final_layer_norm.bias"
88
+ }
89
+ attn_rename_dict = {
90
+ "self_attn.q_proj": "attn.to_q",
91
+ "self_attn.k_proj": "attn.to_k",
92
+ "self_attn.v_proj": "attn.to_v",
93
+ "self_attn.out_proj": "attn.to_out",
94
+ "layer_norm1": "layer_norm1",
95
+ "layer_norm2": "layer_norm2",
96
+ "mlp.fc1": "fc1",
97
+ "mlp.fc2": "fc2",
98
+ }
99
+ state_dict_ = {}
100
+ for name in state_dict:
101
+ if name in rename_dict:
102
+ param = state_dict[name]
103
+ if name == "text_model.embeddings.position_embedding.weight":
104
+ param = param.reshape((1, param.shape[0], param.shape[1]))
105
+ state_dict_[rename_dict[name]] = param
106
+ elif name.startswith("text_model.encoder.layers."):
107
+ param = state_dict[name]
108
+ names = name.split(".")
109
+ layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
110
+ name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
111
+ state_dict_[name_] = param
112
+ return state_dict_
113
+
114
+ def from_civitai(self, state_dict):
115
+ rename_dict = {
116
+ "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
117
+ "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
118
+ "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
119
+ "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
120
+ "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
121
+ "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
122
+ "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
123
+ "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
124
+ "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
125
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
126
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
127
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
128
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
129
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
130
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
131
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
132
+ "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
133
+ "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
134
+ "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
135
+ "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
136
+ "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
137
+ "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
138
+ "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
139
+ "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
140
+ "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
141
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
142
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
143
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
144
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
145
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
146
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
147
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
148
+ "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
149
+ "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
150
+ "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
151
+ "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
152
+ "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
153
+ "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
154
+ "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
155
+ "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
156
+ "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
157
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
158
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
159
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
160
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
161
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
162
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
163
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
164
+ "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
165
+ "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias",
166
+ "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight",
167
+ "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias",
168
+ "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight",
169
+ "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias",
170
+ "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight",
171
+ "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias",
172
+ "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight",
173
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias",
174
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight",
175
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias",
176
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight",
177
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias",
178
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight",
179
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias",
180
+ "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight",
181
+ "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
182
+ "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
183
+ "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
184
+ "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
185
+ "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
186
+ "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
187
+ "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
188
+ "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
189
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
190
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
191
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
192
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
193
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
194
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
195
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
196
+ "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
197
+ "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
198
+ "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
199
+ "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
200
+ "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
201
+ "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
202
+ "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
203
+ "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
204
+ "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
205
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
206
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
207
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
208
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
209
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
210
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
211
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
212
+ "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
213
+ "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
214
+ "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
215
+ "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
216
+ "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
217
+ "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
218
+ "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
219
+ "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
220
+ "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
221
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
222
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
223
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
224
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
225
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
226
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
227
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
228
+ "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
229
+ "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
230
+ "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
231
+ "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
232
+ "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
233
+ "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
234
+ "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
235
+ "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
236
+ "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
237
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
238
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
239
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
240
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
241
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
242
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
243
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
244
+ "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
245
+ "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
246
+ "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
247
+ "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
248
+ "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
249
+ "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
250
+ "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
251
+ "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
252
+ "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
253
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
254
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
255
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
256
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
257
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
258
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
259
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
260
+ "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
261
+ "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
262
+ "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
263
+ "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
264
+ "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
265
+ "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
266
+ "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
267
+ "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
268
+ "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
269
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
270
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
271
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
272
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
273
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
274
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
275
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
276
+ "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
277
+ "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
278
+ "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
279
+ "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
280
+ "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
281
+ "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
282
+ "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
283
+ "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
284
+ "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
285
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
286
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
287
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
288
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
289
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
290
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
291
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
292
+ "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
293
+ "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
294
+ "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
295
+ "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
296
+ "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
297
+ "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
298
+ "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
299
+ "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
300
+ "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
301
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
302
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
303
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
304
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
305
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
306
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
307
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
308
+ "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
309
+ "cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias",
310
+ "cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight",
311
+ "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds"
312
+ }
313
+ state_dict_ = {}
314
+ for name in state_dict:
315
+ if name in rename_dict:
316
+ param = state_dict[name]
317
+ if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight":
318
+ param = param.reshape((1, param.shape[0], param.shape[1]))
319
+ state_dict_[rename_dict[name]] = param
320
+ return state_dict_
diffsynth/models/sd_unet.py ADDED
The diff for this file is too large to render. See raw diff
 
diffsynth/models/sd_vae_decoder.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .attention import Attention
3
+ from .sd_unet import ResnetBlock, UpSampler
4
+ from .tiler import TileWorker
5
+
6
+
7
+ class VAEAttentionBlock(torch.nn.Module):
8
+
9
+ def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
10
+ super().__init__()
11
+ inner_dim = num_attention_heads * attention_head_dim
12
+
13
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
14
+
15
+ self.transformer_blocks = torch.nn.ModuleList([
16
+ Attention(
17
+ inner_dim,
18
+ num_attention_heads,
19
+ attention_head_dim,
20
+ bias_q=True,
21
+ bias_kv=True,
22
+ bias_out=True
23
+ )
24
+ for d in range(num_layers)
25
+ ])
26
+
27
+ def forward(self, hidden_states, time_emb, text_emb, res_stack):
28
+ batch, _, height, width = hidden_states.shape
29
+ residual = hidden_states
30
+
31
+ hidden_states = self.norm(hidden_states)
32
+ inner_dim = hidden_states.shape[1]
33
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
34
+
35
+ for block in self.transformer_blocks:
36
+ hidden_states = block(hidden_states)
37
+
38
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
39
+ hidden_states = hidden_states + residual
40
+
41
+ return hidden_states, time_emb, text_emb, res_stack
42
+
43
+
44
+ class SDVAEDecoder(torch.nn.Module):
45
+ def __init__(self):
46
+ super().__init__()
47
+ self.scaling_factor = 0.18215
48
+ self.post_quant_conv = torch.nn.Conv2d(4, 4, kernel_size=1)
49
+ self.conv_in = torch.nn.Conv2d(4, 512, kernel_size=3, padding=1)
50
+
51
+ self.blocks = torch.nn.ModuleList([
52
+ # UNetMidBlock2D
53
+ ResnetBlock(512, 512, eps=1e-6),
54
+ VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
55
+ ResnetBlock(512, 512, eps=1e-6),
56
+ # UpDecoderBlock2D
57
+ ResnetBlock(512, 512, eps=1e-6),
58
+ ResnetBlock(512, 512, eps=1e-6),
59
+ ResnetBlock(512, 512, eps=1e-6),
60
+ UpSampler(512),
61
+ # UpDecoderBlock2D
62
+ ResnetBlock(512, 512, eps=1e-6),
63
+ ResnetBlock(512, 512, eps=1e-6),
64
+ ResnetBlock(512, 512, eps=1e-6),
65
+ UpSampler(512),
66
+ # UpDecoderBlock2D
67
+ ResnetBlock(512, 256, eps=1e-6),
68
+ ResnetBlock(256, 256, eps=1e-6),
69
+ ResnetBlock(256, 256, eps=1e-6),
70
+ UpSampler(256),
71
+ # UpDecoderBlock2D
72
+ ResnetBlock(256, 128, eps=1e-6),
73
+ ResnetBlock(128, 128, eps=1e-6),
74
+ ResnetBlock(128, 128, eps=1e-6),
75
+ ])
76
+
77
+ self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-5)
78
+ self.conv_act = torch.nn.SiLU()
79
+ self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
80
+
81
+ def tiled_forward(self, sample, tile_size=64, tile_stride=32):
82
+ hidden_states = TileWorker().tiled_forward(
83
+ lambda x: self.forward(x),
84
+ sample,
85
+ tile_size,
86
+ tile_stride,
87
+ tile_device=sample.device,
88
+ tile_dtype=sample.dtype
89
+ )
90
+ return hidden_states
91
+
92
+ def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
93
+ # For VAE Decoder, we do not need to apply the tiler on each layer.
94
+ if tiled:
95
+ return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
96
+
97
+ # 1. pre-process
98
+ sample = sample / self.scaling_factor
99
+ hidden_states = self.post_quant_conv(sample)
100
+ hidden_states = self.conv_in(hidden_states)
101
+ time_emb = None
102
+ text_emb = None
103
+ res_stack = None
104
+
105
+ # 2. blocks
106
+ for i, block in enumerate(self.blocks):
107
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
108
+
109
+ # 3. output
110
+ hidden_states = self.conv_norm_out(hidden_states)
111
+ hidden_states = self.conv_act(hidden_states)
112
+ hidden_states = self.conv_out(hidden_states)
113
+
114
+ return hidden_states
115
+
116
+ def state_dict_converter(self):
117
+ return SDVAEDecoderStateDictConverter()
118
+
119
+
120
+ class SDVAEDecoderStateDictConverter:
121
+ def __init__(self):
122
+ pass
123
+
124
+ def from_diffusers(self, state_dict):
125
+ # architecture
126
+ block_types = [
127
+ 'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock',
128
+ 'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
129
+ 'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
130
+ 'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
131
+ 'ResnetBlock', 'ResnetBlock', 'ResnetBlock'
132
+ ]
133
+
134
+ # Rename each parameter
135
+ local_rename_dict = {
136
+ "post_quant_conv": "post_quant_conv",
137
+ "decoder.conv_in": "conv_in",
138
+ "decoder.mid_block.attentions.0.group_norm": "blocks.1.norm",
139
+ "decoder.mid_block.attentions.0.to_q": "blocks.1.transformer_blocks.0.to_q",
140
+ "decoder.mid_block.attentions.0.to_k": "blocks.1.transformer_blocks.0.to_k",
141
+ "decoder.mid_block.attentions.0.to_v": "blocks.1.transformer_blocks.0.to_v",
142
+ "decoder.mid_block.attentions.0.to_out.0": "blocks.1.transformer_blocks.0.to_out",
143
+ "decoder.mid_block.resnets.0.norm1": "blocks.0.norm1",
144
+ "decoder.mid_block.resnets.0.conv1": "blocks.0.conv1",
145
+ "decoder.mid_block.resnets.0.norm2": "blocks.0.norm2",
146
+ "decoder.mid_block.resnets.0.conv2": "blocks.0.conv2",
147
+ "decoder.mid_block.resnets.1.norm1": "blocks.2.norm1",
148
+ "decoder.mid_block.resnets.1.conv1": "blocks.2.conv1",
149
+ "decoder.mid_block.resnets.1.norm2": "blocks.2.norm2",
150
+ "decoder.mid_block.resnets.1.conv2": "blocks.2.conv2",
151
+ "decoder.conv_norm_out": "conv_norm_out",
152
+ "decoder.conv_out": "conv_out",
153
+ }
154
+ name_list = sorted([name for name in state_dict])
155
+ rename_dict = {}
156
+ block_id = {"ResnetBlock": 2, "DownSampler": 2, "UpSampler": 2}
157
+ last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
158
+ for name in name_list:
159
+ names = name.split(".")
160
+ name_prefix = ".".join(names[:-1])
161
+ if name_prefix in local_rename_dict:
162
+ rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
163
+ elif name.startswith("decoder.up_blocks"):
164
+ block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
165
+ block_type_with_id = ".".join(names[:5])
166
+ if block_type_with_id != last_block_type_with_id[block_type]:
167
+ block_id[block_type] += 1
168
+ last_block_type_with_id[block_type] = block_type_with_id
169
+ while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
170
+ block_id[block_type] += 1
171
+ block_type_with_id = ".".join(names[:5])
172
+ names = ["blocks", str(block_id[block_type])] + names[5:]
173
+ rename_dict[name] = ".".join(names)
174
+
175
+ # Convert state_dict
176
+ state_dict_ = {}
177
+ for name, param in state_dict.items():
178
+ if name in rename_dict:
179
+ state_dict_[rename_dict[name]] = param
180
+ return state_dict_
181
+
182
+ def from_civitai(self, state_dict):
183
+ rename_dict = {
184
+ "first_stage_model.decoder.conv_in.bias": "conv_in.bias",
185
+ "first_stage_model.decoder.conv_in.weight": "conv_in.weight",
186
+ "first_stage_model.decoder.conv_out.bias": "conv_out.bias",
187
+ "first_stage_model.decoder.conv_out.weight": "conv_out.weight",
188
+ "first_stage_model.decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias",
189
+ "first_stage_model.decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight",
190
+ "first_stage_model.decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias",
191
+ "first_stage_model.decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight",
192
+ "first_stage_model.decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias",
193
+ "first_stage_model.decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight",
194
+ "first_stage_model.decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias",
195
+ "first_stage_model.decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight",
196
+ "first_stage_model.decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias",
197
+ "first_stage_model.decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight",
198
+ "first_stage_model.decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
199
+ "first_stage_model.decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
200
+ "first_stage_model.decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
201
+ "first_stage_model.decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
202
+ "first_stage_model.decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
203
+ "first_stage_model.decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
204
+ "first_stage_model.decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
205
+ "first_stage_model.decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
206
+ "first_stage_model.decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias",
207
+ "first_stage_model.decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight",
208
+ "first_stage_model.decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias",
209
+ "first_stage_model.decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight",
210
+ "first_stage_model.decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias",
211
+ "first_stage_model.decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight",
212
+ "first_stage_model.decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias",
213
+ "first_stage_model.decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight",
214
+ "first_stage_model.decoder.norm_out.bias": "conv_norm_out.bias",
215
+ "first_stage_model.decoder.norm_out.weight": "conv_norm_out.weight",
216
+ "first_stage_model.decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias",
217
+ "first_stage_model.decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight",
218
+ "first_stage_model.decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias",
219
+ "first_stage_model.decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight",
220
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias",
221
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight",
222
+ "first_stage_model.decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias",
223
+ "first_stage_model.decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight",
224
+ "first_stage_model.decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias",
225
+ "first_stage_model.decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight",
226
+ "first_stage_model.decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias",
227
+ "first_stage_model.decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight",
228
+ "first_stage_model.decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias",
229
+ "first_stage_model.decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight",
230
+ "first_stage_model.decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias",
231
+ "first_stage_model.decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight",
232
+ "first_stage_model.decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias",
233
+ "first_stage_model.decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight",
234
+ "first_stage_model.decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias",
235
+ "first_stage_model.decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight",
236
+ "first_stage_model.decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias",
237
+ "first_stage_model.decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight",
238
+ "first_stage_model.decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias",
239
+ "first_stage_model.decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight",
240
+ "first_stage_model.decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias",
241
+ "first_stage_model.decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight",
242
+ "first_stage_model.decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias",
243
+ "first_stage_model.decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight",
244
+ "first_stage_model.decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias",
245
+ "first_stage_model.decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight",
246
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias",
247
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight",
248
+ "first_stage_model.decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias",
249
+ "first_stage_model.decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight",
250
+ "first_stage_model.decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias",
251
+ "first_stage_model.decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight",
252
+ "first_stage_model.decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias",
253
+ "first_stage_model.decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight",
254
+ "first_stage_model.decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias",
255
+ "first_stage_model.decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight",
256
+ "first_stage_model.decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias",
257
+ "first_stage_model.decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight",
258
+ "first_stage_model.decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias",
259
+ "first_stage_model.decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight",
260
+ "first_stage_model.decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias",
261
+ "first_stage_model.decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight",
262
+ "first_stage_model.decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias",
263
+ "first_stage_model.decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight",
264
+ "first_stage_model.decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias",
265
+ "first_stage_model.decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight",
266
+ "first_stage_model.decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias",
267
+ "first_stage_model.decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight",
268
+ "first_stage_model.decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias",
269
+ "first_stage_model.decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight",
270
+ "first_stage_model.decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias",
271
+ "first_stage_model.decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight",
272
+ "first_stage_model.decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias",
273
+ "first_stage_model.decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight",
274
+ "first_stage_model.decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias",
275
+ "first_stage_model.decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight",
276
+ "first_stage_model.decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias",
277
+ "first_stage_model.decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight",
278
+ "first_stage_model.decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias",
279
+ "first_stage_model.decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight",
280
+ "first_stage_model.decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias",
281
+ "first_stage_model.decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight",
282
+ "first_stage_model.decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias",
283
+ "first_stage_model.decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight",
284
+ "first_stage_model.decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias",
285
+ "first_stage_model.decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight",
286
+ "first_stage_model.decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias",
287
+ "first_stage_model.decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight",
288
+ "first_stage_model.decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias",
289
+ "first_stage_model.decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight",
290
+ "first_stage_model.decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias",
291
+ "first_stage_model.decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight",
292
+ "first_stage_model.decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias",
293
+ "first_stage_model.decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight",
294
+ "first_stage_model.decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias",
295
+ "first_stage_model.decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight",
296
+ "first_stage_model.decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias",
297
+ "first_stage_model.decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight",
298
+ "first_stage_model.decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias",
299
+ "first_stage_model.decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight",
300
+ "first_stage_model.decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias",
301
+ "first_stage_model.decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight",
302
+ "first_stage_model.decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias",
303
+ "first_stage_model.decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight",
304
+ "first_stage_model.decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias",
305
+ "first_stage_model.decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight",
306
+ "first_stage_model.decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias",
307
+ "first_stage_model.decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight",
308
+ "first_stage_model.decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias",
309
+ "first_stage_model.decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight",
310
+ "first_stage_model.decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias",
311
+ "first_stage_model.decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight",
312
+ "first_stage_model.decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias",
313
+ "first_stage_model.decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight",
314
+ "first_stage_model.decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias",
315
+ "first_stage_model.decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight",
316
+ "first_stage_model.decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias",
317
+ "first_stage_model.decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight",
318
+ "first_stage_model.decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias",
319
+ "first_stage_model.decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight",
320
+ "first_stage_model.decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias",
321
+ "first_stage_model.decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight",
322
+ "first_stage_model.post_quant_conv.bias": "post_quant_conv.bias",
323
+ "first_stage_model.post_quant_conv.weight": "post_quant_conv.weight",
324
+ }
325
+ state_dict_ = {}
326
+ for name in state_dict:
327
+ if name in rename_dict:
328
+ param = state_dict[name]
329
+ if "transformer_blocks" in rename_dict[name]:
330
+ param = param.squeeze()
331
+ state_dict_[rename_dict[name]] = param
332
+ return state_dict_
diffsynth/models/sd_vae_encoder.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .sd_unet import ResnetBlock, DownSampler
3
+ from .sd_vae_decoder import VAEAttentionBlock
4
+ from .tiler import TileWorker
5
+ from einops import rearrange
6
+
7
+
8
+ class SDVAEEncoder(torch.nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+ self.scaling_factor = 0.18215
12
+ self.quant_conv = torch.nn.Conv2d(8, 8, kernel_size=1)
13
+ self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
14
+
15
+ self.blocks = torch.nn.ModuleList([
16
+ # DownEncoderBlock2D
17
+ ResnetBlock(128, 128, eps=1e-6),
18
+ ResnetBlock(128, 128, eps=1e-6),
19
+ DownSampler(128, padding=0, extra_padding=True),
20
+ # DownEncoderBlock2D
21
+ ResnetBlock(128, 256, eps=1e-6),
22
+ ResnetBlock(256, 256, eps=1e-6),
23
+ DownSampler(256, padding=0, extra_padding=True),
24
+ # DownEncoderBlock2D
25
+ ResnetBlock(256, 512, eps=1e-6),
26
+ ResnetBlock(512, 512, eps=1e-6),
27
+ DownSampler(512, padding=0, extra_padding=True),
28
+ # DownEncoderBlock2D
29
+ ResnetBlock(512, 512, eps=1e-6),
30
+ ResnetBlock(512, 512, eps=1e-6),
31
+ # UNetMidBlock2D
32
+ ResnetBlock(512, 512, eps=1e-6),
33
+ VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
34
+ ResnetBlock(512, 512, eps=1e-6),
35
+ ])
36
+
37
+ self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
38
+ self.conv_act = torch.nn.SiLU()
39
+ self.conv_out = torch.nn.Conv2d(512, 8, kernel_size=3, padding=1)
40
+
41
+ def tiled_forward(self, sample, tile_size=64, tile_stride=32):
42
+ hidden_states = TileWorker().tiled_forward(
43
+ lambda x: self.forward(x),
44
+ sample,
45
+ tile_size,
46
+ tile_stride,
47
+ tile_device=sample.device,
48
+ tile_dtype=sample.dtype
49
+ )
50
+ return hidden_states
51
+
52
+ def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
53
+ # For VAE Decoder, we do not need to apply the tiler on each layer.
54
+ if tiled:
55
+ return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
56
+
57
+ # 1. pre-process
58
+ hidden_states = self.conv_in(sample)
59
+ time_emb = None
60
+ text_emb = None
61
+ res_stack = None
62
+
63
+ # 2. blocks
64
+ for i, block in enumerate(self.blocks):
65
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
66
+
67
+ # 3. output
68
+ hidden_states = self.conv_norm_out(hidden_states)
69
+ hidden_states = self.conv_act(hidden_states)
70
+ hidden_states = self.conv_out(hidden_states)
71
+ hidden_states = self.quant_conv(hidden_states)
72
+ hidden_states = hidden_states[:, :4]
73
+ hidden_states *= self.scaling_factor
74
+
75
+ return hidden_states
76
+
77
+ def encode_video(self, sample, batch_size=8):
78
+ B = sample.shape[0]
79
+ hidden_states = []
80
+
81
+ for i in range(0, sample.shape[2], batch_size):
82
+
83
+ j = min(i + batch_size, sample.shape[2])
84
+ sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
85
+
86
+ hidden_states_batch = self(sample_batch)
87
+ hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
88
+
89
+ hidden_states.append(hidden_states_batch)
90
+
91
+ hidden_states = torch.concat(hidden_states, dim=2)
92
+ return hidden_states
93
+
94
+ def state_dict_converter(self):
95
+ return SDVAEEncoderStateDictConverter()
96
+
97
+
98
+ class SDVAEEncoderStateDictConverter:
99
+ def __init__(self):
100
+ pass
101
+
102
+ def from_diffusers(self, state_dict):
103
+ # architecture
104
+ block_types = [
105
+ 'ResnetBlock', 'ResnetBlock', 'DownSampler',
106
+ 'ResnetBlock', 'ResnetBlock', 'DownSampler',
107
+ 'ResnetBlock', 'ResnetBlock', 'DownSampler',
108
+ 'ResnetBlock', 'ResnetBlock',
109
+ 'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock'
110
+ ]
111
+
112
+ # Rename each parameter
113
+ local_rename_dict = {
114
+ "quant_conv": "quant_conv",
115
+ "encoder.conv_in": "conv_in",
116
+ "encoder.mid_block.attentions.0.group_norm": "blocks.12.norm",
117
+ "encoder.mid_block.attentions.0.to_q": "blocks.12.transformer_blocks.0.to_q",
118
+ "encoder.mid_block.attentions.0.to_k": "blocks.12.transformer_blocks.0.to_k",
119
+ "encoder.mid_block.attentions.0.to_v": "blocks.12.transformer_blocks.0.to_v",
120
+ "encoder.mid_block.attentions.0.to_out.0": "blocks.12.transformer_blocks.0.to_out",
121
+ "encoder.mid_block.resnets.0.norm1": "blocks.11.norm1",
122
+ "encoder.mid_block.resnets.0.conv1": "blocks.11.conv1",
123
+ "encoder.mid_block.resnets.0.norm2": "blocks.11.norm2",
124
+ "encoder.mid_block.resnets.0.conv2": "blocks.11.conv2",
125
+ "encoder.mid_block.resnets.1.norm1": "blocks.13.norm1",
126
+ "encoder.mid_block.resnets.1.conv1": "blocks.13.conv1",
127
+ "encoder.mid_block.resnets.1.norm2": "blocks.13.norm2",
128
+ "encoder.mid_block.resnets.1.conv2": "blocks.13.conv2",
129
+ "encoder.conv_norm_out": "conv_norm_out",
130
+ "encoder.conv_out": "conv_out",
131
+ }
132
+ name_list = sorted([name for name in state_dict])
133
+ rename_dict = {}
134
+ block_id = {"ResnetBlock": -1, "DownSampler": -1, "UpSampler": -1}
135
+ last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
136
+ for name in name_list:
137
+ names = name.split(".")
138
+ name_prefix = ".".join(names[:-1])
139
+ if name_prefix in local_rename_dict:
140
+ rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
141
+ elif name.startswith("encoder.down_blocks"):
142
+ block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
143
+ block_type_with_id = ".".join(names[:5])
144
+ if block_type_with_id != last_block_type_with_id[block_type]:
145
+ block_id[block_type] += 1
146
+ last_block_type_with_id[block_type] = block_type_with_id
147
+ while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
148
+ block_id[block_type] += 1
149
+ block_type_with_id = ".".join(names[:5])
150
+ names = ["blocks", str(block_id[block_type])] + names[5:]
151
+ rename_dict[name] = ".".join(names)
152
+
153
+ # Convert state_dict
154
+ state_dict_ = {}
155
+ for name, param in state_dict.items():
156
+ if name in rename_dict:
157
+ state_dict_[rename_dict[name]] = param
158
+ return state_dict_
159
+
160
+ def from_civitai(self, state_dict):
161
+ rename_dict = {
162
+ "first_stage_model.encoder.conv_in.bias": "conv_in.bias",
163
+ "first_stage_model.encoder.conv_in.weight": "conv_in.weight",
164
+ "first_stage_model.encoder.conv_out.bias": "conv_out.bias",
165
+ "first_stage_model.encoder.conv_out.weight": "conv_out.weight",
166
+ "first_stage_model.encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
167
+ "first_stage_model.encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
168
+ "first_stage_model.encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
169
+ "first_stage_model.encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
170
+ "first_stage_model.encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
171
+ "first_stage_model.encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
172
+ "first_stage_model.encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
173
+ "first_stage_model.encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
174
+ "first_stage_model.encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
175
+ "first_stage_model.encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
176
+ "first_stage_model.encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
177
+ "first_stage_model.encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
178
+ "first_stage_model.encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
179
+ "first_stage_model.encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
180
+ "first_stage_model.encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
181
+ "first_stage_model.encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
182
+ "first_stage_model.encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
183
+ "first_stage_model.encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
184
+ "first_stage_model.encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
185
+ "first_stage_model.encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
186
+ "first_stage_model.encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
187
+ "first_stage_model.encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
188
+ "first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
189
+ "first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
190
+ "first_stage_model.encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
191
+ "first_stage_model.encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
192
+ "first_stage_model.encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
193
+ "first_stage_model.encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
194
+ "first_stage_model.encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
195
+ "first_stage_model.encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
196
+ "first_stage_model.encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
197
+ "first_stage_model.encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
198
+ "first_stage_model.encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
199
+ "first_stage_model.encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
200
+ "first_stage_model.encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
201
+ "first_stage_model.encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
202
+ "first_stage_model.encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
203
+ "first_stage_model.encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
204
+ "first_stage_model.encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
205
+ "first_stage_model.encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
206
+ "first_stage_model.encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
207
+ "first_stage_model.encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
208
+ "first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
209
+ "first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
210
+ "first_stage_model.encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
211
+ "first_stage_model.encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
212
+ "first_stage_model.encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
213
+ "first_stage_model.encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
214
+ "first_stage_model.encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
215
+ "first_stage_model.encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
216
+ "first_stage_model.encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
217
+ "first_stage_model.encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
218
+ "first_stage_model.encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
219
+ "first_stage_model.encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
220
+ "first_stage_model.encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
221
+ "first_stage_model.encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
222
+ "first_stage_model.encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
223
+ "first_stage_model.encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
224
+ "first_stage_model.encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
225
+ "first_stage_model.encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
226
+ "first_stage_model.encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
227
+ "first_stage_model.encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
228
+ "first_stage_model.encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
229
+ "first_stage_model.encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
230
+ "first_stage_model.encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
231
+ "first_stage_model.encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
232
+ "first_stage_model.encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
233
+ "first_stage_model.encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
234
+ "first_stage_model.encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
235
+ "first_stage_model.encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
236
+ "first_stage_model.encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
237
+ "first_stage_model.encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
238
+ "first_stage_model.encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
239
+ "first_stage_model.encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
240
+ "first_stage_model.encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
241
+ "first_stage_model.encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
242
+ "first_stage_model.encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
243
+ "first_stage_model.encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
244
+ "first_stage_model.encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
245
+ "first_stage_model.encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
246
+ "first_stage_model.encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
247
+ "first_stage_model.encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
248
+ "first_stage_model.encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
249
+ "first_stage_model.encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
250
+ "first_stage_model.encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
251
+ "first_stage_model.encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
252
+ "first_stage_model.encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
253
+ "first_stage_model.encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
254
+ "first_stage_model.encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
255
+ "first_stage_model.encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
256
+ "first_stage_model.encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
257
+ "first_stage_model.encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
258
+ "first_stage_model.encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
259
+ "first_stage_model.encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
260
+ "first_stage_model.encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
261
+ "first_stage_model.encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
262
+ "first_stage_model.encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
263
+ "first_stage_model.encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
264
+ "first_stage_model.encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
265
+ "first_stage_model.encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
266
+ "first_stage_model.encoder.norm_out.bias": "conv_norm_out.bias",
267
+ "first_stage_model.encoder.norm_out.weight": "conv_norm_out.weight",
268
+ "first_stage_model.quant_conv.bias": "quant_conv.bias",
269
+ "first_stage_model.quant_conv.weight": "quant_conv.weight",
270
+ }
271
+ state_dict_ = {}
272
+ for name in state_dict:
273
+ if name in rename_dict:
274
+ param = state_dict[name]
275
+ if "transformer_blocks" in rename_dict[name]:
276
+ param = param.squeeze()
277
+ state_dict_[rename_dict[name]] = param
278
+ return state_dict_
diffsynth/models/sdxl_ipadapter.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .svd_image_encoder import SVDImageEncoder
2
+ from transformers import CLIPImageProcessor
3
+ import torch
4
+
5
+
6
+ class IpAdapterXLCLIPImageEmbedder(SVDImageEncoder):
7
+ def __init__(self):
8
+ super().__init__(embed_dim=1664, encoder_intermediate_size=8192, projection_dim=1280, num_encoder_layers=48, num_heads=16, head_dim=104)
9
+ self.image_processor = CLIPImageProcessor()
10
+
11
+ def forward(self, image):
12
+ pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values
13
+ pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype)
14
+ return super().forward(pixel_values)
15
+
16
+
17
+ class IpAdapterImageProjModel(torch.nn.Module):
18
+ def __init__(self, cross_attention_dim=2048, clip_embeddings_dim=1280, clip_extra_context_tokens=4):
19
+ super().__init__()
20
+ self.cross_attention_dim = cross_attention_dim
21
+ self.clip_extra_context_tokens = clip_extra_context_tokens
22
+ self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
23
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
24
+
25
+ def forward(self, image_embeds):
26
+ clip_extra_context_tokens = self.proj(image_embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
27
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
28
+ return clip_extra_context_tokens
29
+
30
+
31
+ class IpAdapterModule(torch.nn.Module):
32
+ def __init__(self, input_dim, output_dim):
33
+ super().__init__()
34
+ self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
35
+ self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
36
+
37
+ def forward(self, hidden_states):
38
+ ip_k = self.to_k_ip(hidden_states)
39
+ ip_v = self.to_v_ip(hidden_states)
40
+ return ip_k, ip_v
41
+
42
+
43
+ class SDXLIpAdapter(torch.nn.Module):
44
+ def __init__(self):
45
+ super().__init__()
46
+ shape_list = [(2048, 640)] * 4 + [(2048, 1280)] * 50 + [(2048, 640)] * 6 + [(2048, 1280)] * 10
47
+ self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list])
48
+ self.image_proj = IpAdapterImageProjModel()
49
+ self.set_full_adapter()
50
+
51
+ def set_full_adapter(self):
52
+ map_list = sum([
53
+ [(7, i) for i in range(2)],
54
+ [(10, i) for i in range(2)],
55
+ [(15, i) for i in range(10)],
56
+ [(18, i) for i in range(10)],
57
+ [(25, i) for i in range(10)],
58
+ [(28, i) for i in range(10)],
59
+ [(31, i) for i in range(10)],
60
+ [(35, i) for i in range(2)],
61
+ [(38, i) for i in range(2)],
62
+ [(41, i) for i in range(2)],
63
+ [(21, i) for i in range(10)],
64
+ ], [])
65
+ self.call_block_id = {i: j for j, i in enumerate(map_list)}
66
+
67
+ def set_less_adapter(self):
68
+ map_list = sum([
69
+ [(7, i) for i in range(2)],
70
+ [(10, i) for i in range(2)],
71
+ [(15, i) for i in range(10)],
72
+ [(18, i) for i in range(10)],
73
+ [(25, i) for i in range(10)],
74
+ [(28, i) for i in range(10)],
75
+ [(31, i) for i in range(10)],
76
+ [(35, i) for i in range(2)],
77
+ [(38, i) for i in range(2)],
78
+ [(41, i) for i in range(2)],
79
+ [(21, i) for i in range(10)],
80
+ ], [])
81
+ self.call_block_id = {i: j for j, i in enumerate(map_list) if j>=34 and j<44}
82
+
83
+ def forward(self, hidden_states, scale=1.0):
84
+ hidden_states = self.image_proj(hidden_states)
85
+ hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
86
+ ip_kv_dict = {}
87
+ for (block_id, transformer_id) in self.call_block_id:
88
+ ipadapter_id = self.call_block_id[(block_id, transformer_id)]
89
+ ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
90
+ if block_id not in ip_kv_dict:
91
+ ip_kv_dict[block_id] = {}
92
+ ip_kv_dict[block_id][transformer_id] = {
93
+ "ip_k": ip_k,
94
+ "ip_v": ip_v,
95
+ "scale": scale
96
+ }
97
+ return ip_kv_dict
98
+
99
+ def state_dict_converter(self):
100
+ return SDXLIpAdapterStateDictConverter()
101
+
102
+
103
+ class SDXLIpAdapterStateDictConverter:
104
+ def __init__(self):
105
+ pass
106
+
107
+ def from_diffusers(self, state_dict):
108
+ state_dict_ = {}
109
+ for name in state_dict["ip_adapter"]:
110
+ names = name.split(".")
111
+ layer_id = str(int(names[0]) // 2)
112
+ name_ = ".".join(["ipadapter_modules"] + [layer_id] + names[1:])
113
+ state_dict_[name_] = state_dict["ip_adapter"][name]
114
+ for name in state_dict["image_proj"]:
115
+ name_ = "image_proj." + name
116
+ state_dict_[name_] = state_dict["image_proj"][name]
117
+ return state_dict_
118
+
119
+ def from_civitai(self, state_dict):
120
+ return self.from_diffusers(state_dict)
121
+
diffsynth/models/sdxl_motion.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .sd_motion import TemporalBlock
2
+ import torch
3
+
4
+
5
+
6
+ class SDXLMotionModel(torch.nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+ self.motion_modules = torch.nn.ModuleList([
10
+ TemporalBlock(8, 320//8, 320, eps=1e-6),
11
+ TemporalBlock(8, 320//8, 320, eps=1e-6),
12
+
13
+ TemporalBlock(8, 640//8, 640, eps=1e-6),
14
+ TemporalBlock(8, 640//8, 640, eps=1e-6),
15
+
16
+ TemporalBlock(8, 1280//8, 1280, eps=1e-6),
17
+ TemporalBlock(8, 1280//8, 1280, eps=1e-6),
18
+
19
+ TemporalBlock(8, 1280//8, 1280, eps=1e-6),
20
+ TemporalBlock(8, 1280//8, 1280, eps=1e-6),
21
+ TemporalBlock(8, 1280//8, 1280, eps=1e-6),
22
+
23
+ TemporalBlock(8, 640//8, 640, eps=1e-6),
24
+ TemporalBlock(8, 640//8, 640, eps=1e-6),
25
+ TemporalBlock(8, 640//8, 640, eps=1e-6),
26
+
27
+ TemporalBlock(8, 320//8, 320, eps=1e-6),
28
+ TemporalBlock(8, 320//8, 320, eps=1e-6),
29
+ TemporalBlock(8, 320//8, 320, eps=1e-6),
30
+ ])
31
+ self.call_block_id = {
32
+ 0: 0,
33
+ 2: 1,
34
+ 7: 2,
35
+ 10: 3,
36
+ 15: 4,
37
+ 18: 5,
38
+ 25: 6,
39
+ 28: 7,
40
+ 31: 8,
41
+ 35: 9,
42
+ 38: 10,
43
+ 41: 11,
44
+ 44: 12,
45
+ 46: 13,
46
+ 48: 14,
47
+ }
48
+
49
+ def forward(self):
50
+ pass
51
+
52
+ def state_dict_converter(self):
53
+ return SDMotionModelStateDictConverter()
54
+
55
+
56
+ class SDMotionModelStateDictConverter:
57
+ def __init__(self):
58
+ pass
59
+
60
+ def from_diffusers(self, state_dict):
61
+ rename_dict = {
62
+ "norm": "norm",
63
+ "proj_in": "proj_in",
64
+ "transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q",
65
+ "transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k",
66
+ "transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v",
67
+ "transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out",
68
+ "transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1",
69
+ "transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q",
70
+ "transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k",
71
+ "transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v",
72
+ "transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out",
73
+ "transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2",
74
+ "transformer_blocks.0.norms.0": "transformer_blocks.0.norm1",
75
+ "transformer_blocks.0.norms.1": "transformer_blocks.0.norm2",
76
+ "transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj",
77
+ "transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff",
78
+ "transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3",
79
+ "proj_out": "proj_out",
80
+ }
81
+ name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")])
82
+ name_list += sorted([i for i in state_dict if i.startswith("mid_block.")])
83
+ name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")])
84
+ state_dict_ = {}
85
+ last_prefix, module_id = "", -1
86
+ for name in name_list:
87
+ names = name.split(".")
88
+ prefix_index = names.index("temporal_transformer") + 1
89
+ prefix = ".".join(names[:prefix_index])
90
+ if prefix != last_prefix:
91
+ last_prefix = prefix
92
+ module_id += 1
93
+ middle_name = ".".join(names[prefix_index:-1])
94
+ suffix = names[-1]
95
+ if "pos_encoder" in names:
96
+ rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]])
97
+ else:
98
+ rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
99
+ state_dict_[rename] = state_dict[name]
100
+ return state_dict_
101
+
102
+ def from_civitai(self, state_dict):
103
+ return self.from_diffusers(state_dict)
diffsynth/models/sdxl_text_encoder.py ADDED
@@ -0,0 +1,757 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .sd_text_encoder import CLIPEncoderLayer
3
+
4
+
5
+ class SDXLTextEncoder(torch.nn.Module):
6
+ def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=11, encoder_intermediate_size=3072):
7
+ super().__init__()
8
+
9
+ # token_embedding
10
+ self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
11
+
12
+ # position_embeds (This is a fixed tensor)
13
+ self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
14
+
15
+ # encoders
16
+ self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
17
+
18
+ # attn_mask
19
+ self.attn_mask = self.attention_mask(max_position_embeddings)
20
+
21
+ # The text encoder is different to that in Stable Diffusion 1.x.
22
+ # It does not include final_layer_norm.
23
+
24
+ def attention_mask(self, length):
25
+ mask = torch.empty(length, length)
26
+ mask.fill_(float("-inf"))
27
+ mask.triu_(1)
28
+ return mask
29
+
30
+ def forward(self, input_ids, clip_skip=1):
31
+ embeds = self.token_embedding(input_ids) + self.position_embeds
32
+ attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
33
+ for encoder_id, encoder in enumerate(self.encoders):
34
+ embeds = encoder(embeds, attn_mask=attn_mask)
35
+ if encoder_id + clip_skip == len(self.encoders):
36
+ break
37
+ return embeds
38
+
39
+ def state_dict_converter(self):
40
+ return SDXLTextEncoderStateDictConverter()
41
+
42
+
43
+ class SDXLTextEncoder2(torch.nn.Module):
44
+ def __init__(self, embed_dim=1280, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=32, encoder_intermediate_size=5120):
45
+ super().__init__()
46
+
47
+ # token_embedding
48
+ self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
49
+
50
+ # position_embeds (This is a fixed tensor)
51
+ self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
52
+
53
+ # encoders
54
+ self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=20, head_dim=64, use_quick_gelu=False) for _ in range(num_encoder_layers)])
55
+
56
+ # attn_mask
57
+ self.attn_mask = self.attention_mask(max_position_embeddings)
58
+
59
+ # final_layer_norm
60
+ self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
61
+
62
+ # text_projection
63
+ self.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False)
64
+
65
+ def attention_mask(self, length):
66
+ mask = torch.empty(length, length)
67
+ mask.fill_(float("-inf"))
68
+ mask.triu_(1)
69
+ return mask
70
+
71
+ def forward(self, input_ids, clip_skip=2):
72
+ embeds = self.token_embedding(input_ids) + self.position_embeds
73
+ attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
74
+ for encoder_id, encoder in enumerate(self.encoders):
75
+ embeds = encoder(embeds, attn_mask=attn_mask)
76
+ if encoder_id + clip_skip == len(self.encoders):
77
+ hidden_states = embeds
78
+ embeds = self.final_layer_norm(embeds)
79
+ pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
80
+ pooled_embeds = self.text_projection(pooled_embeds)
81
+ return pooled_embeds, hidden_states
82
+
83
+ def state_dict_converter(self):
84
+ return SDXLTextEncoder2StateDictConverter()
85
+
86
+
87
+ class SDXLTextEncoderStateDictConverter:
88
+ def __init__(self):
89
+ pass
90
+
91
+ def from_diffusers(self, state_dict):
92
+ rename_dict = {
93
+ "text_model.embeddings.token_embedding.weight": "token_embedding.weight",
94
+ "text_model.embeddings.position_embedding.weight": "position_embeds",
95
+ "text_model.final_layer_norm.weight": "final_layer_norm.weight",
96
+ "text_model.final_layer_norm.bias": "final_layer_norm.bias"
97
+ }
98
+ attn_rename_dict = {
99
+ "self_attn.q_proj": "attn.to_q",
100
+ "self_attn.k_proj": "attn.to_k",
101
+ "self_attn.v_proj": "attn.to_v",
102
+ "self_attn.out_proj": "attn.to_out",
103
+ "layer_norm1": "layer_norm1",
104
+ "layer_norm2": "layer_norm2",
105
+ "mlp.fc1": "fc1",
106
+ "mlp.fc2": "fc2",
107
+ }
108
+ state_dict_ = {}
109
+ for name in state_dict:
110
+ if name in rename_dict:
111
+ param = state_dict[name]
112
+ if name == "text_model.embeddings.position_embedding.weight":
113
+ param = param.reshape((1, param.shape[0], param.shape[1]))
114
+ state_dict_[rename_dict[name]] = param
115
+ elif name.startswith("text_model.encoder.layers."):
116
+ param = state_dict[name]
117
+ names = name.split(".")
118
+ layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
119
+ name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
120
+ state_dict_[name_] = param
121
+ return state_dict_
122
+
123
+ def from_civitai(self, state_dict):
124
+ rename_dict = {
125
+ "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight": "position_embeds",
126
+ "conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
127
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
128
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
129
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
130
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
131
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
132
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
133
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
134
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
135
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
136
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
137
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
138
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
139
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
140
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
141
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
142
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
143
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
144
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
145
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
146
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
147
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
148
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
149
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
150
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
151
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
152
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
153
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
154
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
155
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
156
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
157
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
158
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
159
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
160
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
161
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
162
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
163
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
164
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
165
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
166
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
167
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
168
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
169
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
170
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
171
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
172
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
173
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
174
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
175
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
176
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
177
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
178
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
179
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
180
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
181
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
182
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
183
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
184
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
185
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
186
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
187
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
188
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
189
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
190
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
191
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
192
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
193
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
194
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
195
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
196
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
197
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
198
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
199
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
200
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
201
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
202
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
203
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
204
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
205
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
206
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
207
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
208
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
209
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
210
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
211
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
212
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
213
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
214
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
215
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
216
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
217
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
218
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
219
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
220
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
221
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
222
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
223
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
224
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
225
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
226
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
227
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
228
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
229
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
230
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
231
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
232
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
233
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
234
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
235
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
236
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
237
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
238
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
239
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
240
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
241
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
242
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
243
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
244
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
245
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
246
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
247
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
248
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
249
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
250
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
251
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
252
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
253
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
254
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
255
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
256
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
257
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
258
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
259
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
260
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
261
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
262
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
263
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
264
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
265
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
266
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
267
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
268
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
269
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
270
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
271
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
272
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
273
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
274
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
275
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
276
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
277
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
278
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
279
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
280
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
281
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
282
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
283
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
284
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
285
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
286
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
287
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
288
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
289
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
290
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
291
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
292
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
293
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
294
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
295
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
296
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
297
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
298
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
299
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
300
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
301
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
302
+ "conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
303
+ }
304
+ state_dict_ = {}
305
+ for name in state_dict:
306
+ if name in rename_dict:
307
+ param = state_dict[name]
308
+ if name == "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight":
309
+ param = param.reshape((1, param.shape[0], param.shape[1]))
310
+ state_dict_[rename_dict[name]] = param
311
+ return state_dict_
312
+
313
+
314
+ class SDXLTextEncoder2StateDictConverter:
315
+ def __init__(self):
316
+ pass
317
+
318
+ def from_diffusers(self, state_dict):
319
+ rename_dict = {
320
+ "text_model.embeddings.token_embedding.weight": "token_embedding.weight",
321
+ "text_model.embeddings.position_embedding.weight": "position_embeds",
322
+ "text_model.final_layer_norm.weight": "final_layer_norm.weight",
323
+ "text_model.final_layer_norm.bias": "final_layer_norm.bias",
324
+ "text_projection.weight": "text_projection.weight"
325
+ }
326
+ attn_rename_dict = {
327
+ "self_attn.q_proj": "attn.to_q",
328
+ "self_attn.k_proj": "attn.to_k",
329
+ "self_attn.v_proj": "attn.to_v",
330
+ "self_attn.out_proj": "attn.to_out",
331
+ "layer_norm1": "layer_norm1",
332
+ "layer_norm2": "layer_norm2",
333
+ "mlp.fc1": "fc1",
334
+ "mlp.fc2": "fc2",
335
+ }
336
+ state_dict_ = {}
337
+ for name in state_dict:
338
+ if name in rename_dict:
339
+ param = state_dict[name]
340
+ if name == "text_model.embeddings.position_embedding.weight":
341
+ param = param.reshape((1, param.shape[0], param.shape[1]))
342
+ state_dict_[rename_dict[name]] = param
343
+ elif name.startswith("text_model.encoder.layers."):
344
+ param = state_dict[name]
345
+ names = name.split(".")
346
+ layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
347
+ name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
348
+ state_dict_[name_] = param
349
+ return state_dict_
350
+
351
+ def from_civitai(self, state_dict):
352
+ rename_dict = {
353
+ "conditioner.embedders.1.model.ln_final.bias": "final_layer_norm.bias",
354
+ "conditioner.embedders.1.model.ln_final.weight": "final_layer_norm.weight",
355
+ "conditioner.embedders.1.model.positional_embedding": "position_embeds",
356
+ "conditioner.embedders.1.model.token_embedding.weight": "token_embedding.weight",
357
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias": ['encoders.0.attn.to_q.bias', 'encoders.0.attn.to_k.bias', 'encoders.0.attn.to_v.bias'],
358
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight": ['encoders.0.attn.to_q.weight', 'encoders.0.attn.to_k.weight', 'encoders.0.attn.to_v.weight'],
359
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.bias": "encoders.0.attn.to_out.bias",
360
+ "conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.weight": "encoders.0.attn.to_out.weight",
361
+ "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias": "encoders.0.layer_norm1.bias",
362
+ "conditioner.embedders.1.model.transformer.resblocks.0.ln_1.weight": "encoders.0.layer_norm1.weight",
363
+ "conditioner.embedders.1.model.transformer.resblocks.0.ln_2.bias": "encoders.0.layer_norm2.bias",
364
+ "conditioner.embedders.1.model.transformer.resblocks.0.ln_2.weight": "encoders.0.layer_norm2.weight",
365
+ "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.bias": "encoders.0.fc1.bias",
366
+ "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.weight": "encoders.0.fc1.weight",
367
+ "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.bias": "encoders.0.fc2.bias",
368
+ "conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.weight": "encoders.0.fc2.weight",
369
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias": ['encoders.1.attn.to_q.bias', 'encoders.1.attn.to_k.bias', 'encoders.1.attn.to_v.bias'],
370
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight": ['encoders.1.attn.to_q.weight', 'encoders.1.attn.to_k.weight', 'encoders.1.attn.to_v.weight'],
371
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.bias": "encoders.1.attn.to_out.bias",
372
+ "conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.weight": "encoders.1.attn.to_out.weight",
373
+ "conditioner.embedders.1.model.transformer.resblocks.1.ln_1.bias": "encoders.1.layer_norm1.bias",
374
+ "conditioner.embedders.1.model.transformer.resblocks.1.ln_1.weight": "encoders.1.layer_norm1.weight",
375
+ "conditioner.embedders.1.model.transformer.resblocks.1.ln_2.bias": "encoders.1.layer_norm2.bias",
376
+ "conditioner.embedders.1.model.transformer.resblocks.1.ln_2.weight": "encoders.1.layer_norm2.weight",
377
+ "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.bias": "encoders.1.fc1.bias",
378
+ "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.weight": "encoders.1.fc1.weight",
379
+ "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.bias": "encoders.1.fc2.bias",
380
+ "conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.weight": "encoders.1.fc2.weight",
381
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias": ['encoders.10.attn.to_q.bias', 'encoders.10.attn.to_k.bias', 'encoders.10.attn.to_v.bias'],
382
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight": ['encoders.10.attn.to_q.weight', 'encoders.10.attn.to_k.weight', 'encoders.10.attn.to_v.weight'],
383
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.bias": "encoders.10.attn.to_out.bias",
384
+ "conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.weight": "encoders.10.attn.to_out.weight",
385
+ "conditioner.embedders.1.model.transformer.resblocks.10.ln_1.bias": "encoders.10.layer_norm1.bias",
386
+ "conditioner.embedders.1.model.transformer.resblocks.10.ln_1.weight": "encoders.10.layer_norm1.weight",
387
+ "conditioner.embedders.1.model.transformer.resblocks.10.ln_2.bias": "encoders.10.layer_norm2.bias",
388
+ "conditioner.embedders.1.model.transformer.resblocks.10.ln_2.weight": "encoders.10.layer_norm2.weight",
389
+ "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.bias": "encoders.10.fc1.bias",
390
+ "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.weight": "encoders.10.fc1.weight",
391
+ "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.bias": "encoders.10.fc2.bias",
392
+ "conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.weight": "encoders.10.fc2.weight",
393
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias": ['encoders.11.attn.to_q.bias', 'encoders.11.attn.to_k.bias', 'encoders.11.attn.to_v.bias'],
394
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight": ['encoders.11.attn.to_q.weight', 'encoders.11.attn.to_k.weight', 'encoders.11.attn.to_v.weight'],
395
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.bias": "encoders.11.attn.to_out.bias",
396
+ "conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.weight": "encoders.11.attn.to_out.weight",
397
+ "conditioner.embedders.1.model.transformer.resblocks.11.ln_1.bias": "encoders.11.layer_norm1.bias",
398
+ "conditioner.embedders.1.model.transformer.resblocks.11.ln_1.weight": "encoders.11.layer_norm1.weight",
399
+ "conditioner.embedders.1.model.transformer.resblocks.11.ln_2.bias": "encoders.11.layer_norm2.bias",
400
+ "conditioner.embedders.1.model.transformer.resblocks.11.ln_2.weight": "encoders.11.layer_norm2.weight",
401
+ "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.bias": "encoders.11.fc1.bias",
402
+ "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.weight": "encoders.11.fc1.weight",
403
+ "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.bias": "encoders.11.fc2.bias",
404
+ "conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.weight": "encoders.11.fc2.weight",
405
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias": ['encoders.12.attn.to_q.bias', 'encoders.12.attn.to_k.bias', 'encoders.12.attn.to_v.bias'],
406
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight": ['encoders.12.attn.to_q.weight', 'encoders.12.attn.to_k.weight', 'encoders.12.attn.to_v.weight'],
407
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.bias": "encoders.12.attn.to_out.bias",
408
+ "conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.weight": "encoders.12.attn.to_out.weight",
409
+ "conditioner.embedders.1.model.transformer.resblocks.12.ln_1.bias": "encoders.12.layer_norm1.bias",
410
+ "conditioner.embedders.1.model.transformer.resblocks.12.ln_1.weight": "encoders.12.layer_norm1.weight",
411
+ "conditioner.embedders.1.model.transformer.resblocks.12.ln_2.bias": "encoders.12.layer_norm2.bias",
412
+ "conditioner.embedders.1.model.transformer.resblocks.12.ln_2.weight": "encoders.12.layer_norm2.weight",
413
+ "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.bias": "encoders.12.fc1.bias",
414
+ "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.weight": "encoders.12.fc1.weight",
415
+ "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.bias": "encoders.12.fc2.bias",
416
+ "conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.weight": "encoders.12.fc2.weight",
417
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias": ['encoders.13.attn.to_q.bias', 'encoders.13.attn.to_k.bias', 'encoders.13.attn.to_v.bias'],
418
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight": ['encoders.13.attn.to_q.weight', 'encoders.13.attn.to_k.weight', 'encoders.13.attn.to_v.weight'],
419
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.bias": "encoders.13.attn.to_out.bias",
420
+ "conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.weight": "encoders.13.attn.to_out.weight",
421
+ "conditioner.embedders.1.model.transformer.resblocks.13.ln_1.bias": "encoders.13.layer_norm1.bias",
422
+ "conditioner.embedders.1.model.transformer.resblocks.13.ln_1.weight": "encoders.13.layer_norm1.weight",
423
+ "conditioner.embedders.1.model.transformer.resblocks.13.ln_2.bias": "encoders.13.layer_norm2.bias",
424
+ "conditioner.embedders.1.model.transformer.resblocks.13.ln_2.weight": "encoders.13.layer_norm2.weight",
425
+ "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.bias": "encoders.13.fc1.bias",
426
+ "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.weight": "encoders.13.fc1.weight",
427
+ "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.bias": "encoders.13.fc2.bias",
428
+ "conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.weight": "encoders.13.fc2.weight",
429
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias": ['encoders.14.attn.to_q.bias', 'encoders.14.attn.to_k.bias', 'encoders.14.attn.to_v.bias'],
430
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight": ['encoders.14.attn.to_q.weight', 'encoders.14.attn.to_k.weight', 'encoders.14.attn.to_v.weight'],
431
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.bias": "encoders.14.attn.to_out.bias",
432
+ "conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.weight": "encoders.14.attn.to_out.weight",
433
+ "conditioner.embedders.1.model.transformer.resblocks.14.ln_1.bias": "encoders.14.layer_norm1.bias",
434
+ "conditioner.embedders.1.model.transformer.resblocks.14.ln_1.weight": "encoders.14.layer_norm1.weight",
435
+ "conditioner.embedders.1.model.transformer.resblocks.14.ln_2.bias": "encoders.14.layer_norm2.bias",
436
+ "conditioner.embedders.1.model.transformer.resblocks.14.ln_2.weight": "encoders.14.layer_norm2.weight",
437
+ "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.bias": "encoders.14.fc1.bias",
438
+ "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.weight": "encoders.14.fc1.weight",
439
+ "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.bias": "encoders.14.fc2.bias",
440
+ "conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.weight": "encoders.14.fc2.weight",
441
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias": ['encoders.15.attn.to_q.bias', 'encoders.15.attn.to_k.bias', 'encoders.15.attn.to_v.bias'],
442
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight": ['encoders.15.attn.to_q.weight', 'encoders.15.attn.to_k.weight', 'encoders.15.attn.to_v.weight'],
443
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.bias": "encoders.15.attn.to_out.bias",
444
+ "conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.weight": "encoders.15.attn.to_out.weight",
445
+ "conditioner.embedders.1.model.transformer.resblocks.15.ln_1.bias": "encoders.15.layer_norm1.bias",
446
+ "conditioner.embedders.1.model.transformer.resblocks.15.ln_1.weight": "encoders.15.layer_norm1.weight",
447
+ "conditioner.embedders.1.model.transformer.resblocks.15.ln_2.bias": "encoders.15.layer_norm2.bias",
448
+ "conditioner.embedders.1.model.transformer.resblocks.15.ln_2.weight": "encoders.15.layer_norm2.weight",
449
+ "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.bias": "encoders.15.fc1.bias",
450
+ "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.weight": "encoders.15.fc1.weight",
451
+ "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.bias": "encoders.15.fc2.bias",
452
+ "conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.weight": "encoders.15.fc2.weight",
453
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias": ['encoders.16.attn.to_q.bias', 'encoders.16.attn.to_k.bias', 'encoders.16.attn.to_v.bias'],
454
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight": ['encoders.16.attn.to_q.weight', 'encoders.16.attn.to_k.weight', 'encoders.16.attn.to_v.weight'],
455
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.bias": "encoders.16.attn.to_out.bias",
456
+ "conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.weight": "encoders.16.attn.to_out.weight",
457
+ "conditioner.embedders.1.model.transformer.resblocks.16.ln_1.bias": "encoders.16.layer_norm1.bias",
458
+ "conditioner.embedders.1.model.transformer.resblocks.16.ln_1.weight": "encoders.16.layer_norm1.weight",
459
+ "conditioner.embedders.1.model.transformer.resblocks.16.ln_2.bias": "encoders.16.layer_norm2.bias",
460
+ "conditioner.embedders.1.model.transformer.resblocks.16.ln_2.weight": "encoders.16.layer_norm2.weight",
461
+ "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.bias": "encoders.16.fc1.bias",
462
+ "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.weight": "encoders.16.fc1.weight",
463
+ "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.bias": "encoders.16.fc2.bias",
464
+ "conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.weight": "encoders.16.fc2.weight",
465
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias": ['encoders.17.attn.to_q.bias', 'encoders.17.attn.to_k.bias', 'encoders.17.attn.to_v.bias'],
466
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight": ['encoders.17.attn.to_q.weight', 'encoders.17.attn.to_k.weight', 'encoders.17.attn.to_v.weight'],
467
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.bias": "encoders.17.attn.to_out.bias",
468
+ "conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.weight": "encoders.17.attn.to_out.weight",
469
+ "conditioner.embedders.1.model.transformer.resblocks.17.ln_1.bias": "encoders.17.layer_norm1.bias",
470
+ "conditioner.embedders.1.model.transformer.resblocks.17.ln_1.weight": "encoders.17.layer_norm1.weight",
471
+ "conditioner.embedders.1.model.transformer.resblocks.17.ln_2.bias": "encoders.17.layer_norm2.bias",
472
+ "conditioner.embedders.1.model.transformer.resblocks.17.ln_2.weight": "encoders.17.layer_norm2.weight",
473
+ "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.bias": "encoders.17.fc1.bias",
474
+ "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.weight": "encoders.17.fc1.weight",
475
+ "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.bias": "encoders.17.fc2.bias",
476
+ "conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.weight": "encoders.17.fc2.weight",
477
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias": ['encoders.18.attn.to_q.bias', 'encoders.18.attn.to_k.bias', 'encoders.18.attn.to_v.bias'],
478
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight": ['encoders.18.attn.to_q.weight', 'encoders.18.attn.to_k.weight', 'encoders.18.attn.to_v.weight'],
479
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.bias": "encoders.18.attn.to_out.bias",
480
+ "conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.weight": "encoders.18.attn.to_out.weight",
481
+ "conditioner.embedders.1.model.transformer.resblocks.18.ln_1.bias": "encoders.18.layer_norm1.bias",
482
+ "conditioner.embedders.1.model.transformer.resblocks.18.ln_1.weight": "encoders.18.layer_norm1.weight",
483
+ "conditioner.embedders.1.model.transformer.resblocks.18.ln_2.bias": "encoders.18.layer_norm2.bias",
484
+ "conditioner.embedders.1.model.transformer.resblocks.18.ln_2.weight": "encoders.18.layer_norm2.weight",
485
+ "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.bias": "encoders.18.fc1.bias",
486
+ "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.weight": "encoders.18.fc1.weight",
487
+ "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.bias": "encoders.18.fc2.bias",
488
+ "conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.weight": "encoders.18.fc2.weight",
489
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias": ['encoders.19.attn.to_q.bias', 'encoders.19.attn.to_k.bias', 'encoders.19.attn.to_v.bias'],
490
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight": ['encoders.19.attn.to_q.weight', 'encoders.19.attn.to_k.weight', 'encoders.19.attn.to_v.weight'],
491
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.bias": "encoders.19.attn.to_out.bias",
492
+ "conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.weight": "encoders.19.attn.to_out.weight",
493
+ "conditioner.embedders.1.model.transformer.resblocks.19.ln_1.bias": "encoders.19.layer_norm1.bias",
494
+ "conditioner.embedders.1.model.transformer.resblocks.19.ln_1.weight": "encoders.19.layer_norm1.weight",
495
+ "conditioner.embedders.1.model.transformer.resblocks.19.ln_2.bias": "encoders.19.layer_norm2.bias",
496
+ "conditioner.embedders.1.model.transformer.resblocks.19.ln_2.weight": "encoders.19.layer_norm2.weight",
497
+ "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.bias": "encoders.19.fc1.bias",
498
+ "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.weight": "encoders.19.fc1.weight",
499
+ "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.bias": "encoders.19.fc2.bias",
500
+ "conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.weight": "encoders.19.fc2.weight",
501
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias": ['encoders.2.attn.to_q.bias', 'encoders.2.attn.to_k.bias', 'encoders.2.attn.to_v.bias'],
502
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight": ['encoders.2.attn.to_q.weight', 'encoders.2.attn.to_k.weight', 'encoders.2.attn.to_v.weight'],
503
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.bias": "encoders.2.attn.to_out.bias",
504
+ "conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.weight": "encoders.2.attn.to_out.weight",
505
+ "conditioner.embedders.1.model.transformer.resblocks.2.ln_1.bias": "encoders.2.layer_norm1.bias",
506
+ "conditioner.embedders.1.model.transformer.resblocks.2.ln_1.weight": "encoders.2.layer_norm1.weight",
507
+ "conditioner.embedders.1.model.transformer.resblocks.2.ln_2.bias": "encoders.2.layer_norm2.bias",
508
+ "conditioner.embedders.1.model.transformer.resblocks.2.ln_2.weight": "encoders.2.layer_norm2.weight",
509
+ "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.bias": "encoders.2.fc1.bias",
510
+ "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.weight": "encoders.2.fc1.weight",
511
+ "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.bias": "encoders.2.fc2.bias",
512
+ "conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.weight": "encoders.2.fc2.weight",
513
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias": ['encoders.20.attn.to_q.bias', 'encoders.20.attn.to_k.bias', 'encoders.20.attn.to_v.bias'],
514
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight": ['encoders.20.attn.to_q.weight', 'encoders.20.attn.to_k.weight', 'encoders.20.attn.to_v.weight'],
515
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.bias": "encoders.20.attn.to_out.bias",
516
+ "conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.weight": "encoders.20.attn.to_out.weight",
517
+ "conditioner.embedders.1.model.transformer.resblocks.20.ln_1.bias": "encoders.20.layer_norm1.bias",
518
+ "conditioner.embedders.1.model.transformer.resblocks.20.ln_1.weight": "encoders.20.layer_norm1.weight",
519
+ "conditioner.embedders.1.model.transformer.resblocks.20.ln_2.bias": "encoders.20.layer_norm2.bias",
520
+ "conditioner.embedders.1.model.transformer.resblocks.20.ln_2.weight": "encoders.20.layer_norm2.weight",
521
+ "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.bias": "encoders.20.fc1.bias",
522
+ "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.weight": "encoders.20.fc1.weight",
523
+ "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.bias": "encoders.20.fc2.bias",
524
+ "conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.weight": "encoders.20.fc2.weight",
525
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias": ['encoders.21.attn.to_q.bias', 'encoders.21.attn.to_k.bias', 'encoders.21.attn.to_v.bias'],
526
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight": ['encoders.21.attn.to_q.weight', 'encoders.21.attn.to_k.weight', 'encoders.21.attn.to_v.weight'],
527
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.bias": "encoders.21.attn.to_out.bias",
528
+ "conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.weight": "encoders.21.attn.to_out.weight",
529
+ "conditioner.embedders.1.model.transformer.resblocks.21.ln_1.bias": "encoders.21.layer_norm1.bias",
530
+ "conditioner.embedders.1.model.transformer.resblocks.21.ln_1.weight": "encoders.21.layer_norm1.weight",
531
+ "conditioner.embedders.1.model.transformer.resblocks.21.ln_2.bias": "encoders.21.layer_norm2.bias",
532
+ "conditioner.embedders.1.model.transformer.resblocks.21.ln_2.weight": "encoders.21.layer_norm2.weight",
533
+ "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.bias": "encoders.21.fc1.bias",
534
+ "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.weight": "encoders.21.fc1.weight",
535
+ "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.bias": "encoders.21.fc2.bias",
536
+ "conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.weight": "encoders.21.fc2.weight",
537
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias": ['encoders.22.attn.to_q.bias', 'encoders.22.attn.to_k.bias', 'encoders.22.attn.to_v.bias'],
538
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight": ['encoders.22.attn.to_q.weight', 'encoders.22.attn.to_k.weight', 'encoders.22.attn.to_v.weight'],
539
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.bias": "encoders.22.attn.to_out.bias",
540
+ "conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.weight": "encoders.22.attn.to_out.weight",
541
+ "conditioner.embedders.1.model.transformer.resblocks.22.ln_1.bias": "encoders.22.layer_norm1.bias",
542
+ "conditioner.embedders.1.model.transformer.resblocks.22.ln_1.weight": "encoders.22.layer_norm1.weight",
543
+ "conditioner.embedders.1.model.transformer.resblocks.22.ln_2.bias": "encoders.22.layer_norm2.bias",
544
+ "conditioner.embedders.1.model.transformer.resblocks.22.ln_2.weight": "encoders.22.layer_norm2.weight",
545
+ "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.bias": "encoders.22.fc1.bias",
546
+ "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.weight": "encoders.22.fc1.weight",
547
+ "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.bias": "encoders.22.fc2.bias",
548
+ "conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.weight": "encoders.22.fc2.weight",
549
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias": ['encoders.23.attn.to_q.bias', 'encoders.23.attn.to_k.bias', 'encoders.23.attn.to_v.bias'],
550
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight": ['encoders.23.attn.to_q.weight', 'encoders.23.attn.to_k.weight', 'encoders.23.attn.to_v.weight'],
551
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.bias": "encoders.23.attn.to_out.bias",
552
+ "conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.weight": "encoders.23.attn.to_out.weight",
553
+ "conditioner.embedders.1.model.transformer.resblocks.23.ln_1.bias": "encoders.23.layer_norm1.bias",
554
+ "conditioner.embedders.1.model.transformer.resblocks.23.ln_1.weight": "encoders.23.layer_norm1.weight",
555
+ "conditioner.embedders.1.model.transformer.resblocks.23.ln_2.bias": "encoders.23.layer_norm2.bias",
556
+ "conditioner.embedders.1.model.transformer.resblocks.23.ln_2.weight": "encoders.23.layer_norm2.weight",
557
+ "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.bias": "encoders.23.fc1.bias",
558
+ "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.weight": "encoders.23.fc1.weight",
559
+ "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.bias": "encoders.23.fc2.bias",
560
+ "conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.weight": "encoders.23.fc2.weight",
561
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias": ['encoders.24.attn.to_q.bias', 'encoders.24.attn.to_k.bias', 'encoders.24.attn.to_v.bias'],
562
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight": ['encoders.24.attn.to_q.weight', 'encoders.24.attn.to_k.weight', 'encoders.24.attn.to_v.weight'],
563
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.bias": "encoders.24.attn.to_out.bias",
564
+ "conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.weight": "encoders.24.attn.to_out.weight",
565
+ "conditioner.embedders.1.model.transformer.resblocks.24.ln_1.bias": "encoders.24.layer_norm1.bias",
566
+ "conditioner.embedders.1.model.transformer.resblocks.24.ln_1.weight": "encoders.24.layer_norm1.weight",
567
+ "conditioner.embedders.1.model.transformer.resblocks.24.ln_2.bias": "encoders.24.layer_norm2.bias",
568
+ "conditioner.embedders.1.model.transformer.resblocks.24.ln_2.weight": "encoders.24.layer_norm2.weight",
569
+ "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.bias": "encoders.24.fc1.bias",
570
+ "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.weight": "encoders.24.fc1.weight",
571
+ "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.bias": "encoders.24.fc2.bias",
572
+ "conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.weight": "encoders.24.fc2.weight",
573
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias": ['encoders.25.attn.to_q.bias', 'encoders.25.attn.to_k.bias', 'encoders.25.attn.to_v.bias'],
574
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight": ['encoders.25.attn.to_q.weight', 'encoders.25.attn.to_k.weight', 'encoders.25.attn.to_v.weight'],
575
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.bias": "encoders.25.attn.to_out.bias",
576
+ "conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.weight": "encoders.25.attn.to_out.weight",
577
+ "conditioner.embedders.1.model.transformer.resblocks.25.ln_1.bias": "encoders.25.layer_norm1.bias",
578
+ "conditioner.embedders.1.model.transformer.resblocks.25.ln_1.weight": "encoders.25.layer_norm1.weight",
579
+ "conditioner.embedders.1.model.transformer.resblocks.25.ln_2.bias": "encoders.25.layer_norm2.bias",
580
+ "conditioner.embedders.1.model.transformer.resblocks.25.ln_2.weight": "encoders.25.layer_norm2.weight",
581
+ "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.bias": "encoders.25.fc1.bias",
582
+ "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.weight": "encoders.25.fc1.weight",
583
+ "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.bias": "encoders.25.fc2.bias",
584
+ "conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.weight": "encoders.25.fc2.weight",
585
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias": ['encoders.26.attn.to_q.bias', 'encoders.26.attn.to_k.bias', 'encoders.26.attn.to_v.bias'],
586
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight": ['encoders.26.attn.to_q.weight', 'encoders.26.attn.to_k.weight', 'encoders.26.attn.to_v.weight'],
587
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.bias": "encoders.26.attn.to_out.bias",
588
+ "conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.weight": "encoders.26.attn.to_out.weight",
589
+ "conditioner.embedders.1.model.transformer.resblocks.26.ln_1.bias": "encoders.26.layer_norm1.bias",
590
+ "conditioner.embedders.1.model.transformer.resblocks.26.ln_1.weight": "encoders.26.layer_norm1.weight",
591
+ "conditioner.embedders.1.model.transformer.resblocks.26.ln_2.bias": "encoders.26.layer_norm2.bias",
592
+ "conditioner.embedders.1.model.transformer.resblocks.26.ln_2.weight": "encoders.26.layer_norm2.weight",
593
+ "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.bias": "encoders.26.fc1.bias",
594
+ "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.weight": "encoders.26.fc1.weight",
595
+ "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.bias": "encoders.26.fc2.bias",
596
+ "conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.weight": "encoders.26.fc2.weight",
597
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias": ['encoders.27.attn.to_q.bias', 'encoders.27.attn.to_k.bias', 'encoders.27.attn.to_v.bias'],
598
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight": ['encoders.27.attn.to_q.weight', 'encoders.27.attn.to_k.weight', 'encoders.27.attn.to_v.weight'],
599
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.bias": "encoders.27.attn.to_out.bias",
600
+ "conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.weight": "encoders.27.attn.to_out.weight",
601
+ "conditioner.embedders.1.model.transformer.resblocks.27.ln_1.bias": "encoders.27.layer_norm1.bias",
602
+ "conditioner.embedders.1.model.transformer.resblocks.27.ln_1.weight": "encoders.27.layer_norm1.weight",
603
+ "conditioner.embedders.1.model.transformer.resblocks.27.ln_2.bias": "encoders.27.layer_norm2.bias",
604
+ "conditioner.embedders.1.model.transformer.resblocks.27.ln_2.weight": "encoders.27.layer_norm2.weight",
605
+ "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.bias": "encoders.27.fc1.bias",
606
+ "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.weight": "encoders.27.fc1.weight",
607
+ "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.bias": "encoders.27.fc2.bias",
608
+ "conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.weight": "encoders.27.fc2.weight",
609
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias": ['encoders.28.attn.to_q.bias', 'encoders.28.attn.to_k.bias', 'encoders.28.attn.to_v.bias'],
610
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight": ['encoders.28.attn.to_q.weight', 'encoders.28.attn.to_k.weight', 'encoders.28.attn.to_v.weight'],
611
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.bias": "encoders.28.attn.to_out.bias",
612
+ "conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.weight": "encoders.28.attn.to_out.weight",
613
+ "conditioner.embedders.1.model.transformer.resblocks.28.ln_1.bias": "encoders.28.layer_norm1.bias",
614
+ "conditioner.embedders.1.model.transformer.resblocks.28.ln_1.weight": "encoders.28.layer_norm1.weight",
615
+ "conditioner.embedders.1.model.transformer.resblocks.28.ln_2.bias": "encoders.28.layer_norm2.bias",
616
+ "conditioner.embedders.1.model.transformer.resblocks.28.ln_2.weight": "encoders.28.layer_norm2.weight",
617
+ "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.bias": "encoders.28.fc1.bias",
618
+ "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.weight": "encoders.28.fc1.weight",
619
+ "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.bias": "encoders.28.fc2.bias",
620
+ "conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.weight": "encoders.28.fc2.weight",
621
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias": ['encoders.29.attn.to_q.bias', 'encoders.29.attn.to_k.bias', 'encoders.29.attn.to_v.bias'],
622
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight": ['encoders.29.attn.to_q.weight', 'encoders.29.attn.to_k.weight', 'encoders.29.attn.to_v.weight'],
623
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.bias": "encoders.29.attn.to_out.bias",
624
+ "conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.weight": "encoders.29.attn.to_out.weight",
625
+ "conditioner.embedders.1.model.transformer.resblocks.29.ln_1.bias": "encoders.29.layer_norm1.bias",
626
+ "conditioner.embedders.1.model.transformer.resblocks.29.ln_1.weight": "encoders.29.layer_norm1.weight",
627
+ "conditioner.embedders.1.model.transformer.resblocks.29.ln_2.bias": "encoders.29.layer_norm2.bias",
628
+ "conditioner.embedders.1.model.transformer.resblocks.29.ln_2.weight": "encoders.29.layer_norm2.weight",
629
+ "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.bias": "encoders.29.fc1.bias",
630
+ "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.weight": "encoders.29.fc1.weight",
631
+ "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.bias": "encoders.29.fc2.bias",
632
+ "conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.weight": "encoders.29.fc2.weight",
633
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias": ['encoders.3.attn.to_q.bias', 'encoders.3.attn.to_k.bias', 'encoders.3.attn.to_v.bias'],
634
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight": ['encoders.3.attn.to_q.weight', 'encoders.3.attn.to_k.weight', 'encoders.3.attn.to_v.weight'],
635
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.bias": "encoders.3.attn.to_out.bias",
636
+ "conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.weight": "encoders.3.attn.to_out.weight",
637
+ "conditioner.embedders.1.model.transformer.resblocks.3.ln_1.bias": "encoders.3.layer_norm1.bias",
638
+ "conditioner.embedders.1.model.transformer.resblocks.3.ln_1.weight": "encoders.3.layer_norm1.weight",
639
+ "conditioner.embedders.1.model.transformer.resblocks.3.ln_2.bias": "encoders.3.layer_norm2.bias",
640
+ "conditioner.embedders.1.model.transformer.resblocks.3.ln_2.weight": "encoders.3.layer_norm2.weight",
641
+ "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.bias": "encoders.3.fc1.bias",
642
+ "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.weight": "encoders.3.fc1.weight",
643
+ "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.bias": "encoders.3.fc2.bias",
644
+ "conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.weight": "encoders.3.fc2.weight",
645
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias": ['encoders.30.attn.to_q.bias', 'encoders.30.attn.to_k.bias', 'encoders.30.attn.to_v.bias'],
646
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight": ['encoders.30.attn.to_q.weight', 'encoders.30.attn.to_k.weight', 'encoders.30.attn.to_v.weight'],
647
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.bias": "encoders.30.attn.to_out.bias",
648
+ "conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.weight": "encoders.30.attn.to_out.weight",
649
+ "conditioner.embedders.1.model.transformer.resblocks.30.ln_1.bias": "encoders.30.layer_norm1.bias",
650
+ "conditioner.embedders.1.model.transformer.resblocks.30.ln_1.weight": "encoders.30.layer_norm1.weight",
651
+ "conditioner.embedders.1.model.transformer.resblocks.30.ln_2.bias": "encoders.30.layer_norm2.bias",
652
+ "conditioner.embedders.1.model.transformer.resblocks.30.ln_2.weight": "encoders.30.layer_norm2.weight",
653
+ "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.bias": "encoders.30.fc1.bias",
654
+ "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.weight": "encoders.30.fc1.weight",
655
+ "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.bias": "encoders.30.fc2.bias",
656
+ "conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.weight": "encoders.30.fc2.weight",
657
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias": ['encoders.31.attn.to_q.bias', 'encoders.31.attn.to_k.bias', 'encoders.31.attn.to_v.bias'],
658
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight": ['encoders.31.attn.to_q.weight', 'encoders.31.attn.to_k.weight', 'encoders.31.attn.to_v.weight'],
659
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.bias": "encoders.31.attn.to_out.bias",
660
+ "conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.weight": "encoders.31.attn.to_out.weight",
661
+ "conditioner.embedders.1.model.transformer.resblocks.31.ln_1.bias": "encoders.31.layer_norm1.bias",
662
+ "conditioner.embedders.1.model.transformer.resblocks.31.ln_1.weight": "encoders.31.layer_norm1.weight",
663
+ "conditioner.embedders.1.model.transformer.resblocks.31.ln_2.bias": "encoders.31.layer_norm2.bias",
664
+ "conditioner.embedders.1.model.transformer.resblocks.31.ln_2.weight": "encoders.31.layer_norm2.weight",
665
+ "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.bias": "encoders.31.fc1.bias",
666
+ "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.weight": "encoders.31.fc1.weight",
667
+ "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.bias": "encoders.31.fc2.bias",
668
+ "conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.weight": "encoders.31.fc2.weight",
669
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias": ['encoders.4.attn.to_q.bias', 'encoders.4.attn.to_k.bias', 'encoders.4.attn.to_v.bias'],
670
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight": ['encoders.4.attn.to_q.weight', 'encoders.4.attn.to_k.weight', 'encoders.4.attn.to_v.weight'],
671
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.bias": "encoders.4.attn.to_out.bias",
672
+ "conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.weight": "encoders.4.attn.to_out.weight",
673
+ "conditioner.embedders.1.model.transformer.resblocks.4.ln_1.bias": "encoders.4.layer_norm1.bias",
674
+ "conditioner.embedders.1.model.transformer.resblocks.4.ln_1.weight": "encoders.4.layer_norm1.weight",
675
+ "conditioner.embedders.1.model.transformer.resblocks.4.ln_2.bias": "encoders.4.layer_norm2.bias",
676
+ "conditioner.embedders.1.model.transformer.resblocks.4.ln_2.weight": "encoders.4.layer_norm2.weight",
677
+ "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.bias": "encoders.4.fc1.bias",
678
+ "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.weight": "encoders.4.fc1.weight",
679
+ "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.bias": "encoders.4.fc2.bias",
680
+ "conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.weight": "encoders.4.fc2.weight",
681
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias": ['encoders.5.attn.to_q.bias', 'encoders.5.attn.to_k.bias', 'encoders.5.attn.to_v.bias'],
682
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight": ['encoders.5.attn.to_q.weight', 'encoders.5.attn.to_k.weight', 'encoders.5.attn.to_v.weight'],
683
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.bias": "encoders.5.attn.to_out.bias",
684
+ "conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.weight": "encoders.5.attn.to_out.weight",
685
+ "conditioner.embedders.1.model.transformer.resblocks.5.ln_1.bias": "encoders.5.layer_norm1.bias",
686
+ "conditioner.embedders.1.model.transformer.resblocks.5.ln_1.weight": "encoders.5.layer_norm1.weight",
687
+ "conditioner.embedders.1.model.transformer.resblocks.5.ln_2.bias": "encoders.5.layer_norm2.bias",
688
+ "conditioner.embedders.1.model.transformer.resblocks.5.ln_2.weight": "encoders.5.layer_norm2.weight",
689
+ "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.bias": "encoders.5.fc1.bias",
690
+ "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.weight": "encoders.5.fc1.weight",
691
+ "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.bias": "encoders.5.fc2.bias",
692
+ "conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.weight": "encoders.5.fc2.weight",
693
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias": ['encoders.6.attn.to_q.bias', 'encoders.6.attn.to_k.bias', 'encoders.6.attn.to_v.bias'],
694
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight": ['encoders.6.attn.to_q.weight', 'encoders.6.attn.to_k.weight', 'encoders.6.attn.to_v.weight'],
695
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.bias": "encoders.6.attn.to_out.bias",
696
+ "conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.weight": "encoders.6.attn.to_out.weight",
697
+ "conditioner.embedders.1.model.transformer.resblocks.6.ln_1.bias": "encoders.6.layer_norm1.bias",
698
+ "conditioner.embedders.1.model.transformer.resblocks.6.ln_1.weight": "encoders.6.layer_norm1.weight",
699
+ "conditioner.embedders.1.model.transformer.resblocks.6.ln_2.bias": "encoders.6.layer_norm2.bias",
700
+ "conditioner.embedders.1.model.transformer.resblocks.6.ln_2.weight": "encoders.6.layer_norm2.weight",
701
+ "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.bias": "encoders.6.fc1.bias",
702
+ "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.weight": "encoders.6.fc1.weight",
703
+ "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.bias": "encoders.6.fc2.bias",
704
+ "conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.weight": "encoders.6.fc2.weight",
705
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias": ['encoders.7.attn.to_q.bias', 'encoders.7.attn.to_k.bias', 'encoders.7.attn.to_v.bias'],
706
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight": ['encoders.7.attn.to_q.weight', 'encoders.7.attn.to_k.weight', 'encoders.7.attn.to_v.weight'],
707
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.bias": "encoders.7.attn.to_out.bias",
708
+ "conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.weight": "encoders.7.attn.to_out.weight",
709
+ "conditioner.embedders.1.model.transformer.resblocks.7.ln_1.bias": "encoders.7.layer_norm1.bias",
710
+ "conditioner.embedders.1.model.transformer.resblocks.7.ln_1.weight": "encoders.7.layer_norm1.weight",
711
+ "conditioner.embedders.1.model.transformer.resblocks.7.ln_2.bias": "encoders.7.layer_norm2.bias",
712
+ "conditioner.embedders.1.model.transformer.resblocks.7.ln_2.weight": "encoders.7.layer_norm2.weight",
713
+ "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.bias": "encoders.7.fc1.bias",
714
+ "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.weight": "encoders.7.fc1.weight",
715
+ "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.bias": "encoders.7.fc2.bias",
716
+ "conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.weight": "encoders.7.fc2.weight",
717
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias": ['encoders.8.attn.to_q.bias', 'encoders.8.attn.to_k.bias', 'encoders.8.attn.to_v.bias'],
718
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight": ['encoders.8.attn.to_q.weight', 'encoders.8.attn.to_k.weight', 'encoders.8.attn.to_v.weight'],
719
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.bias": "encoders.8.attn.to_out.bias",
720
+ "conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.weight": "encoders.8.attn.to_out.weight",
721
+ "conditioner.embedders.1.model.transformer.resblocks.8.ln_1.bias": "encoders.8.layer_norm1.bias",
722
+ "conditioner.embedders.1.model.transformer.resblocks.8.ln_1.weight": "encoders.8.layer_norm1.weight",
723
+ "conditioner.embedders.1.model.transformer.resblocks.8.ln_2.bias": "encoders.8.layer_norm2.bias",
724
+ "conditioner.embedders.1.model.transformer.resblocks.8.ln_2.weight": "encoders.8.layer_norm2.weight",
725
+ "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.bias": "encoders.8.fc1.bias",
726
+ "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.weight": "encoders.8.fc1.weight",
727
+ "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.bias": "encoders.8.fc2.bias",
728
+ "conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.weight": "encoders.8.fc2.weight",
729
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias": ['encoders.9.attn.to_q.bias', 'encoders.9.attn.to_k.bias', 'encoders.9.attn.to_v.bias'],
730
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight": ['encoders.9.attn.to_q.weight', 'encoders.9.attn.to_k.weight', 'encoders.9.attn.to_v.weight'],
731
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.bias": "encoders.9.attn.to_out.bias",
732
+ "conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.weight": "encoders.9.attn.to_out.weight",
733
+ "conditioner.embedders.1.model.transformer.resblocks.9.ln_1.bias": "encoders.9.layer_norm1.bias",
734
+ "conditioner.embedders.1.model.transformer.resblocks.9.ln_1.weight": "encoders.9.layer_norm1.weight",
735
+ "conditioner.embedders.1.model.transformer.resblocks.9.ln_2.bias": "encoders.9.layer_norm2.bias",
736
+ "conditioner.embedders.1.model.transformer.resblocks.9.ln_2.weight": "encoders.9.layer_norm2.weight",
737
+ "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.bias": "encoders.9.fc1.bias",
738
+ "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.weight": "encoders.9.fc1.weight",
739
+ "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias": "encoders.9.fc2.bias",
740
+ "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight": "encoders.9.fc2.weight",
741
+ "conditioner.embedders.1.model.text_projection": "text_projection.weight",
742
+ }
743
+ state_dict_ = {}
744
+ for name in state_dict:
745
+ if name in rename_dict:
746
+ param = state_dict[name]
747
+ if name == "conditioner.embedders.1.model.positional_embedding":
748
+ param = param.reshape((1, param.shape[0], param.shape[1]))
749
+ elif name == "conditioner.embedders.1.model.text_projection":
750
+ param = param.T
751
+ if isinstance(rename_dict[name], str):
752
+ state_dict_[rename_dict[name]] = param
753
+ else:
754
+ length = param.shape[0] // 3
755
+ for i, rename in enumerate(rename_dict[name]):
756
+ state_dict_[rename] = param[i*length: i*length+length]
757
+ return state_dict_
diffsynth/models/sdxl_unet.py ADDED
The diff for this file is too large to render. See raw diff
 
diffsynth/models/sdxl_vae_decoder.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .sd_vae_decoder import SDVAEDecoder, SDVAEDecoderStateDictConverter
2
+
3
+
4
+ class SDXLVAEDecoder(SDVAEDecoder):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.scaling_factor = 0.13025
8
+
9
+ def state_dict_converter(self):
10
+ return SDXLVAEDecoderStateDictConverter()
11
+
12
+
13
+ class SDXLVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
14
+ def __init__(self):
15
+ super().__init__()
diffsynth/models/sdxl_vae_encoder.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder
2
+
3
+
4
+ class SDXLVAEEncoder(SDVAEEncoder):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.scaling_factor = 0.13025
8
+
9
+ def state_dict_converter(self):
10
+ return SDXLVAEEncoderStateDictConverter()
11
+
12
+
13
+ class SDXLVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
14
+ def __init__(self):
15
+ super().__init__()
diffsynth/models/svd_image_encoder.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .sd_text_encoder import CLIPEncoderLayer
3
+
4
+
5
+ class CLIPVisionEmbeddings(torch.nn.Module):
6
+ def __init__(self, embed_dim=1280, image_size=224, patch_size=14, num_channels=3):
7
+ super().__init__()
8
+
9
+ # class_embeds (This is a fixed tensor)
10
+ self.class_embedding = torch.nn.Parameter(torch.randn(1, 1, embed_dim))
11
+
12
+ # position_embeds
13
+ self.patch_embedding = torch.nn.Conv2d(in_channels=num_channels, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size, bias=False)
14
+
15
+ # position_embeds (This is a fixed tensor)
16
+ self.position_embeds = torch.nn.Parameter(torch.zeros(1, (image_size // patch_size) ** 2 + 1, embed_dim))
17
+
18
+ def forward(self, pixel_values):
19
+ batch_size = pixel_values.shape[0]
20
+ patch_embeds = self.patch_embedding(pixel_values)
21
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
22
+ class_embeds = self.class_embedding.repeat(batch_size, 1, 1)
23
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + self.position_embeds
24
+ return embeddings
25
+
26
+
27
+ class SVDImageEncoder(torch.nn.Module):
28
+ def __init__(self, embed_dim=1280, layer_norm_eps=1e-5, num_encoder_layers=32, encoder_intermediate_size=5120, projection_dim=1024, num_heads=16, head_dim=80):
29
+ super().__init__()
30
+ self.embeddings = CLIPVisionEmbeddings(embed_dim=embed_dim)
31
+ self.pre_layernorm = torch.nn.LayerNorm(embed_dim, eps=layer_norm_eps)
32
+ self.encoders = torch.nn.ModuleList([
33
+ CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=num_heads, head_dim=head_dim, use_quick_gelu=False)
34
+ for _ in range(num_encoder_layers)])
35
+ self.post_layernorm = torch.nn.LayerNorm(embed_dim, eps=layer_norm_eps)
36
+ self.visual_projection = torch.nn.Linear(embed_dim, projection_dim, bias=False)
37
+
38
+ def forward(self, pixel_values):
39
+ embeds = self.embeddings(pixel_values)
40
+ embeds = self.pre_layernorm(embeds)
41
+ for encoder_id, encoder in enumerate(self.encoders):
42
+ embeds = encoder(embeds)
43
+ embeds = self.post_layernorm(embeds[:, 0, :])
44
+ embeds = self.visual_projection(embeds)
45
+ return embeds
46
+
47
+ def state_dict_converter(self):
48
+ return SVDImageEncoderStateDictConverter()
49
+
50
+
51
+ class SVDImageEncoderStateDictConverter:
52
+ def __init__(self):
53
+ pass
54
+
55
+ def from_diffusers(self, state_dict):
56
+ rename_dict = {
57
+ "vision_model.embeddings.patch_embedding.weight": "embeddings.patch_embedding.weight",
58
+ "vision_model.embeddings.class_embedding": "embeddings.class_embedding",
59
+ "vision_model.embeddings.position_embedding.weight": "embeddings.position_embeds",
60
+ "vision_model.pre_layrnorm.weight": "pre_layernorm.weight",
61
+ "vision_model.pre_layrnorm.bias": "pre_layernorm.bias",
62
+ "vision_model.post_layernorm.weight": "post_layernorm.weight",
63
+ "vision_model.post_layernorm.bias": "post_layernorm.bias",
64
+ "visual_projection.weight": "visual_projection.weight"
65
+ }
66
+ attn_rename_dict = {
67
+ "self_attn.q_proj": "attn.to_q",
68
+ "self_attn.k_proj": "attn.to_k",
69
+ "self_attn.v_proj": "attn.to_v",
70
+ "self_attn.out_proj": "attn.to_out",
71
+ "layer_norm1": "layer_norm1",
72
+ "layer_norm2": "layer_norm2",
73
+ "mlp.fc1": "fc1",
74
+ "mlp.fc2": "fc2",
75
+ }
76
+ state_dict_ = {}
77
+ for name in state_dict:
78
+ if name in rename_dict:
79
+ param = state_dict[name]
80
+ if name == "vision_model.embeddings.class_embedding":
81
+ param = state_dict[name].view(1, 1, -1)
82
+ elif name == "vision_model.embeddings.position_embedding.weight":
83
+ param = state_dict[name].unsqueeze(0)
84
+ state_dict_[rename_dict[name]] = param
85
+ elif name.startswith("vision_model.encoder.layers."):
86
+ param = state_dict[name]
87
+ names = name.split(".")
88
+ layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
89
+ name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
90
+ state_dict_[name_] = param
91
+ return state_dict_
92
+
93
+ def from_civitai(self, state_dict):
94
+ rename_dict = {
95
+ "conditioner.embedders.0.open_clip.model.visual.class_embedding": "embeddings.class_embedding",
96
+ "conditioner.embedders.0.open_clip.model.visual.conv1.weight": "embeddings.patch_embedding.weight",
97
+ "conditioner.embedders.0.open_clip.model.visual.ln_post.bias": "post_layernorm.bias",
98
+ "conditioner.embedders.0.open_clip.model.visual.ln_post.weight": "post_layernorm.weight",
99
+ "conditioner.embedders.0.open_clip.model.visual.ln_pre.bias": "pre_layernorm.bias",
100
+ "conditioner.embedders.0.open_clip.model.visual.ln_pre.weight": "pre_layernorm.weight",
101
+ "conditioner.embedders.0.open_clip.model.visual.positional_embedding": "embeddings.position_embeds",
102
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.attn.in_proj_bias": ['encoders.0.attn.to_q.bias', 'encoders.0.attn.to_k.bias', 'encoders.0.attn.to_v.bias'],
103
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.attn.in_proj_weight": ['encoders.0.attn.to_q.weight', 'encoders.0.attn.to_k.weight', 'encoders.0.attn.to_v.weight'],
104
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.attn.out_proj.bias": "encoders.0.attn.to_out.bias",
105
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.attn.out_proj.weight": "encoders.0.attn.to_out.weight",
106
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.ln_1.bias": "encoders.0.layer_norm1.bias",
107
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.ln_1.weight": "encoders.0.layer_norm1.weight",
108
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.ln_2.bias": "encoders.0.layer_norm2.bias",
109
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.ln_2.weight": "encoders.0.layer_norm2.weight",
110
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.mlp.c_fc.bias": "encoders.0.fc1.bias",
111
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.mlp.c_fc.weight": "encoders.0.fc1.weight",
112
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.mlp.c_proj.bias": "encoders.0.fc2.bias",
113
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.mlp.c_proj.weight": "encoders.0.fc2.weight",
114
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.attn.in_proj_bias": ['encoders.1.attn.to_q.bias', 'encoders.1.attn.to_k.bias', 'encoders.1.attn.to_v.bias'],
115
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.attn.in_proj_weight": ['encoders.1.attn.to_q.weight', 'encoders.1.attn.to_k.weight', 'encoders.1.attn.to_v.weight'],
116
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.attn.out_proj.bias": "encoders.1.attn.to_out.bias",
117
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.attn.out_proj.weight": "encoders.1.attn.to_out.weight",
118
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.ln_1.bias": "encoders.1.layer_norm1.bias",
119
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.ln_1.weight": "encoders.1.layer_norm1.weight",
120
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.ln_2.bias": "encoders.1.layer_norm2.bias",
121
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.ln_2.weight": "encoders.1.layer_norm2.weight",
122
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.mlp.c_fc.bias": "encoders.1.fc1.bias",
123
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.mlp.c_fc.weight": "encoders.1.fc1.weight",
124
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.mlp.c_proj.bias": "encoders.1.fc2.bias",
125
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.mlp.c_proj.weight": "encoders.1.fc2.weight",
126
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.attn.in_proj_bias": ['encoders.10.attn.to_q.bias', 'encoders.10.attn.to_k.bias', 'encoders.10.attn.to_v.bias'],
127
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.attn.in_proj_weight": ['encoders.10.attn.to_q.weight', 'encoders.10.attn.to_k.weight', 'encoders.10.attn.to_v.weight'],
128
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.attn.out_proj.bias": "encoders.10.attn.to_out.bias",
129
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.attn.out_proj.weight": "encoders.10.attn.to_out.weight",
130
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.ln_1.bias": "encoders.10.layer_norm1.bias",
131
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.ln_1.weight": "encoders.10.layer_norm1.weight",
132
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.ln_2.bias": "encoders.10.layer_norm2.bias",
133
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.ln_2.weight": "encoders.10.layer_norm2.weight",
134
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.mlp.c_fc.bias": "encoders.10.fc1.bias",
135
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.mlp.c_fc.weight": "encoders.10.fc1.weight",
136
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.mlp.c_proj.bias": "encoders.10.fc2.bias",
137
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.mlp.c_proj.weight": "encoders.10.fc2.weight",
138
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.attn.in_proj_bias": ['encoders.11.attn.to_q.bias', 'encoders.11.attn.to_k.bias', 'encoders.11.attn.to_v.bias'],
139
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.attn.in_proj_weight": ['encoders.11.attn.to_q.weight', 'encoders.11.attn.to_k.weight', 'encoders.11.attn.to_v.weight'],
140
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.attn.out_proj.bias": "encoders.11.attn.to_out.bias",
141
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.attn.out_proj.weight": "encoders.11.attn.to_out.weight",
142
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.ln_1.bias": "encoders.11.layer_norm1.bias",
143
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.ln_1.weight": "encoders.11.layer_norm1.weight",
144
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.ln_2.bias": "encoders.11.layer_norm2.bias",
145
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.ln_2.weight": "encoders.11.layer_norm2.weight",
146
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.mlp.c_fc.bias": "encoders.11.fc1.bias",
147
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.mlp.c_fc.weight": "encoders.11.fc1.weight",
148
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.mlp.c_proj.bias": "encoders.11.fc2.bias",
149
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.mlp.c_proj.weight": "encoders.11.fc2.weight",
150
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.attn.in_proj_bias": ['encoders.12.attn.to_q.bias', 'encoders.12.attn.to_k.bias', 'encoders.12.attn.to_v.bias'],
151
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.attn.in_proj_weight": ['encoders.12.attn.to_q.weight', 'encoders.12.attn.to_k.weight', 'encoders.12.attn.to_v.weight'],
152
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.attn.out_proj.bias": "encoders.12.attn.to_out.bias",
153
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.attn.out_proj.weight": "encoders.12.attn.to_out.weight",
154
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.ln_1.bias": "encoders.12.layer_norm1.bias",
155
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.ln_1.weight": "encoders.12.layer_norm1.weight",
156
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.ln_2.bias": "encoders.12.layer_norm2.bias",
157
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.ln_2.weight": "encoders.12.layer_norm2.weight",
158
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.mlp.c_fc.bias": "encoders.12.fc1.bias",
159
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.mlp.c_fc.weight": "encoders.12.fc1.weight",
160
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.mlp.c_proj.bias": "encoders.12.fc2.bias",
161
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.mlp.c_proj.weight": "encoders.12.fc2.weight",
162
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.attn.in_proj_bias": ['encoders.13.attn.to_q.bias', 'encoders.13.attn.to_k.bias', 'encoders.13.attn.to_v.bias'],
163
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.attn.in_proj_weight": ['encoders.13.attn.to_q.weight', 'encoders.13.attn.to_k.weight', 'encoders.13.attn.to_v.weight'],
164
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.attn.out_proj.bias": "encoders.13.attn.to_out.bias",
165
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.attn.out_proj.weight": "encoders.13.attn.to_out.weight",
166
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.ln_1.bias": "encoders.13.layer_norm1.bias",
167
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.ln_1.weight": "encoders.13.layer_norm1.weight",
168
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.ln_2.bias": "encoders.13.layer_norm2.bias",
169
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.ln_2.weight": "encoders.13.layer_norm2.weight",
170
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.mlp.c_fc.bias": "encoders.13.fc1.bias",
171
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.mlp.c_fc.weight": "encoders.13.fc1.weight",
172
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.mlp.c_proj.bias": "encoders.13.fc2.bias",
173
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.mlp.c_proj.weight": "encoders.13.fc2.weight",
174
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.attn.in_proj_bias": ['encoders.14.attn.to_q.bias', 'encoders.14.attn.to_k.bias', 'encoders.14.attn.to_v.bias'],
175
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.attn.in_proj_weight": ['encoders.14.attn.to_q.weight', 'encoders.14.attn.to_k.weight', 'encoders.14.attn.to_v.weight'],
176
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.attn.out_proj.bias": "encoders.14.attn.to_out.bias",
177
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.attn.out_proj.weight": "encoders.14.attn.to_out.weight",
178
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.ln_1.bias": "encoders.14.layer_norm1.bias",
179
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.ln_1.weight": "encoders.14.layer_norm1.weight",
180
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.ln_2.bias": "encoders.14.layer_norm2.bias",
181
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.ln_2.weight": "encoders.14.layer_norm2.weight",
182
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.mlp.c_fc.bias": "encoders.14.fc1.bias",
183
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.mlp.c_fc.weight": "encoders.14.fc1.weight",
184
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.mlp.c_proj.bias": "encoders.14.fc2.bias",
185
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.mlp.c_proj.weight": "encoders.14.fc2.weight",
186
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.attn.in_proj_bias": ['encoders.15.attn.to_q.bias', 'encoders.15.attn.to_k.bias', 'encoders.15.attn.to_v.bias'],
187
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.attn.in_proj_weight": ['encoders.15.attn.to_q.weight', 'encoders.15.attn.to_k.weight', 'encoders.15.attn.to_v.weight'],
188
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.attn.out_proj.bias": "encoders.15.attn.to_out.bias",
189
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.attn.out_proj.weight": "encoders.15.attn.to_out.weight",
190
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.ln_1.bias": "encoders.15.layer_norm1.bias",
191
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.ln_1.weight": "encoders.15.layer_norm1.weight",
192
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.ln_2.bias": "encoders.15.layer_norm2.bias",
193
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.ln_2.weight": "encoders.15.layer_norm2.weight",
194
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.mlp.c_fc.bias": "encoders.15.fc1.bias",
195
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.mlp.c_fc.weight": "encoders.15.fc1.weight",
196
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.mlp.c_proj.bias": "encoders.15.fc2.bias",
197
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.mlp.c_proj.weight": "encoders.15.fc2.weight",
198
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.attn.in_proj_bias": ['encoders.16.attn.to_q.bias', 'encoders.16.attn.to_k.bias', 'encoders.16.attn.to_v.bias'],
199
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.attn.in_proj_weight": ['encoders.16.attn.to_q.weight', 'encoders.16.attn.to_k.weight', 'encoders.16.attn.to_v.weight'],
200
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.attn.out_proj.bias": "encoders.16.attn.to_out.bias",
201
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.attn.out_proj.weight": "encoders.16.attn.to_out.weight",
202
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.ln_1.bias": "encoders.16.layer_norm1.bias",
203
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.ln_1.weight": "encoders.16.layer_norm1.weight",
204
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.ln_2.bias": "encoders.16.layer_norm2.bias",
205
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.ln_2.weight": "encoders.16.layer_norm2.weight",
206
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.mlp.c_fc.bias": "encoders.16.fc1.bias",
207
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.mlp.c_fc.weight": "encoders.16.fc1.weight",
208
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.mlp.c_proj.bias": "encoders.16.fc2.bias",
209
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.mlp.c_proj.weight": "encoders.16.fc2.weight",
210
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.attn.in_proj_bias": ['encoders.17.attn.to_q.bias', 'encoders.17.attn.to_k.bias', 'encoders.17.attn.to_v.bias'],
211
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.attn.in_proj_weight": ['encoders.17.attn.to_q.weight', 'encoders.17.attn.to_k.weight', 'encoders.17.attn.to_v.weight'],
212
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.attn.out_proj.bias": "encoders.17.attn.to_out.bias",
213
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.attn.out_proj.weight": "encoders.17.attn.to_out.weight",
214
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.ln_1.bias": "encoders.17.layer_norm1.bias",
215
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.ln_1.weight": "encoders.17.layer_norm1.weight",
216
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.ln_2.bias": "encoders.17.layer_norm2.bias",
217
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.ln_2.weight": "encoders.17.layer_norm2.weight",
218
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.mlp.c_fc.bias": "encoders.17.fc1.bias",
219
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.mlp.c_fc.weight": "encoders.17.fc1.weight",
220
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.mlp.c_proj.bias": "encoders.17.fc2.bias",
221
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.mlp.c_proj.weight": "encoders.17.fc2.weight",
222
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.attn.in_proj_bias": ['encoders.18.attn.to_q.bias', 'encoders.18.attn.to_k.bias', 'encoders.18.attn.to_v.bias'],
223
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.attn.in_proj_weight": ['encoders.18.attn.to_q.weight', 'encoders.18.attn.to_k.weight', 'encoders.18.attn.to_v.weight'],
224
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.attn.out_proj.bias": "encoders.18.attn.to_out.bias",
225
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.attn.out_proj.weight": "encoders.18.attn.to_out.weight",
226
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.ln_1.bias": "encoders.18.layer_norm1.bias",
227
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.ln_1.weight": "encoders.18.layer_norm1.weight",
228
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.ln_2.bias": "encoders.18.layer_norm2.bias",
229
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.ln_2.weight": "encoders.18.layer_norm2.weight",
230
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.mlp.c_fc.bias": "encoders.18.fc1.bias",
231
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.mlp.c_fc.weight": "encoders.18.fc1.weight",
232
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.mlp.c_proj.bias": "encoders.18.fc2.bias",
233
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.mlp.c_proj.weight": "encoders.18.fc2.weight",
234
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.attn.in_proj_bias": ['encoders.19.attn.to_q.bias', 'encoders.19.attn.to_k.bias', 'encoders.19.attn.to_v.bias'],
235
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.attn.in_proj_weight": ['encoders.19.attn.to_q.weight', 'encoders.19.attn.to_k.weight', 'encoders.19.attn.to_v.weight'],
236
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.attn.out_proj.bias": "encoders.19.attn.to_out.bias",
237
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.attn.out_proj.weight": "encoders.19.attn.to_out.weight",
238
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.ln_1.bias": "encoders.19.layer_norm1.bias",
239
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.ln_1.weight": "encoders.19.layer_norm1.weight",
240
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.ln_2.bias": "encoders.19.layer_norm2.bias",
241
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.ln_2.weight": "encoders.19.layer_norm2.weight",
242
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.mlp.c_fc.bias": "encoders.19.fc1.bias",
243
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.mlp.c_fc.weight": "encoders.19.fc1.weight",
244
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.mlp.c_proj.bias": "encoders.19.fc2.bias",
245
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.mlp.c_proj.weight": "encoders.19.fc2.weight",
246
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.attn.in_proj_bias": ['encoders.2.attn.to_q.bias', 'encoders.2.attn.to_k.bias', 'encoders.2.attn.to_v.bias'],
247
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.attn.in_proj_weight": ['encoders.2.attn.to_q.weight', 'encoders.2.attn.to_k.weight', 'encoders.2.attn.to_v.weight'],
248
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.attn.out_proj.bias": "encoders.2.attn.to_out.bias",
249
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.attn.out_proj.weight": "encoders.2.attn.to_out.weight",
250
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.ln_1.bias": "encoders.2.layer_norm1.bias",
251
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.ln_1.weight": "encoders.2.layer_norm1.weight",
252
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.ln_2.bias": "encoders.2.layer_norm2.bias",
253
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.ln_2.weight": "encoders.2.layer_norm2.weight",
254
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.mlp.c_fc.bias": "encoders.2.fc1.bias",
255
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.mlp.c_fc.weight": "encoders.2.fc1.weight",
256
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.mlp.c_proj.bias": "encoders.2.fc2.bias",
257
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.mlp.c_proj.weight": "encoders.2.fc2.weight",
258
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.attn.in_proj_bias": ['encoders.20.attn.to_q.bias', 'encoders.20.attn.to_k.bias', 'encoders.20.attn.to_v.bias'],
259
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.attn.in_proj_weight": ['encoders.20.attn.to_q.weight', 'encoders.20.attn.to_k.weight', 'encoders.20.attn.to_v.weight'],
260
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.attn.out_proj.bias": "encoders.20.attn.to_out.bias",
261
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.attn.out_proj.weight": "encoders.20.attn.to_out.weight",
262
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.ln_1.bias": "encoders.20.layer_norm1.bias",
263
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.ln_1.weight": "encoders.20.layer_norm1.weight",
264
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.ln_2.bias": "encoders.20.layer_norm2.bias",
265
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.ln_2.weight": "encoders.20.layer_norm2.weight",
266
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.mlp.c_fc.bias": "encoders.20.fc1.bias",
267
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.mlp.c_fc.weight": "encoders.20.fc1.weight",
268
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.mlp.c_proj.bias": "encoders.20.fc2.bias",
269
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.mlp.c_proj.weight": "encoders.20.fc2.weight",
270
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.attn.in_proj_bias": ['encoders.21.attn.to_q.bias', 'encoders.21.attn.to_k.bias', 'encoders.21.attn.to_v.bias'],
271
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.attn.in_proj_weight": ['encoders.21.attn.to_q.weight', 'encoders.21.attn.to_k.weight', 'encoders.21.attn.to_v.weight'],
272
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.attn.out_proj.bias": "encoders.21.attn.to_out.bias",
273
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.attn.out_proj.weight": "encoders.21.attn.to_out.weight",
274
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.ln_1.bias": "encoders.21.layer_norm1.bias",
275
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.ln_1.weight": "encoders.21.layer_norm1.weight",
276
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.ln_2.bias": "encoders.21.layer_norm2.bias",
277
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.ln_2.weight": "encoders.21.layer_norm2.weight",
278
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.mlp.c_fc.bias": "encoders.21.fc1.bias",
279
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.mlp.c_fc.weight": "encoders.21.fc1.weight",
280
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.mlp.c_proj.bias": "encoders.21.fc2.bias",
281
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.mlp.c_proj.weight": "encoders.21.fc2.weight",
282
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.attn.in_proj_bias": ['encoders.22.attn.to_q.bias', 'encoders.22.attn.to_k.bias', 'encoders.22.attn.to_v.bias'],
283
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.attn.in_proj_weight": ['encoders.22.attn.to_q.weight', 'encoders.22.attn.to_k.weight', 'encoders.22.attn.to_v.weight'],
284
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.attn.out_proj.bias": "encoders.22.attn.to_out.bias",
285
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.attn.out_proj.weight": "encoders.22.attn.to_out.weight",
286
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.ln_1.bias": "encoders.22.layer_norm1.bias",
287
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.ln_1.weight": "encoders.22.layer_norm1.weight",
288
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.ln_2.bias": "encoders.22.layer_norm2.bias",
289
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.ln_2.weight": "encoders.22.layer_norm2.weight",
290
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.mlp.c_fc.bias": "encoders.22.fc1.bias",
291
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.mlp.c_fc.weight": "encoders.22.fc1.weight",
292
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.mlp.c_proj.bias": "encoders.22.fc2.bias",
293
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.mlp.c_proj.weight": "encoders.22.fc2.weight",
294
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.attn.in_proj_bias": ['encoders.23.attn.to_q.bias', 'encoders.23.attn.to_k.bias', 'encoders.23.attn.to_v.bias'],
295
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.attn.in_proj_weight": ['encoders.23.attn.to_q.weight', 'encoders.23.attn.to_k.weight', 'encoders.23.attn.to_v.weight'],
296
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.attn.out_proj.bias": "encoders.23.attn.to_out.bias",
297
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.attn.out_proj.weight": "encoders.23.attn.to_out.weight",
298
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.ln_1.bias": "encoders.23.layer_norm1.bias",
299
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.ln_1.weight": "encoders.23.layer_norm1.weight",
300
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.ln_2.bias": "encoders.23.layer_norm2.bias",
301
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.ln_2.weight": "encoders.23.layer_norm2.weight",
302
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.mlp.c_fc.bias": "encoders.23.fc1.bias",
303
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.mlp.c_fc.weight": "encoders.23.fc1.weight",
304
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.mlp.c_proj.bias": "encoders.23.fc2.bias",
305
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.mlp.c_proj.weight": "encoders.23.fc2.weight",
306
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.attn.in_proj_bias": ['encoders.24.attn.to_q.bias', 'encoders.24.attn.to_k.bias', 'encoders.24.attn.to_v.bias'],
307
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.attn.in_proj_weight": ['encoders.24.attn.to_q.weight', 'encoders.24.attn.to_k.weight', 'encoders.24.attn.to_v.weight'],
308
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.attn.out_proj.bias": "encoders.24.attn.to_out.bias",
309
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.attn.out_proj.weight": "encoders.24.attn.to_out.weight",
310
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.ln_1.bias": "encoders.24.layer_norm1.bias",
311
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.ln_1.weight": "encoders.24.layer_norm1.weight",
312
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.ln_2.bias": "encoders.24.layer_norm2.bias",
313
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.ln_2.weight": "encoders.24.layer_norm2.weight",
314
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.mlp.c_fc.bias": "encoders.24.fc1.bias",
315
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.mlp.c_fc.weight": "encoders.24.fc1.weight",
316
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.mlp.c_proj.bias": "encoders.24.fc2.bias",
317
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.mlp.c_proj.weight": "encoders.24.fc2.weight",
318
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.attn.in_proj_bias": ['encoders.25.attn.to_q.bias', 'encoders.25.attn.to_k.bias', 'encoders.25.attn.to_v.bias'],
319
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.attn.in_proj_weight": ['encoders.25.attn.to_q.weight', 'encoders.25.attn.to_k.weight', 'encoders.25.attn.to_v.weight'],
320
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.attn.out_proj.bias": "encoders.25.attn.to_out.bias",
321
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.attn.out_proj.weight": "encoders.25.attn.to_out.weight",
322
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.ln_1.bias": "encoders.25.layer_norm1.bias",
323
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.ln_1.weight": "encoders.25.layer_norm1.weight",
324
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.ln_2.bias": "encoders.25.layer_norm2.bias",
325
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.ln_2.weight": "encoders.25.layer_norm2.weight",
326
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.mlp.c_fc.bias": "encoders.25.fc1.bias",
327
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.mlp.c_fc.weight": "encoders.25.fc1.weight",
328
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.mlp.c_proj.bias": "encoders.25.fc2.bias",
329
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.mlp.c_proj.weight": "encoders.25.fc2.weight",
330
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.attn.in_proj_bias": ['encoders.26.attn.to_q.bias', 'encoders.26.attn.to_k.bias', 'encoders.26.attn.to_v.bias'],
331
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.attn.in_proj_weight": ['encoders.26.attn.to_q.weight', 'encoders.26.attn.to_k.weight', 'encoders.26.attn.to_v.weight'],
332
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.attn.out_proj.bias": "encoders.26.attn.to_out.bias",
333
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.attn.out_proj.weight": "encoders.26.attn.to_out.weight",
334
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.ln_1.bias": "encoders.26.layer_norm1.bias",
335
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.ln_1.weight": "encoders.26.layer_norm1.weight",
336
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.ln_2.bias": "encoders.26.layer_norm2.bias",
337
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.ln_2.weight": "encoders.26.layer_norm2.weight",
338
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.mlp.c_fc.bias": "encoders.26.fc1.bias",
339
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.mlp.c_fc.weight": "encoders.26.fc1.weight",
340
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.mlp.c_proj.bias": "encoders.26.fc2.bias",
341
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.mlp.c_proj.weight": "encoders.26.fc2.weight",
342
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.attn.in_proj_bias": ['encoders.27.attn.to_q.bias', 'encoders.27.attn.to_k.bias', 'encoders.27.attn.to_v.bias'],
343
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.attn.in_proj_weight": ['encoders.27.attn.to_q.weight', 'encoders.27.attn.to_k.weight', 'encoders.27.attn.to_v.weight'],
344
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.attn.out_proj.bias": "encoders.27.attn.to_out.bias",
345
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.attn.out_proj.weight": "encoders.27.attn.to_out.weight",
346
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.ln_1.bias": "encoders.27.layer_norm1.bias",
347
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.ln_1.weight": "encoders.27.layer_norm1.weight",
348
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.ln_2.bias": "encoders.27.layer_norm2.bias",
349
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.ln_2.weight": "encoders.27.layer_norm2.weight",
350
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.mlp.c_fc.bias": "encoders.27.fc1.bias",
351
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.mlp.c_fc.weight": "encoders.27.fc1.weight",
352
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.mlp.c_proj.bias": "encoders.27.fc2.bias",
353
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.mlp.c_proj.weight": "encoders.27.fc2.weight",
354
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.attn.in_proj_bias": ['encoders.28.attn.to_q.bias', 'encoders.28.attn.to_k.bias', 'encoders.28.attn.to_v.bias'],
355
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.attn.in_proj_weight": ['encoders.28.attn.to_q.weight', 'encoders.28.attn.to_k.weight', 'encoders.28.attn.to_v.weight'],
356
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.attn.out_proj.bias": "encoders.28.attn.to_out.bias",
357
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.attn.out_proj.weight": "encoders.28.attn.to_out.weight",
358
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.ln_1.bias": "encoders.28.layer_norm1.bias",
359
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.ln_1.weight": "encoders.28.layer_norm1.weight",
360
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.ln_2.bias": "encoders.28.layer_norm2.bias",
361
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.ln_2.weight": "encoders.28.layer_norm2.weight",
362
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.mlp.c_fc.bias": "encoders.28.fc1.bias",
363
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.mlp.c_fc.weight": "encoders.28.fc1.weight",
364
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.mlp.c_proj.bias": "encoders.28.fc2.bias",
365
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.mlp.c_proj.weight": "encoders.28.fc2.weight",
366
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.attn.in_proj_bias": ['encoders.29.attn.to_q.bias', 'encoders.29.attn.to_k.bias', 'encoders.29.attn.to_v.bias'],
367
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.attn.in_proj_weight": ['encoders.29.attn.to_q.weight', 'encoders.29.attn.to_k.weight', 'encoders.29.attn.to_v.weight'],
368
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.attn.out_proj.bias": "encoders.29.attn.to_out.bias",
369
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.attn.out_proj.weight": "encoders.29.attn.to_out.weight",
370
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.ln_1.bias": "encoders.29.layer_norm1.bias",
371
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.ln_1.weight": "encoders.29.layer_norm1.weight",
372
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.ln_2.bias": "encoders.29.layer_norm2.bias",
373
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.ln_2.weight": "encoders.29.layer_norm2.weight",
374
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.mlp.c_fc.bias": "encoders.29.fc1.bias",
375
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.mlp.c_fc.weight": "encoders.29.fc1.weight",
376
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.mlp.c_proj.bias": "encoders.29.fc2.bias",
377
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.mlp.c_proj.weight": "encoders.29.fc2.weight",
378
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.attn.in_proj_bias": ['encoders.3.attn.to_q.bias', 'encoders.3.attn.to_k.bias', 'encoders.3.attn.to_v.bias'],
379
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.attn.in_proj_weight": ['encoders.3.attn.to_q.weight', 'encoders.3.attn.to_k.weight', 'encoders.3.attn.to_v.weight'],
380
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.attn.out_proj.bias": "encoders.3.attn.to_out.bias",
381
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.attn.out_proj.weight": "encoders.3.attn.to_out.weight",
382
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.ln_1.bias": "encoders.3.layer_norm1.bias",
383
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.ln_1.weight": "encoders.3.layer_norm1.weight",
384
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.ln_2.bias": "encoders.3.layer_norm2.bias",
385
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.ln_2.weight": "encoders.3.layer_norm2.weight",
386
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.mlp.c_fc.bias": "encoders.3.fc1.bias",
387
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.mlp.c_fc.weight": "encoders.3.fc1.weight",
388
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.mlp.c_proj.bias": "encoders.3.fc2.bias",
389
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.mlp.c_proj.weight": "encoders.3.fc2.weight",
390
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.attn.in_proj_bias": ['encoders.30.attn.to_q.bias', 'encoders.30.attn.to_k.bias', 'encoders.30.attn.to_v.bias'],
391
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.attn.in_proj_weight": ['encoders.30.attn.to_q.weight', 'encoders.30.attn.to_k.weight', 'encoders.30.attn.to_v.weight'],
392
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.attn.out_proj.bias": "encoders.30.attn.to_out.bias",
393
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.attn.out_proj.weight": "encoders.30.attn.to_out.weight",
394
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.ln_1.bias": "encoders.30.layer_norm1.bias",
395
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.ln_1.weight": "encoders.30.layer_norm1.weight",
396
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.ln_2.bias": "encoders.30.layer_norm2.bias",
397
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.ln_2.weight": "encoders.30.layer_norm2.weight",
398
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.mlp.c_fc.bias": "encoders.30.fc1.bias",
399
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.mlp.c_fc.weight": "encoders.30.fc1.weight",
400
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.mlp.c_proj.bias": "encoders.30.fc2.bias",
401
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.mlp.c_proj.weight": "encoders.30.fc2.weight",
402
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.attn.in_proj_bias": ['encoders.31.attn.to_q.bias', 'encoders.31.attn.to_k.bias', 'encoders.31.attn.to_v.bias'],
403
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.attn.in_proj_weight": ['encoders.31.attn.to_q.weight', 'encoders.31.attn.to_k.weight', 'encoders.31.attn.to_v.weight'],
404
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.attn.out_proj.bias": "encoders.31.attn.to_out.bias",
405
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.attn.out_proj.weight": "encoders.31.attn.to_out.weight",
406
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.ln_1.bias": "encoders.31.layer_norm1.bias",
407
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.ln_1.weight": "encoders.31.layer_norm1.weight",
408
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.ln_2.bias": "encoders.31.layer_norm2.bias",
409
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.ln_2.weight": "encoders.31.layer_norm2.weight",
410
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.mlp.c_fc.bias": "encoders.31.fc1.bias",
411
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.mlp.c_fc.weight": "encoders.31.fc1.weight",
412
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.mlp.c_proj.bias": "encoders.31.fc2.bias",
413
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.mlp.c_proj.weight": "encoders.31.fc2.weight",
414
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.attn.in_proj_bias": ['encoders.4.attn.to_q.bias', 'encoders.4.attn.to_k.bias', 'encoders.4.attn.to_v.bias'],
415
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.attn.in_proj_weight": ['encoders.4.attn.to_q.weight', 'encoders.4.attn.to_k.weight', 'encoders.4.attn.to_v.weight'],
416
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.attn.out_proj.bias": "encoders.4.attn.to_out.bias",
417
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.attn.out_proj.weight": "encoders.4.attn.to_out.weight",
418
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.ln_1.bias": "encoders.4.layer_norm1.bias",
419
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.ln_1.weight": "encoders.4.layer_norm1.weight",
420
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.ln_2.bias": "encoders.4.layer_norm2.bias",
421
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.ln_2.weight": "encoders.4.layer_norm2.weight",
422
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.mlp.c_fc.bias": "encoders.4.fc1.bias",
423
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.mlp.c_fc.weight": "encoders.4.fc1.weight",
424
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.mlp.c_proj.bias": "encoders.4.fc2.bias",
425
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.mlp.c_proj.weight": "encoders.4.fc2.weight",
426
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.attn.in_proj_bias": ['encoders.5.attn.to_q.bias', 'encoders.5.attn.to_k.bias', 'encoders.5.attn.to_v.bias'],
427
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.attn.in_proj_weight": ['encoders.5.attn.to_q.weight', 'encoders.5.attn.to_k.weight', 'encoders.5.attn.to_v.weight'],
428
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.attn.out_proj.bias": "encoders.5.attn.to_out.bias",
429
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.attn.out_proj.weight": "encoders.5.attn.to_out.weight",
430
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.ln_1.bias": "encoders.5.layer_norm1.bias",
431
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.ln_1.weight": "encoders.5.layer_norm1.weight",
432
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.ln_2.bias": "encoders.5.layer_norm2.bias",
433
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.ln_2.weight": "encoders.5.layer_norm2.weight",
434
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.mlp.c_fc.bias": "encoders.5.fc1.bias",
435
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.mlp.c_fc.weight": "encoders.5.fc1.weight",
436
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.mlp.c_proj.bias": "encoders.5.fc2.bias",
437
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.mlp.c_proj.weight": "encoders.5.fc2.weight",
438
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.attn.in_proj_bias": ['encoders.6.attn.to_q.bias', 'encoders.6.attn.to_k.bias', 'encoders.6.attn.to_v.bias'],
439
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.attn.in_proj_weight": ['encoders.6.attn.to_q.weight', 'encoders.6.attn.to_k.weight', 'encoders.6.attn.to_v.weight'],
440
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.attn.out_proj.bias": "encoders.6.attn.to_out.bias",
441
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.attn.out_proj.weight": "encoders.6.attn.to_out.weight",
442
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.ln_1.bias": "encoders.6.layer_norm1.bias",
443
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.ln_1.weight": "encoders.6.layer_norm1.weight",
444
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.ln_2.bias": "encoders.6.layer_norm2.bias",
445
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.ln_2.weight": "encoders.6.layer_norm2.weight",
446
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.mlp.c_fc.bias": "encoders.6.fc1.bias",
447
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.mlp.c_fc.weight": "encoders.6.fc1.weight",
448
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.mlp.c_proj.bias": "encoders.6.fc2.bias",
449
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.mlp.c_proj.weight": "encoders.6.fc2.weight",
450
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.attn.in_proj_bias": ['encoders.7.attn.to_q.bias', 'encoders.7.attn.to_k.bias', 'encoders.7.attn.to_v.bias'],
451
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.attn.in_proj_weight": ['encoders.7.attn.to_q.weight', 'encoders.7.attn.to_k.weight', 'encoders.7.attn.to_v.weight'],
452
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.attn.out_proj.bias": "encoders.7.attn.to_out.bias",
453
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.attn.out_proj.weight": "encoders.7.attn.to_out.weight",
454
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.ln_1.bias": "encoders.7.layer_norm1.bias",
455
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.ln_1.weight": "encoders.7.layer_norm1.weight",
456
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.ln_2.bias": "encoders.7.layer_norm2.bias",
457
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.ln_2.weight": "encoders.7.layer_norm2.weight",
458
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.mlp.c_fc.bias": "encoders.7.fc1.bias",
459
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.mlp.c_fc.weight": "encoders.7.fc1.weight",
460
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.mlp.c_proj.bias": "encoders.7.fc2.bias",
461
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.mlp.c_proj.weight": "encoders.7.fc2.weight",
462
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.attn.in_proj_bias": ['encoders.8.attn.to_q.bias', 'encoders.8.attn.to_k.bias', 'encoders.8.attn.to_v.bias'],
463
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.attn.in_proj_weight": ['encoders.8.attn.to_q.weight', 'encoders.8.attn.to_k.weight', 'encoders.8.attn.to_v.weight'],
464
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.attn.out_proj.bias": "encoders.8.attn.to_out.bias",
465
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.attn.out_proj.weight": "encoders.8.attn.to_out.weight",
466
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.ln_1.bias": "encoders.8.layer_norm1.bias",
467
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.ln_1.weight": "encoders.8.layer_norm1.weight",
468
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.ln_2.bias": "encoders.8.layer_norm2.bias",
469
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.ln_2.weight": "encoders.8.layer_norm2.weight",
470
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.mlp.c_fc.bias": "encoders.8.fc1.bias",
471
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.mlp.c_fc.weight": "encoders.8.fc1.weight",
472
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.mlp.c_proj.bias": "encoders.8.fc2.bias",
473
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.mlp.c_proj.weight": "encoders.8.fc2.weight",
474
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.attn.in_proj_bias": ['encoders.9.attn.to_q.bias', 'encoders.9.attn.to_k.bias', 'encoders.9.attn.to_v.bias'],
475
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.attn.in_proj_weight": ['encoders.9.attn.to_q.weight', 'encoders.9.attn.to_k.weight', 'encoders.9.attn.to_v.weight'],
476
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.attn.out_proj.bias": "encoders.9.attn.to_out.bias",
477
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.attn.out_proj.weight": "encoders.9.attn.to_out.weight",
478
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.ln_1.bias": "encoders.9.layer_norm1.bias",
479
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.ln_1.weight": "encoders.9.layer_norm1.weight",
480
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.ln_2.bias": "encoders.9.layer_norm2.bias",
481
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.ln_2.weight": "encoders.9.layer_norm2.weight",
482
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.mlp.c_fc.bias": "encoders.9.fc1.bias",
483
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.mlp.c_fc.weight": "encoders.9.fc1.weight",
484
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.mlp.c_proj.bias": "encoders.9.fc2.bias",
485
+ "conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.mlp.c_proj.weight": "encoders.9.fc2.weight",
486
+ "conditioner.embedders.0.open_clip.model.visual.proj": "visual_projection.weight",
487
+ }
488
+ state_dict_ = {}
489
+ for name in state_dict:
490
+ if name in rename_dict:
491
+ param = state_dict[name]
492
+ if name == "conditioner.embedders.0.open_clip.model.visual.class_embedding":
493
+ param = param.reshape((1, 1, param.shape[0]))
494
+ elif name == "conditioner.embedders.0.open_clip.model.visual.positional_embedding":
495
+ param = param.reshape((1, param.shape[0], param.shape[1]))
496
+ elif name == "conditioner.embedders.0.open_clip.model.visual.proj":
497
+ param = param.T
498
+ if isinstance(rename_dict[name], str):
499
+ state_dict_[rename_dict[name]] = param
500
+ else:
501
+ length = param.shape[0] // 3
502
+ for i, rename in enumerate(rename_dict[name]):
503
+ state_dict_[rename] = param[i*length: i*length+length]
504
+ return state_dict_
diffsynth/models/svd_unet.py ADDED
The diff for this file is too large to render. See raw diff
 
diffsynth/models/svd_vae_decoder.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .attention import Attention
3
+ from .sd_unet import ResnetBlock, UpSampler
4
+ from .tiler import TileWorker
5
+ from einops import rearrange, repeat
6
+
7
+
8
+ class VAEAttentionBlock(torch.nn.Module):
9
+
10
+ def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
11
+ super().__init__()
12
+ inner_dim = num_attention_heads * attention_head_dim
13
+
14
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
15
+
16
+ self.transformer_blocks = torch.nn.ModuleList([
17
+ Attention(
18
+ inner_dim,
19
+ num_attention_heads,
20
+ attention_head_dim,
21
+ bias_q=True,
22
+ bias_kv=True,
23
+ bias_out=True
24
+ )
25
+ for d in range(num_layers)
26
+ ])
27
+
28
+ def forward(self, hidden_states, time_emb, text_emb, res_stack):
29
+ batch, _, height, width = hidden_states.shape
30
+ residual = hidden_states
31
+
32
+ hidden_states = self.norm(hidden_states)
33
+ inner_dim = hidden_states.shape[1]
34
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
35
+
36
+ for block in self.transformer_blocks:
37
+ hidden_states = block(hidden_states)
38
+
39
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
40
+ hidden_states = hidden_states + residual
41
+
42
+ return hidden_states, time_emb, text_emb, res_stack
43
+
44
+
45
+ class TemporalResnetBlock(torch.nn.Module):
46
+
47
+ def __init__(self, in_channels, out_channels, groups=32, eps=1e-5):
48
+ super().__init__()
49
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
50
+ self.conv1 = torch.nn.Conv3d(in_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0))
51
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
52
+ self.conv2 = torch.nn.Conv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0))
53
+ self.nonlinearity = torch.nn.SiLU()
54
+ self.mix_factor = torch.nn.Parameter(torch.Tensor([0.5]))
55
+
56
+ def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
57
+ x_spatial = hidden_states
58
+ x = rearrange(hidden_states, "T C H W -> 1 C T H W")
59
+ x = self.norm1(x)
60
+ x = self.nonlinearity(x)
61
+ x = self.conv1(x)
62
+ x = self.norm2(x)
63
+ x = self.nonlinearity(x)
64
+ x = self.conv2(x)
65
+ x_temporal = hidden_states + x[0].permute(1, 0, 2, 3)
66
+ alpha = torch.sigmoid(self.mix_factor)
67
+ hidden_states = alpha * x_temporal + (1 - alpha) * x_spatial
68
+ return hidden_states, time_emb, text_emb, res_stack
69
+
70
+
71
+ class SVDVAEDecoder(torch.nn.Module):
72
+ def __init__(self):
73
+ super().__init__()
74
+ self.scaling_factor = 0.18215
75
+ self.conv_in = torch.nn.Conv2d(4, 512, kernel_size=3, padding=1)
76
+
77
+ self.blocks = torch.nn.ModuleList([
78
+ # UNetMidBlock
79
+ ResnetBlock(512, 512, eps=1e-6),
80
+ TemporalResnetBlock(512, 512, eps=1e-6),
81
+ VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
82
+ ResnetBlock(512, 512, eps=1e-6),
83
+ TemporalResnetBlock(512, 512, eps=1e-6),
84
+ # UpDecoderBlock
85
+ ResnetBlock(512, 512, eps=1e-6),
86
+ TemporalResnetBlock(512, 512, eps=1e-6),
87
+ ResnetBlock(512, 512, eps=1e-6),
88
+ TemporalResnetBlock(512, 512, eps=1e-6),
89
+ ResnetBlock(512, 512, eps=1e-6),
90
+ TemporalResnetBlock(512, 512, eps=1e-6),
91
+ UpSampler(512),
92
+ # UpDecoderBlock
93
+ ResnetBlock(512, 512, eps=1e-6),
94
+ TemporalResnetBlock(512, 512, eps=1e-6),
95
+ ResnetBlock(512, 512, eps=1e-6),
96
+ TemporalResnetBlock(512, 512, eps=1e-6),
97
+ ResnetBlock(512, 512, eps=1e-6),
98
+ TemporalResnetBlock(512, 512, eps=1e-6),
99
+ UpSampler(512),
100
+ # UpDecoderBlock
101
+ ResnetBlock(512, 256, eps=1e-6),
102
+ TemporalResnetBlock(256, 256, eps=1e-6),
103
+ ResnetBlock(256, 256, eps=1e-6),
104
+ TemporalResnetBlock(256, 256, eps=1e-6),
105
+ ResnetBlock(256, 256, eps=1e-6),
106
+ TemporalResnetBlock(256, 256, eps=1e-6),
107
+ UpSampler(256),
108
+ # UpDecoderBlock
109
+ ResnetBlock(256, 128, eps=1e-6),
110
+ TemporalResnetBlock(128, 128, eps=1e-6),
111
+ ResnetBlock(128, 128, eps=1e-6),
112
+ TemporalResnetBlock(128, 128, eps=1e-6),
113
+ ResnetBlock(128, 128, eps=1e-6),
114
+ TemporalResnetBlock(128, 128, eps=1e-6),
115
+ ])
116
+
117
+ self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-5)
118
+ self.conv_act = torch.nn.SiLU()
119
+ self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
120
+ self.time_conv_out = torch.nn.Conv3d(3, 3, kernel_size=(3, 1, 1), padding=(1, 0, 0))
121
+
122
+
123
+ def forward(self, sample):
124
+ # 1. pre-process
125
+ hidden_states = rearrange(sample, "C T H W -> T C H W")
126
+ hidden_states = hidden_states / self.scaling_factor
127
+ hidden_states = self.conv_in(hidden_states)
128
+ time_emb, text_emb, res_stack = None, None, None
129
+
130
+ # 2. blocks
131
+ for i, block in enumerate(self.blocks):
132
+ hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
133
+
134
+ # 3. output
135
+ hidden_states = self.conv_norm_out(hidden_states)
136
+ hidden_states = self.conv_act(hidden_states)
137
+ hidden_states = self.conv_out(hidden_states)
138
+ hidden_states = rearrange(hidden_states, "T C H W -> C T H W")
139
+ hidden_states = self.time_conv_out(hidden_states)
140
+
141
+ return hidden_states
142
+
143
+
144
+ def build_mask(self, data, is_bound):
145
+ _, T, H, W = data.shape
146
+ t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
147
+ h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
148
+ w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
149
+ border_width = (T + H + W) // 6
150
+ pad = torch.ones_like(t) * border_width
151
+ mask = torch.stack([
152
+ pad if is_bound[0] else t + 1,
153
+ pad if is_bound[1] else T - t,
154
+ pad if is_bound[2] else h + 1,
155
+ pad if is_bound[3] else H - h,
156
+ pad if is_bound[4] else w + 1,
157
+ pad if is_bound[5] else W - w
158
+ ]).min(dim=0).values
159
+ mask = mask.clip(1, border_width)
160
+ mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
161
+ mask = rearrange(mask, "T H W -> 1 T H W")
162
+ return mask
163
+
164
+
165
+ def decode_video(
166
+ self, sample,
167
+ batch_time=8, batch_height=128, batch_width=128,
168
+ stride_time=4, stride_height=32, stride_width=32,
169
+ progress_bar=lambda x:x
170
+ ):
171
+ sample = sample.permute(1, 0, 2, 3)
172
+ data_device = sample.device
173
+ computation_device = self.conv_in.weight.device
174
+ torch_dtype = sample.dtype
175
+ _, T, H, W = sample.shape
176
+
177
+ weight = torch.zeros((1, T, H*8, W*8), dtype=torch_dtype, device=data_device)
178
+ values = torch.zeros((3, T, H*8, W*8), dtype=torch_dtype, device=data_device)
179
+
180
+ # Split tasks
181
+ tasks = []
182
+ for t in range(0, T, stride_time):
183
+ for h in range(0, H, stride_height):
184
+ for w in range(0, W, stride_width):
185
+ if (t-stride_time >= 0 and t-stride_time+batch_time >= T)\
186
+ or (h-stride_height >= 0 and h-stride_height+batch_height >= H)\
187
+ or (w-stride_width >= 0 and w-stride_width+batch_width >= W):
188
+ continue
189
+ tasks.append((t, t+batch_time, h, h+batch_height, w, w+batch_width))
190
+
191
+ # Run
192
+ for tl, tr, hl, hr, wl, wr in progress_bar(tasks):
193
+ sample_batch = sample[:, tl:tr, hl:hr, wl:wr].to(computation_device)
194
+ sample_batch = self.forward(sample_batch).to(data_device)
195
+ mask = self.build_mask(sample_batch, is_bound=(tl==0, tr>=T, hl==0, hr>=H, wl==0, wr>=W))
196
+ values[:, tl:tr, hl*8:hr*8, wl*8:wr*8] += sample_batch * mask
197
+ weight[:, tl:tr, hl*8:hr*8, wl*8:wr*8] += mask
198
+ values /= weight
199
+ return values
200
+
201
+
202
+ def state_dict_converter(self):
203
+ return SVDVAEDecoderStateDictConverter()
204
+
205
+
206
+ class SVDVAEDecoderStateDictConverter:
207
+ def __init__(self):
208
+ pass
209
+
210
+ def from_diffusers(self, state_dict):
211
+ static_rename_dict = {
212
+ "decoder.conv_in": "conv_in",
213
+ "decoder.mid_block.attentions.0.group_norm": "blocks.2.norm",
214
+ "decoder.mid_block.attentions.0.to_q": "blocks.2.transformer_blocks.0.to_q",
215
+ "decoder.mid_block.attentions.0.to_k": "blocks.2.transformer_blocks.0.to_k",
216
+ "decoder.mid_block.attentions.0.to_v": "blocks.2.transformer_blocks.0.to_v",
217
+ "decoder.mid_block.attentions.0.to_out.0": "blocks.2.transformer_blocks.0.to_out",
218
+ "decoder.up_blocks.0.upsamplers.0.conv": "blocks.11.conv",
219
+ "decoder.up_blocks.1.upsamplers.0.conv": "blocks.18.conv",
220
+ "decoder.up_blocks.2.upsamplers.0.conv": "blocks.25.conv",
221
+ "decoder.conv_norm_out": "conv_norm_out",
222
+ "decoder.conv_out": "conv_out",
223
+ "decoder.time_conv_out": "time_conv_out"
224
+ }
225
+ prefix_rename_dict = {
226
+ "decoder.mid_block.resnets.0.spatial_res_block": "blocks.0",
227
+ "decoder.mid_block.resnets.0.temporal_res_block": "blocks.1",
228
+ "decoder.mid_block.resnets.0.time_mixer": "blocks.1",
229
+ "decoder.mid_block.resnets.1.spatial_res_block": "blocks.3",
230
+ "decoder.mid_block.resnets.1.temporal_res_block": "blocks.4",
231
+ "decoder.mid_block.resnets.1.time_mixer": "blocks.4",
232
+
233
+ "decoder.up_blocks.0.resnets.0.spatial_res_block": "blocks.5",
234
+ "decoder.up_blocks.0.resnets.0.temporal_res_block": "blocks.6",
235
+ "decoder.up_blocks.0.resnets.0.time_mixer": "blocks.6",
236
+ "decoder.up_blocks.0.resnets.1.spatial_res_block": "blocks.7",
237
+ "decoder.up_blocks.0.resnets.1.temporal_res_block": "blocks.8",
238
+ "decoder.up_blocks.0.resnets.1.time_mixer": "blocks.8",
239
+ "decoder.up_blocks.0.resnets.2.spatial_res_block": "blocks.9",
240
+ "decoder.up_blocks.0.resnets.2.temporal_res_block": "blocks.10",
241
+ "decoder.up_blocks.0.resnets.2.time_mixer": "blocks.10",
242
+
243
+ "decoder.up_blocks.1.resnets.0.spatial_res_block": "blocks.12",
244
+ "decoder.up_blocks.1.resnets.0.temporal_res_block": "blocks.13",
245
+ "decoder.up_blocks.1.resnets.0.time_mixer": "blocks.13",
246
+ "decoder.up_blocks.1.resnets.1.spatial_res_block": "blocks.14",
247
+ "decoder.up_blocks.1.resnets.1.temporal_res_block": "blocks.15",
248
+ "decoder.up_blocks.1.resnets.1.time_mixer": "blocks.15",
249
+ "decoder.up_blocks.1.resnets.2.spatial_res_block": "blocks.16",
250
+ "decoder.up_blocks.1.resnets.2.temporal_res_block": "blocks.17",
251
+ "decoder.up_blocks.1.resnets.2.time_mixer": "blocks.17",
252
+
253
+ "decoder.up_blocks.2.resnets.0.spatial_res_block": "blocks.19",
254
+ "decoder.up_blocks.2.resnets.0.temporal_res_block": "blocks.20",
255
+ "decoder.up_blocks.2.resnets.0.time_mixer": "blocks.20",
256
+ "decoder.up_blocks.2.resnets.1.spatial_res_block": "blocks.21",
257
+ "decoder.up_blocks.2.resnets.1.temporal_res_block": "blocks.22",
258
+ "decoder.up_blocks.2.resnets.1.time_mixer": "blocks.22",
259
+ "decoder.up_blocks.2.resnets.2.spatial_res_block": "blocks.23",
260
+ "decoder.up_blocks.2.resnets.2.temporal_res_block": "blocks.24",
261
+ "decoder.up_blocks.2.resnets.2.time_mixer": "blocks.24",
262
+
263
+ "decoder.up_blocks.3.resnets.0.spatial_res_block": "blocks.26",
264
+ "decoder.up_blocks.3.resnets.0.temporal_res_block": "blocks.27",
265
+ "decoder.up_blocks.3.resnets.0.time_mixer": "blocks.27",
266
+ "decoder.up_blocks.3.resnets.1.spatial_res_block": "blocks.28",
267
+ "decoder.up_blocks.3.resnets.1.temporal_res_block": "blocks.29",
268
+ "decoder.up_blocks.3.resnets.1.time_mixer": "blocks.29",
269
+ "decoder.up_blocks.3.resnets.2.spatial_res_block": "blocks.30",
270
+ "decoder.up_blocks.3.resnets.2.temporal_res_block": "blocks.31",
271
+ "decoder.up_blocks.3.resnets.2.time_mixer": "blocks.31",
272
+ }
273
+ suffix_rename_dict = {
274
+ "norm1.weight": "norm1.weight",
275
+ "conv1.weight": "conv1.weight",
276
+ "norm2.weight": "norm2.weight",
277
+ "conv2.weight": "conv2.weight",
278
+ "conv_shortcut.weight": "conv_shortcut.weight",
279
+ "norm1.bias": "norm1.bias",
280
+ "conv1.bias": "conv1.bias",
281
+ "norm2.bias": "norm2.bias",
282
+ "conv2.bias": "conv2.bias",
283
+ "conv_shortcut.bias": "conv_shortcut.bias",
284
+ "mix_factor": "mix_factor",
285
+ }
286
+
287
+ state_dict_ = {}
288
+ for name in static_rename_dict:
289
+ state_dict_[static_rename_dict[name] + ".weight"] = state_dict[name + ".weight"]
290
+ state_dict_[static_rename_dict[name] + ".bias"] = state_dict[name + ".bias"]
291
+ for prefix_name in prefix_rename_dict:
292
+ for suffix_name in suffix_rename_dict:
293
+ name = prefix_name + "." + suffix_name
294
+ name_ = prefix_rename_dict[prefix_name] + "." + suffix_rename_dict[suffix_name]
295
+ if name in state_dict:
296
+ state_dict_[name_] = state_dict[name]
297
+
298
+ return state_dict_
299
+
300
+
301
+ def from_civitai(self, state_dict):
302
+ rename_dict = {
303
+ "first_stage_model.decoder.conv_in.bias": "conv_in.bias",
304
+ "first_stage_model.decoder.conv_in.weight": "conv_in.weight",
305
+ "first_stage_model.decoder.conv_out.bias": "conv_out.bias",
306
+ "first_stage_model.decoder.conv_out.time_mix_conv.bias": "time_conv_out.bias",
307
+ "first_stage_model.decoder.conv_out.time_mix_conv.weight": "time_conv_out.weight",
308
+ "first_stage_model.decoder.conv_out.weight": "conv_out.weight",
309
+ "first_stage_model.decoder.mid.attn_1.k.bias": "blocks.2.transformer_blocks.0.to_k.bias",
310
+ "first_stage_model.decoder.mid.attn_1.k.weight": "blocks.2.transformer_blocks.0.to_k.weight",
311
+ "first_stage_model.decoder.mid.attn_1.norm.bias": "blocks.2.norm.bias",
312
+ "first_stage_model.decoder.mid.attn_1.norm.weight": "blocks.2.norm.weight",
313
+ "first_stage_model.decoder.mid.attn_1.proj_out.bias": "blocks.2.transformer_blocks.0.to_out.bias",
314
+ "first_stage_model.decoder.mid.attn_1.proj_out.weight": "blocks.2.transformer_blocks.0.to_out.weight",
315
+ "first_stage_model.decoder.mid.attn_1.q.bias": "blocks.2.transformer_blocks.0.to_q.bias",
316
+ "first_stage_model.decoder.mid.attn_1.q.weight": "blocks.2.transformer_blocks.0.to_q.weight",
317
+ "first_stage_model.decoder.mid.attn_1.v.bias": "blocks.2.transformer_blocks.0.to_v.bias",
318
+ "first_stage_model.decoder.mid.attn_1.v.weight": "blocks.2.transformer_blocks.0.to_v.weight",
319
+ "first_stage_model.decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
320
+ "first_stage_model.decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
321
+ "first_stage_model.decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
322
+ "first_stage_model.decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
323
+ "first_stage_model.decoder.mid.block_1.mix_factor": "blocks.1.mix_factor",
324
+ "first_stage_model.decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
325
+ "first_stage_model.decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
326
+ "first_stage_model.decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
327
+ "first_stage_model.decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
328
+ "first_stage_model.decoder.mid.block_1.time_stack.in_layers.0.bias": "blocks.1.norm1.bias",
329
+ "first_stage_model.decoder.mid.block_1.time_stack.in_layers.0.weight": "blocks.1.norm1.weight",
330
+ "first_stage_model.decoder.mid.block_1.time_stack.in_layers.2.bias": "blocks.1.conv1.bias",
331
+ "first_stage_model.decoder.mid.block_1.time_stack.in_layers.2.weight": "blocks.1.conv1.weight",
332
+ "first_stage_model.decoder.mid.block_1.time_stack.out_layers.0.bias": "blocks.1.norm2.bias",
333
+ "first_stage_model.decoder.mid.block_1.time_stack.out_layers.0.weight": "blocks.1.norm2.weight",
334
+ "first_stage_model.decoder.mid.block_1.time_stack.out_layers.3.bias": "blocks.1.conv2.bias",
335
+ "first_stage_model.decoder.mid.block_1.time_stack.out_layers.3.weight": "blocks.1.conv2.weight",
336
+ "first_stage_model.decoder.mid.block_2.conv1.bias": "blocks.3.conv1.bias",
337
+ "first_stage_model.decoder.mid.block_2.conv1.weight": "blocks.3.conv1.weight",
338
+ "first_stage_model.decoder.mid.block_2.conv2.bias": "blocks.3.conv2.bias",
339
+ "first_stage_model.decoder.mid.block_2.conv2.weight": "blocks.3.conv2.weight",
340
+ "first_stage_model.decoder.mid.block_2.mix_factor": "blocks.4.mix_factor",
341
+ "first_stage_model.decoder.mid.block_2.norm1.bias": "blocks.3.norm1.bias",
342
+ "first_stage_model.decoder.mid.block_2.norm1.weight": "blocks.3.norm1.weight",
343
+ "first_stage_model.decoder.mid.block_2.norm2.bias": "blocks.3.norm2.bias",
344
+ "first_stage_model.decoder.mid.block_2.norm2.weight": "blocks.3.norm2.weight",
345
+ "first_stage_model.decoder.mid.block_2.time_stack.in_layers.0.bias": "blocks.4.norm1.bias",
346
+ "first_stage_model.decoder.mid.block_2.time_stack.in_layers.0.weight": "blocks.4.norm1.weight",
347
+ "first_stage_model.decoder.mid.block_2.time_stack.in_layers.2.bias": "blocks.4.conv1.bias",
348
+ "first_stage_model.decoder.mid.block_2.time_stack.in_layers.2.weight": "blocks.4.conv1.weight",
349
+ "first_stage_model.decoder.mid.block_2.time_stack.out_layers.0.bias": "blocks.4.norm2.bias",
350
+ "first_stage_model.decoder.mid.block_2.time_stack.out_layers.0.weight": "blocks.4.norm2.weight",
351
+ "first_stage_model.decoder.mid.block_2.time_stack.out_layers.3.bias": "blocks.4.conv2.bias",
352
+ "first_stage_model.decoder.mid.block_2.time_stack.out_layers.3.weight": "blocks.4.conv2.weight",
353
+ "first_stage_model.decoder.norm_out.bias": "conv_norm_out.bias",
354
+ "first_stage_model.decoder.norm_out.weight": "conv_norm_out.weight",
355
+ "first_stage_model.decoder.up.0.block.0.conv1.bias": "blocks.26.conv1.bias",
356
+ "first_stage_model.decoder.up.0.block.0.conv1.weight": "blocks.26.conv1.weight",
357
+ "first_stage_model.decoder.up.0.block.0.conv2.bias": "blocks.26.conv2.bias",
358
+ "first_stage_model.decoder.up.0.block.0.conv2.weight": "blocks.26.conv2.weight",
359
+ "first_stage_model.decoder.up.0.block.0.mix_factor": "blocks.27.mix_factor",
360
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "blocks.26.conv_shortcut.bias",
361
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "blocks.26.conv_shortcut.weight",
362
+ "first_stage_model.decoder.up.0.block.0.norm1.bias": "blocks.26.norm1.bias",
363
+ "first_stage_model.decoder.up.0.block.0.norm1.weight": "blocks.26.norm1.weight",
364
+ "first_stage_model.decoder.up.0.block.0.norm2.bias": "blocks.26.norm2.bias",
365
+ "first_stage_model.decoder.up.0.block.0.norm2.weight": "blocks.26.norm2.weight",
366
+ "first_stage_model.decoder.up.0.block.0.time_stack.in_layers.0.bias": "blocks.27.norm1.bias",
367
+ "first_stage_model.decoder.up.0.block.0.time_stack.in_layers.0.weight": "blocks.27.norm1.weight",
368
+ "first_stage_model.decoder.up.0.block.0.time_stack.in_layers.2.bias": "blocks.27.conv1.bias",
369
+ "first_stage_model.decoder.up.0.block.0.time_stack.in_layers.2.weight": "blocks.27.conv1.weight",
370
+ "first_stage_model.decoder.up.0.block.0.time_stack.out_layers.0.bias": "blocks.27.norm2.bias",
371
+ "first_stage_model.decoder.up.0.block.0.time_stack.out_layers.0.weight": "blocks.27.norm2.weight",
372
+ "first_stage_model.decoder.up.0.block.0.time_stack.out_layers.3.bias": "blocks.27.conv2.bias",
373
+ "first_stage_model.decoder.up.0.block.0.time_stack.out_layers.3.weight": "blocks.27.conv2.weight",
374
+ "first_stage_model.decoder.up.0.block.1.conv1.bias": "blocks.28.conv1.bias",
375
+ "first_stage_model.decoder.up.0.block.1.conv1.weight": "blocks.28.conv1.weight",
376
+ "first_stage_model.decoder.up.0.block.1.conv2.bias": "blocks.28.conv2.bias",
377
+ "first_stage_model.decoder.up.0.block.1.conv2.weight": "blocks.28.conv2.weight",
378
+ "first_stage_model.decoder.up.0.block.1.mix_factor": "blocks.29.mix_factor",
379
+ "first_stage_model.decoder.up.0.block.1.norm1.bias": "blocks.28.norm1.bias",
380
+ "first_stage_model.decoder.up.0.block.1.norm1.weight": "blocks.28.norm1.weight",
381
+ "first_stage_model.decoder.up.0.block.1.norm2.bias": "blocks.28.norm2.bias",
382
+ "first_stage_model.decoder.up.0.block.1.norm2.weight": "blocks.28.norm2.weight",
383
+ "first_stage_model.decoder.up.0.block.1.time_stack.in_layers.0.bias": "blocks.29.norm1.bias",
384
+ "first_stage_model.decoder.up.0.block.1.time_stack.in_layers.0.weight": "blocks.29.norm1.weight",
385
+ "first_stage_model.decoder.up.0.block.1.time_stack.in_layers.2.bias": "blocks.29.conv1.bias",
386
+ "first_stage_model.decoder.up.0.block.1.time_stack.in_layers.2.weight": "blocks.29.conv1.weight",
387
+ "first_stage_model.decoder.up.0.block.1.time_stack.out_layers.0.bias": "blocks.29.norm2.bias",
388
+ "first_stage_model.decoder.up.0.block.1.time_stack.out_layers.0.weight": "blocks.29.norm2.weight",
389
+ "first_stage_model.decoder.up.0.block.1.time_stack.out_layers.3.bias": "blocks.29.conv2.bias",
390
+ "first_stage_model.decoder.up.0.block.1.time_stack.out_layers.3.weight": "blocks.29.conv2.weight",
391
+ "first_stage_model.decoder.up.0.block.2.conv1.bias": "blocks.30.conv1.bias",
392
+ "first_stage_model.decoder.up.0.block.2.conv1.weight": "blocks.30.conv1.weight",
393
+ "first_stage_model.decoder.up.0.block.2.conv2.bias": "blocks.30.conv2.bias",
394
+ "first_stage_model.decoder.up.0.block.2.conv2.weight": "blocks.30.conv2.weight",
395
+ "first_stage_model.decoder.up.0.block.2.mix_factor": "blocks.31.mix_factor",
396
+ "first_stage_model.decoder.up.0.block.2.norm1.bias": "blocks.30.norm1.bias",
397
+ "first_stage_model.decoder.up.0.block.2.norm1.weight": "blocks.30.norm1.weight",
398
+ "first_stage_model.decoder.up.0.block.2.norm2.bias": "blocks.30.norm2.bias",
399
+ "first_stage_model.decoder.up.0.block.2.norm2.weight": "blocks.30.norm2.weight",
400
+ "first_stage_model.decoder.up.0.block.2.time_stack.in_layers.0.bias": "blocks.31.norm1.bias",
401
+ "first_stage_model.decoder.up.0.block.2.time_stack.in_layers.0.weight": "blocks.31.norm1.weight",
402
+ "first_stage_model.decoder.up.0.block.2.time_stack.in_layers.2.bias": "blocks.31.conv1.bias",
403
+ "first_stage_model.decoder.up.0.block.2.time_stack.in_layers.2.weight": "blocks.31.conv1.weight",
404
+ "first_stage_model.decoder.up.0.block.2.time_stack.out_layers.0.bias": "blocks.31.norm2.bias",
405
+ "first_stage_model.decoder.up.0.block.2.time_stack.out_layers.0.weight": "blocks.31.norm2.weight",
406
+ "first_stage_model.decoder.up.0.block.2.time_stack.out_layers.3.bias": "blocks.31.conv2.bias",
407
+ "first_stage_model.decoder.up.0.block.2.time_stack.out_layers.3.weight": "blocks.31.conv2.weight",
408
+ "first_stage_model.decoder.up.1.block.0.conv1.bias": "blocks.19.conv1.bias",
409
+ "first_stage_model.decoder.up.1.block.0.conv1.weight": "blocks.19.conv1.weight",
410
+ "first_stage_model.decoder.up.1.block.0.conv2.bias": "blocks.19.conv2.bias",
411
+ "first_stage_model.decoder.up.1.block.0.conv2.weight": "blocks.19.conv2.weight",
412
+ "first_stage_model.decoder.up.1.block.0.mix_factor": "blocks.20.mix_factor",
413
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "blocks.19.conv_shortcut.bias",
414
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "blocks.19.conv_shortcut.weight",
415
+ "first_stage_model.decoder.up.1.block.0.norm1.bias": "blocks.19.norm1.bias",
416
+ "first_stage_model.decoder.up.1.block.0.norm1.weight": "blocks.19.norm1.weight",
417
+ "first_stage_model.decoder.up.1.block.0.norm2.bias": "blocks.19.norm2.bias",
418
+ "first_stage_model.decoder.up.1.block.0.norm2.weight": "blocks.19.norm2.weight",
419
+ "first_stage_model.decoder.up.1.block.0.time_stack.in_layers.0.bias": "blocks.20.norm1.bias",
420
+ "first_stage_model.decoder.up.1.block.0.time_stack.in_layers.0.weight": "blocks.20.norm1.weight",
421
+ "first_stage_model.decoder.up.1.block.0.time_stack.in_layers.2.bias": "blocks.20.conv1.bias",
422
+ "first_stage_model.decoder.up.1.block.0.time_stack.in_layers.2.weight": "blocks.20.conv1.weight",
423
+ "first_stage_model.decoder.up.1.block.0.time_stack.out_layers.0.bias": "blocks.20.norm2.bias",
424
+ "first_stage_model.decoder.up.1.block.0.time_stack.out_layers.0.weight": "blocks.20.norm2.weight",
425
+ "first_stage_model.decoder.up.1.block.0.time_stack.out_layers.3.bias": "blocks.20.conv2.bias",
426
+ "first_stage_model.decoder.up.1.block.0.time_stack.out_layers.3.weight": "blocks.20.conv2.weight",
427
+ "first_stage_model.decoder.up.1.block.1.conv1.bias": "blocks.21.conv1.bias",
428
+ "first_stage_model.decoder.up.1.block.1.conv1.weight": "blocks.21.conv1.weight",
429
+ "first_stage_model.decoder.up.1.block.1.conv2.bias": "blocks.21.conv2.bias",
430
+ "first_stage_model.decoder.up.1.block.1.conv2.weight": "blocks.21.conv2.weight",
431
+ "first_stage_model.decoder.up.1.block.1.mix_factor": "blocks.22.mix_factor",
432
+ "first_stage_model.decoder.up.1.block.1.norm1.bias": "blocks.21.norm1.bias",
433
+ "first_stage_model.decoder.up.1.block.1.norm1.weight": "blocks.21.norm1.weight",
434
+ "first_stage_model.decoder.up.1.block.1.norm2.bias": "blocks.21.norm2.bias",
435
+ "first_stage_model.decoder.up.1.block.1.norm2.weight": "blocks.21.norm2.weight",
436
+ "first_stage_model.decoder.up.1.block.1.time_stack.in_layers.0.bias": "blocks.22.norm1.bias",
437
+ "first_stage_model.decoder.up.1.block.1.time_stack.in_layers.0.weight": "blocks.22.norm1.weight",
438
+ "first_stage_model.decoder.up.1.block.1.time_stack.in_layers.2.bias": "blocks.22.conv1.bias",
439
+ "first_stage_model.decoder.up.1.block.1.time_stack.in_layers.2.weight": "blocks.22.conv1.weight",
440
+ "first_stage_model.decoder.up.1.block.1.time_stack.out_layers.0.bias": "blocks.22.norm2.bias",
441
+ "first_stage_model.decoder.up.1.block.1.time_stack.out_layers.0.weight": "blocks.22.norm2.weight",
442
+ "first_stage_model.decoder.up.1.block.1.time_stack.out_layers.3.bias": "blocks.22.conv2.bias",
443
+ "first_stage_model.decoder.up.1.block.1.time_stack.out_layers.3.weight": "blocks.22.conv2.weight",
444
+ "first_stage_model.decoder.up.1.block.2.conv1.bias": "blocks.23.conv1.bias",
445
+ "first_stage_model.decoder.up.1.block.2.conv1.weight": "blocks.23.conv1.weight",
446
+ "first_stage_model.decoder.up.1.block.2.conv2.bias": "blocks.23.conv2.bias",
447
+ "first_stage_model.decoder.up.1.block.2.conv2.weight": "blocks.23.conv2.weight",
448
+ "first_stage_model.decoder.up.1.block.2.mix_factor": "blocks.24.mix_factor",
449
+ "first_stage_model.decoder.up.1.block.2.norm1.bias": "blocks.23.norm1.bias",
450
+ "first_stage_model.decoder.up.1.block.2.norm1.weight": "blocks.23.norm1.weight",
451
+ "first_stage_model.decoder.up.1.block.2.norm2.bias": "blocks.23.norm2.bias",
452
+ "first_stage_model.decoder.up.1.block.2.norm2.weight": "blocks.23.norm2.weight",
453
+ "first_stage_model.decoder.up.1.block.2.time_stack.in_layers.0.bias": "blocks.24.norm1.bias",
454
+ "first_stage_model.decoder.up.1.block.2.time_stack.in_layers.0.weight": "blocks.24.norm1.weight",
455
+ "first_stage_model.decoder.up.1.block.2.time_stack.in_layers.2.bias": "blocks.24.conv1.bias",
456
+ "first_stage_model.decoder.up.1.block.2.time_stack.in_layers.2.weight": "blocks.24.conv1.weight",
457
+ "first_stage_model.decoder.up.1.block.2.time_stack.out_layers.0.bias": "blocks.24.norm2.bias",
458
+ "first_stage_model.decoder.up.1.block.2.time_stack.out_layers.0.weight": "blocks.24.norm2.weight",
459
+ "first_stage_model.decoder.up.1.block.2.time_stack.out_layers.3.bias": "blocks.24.conv2.bias",
460
+ "first_stage_model.decoder.up.1.block.2.time_stack.out_layers.3.weight": "blocks.24.conv2.weight",
461
+ "first_stage_model.decoder.up.1.upsample.conv.bias": "blocks.25.conv.bias",
462
+ "first_stage_model.decoder.up.1.upsample.conv.weight": "blocks.25.conv.weight",
463
+ "first_stage_model.decoder.up.2.block.0.conv1.bias": "blocks.12.conv1.bias",
464
+ "first_stage_model.decoder.up.2.block.0.conv1.weight": "blocks.12.conv1.weight",
465
+ "first_stage_model.decoder.up.2.block.0.conv2.bias": "blocks.12.conv2.bias",
466
+ "first_stage_model.decoder.up.2.block.0.conv2.weight": "blocks.12.conv2.weight",
467
+ "first_stage_model.decoder.up.2.block.0.mix_factor": "blocks.13.mix_factor",
468
+ "first_stage_model.decoder.up.2.block.0.norm1.bias": "blocks.12.norm1.bias",
469
+ "first_stage_model.decoder.up.2.block.0.norm1.weight": "blocks.12.norm1.weight",
470
+ "first_stage_model.decoder.up.2.block.0.norm2.bias": "blocks.12.norm2.bias",
471
+ "first_stage_model.decoder.up.2.block.0.norm2.weight": "blocks.12.norm2.weight",
472
+ "first_stage_model.decoder.up.2.block.0.time_stack.in_layers.0.bias": "blocks.13.norm1.bias",
473
+ "first_stage_model.decoder.up.2.block.0.time_stack.in_layers.0.weight": "blocks.13.norm1.weight",
474
+ "first_stage_model.decoder.up.2.block.0.time_stack.in_layers.2.bias": "blocks.13.conv1.bias",
475
+ "first_stage_model.decoder.up.2.block.0.time_stack.in_layers.2.weight": "blocks.13.conv1.weight",
476
+ "first_stage_model.decoder.up.2.block.0.time_stack.out_layers.0.bias": "blocks.13.norm2.bias",
477
+ "first_stage_model.decoder.up.2.block.0.time_stack.out_layers.0.weight": "blocks.13.norm2.weight",
478
+ "first_stage_model.decoder.up.2.block.0.time_stack.out_layers.3.bias": "blocks.13.conv2.bias",
479
+ "first_stage_model.decoder.up.2.block.0.time_stack.out_layers.3.weight": "blocks.13.conv2.weight",
480
+ "first_stage_model.decoder.up.2.block.1.conv1.bias": "blocks.14.conv1.bias",
481
+ "first_stage_model.decoder.up.2.block.1.conv1.weight": "blocks.14.conv1.weight",
482
+ "first_stage_model.decoder.up.2.block.1.conv2.bias": "blocks.14.conv2.bias",
483
+ "first_stage_model.decoder.up.2.block.1.conv2.weight": "blocks.14.conv2.weight",
484
+ "first_stage_model.decoder.up.2.block.1.mix_factor": "blocks.15.mix_factor",
485
+ "first_stage_model.decoder.up.2.block.1.norm1.bias": "blocks.14.norm1.bias",
486
+ "first_stage_model.decoder.up.2.block.1.norm1.weight": "blocks.14.norm1.weight",
487
+ "first_stage_model.decoder.up.2.block.1.norm2.bias": "blocks.14.norm2.bias",
488
+ "first_stage_model.decoder.up.2.block.1.norm2.weight": "blocks.14.norm2.weight",
489
+ "first_stage_model.decoder.up.2.block.1.time_stack.in_layers.0.bias": "blocks.15.norm1.bias",
490
+ "first_stage_model.decoder.up.2.block.1.time_stack.in_layers.0.weight": "blocks.15.norm1.weight",
491
+ "first_stage_model.decoder.up.2.block.1.time_stack.in_layers.2.bias": "blocks.15.conv1.bias",
492
+ "first_stage_model.decoder.up.2.block.1.time_stack.in_layers.2.weight": "blocks.15.conv1.weight",
493
+ "first_stage_model.decoder.up.2.block.1.time_stack.out_layers.0.bias": "blocks.15.norm2.bias",
494
+ "first_stage_model.decoder.up.2.block.1.time_stack.out_layers.0.weight": "blocks.15.norm2.weight",
495
+ "first_stage_model.decoder.up.2.block.1.time_stack.out_layers.3.bias": "blocks.15.conv2.bias",
496
+ "first_stage_model.decoder.up.2.block.1.time_stack.out_layers.3.weight": "blocks.15.conv2.weight",
497
+ "first_stage_model.decoder.up.2.block.2.conv1.bias": "blocks.16.conv1.bias",
498
+ "first_stage_model.decoder.up.2.block.2.conv1.weight": "blocks.16.conv1.weight",
499
+ "first_stage_model.decoder.up.2.block.2.conv2.bias": "blocks.16.conv2.bias",
500
+ "first_stage_model.decoder.up.2.block.2.conv2.weight": "blocks.16.conv2.weight",
501
+ "first_stage_model.decoder.up.2.block.2.mix_factor": "blocks.17.mix_factor",
502
+ "first_stage_model.decoder.up.2.block.2.norm1.bias": "blocks.16.norm1.bias",
503
+ "first_stage_model.decoder.up.2.block.2.norm1.weight": "blocks.16.norm1.weight",
504
+ "first_stage_model.decoder.up.2.block.2.norm2.bias": "blocks.16.norm2.bias",
505
+ "first_stage_model.decoder.up.2.block.2.norm2.weight": "blocks.16.norm2.weight",
506
+ "first_stage_model.decoder.up.2.block.2.time_stack.in_layers.0.bias": "blocks.17.norm1.bias",
507
+ "first_stage_model.decoder.up.2.block.2.time_stack.in_layers.0.weight": "blocks.17.norm1.weight",
508
+ "first_stage_model.decoder.up.2.block.2.time_stack.in_layers.2.bias": "blocks.17.conv1.bias",
509
+ "first_stage_model.decoder.up.2.block.2.time_stack.in_layers.2.weight": "blocks.17.conv1.weight",
510
+ "first_stage_model.decoder.up.2.block.2.time_stack.out_layers.0.bias": "blocks.17.norm2.bias",
511
+ "first_stage_model.decoder.up.2.block.2.time_stack.out_layers.0.weight": "blocks.17.norm2.weight",
512
+ "first_stage_model.decoder.up.2.block.2.time_stack.out_layers.3.bias": "blocks.17.conv2.bias",
513
+ "first_stage_model.decoder.up.2.block.2.time_stack.out_layers.3.weight": "blocks.17.conv2.weight",
514
+ "first_stage_model.decoder.up.2.upsample.conv.bias": "blocks.18.conv.bias",
515
+ "first_stage_model.decoder.up.2.upsample.conv.weight": "blocks.18.conv.weight",
516
+ "first_stage_model.decoder.up.3.block.0.conv1.bias": "blocks.5.conv1.bias",
517
+ "first_stage_model.decoder.up.3.block.0.conv1.weight": "blocks.5.conv1.weight",
518
+ "first_stage_model.decoder.up.3.block.0.conv2.bias": "blocks.5.conv2.bias",
519
+ "first_stage_model.decoder.up.3.block.0.conv2.weight": "blocks.5.conv2.weight",
520
+ "first_stage_model.decoder.up.3.block.0.mix_factor": "blocks.6.mix_factor",
521
+ "first_stage_model.decoder.up.3.block.0.norm1.bias": "blocks.5.norm1.bias",
522
+ "first_stage_model.decoder.up.3.block.0.norm1.weight": "blocks.5.norm1.weight",
523
+ "first_stage_model.decoder.up.3.block.0.norm2.bias": "blocks.5.norm2.bias",
524
+ "first_stage_model.decoder.up.3.block.0.norm2.weight": "blocks.5.norm2.weight",
525
+ "first_stage_model.decoder.up.3.block.0.time_stack.in_layers.0.bias": "blocks.6.norm1.bias",
526
+ "first_stage_model.decoder.up.3.block.0.time_stack.in_layers.0.weight": "blocks.6.norm1.weight",
527
+ "first_stage_model.decoder.up.3.block.0.time_stack.in_layers.2.bias": "blocks.6.conv1.bias",
528
+ "first_stage_model.decoder.up.3.block.0.time_stack.in_layers.2.weight": "blocks.6.conv1.weight",
529
+ "first_stage_model.decoder.up.3.block.0.time_stack.out_layers.0.bias": "blocks.6.norm2.bias",
530
+ "first_stage_model.decoder.up.3.block.0.time_stack.out_layers.0.weight": "blocks.6.norm2.weight",
531
+ "first_stage_model.decoder.up.3.block.0.time_stack.out_layers.3.bias": "blocks.6.conv2.bias",
532
+ "first_stage_model.decoder.up.3.block.0.time_stack.out_layers.3.weight": "blocks.6.conv2.weight",
533
+ "first_stage_model.decoder.up.3.block.1.conv1.bias": "blocks.7.conv1.bias",
534
+ "first_stage_model.decoder.up.3.block.1.conv1.weight": "blocks.7.conv1.weight",
535
+ "first_stage_model.decoder.up.3.block.1.conv2.bias": "blocks.7.conv2.bias",
536
+ "first_stage_model.decoder.up.3.block.1.conv2.weight": "blocks.7.conv2.weight",
537
+ "first_stage_model.decoder.up.3.block.1.mix_factor": "blocks.8.mix_factor",
538
+ "first_stage_model.decoder.up.3.block.1.norm1.bias": "blocks.7.norm1.bias",
539
+ "first_stage_model.decoder.up.3.block.1.norm1.weight": "blocks.7.norm1.weight",
540
+ "first_stage_model.decoder.up.3.block.1.norm2.bias": "blocks.7.norm2.bias",
541
+ "first_stage_model.decoder.up.3.block.1.norm2.weight": "blocks.7.norm2.weight",
542
+ "first_stage_model.decoder.up.3.block.1.time_stack.in_layers.0.bias": "blocks.8.norm1.bias",
543
+ "first_stage_model.decoder.up.3.block.1.time_stack.in_layers.0.weight": "blocks.8.norm1.weight",
544
+ "first_stage_model.decoder.up.3.block.1.time_stack.in_layers.2.bias": "blocks.8.conv1.bias",
545
+ "first_stage_model.decoder.up.3.block.1.time_stack.in_layers.2.weight": "blocks.8.conv1.weight",
546
+ "first_stage_model.decoder.up.3.block.1.time_stack.out_layers.0.bias": "blocks.8.norm2.bias",
547
+ "first_stage_model.decoder.up.3.block.1.time_stack.out_layers.0.weight": "blocks.8.norm2.weight",
548
+ "first_stage_model.decoder.up.3.block.1.time_stack.out_layers.3.bias": "blocks.8.conv2.bias",
549
+ "first_stage_model.decoder.up.3.block.1.time_stack.out_layers.3.weight": "blocks.8.conv2.weight",
550
+ "first_stage_model.decoder.up.3.block.2.conv1.bias": "blocks.9.conv1.bias",
551
+ "first_stage_model.decoder.up.3.block.2.conv1.weight": "blocks.9.conv1.weight",
552
+ "first_stage_model.decoder.up.3.block.2.conv2.bias": "blocks.9.conv2.bias",
553
+ "first_stage_model.decoder.up.3.block.2.conv2.weight": "blocks.9.conv2.weight",
554
+ "first_stage_model.decoder.up.3.block.2.mix_factor": "blocks.10.mix_factor",
555
+ "first_stage_model.decoder.up.3.block.2.norm1.bias": "blocks.9.norm1.bias",
556
+ "first_stage_model.decoder.up.3.block.2.norm1.weight": "blocks.9.norm1.weight",
557
+ "first_stage_model.decoder.up.3.block.2.norm2.bias": "blocks.9.norm2.bias",
558
+ "first_stage_model.decoder.up.3.block.2.norm2.weight": "blocks.9.norm2.weight",
559
+ "first_stage_model.decoder.up.3.block.2.time_stack.in_layers.0.bias": "blocks.10.norm1.bias",
560
+ "first_stage_model.decoder.up.3.block.2.time_stack.in_layers.0.weight": "blocks.10.norm1.weight",
561
+ "first_stage_model.decoder.up.3.block.2.time_stack.in_layers.2.bias": "blocks.10.conv1.bias",
562
+ "first_stage_model.decoder.up.3.block.2.time_stack.in_layers.2.weight": "blocks.10.conv1.weight",
563
+ "first_stage_model.decoder.up.3.block.2.time_stack.out_layers.0.bias": "blocks.10.norm2.bias",
564
+ "first_stage_model.decoder.up.3.block.2.time_stack.out_layers.0.weight": "blocks.10.norm2.weight",
565
+ "first_stage_model.decoder.up.3.block.2.time_stack.out_layers.3.bias": "blocks.10.conv2.bias",
566
+ "first_stage_model.decoder.up.3.block.2.time_stack.out_layers.3.weight": "blocks.10.conv2.weight",
567
+ "first_stage_model.decoder.up.3.upsample.conv.bias": "blocks.11.conv.bias",
568
+ "first_stage_model.decoder.up.3.upsample.conv.weight": "blocks.11.conv.weight",
569
+ }
570
+ state_dict_ = {}
571
+ for name in state_dict:
572
+ if name in rename_dict:
573
+ param = state_dict[name]
574
+ if "blocks.2.transformer_blocks.0" in rename_dict[name]:
575
+ param = param.squeeze()
576
+ state_dict_[rename_dict[name]] = param
577
+ return state_dict_
diffsynth/models/svd_vae_encoder.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder
2
+
3
+
4
+ class SVDVAEEncoder(SDVAEEncoder):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.scaling_factor = 0.13025
8
+
9
+ def state_dict_converter(self):
10
+ return SVDVAEEncoderStateDictConverter()
11
+
12
+
13
+ class SVDVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+ def from_diffusers(self, state_dict):
18
+ return super().from_diffusers(state_dict)
19
+
20
+ def from_civitai(self, state_dict):
21
+ rename_dict = {
22
+ "conditioner.embedders.3.encoder.encoder.conv_in.bias": "conv_in.bias",
23
+ "conditioner.embedders.3.encoder.encoder.conv_in.weight": "conv_in.weight",
24
+ "conditioner.embedders.3.encoder.encoder.conv_out.bias": "conv_out.bias",
25
+ "conditioner.embedders.3.encoder.encoder.conv_out.weight": "conv_out.weight",
26
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
27
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
28
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
29
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
30
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
31
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
32
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
33
+ "conditioner.embedders.3.encoder.encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
34
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
35
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
36
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
37
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
38
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
39
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
40
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
41
+ "conditioner.embedders.3.encoder.encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
42
+ "conditioner.embedders.3.encoder.encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
43
+ "conditioner.embedders.3.encoder.encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
44
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
45
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
46
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
47
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
48
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
49
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
50
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
51
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
52
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
53
+ "conditioner.embedders.3.encoder.encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
54
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
55
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
56
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
57
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
58
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
59
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
60
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
61
+ "conditioner.embedders.3.encoder.encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
62
+ "conditioner.embedders.3.encoder.encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
63
+ "conditioner.embedders.3.encoder.encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
64
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
65
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
66
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
67
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
68
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
69
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
70
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
71
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
72
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
73
+ "conditioner.embedders.3.encoder.encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
74
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
75
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
76
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
77
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
78
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
79
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
80
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
81
+ "conditioner.embedders.3.encoder.encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
82
+ "conditioner.embedders.3.encoder.encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
83
+ "conditioner.embedders.3.encoder.encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
84
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
85
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
86
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
87
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
88
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
89
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
90
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
91
+ "conditioner.embedders.3.encoder.encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
92
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
93
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
94
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
95
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
96
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
97
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
98
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
99
+ "conditioner.embedders.3.encoder.encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
100
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
101
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
102
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
103
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
104
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
105
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
106
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
107
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
108
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
109
+ "conditioner.embedders.3.encoder.encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
110
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
111
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
112
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
113
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
114
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
115
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
116
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
117
+ "conditioner.embedders.3.encoder.encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
118
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
119
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
120
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
121
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
122
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
123
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
124
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
125
+ "conditioner.embedders.3.encoder.encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
126
+ "conditioner.embedders.3.encoder.encoder.norm_out.bias": "conv_norm_out.bias",
127
+ "conditioner.embedders.3.encoder.encoder.norm_out.weight": "conv_norm_out.weight",
128
+ "conditioner.embedders.3.encoder.quant_conv.bias": "quant_conv.bias",
129
+ "conditioner.embedders.3.encoder.quant_conv.weight": "quant_conv.weight",
130
+ }
131
+ state_dict_ = {}
132
+ for name in state_dict:
133
+ if name in rename_dict:
134
+ param = state_dict[name]
135
+ if "transformer_blocks" in rename_dict[name]:
136
+ param = param.squeeze()
137
+ state_dict_[rename_dict[name]] = param
138
+ return state_dict_