|
import gradio as gr |
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel |
|
import requests |
|
from PIL import Image |
|
import re |
|
|
|
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") |
|
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") |
|
|
|
def remove_non_alphabet_chars(s): |
|
return re.sub('[^a-zA-Z]', '', s) |
|
|
|
|
|
urls = [ |
|
'https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg', |
|
'https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcSoolxi9yWGAT5SLZShv8vVd0bz47UWRzQC19fDTeE8GmGv_Rn-PCF1pP1rrUx8kOjA4gg&usqp=CAU', |
|
'https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcRNYtTuSBpZPV_nkBYPMFwVVD9asZOPgHww4epu9EqWgDmXW--sE2o8og40ZfDGo87j5w&usqp=CAU', |
|
] |
|
|
|
for idx, url in enumerate(urls): |
|
image = Image.open(requests.get(url, stream=True).raw) |
|
image.save(f"image_{idx}.png") |
|
|
|
def process_image(image): |
|
pixel_values = processor(image, return_tensors="pt").pixel_values |
|
|
|
generated_ids = model.generate(pixel_values) |
|
|
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
|
generated_text = remove_non_alphabet_chars(generated_text) |
|
|
|
return generated_text |
|
|
|
title = "Interactive demo: TrOCR - Handwriting Text Recognition" |
|
description = "Short Demo for the handwriting recognition component of the Receipt OCR tool. Upload an image of handwriting which you want the model to read. I recommend screenshotting words on here: https://www.calligrapher.ai/" |
|
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2109.10282'>TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models</a> | <a href='https://github.com/microsoft/unilm/tree/master/trocr'>Github Repo</a></p>" |
|
examples =[["image_0.png"], ["image_1.png"], ["image_2.png"]] |
|
|
|
iface = gr.Interface(fn=process_image, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Textbox(), |
|
title=title, |
|
description=description, |
|
article=article, |
|
examples=examples) |
|
|
|
iface.launch(debug=True, share=True) |