LongCite / app.py
NeoZ123's picture
Update app.py
33cc87a verified
raw
history blame contribute delete
No virus
8.35 kB
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
shell=True)
import gradio as gr
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
)
import docx
import PyPDF2
import spaces
def convert_to_txt(file):
doc_type = file.split(".")[-1].strip()
if doc_type in ["txt", "md", "py"]:
data = [file.read().decode("utf-8")]
elif doc_type in ["pdf"]:
pdf_reader = PyPDF2.PdfReader(file)
data = [
pdf_reader.pages[i].extract_text() for i in range(len(pdf_reader.pages))
]
elif doc_type in ["docx"]:
doc = docx.Document(file)
data = [p.text for p in doc.paragraphs]
else:
raise gr.Error(f"ERROR: unsupported document type: {doc_type}")
text = "\n\n".join(data)
return text
model_name = "THUDM/LongCite-glm4-9b"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device="cuda",
attn_implementation="flash_attention_2",
)
html_styles = """<style>
.reference {
color: blue;
text-decoration: underline;
}
.highlight {
background-color: yellow;
}
.label {
font-family: sans-serif;
font-size: 16px;
font-weight: bold;
}
.Bold {
font-weight: bold;
}
.statement {
background-color: lightgrey;
}
</style>\n"""
def process_text(text):
special_char = {
"&": "&amp;",
"'": "&apos;",
'"': "&quot;",
"<": "&lt;",
">": "&gt;",
"\n": "<br>",
}
for x, y in special_char.items():
text = text.replace(x, y)
return text
def convert_to_html(statements, clicked=-1):
html = html_styles + '<br><span class="label">Answer:</span><br>\n'
all_cite_html = []
clicked_cite_html = None
cite_num2idx = {}
idx = 0
for i, js in enumerate(statements):
statement, citations = process_text(js["statement"]), js["citation"]
if clicked == i:
html += f"""<span class="statement">{statement}</span>"""
else:
html += f"<span>{statement}</span>"
if citations:
cite_html = []
idxs = []
for c in citations:
idx += 1
idxs.append(str(idx))
cite = (
"[Sentence: {}-{}\t|\tChar: {}-{}]<br>\n<span {}>{}</span>".format(
c["start_sentence_idx"],
c["end_sentence_idx"],
c["start_char_idx"],
c["end_char_idx"],
'class="highlight"' if clicked == i else "",
process_text(c["cite"].strip()),
)
)
cite_html.append(
f"""<span><span class="Bold">Snippet [{idx}]:</span><br>{cite}</span>"""
)
all_cite_html.extend(cite_html)
cite_num = "[{}]".format(",".join(idxs))
cite_num2idx[cite_num] = i
cite_num_html = """ <span class="reference" style="color: blue" id={}>{}</span>""".format(
i, cite_num
)
html += cite_num_html
html += "\n"
if clicked == i:
clicked_cite_html = (
html_styles
+ """<br><span class="label">Citations of current statement:</span><br><div style="overflow-y: auto; padding: 20px; border: 0px dashed black; border-radius: 6px; background-color: #EFF2F6;">{}</div>""".format(
"<br><br>\n".join(cite_html)
)
)
all_cite_html = (
html_styles
+ """<br><span class="label">All citations:</span><br>\n<div style="overflow-y: auto; padding: 20px; border: 0px dashed black; border-radius: 6px; background-color: #EFF2F6;">{}</div>""".format(
"<br><br>\n".join(all_cite_html).replace(
'<span class="highlight">', "<span>"
)
if len(all_cite_html)
else "No citation in the answer"
)
)
return html, all_cite_html, clicked_cite_html, cite_num2idx
def render_context(file):
if hasattr(file, "name"):
context = convert_to_txt(file.name)
return gr.Textbox(context, visible=True)
else:
raise gr.Error(f"ERROR: no uploaded document")
@spaces.GPU(duration=120)
def infer(context, query):
return model.query_longcite(
context=context,
query=query,
tokenizer=tokenizer,
max_input_length=128000,
max_new_tokens=1024,
)
def run_llm(context, query):
if not context:
raise gr.Error("Error: no uploaded document")
if not query:
raise gr.Error("Error: no query")
result = infer(context=context, query=query)
all_statements = result["all_statements"]
answer_html, all_cite_html, clicked_cite_html, cite_num2idx_dict = convert_to_html(
all_statements
)
cite_nums = list(cite_num2idx_dict.keys())
return {
statements: gr.JSON(all_statements),
answer: gr.HTML(answer_html, visible=True),
all_citations: gr.HTML(all_cite_html, visible=True),
cite_num2idx: gr.JSON(cite_num2idx_dict),
citation_choices: gr.Radio(cite_nums, visible=len(cite_nums) > 0),
clicked_citations: gr.HTML(visible=False),
}
def chose_citation(statements, cite_num2idx, clicked_cite_num):
clicked = cite_num2idx[clicked_cite_num]
answer_html, _, clicked_cite_html, _ = convert_to_html(statements, clicked=clicked)
return {
answer: gr.HTML(answer_html, visible=True),
clicked_citations: gr.HTML(clicked_cite_html, visible=True),
}
with gr.Blocks() as demo:
gr.Markdown(
"""
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
LongCite-glm4-9b Huggingface Space🤗
</div>
<div style="text-align: center;">
<a href="https://huggingface.co/THUDM/LongCite-glm4-9b">🤗 Model Hub</a> |
<a href="https://github.com/THUDM/LongCite">🌐 Github</a> |
<a href="https://arxiv.org/abs/2409.02897">📜 arxiv </a>
</div>
<br>
<div style="text-align: center; font-size: 15px; font-weight: bold; margin-bottom: 20px; line-height: 1.5;">
If you plan to use it long-term, please consider deploying the model or forking this space yourself.
</div>
"""
)
with gr.Row():
with gr.Column(scale=4):
file = gr.File(
label="Upload a document (supported type: pdf, docx, txt, md, py)"
)
query = gr.Textbox(label="Question")
submit_btn = gr.Button("Submit")
with gr.Column(scale=4):
context = gr.Textbox(
label="Document content",
autoscroll=False,
placeholder="No uploaded document.",
max_lines=10,
visible=False,
)
file.upload(render_context, [file], [context])
with gr.Row():
with gr.Column(scale=4):
statements = gr.JSON(label="statements", visible=False)
answer = gr.HTML(label="Answer", visible=True)
cite_num2idx = gr.JSON(label="cite_num2idx", visible=False)
citation_choices = gr.Radio(
label="Chose citations for details", visible=False, interactive=True
)
with gr.Column(scale=4):
clicked_citations = gr.HTML(
label="Citations of the chosen statement", visible=False
)
all_citations = gr.HTML(label="All citations", visible=False)
submit_btn.click(
run_llm,
[context, query],
[
statements,
answer,
all_citations,
cite_num2idx,
citation_choices,
clicked_citations,
],
)
citation_choices.change(
chose_citation,
[statements, cite_num2idx, citation_choices],
[answer, clicked_citations],
)
demo.queue()
demo.launch()