Skip to content

Commit

Permalink
SEC-39: Check if JWT is associated with session (#357)
Browse files Browse the repository at this point in the history
  • Loading branch information
jr22 authored Sep 27, 2024
1 parent c7c7522 commit a22af49
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 11 deletions.
6 changes: 4 additions & 2 deletions rpc/server_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ type JWTClaims struct {
jwt.RegisteredClaims
AuthCredentialsType CredentialsType `json:"rpc_creds_type,omitempty"`
AuthMetadata map[string]string `json:"rpc_auth_md,omitempty"`
ApplicationID string `json:"applicationId,omitempty"`
}

// Entity returns the entity from the claims' Subject.
Expand Down Expand Up @@ -238,7 +239,8 @@ func (wrapped ctxWrappedServerStream) Context() context.Context {
return wrapped.ctx
}

func tokenFromContext(ctx context.Context) (string, error) {
// TokenFromContext returns the bearer token from the authorization header and errors if it does not exist.
func TokenFromContext(ctx context.Context) (string, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return "", status.Error(codes.Unauthenticated, "authentication required")
Expand Down Expand Up @@ -289,7 +291,7 @@ func (ss *simpleServer) ensureAuthed(ctx context.Context) (context.Context, erro
return ss.ensureAuthedHandler(ctx)
}

tokenString, err := tokenFromContext(ctx)
tokenString, err := TokenFromContext(ctx)
if err != nil {
// check TLS state
if ss.tlsAuthHandler == nil {
Expand Down
16 changes: 8 additions & 8 deletions web/auth0.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,20 +196,20 @@ func installAuthProviderRoutes(
logger utils.ZapCompatibleLogger,
) {
mux.Handle(pat.New("/login"), &loginHandler{
authProvider,
logger,
redirectStateCookieName,
redirectStateCookieMaxAge,
state: authProvider,
logger: logger,
redirectStateCookieName: redirectStateCookieName,
redirectStateCookieMaxAge: redirectStateCookieMaxAge,
})
mux.Handle(pat.New(redirectURL), &callbackHandler{
authProvider,
logger,
redirectStateCookieName,
})
mux.Handle(pat.New("/logout"), &logoutHandler{
authProvider,
logger,
providerLogoutURL,
state: authProvider,
logger: logger,
providerLogoutURL: providerLogoutURL,
})
mux.Handle(pat.New("/token"), &tokenHandler{
authProvider,
Expand Down Expand Up @@ -425,7 +425,7 @@ func verifyAndSaveToken(ctx context.Context, state *AuthProvider, session *Sessi
}

session.Data["id_token"] = rawIDToken
session.Data["access_token"] = token.AccessToken
session.Data[accessTokenSessionDataField] = token.AccessToken
session.Data["profile"] = profile

return session, nil
Expand Down
49 changes: 49 additions & 0 deletions web/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,23 @@ import (
mongoutils "go.viam.com/utils/mongo"
)

var (
accessTokenSessionDataField = "access_token"
accessTokenFieldPath = fmt.Sprintf("data.%s", accessTokenSessionDataField)
)

var webSessionsIndex = []mongo.IndexModel{
{
Keys: bson.D{
{Key: "lastUpdate", Value: 1},
},
Options: options.Index().SetExpireAfterSeconds(30 * 24 * 3600),
},
{
Keys: bson.D{
{Key: accessTokenFieldPath, Value: 1},
},
},
}

// SessionManager handles working with sessions from http.
Expand All @@ -51,6 +61,7 @@ type Store interface {
Get(ctx context.Context, id string) (*Session, error)
Save(ctx context.Context, s *Session) error
SetSessionManager(*SessionManager)
HasSessionWithToken(ctx context.Context, token string) (bool, error)
}

// ----
Expand Down Expand Up @@ -130,6 +141,17 @@ func (sm *SessionManager) DeleteSession(ctx context.Context, r *http.Request, w
}
}

// HasSessionWithAccessToken returns true if there is an active session associated with that access token.
func (sm *SessionManager) HasSessionWithAccessToken(ctx context.Context, token string) bool {
hasSessionWithToken, err := sm.store.HasSessionWithToken(ctx, token)
if err != nil {
sm.logger.Errorw("error finding session with access token", "err", err)
return false
}

return hasSessionWithToken
}

func (sm *SessionManager) newID() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
Expand Down Expand Up @@ -216,6 +238,18 @@ func (mss *mongoDBSessionStore) Get(ctx context.Context, id string) (*Session, e
return s, nil
}

func (mss *mongoDBSessionStore) HasSessionWithToken(ctx context.Context, token string) (bool, error) {
ctx, span := trace.StartSpan(ctx, "MongoDBSessionStore::HasSessionWithToken")
defer span.End()

count, err := mss.collection.CountDocuments(ctx, bson.M{accessTokenFieldPath: token}, options.Count().SetLimit(1))
if err != nil {
return false, fmt.Errorf("failed to check if token session exists: %w", err)
}

return count > 0, nil
}

func (mss *mongoDBSessionStore) Save(ctx context.Context, s *Session) error {
ctx, span := trace.StartSpan(ctx, "MongoDBSessionStore::Save")
defer span.End()
Expand Down Expand Up @@ -258,6 +292,21 @@ func (mss *memorySessionStore) Delete(ctx context.Context, id string) error {
return nil
}

func (mss *memorySessionStore) HasSessionWithToken(ctx context.Context, token string) (bool, error) {
if mss.data != nil {
for _, session := range mss.data {
savedToken, ok := session.Data[accessTokenSessionDataField]
if !ok {
continue
}
if token == savedToken {
return true, nil
}
}
}
return false, errNoSession
}

func (mss *memorySessionStore) Get(ctx context.Context, id string) (*Session, error) {
if mss.data == nil {
return nil, errNoSession
Expand Down
28 changes: 27 additions & 1 deletion web/sessions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
)

func TestSession1(t *testing.T) {
ctx := context.Background()
sm := NewSessionManager(&memorySessionStore{}, golog.NewTestLogger(t))

r, err := http.NewRequest(http.MethodGet, "http://localhost/", nil)
Expand All @@ -36,7 +37,16 @@ func TestSession1(t *testing.T) {
w := &DummyWriter{}

s.Data["a"] = 1
s.Data["access_token"] = "the_access_token"
s.Save(context.TODO(), r, w)

if hasSession := sm.HasSessionWithAccessToken(ctx, "the_wrong_access_token"); hasSession {
t.Fatal("should not have session with token")
}

if hasSession := sm.HasSessionWithAccessToken(ctx, "the_access_token"); !hasSession {
t.Fatal("should have session with token")
}
}

// ----
Expand Down Expand Up @@ -64,7 +74,7 @@ func TestMongoStore(t *testing.T) {

s1 := &Session{}
s1.id = "foo"
s1.Data = bson.M{"a": 1, "b": 2}
s1.Data = bson.M{"a": 1, "b": 2, "access_token": "testToken"}
err = store.Save(ctx, s1)
if err != nil {
t.Fatal(err)
Expand All @@ -84,6 +94,22 @@ func TestMongoStore(t *testing.T) {
if _, err := store.Get(ctx, "something"); !errors.Is(err, errNoSession) {
t.Fatal(err)
}

hasSession, err := store.HasSessionWithToken(ctx, "no_token")
if err != nil {
t.Fatal(err)
}
if hasSession {
t.Fatal("should not have session")
}

hasSession, err = store.HasSessionWithToken(ctx, "testToken")
if err != nil {
t.Fatal(err)
}
if !hasSession {
t.Fatal("should have session")
}
}

// ----
Expand Down

0 comments on commit a22af49

Please sign in to comment.