|
import subprocess |
|
|
|
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, |
|
shell=True) |
|
|
|
import gradio as gr |
|
import torch |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
) |
|
import docx |
|
import PyPDF2 |
|
import spaces |
|
|
|
|
|
def convert_to_txt(file): |
|
doc_type = file.split(".")[-1].strip() |
|
if doc_type in ["txt", "md", "py"]: |
|
data = [file.read().decode("utf-8")] |
|
elif doc_type in ["pdf"]: |
|
pdf_reader = PyPDF2.PdfReader(file) |
|
data = [ |
|
pdf_reader.pages[i].extract_text() for i in range(len(pdf_reader.pages)) |
|
] |
|
elif doc_type in ["docx"]: |
|
doc = docx.Document(file) |
|
data = [p.text for p in doc.paragraphs] |
|
else: |
|
raise gr.Error(f"ERROR: unsupported document type: {doc_type}") |
|
text = "\n\n".join(data) |
|
return text |
|
|
|
|
|
model_name = "THUDM/LongCite-glm4-9b" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.bfloat16, |
|
trust_remote_code=True, |
|
device="cuda", |
|
attn_implementation="flash_attention_2", |
|
) |
|
|
|
html_styles = """<style> |
|
.reference { |
|
color: blue; |
|
text-decoration: underline; |
|
} |
|
.highlight { |
|
background-color: yellow; |
|
} |
|
.label { |
|
font-family: sans-serif; |
|
font-size: 16px; |
|
font-weight: bold; |
|
} |
|
.Bold { |
|
font-weight: bold; |
|
} |
|
.statement { |
|
background-color: lightgrey; |
|
} |
|
</style>\n""" |
|
|
|
|
|
def process_text(text): |
|
special_char = { |
|
"&": "&", |
|
"'": "'", |
|
'"': """, |
|
"<": "<", |
|
">": ">", |
|
"\n": "<br>", |
|
} |
|
for x, y in special_char.items(): |
|
text = text.replace(x, y) |
|
return text |
|
|
|
|
|
def convert_to_html(statements, clicked=-1): |
|
html = html_styles + '<br><span class="label">Answer:</span><br>\n' |
|
all_cite_html = [] |
|
clicked_cite_html = None |
|
cite_num2idx = {} |
|
idx = 0 |
|
for i, js in enumerate(statements): |
|
statement, citations = process_text(js["statement"]), js["citation"] |
|
if clicked == i: |
|
html += f"""<span class="statement">{statement}</span>""" |
|
else: |
|
html += f"<span>{statement}</span>" |
|
if citations: |
|
cite_html = [] |
|
idxs = [] |
|
for c in citations: |
|
idx += 1 |
|
idxs.append(str(idx)) |
|
cite = ( |
|
"[Sentence: {}-{}\t|\tChar: {}-{}]<br>\n<span {}>{}</span>".format( |
|
c["start_sentence_idx"], |
|
c["end_sentence_idx"], |
|
c["start_char_idx"], |
|
c["end_char_idx"], |
|
'class="highlight"' if clicked == i else "", |
|
process_text(c["cite"].strip()), |
|
) |
|
) |
|
cite_html.append( |
|
f"""<span><span class="Bold">Snippet [{idx}]:</span><br>{cite}</span>""" |
|
) |
|
all_cite_html.extend(cite_html) |
|
cite_num = "[{}]".format(",".join(idxs)) |
|
cite_num2idx[cite_num] = i |
|
cite_num_html = """ <span class="reference" style="color: blue" id={}>{}</span>""".format( |
|
i, cite_num |
|
) |
|
html += cite_num_html |
|
html += "\n" |
|
if clicked == i: |
|
clicked_cite_html = ( |
|
html_styles |
|
+ """<br><span class="label">Citations of current statement:</span><br><div style="overflow-y: auto; padding: 20px; border: 0px dashed black; border-radius: 6px; background-color: #EFF2F6;">{}</div>""".format( |
|
"<br><br>\n".join(cite_html) |
|
) |
|
) |
|
all_cite_html = ( |
|
html_styles |
|
+ """<br><span class="label">All citations:</span><br>\n<div style="overflow-y: auto; padding: 20px; border: 0px dashed black; border-radius: 6px; background-color: #EFF2F6;">{}</div>""".format( |
|
"<br><br>\n".join(all_cite_html).replace( |
|
'<span class="highlight">', "<span>" |
|
) |
|
if len(all_cite_html) |
|
else "No citation in the answer" |
|
) |
|
) |
|
return html, all_cite_html, clicked_cite_html, cite_num2idx |
|
|
|
|
|
def render_context(file): |
|
if hasattr(file, "name"): |
|
context = convert_to_txt(file.name) |
|
return gr.Textbox(context, visible=True) |
|
else: |
|
raise gr.Error(f"ERROR: no uploaded document") |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
def infer(context, query): |
|
return model.query_longcite( |
|
context=context, |
|
query=query, |
|
tokenizer=tokenizer, |
|
max_input_length=128000, |
|
max_new_tokens=1024, |
|
) |
|
|
|
def run_llm(context, query): |
|
if not context: |
|
raise gr.Error("Error: no uploaded document") |
|
if not query: |
|
raise gr.Error("Error: no query") |
|
result = infer(context=context, query=query) |
|
all_statements = result["all_statements"] |
|
answer_html, all_cite_html, clicked_cite_html, cite_num2idx_dict = convert_to_html( |
|
all_statements |
|
) |
|
cite_nums = list(cite_num2idx_dict.keys()) |
|
return { |
|
statements: gr.JSON(all_statements), |
|
answer: gr.HTML(answer_html, visible=True), |
|
all_citations: gr.HTML(all_cite_html, visible=True), |
|
cite_num2idx: gr.JSON(cite_num2idx_dict), |
|
citation_choices: gr.Radio(cite_nums, visible=len(cite_nums) > 0), |
|
clicked_citations: gr.HTML(visible=False), |
|
} |
|
|
|
|
|
def chose_citation(statements, cite_num2idx, clicked_cite_num): |
|
clicked = cite_num2idx[clicked_cite_num] |
|
answer_html, _, clicked_cite_html, _ = convert_to_html(statements, clicked=clicked) |
|
return { |
|
answer: gr.HTML(answer_html, visible=True), |
|
clicked_citations: gr.HTML(clicked_cite_html, visible=True), |
|
} |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;"> |
|
LongCite-glm4-9b Huggingface Space🤗 |
|
</div> |
|
<div style="text-align: center;"> |
|
<a href="https://huggingface.co/THUDM/LongCite-glm4-9b">🤗 Model Hub</a> | |
|
<a href="https://github.com/THUDM/LongCite">🌐 Github</a> | |
|
<a href="https://arxiv.org/abs/2409.02897">📜 arxiv </a> |
|
</div> |
|
<br> |
|
<div style="text-align: center; font-size: 15px; font-weight: bold; margin-bottom: 20px; line-height: 1.5;"> |
|
If you plan to use it long-term, please consider deploying the model or forking this space yourself. |
|
</div> |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
file = gr.File( |
|
label="Upload a document (supported type: pdf, docx, txt, md, py)" |
|
) |
|
query = gr.Textbox(label="Question") |
|
submit_btn = gr.Button("Submit") |
|
|
|
with gr.Column(scale=4): |
|
context = gr.Textbox( |
|
label="Document content", |
|
autoscroll=False, |
|
placeholder="No uploaded document.", |
|
max_lines=10, |
|
visible=False, |
|
) |
|
|
|
file.upload(render_context, [file], [context]) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
statements = gr.JSON(label="statements", visible=False) |
|
answer = gr.HTML(label="Answer", visible=True) |
|
cite_num2idx = gr.JSON(label="cite_num2idx", visible=False) |
|
citation_choices = gr.Radio( |
|
label="Chose citations for details", visible=False, interactive=True |
|
) |
|
|
|
with gr.Column(scale=4): |
|
clicked_citations = gr.HTML( |
|
label="Citations of the chosen statement", visible=False |
|
) |
|
all_citations = gr.HTML(label="All citations", visible=False) |
|
|
|
submit_btn.click( |
|
run_llm, |
|
[context, query], |
|
[ |
|
statements, |
|
answer, |
|
all_citations, |
|
cite_num2idx, |
|
citation_choices, |
|
clicked_citations, |
|
], |
|
) |
|
citation_choices.change( |
|
chose_citation, |
|
[statements, cite_num2idx, citation_choices], |
|
[answer, clicked_citations], |
|
) |
|
|
|
demo.queue() |
|
demo.launch() |
|
|