text2image / app.py
iohanngrig's picture
Update app.py
d69cde5 verified
raw
history blame contribute delete
No virus
2.91 kB
import streamlit as st
from io import BytesIO
from typing import Literal
from diffusers import StableDiffusionPipeline
import torch
import time
seed = 42
generator = torch.manual_seed(seed)
NUM_ITERS_TO_RUN = 2
NUM_INFERENCE_STEPS = 20
NUM_IMAGES_PER_PROMPT = 1
def text2image(
prompt: str,
repo_id: Literal[
"dreamlike-art/dreamlike-photoreal-2.0",
"hakurei/waifu-diffusion",
"prompthero/openjourney",
"stabilityai/stable-diffusion-2-1",
"runwayml/stable-diffusion-v1-5",
"nota-ai/bk-sdm-small",
"CompVis/stable-diffusion-v1-4",
],
):
start = time.time()
if torch.cuda.is_available():
print("Using GPU")
pipeline = StableDiffusionPipeline.from_pretrained(
repo_id,
torch_dtype=torch.float16,
use_safetensors=True,
).to("cuda")
else:
print("Using CPU")
pipeline = StableDiffusionPipeline.from_pretrained(
repo_id,
torch_dtype=torch.float32,
use_safetensors=True,
)
for _ in range(NUM_ITERS_TO_RUN):
images = pipeline(
prompt,
num_inference_steps=NUM_INFERENCE_STEPS,
generator=generator,
num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
).images
end = time.time()
return images[0], start, end
def app():
st.header("Text-to-image Web App")
st.subheader("Powered by Hugging Face")
user_input = st.text_area(
"Enter your text prompt below and click the button to submit."
)
option = st.selectbox(
"Select model (in order of processing time)",
(
"nota-ai/bk-sdm-small",
"CompVis/stable-diffusion-v1-4",
"runwayml/stable-diffusion-v1-5",
"prompthero/openjourney",
"hakurei/waifu-diffusion",
"stabilityai/stable-diffusion-2-1",
"dreamlike-art/dreamlike-photoreal-2.0",
),
)
with st.form("my_form"):
submit = st.form_submit_button(label="Submit text prompt")
if submit:
with st.spinner(text="Generating image ... It may take up to 20 minutes."):
im, start, end = text2image(prompt=user_input, repo_id=option)
buf = BytesIO()
im.save(buf, format="PNG")
byte_im = buf.getvalue()
hours, rem = divmod(end - start, 3600)
minutes, seconds = divmod(rem, 60)
st.success(
"Processing time: {:0>2}:{:0>2}:{:05.2f}.".format(
int(hours), int(minutes), seconds
)
)
st.image(im)
st.download_button(
label="Click here to download",
data=byte_im,
file_name="generated_image.png",
mime="image/png",
)
if __name__ == "__main__":
app()