Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Supports callbacks when reading a message fails #331

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions server/callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ func (c CallbacksStruct) OnConnecting(request *http.Request) types.ConnectionRes
// ConnectionCallbacksStruct is a struct that implements ConnectionCallbacks interface and allows
// to override only the methods that are needed.
type ConnectionCallbacksStruct struct {
OnConnectedFunc func(ctx context.Context, conn types.Connection)
OnMessageFunc func(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent
OnConnectionCloseFunc func(conn types.Connection)
OnConnectedFunc func(ctx context.Context, conn types.Connection)
OnMessageFunc func(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent
OnConnectionCloseFunc func(conn types.Connection)
OnReadMessageErrorFunc func(conn types.Connection, mt int, msgByte []byte, err error)
}

var _ types.ConnectionCallbacks = (*ConnectionCallbacksStruct)(nil)
Expand Down Expand Up @@ -61,3 +62,10 @@ func (c ConnectionCallbacksStruct) OnConnectionClose(conn types.Connection) {
c.OnConnectionCloseFunc(conn)
}
}

// OnReadMessageError implements types.ConnectionCallbacks.
func (c ConnectionCallbacksStruct) OnReadMessageError(conn types.Connection, mt int, msgByte []byte, err error) {
if c.OnReadMessageErrorFunc != nil {
c.OnReadMessageErrorFunc(conn, mt, msgByte, err)
}
}
64 changes: 38 additions & 26 deletions server/serverimpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"compress/gzip"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
Expand All @@ -19,16 +20,16 @@ import (
serverTypes "github.com/open-telemetry/opamp-go/server/types"
)

var (
errAlreadyStarted = errors.New("already started")
)
var errAlreadyStarted = errors.New("already started")

const defaultOpAMPPath = "/v1/opamp"
const headerContentType = "Content-Type"
const headerContentEncoding = "Content-Encoding"
const headerAcceptEncoding = "Accept-Encoding"
const contentEncodingGzip = "gzip"
const contentTypeProtobuf = "application/x-protobuf"
const (
defaultOpAMPPath = "/v1/opamp"
headerContentType = "Content-Type"
headerContentEncoding = "Content-Encoding"
headerAcceptEncoding = "Accept-Encoding"
contentEncodingGzip = "gzip"
contentTypeProtobuf = "application/x-protobuf"
)

type server struct {
logger types.Logger
Expand Down Expand Up @@ -230,27 +231,39 @@ func (s *server) handleWSConnection(reqCtx context.Context, wsConn *websocket.Co
// Loop until fail to read from the WebSocket connection.
for {
msgContext := context.Background()
request := protobufs.AgentToServer{}

// Block until the next message can be read.
mt, msgBytes, err := wsConn.ReadMessage()
if err != nil {
if !websocket.IsUnexpectedCloseError(err) {
s.logger.Errorf(msgContext, "Cannot read a message from WebSocket: %v", err)
break
isBreak, err := func() (bool, error) {
if err != nil {
if !websocket.IsUnexpectedCloseError(err) {
s.logger.Errorf(msgContext, "Cannot read a message from WebSocket: %v", err)
return true, err
}
// This is a normal closing of the WebSocket connection.
s.logger.Debugf(msgContext, "Agent disconnected: %v", err)
return true, err
}
if mt != websocket.BinaryMessage {
err = fmt.Errorf("Received unexpected message type from WebSocket: %v", mt)
s.logger.Errorf(msgContext, err.Error())
return false, err
}
// This is a normal closing of the WebSocket connection.
s.logger.Debugf(msgContext, "Agent disconnected: %v", err)
break
}
if mt != websocket.BinaryMessage {
s.logger.Errorf(msgContext, "Received unexpected message type from WebSocket: %v", mt)
continue
}

// Decode WebSocket message as a Protobuf message.
var request protobufs.AgentToServer
err = internal.DecodeWSMessage(msgBytes, &request)
// Decode WebSocket message as a Protobuf message.
err = internal.DecodeWSMessage(msgBytes, &request)
if err != nil {
s.logger.Errorf(msgContext, "Cannot decode message from WebSocket: %v", err)
return false, err
}
return false, nil
}()
if err != nil {
s.logger.Errorf(msgContext, "Cannot decode message from WebSocket: %v", err)
connectionCallbacks.OnReadMessageError(agentConn, mt, msgBytes, err)
if isBreak {
break
}
continue
}

Expand Down Expand Up @@ -377,7 +390,6 @@ func (s *server) handlePlainHTTPRequest(req *http.Request, w http.ResponseWriter
w.Header().Set(headerContentEncoding, contentEncodingGzip)
}
_, err = w.Write(bodyBytes)

if err != nil {
s.logger.Debugf(req.Context(), "Cannot send HTTP response: %v", err)
}
Expand Down
51 changes: 46 additions & 5 deletions server/serverimpl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,52 @@ func TestServerReceiveSendMessage(t *testing.T) {
assert.EqualValues(t, settings.CustomCapabilities, response.CustomCapabilities.Capabilities)
}

func TestServerReceiveSendErrorMessage(t *testing.T) {
var rcvMsg atomic.Value
type ErrorInfo struct {
mt int
msgByte []byte
err error
}
callbacks := CallbacksStruct{
OnConnectingFunc: func(request *http.Request) types.ConnectionResponse {
return types.ConnectionResponse{Accept: true, ConnectionCallbacks: ConnectionCallbacksStruct{
OnReadMessageErrorFunc: func(conn types.Connection, mt int, msgByte []byte, err error) {
rcvMsg.Store(ErrorInfo{
mt: mt,
msgByte: msgByte,
err: err,
})
},
}}
},
}

// Start a Server.
settings := &StartSettings{Settings: Settings{
Callbacks: callbacks,
CustomCapabilities: []string{"local.test.capability"},
}}
srv := startServer(t, settings)
defer srv.Stop(context.Background())

// Connect using a WebSocket client.
conn, _, _ := dialClient(settings)
require.NotNil(t, conn)
defer conn.Close()

// Send a message to the Server.
err := conn.WriteMessage(websocket.TextMessage, []byte(""))
require.NoError(t, err)

// Wait until Server receives the message.
eventually(t, func() bool { return rcvMsg.Load() != nil })
errInfo := rcvMsg.Load().(ErrorInfo)
assert.EqualValues(t, websocket.TextMessage, errInfo.mt)
assert.EqualValues(t, []byte(""), errInfo.msgByte)
assert.NotNil(t, errInfo.err)
}

func TestServerReceiveSendMessageWithCompression(t *testing.T) {
// Use highly compressible config body.
uncompressedCfg := []byte(strings.Repeat("test", 10000))
Expand Down Expand Up @@ -620,7 +666,6 @@ func TestServerAttachSendMessagePlainHTTP(t *testing.T) {
}

func TestServerHonoursClientRequestContentEncoding(t *testing.T) {

hc := http.Client{}
var rcvMsg atomic.Value
var onConnectedCalled, onCloseCalled int32
Expand Down Expand Up @@ -698,7 +743,6 @@ func TestServerHonoursClientRequestContentEncoding(t *testing.T) {
}

func TestServerHonoursAcceptEncoding(t *testing.T) {

hc := http.Client{}
var rcvMsg atomic.Value
var onConnectedCalled, onCloseCalled int32
Expand Down Expand Up @@ -985,7 +1029,6 @@ func BenchmarkSendToClient(b *testing.B) {
}
srv := New(&sharedinternal.NopLogger{})
err := srv.Start(*settings)

if err != nil {
b.Error(err)
}
Expand Down Expand Up @@ -1017,7 +1060,6 @@ func BenchmarkSendToClient(b *testing.B) {

for _, conn := range serverConnections {
err := conn.Send(context.Background(), &protobufs.ServerToAgent{})

if err != nil {
b.Error(err)
}
Expand All @@ -1026,5 +1068,4 @@ func BenchmarkSendToClient(b *testing.B) {
for _, conn := range clientConnections {
conn.Close()
}

}
3 changes: 3 additions & 0 deletions server/types/callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,7 @@ type ConnectionCallbacks interface {

// OnConnectionClose is called when the OpAMP connection is closed.
OnConnectionClose(conn Connection)

// OnConnectionError is called when an error occurs while reading or serializing a message.
OnReadMessageError(conn Connection, mt int, msgByte []byte, err error)
}