Adir Gozlan
commited on
Commit
•
6d5ec26
1
Parent(s):
a200fe6
late commit
Browse files- app.py +33 -8
- backend/cross_encoder.py +29 -0
app.py
CHANGED
@@ -11,9 +11,13 @@ from jinja2 import Environment, FileSystemLoader
|
|
11 |
|
12 |
from backend.query_llm import generate_hf, generate_openai
|
13 |
from backend.semantic_search import retrieve
|
|
|
|
|
|
|
14 |
|
15 |
|
16 |
TOP_K = int(os.getenv("TOP_K", 4))
|
|
|
17 |
|
18 |
proj_dir = Path(__file__).parent
|
19 |
# Setting up the logging
|
@@ -34,7 +38,7 @@ def add_text(history, text):
|
|
34 |
return history, gr.Textbox(value="", interactive=False)
|
35 |
|
36 |
|
37 |
-
def bot(history, api_kind):
|
38 |
query = history[-1][0]
|
39 |
|
40 |
if not query:
|
@@ -42,12 +46,32 @@ def bot(history, api_kind):
|
|
42 |
|
43 |
logger.info('Retrieving documents...')
|
44 |
# Retrieve documents relevant to query
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
-
|
48 |
|
49 |
-
document_time = perf_counter() - document_start
|
50 |
-
logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
|
51 |
|
52 |
# Create Prompt
|
53 |
prompt = template.render(documents=documents, query=query)
|
@@ -86,19 +110,20 @@ with gr.Blocks() as demo:
|
|
86 |
)
|
87 |
txt_btn = gr.Button(value="Submit text", scale=1)
|
88 |
|
89 |
-
api_kind = gr.Radio(choices=["HuggingFace", "OpenAI"], value="HuggingFace")
|
|
|
90 |
|
91 |
prompt_html = gr.HTML()
|
92 |
# Turn off interactivity while generating if you click
|
93 |
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
94 |
-
bot, [chatbot, api_kind], [chatbot, prompt_html])
|
95 |
|
96 |
# Turn it back on
|
97 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
98 |
|
99 |
# Turn off interactivity while generating if you hit enter
|
100 |
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
101 |
-
bot, [chatbot, api_kind], [chatbot, prompt_html])
|
102 |
|
103 |
# Turn it back on
|
104 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
|
|
11 |
|
12 |
from backend.query_llm import generate_hf, generate_openai
|
13 |
from backend.semantic_search import retrieve
|
14 |
+
from backend.cross_encoder import rerank_with_cross_encoder
|
15 |
+
|
16 |
+
|
17 |
|
18 |
|
19 |
TOP_K = int(os.getenv("TOP_K", 4))
|
20 |
+
TOP_K_RERANK = int(os.getenv("TOP_K_RERANK", 40))
|
21 |
|
22 |
proj_dir = Path(__file__).parent
|
23 |
# Setting up the logging
|
|
|
38 |
return history, gr.Textbox(value="", interactive=False)
|
39 |
|
40 |
|
41 |
+
def bot(history, api_kind, cross_enc):
|
42 |
query = history[-1][0]
|
43 |
|
44 |
if not query:
|
|
|
46 |
|
47 |
logger.info('Retrieving documents...')
|
48 |
# Retrieve documents relevant to query
|
49 |
+
documents = []
|
50 |
+
if not cross_enc:
|
51 |
+
document_start = perf_counter()
|
52 |
+
|
53 |
+
documents = retrieve(query, TOP_K)
|
54 |
+
|
55 |
+
document_time = perf_counter() - document_start
|
56 |
+
logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
|
57 |
+
|
58 |
+
else:
|
59 |
+
document_start = perf_counter()
|
60 |
+
|
61 |
+
documents = retrieve(query, TOP_K_RERANK)
|
62 |
+
|
63 |
+
document_time = perf_counter() - document_start
|
64 |
+
logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
|
65 |
+
|
66 |
+
logger.info('Reranking documents')
|
67 |
+
document_start = perf_counter()
|
68 |
+
|
69 |
+
documents = rerank_with_cross_encoder(cross_enc, documents, query)
|
70 |
+
|
71 |
+
document_time = perf_counter() - document_start
|
72 |
|
73 |
+
logger.info(f'Finished Reranking documents in {round(document_time, 2)} seconds...')
|
74 |
|
|
|
|
|
75 |
|
76 |
# Create Prompt
|
77 |
prompt = template.render(documents=documents, query=query)
|
|
|
110 |
)
|
111 |
txt_btn = gr.Button(value="Submit text", scale=1)
|
112 |
|
113 |
+
api_kind = gr.Radio(choices=["HuggingFace", "OpenAI"], value="HuggingFace", label="LLM")
|
114 |
+
cross_enc = gr.Radio(choices=["None", "cross-encoder/ms-marco-MiniLM-L-6-v2", "BAAI/bge-reranker-large"], value=None, label="Cross Encoder")
|
115 |
|
116 |
prompt_html = gr.HTML()
|
117 |
# Turn off interactivity while generating if you click
|
118 |
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
119 |
+
bot, [chatbot, api_kind, cross_enc], [chatbot, prompt_html])
|
120 |
|
121 |
# Turn it back on
|
122 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
123 |
|
124 |
# Turn off interactivity while generating if you hit enter
|
125 |
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
126 |
+
bot, [chatbot, api_kind, cross_enc], [chatbot, prompt_html])
|
127 |
|
128 |
# Turn it back on
|
129 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
backend/cross_encoder.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
5 |
+
|
6 |
+
cross_encoder = None
|
7 |
+
cross_enc_tokenizer = None
|
8 |
+
|
9 |
+
TOP_K_RERANK = os.getenv("TOP_K_RERANK", 40)
|
10 |
+
|
11 |
+
|
12 |
+
@torch.no_grad()
|
13 |
+
def rerank_with_cross_encoder(cross_enc_name, documents, query):
|
14 |
+
if cross_enc_name is None or len(documents) <= 1:
|
15 |
+
return documents
|
16 |
+
|
17 |
+
global cross_encoder, cross_enc_tokenizer
|
18 |
+
if cross_encoder is None or cross_encoder.name_or_path != cross_enc_name:
|
19 |
+
cross_encoder = AutoModelForSequenceClassification.from_pretrained(cross_enc_name)
|
20 |
+
cross_encoder.eval()
|
21 |
+
cross_enc_tokenizer = AutoTokenizer.from_pretrained(cross_enc_name)
|
22 |
+
|
23 |
+
features = cross_enc_tokenizer(
|
24 |
+
[query] * len(documents), documents, padding=True, truncation=True, return_tensors="pt"
|
25 |
+
)
|
26 |
+
scores = cross_encoder(**features).logits.squeeze()
|
27 |
+
ranks = torch.argsort(scores, descending=True)
|
28 |
+
documents = [documents[i] for i in ranks[:TOP_K_RERANK]]
|
29 |
+
return documents
|