Skip to content

Commit

Permalink
.Net: Add factory for customizing OpenAPI plugins responses (#10106)
Browse files Browse the repository at this point in the history
### Motivation and Context  
Currently, it's impossible to access the HTTP response and response
content headers returned by a REST API requested from OpenAPI plugins.
   
### Description  
This PR adds `RestApiOperationResponseFactory`, which can be used to
customize the responses of OpenAPI plugins before returning them to the
caller. The customization may include modifying the original response by
adding response headers, changing the response content, adjusting the
schema, or providing a completely new response.

Closes: #9986
  • Loading branch information
SergeyMenshykh authored Jan 7, 2025
1 parent 4a70658 commit 941ee64
Show file tree
Hide file tree
Showing 12 changed files with 428 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Net;
using System.Text;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Plugins.OpenApi;

namespace Plugins;

/// <summary>
/// Sample shows how to register the <see cref="RestApiOperationResponseFactory"/> to transform existing or create new <see cref="RestApiOperationResponse"/>.
/// </summary>
public sealed class OpenApiPlugin_RestApiOperationResponseFactory(ITestOutputHelper output) : BaseTest(output)
{
private readonly HttpClient _httpClient = new(new StubHttpHandler(InterceptRequestAndCustomizeResponseAsync));

[Fact]
public async Task IncludeResponseHeadersToOperationResponseAsync()
{
Kernel kernel = new();

// Register the operation response factory and the custom HTTP client
OpenApiFunctionExecutionParameters executionParameters = new()
{
RestApiOperationResponseFactory = IncludeHeadersIntoRestApiOperationResponseAsync,
HttpClient = this._httpClient
};

// Create OpenAPI plugin
KernelPlugin plugin = await OpenApiKernelPluginFactory.CreateFromOpenApiAsync("RepairService", "Resources/Plugins/RepairServicePlugin/repair-service.json", executionParameters);

// Create arguments for a new repair
KernelArguments arguments = new()
{
["title"] = "The Case of the Broken Gizmo",
["description"] = "It's broken. Send help!",
["assignedTo"] = "Tech Magician"
};

// Create the repair
FunctionResult createResult = await plugin["createRepair"].InvokeAsync(kernel, arguments);

// Get operation response that was modified
RestApiOperationResponse response = createResult.GetValue<RestApiOperationResponse>()!;

// Display the 'repair-id' header value
Console.WriteLine(response.Headers!["repair-id"].First());
}

/// <summary>
/// A custom factory to transform the operation response.
/// </summary>
/// <param name="context">The context for the <see cref="RestApiOperationResponseFactory"/>.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>The transformed operation response.</returns>
private static async Task<RestApiOperationResponse> IncludeHeadersIntoRestApiOperationResponseAsync(RestApiOperationResponseFactoryContext context, CancellationToken cancellationToken)
{
// Create the response using the internal factory
RestApiOperationResponse response = await context.InternalFactory(context, cancellationToken);

// Obtain the 'repair-id' header value from the HTTP response and include it in the operation response only for the 'createRepair' operation
if (context.Operation.Id == "createRepair" && context.Response.Headers.TryGetValues("repair-id", out IEnumerable<string>? values))
{
response.Headers ??= new Dictionary<string, IEnumerable<string>>();
response.Headers["repair-id"] = values;
}

// Return the modified response that will be returned to the caller
return response;
}

/// <summary>
/// A custom HTTP handler to intercept HTTP requests and return custom responses.
/// </summary>
/// <param name="request">The original HTTP request.</param>
/// <returns>The custom HTTP response.</returns>
private static async Task<HttpResponseMessage> InterceptRequestAndCustomizeResponseAsync(HttpRequestMessage request)
{
// Return a mock response that includes the 'repair-id' header for the 'createRepair' operation
if (request.RequestUri!.AbsolutePath == "/repairs" && request.Method == HttpMethod.Post)
{
return new HttpResponseMessage(HttpStatusCode.Created)
{
Content = new StringContent("Success", Encoding.UTF8, "application/json"),
Headers =
{
{ "repair-id", "repair-12345" }
}
};
}

return new HttpResponseMessage(HttpStatusCode.NoContent);
}

private sealed class StubHttpHandler(Func<HttpRequestMessage, Task<HttpResponseMessage>> requestHandler) : DelegatingHandler()
{
private readonly Func<HttpRequestMessage, Task<HttpResponseMessage>> _requestHandler = requestHandler;

protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
return await this._requestHandler(request);
}
}

protected override void Dispose(bool disposing)
{
base.Dispose(disposing);
this._httpClient.Dispose();
}
}
1 change: 1 addition & 0 deletions dotnet/samples/Concepts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ dotnet test -l "console;verbosity=detailed" --filter "FullyQualifiedName=ChatCom
- [OpenApiPlugin_Customization](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Plugins/OpenApiPlugin_Customization.cs)
- [OpenApiPlugin_Filtering](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Plugins/OpenApiPlugin_Filtering.cs)
- [OpenApiPlugin_Telemetry](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Plugins/OpenApiPlugin_Telemetry.cs)
- [OpenApiPlugin_RestApiOperationResponseFactory](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Plugins/OpenApiPlugin_RestApiOperationResponseFactory.cs)
- [CustomMutablePlugin](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Plugins/CustomMutablePlugin.cs)
- [DescribeAllPluginsAndFunctions](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Plugins/DescribeAllPluginsAndFunctions.cs)
- [GroundednessChecks](https://github.com/microsoft/semantic-kernel/blob/main/dotnet/samples/Concepts/Plugins/GroundednessChecks.cs)
Expand Down
6 changes: 5 additions & 1 deletion dotnet/src/Functions/Functions.Grpc/Functions.Grpc.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

<Import Project="$(RepoRoot)/dotnet/nuget/nuget-package.props" />
<Import Project="$(RepoRoot)/dotnet/src/InternalUtilities/src/InternalUtilities.props" />

<PropertyGroup>
<!-- NuGet Package Settings -->
<Title>Semantic Kernel - gRPC Plugins</Title>
Expand All @@ -36,4 +35,9 @@
<ProjectReference Include="..\..\SemanticKernel.Core\SemanticKernel.Core.csproj" />
</ItemGroup>

<ItemGroup>
<!-- Exclude utilities that are not used by the project -->
<Compile Remove="$(RepoRoot)/dotnet/src/InternalUtilities/src/Http/HttpResponseStream.cs" Link="%(RecursiveDir)%(Filename)%(Extension)" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
</PropertyGroup>

<Import Project="$(RepoRoot)/dotnet/nuget/nuget-package.props" />
<Import Project="$(RepoRoot)/dotnet/src/InternalUtilities/src/InternalUtilities.props" />


<PropertyGroup>
<!-- NuGet Package Settings -->
Expand All @@ -30,4 +28,9 @@
<ProjectReference Include="..\..\SemanticKernel.Core\SemanticKernel.Core.csproj" />
</ItemGroup>

<ItemGroup>
<Compile Include="$(RepoRoot)/dotnet/src/InternalUtilities/src/Diagnostics/**/*.cs" Link="%(RecursiveDir)%(Filename)%(Extension)" />
<Compile Include="$(RepoRoot)/dotnet/src/InternalUtilities/src/System/AppContextSwitchHelper.cs" Link="%(RecursiveDir)%(Filename)%(Extension)" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ public class OpenApiFunctionExecutionParameters
[Experimental("SKEXP0040")]
public HttpResponseContentReader? HttpResponseContentReader { get; set; }

/// <summary>
/// A custom factory for the <see cref="RestApiOperationResponse"/>.
/// It allows modifications of various aspects of the original response, such as adding response headers,
/// changing response content, adjusting the schema, or providing a completely new response.
/// If a custom factory is not supplied, the internal factory will be used by default.
/// </summary>
[Experimental("SKEXP0040")]
public RestApiOperationResponseFactory? RestApiOperationResponseFactory { get; set; }

/// <summary>
/// A custom REST API parameter filter.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ internal static KernelPlugin CreateOpenApiPlugin(
executionParameters?.UserAgent,
executionParameters?.EnableDynamicPayload ?? true,
executionParameters?.EnablePayloadNamespacing ?? false,
executionParameters?.HttpResponseContentReader);
executionParameters?.HttpResponseContentReader,
executionParameters?.RestApiOperationResponseFactory);

var functions = new List<KernelFunction>();
ILogger logger = loggerFactory.CreateLogger(typeof(OpenApiKernelExtensions)) ?? NullLogger.Instance;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Diagnostics.CodeAnalysis;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.SemanticKernel.Plugins.OpenApi;

/// <summary>
/// Represents a factory for creating instances of the <see cref="RestApiOperationResponse"/>.
/// </summary>
/// <param name="context">The context that contains the operation details.</param>
/// <param name="cancellationToken">The cancellation token used to signal cancellation.</param>
/// <returns>A task that represents the asynchronous operation, containing an instance of <see cref="RestApiOperationResponse"/>.</returns>
[Experimental("SKEXP0040")]
public delegate Task<RestApiOperationResponse> RestApiOperationResponseFactory(RestApiOperationResponseFactoryContext context, CancellationToken cancellationToken = default);
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Diagnostics.CodeAnalysis;
using System.Net.Http;

namespace Microsoft.SemanticKernel.Plugins.OpenApi;

/// <summary>
/// Represents the context for the <see cref="RestApiOperationResponseFactory"/>."/>
/// </summary>
[Experimental("SKEXP0040")]
public sealed class RestApiOperationResponseFactoryContext
{
/// <summary>
/// Initializes a new instance of the <see cref="RestApiOperationResponseFactoryContext"/> class.
/// </summary>
/// <param name="operation">The REST API operation.</param>
/// <param name="request">The HTTP request message.</param>
/// <param name="response">The HTTP response message.</param>
/// <param name="internalFactory">The internal factory to create instances of the <see cref="RestApiOperationResponse"/>.</param>
internal RestApiOperationResponseFactoryContext(RestApiOperation operation, HttpRequestMessage request, HttpResponseMessage response, RestApiOperationResponseFactory internalFactory)
{
this.InternalFactory = internalFactory;
this.Operation = operation;
this.Request = request;
this.Response = response;
}

/// <summary>
/// The REST API operation.
/// </summary>
public RestApiOperation Operation { get; }

/// <summary>
/// The HTTP request message.
/// </summary>
public HttpRequestMessage Request { get; }

/// <summary>
/// The HTTP response message.
/// </summary>
public HttpResponseMessage Response { get; }

/// <summary>
/// The internal factory to create instances of the <see cref="RestApiOperationResponse"/>.
/// </summary>
public RestApiOperationResponseFactory InternalFactory { get; }
}
34 changes: 31 additions & 3 deletions dotnet/src/Functions/Functions.OpenApi/RestApiOperationRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ internal sealed class RestApiOperationRunner
/// </summary>
private readonly HttpResponseContentReader? _httpResponseContentReader;

/// <summary>
/// The external response factory for creating <see cref="RestApiOperationResponse"/>.
/// </summary>
private readonly RestApiOperationResponseFactory? _responseFactory;

/// <summary>
/// The external URL factory to use if provided, instead of the default one.
/// </summary>
Expand Down Expand Up @@ -115,6 +120,7 @@ internal sealed class RestApiOperationRunner
/// <param name="enablePayloadNamespacing">Determines whether payload parameters are resolved from the arguments by
/// full name (parameter name prefixed with the parent property name).</param>
/// <param name="httpResponseContentReader">Custom HTTP response content reader.</param>
/// <param name="responseFactory">The external response factory for creating <see cref="RestApiOperationResponse"/>.</param>
/// <param name="urlFactory">The external URL factory to use if provided if provided instead of the default one.</param>
/// <param name="headersFactory">The external headers factory to use if provided instead of the default one.</param>
/// <param name="payloadFactory">The external payload factory to use if provided instead of the default one.</param>
Expand All @@ -125,6 +131,7 @@ public RestApiOperationRunner(
bool enableDynamicPayload = false,
bool enablePayloadNamespacing = false,
HttpResponseContentReader? httpResponseContentReader = null,
RestApiOperationResponseFactory? responseFactory = null,
RestApiOperationUrlFactory? urlFactory = null,
RestApiOperationHeadersFactory? headersFactory = null,
RestApiOperationPayloadFactory? payloadFactory = null)
Expand All @@ -134,6 +141,7 @@ public RestApiOperationRunner(
this._enableDynamicPayload = enableDynamicPayload;
this._enablePayloadNamespacing = enablePayloadNamespacing;
this._httpResponseContentReader = httpResponseContentReader;
this._responseFactory = responseFactory;
this._urlFactory = urlFactory;
this._headersFactory = headersFactory;
this._payloadFactory = payloadFactory;
Expand Down Expand Up @@ -577,11 +585,31 @@ private Uri BuildsOperationUrl(RestApiOperation operation, IDictionary<string, o
/// <returns>The operation response.</returns>
private async Task<RestApiOperationResponse> BuildResponseAsync(RestApiOperation operation, HttpRequestMessage requestMessage, HttpResponseMessage responseMessage, object? payload, CancellationToken cancellationToken)
{
var response = await this.ReadContentAndCreateOperationResponseAsync(requestMessage, responseMessage, payload, cancellationToken).ConfigureAwait(false);
async Task<RestApiOperationResponse> Build(RestApiOperationResponseFactoryContext context, CancellationToken ct)
{
var response = await this.ReadContentAndCreateOperationResponseAsync(context.Request, context.Response, payload, ct).ConfigureAwait(false);

response.ExpectedSchema ??= GetExpectedSchema(context.Operation.Responses.ToDictionary(item => item.Key, item => item.Value.Schema), context.Response.StatusCode);

return response;
}

response.ExpectedSchema ??= GetExpectedSchema(operation.Responses.ToDictionary(item => item.Key, item => item.Value.Schema), responseMessage.StatusCode);
// Delegate the response building to the custom response factory if provided.
if (this._responseFactory is not null)
{
var response = await this._responseFactory(new(operation: operation, request: requestMessage, response: responseMessage, internalFactory: Build), cancellationToken).ConfigureAwait(false);

// Handling the case when the content is a stream
if (response.Content is Stream stream and not HttpResponseStream)
{
// Wrap the stream content to capture the HTTP response message, delegating its disposal to the caller.
response.Content = new HttpResponseStream(stream, responseMessage);
}

return response;
}

return response;
return await Build(new(operation: operation, request: requestMessage, response: responseMessage, internalFactory: null!), cancellationToken).ConfigureAwait(false);
}

#endregion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Net.Mime;
using System.Text;
using System.Text.Json.Nodes;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Plugins.OpenApi;
Expand Down Expand Up @@ -606,6 +607,42 @@ public async Task ItShouldResolveArgumentsBySanitizedParameterNamesAsync()
Assert.Equal(23.4f, deserializedPayload["float?parameter"]!.GetValue<float>());
}

[Fact]
public async Task ItShouldPropagateRestApiOperationResponseFactoryToRunnerAsync()
{
// Arrange
bool restApiOperationResponseFactoryIsInvoked = false;

async Task<RestApiOperationResponse> RestApiOperationResponseFactory(RestApiOperationResponseFactoryContext context, CancellationToken cancellationToken)
{
restApiOperationResponseFactoryIsInvoked = true;

return await context.InternalFactory(context, cancellationToken);
}

using var messageHandlerStub = new HttpMessageHandlerStub();
using var httpClient = new HttpClient(messageHandlerStub, false);

this._executionParameters.HttpClient = httpClient;
this._executionParameters.RestApiOperationResponseFactory = RestApiOperationResponseFactory;

var openApiPlugins = await OpenApiKernelPluginFactory.CreateFromOpenApiAsync("fakePlugin", this._openApiDocument, this._executionParameters);

var kernel = new Kernel();

var arguments = new KernelArguments
{
{ "secret-name", "fake-secret-name" },
{ "api-version", "fake-api-version" }
};

// Act
await kernel.InvokeAsync(openApiPlugins["GetSecret"], arguments);

// Assert
Assert.True(restApiOperationResponseFactoryIsInvoked);
}

/// <summary>
/// Generate theory data for ItAddSecurityMetadataToOperationAsync
/// </summary>
Expand Down
Loading

0 comments on commit 941ee64

Please sign in to comment.