Skip to content

Commit

Permalink
Push more of the LLM conversation agent loop into ChatSession (#136602)
Browse files Browse the repository at this point in the history
* Push more of the LLM conversation agent loop into ChatSession

* Revert unnecessary changes

* Revert changes to agent id filtering
  • Loading branch information
allenporter authored Jan 27, 2025
1 parent dfbb485 commit 6993854
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 56 deletions.
17 changes: 15 additions & 2 deletions homeassistant/components/conversation/agent_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import voluptuous as vol

from homeassistant.core import Context, HomeAssistant, async_get_hass, callback
from homeassistant.helpers import config_validation as cv, singleton
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv, intent, singleton

from .const import (
DATA_COMPONENT,
Expand Down Expand Up @@ -109,7 +110,19 @@ async def async_converse(
dataclasses.asdict(conversation_input),
)
)
result = await method(conversation_input)
try:
result = await method(conversation_input)
except HomeAssistantError as err:
intent_response = intent.IntentResponse(language=language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
str(err),
)
result = ConversationResult(
response=intent_response,
conversation_id=conversation_id,
)

trace.set_result(**result.as_dict())
return result

Expand Down
37 changes: 35 additions & 2 deletions homeassistant/components/conversation/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import logging
from typing import Literal

import voluptuous as vol

from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import (
CALLBACK_TYPE,
Expand All @@ -23,7 +25,9 @@
from homeassistant.helpers.event import async_call_later
from homeassistant.util import dt as dt_util, ulid as ulid_util
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import JsonObjectType

from . import trace
from .const import DOMAIN
from .models import ConversationInput, ConversationResult

Expand Down Expand Up @@ -120,7 +124,7 @@ async def async_get_chat_session(
if history:
history = replace(history, messages=history.messages.copy())
else:
history = ChatSession(hass, conversation_id)
history = ChatSession(hass, conversation_id, user_input.agent_id)

message: ChatMessage = ChatMessage(
role="user",
Expand Down Expand Up @@ -190,6 +194,7 @@ class ChatSession[_NativeT]:

hass: HomeAssistant
conversation_id: str
agent_id: str | None
user_name: str | None = None
messages: list[ChatMessage[_NativeT]] = field(
default_factory=lambda: [ChatMessage(role="system", agent_id=None, content="")]
Expand All @@ -209,7 +214,9 @@ def async_add_message(self, message: ChatMessage[_NativeT]) -> None:
self.messages.append(message)

@callback
def async_get_messages(self, agent_id: str | None) -> list[ChatMessage[_NativeT]]:
def async_get_messages(
self, agent_id: str | None = None
) -> list[ChatMessage[_NativeT]]:
"""Get messages for a specific agent ID.
This will filter out any native message tied to other agent IDs.
Expand Down Expand Up @@ -326,3 +333,29 @@ async def async_update_llm_data(
agent_id=user_input.agent_id,
content=prompt,
)

LOGGER.debug("Prompt: %s", self.messages)
LOGGER.debug("Tools: %s", self.llm_api.tools if self.llm_api else None)

trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL,
{
"messages": self.messages,
"tools": self.llm_api.tools if self.llm_api else None,
},
)

async def async_call_tool(self, tool_input: llm.ToolInput) -> JsonObjectType:
"""Invoke LLM tool for the configured LLM API."""
if not self.llm_api:
raise ValueError("No LLM API configured")
LOGGER.debug("Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args)

try:
tool_response = await self.llm_api.async_call_tool(tool_input)
except (HomeAssistantError, vol.Invalid) as e:
tool_response = {"error": type(e).__name__}
if str(e):
tool_response["error_text"] = str(e)
LOGGER.debug("Tool response: %s", tool_response)
return tool_response
62 changes: 19 additions & 43 deletions homeassistant/components/openai_conversation/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@
)
from openai.types.chat.chat_completion_message_tool_call_param import Function
from openai.types.shared_params import FunctionDefinition
import voluptuous as vol
from voluptuous_openapi import convert

from homeassistant.components import assist_pipeline, conversation
from homeassistant.components.conversation import trace
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_LLM_HASS_API, MATCH_ALL
from homeassistant.core import HomeAssistant
Expand Down Expand Up @@ -94,6 +92,19 @@ def _message_convert(message: ChatCompletionMessage) -> ChatCompletionMessagePar
return param


def _chat_message_convert(
message: conversation.ChatMessage[ChatCompletionMessageParam],
agent_id: str | None,
) -> ChatCompletionMessageParam:
"""Convert any native chat message for this agent to the native format."""
if message.native is not None and message.agent_id == agent_id:
return message.native
return cast(
ChatCompletionMessageParam,
{"role": message.role, "content": message.content},
)


class OpenAIConversationEntity(
conversation.ConversationEntity, conversation.AbstractConversationAgent
):
Expand Down Expand Up @@ -173,27 +184,10 @@ async def _async_call_api(
for tool in session.llm_api.tools
]

messages: list[ChatCompletionMessageParam] = []
for message in session.async_get_messages(user_input.agent_id):
if message.native is not None and message.agent_id == user_input.agent_id:
messages.append(message.native)
else:
messages.append(
cast(
ChatCompletionMessageParam,
{"role": message.role, "content": message.content},
)
)

LOGGER.debug("Prompt: %s", messages)
LOGGER.debug("Tools: %s", tools)
trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL,
{
"messages": session.messages,
"tools": session.llm_api.tools if session.llm_api else None,
},
)
messages = [
_chat_message_convert(message, user_input.agent_id)
for message in session.async_get_messages()
]

client = self.entry.runtime_data

Expand All @@ -211,14 +205,7 @@ async def _async_call_api(
)
except openai.OpenAIError as err:
LOGGER.error("Error talking to OpenAI: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
"Sorry, I had a problem talking to OpenAI",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=session.conversation_id
)
raise HomeAssistantError("Error talking to OpenAI") from err

LOGGER.debug("Response %s", result)
response = result.choices[0].message
Expand All @@ -241,18 +228,7 @@ async def _async_call_api(
tool_name=tool_call.function.name,
tool_args=json.loads(tool_call.function.arguments),
)
LOGGER.debug(
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
)

try:
tool_response = await session.llm_api.async_call_tool(tool_input)
except (HomeAssistantError, vol.Invalid) as e:
tool_response = {"error": type(e).__name__}
if str(e):
tool_response["error_text"] = str(e)

LOGGER.debug("Tool response: %s", tool_response)
tool_response = await session.async_call_tool(tool_input)
messages.append(
ChatCompletionToolMessageParam(
role="tool",
Expand Down
102 changes: 97 additions & 5 deletions tests/components/conversation/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

from collections.abc import Generator
from datetime import timedelta
from unittest.mock import Mock, patch
from unittest.mock import AsyncMock, Mock, patch

import pytest
from syrupy.assertion import SnapshotAssertion
import voluptuous as vol

from homeassistant.components.conversation import ConversationInput, session
from homeassistant.core import Context, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import llm
from homeassistant.util import dt as dt_util

Expand Down Expand Up @@ -182,7 +184,7 @@ async def test_message_filtering(
)
assert messages[1] == session.ChatMessage(
role="user",
agent_id=mock_conversation_input.agent_id,
agent_id="mock-agent-id",
content=mock_conversation_input.text,
)
# Cannot add a second user message in a row
Expand All @@ -203,7 +205,7 @@ async def test_message_filtering(
native="assistant-reply-native",
)
)
# Different agent, will be filtered out.
# Different agent, native messages will be filtered out.
chat_session.async_add_message(
session.ChatMessage(
role="native", agent_id="another-mock-agent-id", content="", native=1
Expand All @@ -214,11 +216,20 @@ async def test_message_filtering(
role="native", agent_id="mock-agent-id", content="", native=1
)
)
# A non-native message from another agent is not filtered out.
chat_session.async_add_message(
session.ChatMessage(
role="assistant",
agent_id="another-mock-agent-id",
content="Hi!",
native=1,
)
)

assert len(chat_session.messages) == 5
assert len(chat_session.messages) == 6

messages = chat_session.async_get_messages(agent_id="mock-agent-id")
assert len(messages) == 4
assert len(messages) == 5

assert messages[2] == session.ChatMessage(
role="assistant",
Expand All @@ -229,6 +240,9 @@ async def test_message_filtering(
assert messages[3] == session.ChatMessage(
role="native", agent_id="mock-agent-id", content="", native=1
)
assert messages[4] == session.ChatMessage(
role="assistant", agent_id="another-mock-agent-id", content="Hi!", native=1
)


async def test_llm_api(
Expand Down Expand Up @@ -413,3 +427,81 @@ async def test_extra_systen_prompt(

assert chat_session.extra_system_prompt == extra_system_prompt2
assert chat_session.messages[0].content.endswith(extra_system_prompt2)


async def test_tool_call(
hass: HomeAssistant,
mock_conversation_input: ConversationInput,
) -> None:
"""Test using the session tool calling API."""

mock_tool = AsyncMock()
mock_tool.name = "test_tool"
mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema(
{vol.Optional("param1", description="Test parameters"): str}
)
mock_tool.async_call.return_value = "Test response"

with patch(
"homeassistant.components.conversation.session.llm.AssistAPI._async_get_tools",
return_value=[],
) as mock_get_tools:
mock_get_tools.return_value = [mock_tool]

async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
await chat_session.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api="assist",
user_llm_prompt=None,
)
result = await chat_session.async_call_tool(
llm.ToolInput(
tool_name="test_tool",
tool_args={"param1": "Test Param"},
)
)

assert result == "Test response"


async def test_tool_call_exception(
hass: HomeAssistant,
mock_conversation_input: ConversationInput,
) -> None:
"""Test using the session tool calling API."""

mock_tool = AsyncMock()
mock_tool.name = "test_tool"
mock_tool.description = "Test function"
mock_tool.parameters = vol.Schema(
{vol.Optional("param1", description="Test parameters"): str}
)
mock_tool.async_call.side_effect = HomeAssistantError("Test error")

with patch(
"homeassistant.components.conversation.session.llm.AssistAPI._async_get_tools",
return_value=[],
) as mock_get_tools:
mock_get_tools.return_value = [mock_tool]

async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
await chat_session.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api="assist",
user_llm_prompt=None,
)
result = await chat_session.async_call_tool(
llm.ToolInput(
tool_name="test_tool",
tool_args={"param1": "Test Param"},
)
)

assert result == {"error": "HomeAssistantError", "error_text": "Test error"}
Loading

0 comments on commit 6993854

Please sign in to comment.