Joshua Sundance Bailey commited on
Commit
68a3064
1 Parent(s): 030f7e5
docker-compose.yml CHANGED
@@ -4,6 +4,8 @@ services:
4
  langchain-streamlit-demo:
5
  image: langchain-streamlit-demo:latest
6
  build: .
 
 
7
  ports:
8
  - "${APP_PORT:-7860}:${APP_PORT:-7860}"
9
  command: [
 
4
  langchain-streamlit-demo:
5
  image: langchain-streamlit-demo:latest
6
  build: .
7
+ env_file:
8
+ - .env
9
  ports:
10
  - "${APP_PORT:-7860}:${APP_PORT:-7860}"
11
  command: [
langchain-streamlit-demo/app.py CHANGED
@@ -1,17 +1,25 @@
1
  import os
2
  from datetime import datetime
 
3
  from typing import Union
4
 
5
  import anthropic
6
  import openai
7
  import streamlit as st
8
  from langchain import LLMChain
 
9
  from langchain.callbacks.base import BaseCallbackHandler
10
  from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
11
  from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
 
12
  from langchain.chat_models import ChatOpenAI, ChatAnyscale, ChatAnthropic
 
 
13
  from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
14
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
 
 
 
15
  from langsmith.client import Client
16
  from streamlit_feedback import streamlit_feedback
17
 
@@ -31,8 +39,10 @@ def st_init_null(*variable_names) -> None:
31
  st_init_null(
32
  "chain",
33
  "client",
 
34
  "llm",
35
  "ls_tracer",
 
36
  "run",
37
  "run_id",
38
  "trace_link",
@@ -93,6 +103,22 @@ PROVIDER_KEY_DICT = {
93
  "Anyscale Endpoints": os.environ.get("ANYSCALE_API_KEY", ""),
94
  "LANGSMITH": os.environ.get("LANGCHAIN_API_KEY", ""),
95
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
 
98
  # --- Sidebar ---
@@ -106,11 +132,35 @@ with sidebar:
106
  index=SUPPORTED_MODELS.index(DEFAULT_MODEL),
107
  )
108
 
109
- # document_chat = st.checkbox(
110
- # "Document Chat",
111
- # value=False,
112
- # help="Upload a document",
113
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  if st.button("Clear message history"):
116
  STMEMORY.clear()
@@ -150,13 +200,6 @@ with sidebar:
150
  )
151
 
152
  # --- API Keys ---
153
- provider = MODEL_DICT[model]
154
-
155
- provider_api_key = PROVIDER_KEY_DICT.get(provider) or st.text_input(
156
- f"{provider} API key",
157
- type="password",
158
- )
159
-
160
  LANGSMITH_API_KEY = PROVIDER_KEY_DICT.get("LANGSMITH") or st.text_input(
161
  "LangSmith API Key (optional)",
162
  type="password",
@@ -217,30 +260,38 @@ for msg in STMEMORY.messages:
217
 
218
  # --- Current Chat ---
219
  if st.session_state.llm:
220
- # if isinstance(retriever, BaseRetriever):
221
- # # --- Document Chat ---
222
- # chain = ConversationalRetrievalChain.from_llm(
223
- # st.session_state.llm,
224
- # retriever,
225
- # memory=_MEMORY,
226
- # )
227
- # else:
228
- # --- Regular Chat ---
229
- chat_prompt = ChatPromptTemplate.from_messages(
230
- [
231
- (
232
- "system",
233
- system_prompt + "\nIt's currently {time}.",
234
- ),
235
- MessagesPlaceholder(variable_name="chat_history"),
236
- ("human", "{input}"),
237
- ],
238
- ).partial(time=lambda: str(datetime.now()))
239
- st.session_state.chain = LLMChain(
240
- prompt=chat_prompt,
241
- llm=st.session_state.llm,
242
- memory=MEMORY,
243
- )
 
 
 
 
 
 
 
 
244
 
245
  # --- Chat Input ---
246
  prompt = st.chat_input(placeholder="Ask me a question!")
@@ -251,18 +302,42 @@ if st.session_state.llm:
251
 
252
  # --- Chat Output ---
253
  with st.chat_message("assistant", avatar="🦜"):
254
- message_placeholder = st.empty()
255
- stream_handler = StreamHandler(message_placeholder)
256
- callbacks = [RUN_COLLECTOR, stream_handler]
257
  if st.session_state.ls_tracer:
258
  callbacks.append(st.session_state.ls_tracer)
259
 
 
 
 
 
 
 
 
 
260
  try:
261
- full_response = st.session_state.chain(
262
- {"input": prompt},
263
- callbacks=callbacks,
264
- tags=["Streamlit Chat"],
265
- )["text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  except (openai.error.AuthenticationError, anthropic.AuthenticationError):
267
  st.error(
268
  f"Please enter a valid {provider} API key.",
@@ -270,8 +345,6 @@ if st.session_state.llm:
270
  )
271
  full_response = None
272
  if full_response:
273
- message_placeholder.markdown(full_response)
274
-
275
  # --- Tracing ---
276
  if st.session_state.client:
277
  st.session_state.run = RUN_COLLECTOR.traced_runs[0]
 
1
  import os
2
  from datetime import datetime
3
+ from tempfile import NamedTemporaryFile
4
  from typing import Union
5
 
6
  import anthropic
7
  import openai
8
  import streamlit as st
9
  from langchain import LLMChain
10
+ from langchain.callbacks import StreamlitCallbackHandler
11
  from langchain.callbacks.base import BaseCallbackHandler
12
  from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
13
  from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
14
+ from langchain.chains import RetrievalQA
15
  from langchain.chat_models import ChatOpenAI, ChatAnyscale, ChatAnthropic
16
+ from langchain.document_loaders import PyPDFLoader
17
+ from langchain.embeddings import OpenAIEmbeddings
18
  from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
19
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
20
+ from langchain.schema.retriever import BaseRetriever
21
+ from langchain.text_splitter import CharacterTextSplitter
22
+ from langchain.vectorstores import FAISS
23
  from langsmith.client import Client
24
  from streamlit_feedback import streamlit_feedback
25
 
 
39
  st_init_null(
40
  "chain",
41
  "client",
42
+ "doc_chain",
43
  "llm",
44
  "ls_tracer",
45
+ "retriever",
46
  "run",
47
  "run_id",
48
  "trace_link",
 
103
  "Anyscale Endpoints": os.environ.get("ANYSCALE_API_KEY", ""),
104
  "LANGSMITH": os.environ.get("LANGCHAIN_API_KEY", ""),
105
  }
106
+ OPENAI_API_KEY = PROVIDER_KEY_DICT["OpenAI"]
107
+
108
+
109
+ @st.cache_data
110
+ def get_retriever(uploaded_file_bytes: bytes) -> BaseRetriever:
111
+ with NamedTemporaryFile() as temp_file:
112
+ temp_file.write(uploaded_file_bytes)
113
+ temp_file.seek(0)
114
+
115
+ loader = PyPDFLoader(temp_file.name)
116
+ documents = loader.load()
117
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
118
+ texts = text_splitter.split_documents(documents)
119
+ embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
120
+ db = FAISS.from_documents(texts, embeddings)
121
+ return db.as_retriever()
122
 
123
 
124
  # --- Sidebar ---
 
132
  index=SUPPORTED_MODELS.index(DEFAULT_MODEL),
133
  )
134
 
135
+ provider = MODEL_DICT[model]
136
+
137
+ provider_api_key = PROVIDER_KEY_DICT.get(provider) or st.text_input(
138
+ f"{provider} API key",
139
+ type="password",
140
+ )
141
+
142
+ uploaded_file = st.file_uploader("Upload a PDF", type="pdf")
143
+
144
+ openai_api_key = (
145
+ provider_api_key
146
+ if provider == "OpenAI"
147
+ else OPENAI_API_KEY
148
+ or st.sidebar.text_input("OpenAI API Key: ", type="password")
149
+ )
150
+
151
+ if uploaded_file:
152
+ if openai_api_key:
153
+ st.session_state.retriever = get_retriever(
154
+ uploaded_file_bytes=uploaded_file.getvalue(),
155
+ )
156
+ else:
157
+ st.error("Please enter a valid OpenAI API key.", icon="❌")
158
+
159
+ document_chat = st.checkbox(
160
+ "Document Chat",
161
+ value=False,
162
+ help="Uploaded document will provide context for the chat.",
163
+ )
164
 
165
  if st.button("Clear message history"):
166
  STMEMORY.clear()
 
200
  )
201
 
202
  # --- API Keys ---
 
 
 
 
 
 
 
203
  LANGSMITH_API_KEY = PROVIDER_KEY_DICT.get("LANGSMITH") or st.text_input(
204
  "LangSmith API Key (optional)",
205
  type="password",
 
260
 
261
  # --- Current Chat ---
262
  if st.session_state.llm:
263
+ # --- Document Chat ---
264
+ if st.session_state.retriever:
265
+ # st.session_state.doc_chain = ConversationalRetrievalChain.from_llm(
266
+ # st.session_state.llm,
267
+ # st.session_state.retriever,
268
+ # memory=MEMORY,
269
+ # )
270
+
271
+ st.session_state.doc_chain = RetrievalQA.from_chain_type(
272
+ llm=st.session_state.llm,
273
+ chain_type="stuff",
274
+ retriever=st.session_state.retriever,
275
+ memory=MEMORY,
276
+ )
277
+
278
+ else:
279
+ # --- Regular Chat ---
280
+ chat_prompt = ChatPromptTemplate.from_messages(
281
+ [
282
+ (
283
+ "system",
284
+ system_prompt + "\nIt's currently {time}.",
285
+ ),
286
+ MessagesPlaceholder(variable_name="chat_history"),
287
+ ("human", "{query}"),
288
+ ],
289
+ ).partial(time=lambda: str(datetime.now()))
290
+ st.session_state.chain = LLMChain(
291
+ prompt=chat_prompt,
292
+ llm=st.session_state.llm,
293
+ memory=MEMORY,
294
+ )
295
 
296
  # --- Chat Input ---
297
  prompt = st.chat_input(placeholder="Ask me a question!")
 
302
 
303
  # --- Chat Output ---
304
  with st.chat_message("assistant", avatar="🦜"):
305
+ callbacks = [RUN_COLLECTOR]
306
+
 
307
  if st.session_state.ls_tracer:
308
  callbacks.append(st.session_state.ls_tracer)
309
 
310
+ use_document_chat = all(
311
+ [
312
+ document_chat,
313
+ st.session_state.doc_chain,
314
+ st.session_state.retriever,
315
+ ],
316
+ )
317
+
318
  try:
319
+ if use_document_chat:
320
+ st_handler = StreamlitCallbackHandler(st.container())
321
+ callbacks.append(st_handler)
322
+ full_response = st.session_state.doc_chain(
323
+ {"query": prompt},
324
+ callbacks=callbacks,
325
+ tags=["Streamlit Chat"],
326
+ return_only_outputs=True,
327
+ )[st.session_state.doc_chain.output_key]
328
+ st_handler._complete_current_thought()
329
+ st.markdown(full_response)
330
+ else:
331
+ message_placeholder = st.empty()
332
+ stream_handler = StreamHandler(message_placeholder)
333
+ callbacks.append(stream_handler)
334
+ full_response = st.session_state.chain(
335
+ {"query": prompt},
336
+ callbacks=callbacks,
337
+ tags=["Streamlit Chat"],
338
+ return_only_outputs=True,
339
+ )[st.session_state.chain.output_key]
340
+ message_placeholder.markdown(full_response)
341
  except (openai.error.AuthenticationError, anthropic.AuthenticationError):
342
  st.error(
343
  f"Please enter a valid {provider} API key.",
 
345
  )
346
  full_response = None
347
  if full_response:
 
 
348
  # --- Tracing ---
349
  if st.session_state.client:
350
  st.session_state.run = RUN_COLLECTOR.traced_runs[0]
requirements.txt CHANGED
@@ -1,7 +1,9 @@
1
  anthropic==0.3.11
 
2
  langchain==0.0.293
3
  langsmith==0.0.38
4
  openai==0.28.0
 
5
  streamlit==1.26.0
6
  streamlit-feedback==0.1.2
7
  tiktoken==0.5.1
 
1
  anthropic==0.3.11
2
+ faiss-cpu==1.7.4
3
  langchain==0.0.293
4
  langsmith==0.0.38
5
  openai==0.28.0
6
+ pypdf==3.16.1
7
  streamlit==1.26.0
8
  streamlit-feedback==0.1.2
9
  tiktoken==0.5.1