narugo1992
dev(narugo): use new models
e8760fb
raw
history blame contribute delete
No virus
7.53 kB
import json
import logging
import os
import re
import shutil
from functools import lru_cache
from typing import Optional, List, Tuple, Mapping
import gradio as gr
import numpy as np
from PIL import Image
from hbutils.system import pip_install
from huggingface_hub import hf_hub_download
def _ensure_onnxruntime():
try:
import onnxruntime
except (ImportError, ModuleNotFoundError):
logging.warning('Onnx runtime not installed, preparing to install ...')
if shutil.which('nvidia-smi'):
logging.info('Installing onnxruntime-gpu ...')
pip_install(['onnxruntime-gpu'], silent=True)
else:
logging.info('Installing onnxruntime (cpu) ...')
pip_install(['onnxruntime'], silent=True)
_ensure_onnxruntime()
from onnxruntime import get_available_providers, get_all_providers, InferenceSession, SessionOptions, \
GraphOptimizationLevel
alias = {
'gpu': "CUDAExecutionProvider",
"trt": "TensorrtExecutionProvider",
}
def get_onnx_provider(provider: Optional[str] = None):
if not provider:
if "CUDAExecutionProvider" in get_available_providers():
return "CUDAExecutionProvider"
else:
return "CPUExecutionProvider"
elif provider.lower() in alias:
return alias[provider.lower()]
else:
for p in get_all_providers():
if provider.lower() == p.lower() or f'{provider}ExecutionProvider'.lower() == p.lower():
return p
raise ValueError(f'One of the {get_all_providers()!r} expected, '
f'but unsupported provider {provider!r} found.')
def resize(pic: Image.Image, size: int, keep_ratio: float = True) -> Image.Image:
if not keep_ratio:
target_size = (size, size)
else:
min_edge = min(pic.size)
target_size = (
int(pic.size[0] / min_edge * size),
int(pic.size[1] / min_edge * size),
)
target_size = (
(target_size[0] // 4) * 4,
(target_size[1] // 4) * 4,
)
return pic.resize(target_size, resample=Image.Resampling.BILINEAR)
def to_tensor(pic: Image.Image):
img: np.ndarray = np.array(pic, np.uint8, copy=True)
img = img.reshape(pic.size[1], pic.size[0], len(pic.getbands()))
# put it from HWC to CHW format
img = img.transpose((2, 0, 1))
return img.astype(np.float32) / 255
def fill_background(pic: Image.Image, background: str = 'white') -> Image.Image:
if pic.mode == 'RGB':
return pic
if pic.mode != 'RGBA':
pic = pic.convert('RGBA')
background = background or 'white'
result = Image.new('RGBA', pic.size, background)
result.paste(pic, (0, 0), pic)
return result.convert('RGB')
def image_to_tensor(pic: Image.Image, size: int = 512, keep_ratio: float = True, background: str = 'white'):
return to_tensor(resize(fill_background(pic, background), size, keep_ratio))
MODELS = [
'ml_caformer_m36_dec-5-97527.onnx',
'ml_caformer_m36_dec-3-80000.onnx',
'TResnet-D-FLq_ema_6-30000.onnx',
'TResnet-D-FLq_ema_6-10000.onnx',
'TResnet-D-FLq_ema_4-10000.onnx',
'TResnet-D-FLq_ema_2-40000.onnx',
]
DEFAULT_MODEL = MODELS[0]
def get_onnx_model_file(name=DEFAULT_MODEL):
return hf_hub_download(
repo_id='deepghs/ml-danbooru-onnx',
filename=name,
)
@lru_cache()
def _open_onnx_model(ckpt: str, provider: str) -> InferenceSession:
options = SessionOptions()
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
if provider == "CPUExecutionProvider":
options.intra_op_num_threads = os.cpu_count()
logging.info(f'Model {ckpt!r} loaded with provider {provider!r}')
return InferenceSession(ckpt, options, [provider])
def load_classes() -> List[str]:
classes_file = hf_hub_download(
repo_id='deepghs/ml-danbooru-onnx',
filename='classes.json',
)
with open(classes_file, 'r', encoding='utf-8') as f:
return json.load(f)
def get_tags_from_image(pic: Image.Image, threshold: float = 0.7, size: int = 512, keep_ratio: bool = False,
model_name=DEFAULT_MODEL):
real_input = image_to_tensor(pic, size, keep_ratio)
real_input = real_input.reshape(1, *real_input.shape)
model = _open_onnx_model(get_onnx_model_file(model_name), get_onnx_provider('cpu'))
native_output, = model.run(['output'], {'input': real_input})
output = (1 / (1 + np.exp(-native_output))).reshape(-1)
tags = load_classes()
pairs = sorted([(tags[i], ratio) for i, ratio in enumerate(output)], key=lambda x: (-x[1], x[0]))
return {tag: float(ratio) for tag, ratio in pairs if ratio >= threshold}
RE_SPECIAL = re.compile(r'([\\()])')
def image_to_mldanbooru_tags(pic: Image.Image, threshold: float, size: int, keep_ratio: bool, model: str,
use_spaces: bool, use_escape: bool, include_ranks: bool, score_descend: bool) \
-> Tuple[str, Mapping[str, float]]:
filtered_tags = get_tags_from_image(pic, threshold, size, keep_ratio, model)
text_items = []
tags_pairs = filtered_tags.items()
if score_descend:
tags_pairs = sorted(tags_pairs, key=lambda x: (-x[1], x[0]))
for tag, score in tags_pairs:
tag_outformat = tag
if use_spaces:
tag_outformat = tag_outformat.replace('_', ' ')
if use_escape:
tag_outformat = re.sub(RE_SPECIAL, r'\\\1', tag_outformat)
if include_ranks:
tag_outformat = f"({tag_outformat}:{score:.3f})"
text_items.append(tag_outformat)
output_text = ', '.join(text_items)
return output_text, filtered_tags
if __name__ == '__main__':
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
gr_input_image = gr.Image(type='pil', label='Original Image')
with gr.Row():
gr_threshold = gr.Slider(0.0, 1.0, 0.7, label='Tagging Confidence Threshold')
# gr_image_size = gr.Slider(128, 960, 640, step=32, label='Image for Recognition')
gr_image_size = gr.Slider(128, 960, 448, step=32, label='Image for Recognition')
gr_keep_ratio = gr.Checkbox(value=False, label='Keep the Ratio')
with gr.Row():
gr_model = gr.Dropdown(MODELS, value=DEFAULT_MODEL, label='Model')
with gr.Row():
gr_space = gr.Checkbox(value=False, label='Use Space Instead Of _')
gr_escape = gr.Checkbox(value=True, label='Use Text Escape')
gr_confidence = gr.Checkbox(value=False, label='Keep Confidences')
gr_order = gr.Checkbox(value=True, label='Descend By Confidence')
gr_btn_submit = gr.Button(value='Tagging', variant='primary')
with gr.Column():
with gr.Tabs():
with gr.Tab("Tags"):
gr_tags = gr.Label(label='Tags')
with gr.Tab("Exported Text"):
gr_output_text = gr.TextArea(label='Exported Text')
gr_btn_submit.click(
image_to_mldanbooru_tags,
inputs=[
gr_input_image, gr_threshold, gr_image_size,
gr_keep_ratio, gr_model,
gr_space, gr_escape, gr_confidence, gr_order
],
outputs=[gr_output_text, gr_tags],
)
demo.queue(os.cpu_count()).launch()