multimodalart HF staff commited on
Commit
ad569d5
1 Parent(s): 2a2c118

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -15
app.py CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
2
  import torch
3
  from diffusers import StableDiffusionXLPipeline, AutoencoderKL
4
  from huggingface_hub import hf_hub_download
 
5
  from share_btn import community_icon_html, loading_icon_html, share_js
6
  from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler
7
  import lora
@@ -26,12 +27,19 @@ with open("sdxl_loras.json", "r") as file:
26
  }
27
  for item in data
28
  ]
29
- print(sdxl_loras)
30
- saved_names = [
31
- hf_hub_download(item["repo"], item["weights"]) for item in sdxl_loras
32
- ]
33
 
34
- device = "cuda" # replace this to `mps` if on a MacOS Silicon
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  vae = AutoencoderKL.from_pretrained(
37
  "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
@@ -40,14 +48,13 @@ pipe = StableDiffusionXLPipeline.from_pretrained(
40
  "stabilityai/stable-diffusion-xl-base-1.0",
41
  vae=vae,
42
  torch_dtype=torch.float16,
43
- ).to("cpu")
44
  original_pipe = copy.deepcopy(pipe)
45
  pipe.to(device)
46
 
47
  last_lora = ""
48
  last_merged = False
49
 
50
-
51
  def update_selection(selected_state: gr.SelectData):
52
  lora_repo = sdxl_loras[selected_state.index]["repo"]
53
  instance_prompt = sdxl_loras[selected_state.index]["trigger_word"]
@@ -128,7 +135,7 @@ def merge_incompatible_lora(full_path_lora, lora_scale):
128
  del lora_model
129
  gc.collect()
130
 
131
- def run_lora(prompt, negative, lora_scale, selected_state):
132
  global last_lora, last_merged, pipe
133
 
134
  if negative == "":
@@ -138,7 +145,8 @@ def run_lora(prompt, negative, lora_scale, selected_state):
138
  raise gr.Error("You must select a LoRA")
139
  repo_name = sdxl_loras[selected_state.index]["repo"]
140
  weight_name = sdxl_loras[selected_state.index]["weights"]
141
- full_path_lora = saved_names[selected_state.index]
 
142
  cross_attention_kwargs = None
143
  if last_lora != repo_name:
144
  if last_merged:
@@ -148,17 +156,17 @@ def run_lora(prompt, negative, lora_scale, selected_state):
148
  pipe.to(device)
149
  else:
150
  pipe.unload_lora_weights()
 
151
  is_compatible = sdxl_loras[selected_state.index]["is_compatible"]
152
  if is_compatible:
153
- pipe.load_lora_weights(full_path_lora)
154
- cross_attention_kwargs = {"scale": lora_scale}
155
  else:
156
  is_pivotal = sdxl_loras[selected_state.index]["is_pivotal"]
157
  if(is_pivotal):
 
 
158
 
159
- pipe.load_lora_weights(full_path_lora)
160
- cross_attention_kwargs = {"scale": lora_scale}
161
-
162
  #Add the textual inversion embeddings from pivotal tuning models
163
  text_embedding_name = sdxl_loras[selected_state.index]["text_embedding_weights"]
164
  text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
@@ -177,7 +185,6 @@ def run_lora(prompt, negative, lora_scale, selected_state):
177
  height=768,
178
  num_inference_steps=20,
179
  guidance_scale=7.5,
180
- cross_attention_kwargs=cross_attention_kwargs,
181
  ).images[0]
182
  last_lora = repo_name
183
  gc.collect()
 
2
  import torch
3
  from diffusers import StableDiffusionXLPipeline, AutoencoderKL
4
  from huggingface_hub import hf_hub_download
5
+ from safetensors.torch import load_file
6
  from share_btn import community_icon_html, loading_icon_html, share_js
7
  from cog_sdxl_dataset_and_utils import TokenEmbeddingsHandler
8
  import lora
 
27
  }
28
  for item in data
29
  ]
 
 
 
 
30
 
31
+ device = "cuda"
32
+
33
+ for item in sdxl_loras:
34
+ saved_name = hf_hub_download(item["repo"], item["weights"])
35
+
36
+ if not saved_name.endswith('.safetensors'):
37
+ state_dict = torch.load(saved_name)
38
+ else:
39
+ state_dict = load_file(saved_name)
40
+
41
+ item["saved_name"] = saved_name
42
+ item["state_dict"] = state_dict #{k: v.to(device=device, dtype=torch.float16) for k, v in state_dict.items() if torch.is_tensor(v)}
43
 
44
  vae = AutoencoderKL.from_pretrained(
45
  "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
 
48
  "stabilityai/stable-diffusion-xl-base-1.0",
49
  vae=vae,
50
  torch_dtype=torch.float16,
51
+ )
52
  original_pipe = copy.deepcopy(pipe)
53
  pipe.to(device)
54
 
55
  last_lora = ""
56
  last_merged = False
57
 
 
58
  def update_selection(selected_state: gr.SelectData):
59
  lora_repo = sdxl_loras[selected_state.index]["repo"]
60
  instance_prompt = sdxl_loras[selected_state.index]["trigger_word"]
 
135
  del lora_model
136
  gc.collect()
137
 
138
+ def run_lora(prompt, negative, lora_scale, selected_state, progress=gr.Progress(track_tqdm=True)):
139
  global last_lora, last_merged, pipe
140
 
141
  if negative == "":
 
145
  raise gr.Error("You must select a LoRA")
146
  repo_name = sdxl_loras[selected_state.index]["repo"]
147
  weight_name = sdxl_loras[selected_state.index]["weights"]
148
+ full_path_lora = sdxl_loras[selected_state.index]["saved_name"]
149
+ loaded_state_dict = sdxl_loras[selected_state.index]["state_dict"]
150
  cross_attention_kwargs = None
151
  if last_lora != repo_name:
152
  if last_merged:
 
156
  pipe.to(device)
157
  else:
158
  pipe.unload_lora_weights()
159
+ pipe.unfuse_lora()
160
  is_compatible = sdxl_loras[selected_state.index]["is_compatible"]
161
  if is_compatible:
162
+ pipe.load_lora_weights(loaded_state_dict)
163
+ pipe.fuse_lora(lora_scale)
164
  else:
165
  is_pivotal = sdxl_loras[selected_state.index]["is_pivotal"]
166
  if(is_pivotal):
167
+ pipe.load_lora_weights(loaded_state_dict)
168
+ pipe.fuse_lora(lora_scale)
169
 
 
 
 
170
  #Add the textual inversion embeddings from pivotal tuning models
171
  text_embedding_name = sdxl_loras[selected_state.index]["text_embedding_weights"]
172
  text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
 
185
  height=768,
186
  num_inference_steps=20,
187
  guidance_scale=7.5,
 
188
  ).images[0]
189
  last_lora = repo_name
190
  gc.collect()