From 2071e66a606d67f311eb11dc7d2c453832aa4aa7 Mon Sep 17 00:00:00 2001 From: toad Date: Fri, 27 Dec 2024 15:50:29 +0800 Subject: [PATCH] feat: Supports callbacks when reading a message fails --- server/callbacks.go | 14 +++++++-- server/serverimpl.go | 64 +++++++++++++++++++++++---------------- server/serverimpl_test.go | 51 ++++++++++++++++++++++++++++--- server/types/callbacks.go | 3 ++ 4 files changed, 98 insertions(+), 34 deletions(-) diff --git a/server/callbacks.go b/server/callbacks.go index c3e75a96..8849910b 100644 --- a/server/callbacks.go +++ b/server/callbacks.go @@ -28,9 +28,10 @@ func (c CallbacksStruct) OnConnecting(request *http.Request) types.ConnectionRes // ConnectionCallbacksStruct is a struct that implements ConnectionCallbacks interface and allows // to override only the methods that are needed. type ConnectionCallbacksStruct struct { - OnConnectedFunc func(ctx context.Context, conn types.Connection) - OnMessageFunc func(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent - OnConnectionCloseFunc func(conn types.Connection) + OnConnectedFunc func(ctx context.Context, conn types.Connection) + OnMessageFunc func(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent + OnConnectionCloseFunc func(conn types.Connection) + OnReadMessageErrorFunc func(conn types.Connection, mt int, msgByte []byte, err error) } var _ types.ConnectionCallbacks = (*ConnectionCallbacksStruct)(nil) @@ -61,3 +62,10 @@ func (c ConnectionCallbacksStruct) OnConnectionClose(conn types.Connection) { c.OnConnectionCloseFunc(conn) } } + +// OnReadMessageError implements types.ConnectionCallbacks. +func (c ConnectionCallbacksStruct) OnReadMessageError(conn types.Connection, mt int, msgByte []byte, err error) { + if c.OnReadMessageErrorFunc != nil { + c.OnReadMessageErrorFunc(conn, mt, msgByte, err) + } +} diff --git a/server/serverimpl.go b/server/serverimpl.go index 815c0528..6c0d73c5 100644 --- a/server/serverimpl.go +++ b/server/serverimpl.go @@ -5,6 +5,7 @@ import ( "compress/gzip" "context" "errors" + "fmt" "io" "net" "net/http" @@ -19,16 +20,16 @@ import ( serverTypes "github.com/open-telemetry/opamp-go/server/types" ) -var ( - errAlreadyStarted = errors.New("already started") -) +var errAlreadyStarted = errors.New("already started") -const defaultOpAMPPath = "/v1/opamp" -const headerContentType = "Content-Type" -const headerContentEncoding = "Content-Encoding" -const headerAcceptEncoding = "Accept-Encoding" -const contentEncodingGzip = "gzip" -const contentTypeProtobuf = "application/x-protobuf" +const ( + defaultOpAMPPath = "/v1/opamp" + headerContentType = "Content-Type" + headerContentEncoding = "Content-Encoding" + headerAcceptEncoding = "Accept-Encoding" + contentEncodingGzip = "gzip" + contentTypeProtobuf = "application/x-protobuf" +) type server struct { logger types.Logger @@ -230,27 +231,39 @@ func (s *server) handleWSConnection(reqCtx context.Context, wsConn *websocket.Co // Loop until fail to read from the WebSocket connection. for { msgContext := context.Background() + request := protobufs.AgentToServer{} + // Block until the next message can be read. mt, msgBytes, err := wsConn.ReadMessage() - if err != nil { - if !websocket.IsUnexpectedCloseError(err) { - s.logger.Errorf(msgContext, "Cannot read a message from WebSocket: %v", err) - break + isBreak, err := func() (bool, error) { + if err != nil { + if !websocket.IsUnexpectedCloseError(err) { + s.logger.Errorf(msgContext, "Cannot read a message from WebSocket: %v", err) + return true, err + } + // This is a normal closing of the WebSocket connection. + s.logger.Debugf(msgContext, "Agent disconnected: %v", err) + return true, err + } + if mt != websocket.BinaryMessage { + err = fmt.Errorf("Received unexpected message type from WebSocket: %v", mt) + s.logger.Errorf(msgContext, err.Error()) + return false, err } - // This is a normal closing of the WebSocket connection. - s.logger.Debugf(msgContext, "Agent disconnected: %v", err) - break - } - if mt != websocket.BinaryMessage { - s.logger.Errorf(msgContext, "Received unexpected message type from WebSocket: %v", mt) - continue - } - // Decode WebSocket message as a Protobuf message. - var request protobufs.AgentToServer - err = internal.DecodeWSMessage(msgBytes, &request) + // Decode WebSocket message as a Protobuf message. + err = internal.DecodeWSMessage(msgBytes, &request) + if err != nil { + s.logger.Errorf(msgContext, "Cannot decode message from WebSocket: %v", err) + return false, err + } + return false, nil + }() if err != nil { - s.logger.Errorf(msgContext, "Cannot decode message from WebSocket: %v", err) + connectionCallbacks.OnReadMessageError(agentConn, mt, msgBytes, err) + if isBreak { + break + } continue } @@ -377,7 +390,6 @@ func (s *server) handlePlainHTTPRequest(req *http.Request, w http.ResponseWriter w.Header().Set(headerContentEncoding, contentEncodingGzip) } _, err = w.Write(bodyBytes) - if err != nil { s.logger.Debugf(req.Context(), "Cannot send HTTP response: %v", err) } diff --git a/server/serverimpl_test.go b/server/serverimpl_test.go index d7a9098e..d20014b1 100644 --- a/server/serverimpl_test.go +++ b/server/serverimpl_test.go @@ -315,6 +315,52 @@ func TestServerReceiveSendMessage(t *testing.T) { assert.EqualValues(t, settings.CustomCapabilities, response.CustomCapabilities.Capabilities) } +func TestServerReceiveSendErrorMessage(t *testing.T) { + var rcvMsg atomic.Value + type ErrorInfo struct { + mt int + msgByte []byte + err error + } + callbacks := CallbacksStruct{ + OnConnectingFunc: func(request *http.Request) types.ConnectionResponse { + return types.ConnectionResponse{Accept: true, ConnectionCallbacks: ConnectionCallbacksStruct{ + OnReadMessageErrorFunc: func(conn types.Connection, mt int, msgByte []byte, err error) { + rcvMsg.Store(ErrorInfo{ + mt: mt, + msgByte: msgByte, + err: err, + }) + }, + }} + }, + } + + // Start a Server. + settings := &StartSettings{Settings: Settings{ + Callbacks: callbacks, + CustomCapabilities: []string{"local.test.capability"}, + }} + srv := startServer(t, settings) + defer srv.Stop(context.Background()) + + // Connect using a WebSocket client. + conn, _, _ := dialClient(settings) + require.NotNil(t, conn) + defer conn.Close() + + // Send a message to the Server. + err := conn.WriteMessage(websocket.TextMessage, []byte("")) + require.NoError(t, err) + + // Wait until Server receives the message. + eventually(t, func() bool { return rcvMsg.Load() != nil }) + errInfo := rcvMsg.Load().(ErrorInfo) + assert.EqualValues(t, websocket.TextMessage, errInfo.mt) + assert.EqualValues(t, []byte(""), errInfo.msgByte) + assert.NotNil(t, errInfo.err) +} + func TestServerReceiveSendMessageWithCompression(t *testing.T) { // Use highly compressible config body. uncompressedCfg := []byte(strings.Repeat("test", 10000)) @@ -620,7 +666,6 @@ func TestServerAttachSendMessagePlainHTTP(t *testing.T) { } func TestServerHonoursClientRequestContentEncoding(t *testing.T) { - hc := http.Client{} var rcvMsg atomic.Value var onConnectedCalled, onCloseCalled int32 @@ -698,7 +743,6 @@ func TestServerHonoursClientRequestContentEncoding(t *testing.T) { } func TestServerHonoursAcceptEncoding(t *testing.T) { - hc := http.Client{} var rcvMsg atomic.Value var onConnectedCalled, onCloseCalled int32 @@ -985,7 +1029,6 @@ func BenchmarkSendToClient(b *testing.B) { } srv := New(&sharedinternal.NopLogger{}) err := srv.Start(*settings) - if err != nil { b.Error(err) } @@ -1017,7 +1060,6 @@ func BenchmarkSendToClient(b *testing.B) { for _, conn := range serverConnections { err := conn.Send(context.Background(), &protobufs.ServerToAgent{}) - if err != nil { b.Error(err) } @@ -1026,5 +1068,4 @@ func BenchmarkSendToClient(b *testing.B) { for _, conn := range clientConnections { conn.Close() } - } diff --git a/server/types/callbacks.go b/server/types/callbacks.go index 0546903d..a9c0417e 100644 --- a/server/types/callbacks.go +++ b/server/types/callbacks.go @@ -52,4 +52,7 @@ type ConnectionCallbacks interface { // OnConnectionClose is called when the OpAMP connection is closed. OnConnectionClose(conn Connection) + + // OnConnectionError is called when an error occurs while reading or serializing a message. + OnReadMessageError(conn Connection, mt int, msgByte []byte, err error) }