Add tool calling support to chat template

#68
by Rocketknight1 HF staff - opened
No description provided.

Test script to confirm equivalence:

from transformers import AutoTokenizer
from pathlib import Path
from mistral_common.protocol.instruct.tool_calls import Function, Tool

from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage, AssistantMessage, ToolMessage
from mistral_common.protocol.instruct.tool_calls import FunctionCall, ToolCall
from mistral_common.protocol.instruct.request import ChatCompletionRequest

hf_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", revision="pr/68")

hf_tool = {
                "name": "get_current_weather",
                "description": "Get the current weather",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "location": {
                            "type": "string",
                            "description": "The city and state, e.g. San Francisco, CA",
                        },
                        "format": {
                            "type": "string",
                            "enum": ["celsius", "fahrenheit"],
                            "description": "The temperature unit to use. Infer this from the users location.",
                        },
                    },
                    "required": ["location", "format"],
                },
}

hf_tool = {"type": "function", "function": hf_tool}

test_chat = [{"role": "user", "content": "What's the weather like today in Paris"}]
tool_call = {"name": "get_current_weather", "arguments": {"location": "Paris, France"}}
test_chat.append({"role": "assistant", "tool_calls": [{"type": "function", "function": tool_call, "id": "abcdef123"}]})
test_chat.append({"role": "tool", "name": "get_current_temperature", "tool_call_id": "abcdef123", "content": "22.0"})

hf_text =hf_tokenizer.apply_chat_template(test_chat, tokenize=False, tools=[hf_tool])
hf_tokens = hf_tokenizer.apply_chat_template(test_chat, tokenize=True, tools=[hf_tool])

mistral_models_path = Path.home().joinpath('mistral_models', '7B-Instruct-v0.3')
mistral_tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tokenizer.model.v3")

mistral_tool = Tool(
            function=Function(
                name="get_current_weather",
                description="Get the current weather",
                parameters={
                    "type": "object",
                    "properties": {
                        "location": {
                            "type": "string",
                            "description": "The city and state, e.g. San Francisco, CA",
                        },
                        "format": {
                            "type": "string",
                            "enum": ["celsius", "fahrenheit"],
                            "description": "The temperature unit to use. Infer this from the users location.",
                        },
                    },
                    "required": ["location", "format"],
                },
            )
        )

mistral_query = ChatCompletionRequest(
    tools=[mistral_tool],
    messages=[
        UserMessage(content="What's the weather like today in Paris"),
        AssistantMessage(tool_calls=[ToolCall(type="function", function=FunctionCall(
            name="get_current_weather", arguments={"location": "Paris, France"}), id="abcdef123"
        )]),
        ToolMessage(content="22.0", tool_call_id="abcdef123")
    ],
    model="test",
)
encodeds = mistral_tokenizer.encode_chat_completion(mistral_query).text
mistral_text = encodeds.replace("▁", " ")
mistral_tokens = mistral_tokenizer.encode_chat_completion(mistral_query).tokens

print(hf_text == mistral_text)
print(hf_tokens == mistral_tokens)
pandora-s changed pull request status to merged

The changes to chat template cause issue when used with openai openapi specced tools.
This line describing tool_calls which is a part of ChatCompletionRequestAssistantMessage is the cause of the issue. For assistant responses, tool_calls field is existent and set to None, hence is filtered out when we do
| selectattr("tool_calls", "undefined").
So the resulting loop_messageswould not have assistant responses. The 2nd instance of user message would be indexed 1 (instead of 2) and hence throws the error After the optional system message, conversation roles must alternate user/assistant/user/assistant/....

Steps to reproduce:

from transformers import AutoTokenizer
from kserve.protocol.rest.openai.types.openapi import ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, ChatCompletionRequestAssistantMessage

messages = [
    ChatCompletionRequestSystemMessage(
        content="You are a pirate chatbot who always responds in pirate speak!",
        role="system",
    ),
    ChatCompletionRequestUserMessage(
        content="Hi, who are you?",
        role="user"
    ),
    ChatCompletionRequestAssistantMessage(
        content="I am an AI model created by MistralAI",
        role="assistant"
    ),
    ChatCompletionRequestUserMessage(
        content="Tell me about robots",
        role="user"
    ),
]

messages_as_list = [
    {
        "content":"You are a pirate chatbot who always responds in pirate speak!",
        "role":"system",
    },
    {
        "content":"Hi, who are you?",
        "role":"user"
    },
    {
        "content":"I am an AI model created by MistralAI",
        "role":"assistant"
    },
    {
        "content":"Tell me about robots",
        "role":"user"
    }
]

tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-Instruct-v0.3')
templated_list = tokenizer.apply_chat_template(messages_as_list,tokenize=False)
print(templated_list)
templated_messages = tokenizer.apply_chat_template(messages,tokenize=False)

Response:

<s>[INST] Hi, who are you?[/INST] I am an AI model created by MistralAI</s>[INST] You are a pirate chatbot who always responds in pirate speak!

Tell me about robots[/INST]
{
    "name": "TemplateError",
    "message": "After the optional system message, conversation roles must alternate user/assistant/user/assistant/...",
    "stack": "---------------------------------------------------------------------------
TemplateError                             Traceback (most recent call last)
Cell In[3], line 45
     43 templated_list = tokenizer.apply_chat_template(messages_as_list,tokenize=False)
     44 print(templated_list)
---> 45 templated_messages = tokenizer.apply_chat_template(messages,tokenize=False)

File ~/venvs/kserve/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:1833, in PreTrainedTokenizerBase.apply_chat_template(self, conversation, tools, documents, chat_template, add_generation_prompt, tokenize, padding, truncation, max_length, return_tensors, return_dict, return_assistant_tokens_mask, tokenizer_kwargs, **kwargs)
   1831         all_generation_indices.append(generation_indices)
   1832     else:
-> 1833         rendered_chat = compiled_template.render(
   1834             messages=chat,
   1835             tools=tool_schemas,
   1836             documents=documents,
   1837             add_generation_prompt=add_generation_prompt,
   1838             **template_kwargs,
   1839         )
   1840     rendered.append(rendered_chat)
   1842 if not is_batched:

File ~/venvs/kserve/lib/python3.9/site-packages/jinja2/environment.py:1304, in Template.render(self, *args, **kwargs)
   1302     return self.environment.concat(self.root_render_func(ctx))  # type: ignore
   1303 except Exception:
-> 1304     self.environment.handle_exception()

File ~/venvs/kserve/lib/python3.9/site-packages/jinja2/environment.py:939, in Environment.handle_exception(self, source)
    934 \"\"\"Exception handling helper.  This is used internally to either raise
    935 rewritten exceptions or return a rendered traceback for the template.
    936 \"\"\"
    937 from .debug import rewrite_traceback_stack
--> 939 raise rewrite_traceback_stack(source=source)

File <template>:14, in top-level template code()

File ~/venvs/kserve/lib/python3.9/site-packages/jinja2/sandbox.py:394, in SandboxedEnvironment.call(_SandboxedEnvironment__self, _SandboxedEnvironment__context, _SandboxedEnvironment__obj, *args, **kwargs)
    392 if not __self.is_safe_callable(__obj):
    393     raise SecurityError(f\"{__obj!r} is not safely callable\")
--> 394 return __context.call(__obj, *args, **kwargs)

File ~/venvs/kserve/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:1914, in PreTrainedTokenizerBase._compile_jinja_template.<locals>.raise_exception(message)
   1913 def raise_exception(message):
-> 1914     raise TemplateError(message)

TemplateError: After the optional system message, conversation roles must alternate user/assistant/user/assistant/..."
}

@imdatta0 I see, let me see if I can update the template to handle cases where tool_calls is present but null.

@imdatta0 , can you try running this snippet to update your local template and then rerunning your code to check it works okay?

import json

tokenizer.chat_template = json.loads('"{%- if messages[0][\\"role\\"] == \\"system\\" %}\\n    {%- set system_message = messages[0][\\"content\\"] %}\\n    {%- set loop_messages = messages[1:] %}\\n{%- else %}\\n    {%- set loop_messages = messages %}\\n{%- endif %}\\n{%- if not tools is defined %}\\n    {%- set tools = none %}\\n{%- endif %}\\n{%- set user_messages = loop_messages | selectattr(\\"role\\", \\"equalto\\", \\"user\\") | list %}\\n\\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\\n{%- set ns = namespace() %}\\n{%- set ns.index = 0 %}\\n{%- for message in loop_messages %}\\n    {%- if not (message.role == \\"tool\\" or message.role == \\"tool_results\\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\\n        {%- if (message[\\"role\\"] == \\"user\\") != (ns.index % 2 == 0) %}\\n            {{- raise_exception(\\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\\") }}\\n        {%- endif %}\\n        {%- set ns.index = ns.index + 1 %}\\n    {%- endif %}\\n{%- endfor %}\\n\\n{{- bos_token }}\\n{%- for message in loop_messages %}\\n    {%- if message[\\"role\\"] == \\"user\\" %}\\n        {%- if tools is not none and (message == user_messages[-1]) %}\\n            {{- \\"[AVAILABLE_TOOLS] [\\" }}\\n            {%- for tool in tools %}\\n        {%- set tool = tool.function %}\\n        {{- \'{\\"type\\": \\"function\\", \\"function\\": {\' }}\\n        {%- for key, val in tool.items() if key != \\"return\\" %}\\n            {%- if val is string %}\\n            {{- \'\\"\' + key + \'\\": \\"\' + val + \'\\"\' }}\\n            {%- else %}\\n            {{- \'\\"\' + key + \'\\": \' + val|tojson }}\\n            {%- endif %}\\n            {%- if not loop.last %}\\n            {{- \\", \\" }}\\n            {%- endif %}\\n        {%- endfor %}\\n        {{- \\"}}\\" }}\\n                {%- if not loop.last %}\\n                    {{- \\", \\" }}\\n                {%- else %}\\n                    {{- \\"]\\" }}\\n                {%- endif %}\\n            {%- endfor %}\\n            {{- \\"[/AVAILABLE_TOOLS]\\" }}\\n            {%- endif %}\\n        {%- if loop.last and system_message is defined %}\\n            {{- \\"[INST] \\" + system_message + \\"\\\\n\\\\n\\" + message[\\"content\\"] + \\"[/INST]\\" }}\\n        {%- else %}\\n            {{- \\"[INST] \\" + message[\\"content\\"] + \\"[/INST]\\" }}\\n        {%- endif %}\\n    {%- elif message[\\"role\\"] == \\"tool_calls\\" or message.tool_calls is defined %}\\n        {%- if message.tool_calls is defined %}\\n            {%- set tool_calls = message.tool_calls %}\\n        {%- else %}\\n            {%- set tool_calls = message.content %}\\n        {%- endif %}\\n        {{- \\"[TOOL_CALLS] [\\" }}\\n        {%- for tool_call in tool_calls %}\\n            {%- set out = tool_call.function|tojson %}\\n            {{- out[:-1] }}\\n            {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\\n                {{- raise_exception(\\"Tool call IDs should be alphanumeric strings with length 9!\\") }}\\n            {%- endif %}\\n            {{- \', \\"id\\": \\"\' + tool_call.id + \'\\"}\' }}\\n            {%- if not loop.last %}\\n                {{- \\", \\" }}\\n            {%- else %}\\n                {{- \\"]\\" + eos_token }}\\n            {%- endif %}\\n        {%- endfor %}\\n    {%- elif message[\\"role\\"] == \\"assistant\\" %}\\n        {{- \\" \\" + message[\\"content\\"] + eos_token}}\\n    {%- elif message[\\"role\\"] == \\"tool_results\\" or message[\\"role\\"] == \\"tool\\" %}\\n        {%- if message.content is defined and message.content.content is defined %}\\n            {%- set content = message.content.content %}\\n        {%- else %}\\n            {%- set content = message.content %}\\n        {%- endif %}\\n        {{- \'[TOOL_RESULTS] {\\"content\\": \' + content|string + \\", \\" }}\\n        {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\\n            {{- raise_exception(\\"Tool call IDs should be alphanumeric strings with length 9!\\") }}\\n        {%- endif %}\\n        {{- \'\\"call_id\\": \\"\' + message.tool_call_id + \'\\"}[/TOOL_RESULTS]\' }}\\n    {%- else %}\\n        {{- raise_exception(\\"Only user and assistant roles are supported, with the exception of an initial optional system message!\\") }}\\n    {%- endif %}\\n{%- endfor %}\\n"')

Hey guys - can you clarify what the expected result is for parallel tool calls for a request like this?

{
  "model": "mistralai/Mistral-7B-Instruct-v0.3",
  "messages": [
    {
      "role": "user",
      "content": "Hi! How are you doing today?"
    },
    {
      "role": "assistant",
      "content": "I'm doing well! How can I help you?"
    },
    {
      "role": "user",
      "content": "Can you tell me what the weather will be in Dallas and Orlando in fahrenheit?"
    }
  ],
  "stream": false,
  "tools": [
    {
      "type": "function",
      "function": {
        "name": "get_current_weather",
        "description": "Get the current weather in a given location",
        "parameters": {
          "type": "object",
          "properties": {
            "city": {
              "type": "string",
              "description": "The city to find the weather for, e.g. 'San Francisco'"
            },
            "state": {
              "type": "string",
              "description": "the two-letter abbreviation for the state that the city is in, e.g. 'CA' which would mean 'California'"
            },
            "unit": {
              "type": "string",
              "description": "The unit to fetch the temperature in",
              "enum": [
                "celsius",
                "fahrenheit"
              ]
            }
          }
        }
      }
    }
  ]
}

I would expect something like
[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}], but I keep getting this using this chat template:

[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]

[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]

Can you confirm if this is the expected result, or if the model is not intended to support parallel tool calls?

Hi @Rocketknight1 , thanks for the swift response. Yeah this new change seems to fix the alternating message issue. But it throws a new error

{
    "name": "TypeError",
    "message": "'NoneType' object is not iterable",
    "stack": "---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[5], line 3
      1 # templated_list = tokenizer.apply_chat_template(messages_as_list,tokenize=False)
      2 # print(templated_list)
----> 3 templated_messages = tokenizer.apply_chat_template(messages,tokenize=False)

File ~/venvs/kserve/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:1833, in PreTrainedTokenizerBase.apply_chat_template(self, conversation, tools, documents, chat_template, add_generation_prompt, tokenize, padding, truncation, max_length, return_tensors, return_dict, return_assistant_tokens_mask, tokenizer_kwargs, **kwargs)
   1831         all_generation_indices.append(generation_indices)
   1832     else:
-> 1833         rendered_chat = compiled_template.render(
   1834             messages=chat,
   1835             tools=tool_schemas,
   1836             documents=documents,
   1837             add_generation_prompt=add_generation_prompt,
   1838             **template_kwargs,
   1839         )
   1840     rendered.append(rendered_chat)
   1842 if not is_batched:

File ~/venvs/kserve/lib/python3.9/site-packages/jinja2/environment.py:1304, in Template.render(self, *args, **kwargs)
   1302     return self.environment.concat(self.root_render_func(ctx))  # type: ignore
   1303 except Exception:
-> 1304     self.environment.handle_exception()

File ~/venvs/kserve/lib/python3.9/site-packages/jinja2/environment.py:939, in Environment.handle_exception(self, source)
    934 \"\"\"Exception handling helper.  This is used internally to either raise
    935 rewritten exceptions or return a rendered traceback for the template.
    936 \"\"\"
    937 from .debug import rewrite_traceback_stack
--> 939 raise rewrite_traceback_stack(source=source)

File <template>:71, in top-level template code()

File ~/venvs/kserve/lib/python3.9/site-packages/jinja2/runtime.py:422, in LoopContext.__init__(self, iterable, undefined, recurse, depth0)
    413 \"\"\"
    414 :param iterable: Iterable to wrap.
    415 :param undefined: :class:`Undefined` class to use for next and
   (...)
    419 :param depth0: Incremented when looping recursively.
    420 \"\"\"
    421 self._iterable = iterable
--> 422 self._iterator = self._to_iterator(iterable)
    423 self._undefined = undefined
    424 self._recurse = recurse

File ~/venvs/kserve/lib/python3.9/site-packages/jinja2/runtime.py:430, in LoopContext._to_iterator(iterable)
    428 @staticmethod
    429 def _to_iterator(iterable: t.Iterable[V]) -> t.Iterator[V]:
--> 430     return iter(iterable)

TypeError: 'NoneType' object is not iterable"
}

I tried to debug it

{%- elif message["role"] == "tool_calls" or message.tool_calls is defined %}
    {%- if message.tool_calls is defined %}
        {%- set tool_calls = message.tool_calls %}
    {%- else %}
        {%- set tool_calls = message.content %}
    {%- endif %}
    {{- "[TOOL_CALLS] [" }}
    {%- for tool_call in tool_calls %}
        {%- set out = tool_call.function|tojson %}
        {{- out[:-1] }}
        {%- if not tool_call.id is defined or tool_call.id|length != 9 %}
            {{- raise_exception("Tool call IDs should be alphanumeric strings with length 9!") }}
        {%- endif %}
        {{- ', "id": "' + tool_call.id + '"}' }}
        {%- if not loop.last %}
            {{- ", " }}
        {%- else %}
            {{- "]" + eos_token }}
        {%- endif %}
    {%- endfor %}

This seems to be causing the issue. Here message is of the assistant when this breaks. content='I am an AI model created by MistralAI' refusal=None role='assistant' name=None tool_calls=None function_call=None

So updating the first line check to
elif message["role"] == "tool_calls" or message.tool_calls is defined and message.tool_calls is not none seems to work for me.

Would be great if we can include the above changes too :)

Hi @imdatta0 , good spot! I'll include that as well.

Hi @imdatt0, I've made some PRs - let me know if these work for you, and if so I'll ping the Mistral team and try to get them merged!

7B
8x22B
Large
Nemo

Hey @Rocketknight1 ,
Yeah I was checking in on the repo frequently and noticed your PR. I tried the 7B changes and they do work for me.
I really appreciate your help with this so far.

Sign up or log in to comment