import os import time from itertools import islice import shutil from threading import Thread import lancedb import gradio as gr import polars as pl from datasets import load_dataset from sentence_transformers import SentenceTransformer STYLE = """ .gradio-container td span { overflow: auto !important; } """.strip() # EMBEDDING_MODEL = SentenceTransformer("TaylorAI/bge-micro") MAX_N_ROWS = 3_000_000 N_ROWS_BATCH = 5_000 N_SEARCH_RESULTS = 15 CRAWL_DUMP = "CC-MAIN-2020-05" DB = None DISPLAY_COLUMNS = [ "text", "url", "token_count", "count", ] DISPLAY_COLUMN_TYPES = [ "str", "str", "number", "number", ] DISPLAY_COLUMN_WIDTHS = [ "300px", "100px", "50px", "25px", ] def rename_embedding_column(row): vector = row["embedding"] row["vector"] = vector del row["embedding"] return row def read_header_markdown() -> str: with open("./README.md", "r") as fp: text = fp.read(-1) # Get only the markdown following the HF metadata section. text = text.split("\n---\n")[-1] return text.replace("{{CRAWL_DUMP}}", CRAWL_DUMP) def db(): global DB if DB is None: DB = lancedb.connect("data") return DB def load_data_sample(): time.sleep(5) # remove any data that was already there; we want to replace it. if os.path.exists("data"): shutil.rmtree("data") rows = load_dataset( "airtrain-ai/fineweb-edu-fortified", name=CRAWL_DUMP, split="train", streaming=True, ) print("Loading data") # at this point you could iterate over the rows. # Here, we'll take a sample of rows with size # MAX_N_ROWS. Using islice will load only the amount # we asked for and no extras. sample = islice(rows, MAX_N_ROWS) table = None n_rows_loaded = 0 while True: batch = list(islice(sample, N_ROWS_BATCH)) if len(batch) == 0: break # We'll put it in a vector DB for easy vector search. # rename "embedding" column to "vector" data = [rename_embedding_column(row) for row in batch] n_rows_loaded += len(data) if table is None: print("Creating table") table = db().create_table("data", data=data) # index the embedding column for fast search. print("Indexing table") table.create_index(num_sub_vectors=1) else: table.add(data) print(f"Loaded {n_rows_loaded} rows") print("Done loading data") def search(search_phrase: str) -> tuple[pl.DataFrame, int]: while "data" not in db().table_names(): # Data is loaded asynchronously. Make sure there is at least # some in the table before searching. time.sleep(1) # Create our search vector embedding = EMBEDDING_MODEL.encode([search_phrase])[0] # Search table = db().open_table("data") data_frame = table.search(embedding).limit(N_SEARCH_RESULTS).to_polars() return ( # Return only what we want to display data_frame.select(*[pl.col(c) for c in DISPLAY_COLUMNS]).to_pandas(), table.count_rows(), ) with gr.Blocks(css=STYLE) as demo: gr.HTML(f"") with gr.Row(): gr.Markdown(read_header_markdown()) with gr.Row(): input_text = gr.Textbox(label="Search phrase", scale=100) search_button = gr.Button("Search", scale=1, min_width=100) with gr.Row(): rows_searched = gr.Number( label="Rows searched", show_label=True, ) with gr.Row(): search_results = gr.DataFrame( headers=DISPLAY_COLUMNS, type="pandas", datatype=DISPLAY_COLUMN_TYPES, row_count=N_SEARCH_RESULTS, col_count=(len(DISPLAY_COLUMNS), "fixed"), column_widths=DISPLAY_COLUMN_WIDTHS, elem_classes=".df-text-col", ) search_button.click( search, [input_text], [search_results, rows_searched], ) # load data on another thread so we can start searching even before it's # all loaded. data_load_thread = Thread(target=load_data_sample, daemon=True) data_load_thread.start() print("Launching app") demo.launch()