LLMAgora / app.py
Cartinoe5930's picture
Update app.py
8ce66bc
raw
history blame contribute delete
No virus
20.3 kB
import gradio as gr
import json
import requests
import os
from model_inference import Inference
import time
HF_TOKEN = os.environ.get("HF_TOKEN")
question_selector_map = {}
every_model = ["llama2", "llama2-chat", "vicuna", "falcon", "falcon-instruct", "orca", "wizardlm"]
with open("src/inference_endpoint.json", "r") as f:
inference_endpoint = json.load(f)
for i in range(len(every_model)):
inference_endpoint[every_model[i]]["headers"]["Authorization"] += HF_TOKEN
def build_question_selector_map(questions):
question_selector_map = {}
# Build question selector map
for q in questions:
preview = f"{q['question_id']+1}: " + q["question"][:128] + "..."
question_selector_map[preview] = q
return question_selector_map
def math_display_question_answer(question, cot, request: gr.Request):
if cot:
q = math_cot_question_selector_map[question]
else:
q = math_question_selector_map[question]
return q["agent_response"]["llama"][0], q["agent_response"]["wizardlm"][0], q["agent_response"]["orca"][0], q["summarization"][0], q["agent_response"]["llama"][1], q["agent_response"]["wizardlm"][1], q["agent_response"]["orca"][1], q["summarization"][1], q["agent_response"]["llama"][2], q["agent_response"]["wizardlm"][2], q["agent_response"]["orca"][2]
def gsm_display_question_answer(question, cot, request: gr.Request):
if cot:
q = gsm_cot_question_selector_map[question]
else:
q = gsm_question_selector_map[question]
return q["agent_response"]["llama"][0], q["agent_response"]["wizardlm"][0], q["agent_response"]["orca"][0], q["summarization"][0], q["agent_response"]["llama"][1], q["agent_response"]["wizardlm"][1], q["agent_response"]["orca"][1], q["summarization"][1], q["agent_response"]["llama"][2], q["agent_response"]["wizardlm"][2], q["agent_response"]["orca"][2]
def mmlu_display_question_answer(question, cot, request: gr.Request):
if cot:
q = mmlu_cot_question_selector_map[question]
else:
q = mmlu_question_selector_map[question]
return q["agent_response"]["llama"][0], q["agent_response"]["wizardlm"][0], q["agent_response"]["orca"][0], q["summarization"][0], q["agent_response"]["llama"][1], q["agent_response"]["wizardlm"][1], q["agent_response"]["orca"][1], q["summarization"][1], q["agent_response"]["llama"][2], q["agent_response"]["wizardlm"][2], q["agent_response"]["orca"][2]
def warmup(list_model, model_inference_endpoints=inference_endpoint):
for model in list_model:
model = model.lower()
API_URL = model_inference_endpoints[model]["API_URL"]
headers = model_inference_endpoints[model]["headers"]
headers["Authorization"] += HF_TOKEN
def query(payload):
return requests.post(API_URL, headers=headers, json=payload)
output = query({
"inputs": "Hello. "
})
time.sleep(300)
return {
model_list: gr.update(visible=False),
options: gr.update(visible=True),
inputbox: gr.update(visible=True),
submit: gr.update(visible=True),
warmup_button: gr.update(visible=False),
welcome_message: gr.update(visible=True)
}
def inference(model_list, question, API_KEY, cot, hf_token=HF_TOKEN):
if len(model_list) != 3:
raise gr.Error("Please choose just '3' models! Neither more nor less!")
for i in range(len(model_list)):
model_list[i] = model_list[i].lower()
model_response = Inference(model_list, question, API_KEY, cot, hf_token)
return {
output_msg: gr.update(visible=True),
output_col: gr.update(visible=True),
model1_output1: model_response["agent_response"][model_list[0]][0],
model2_output1: model_response["agent_response"][model_list[1]][0],
model3_output1: model_response["agent_response"][model_list[2]][0],
summarization_text1: model_response["summarization"][0],
model1_output2: model_response["agent_response"][model_list[0]][1],
model2_output2: model_response["agent_response"][model_list[1]][1],
model3_output2: model_response["agent_response"][model_list[2]][1],
summarization_text2: model_response["summarization"][1],
model1_output3: model_response["agent_response"][model_list[0]][2],
model2_output3: model_response["agent_response"][model_list[1]][2],
model3_output3: model_response["agent_response"][model_list[2]][2]
}
def load_responses():
with open("result/Math/math_result.json", "r") as math_file:
math_responses = json.load(math_file)
with open("result/Math/math_result_cot.json", "r") as math_cot_file:
math_cot_responses = json.load(math_cot_file)
with open("result/GSM8K/gsm_result.json", "r") as gsm_file:
gsm_responses = json.load(gsm_file)
with open("result/GSM8K/gsm_result_cot.json", "r") as gsm_cot_file:
gsm_cot_responses = json.load(gsm_cot_file)
with open("result/MMLU/mmlu_result.json", "r") as mmlu_file:
mmlu_responses = json.load(mmlu_file)
with open("result/MMLU/mmlu_result_cot.json", "r") as mmlu_cot_file:
mmlu_cot_responses = json.load(mmlu_cot_file)
return math_responses, math_cot_responses, gsm_responses, gsm_cot_responses, mmlu_responses, mmlu_cot_responses
def load_questions(math, gsm, mmlu):
math_questions = []
gsm_questions = []
mmlu_questions = []
for i in range(100):
math_questions.append(f"{i+1}: " + math[i]["question"][:128] + "...")
gsm_questions.append(f"{i+1}: " + gsm[i]["question"][:128] + "...")
mmlu_questions.append(f"{i+1}: " + mmlu[i]["question"][:128] + "...")
return math_questions, gsm_questions, mmlu_questions
math_result, math_cot_result, gsm_result, gsm_cot_result, mmlu_result, mmlu_cot_result = load_responses()
math_questions, gsm_questions, mmlu_questions = load_questions(math_result, gsm_result, mmlu_result)
math_question_selector_map = build_question_selector_map(math_result)
math_cot_question_selector_map = build_question_selector_map(math_cot_result)
gsm_question_selector_map = build_question_selector_map(gsm_result)
gsm_cot_question_selector_map = build_question_selector_map(gsm_cot_result)
mmlu_question_selector_map = build_question_selector_map(mmlu_result)
mmlu_cot_question_selector_map = build_question_selector_map(mmlu_cot_result)
TITLE = """<h1 align="center">LLM Agora 🗣️🏦</h1>"""
INTRODUCTION_TEXT = """
The **LLM Agora** 🗣️🏦 aims to improve the quality of open-source LMs' responses through debate & revision introduced in [Improving Factuality and Reasoning in Language Models through Multiagent Debate](https://arxiv.org/abs/2305.14325).
Thank you to the authors of this paper for suggesting a great idea!
Do you know that? 🤔 **LLMs can also improve their responses by debating with other LLMs**! 😮 We applied this concept to several open-source LMs to verify that the open-source model, not the proprietary one, can sufficiently improve the response through discussion. 🤗
For more details, please refer to the [GitHub Repository](https://github.com/gauss5930/LLM-Agora).
You can also check the results in this Space!
You can use LLM Agora with your own questions if the response of open-source LM is not satisfactory and you want to improve the quality!
The Math, GSM8K, and MMLU Tabs show the results of the experiment(Llama2, WizardLM2, Orca2), and for inference, please use the 'Inference' tab.
Here's how to use LLM Agora!
1. Before starting, choose just 3 models and click the 'Warm-up LLM Agora 🔥' button and wait until '🤗🔥 Welcome to LLM Agora 🔥🤗' appears. (Suggest to go grab a coffee☕ since it takes 5 minutes!)
2. Once the interaction space is available, proceed with the following process.
3. Check the CoT box if you want to utilize the Chain-of-Thought while inferencing.
4. Please fill in your OpenAI API KEY, it will be used to use ChatGPT to summarize the responses.
5. Type your question in the Question box and click the 'Submit' button! If you do so, LLM Agora will show you improved answers! 🤗 (It will take roughly a minute! Please wait for an answer!)
For more detailed information, please check '※ Specific information about LLM Agora' at the bottom of the page.
※ Due to quota limitations, 'Llama2-Chat' and 'Falcon-Instruct' are currently unavailable. We will provide additional updates in the future.
"""
WELCOME_TEXT = """<h1 align="center">🤗🔥 Welcome to LLM Agora 🔥🤗</h1>"""
RESPONSE_TEXT = """<h1 align="center">🤗 Here are the responses to each model!! 🤗</h1>"""
SPECIFIC_INFORMATION = """
This is the specific information about LLM Agora!
**Tasks**
- Math: The problem of arithmetic operations on six randomly selected numbers. The format is '{}+{}*{}+{}-{}*{}=?'
- GSM8K: GSM8K is a dataset of 8.5K high quality linguistically diverse grade school math word problems created by human problem writers.
- MMLU: MMLU (Massive Multitask Language Understanding) is a new benchmark designed to measure knowledge acquired during pretraining by evaluating models exclusively in zero-shot and few-shot settings.
**Model size**
Besides Falcon, all other models are based on Llama2.
|Model name|Model size|
|---|---|
|Llama2|13B|
|Llama2-Chat|13B|
|Vicuna|13B|
|Falcon|7B|
|Falcon-Instruct|7B|
|WizardLM|13B|
|Orca|13B|
**Agent numbers & Debate rounds**
- We limit the number of agents and debate rounds because of the limitation of resources. As a result, we decided to use 3 agents and 2 rounds of debate!
**GitHub Repository**
- If you want to see more specific information, please check the [GitHub Repository](https://github.com/gauss5930/LLM-Agora) of LLM Agora!
**Citation**
```
@article{du2023improving,
title={Improving Factuality and Reasoning in Language Models through Multiagent Debate},
author={Du, Yilun and Li, Shuang and Torralba, Antonio and Tenenbaum, Joshua B and Mordatch, Igor},
journal={arXiv preprint arXiv:2305.14325},
year={2023}
}
```
"""
with gr.Blocks() as demo:
gr.HTML(TITLE)
gr.Markdown(INTRODUCTION_TEXT)
with gr.Column():
with gr.Tab("Inference"):
model_list = gr.CheckboxGroup(["Llama2", "Vicuna", "Falcon", "WizardLM", "Orca"], label="Model Selection", info="Choose 3 LMs to participate in LLM Agora.", type="value", visible=True)
warmup_button = gr.Button("Warm-up LLM Agora 🔥", visible=True)
welcome_message = gr.HTML(WELCOME_TEXT, visible=False)
with gr.Row(visible=False) as options:
cot = gr.Checkbox(label="CoT", info="Do you want to use CoT for inference?")
API_KEY = gr.Textbox(label="OpenAI API Key", value="", info="Please fill in your OpenAI API token.", placeholder="sk..", type="password")
with gr.Column(visible=False) as inputbox:
question = gr.Textbox(label="Question", value="", info="Please type your question!", placeholder="")
submit = gr.Button("Submit", visible=False)
with gr.Row(visible=False) as output_msg:
gr.HTML(RESPONSE_TEXT)
with gr.Column(visible=False) as output_col:
with gr.Row(elem_id="model1_response"):
model1_output1 = gr.Textbox(label="1️⃣ model's initial response")
model2_output1 = gr.Textbox(label="2️⃣ model's initial response")
model3_output1 = gr.Textbox(label="3️⃣ model's initial response")
summarization_text1 = gr.Textbox(label="Summarization 1")
with gr.Row(elem_id="model2_response"):
model1_output2 = gr.Textbox(label="1️⃣ model's revised response")
model2_output2 = gr.Textbox(label="2️⃣ model's revised response")
model3_output2 = gr.Textbox(label="3️⃣ model's revised response")
summarization_text2 = gr.Textbox(label="Summarization 2")
with gr.Row(elem_id="model3_response"):
model1_output3 = gr.Textbox(label="1️⃣ model's final response")
model2_output3 = gr.Textbox(label="2️⃣ model's final response")
model3_output3 = gr.Textbox(label="3️⃣ model's final response")
with gr.Tab("Math"):
math_cot = gr.Checkbox(label="CoT", info="If you want to see CoT result, please check the box.")
math_question_list = gr.Dropdown(math_questions, label="Math Question")
with gr.Column():
with gr.Row(elem_id="model1_response"):
math_model1_output1 = gr.Textbox(label="Llama2🦙's 1️⃣st response")
math_model2_output1 = gr.Textbox(label="WizardLM🧙‍♂️'s 1️⃣st response")
math_model3_output1 = gr.Textbox(label="Orca🐬's 1️⃣st response")
math_summarization_text1 = gr.Textbox(label="Summarization 1️⃣")
with gr.Row(elem_id="model2_response"):
math_model1_output2 = gr.Textbox(label="Llama2🦙's 2️⃣nd response")
math_model2_output2 = gr.Textbox(label="WizardLM🧙‍♂️'s 2️⃣nd response")
math_model3_output2 = gr.Textbox(label="Orca🐬's 2️⃣nd response")
math_summarization_text2 = gr.Textbox(label="Summarization 2️⃣")
with gr.Row(elem_id="model3_response"):
math_model1_output3 = gr.Textbox(label="Llama2🦙's 3️⃣rd response")
math_model2_output3 = gr.Textbox(label="WizardLM🧙‍♂️'s 3️⃣rd response")
math_model3_output3 = gr.Textbox(label="Orca🐬's 3️⃣rd response")
gr.HTML("""<h1 align="center"> The result of Math </h1>""")
gr.HTML("""<p align="center"><img src='https://github.com/gauss5930/LLM-Agora/assets/80087878/4fc22896-1306-4a93-bd54-a7a2ff184c98'></p>""")
math_cot.select(
math_display_question_answer,
[math_question_list, math_cot],
[math_model1_output1, math_model2_output1, math_model3_output1, math_summarization_text1, math_model1_output2, math_model2_output2, math_model3_output2, math_summarization_text2, math_model1_output3, math_model2_output3, math_model3_output3]
)
math_question_list.change(
math_display_question_answer,
[math_question_list, math_cot],
[math_model1_output1, math_model2_output1, math_model3_output1, math_summarization_text1, math_model1_output2, math_model2_output2, math_model3_output2, math_summarization_text2, math_model1_output3, math_model2_output3, math_model3_output3]
)
with gr.Tab("GSM8K"):
gsm_cot = gr.Checkbox(label="CoT", info="If you want to see CoT result, please check the box.")
gsm_question_list = gr.Dropdown(gsm_questions, label="GSM8K Question")
with gr.Column():
with gr.Row(elem_id="model1_response"):
gsm_model1_output1 = gr.Textbox(label="Llama2🦙's 1️⃣st response")
gsm_model2_output1 = gr.Textbox(label="WizardLM🧙‍♂️'s 1️⃣st response")
gsm_model3_output1 = gr.Textbox(label="Orca🐬's 1️⃣st response")
gsm_summarization_text1 = gr.Textbox(label="Summarization 1️⃣")
with gr.Row(elem_id="model2_response"):
gsm_model1_output2 = gr.Textbox(label="Llama2🦙's 2️⃣nd response")
gsm_model2_output2 = gr.Textbox(label="WizardLM🧙‍♂️'s 2️⃣nd response")
gsm_model3_output2 = gr.Textbox(label="Orca🐬's 2️⃣nd response")
gsm_summarization_text2 = gr.Textbox(label="Summarization 2️⃣")
with gr.Row(elem_id="model3_response"):
gsm_model1_output3 = gr.Textbox(label="Llama2🦙's 3️⃣rd response")
gsm_model2_output3 = gr.Textbox(label="WizardLM🧙‍♂️'s 3️⃣rd response")
gsm_model3_output3 = gr.Textbox(label="Orca🐬's 3️⃣rd response")
gr.HTML("""<h1 align="center"> The result of GSM8K </h1>""")
gr.HTML("""<p align="center"><img src="https://github.com/gauss5930/LLM-Agora/assets/80087878/64f05ea4-5bec-41e4-83d7-d8855e753290"></p>""")
gsm_cot.select(
gsm_display_question_answer,
[gsm_question_list, gsm_cot],
[gsm_model1_output1, gsm_model2_output1, gsm_model3_output1, gsm_summarization_text1, gsm_model1_output2, gsm_model2_output2, gsm_model3_output2, gsm_summarization_text2, gsm_model1_output3, gsm_model2_output3, gsm_model3_output3]
)
gsm_question_list.change(
gsm_display_question_answer,
[gsm_question_list, gsm_cot],
[gsm_model1_output1, gsm_model2_output1, gsm_model3_output1, gsm_summarization_text1, gsm_model1_output2, gsm_model2_output2, gsm_model3_output2, gsm_summarization_text2, gsm_model1_output3, gsm_model2_output3, gsm_model3_output3]
)
with gr.Tab("MMLU"):
mmlu_cot = gr.Checkbox(label="CoT", info="If you want to see CoT result, please check the box.")
mmlu_question_list = gr.Dropdown(mmlu_questions, label="MMLU Question")
with gr.Column():
with gr.Row(elem_id="model1_response"):
mmlu_model1_output1 = gr.Textbox(label="Llama2🦙's 1️⃣st response")
mmlu_model2_output1 = gr.Textbox(label="WizardLM🧙‍♂️'s 1️⃣st response")
mmlu_model3_output1 = gr.Textbox(label="Orca🐬's 1️⃣st response")
mmlu_summarization_text1 = gr.Textbox(label="Summarization 1️⃣")
with gr.Row(elem_id="model2_response"):
mmlu_model1_output2 = gr.Textbox(label="Llama2🦙's 2️⃣nd response")
mmlu_model2_output2 = gr.Textbox(label="WizardLM🧙‍♂️'s 2️⃣nd response")
mmlu_model3_output2 = gr.Textbox(label="Orca🐬's 2️⃣nd response")
mmlu_summarization_text2 = gr.Textbox(label="Summarization 2️⃣")
with gr.Row(elem_id="model3_response"):
mmlu_model1_output3 = gr.Textbox(label="Llama2🦙's 3️⃣rd response")
mmlu_model2_output3 = gr.Textbox(label="WizardLM🧙‍♂️'s 3️⃣rd response")
mmlu_model3_output3 = gr.Textbox(label="Orca🐬's 3️⃣rd response")
gr.HTML("""<h1 align="center"> The result of MMLU </h1>""")
gr.HTML("""<p align="center"><img src="https://github.com/composable-models/llm_multiagent_debate/assets/80087878/963571aa-228b-4d73-9082-5f528552383e"></p>""")
mmlu_cot.select(
mmlu_display_question_answer,
[mmlu_question_list, mmlu_cot],
[mmlu_model1_output1, mmlu_model2_output1, mmlu_model3_output1, mmlu_summarization_text1, mmlu_model1_output2, mmlu_model2_output2, mmlu_model3_output2, mmlu_summarization_text2, mmlu_model1_output3, mmlu_model2_output3, mmlu_model3_output3]
)
mmlu_question_list.change(
mmlu_display_question_answer,
[mmlu_question_list, mmlu_cot],
[mmlu_model1_output1, mmlu_model2_output1, mmlu_model3_output1, mmlu_summarization_text1, mmlu_model1_output2, mmlu_model2_output2, mmlu_model3_output2, mmlu_summarization_text2, mmlu_model1_output3, mmlu_model2_output3, mmlu_model3_output3]
)
with gr.Accordion("※ Specific information about LLM Agora", open=False):
gr.Markdown(SPECIFIC_INFORMATION)
warmup_button.click(warmup, [model_list], [model_list, options, inputbox, submit, warmup_button, welcome_message])
submit.click(inference, [model_list, question, API_KEY, cot], [output_msg, output_col, model1_output1, model2_output1, model3_output1, summarization_text1, model1_output2, model2_output2, model3_output2, summarization_text2, model1_output3, model2_output3, model3_output3])
demo.launch()