From 40c4a0d594b17cbe1ba86d08ceef2d8508c65b84 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Tue, 26 Nov 2024 13:57:53 -0800 Subject: [PATCH] combine Handle and CallObject --- .../Abstractions/IAgentBase.cs | 2 - .../src/Microsoft.AutoGen/Agents/AgentBase.cs | 79 ++++++++----------- .../AgentBaseTests.cs | 3 +- 3 files changed, 35 insertions(+), 49 deletions(-) diff --git a/dotnet/src/Microsoft.AutoGen/Abstractions/IAgentBase.cs b/dotnet/src/Microsoft.AutoGen/Abstractions/IAgentBase.cs index ee7b9e74583c..b0d4777dfc26 100644 --- a/dotnet/src/Microsoft.AutoGen/Abstractions/IAgentBase.cs +++ b/dotnet/src/Microsoft.AutoGen/Abstractions/IAgentBase.cs @@ -11,8 +11,6 @@ public interface IAgentBase AgentId AgentId { get; } IAgentRuntime Context { get; } - // Methods - Task CallHandler(CloudEvent item); Task HandleRequest(RpcRequest request); void ReceiveMessage(Message message); Task StoreAsync(AgentState state, CancellationToken cancellationToken = default); diff --git a/dotnet/src/Microsoft.AutoGen/Agents/AgentBase.cs b/dotnet/src/Microsoft.AutoGen/Agents/AgentBase.cs index 345e6d34c826..b29bb2b7222b 100644 --- a/dotnet/src/Microsoft.AutoGen/Agents/AgentBase.cs +++ b/dotnet/src/Microsoft.AutoGen/Agents/AgentBase.cs @@ -12,7 +12,7 @@ namespace Microsoft.AutoGen.Agents; -public abstract class AgentBase : IAgentBase, IHandle +public abstract class AgentBase : IAgentBase, IHandle, IHandle { public static readonly ActivitySource s_source = new("AutoGen.Agent"); public AgentId AgentId => _context.AgentId; @@ -93,7 +93,7 @@ protected internal async Task HandleRpcMessage(Message msg, CancellationToken ca { var activity = this.ExtractActivity(msg.CloudEvent.Type, msg.CloudEvent.Metadata); await this.InvokeWithActivityAsync( - static ((AgentBase Agent, CloudEvent Item) state, CancellationToken _) => state.Agent.CallHandler(state.Item), + static ((AgentBase Agent, CloudEvent Item) state, CancellationToken _) => state.Agent.HandleObject(state.Item), (this, msg.CloudEvent), activity, msg.CloudEvent.Type, cancellationToken).ConfigureAwait(false); @@ -242,7 +242,35 @@ static async ((AgentBase Agent, CloudEvent Event) state, CancellationToken ct) = item.Type, cancellationToken).ConfigureAwait(false); } - public Task CallHandler(CloudEvent item) + public Task HandleRequest(RpcRequest request) => Task.FromResult(new RpcResponse { Error = "Not implemented" }); + + public virtual Task HandleObject(object item) + { + if (item is CloudEvent ce) + { + return Handle(ce); + } + + var genericInterfaceType = typeof(IHandle<>).MakeGenericType(item.GetType()); + + // check that our target actually implements this interface, otherwise call the default static + if (genericInterfaceType.IsAssignableFrom(this.GetType())) + { + var methodInfo = genericInterfaceType.GetMethod(nameof(IHandle.Handle), BindingFlags.Public | BindingFlags.Instance) + ?? throw new InvalidOperationException($"Method not found on type {genericInterfaceType.FullName}"); + + return methodInfo.Invoke(this, [item]) as Task ?? throw new InvalidOperationException("Method did not return a Task"); + } + + // otherwise, complain + throw new InvalidOperationException($"No handler found for type {item.GetType().FullName}"); + } + public async ValueTask PublishEventAsync(string topic, IMessage evt, CancellationToken cancellationToken = default) + { + await PublishEventAsync(evt.ToCloudEvent(topic), cancellationToken).ConfigureAwait(false); + } + + public virtual Task Handle(CloudEvent item) { // Only send the event to the handler if the agent type is handling that type // foreach of the keys in the EventTypes.EventsMap[] if it contains the item.type @@ -250,25 +278,11 @@ public Task CallHandler(CloudEvent item) { if (EventTypes.EventsMap[key].Contains(item.Type)) { - var payload = item.ProtoData.Unpack(EventTypes.TypeRegistry); - var convertedPayload = Convert.ChangeType(payload, EventTypes.Types[item.Type]); - var genericInterfaceType = typeof(IHandle<>).MakeGenericType(EventTypes.Types[item.Type]); - - MethodInfo methodInfo; try { - // check that our target actually implements this interface, otherwise call the default static - if (genericInterfaceType.IsAssignableFrom(this.GetType())) - { - methodInfo = genericInterfaceType.GetMethod(nameof(IHandle.Handle), BindingFlags.Public | BindingFlags.Instance) - ?? throw new InvalidOperationException($"Method not found on type {genericInterfaceType.FullName}"); - return methodInfo.Invoke(this, [payload]) as Task ?? Task.CompletedTask; - } - else - { - // The error here is we have registered for an event that we do not have code to listen to - throw new InvalidOperationException($"No handler found for event '{item.Type}'; expecting IHandle<{item.Type}> implementation."); - } + var payload = item.ProtoData.Unpack(EventTypes.TypeRegistry); + var convertedPayload = Convert.ChangeType(payload, EventTypes.Types[item.Type]); + return this.HandleObject(convertedPayload); } catch (Exception ex) { @@ -280,29 +294,4 @@ public Task CallHandler(CloudEvent item) return Task.CompletedTask; } - - public Task HandleRequest(RpcRequest request) => Task.FromResult(new RpcResponse { Error = "Not implemented" }); - - //TODO: should this be async and cancellable? - public virtual Task HandleObject(object item) - { - // get all Handle methods - var handleTMethods = this.GetType().GetMethods().Where(m => m.Name == "Handle" && m.GetParameters().Length == 1).ToList(); - - // get the one that matches the type of the item - var handleTMethod = handleTMethods.FirstOrDefault(m => m.GetParameters()[0].ParameterType == item.GetType()); - - // if we found one, invoke it - if (handleTMethod != null) - { - return (Task)handleTMethod.Invoke(this, [item])!; - } - - // otherwise, complain - throw new InvalidOperationException($"No handler found for type {item.GetType().FullName}"); - } - public async ValueTask PublishEventAsync(string topic, IMessage evt, CancellationToken cancellationToken = default) - { - await PublishEventAsync(evt.ToCloudEvent(topic), cancellationToken).ConfigureAwait(false); - } } diff --git a/dotnet/test/Microsoft.AutoGen.Agents.Tests/AgentBaseTests.cs b/dotnet/test/Microsoft.AutoGen.Agents.Tests/AgentBaseTests.cs index e58fdb00f0a0..365cecd55ecd 100644 --- a/dotnet/test/Microsoft.AutoGen.Agents.Tests/AgentBaseTests.cs +++ b/dotnet/test/Microsoft.AutoGen.Agents.Tests/AgentBaseTests.cs @@ -24,7 +24,6 @@ public async Task ItInvokeRightHandlerTestAsync() { var mockContext = new Mock(); var agent = new TestAgent(mockContext.Object, new EventTypes(TypeRegistry.Empty, [], []), new Logger(new LoggerFactory())); - await agent.HandleObject("hello world"); await agent.HandleObject(42); @@ -57,7 +56,7 @@ await client.PublishMessageAsync(new TextMessage() /// /// The test agent is a simple agent that is used for testing purposes. /// - public class TestAgent : AgentBase, IHandle, IHandle, IHandle + public class TestAgent : AgentBase, IHandle, IHandle, IHandle, IHandleConsole { public TestAgent( IAgentRuntime context,