vs / app.py
Afrinetwork7's picture
Update app.py
0867e96 verified
raw
history blame contribute delete
No virus
2.1 kB
import base64
from io import BytesIO
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoProcessor
import torch
from PIL import Image
import subprocess
# Install flash-attn
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
app = FastAPI()
models = {
"microsoft/Phi-3.5-vision-instruct": AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-3.5-vision-instruct",
trust_remote_code=True,
torch_dtype="auto",
attn_implementation="flash_attention_2"
).cuda().eval()
}
processors = {
"microsoft/Phi-3.5-vision-instruct": AutoProcessor.from_pretrained(
"microsoft/Phi-3.5-vision-instruct",
trust_remote_code=True
)
}
class InputData(BaseModel):
image: str
text_input: str
model_id: str = "microsoft/Phi-3.5-vision-instruct"
@app.post("/run_example")
async def run_example(input_data: InputData):
try:
model = models[input_data.model_id]
processor = processors[input_data.model_id]
# Decode base64 image
image_data = base64.b64decode(input_data.image)
image = Image.open(BytesIO(image_data)).convert("RGB")
user_prompt = '<|user|>\n'
assistant_prompt = '<|assistant|>\n'
prompt_suffix = "<|end|>\n"
prompt = f"{user_prompt}<|image_1|>\n{input_data.text_input}{prompt_suffix}{assistant_prompt}"
inputs = processor(prompt, image, return_tensors="pt").to("cuda:0")
generate_ids = model.generate(
**inputs,
max_new_tokens=1000,
eos_token_id=processor.tokenizer.eos_token_id,
)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = processor.batch_decode(
generate_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
return {"response": response}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))