SambaNova-fast / sambanova.py
Xianbao QIAN
initial version
d76943e
raw
history blame
No virus
3.01 kB
import requests
import json
import os
def _stream_chat_response(url, headers, payload):
"""
Streams the chat response from the given URL with the specified headers and payload.
Args:
url (str): The URL to send the POST request to.
headers (dict): The headers for the POST request.
payload (dict): The payload for the POST request.
Raises:
InvalidArgument: If the payload does not have the 'stream' key.
ConnectionError: If the request fails.
Yields:
str: The content of the streamed response.
"""
if not payload.get('stream'):
raise ValueError('This method can only handle stream payload')
try:
# Make the POST request
response = requests.post(url, headers=headers, json=payload, stream=True)
response.raise_for_status() # Raise an error for bad status codes
# Process the streamed response
for line in response.iter_lines():
if line:
decoded_line = line.decode('utf-8')
DATA_PREFIX = "data: "
if decoded_line.startswith(DATA_PREFIX):
decoded_line = decoded_line[len(DATA_PREFIX):] # Remove the "data: " prefix
if decoded_line.strip() == "[DONE]":
break
try:
json_data = json.loads(decoded_line)
content = json_data.get('choices', [{}])[0].get('delta', {}).get('content', '')
if content:
yield content
except json.JSONDecodeError as e:
print(f"Warning: Error decoding JSON: {decoded_line}. Skipping this line.")
except requests.RequestException as e:
raise ConnectionError(f"Request failed: {e}") from e
def Streamer(history, **kwargs):
"""
Streams the chat response based on the provided history and additional kwargs.
Args:
history (dict): The chat history.
**kwargs: Additional parameters to update the payload.
Yields:
str: The content of the streamed response.
"""
url = os.getenv('URL')
token = os.getenv('TOKEN')
if not url or not token:
raise EnvironmentError("URL or TOKEN environment variable is not set.")
headers = {
"Authorization": f"Basic {token}",
"Content-Type": "application/json"
}
payload = {
"messages": history,
"max_tokens": 1000,
"stop": ["<|eot_id|>"],
"model": "llama3-405b",
"stream": True
}
payload.update(kwargs)
for update in _stream_chat_response(url, headers, payload):
yield update
# Example usage
if __name__ == "__main__":
try:
history = [{"role": "user", "content": "Tell me a joke"}]
for content in Streamer(history):
print(content, end='')
except Exception as e:
print(f"An error occurred: {e}")