Uhhy commited on
Commit
4f21ff8
1 Parent(s): 84e0fec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -74
app.py CHANGED
@@ -4,10 +4,7 @@ from llama_cpp import Llama
4
  from concurrent.futures import ThreadPoolExecutor, as_completed
5
  import uvicorn
6
  import re
7
- from dotenv import load_dotenv
8
- import spaces
9
-
10
- load_dotenv()
11
 
12
  app = FastAPI()
13
 
@@ -50,40 +47,29 @@ model_configs = [
50
  class ModelManager:
51
  def __init__(self):
52
  self.loaded = False
 
53
 
54
  def load_model(self, model_config):
55
- try:
56
- return {"model": Llama.from_pretrained(repo_id=model_config['repo_id'], filename=model_config['filename']), "name": model_config['name']}
57
- except Exception:
58
- pass
 
59
 
60
  def load_all_models(self):
61
- if self.loaded:
62
- return global_data['models']
63
-
64
- try:
65
  with ThreadPoolExecutor() as executor:
66
- futures = [executor.submit(self.load_model, config) for config in model_configs]
67
- models = []
68
- for future in as_completed(futures):
69
- model = future.result()
70
- if model:
71
- models.append(model)
72
-
73
- global_data['models'] = models
74
  self.loaded = True
75
- return models
76
- except Exception:
77
- pass
78
 
79
  model_manager = ModelManager()
80
- model_manager.load_all_models()
81
 
82
  class ChatRequest(BaseModel):
83
  message: str
84
- top_k: int = 50
85
- top_p: float = 0.95
86
- temperature: float = 0.7
87
 
88
  def normalize_input(input_text):
89
  return input_text.strip()
@@ -97,61 +83,50 @@ def remove_duplicates(text):
97
  seen_lines = set()
98
  for line in lines:
99
  if line not in seen_lines:
100
- seen_lines.add(line)
101
  unique_lines.append(line)
 
102
  return '\n'.join(unique_lines)
103
 
104
- def remove_repetitive_responses(responses):
105
- seen = set()
106
- unique_responses = []
107
- for response in responses:
108
- normalized_response = remove_duplicates(response['response'])
109
- if normalized_response not in seen:
110
- seen.add(normalized_response)
111
- unique_responses.append(response)
112
- return unique_responses
113
-
114
- def generate_chat_response(request, model_data):
115
- model = model_data['model']
116
  try:
117
- user_input = normalize_input(request.message)
118
- response = model(user_input, top_k=request.top_k, top_p=request.top_p, temperature=request.temperature)
119
- return response
120
- except Exception:
121
- pass
122
 
123
  @app.post("/generate")
124
  async def generate(request: ChatRequest):
125
  try:
126
- responses = []
127
- models = global_data['models']
128
- for model_data in models:
129
- response = generate_chat_response(request, model_data)
130
- if response:
131
- responses.append({
132
- "model": model_data['name'],
133
- "response": response
134
- })
135
-
136
- if not responses:
137
- raise HTTPException(status_code=500, detail="Error: No responses generated.")
138
-
139
- responses = remove_repetitive_responses(responses)
140
- best_response = responses[0] if responses else {}
141
- return {
142
- "best_response": best_response,
143
- "all_responses": responses
144
- }
145
- except Exception:
146
- pass
147
-
148
- @app.api_route("/{method_name:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
149
- async def handle_request(method_name: str, request: Request):
150
  try:
151
- body = await request.json()
152
- return {"message": "Request handled successfully", "body": body}
153
- except Exception:
154
- pass
 
 
 
 
 
 
 
 
155
 
156
  if __name__ == "__main__":
157
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
4
  from concurrent.futures import ThreadPoolExecutor, as_completed
5
  import uvicorn
6
  import re
7
+ from spaces import GPU
 
 
 
8
 
9
  app = FastAPI()
10
 
 
47
  class ModelManager:
48
  def __init__(self):
49
  self.loaded = False
50
+ self.models = {}
51
 
52
  def load_model(self, model_config):
53
+ if model_config['name'] not in self.models:
54
+ try:
55
+ self.models[model_config['name']] = Llama.from_pretrained(repo_id=model_config['repo_id'], filename=model_config['filename'])
56
+ except Exception as e:
57
+ print(f"Error loading model {model_config['name']}: {e}")
58
 
59
  def load_all_models(self):
60
+ if not self.loaded:
 
 
 
61
  with ThreadPoolExecutor() as executor:
62
+ for config in model_configs:
63
+ executor.submit(self.load_model, config)
 
 
 
 
 
 
64
  self.loaded = True
65
+
66
+ return self.models
 
67
 
68
  model_manager = ModelManager()
69
+ global_data['models'] = model_manager.load_all_models()
70
 
71
  class ChatRequest(BaseModel):
72
  message: str
 
 
 
73
 
74
  def normalize_input(input_text):
75
  return input_text.strip()
 
83
  seen_lines = set()
84
  for line in lines:
85
  if line not in seen_lines:
 
86
  unique_lines.append(line)
87
+ seen_lines.add(line)
88
  return '\n'.join(unique_lines)
89
 
90
+ @GPU(duration=0)
91
+ def generate_model_response(model, inputs):
 
 
 
 
 
 
 
 
 
 
92
  try:
93
+ response = model(inputs)
94
+ return remove_duplicates(response['choices'][0]['text'])
95
+ except Exception as e:
96
+ print(f"Error generating model response: {e}")
97
+ return ""
98
 
99
  @app.post("/generate")
100
  async def generate(request: ChatRequest):
101
  try:
102
+ inputs = normalize_input(request.message)
103
+ with ThreadPoolExecutor() as executor:
104
+ futures = [
105
+ executor.submit(generate_model_response, model, inputs)
106
+ for model in global_data['models'].values()
107
+ ]
108
+ responses = [{'model': model_name, 'response': future.result()} for model_name, future in zip(global_data['models'].keys(), as_completed(futures))]
109
+ unique_responses = remove_repetitive_responses(responses)
110
+ return unique_responses
111
+ except Exception as e:
112
+ print(f"Error generating responses: {e}")
113
+ raise HTTPException(status_code=500, detail="Error generating responses")
114
+
115
+ @app.middleware("http")
116
+ async def process_request(request: Request, call_next):
 
 
 
 
 
 
 
 
 
117
  try:
118
+ response = await call_next(request)
119
+ return response
120
+ except Exception as e:
121
+ print(f"Request error: {e}")
122
+ raise HTTPException(status_code=500, detail="Internal Server Error")
123
+
124
+ def remove_repetitive_responses(responses):
125
+ unique_responses = {}
126
+ for response in responses:
127
+ if response['model'] not in unique_responses:
128
+ unique_responses[response['model']] = response['response']
129
+ return unique_responses
130
 
131
  if __name__ == "__main__":
132
+ uvicorn.run(app, host="0.0.0.0", port=7860)