How to use ONNX model in Triton efficiently?

#11
by khaerens - opened

I'm using the phi3 ONNX model in a triton model server.
It's running significantly slower than the pytorch model, probably because I'm making some obvious mistakes in setting this up. Any help would be welcome.

I want to set up the triton server to take in a user prompt as text and return the generated text, to this end I'm using a python backend that does the following:

  • tokenize the input text, create input ids tensor using hf transformers
  • create the attention mask tensor
  • create zero-initialized past_key_values tensors

then, repeat until EOS token is detected:

  • send over input_ids, attention_mask and past_key_values to the Phi3 onnx model
  • take the argmax of the returned logits and add to input ids
  • copy the returned present_key_values into new past_key_values tensors

In a diagram:

image.png

Some questions I have:

  • I'm currently using the Phi3 model in Triton by simply putting the model.onnx and respective model.onnx.data in the Triton model repository. I can't seem to figure out how to use all the other configs and python files included in the ONNX repository in Triton; am I missing something here?

  • I'm "manually" copying the present_key_values into new past_key_values tensors. This seems very inefficient. I'm not even sure the Phi3 model is using the past_key_values for anything, since on the first call I'm sending over zeros, and that works fine, which seems strange to me.

Any pointers would be helpful, this is my first time setting this up, so I'm probably doing things wrong.

Microsoft org

Closing this since it has already been posted as an issue on the ORT GenAI repo and is being tracked there

kvaishnavi changed discussion status to closed

Sign up or log in to comment