From da1c2bf12e7f3eab10c08690fb0566de5914ef83 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Tue, 21 Jan 2025 06:06:19 -0800 Subject: [PATCH] fix: use tool_calls field to detect tool calls in OpenAI client; add integration tests for OpenAI and Gemini (#5122) * fix: use tool_calls field to detect tool calls in OpenAI client * Add unit tests for tool calling; and integration tests for openai and gemini --- .../models/openai/_openai_client.py | 30 +- .../tests/models/test_openai_model_client.py | 331 +++++++++++++++++- 2 files changed, 349 insertions(+), 12 deletions(-) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py index c44d03711b7..79c13442c7d 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py @@ -539,20 +539,33 @@ async def create( if self._resolved_model is not None: if self._resolved_model != result.model: warnings.warn( - f"Resolved model mismatch: {self._resolved_model} != {result.model}. Model mapping may be incorrect.", + f"Resolved model mismatch: {self._resolved_model} != {result.model}. " + "Model mapping in autogen_ext.models.openai may be incorrect.", stacklevel=2, ) # Limited to a single choice currently. choice: Union[ParsedChoice[Any], ParsedChoice[BaseModel], Choice] = result.choices[0] - if choice.finish_reason == "function_call": - raise ValueError("Function calls are not supported in this context") + # Detect whether it is a function call or not. + # We don't rely on choice.finish_reason as it is not always accurate, depending on the API used. content: Union[str, List[FunctionCall]] - if choice.finish_reason == "tool_calls": - assert choice.message.tool_calls is not None - assert choice.message.function_call is None - + if choice.message.function_call is not None: + raise ValueError("function_call is deprecated and is not supported by this model client.") + elif choice.message.tool_calls is not None: + if choice.finish_reason != "tool_calls": + warnings.warn( + f"Finish reason mismatch: {choice.finish_reason} != tool_calls " + "when tool_calls are present. Finish reason may not be accurate. " + "This may be due to the API used that is not returning the correct finish reason.", + stacklevel=2, + ) + if choice.message.content is not None and choice.message.content != "": + warnings.warn( + "Both tool_calls and content are present in the message. " + "This is unexpected. content will be ignored, tool_calls will be used.", + stacklevel=2, + ) # NOTE: If OAI response type changes, this will need to be updated content = [ FunctionCall( @@ -562,10 +575,11 @@ async def create( ) for x in choice.message.tool_calls ] - finish_reason = "function_calls" + finish_reason = "tool_calls" else: finish_reason = choice.finish_reason content = choice.message.content or "" + logprobs: Optional[List[ChatCompletionTokenLogprob]] = None if choice.logprobs and choice.logprobs.content: logprobs = [ diff --git a/python/packages/autogen-ext/tests/models/test_openai_model_client.py b/python/packages/autogen-ext/tests/models/test_openai_model_client.py index 18312e1c161..d629cdc428a 100644 --- a/python/packages/autogen-ext/tests/models/test_openai_model_client.py +++ b/python/packages/autogen-ext/tests/models/test_openai_model_client.py @@ -1,10 +1,11 @@ import asyncio import json -from typing import Annotated, Any, AsyncGenerator, Generic, List, Literal, Tuple, TypeVar +import os +from typing import Annotated, Any, AsyncGenerator, Dict, Generic, List, Literal, Tuple, TypeVar from unittest.mock import MagicMock import pytest -from autogen_core import CancellationToken, Image +from autogen_core import CancellationToken, FunctionCall, Image from autogen_core.models import ( AssistantMessage, CreateResult, @@ -26,10 +27,31 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice from openai.types.chat.chat_completion_message import ChatCompletionMessage +from openai.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, + Function, +) from openai.types.chat.parsed_chat_completion import ParsedChatCompletion, ParsedChatCompletionMessage, ParsedChoice from openai.types.completion_usage import CompletionUsage from pydantic import BaseModel, Field + +class _MockChatCompletion: + def __init__(self, chat_completions: List[ChatCompletion]) -> None: + self._saved_chat_completions = chat_completions + self.curr_index = 0 + self.calls: List[Dict[str, Any]] = [] + + async def mock_create( + self, *args: Any, **kwargs: Any + ) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]: + self.calls.append(kwargs) # Save the call + await asyncio.sleep(0.1) + completion = self._saved_chat_completions[self.curr_index] + self.curr_index += 1 + return completion + + ResponseFormatT = TypeVar("ResponseFormatT", bound=BaseModel) @@ -37,20 +59,32 @@ class _MockBetaChatCompletion(Generic[ResponseFormatT]): def __init__(self, chat_completions: List[ParsedChatCompletion[ResponseFormatT]]) -> None: self._saved_chat_completions = chat_completions self.curr_index = 0 - self.calls: List[List[LLMMessage]] = [] + self.calls: List[Dict[str, Any]] = [] async def mock_parse( self, *args: Any, **kwargs: Any, ) -> ParsedChatCompletion[ResponseFormatT]: - self.calls.append(kwargs["messages"]) + self.calls.append(kwargs) # Save the call await asyncio.sleep(0.1) completion = self._saved_chat_completions[self.curr_index] self.curr_index += 1 return completion +def _pass_function(input: str) -> str: + return "pass" + + +async def _fail_function(input: str) -> str: + return "fail" + + +async def _echo_function(input: str) -> str: + return input + + class MyResult(BaseModel): result: str = Field(description="The other description.") @@ -432,3 +466,292 @@ class AgentResponse(BaseModel): == "The user explicitly states that they are happy without any indication of sadness or neutrality." ) assert response.response == "happy" + + +@pytest.mark.asyncio +async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None: + model = "gpt-4o-2024-05-13" + chat_completions = [ + # Successful completion, single tool call + ChatCompletion( + id="id1", + choices=[ + Choice( + finish_reason="tool_calls", + index=0, + message=ChatCompletionMessage( + content=None, + tool_calls=[ + ChatCompletionMessageToolCall( + id="1", + type="function", + function=Function( + name="_pass_function", + arguments=json.dumps({"input": "task"}), + ), + ) + ], + role="assistant", + ), + ) + ], + created=0, + model=model, + object="chat.completion", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), + ), + # Successful completion, parallel tool calls + ChatCompletion( + id="id2", + choices=[ + Choice( + finish_reason="tool_calls", + index=0, + message=ChatCompletionMessage( + content=None, + tool_calls=[ + ChatCompletionMessageToolCall( + id="1", + type="function", + function=Function( + name="_pass_function", + arguments=json.dumps({"input": "task"}), + ), + ), + ChatCompletionMessageToolCall( + id="2", + type="function", + function=Function( + name="_fail_function", + arguments=json.dumps({"input": "task"}), + ), + ), + ChatCompletionMessageToolCall( + id="3", + type="function", + function=Function( + name="_echo_function", + arguments=json.dumps({"input": "task"}), + ), + ), + ], + role="assistant", + ), + ) + ], + created=0, + model=model, + object="chat.completion", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), + ), + # Warning completion when finish reason is not tool_calls. + ChatCompletion( + id="id3", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage( + content=None, + tool_calls=[ + ChatCompletionMessageToolCall( + id="1", + type="function", + function=Function( + name="_pass_function", + arguments=json.dumps({"input": "task"}), + ), + ) + ], + role="assistant", + ), + ) + ], + created=0, + model=model, + object="chat.completion", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), + ), + # Warning completion when content is not None. + ChatCompletion( + id="id4", + choices=[ + Choice( + finish_reason="tool_calls", + index=0, + message=ChatCompletionMessage( + content="I should make a tool call.", + tool_calls=[ + ChatCompletionMessageToolCall( + id="1", + type="function", + function=Function( + name="_pass_function", + arguments=json.dumps({"input": "task"}), + ), + ) + ], + role="assistant", + ), + ) + ], + created=0, + model=model, + object="chat.completion", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), + ), + ] + mock = _MockChatCompletion(chat_completions) + monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create) + pass_tool = FunctionTool(_pass_function, description="pass tool.") + fail_tool = FunctionTool(_fail_function, description="fail tool.") + echo_tool = FunctionTool(_echo_function, description="echo tool.") + model_client = OpenAIChatCompletionClient(model=model, api_key="") + + # Single tool call + create_result = await model_client.create(messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool]) + assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")] + # Verify that the tool schema was passed to the model client. + kwargs = mock.calls[0] + assert kwargs["tools"] == [{"function": pass_tool.schema, "type": "function"}] + # Verify finish reason + assert create_result.finish_reason == "function_calls" + + # Parallel tool calls + create_result = await model_client.create( + messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool, fail_tool, echo_tool] + ) + assert create_result.content == [ + FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function"), + FunctionCall(id="2", arguments=r'{"input": "task"}', name="_fail_function"), + FunctionCall(id="3", arguments=r'{"input": "task"}', name="_echo_function"), + ] + # Verify that the tool schema was passed to the model client. + kwargs = mock.calls[1] + assert kwargs["tools"] == [ + {"function": pass_tool.schema, "type": "function"}, + {"function": fail_tool.schema, "type": "function"}, + {"function": echo_tool.schema, "type": "function"}, + ] + # Verify finish reason + assert create_result.finish_reason == "function_calls" + + # Warning completion when finish reason is not tool_calls. + with pytest.warns(UserWarning, match="Finish reason mismatch"): + create_result = await model_client.create( + messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool] + ) + assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")] + assert create_result.finish_reason == "function_calls" + + # Warning completion when content is not None. + with pytest.warns(UserWarning, match="Both tool_calls and content are present in the message"): + create_result = await model_client.create( + messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool] + ) + assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")] + assert create_result.finish_reason == "function_calls" + + +async def _test_model_client(model_client: OpenAIChatCompletionClient) -> None: + # Test basic completion + create_result = await model_client.create( + messages=[ + SystemMessage(content="You are a helpful assistant."), + UserMessage(content="Explain to me how AI works.", source="user"), + ] + ) + assert isinstance(create_result.content, str) + assert len(create_result.content) > 0 + + # Test tool calling + pass_tool = FunctionTool(_pass_function, name="pass_tool", description="pass session.") + fail_tool = FunctionTool(_fail_function, name="fail_tool", description="fail session.") + messages: List[LLMMessage] = [UserMessage(content="Call the pass tool with input 'task'", source="user")] + create_result = await model_client.create(messages=messages, tools=[pass_tool, fail_tool]) + assert isinstance(create_result.content, list) + assert len(create_result.content) == 1 + assert isinstance(create_result.content[0], FunctionCall) + assert create_result.content[0].name == "pass_tool" + assert json.loads(create_result.content[0].arguments) == {"input": "task"} + assert create_result.finish_reason == "function_calls" + assert create_result.usage is not None + + # Test reflection on tool call response. + messages.append(AssistantMessage(content=create_result.content, source="assistant")) + messages.append( + FunctionExecutionResultMessage( + content=[FunctionExecutionResult(content="passed", call_id=create_result.content[0].id)] + ) + ) + create_result = await model_client.create(messages=messages) + assert isinstance(create_result.content, str) + assert len(create_result.content) > 0 + + # Test parallel tool calling + messages = [ + UserMessage( + content="Call both the pass tool with input 'task' and the fail tool also with input 'task'", source="user" + ) + ] + create_result = await model_client.create(messages=messages, tools=[pass_tool, fail_tool]) + assert isinstance(create_result.content, list) + assert len(create_result.content) == 2 + assert isinstance(create_result.content[0], FunctionCall) + assert create_result.content[0].name == "pass_tool" + assert json.loads(create_result.content[0].arguments) == {"input": "task"} + assert isinstance(create_result.content[1], FunctionCall) + assert create_result.content[1].name == "fail_tool" + assert json.loads(create_result.content[1].arguments) == {"input": "task"} + assert create_result.finish_reason == "function_calls" + assert create_result.usage is not None + + # Test reflection on parallel tool call response. + messages.append(AssistantMessage(content=create_result.content, source="assistant")) + messages.append( + FunctionExecutionResultMessage( + content=[ + FunctionExecutionResult(content="passed", call_id=create_result.content[0].id), + FunctionExecutionResult(content="failed", call_id=create_result.content[1].id), + ] + ) + ) + create_result = await model_client.create(messages=messages) + assert isinstance(create_result.content, str) + assert len(create_result.content) > 0 + + +@pytest.mark.asyncio +async def test_openai() -> None: + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + pytest.skip("OPENAI_API_KEY not found in environment variables") + + model_client = OpenAIChatCompletionClient( + model="gpt-4o-mini", + api_key=api_key, + ) + await _test_model_client(model_client) + + +@pytest.mark.asyncio +async def test_gemini() -> None: + api_key = os.getenv("GEMINI_API_KEY") + if not api_key: + pytest.skip("GEMINI_API_KEY not found in environment variables") + + model_client = OpenAIChatCompletionClient( + model="gemini-1.5-flash", + api_key=api_key, + base_url="https://generativelanguage.googleapis.com/v1beta/openai/", + model_info={ + "function_calling": True, + "json_output": True, + "vision": True, + "family": ModelFamily.UNKNOWN, + }, + ) + await _test_model_client(model_client) + + +# TODO: add integration tests for Azure OpenAI using AAD token.