Skip to content

Commit

Permalink
RSDK-8566 Send gRPC heartbeats from signaling server to answerer (#356)
Browse files Browse the repository at this point in the history
  • Loading branch information
benjirewis authored Oct 7, 2024
1 parent 1bbe91e commit 598a0ed
Show file tree
Hide file tree
Showing 9 changed files with 408 additions and 183 deletions.
406 changes: 241 additions & 165 deletions proto/rpc/webrtc/v1/signaling.pb.go

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions proto/rpc/webrtc/v1/signaling.proto
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ message AnswerRequestErrorStage {
google.rpc.Status status = 1;
}

// AnswerRequestHeartbeatStage is sent periodically to verify liveness of answerer.
message AnswerRequestHeartbeatStage {
}

// AnswerRequest is the SDP offer that the controlling side is making via the answering
// stream.
message AnswerRequest {
Expand All @@ -153,6 +157,9 @@ message AnswerRequest {

// error is sent any time before done
AnswerRequestErrorStage error = 5;

// heartbeat is sent periodically to verify liveness of answerer
AnswerRequestHeartbeatStage heartbeat = 6;
}
}

Expand All @@ -177,6 +184,7 @@ message AnswerResponseErrorStage {
google.rpc.Status status = 1;
}


// AnswerResponse is the SDP answer that an answerer responds with.
message AnswerResponse {
string uuid = 1;
Expand Down
3 changes: 2 additions & 1 deletion rpc/dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ func testDial(t *testing.T, signalingCallQueue WebRTCCallQueue, logger utils.Zap
)
test.That(t, err, test.ShouldBeNil)

signalingServer := NewWebRTCSignalingServer(signalingCallQueue, nil, logger, externalSignalingHosts...)
signalingServer := NewWebRTCSignalingServer(signalingCallQueue, nil, logger,
defaultHeartbeatInterval, externalSignalingHosts...)
test.That(t, rpcServerExternal.RegisterServiceServer(
context.Background(),
&webrtcpb.SignalingService_ServiceDesc,
Expand Down
3 changes: 2 additions & 1 deletion rpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,8 @@ func NewServer(logger utils.ZapCompatibleLogger, opts ...ServerOption) (Server,
logger.Debug("will run internal signaling service")
signalingCallQueue := NewMemoryWebRTCCallQueue(logger)
server.signalingCallQueue = signalingCallQueue
server.signalingServer = NewWebRTCSignalingServer(signalingCallQueue, nil, logger, internalSignalingHosts...)
server.signalingServer = NewWebRTCSignalingServer(signalingCallQueue, nil, logger,
defaultHeartbeatInterval, internalSignalingHosts...)
if err := server.RegisterServiceServer(
context.Background(),
&webrtcpb.SignalingService_ServiceDesc,
Expand Down
15 changes: 10 additions & 5 deletions rpc/wrtc_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ func TestWebRTCClientServerWithMongoDBQueue(t *testing.T) {

//nolint:thelper
func testWebRTCClientServer(t *testing.T, signalingCallQueue WebRTCCallQueue, logger utils.ZapCompatibleLogger) {
signalingServer := NewWebRTCSignalingServer(signalingCallQueue, nil, logger)
signalingServer := NewWebRTCSignalingServer(signalingCallQueue, nil, logger,
defaultHeartbeatInterval)
defer signalingServer.Close()

grpcListener, err := net.Listen("tcp", "localhost:0")
Expand Down Expand Up @@ -145,7 +146,8 @@ func TestWebRTCClientDialCancelWithMongoDBQueue(t *testing.T) {

//nolint:thelper
func testWebRTCClientDialCancel(t *testing.T, signalingCallQueue WebRTCCallQueue, logger utils.ZapCompatibleLogger) {
signalingServer := NewWebRTCSignalingServer(signalingCallQueue, nil, logger)
signalingServer := NewWebRTCSignalingServer(signalingCallQueue, nil, logger,
defaultHeartbeatInterval)
defer signalingServer.Close()

grpcListener, err := net.Listen("tcp", "localhost:0")
Expand Down Expand Up @@ -232,7 +234,8 @@ func TestWebRTCClientDialReflectAnswererErrorWithMongoDBQueue(t *testing.T) {

//nolint:thelper
func testWebRTCClientDialReflectAnswererError(t *testing.T, signalingCallQueue WebRTCCallQueue, logger utils.ZapCompatibleLogger) {
signalingServer := NewWebRTCSignalingServer(signalingCallQueue, nil, logger)
signalingServer := NewWebRTCSignalingServer(signalingCallQueue, nil, logger,
defaultHeartbeatInterval)
defer signalingServer.Close()

grpcListener, err := net.Listen("tcp", "localhost:0")
Expand Down Expand Up @@ -326,7 +329,8 @@ func TestWebRTCClientDialConcurrentWithMongoDBQueue(t *testing.T) {
//
//nolint:thelper
func testWebRTCClientDialConcurrent(t *testing.T, signalingCallQueue WebRTCCallQueue, logger utils.ZapCompatibleLogger) {
signalingServer := NewWebRTCSignalingServer(signalingCallQueue, nil, logger)
signalingServer := NewWebRTCSignalingServer(signalingCallQueue, nil, logger,
defaultHeartbeatInterval)
defer signalingServer.Close()

grpcListener, err := net.Listen("tcp", "localhost:0")
Expand Down Expand Up @@ -464,7 +468,8 @@ func TestWebRTCClientAnswerConcurrentWithMongoDBQueue(t *testing.T) {

//nolint:thelper
func testWebRTCClientAnswerConcurrent(t *testing.T, signalingCallQueue WebRTCCallQueue, logger utils.ZapCompatibleLogger) {
signalingServer := NewWebRTCSignalingServer(signalingCallQueue, nil, logger)
signalingServer := NewWebRTCSignalingServer(signalingCallQueue, nil, logger,
defaultHeartbeatInterval)
defer signalingServer.Close()

grpcListener, err := net.Listen("tcp", "localhost:0")
Expand Down
4 changes: 2 additions & 2 deletions rpc/wrtc_server_channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func TestWebRTCServerChannel(t *testing.T) {
// It helps that it is in our package.
queue := newMemoryWebRTCCallQueueTest(logger)
defer queue.Close()
signalServer := NewWebRTCSignalingServer(queue, nil, logger)
signalServer := NewWebRTCSignalingServer(queue, nil, logger, defaultHeartbeatInterval)
defer signalServer.Close()
server.RegisterService(
&webrtcpb.SignalingService_ServiceDesc,
Expand Down Expand Up @@ -265,7 +265,7 @@ func TestWebRTCServerChannelResetStream(t *testing.T) {
// It helps that it is in our package.
queue := newMemoryWebRTCCallQueueTest(logger)
defer queue.Close()
signalServer := NewWebRTCSignalingServer(queue, nil, logger)
signalServer := NewWebRTCSignalingServer(queue, nil, logger, defaultHeartbeatInterval)
defer signalServer.Close()
server.RegisterService(
&webrtcpb.SignalingService_ServiceDesc,
Expand Down
35 changes: 30 additions & 5 deletions rpc/wrtc_signaling_answerer.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (

const testDelayAnswererNegotiationVar = "TEST_DELAY_ANSWERER_NEGOTIATION"

const heartbeatReceivedLog = "Received a heartbeat from the signaling server"

// A webrtcSignalingAnswerer listens for and answers calls with a given signaling service. It is
// directly connected to a Server that will handle the actual calls/connections over WebRTC
// data channels.
Expand Down Expand Up @@ -157,6 +159,7 @@ func (ans *webrtcSignalingAnswerer) startAnswerer() {
client := webrtcpb.NewSignalingServiceClient(conn)
md := metadata.New(nil)
md.Append(RPCHostMetadataField, ans.hosts...)
md.Append(HeartbeatsAllowedMetadataField, "true")
answerCtx := metadata.NewOutgoingContext(ans.closeCtx, md)
answerClient, err := client.Answer(answerCtx)
if err != nil {
Expand Down Expand Up @@ -212,9 +215,28 @@ func (ans *webrtcSignalingAnswerer) startAnswerer() {
continue
}

// `client.Recv` waits, typically for a long time, for a caller to show up. Which is
// when the signaling server will send a response saying someone wants to connect.
incomingCallerReq, err := client.Recv()
var incomingCallerReq *webrtcpb.AnswerRequest
for {
// `client.Recv` waits, typically for a long time, for a caller to show
// up. Which is when the signaling server will send a response saying
// someone wants to connect. It can also receive heartbeats every 15s.
//
// The answerer does not respond to heartbeats. The signaling server is
// only using heartbeats to ensure the answerer is reachable. If the
// answerer is down, the heartbeat will error in the server's
// heartbeating goroutine, the server's stream's context will be
// canceled, and the server will stop handling interactions for this
// answerer.
incomingCallerReq, err = client.Recv()
if err != nil {
break
}
if _, ok := incomingCallerReq.Stage.(*webrtcpb.AnswerRequest_Heartbeat); ok {
ans.logger.Debug(heartbeatReceivedLog)
continue
}
break // not a heartbeat
}
if err != nil {
if checkExceptionalError(err) != nil {
ans.logger.Warnw("error communicating with signaling server", "error", err)
Expand All @@ -236,8 +258,9 @@ func (ans *webrtcSignalingAnswerer) startAnswerer() {

initStage, ok := incomingCallerReq.Stage.(*webrtcpb.AnswerRequest_Init)
if !ok {
aa.sendError(fmt.Errorf("expected first stage to be init; got %T", incomingCallerReq.Stage))
ans.logger.Warnw("error communicating with signaling server", "error", err)
err := fmt.Errorf("expected first stage to be init or heartbeat; got %T", incomingCallerReq.Stage)
aa.sendError(err)
ans.logger.Warn(err.Error())
continue
}

Expand Down Expand Up @@ -535,6 +558,8 @@ func (aa *answerAttempt) connect(ctx context.Context) (err error) {
respStatus := status.FromProto(stage.Error.Status)
aa.sendError(fmt.Errorf("error from requester: %w", respStatus.Err()))
return
case *webrtcpb.AnswerRequest_Heartbeat:
aa.logger.Debug(heartbeatReceivedLog)
default:
aa.sendError(fmt.Errorf("unexpected stage %T", stage))
return
Expand Down
57 changes: 56 additions & 1 deletion rpc/wrtc_signaling_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,22 @@ type WebRTCSignalingServer struct {
cancelCtx context.Context
cancelFunc func()
logger utils.ZapCompatibleLogger

// Interval at which to send heartbeats.
heartbeatInterval time.Duration
}

// NewWebRTCSignalingServer makes a new signaling server that uses the given
// call queue and looks routes based on a given robot host. If forHosts is
// non-empty, the server will only accept the given hosts and reject all
// others.
// others. The signaling server will send heartbeats to answerers at the
// provided heartbeatInterval if the answerer requests heartbeats through
// the initial Answer metadata.
func NewWebRTCSignalingServer(
callQueue WebRTCCallQueue,
webrtcConfigProvider WebRTCConfigProvider,
logger utils.ZapCompatibleLogger,
heartbeatInterval time.Duration,
forHosts ...string,
) *WebRTCSignalingServer {
forHostsSet := make(map[string]struct{}, len(forHosts))
Expand All @@ -69,12 +75,20 @@ func NewWebRTCSignalingServer(
cancelCtx: cancelCtx,
cancelFunc: cancelFunc,
logger: logger,
heartbeatInterval: heartbeatInterval,
}
}

// RPCHostMetadataField is the identifier of a host.
const RPCHostMetadataField = "rpc-host"

// HeartbeatsAllowedMetadataField is the identifier for allowing heartbeats
// from a signaling server to answerers.
const HeartbeatsAllowedMetadataField = "heartbeats-allowed"

// Default interval at which to send heartbeats.
const defaultHeartbeatInterval = 15 * time.Second

// HostFromCtx gets the host being called/answered for from the context.
func HostFromCtx(ctx context.Context) (string, error) {
hosts, err := HostsFromCtx(ctx)
Expand Down Expand Up @@ -127,6 +141,17 @@ func (srv *WebRTCSignalingServer) validateHosts(hosts ...string) error {
return nil
}

// HeartbeatsAllowedFromCtx checks if heartbeats are allowed with respect to
// the context.
func HeartbeatsAllowedFromCtx(ctx context.Context) bool {
md, ok := metadata.FromIncomingContext(ctx)
if !ok || len(md[HeartbeatsAllowedMetadataField]) == 0 {
return false
}
// Only allow "true" as a value for now.
return md[HeartbeatsAllowedMetadataField][0] == "true"
}

// Call is a request/offer to start a caller with the connected answerer.
func (srv *WebRTCSignalingServer) Call(req *webrtcpb.CallRequest, server webrtcpb.SignalingService_CallServer) (callErr error) {
ctx := server.Context()
Expand Down Expand Up @@ -331,6 +356,36 @@ func (srv *WebRTCSignalingServer) Answer(server webrtcpb.SignalingService_Answer
}
defer srv.clearAdditionalICEServers(hosts)

// If heartbeats allowed (indicated by answerer), start goroutine to send
// heartbeats.
//
// The answerer does not respond to heartbeats. The signaling server is only
// using heartbeats to ensure the answerer is reachable. If the answerer is
// down, the heartbeat will error in the heartbeating goroutine below, the
// stream's context will be canceled, and we will stop handling interactions
// for this answerer. We stop handling interactions because the stream's
// context (`ctx` here and below) is used in the `RecvOffer` call below this
// goroutine that waits for a caller to attempt to establish a connection.
if HeartbeatsAllowedFromCtx(ctx) {
utils.PanicCapturingGo(func() {
for {
select {
case <-time.After(srv.heartbeatInterval):
if err := server.Send(&webrtcpb.AnswerRequest{
Stage: &webrtcpb.AnswerRequest_Heartbeat{},
}); err != nil {
srv.logger.Debugw(
"error sending answer heartbeat",
"error", err,
)
}
case <-ctx.Done():
return
}
}
})
}

offer, err := srv.callQueue.RecvOffer(ctx, hosts)
if err != nil {
return err
Expand Down
60 changes: 57 additions & 3 deletions rpc/wrtc_signaling_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ func testWebRTCSignaling(t *testing.T, signalingCallQueue WebRTCCallQueue, logge
hosts := []string{"yeehaw", "woahthere"}
for _, host := range hosts {
t.Run(host, func(t *testing.T) {
signalingServer := NewWebRTCSignalingServer(signalingCallQueue, nil, logger)
signalingServer := NewWebRTCSignalingServer(signalingCallQueue, nil, logger,
defaultHeartbeatInterval)
defer signalingServer.Close()

grpcListener, err := net.Listen("tcp", "localhost:0")
Expand Down Expand Up @@ -147,9 +148,11 @@ func testWebRTCSignaling(t *testing.T, signalingCallQueue WebRTCCallQueue, logge
})
}

webrtcServer.Stop()
// Mimic order of stopping used in `simpleServer.Stop()` (answerer, sig
// server's gRPC listener, then machine).
answerer.Stop()
grpcServer.Stop()
webrtcServer.Stop()
test.That(t, <-serveDone, test.ShouldBeNil)
})
}
Expand All @@ -164,7 +167,8 @@ func TestWebRTCAnswererImmediateStop(t *testing.T) {
test.That(t, signalingCallQueue.Close(), test.ShouldBeNil)
}()

signalingServer := NewWebRTCSignalingServer(signalingCallQueue, nil, logger)
signalingServer := NewWebRTCSignalingServer(signalingCallQueue, nil, logger,
defaultHeartbeatInterval)
defer signalingServer.Close()

grpcListener, err := net.Listen("tcp", "localhost:0")
Expand Down Expand Up @@ -199,3 +203,53 @@ func TestWebRTCAnswererImmediateStop(t *testing.T) {
}()
wg.Wait()
}

func TestSignalingHeartbeats(t *testing.T) {
logger, observer := golog.NewObservedTestLogger(t)

// Create a simple signaling server with an in-memory call-queue.
signalingCallQueue := NewMemoryWebRTCCallQueue(logger)
defer func() {
test.That(t, signalingCallQueue.Close(), test.ShouldBeNil)
}()
// Use a lowered heartbeatInterval (500ms instead of 15s).
signalingServer := NewWebRTCSignalingServer(signalingCallQueue, nil, logger,
500*time.Millisecond)
defer signalingServer.Close()
grpcListener, err := net.Listen("tcp", "localhost:0")
test.That(t, err, test.ShouldBeNil)
grpcServer := grpc.NewServer()
grpcServer.RegisterService(&webrtcpb.SignalingService_ServiceDesc, signalingServer)
serveDone := make(chan error)
go func() {
serveDone <- grpcServer.Serve(grpcListener)
}()

// Create a simple WebRTC server that (needlessly) serves the Echo service.
// Start an answerer with it.
webrtcServer := newWebRTCServer(logger)
webrtcServer.RegisterService(&echopb.EchoService_ServiceDesc, &echoserver.Server{})
answerer := newWebRTCSignalingAnswerer(
grpcListener.Addr().String(),
[]string{"foo"},
webrtcServer,
[]DialOption{WithInsecure()},
webrtc.Configuration{},
logger,
)
answerer.Start()

// Assert that the answerer eventually logs received heartbeats.
testutils.WaitForAssertion(t, func(tb testing.TB) {
t.Helper()
test.That(tb, observer.FilterMessageSnippet(heartbeatReceivedLog).Len(),
test.ShouldBeGreaterThan, 0)
})

// Mimic order of stopping used in `simpleServer.Stop()` (answerer, sig
// server's gRPC listener, then machine).
answerer.Stop()
grpcServer.Stop()
webrtcServer.Stop()
test.That(t, <-serveDone, test.ShouldBeNil)
}

0 comments on commit 598a0ed

Please sign in to comment.