curry tang commited on
Commit
b6ec8b9
1 Parent(s): 99a9a6e
Files changed (2) hide show
  1. app.py +40 -26
  2. llm.py +2 -2
app.py CHANGED
@@ -2,7 +2,10 @@ import gradio as gr
2
  from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
3
  from llm import DeepSeekLLM, OpenRouterLLM, TongYiLLM
4
  from config import settings
5
- from prompts import web_prompt, explain_code_template, optimize_code_template, debug_code_template, function_gen_template, translate_doc_template, backend_developer_prompt, analyst_prompt
 
 
 
6
  from langchain_core.prompts import PromptTemplate
7
  from log import logging
8
  from utils import convert_image_to_base64
@@ -21,6 +24,12 @@ provider_model_map = dict(
21
  Tongyi=tongyi_llm,
22
  )
23
 
 
 
 
 
 
 
24
  support_vision_models = [
25
  'openai/gpt-4o-mini', 'anthropic/claude-3.5-sonnet', 'google/gemini-pro-1.5-exp',
26
  'openai/gpt-4o', 'google/gemini-flash-1.5', 'liuhaotian/llava-yi-34b', 'anthropic/claude-3-haiku',
@@ -33,29 +42,39 @@ def get_default_chat():
33
  return _llm.get_chat_engine()
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def predict(message, history, _chat, _current_assistant: str):
37
  logger.info(f"chat predict: {message}, {history}, {_chat}, {_current_assistant}")
38
  files_len = len(message.files)
39
- if _chat is None:
40
- _chat = get_default_chat()
41
  if files_len > 0:
42
  if _chat.model_name not in support_vision_models:
43
  raise gr.Error("当前模型不支持图片,请更换模型。")
44
 
45
  _lc_history = []
46
- assistant_prompt = web_prompt
47
- if _current_assistant == '后端开发助手':
48
- assistant_prompt = backend_developer_prompt
49
- if _current_assistant == '数据分析师':
50
- assistant_prompt = analyst_prompt
51
- _lc_history.append(SystemMessage(content=assistant_prompt))
52
-
53
- for his_msg in history:
54
- if his_msg['role'] == 'user':
55
- if not hasattr(his_msg['content'], 'file'):
56
- _lc_history.append(HumanMessage(content=his_msg['content']))
57
- if his_msg['role'] == 'assistant':
58
- _lc_history.append(AIMessage(content=his_msg['content']))
59
 
60
  if files_len == 0:
61
  _lc_history.append(HumanMessage(content=message.text))
@@ -81,8 +100,7 @@ def update_chat(_provider: str, _model: str, _temperature: float, _max_tokens: i
81
 
82
 
83
  def explain_code(_code_type: str, _code: str, _chat):
84
- if _chat is None:
85
- _chat = get_default_chat()
86
  chat_messages = [
87
  SystemMessage(content=explain_code_template),
88
  HumanMessage(content=_code),
@@ -94,8 +112,7 @@ def explain_code(_code_type: str, _code: str, _chat):
94
 
95
 
96
  def optimize_code(_code_type: str, _code: str, _chat):
97
- if _chat is None:
98
- _chat = get_default_chat()
99
  prompt = PromptTemplate.from_template(optimize_code_template)
100
  prompt = prompt.format(code_type=_code_type)
101
  chat_messages = [
@@ -109,8 +126,7 @@ def optimize_code(_code_type: str, _code: str, _chat):
109
 
110
 
111
  def debug_code(_code_type: str, _code: str, _chat):
112
- if _chat is None:
113
- _chat = get_default_chat()
114
  prompt = PromptTemplate.from_template(debug_code_template)
115
  prompt = prompt.format(code_type=_code_type)
116
  chat_messages = [
@@ -124,8 +140,7 @@ def debug_code(_code_type: str, _code: str, _chat):
124
 
125
 
126
  def function_gen(_code_type: str, _code: str, _chat):
127
- if _chat is None:
128
- _chat = get_default_chat()
129
  prompt = PromptTemplate.from_template(function_gen_template)
130
  prompt = prompt.format(code_type=_code_type)
131
  chat_messages = [
@@ -139,8 +154,7 @@ def function_gen(_code_type: str, _code: str, _chat):
139
 
140
 
141
  def translate_doc(_language_input, _language_output, _doc, _chat):
142
- if _chat is None:
143
- _chat = get_default_chat()
144
  prompt = PromptTemplate.from_template(translate_doc_template)
145
  prompt = prompt.format(language_input=_language_input, language_output=_language_output)
146
  chat_messages = [
 
2
  from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
3
  from llm import DeepSeekLLM, OpenRouterLLM, TongYiLLM
4
  from config import settings
5
+ from prompts import (
6
+ web_prompt, explain_code_template, optimize_code_template, debug_code_template,
7
+ function_gen_template, translate_doc_template, backend_developer_prompt, analyst_prompt
8
+ )
9
  from langchain_core.prompts import PromptTemplate
10
  from log import logging
11
  from utils import convert_image_to_base64
 
24
  Tongyi=tongyi_llm,
25
  )
26
 
27
+ system_prompt_map = {
28
+ "前端开发助手": web_prompt,
29
+ "后端开发助手": backend_developer_prompt,
30
+ "数据分析师": analyst_prompt,
31
+ }
32
+
33
  support_vision_models = [
34
  'openai/gpt-4o-mini', 'anthropic/claude-3.5-sonnet', 'google/gemini-pro-1.5-exp',
35
  'openai/gpt-4o', 'google/gemini-flash-1.5', 'liuhaotian/llava-yi-34b', 'anthropic/claude-3-haiku',
 
42
  return _llm.get_chat_engine()
43
 
44
 
45
+ def get_chat_or_default(chat):
46
+ if chat is None:
47
+ chat = get_default_chat()
48
+ return chat
49
+
50
+
51
+ def convert_history_to_langchain_history(history, lc_history):
52
+ for his_msg in history:
53
+ if his_msg['role'] == 'user':
54
+ if not hasattr(his_msg['content'], 'file'):
55
+ lc_history.append(HumanMessage(content=his_msg['content']))
56
+ if his_msg['role'] == 'assistant':
57
+ lc_history.append(AIMessage(content=his_msg['content']))
58
+ return lc_history
59
+
60
+
61
+ def append_system_prompt(key: str, lc_history):
62
+ prompt = system_prompt_map[key]
63
+ lc_history.append(SystemMessage(content=prompt))
64
+ return lc_history
65
+
66
+
67
  def predict(message, history, _chat, _current_assistant: str):
68
  logger.info(f"chat predict: {message}, {history}, {_chat}, {_current_assistant}")
69
  files_len = len(message.files)
70
+ _chat = get_chat_or_default(_chat)
 
71
  if files_len > 0:
72
  if _chat.model_name not in support_vision_models:
73
  raise gr.Error("当前模型不支持图片,请更换模型。")
74
 
75
  _lc_history = []
76
+ _lc_history = append_system_prompt(_current_assistant, _lc_history)
77
+ _lc_history = convert_history_to_langchain_history(history, _lc_history)
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  if files_len == 0:
80
  _lc_history.append(HumanMessage(content=message.text))
 
100
 
101
 
102
  def explain_code(_code_type: str, _code: str, _chat):
103
+ _chat = get_chat_or_default(_chat)
 
104
  chat_messages = [
105
  SystemMessage(content=explain_code_template),
106
  HumanMessage(content=_code),
 
112
 
113
 
114
  def optimize_code(_code_type: str, _code: str, _chat):
115
+ _chat = get_chat_or_default(_chat)
 
116
  prompt = PromptTemplate.from_template(optimize_code_template)
117
  prompt = prompt.format(code_type=_code_type)
118
  chat_messages = [
 
126
 
127
 
128
  def debug_code(_code_type: str, _code: str, _chat):
129
+ _chat = get_chat_or_default(_chat)
 
130
  prompt = PromptTemplate.from_template(debug_code_template)
131
  prompt = prompt.format(code_type=_code_type)
132
  chat_messages = [
 
140
 
141
 
142
  def function_gen(_code_type: str, _code: str, _chat):
143
+ _chat = get_chat_or_default(_chat)
 
144
  prompt = PromptTemplate.from_template(function_gen_template)
145
  prompt = prompt.format(code_type=_code_type)
146
  chat_messages = [
 
154
 
155
 
156
  def translate_doc(_language_input, _language_output, _doc, _chat):
157
+ _chat = get_chat_or_default(_chat)
 
158
  prompt = PromptTemplate.from_template(translate_doc_template)
159
  prompt = prompt.format(language_input=_language_input, language_output=_language_output)
160
  chat_messages = [
llm.py CHANGED
@@ -60,8 +60,8 @@ class DeepSeekLLM(BaseLLM):
60
 
61
  class OpenRouterLLM(BaseLLM):
62
  _support_models = [
63
- 'openai/gpt-4o-mini', 'anthropic/claude-3.5-sonnet', 'google/gemini-pro-1.5-exp',
64
- 'mistralai/mistral-large', 'meta-llama/llama-3.1-405b-instruct', 'openai/gpt-4o',
65
  'nvidia/nemotron-4-340b-instruct', 'deepseek/deepseek-coder', 'google/gemma-2-27b-it',
66
  'google/gemini-flash-1.5', 'deepseek/deepseek-chat', 'qwen/qwen-2-72b-instruct',
67
  'liuhaotian/llava-yi-34b', 'qwen/qwen-110b-chat',
 
60
 
61
  class OpenRouterLLM(BaseLLM):
62
  _support_models = [
63
+ 'openai/gpt-4o-mini', 'anthropic/claude-3.5-sonnet', 'openai/gpt-4o-2024-08-06',
64
+ 'google/gemini-pro-1.5-exp', 'mistralai/mistral-large', 'meta-llama/llama-3.1-405b-instruct',
65
  'nvidia/nemotron-4-340b-instruct', 'deepseek/deepseek-coder', 'google/gemma-2-27b-it',
66
  'google/gemini-flash-1.5', 'deepseek/deepseek-chat', 'qwen/qwen-2-72b-instruct',
67
  'liuhaotian/llava-yi-34b', 'qwen/qwen-110b-chat',