|
import os |
|
import time |
|
from itertools import islice |
|
import shutil |
|
from threading import Thread |
|
|
|
import lancedb |
|
import gradio as gr |
|
import polars as pl |
|
from datasets import load_dataset |
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
STYLE = """ |
|
.gradio-container td span { |
|
overflow: auto !important; |
|
} |
|
""".strip() |
|
|
|
|
|
EMBEDDING_MODEL = SentenceTransformer("TaylorAI/bge-micro") |
|
|
|
MAX_N_ROWS = 3_000_000 |
|
N_ROWS_BATCH = 5_000 |
|
N_SEARCH_RESULTS = 15 |
|
CRAWL_DUMP = "CC-MAIN-2020-05" |
|
DB = None |
|
DISPLAY_COLUMNS = [ |
|
"text", |
|
"url", |
|
"token_count", |
|
"count", |
|
] |
|
DISPLAY_COLUMN_TYPES = [ |
|
"str", |
|
"str", |
|
"number", |
|
"number", |
|
] |
|
DISPLAY_COLUMN_WIDTHS = [ |
|
"300px", |
|
"100px", |
|
"50px", |
|
"25px", |
|
] |
|
|
|
def rename_embedding_column(row): |
|
vector = row["embedding"] |
|
row["vector"] = vector |
|
del row["embedding"] |
|
return row |
|
|
|
|
|
def read_header_markdown() -> str: |
|
with open("./README.md", "r") as fp: |
|
text = fp.read(-1) |
|
|
|
|
|
text = text.split("\n---\n")[-1] |
|
return text.replace("{{CRAWL_DUMP}}", CRAWL_DUMP) |
|
|
|
|
|
def db(): |
|
global DB |
|
if DB is None: |
|
DB = lancedb.connect("data") |
|
return DB |
|
|
|
def load_data_sample(): |
|
time.sleep(5) |
|
|
|
|
|
if os.path.exists("data"): |
|
shutil.rmtree("data") |
|
|
|
rows = load_dataset( |
|
"airtrain-ai/fineweb-edu-fortified", |
|
name=CRAWL_DUMP, |
|
split="train", |
|
streaming=True, |
|
) |
|
|
|
print("Loading data") |
|
|
|
|
|
|
|
|
|
|
|
sample = islice(rows, MAX_N_ROWS) |
|
|
|
table = None |
|
n_rows_loaded = 0 |
|
while True: |
|
batch = list(islice(sample, N_ROWS_BATCH)) |
|
if len(batch) == 0: |
|
break |
|
|
|
|
|
|
|
data = [rename_embedding_column(row) for row in batch] |
|
n_rows_loaded += len(data) |
|
|
|
if table is None: |
|
print("Creating table") |
|
table = db().create_table("data", data=data) |
|
|
|
|
|
print("Indexing table") |
|
table.create_index(num_sub_vectors=1) |
|
else: |
|
table.add(data) |
|
|
|
print(f"Loaded {n_rows_loaded} rows") |
|
print("Done loading data") |
|
|
|
|
|
|
|
def search(search_phrase: str) -> tuple[pl.DataFrame, int]: |
|
while "data" not in db().table_names(): |
|
|
|
|
|
time.sleep(1) |
|
|
|
|
|
embedding = EMBEDDING_MODEL.encode([search_phrase])[0] |
|
|
|
|
|
table = db().open_table("data") |
|
data_frame = table.search(embedding).limit(N_SEARCH_RESULTS).to_polars() |
|
|
|
return ( |
|
|
|
data_frame.select(*[pl.col(c) for c in DISPLAY_COLUMNS]).to_pandas(), |
|
table.count_rows(), |
|
) |
|
|
|
|
|
|
|
with gr.Blocks(css=STYLE) as demo: |
|
gr.HTML(f"<style>{STYLE}</style>") |
|
with gr.Row(): |
|
gr.Markdown(read_header_markdown()) |
|
with gr.Row(): |
|
input_text = gr.Textbox(label="Search phrase", scale=100) |
|
search_button = gr.Button("Search", scale=1, min_width=100) |
|
with gr.Row(): |
|
rows_searched = gr.Number( |
|
label="Rows searched", |
|
show_label=True, |
|
) |
|
with gr.Row(): |
|
search_results = gr.DataFrame( |
|
headers=DISPLAY_COLUMNS, |
|
type="pandas", |
|
datatype=DISPLAY_COLUMN_TYPES, |
|
row_count=N_SEARCH_RESULTS, |
|
col_count=(len(DISPLAY_COLUMNS), "fixed"), |
|
column_widths=DISPLAY_COLUMN_WIDTHS, |
|
elem_classes=".df-text-col", |
|
) |
|
search_button.click( |
|
search, |
|
[input_text], |
|
[search_results, rows_searched], |
|
) |
|
|
|
|
|
|
|
|
|
data_load_thread = Thread(target=load_data_sample, daemon=True) |
|
data_load_thread.start() |
|
|
|
print("Launching app") |
|
demo.launch() |
|
|