Skip to content

Commit

Permalink
Python: Improve azure assistant agent settings and retrieval operatio…
Browse files Browse the repository at this point in the history
…ns. (#10063)

### Motivation and Context

The AzureAssistantAgent retrieval path was not handling the optional
ad_token. The agent constructor was handling it, to a degree, but there
were improvements to make and a helper function to introduce so that we
remove some code duplication.

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

### Description

This PR:
- Cleans up the code used to create the azure settings for an azure
assistant agent by introducing a helper function to streamline the logic
- Allows the AzureAssistantAgent.retrieve() method to use an ad_token,
if desired.
- Adds unit tests for the new logic. Other unit tests are still passing
after the refactor.

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [X] The code builds clean without any errors or warnings
- [X] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [X] All unit tests pass, and I have added new tests where possible
- [X] I didn't break anyone 😄
  • Loading branch information
moonbox3 authored Jan 7, 2025
1 parent 5e4012f commit 5830081
Show file tree
Hide file tree
Showing 4 changed files with 262 additions and 100 deletions.
12 changes: 12 additions & 0 deletions python/DEV_SETUP.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,18 @@ If prompted, install `ruff`. (It should have been installed as part of `uv sync
You also need to install the `ruff` extension in VSCode so that auto-formatting uses the `ruff` formatter on save.
Read more about the extension [here](https://github.com/astral-sh/ruff-vscode).

### Configuring Unit Testing in VSCode

- We have removed the strict dependency on forcing `pytest` usage via the `.vscode/settings.json` file.
- Developers are free to set up unit tests using their preferred framework, whether it is `pytest` or `unittest`.
- If needed, adjust VSCode's local `settings.json` (accessed via the Command Palette: **Open User Settings (JSON)**) to configure the test framework. For example:
```json
"pythonTestExplorer.testFramework": "pytest"
```
Or, for `unittest`:
```json
"pythonTestExplorer.testFramework": "unittest"

## LLM setup

Make sure you have an
Expand Down
21 changes: 14 additions & 7 deletions python/samples/concepts/agents/assistant_agent_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ async def main():
# Define a service_id for the sample
service_id = "agent"

# Specify an assistant ID which is used
# to retrieve the agent
assistant_id: str = None

# Create the agent configuration
if use_azure_openai:
agent = await AzureAssistantAgent.create(
Expand All @@ -53,6 +57,11 @@ async def main():
instructions=AGENT_INSTRUCTIONS,
enable_code_interpreter=True,
)

retrieved_agent: AzureAssistantAgent = await AzureAssistantAgent.retrieve(
id=assistant_id,
kernel=kernel,
)
else:
agent = await OpenAIAssistantAgent.create(
kernel=kernel,
Expand All @@ -62,13 +71,11 @@ async def main():
enable_code_interpreter=True,
)

assistant_id = agent.assistant.id

# Retrieve the agent using the assistant_id
retrieved_agent: OpenAIAssistantAgent = await OpenAIAssistantAgent.retrieve(
id=assistant_id,
kernel=kernel,
)
# Retrieve the agent using the assistant_id
retrieved_agent: OpenAIAssistantAgent = await OpenAIAssistantAgent.retrieve(
id=assistant_id,
kernel=kernel,
)

# Define a thread and invoke the agent with the user input
thread_id = await retrieved_agent.create_thread()
Expand Down
211 changes: 119 additions & 92 deletions python/semantic_kernel/agents/open_ai/azure_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(
Raises:
AgentInitializationError: If the api_key is not provided in the configuration.
"""
azure_openai_settings = AzureAssistantAgent._create_azure_openai_settings(
azure_openai_settings = self._create_azure_openai_settings(
api_key=api_key,
endpoint=endpoint,
deployment_name=deployment_name,
Expand All @@ -113,30 +113,14 @@ def __init__(
token_endpoint=token_endpoint,
)

if not azure_openai_settings.chat_deployment_name:
raise AgentInitializationException("The Azure OpenAI chat_deployment_name is required.")

if (
client is None
and azure_openai_settings.api_key is None
and ad_token_provider is None
and ad_token is None
and azure_openai_settings.token_endpoint
):
ad_token = get_entra_auth_token(azure_openai_settings.token_endpoint)

if not client and not azure_openai_settings.api_key and not ad_token and not ad_token_provider:
raise AgentInitializationException("Please provide either api_key, ad_token or ad_token_provider.")
client, ad_token = self._setup_client_and_token(
azure_openai_settings=azure_openai_settings,
ad_token=ad_token,
ad_token_provider=ad_token_provider,
client=client,
default_headers=default_headers,
)

if not client:
client = self._create_client(
api_key=azure_openai_settings.api_key.get_secret_value() if azure_openai_settings.api_key else None,
endpoint=azure_openai_settings.endpoint,
api_version=azure_openai_settings.api_version,
ad_token=ad_token,
ad_token_provider=ad_token_provider,
default_headers=default_headers,
)
service_id = service_id if service_id else DEFAULT_SERVICE_NAME

args: dict[str, Any] = {
Expand Down Expand Up @@ -167,6 +151,7 @@ def __init__(
args["kernel"] = kernel
if kwargs:
args.update(kwargs)

super().__init__(**args)

@classmethod
Expand Down Expand Up @@ -204,6 +189,7 @@ async def create(
max_prompt_tokens: int | None = None,
parallel_tool_calls_enabled: bool | None = True,
truncation_message_count: int | None = None,
token_endpoint: str | None = None,
**kwargs: Any,
) -> "AzureAssistantAgent":
"""Asynchronous class method used to create the OpenAI Assistant Agent.
Expand Down Expand Up @@ -240,10 +226,11 @@ async def create(
max_prompt_tokens: The maximum prompt tokens. (optional)
parallel_tool_calls_enabled: Enable parallel tool calls. (optional)
truncation_message_count: The truncation message count. (optional)
token_endpoint: The Azure AD token endpoint. (optional)
**kwargs: Additional keyword arguments.
Returns:
An instance of the AzureOpenAIAssistantAgent
An instance of the AzureAssistantAgent
"""
agent = cls(
kernel=kernel,
Expand Down Expand Up @@ -273,16 +260,15 @@ async def create(
max_prompt_tokens=max_prompt_tokens,
parallel_tool_calls_enabled=parallel_tool_calls_enabled,
truncation_message_count=truncation_message_count,
token_endpoint=token_endpoint,
**kwargs,
)

assistant_create_kwargs: dict[str, Any] = {}

code_interpreter_file_ids_combined: list[str] = []

if code_interpreter_file_ids is not None:
code_interpreter_file_ids_combined.extend(code_interpreter_file_ids)

if code_interpreter_filenames is not None:
for file_path in code_interpreter_filenames:
try:
Expand All @@ -293,16 +279,13 @@ async def create(
f"Failed to upload code interpreter file with path: `{file_path}` with exception: {ex}"
)
raise AgentInitializationException("Failed to upload code interpreter files.", ex) from ex

if code_interpreter_file_ids_combined:
agent.code_interpreter_file_ids = code_interpreter_file_ids_combined
assistant_create_kwargs["code_interpreter_file_ids"] = code_interpreter_file_ids_combined

vector_store_file_ids_combined: list[str] = []

if vector_store_file_ids is not None:
vector_store_file_ids_combined.extend(vector_store_file_ids)

if vector_store_filenames is not None:
for file_path in vector_store_filenames:
try:
Expand All @@ -311,7 +294,6 @@ async def create(
except FileNotFoundError as ex:
logger.error(f"Failed to upload vector store file with path: `{file_path}` with exception: {ex}")
raise AgentInitializationException("Failed to upload vector store files.", ex) from ex

if vector_store_file_ids_combined:
agent.file_search_file_ids = vector_store_file_ids_combined
if enable_file_search or agent.enable_file_search:
Expand All @@ -322,6 +304,110 @@ async def create(
agent.assistant = await agent.create_assistant(**assistant_create_kwargs)
return agent

@classmethod
async def retrieve(
cls,
*,
id: str,
api_key: str | None = None,
endpoint: HttpsUrl | None = None,
api_version: str | None = None,
ad_token: str | None = None,
ad_token_provider: Callable[[], str | Awaitable[str]] | None = None,
client: AsyncAzureOpenAI | None = None,
kernel: "Kernel | None" = None,
default_headers: dict[str, str] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
token_endpoint: str | None = None,
) -> "AzureAssistantAgent":
"""Retrieve an assistant by ID.
Args:
id: The assistant ID.
api_key: The Azure OpenAI API key. (optional)
endpoint: The Azure OpenAI endpoint. (optional)
api_version: The Azure OpenAI API version. (optional)
ad_token: The Azure AD token. (optional)
ad_token_provider: The Azure AD token provider. (optional)
client: The Azure OpenAI client. (optional)
kernel: The Kernel instance. (optional)
default_headers: The default headers. (optional)
env_file_path: The environment file path. (optional)
env_file_encoding: The environment file encoding. (optional)
token_endpoint: The Azure AD token endpoint. (optional)
Returns:
An AzureAssistantAgent instance.
"""
azure_openai_settings = cls._create_azure_openai_settings(
api_key=api_key,
endpoint=endpoint,
deployment_name=None, # Not required for retrieving an existing assistant
api_version=api_version,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
token_endpoint=token_endpoint,
)

client, ad_token = cls._setup_client_and_token(
azure_openai_settings=azure_openai_settings,
ad_token=ad_token,
ad_token_provider=ad_token_provider,
client=client,
default_headers=default_headers,
)

assistant = await client.beta.assistants.retrieve(id)
assistant_definition = OpenAIAssistantBase._create_open_ai_assistant_definition(assistant)

return AzureAssistantAgent(kernel=kernel, assistant=assistant, **assistant_definition)

@staticmethod
def _setup_client_and_token(
azure_openai_settings: AzureOpenAISettings,
ad_token: str | None,
ad_token_provider: Callable[[], str | Awaitable[str]] | None,
client: AsyncAzureOpenAI | None,
default_headers: dict[str, str] | None,
) -> tuple[AsyncAzureOpenAI, str | None]:
"""Helper method that ensures either an AD token or an API key is present.
Retrieves a new AD token if needed, and configures the AsyncAzureOpenAI client.
Returns:
A tuple of (client, ad_token), where client is guaranteed not to be None.
"""
if not azure_openai_settings.chat_deployment_name:
raise AgentInitializationException("The Azure OpenAI chat_deployment_name is required.")

# If everything is missing, but there is a token_endpoint, try to get the token.
if (
client is None
and azure_openai_settings.api_key is None
and ad_token_provider is None
and ad_token is None
and azure_openai_settings.token_endpoint
):
ad_token = get_entra_auth_token(azure_openai_settings.token_endpoint)

# If we still have no credentials, we can't proceed
if not client and not azure_openai_settings.api_key and not ad_token and not ad_token_provider:
raise AgentInitializationException("Please provide either api_key, ad_token or ad_token_provider.")

# Build the client if it's not supplied
if not client:
client = AzureAssistantAgent._create_client(
api_key=azure_openai_settings.api_key.get_secret_value() if azure_openai_settings.api_key else None,
endpoint=azure_openai_settings.endpoint,
api_version=azure_openai_settings.api_version,
ad_token=ad_token,
ad_token_provider=ad_token_provider,
default_headers=default_headers,
)

return client, ad_token

@staticmethod
def _create_client(
api_key: str | None = None,
Expand Down Expand Up @@ -351,10 +437,11 @@ def _create_client(

if not api_key and not ad_token and not ad_token_provider:
raise AgentInitializationException(
"Please provide either AzureOpenAI api_key, an ad_token or an ad_token_provider or a client."
"Please provide either AzureOpenAI api_key, an ad_token, ad_token_provider, or a client."
)
if not endpoint:
raise AgentInitializationException("Please provide an AzureOpenAI endpoint.")

return AsyncAzureOpenAI(
azure_endpoint=str(endpoint),
api_version=api_version,
Expand Down Expand Up @@ -413,64 +500,4 @@ async def list_definitions(self) -> AsyncIterable[dict[str, Any]]:
for assistant in assistants.data:
yield OpenAIAssistantBase._create_open_ai_assistant_definition(assistant)

@classmethod
async def retrieve(
cls,
*,
id: str,
api_key: str | None = None,
endpoint: HttpsUrl | None = None,
api_version: str | None = None,
ad_token: str | None = None,
ad_token_provider: Callable[[], str | Awaitable[str]] | None = None,
client: AsyncAzureOpenAI | None = None,
kernel: "Kernel | None" = None,
default_headers: dict[str, str] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> "AzureAssistantAgent":
"""Retrieve an assistant by ID.
Args:
id: The assistant ID.
api_key: The Azure OpenAI API
endpoint: The Azure OpenAI endpoint. (optional)
api_version: The Azure OpenAI API version. (optional)
ad_token: The Azure AD token. (optional)
ad_token_provider: The Azure AD token provider. (optional)
client: The Azure OpenAI client. (optional)
kernel: The Kernel instance. (optional)
default_headers: The default headers. (optional)
env_file_path: The environment file path. (optional)
env_file_encoding: The environment file encoding. (optional)
Returns:
An AzureAssistantAgent instance.
"""
azure_openai_settings = AzureAssistantAgent._create_azure_openai_settings(
api_key=api_key,
endpoint=endpoint,
api_version=api_version,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)

if not azure_openai_settings.chat_deployment_name:
raise AgentInitializationException("The Azure OpenAI chat_deployment_name is required.")
if not azure_openai_settings.api_key and not ad_token and not ad_token_provider:
raise AgentInitializationException("Please provide either api_key, ad_token or ad_token_provider.")

if not client:
client = AzureAssistantAgent._create_client(
api_key=api_key,
endpoint=endpoint,
api_version=api_version,
ad_token=ad_token,
ad_token_provider=ad_token_provider,
default_headers=default_headers,
)
assistant = await client.beta.assistants.retrieve(id)
assistant_definition = OpenAIAssistantBase._create_open_ai_assistant_definition(assistant)
return AzureAssistantAgent(kernel=kernel, assistant=assistant, **assistant_definition)

# endregion
Loading

0 comments on commit 5830081

Please sign in to comment.