CVE-2024-4183
Description
Mattermost versions 8.1.x before 8.1.12, 9.6.x before 9.6.1, 9.5.x before 9.5.3, 9.4.x before 9.4.5 fail to limit the number of active sessions, which allows an authenticated attacker to crash the server via repeated requests to the getSessions API after flooding the sessions table.
Affected packages
Versions sourced from the GitHub Security Advisory.
| Package | Affected versions | Patched versions |
|---|---|---|
github.com/mattermost/mattermost-serverGo | >= 9.6.0-rc1, < 9.6.1 | 9.6.1 |
github.com/mattermost/mattermost-serverGo | >= 9.5.0, < 9.5.3 | 9.5.3 |
github.com/mattermost/mattermost-serverGo | >= 9.4.0, < 9.4.5 | 9.4.5 |
github.com/mattermost/mattermost-serverGo | >= 8.1.0, < 8.1.12 | 8.1.12 |
Affected products
1- Range: 9.6.0
Patches
49d81eee979ae[MM-55320] Cherry pick of #25900 (#26569)
16 files changed · +336 −5
server/channels/app/app_iface.go+3 −0 modified@@ -195,6 +195,9 @@ type AppIface interface { // relationship with a user. That means any user sharing any channel, including // direct and group channels. GetKnownUsers(userID string) ([]string, *model.AppError) + // GetLRUSessions returns the Least Recently Used sessions for userID, skipping over the newest 'offset' + // number of sessions. E.g., if userID has 100 sessions, offset 98 will return the oldest 2 sessions. + GetLRUSessions(userID string, limit uint64, offset uint64) ([]*model.Session, *model.AppError) // GetLastAccessibleFileTime returns CreateAt time(from cache) of the last accessible post as per the cloud limit GetLastAccessibleFileTime() (int64, *model.AppError) // GetLastAccessiblePostTime returns CreateAt time(from cache) of the last accessible post as per the cloud limit
server/channels/app/oauth.go+5 −0 modified@@ -383,6 +383,11 @@ func (a *App) GetOAuthAccessTokenForCodeFlow(clientId, grantType, redirectURI, c } func (a *App) newSession(app *model.OAuthApp, user *model.User) (*model.Session, *model.AppError) { + if err := a.limitNumberOfSessions(user.Id); err != nil { + return nil, model.NewAppError("newSession", "api.oauth.get_access_token.internal_session.app_error", nil, + "", http.StatusInternalServerError).Wrap(err) + } + // Set new token an session session := &model.Session{UserId: user.Id, Roles: user.Roles, IsOAuth: true} session.GenerateCSRF()
server/channels/app/opentracing/opentracing_layer.go+22 −0 modified@@ -7144,6 +7144,28 @@ func (a *OpenTracingAppLayer) GetKnownUsers(userID string) ([]string, *model.App return resultVar0, resultVar1 } +func (a *OpenTracingAppLayer) GetLRUSessions(userID string, limit uint64, offset uint64) ([]*model.Session, *model.AppError) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetLRUSessions") + + a.ctx = newCtx + a.app.Srv().Store().SetContext(newCtx) + defer func() { + a.app.Srv().Store().SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.GetLRUSessions(userID, limit, offset) + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + func (a *OpenTracingAppLayer) GetLastAccessibleFileTime() (int64, *model.AppError) { origCtx := a.ctx span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetLastAccessibleFileTime")
server/channels/app/platform/session.go+4 −0 modified@@ -41,6 +41,10 @@ func (ps *PlatformService) GetSessions(userID string) ([]*model.Session, error) return ps.Store.Session().GetSessions(userID) } +func (ps *PlatformService) GetLRUSessions(userID string, limit uint64, offset uint64) ([]*model.Session, error) { + return ps.Store.Session().GetLRUSessions(userID, limit, offset) +} + func (ps *PlatformService) AddSessionToCache(session *model.Session) { ps.sessionCache.SetWithExpiry(session.Token, session, time.Duration(int64(*ps.Config().ServiceSettings.SessionCacheInMinutes))*time.Minute) }
server/channels/app/session.go+45 −0 modified@@ -18,7 +18,14 @@ import ( "github.com/mattermost/mattermost/server/v8/channels/store" ) +// maxSessionsLimit prevents a potential DOS caused by creating an unbounded number of sessions; MM-55320 +const maxSessionsLimit = 500 + func (a *App) CreateSession(session *model.Session) (*model.Session, *model.AppError) { + if appErr := a.limitNumberOfSessions(session.UserId); appErr != nil { + return nil, appErr + } + session, err := a.ch.srv.platform.CreateSession(session) if err != nil { var invErr *store.ErrInvalidInput @@ -133,6 +140,40 @@ func (a *App) GetSessions(userID string) ([]*model.Session, *model.AppError) { return sessions, nil } +// limitNumberOfSessions revokes userId's least recently used sessions to keep the number below +// maxSessionsLimit; MM-55320 +func (a *App) limitNumberOfSessions(userId string) *model.AppError { + const returnLimit = 100 + sessions, appErr := a.GetLRUSessions(userId, returnLimit, maxSessionsLimit-1) + if appErr != nil { + return model.NewAppError("limitNumberOfSessions", "app.session.save.app_error", nil, "", http.StatusInternalServerError).Wrap(appErr) + } + + // Revoke any sessions over the limit to make room for new sessions + for _, sess := range sessions { + if err := a.RevokeSession(sess); err != nil { + return model.NewAppError("limitNumberOfSessions", "app.session.save.app_error", nil, "", http.StatusInternalServerError).Wrap(err) + } + + mlog.Debug("Session revoked; user's number of sessions were over the maxSessionsLimit", + mlog.String("user_id", userId), + mlog.String("session_id", sess.Id)) + } + + return nil +} + +// GetLRUSessions returns the Least Recently Used sessions for userID, skipping over the newest 'offset' +// number of sessions. E.g., if userID has 100 sessions, offset 98 will return the oldest 2 sessions. +func (a *App) GetLRUSessions(userID string, limit uint64, offset uint64) ([]*model.Session, *model.AppError) { + sessions, err := a.ch.srv.platform.GetLRUSessions(userID, limit, offset) + if err != nil { + return nil, model.NewAppError("GetLRUSessions", "app.session.get_lru_sessions.app_error", nil, "", http.StatusInternalServerError).Wrap(err) + } + + return sessions, nil +} + func (a *App) RevokeAllSessions(userID string) *model.AppError { if err := a.ch.srv.platform.RevokeAllSessions(userID); err != nil { switch { @@ -378,6 +419,10 @@ func (a *App) createSessionForUserAccessToken(tokenString string) (*model.Sessio return nil, model.NewAppError("createSessionForUserAccessToken", "app.user_access_token.invalid_or_missing", nil, "inactive_user_id="+user.Id, http.StatusUnauthorized) } + if appErr := a.limitNumberOfSessions(user.Id); appErr != nil { + return nil, appErr + } + session := &model.Session{ Token: token.Token, UserId: user.Id,
server/channels/app/session_test.go+57 −0 modified@@ -6,8 +6,11 @@ package app import ( "context" "fmt" + "net/http" + "net/http/httptest" "os" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -396,3 +399,57 @@ func TestGetRemoteClusterSession(t *testing.T) { require.Nil(t, session) }) } + +func TestSessionsLimit(t *testing.T) { + th := Setup(t).InitBasic() + defer th.TearDown() + + user := th.BasicUser + var sessionIds []string + + r := &http.Request{} + w := httptest.NewRecorder() + for i := 0; i < maxSessionsLimit; i++ { + err := th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false) + require.Nil(t, err) + sessionIds = append(sessionIds, th.Context.Session().Id) + time.Sleep(1 * time.Millisecond) + } + + gotSessions, _ := th.App.GetSessions(user.Id) + require.Equal(t, maxSessionsLimit, len(gotSessions), "should have maxSessionsLimit number of sessions") + + // Ensure we are retrieving the same sessions. + reverse(gotSessions) + for i, sess := range gotSessions { + require.Equal(t, sessionIds[i], sess.Id) + } + + // Now add 10 more. + for i := 0; i < 10; i++ { + err := th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false) + require.Nil(t, err, "should not have an error creating user sessions") + + // Remove oldest, append newest. + sessionIds = sessionIds[1:] + sessionIds = append(sessionIds, th.Context.Session().Id) + time.Sleep(1 * time.Millisecond) + } + + // Ensure that we still only have the max allowed. + gotSessions, _ = th.App.GetSessions(user.Id) + require.Equal(t, maxSessionsLimit, len(gotSessions), "should have maxSessionsLimit number of sessions") + + // Ensure the the oldest sessions were removed first. + reverse(gotSessions) + for i, sess := range gotSessions { + require.Equal(t, sessionIds[i], sess.Id) + } +} + +// reverse can be replaced by the slices version when we move to 1.21+ +func reverse[S ~[]E, E any](s S) { + for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { + s[i], s[j] = s[j], s[i] + } +}
server/channels/app/user_agent.go+29 −4 modified@@ -10,6 +10,8 @@ import ( "github.com/avct/uasurfer" ) +const maxUserAgentVersionLength = 128 + var platformNames = map[uasurfer.Platform]string{ uasurfer.PlatformUnknown: "Windows", uasurfer.PlatformWindows: "Windows", @@ -84,24 +86,35 @@ func getOSName(ua *uasurfer.UserAgent) string { } func getBrowserVersion(ua *uasurfer.UserAgent, userAgentString string) string { + if index := strings.Index(userAgentString, "Mattermost Mobile/"); index != -1 { + afterVersion := userAgentString[index+len("Mattermost Mobile/"):] + // MM-55320: limitStringLength prevents potential DOS caused by filling an unbounded string with junk data + return limitStringLength(strings.Fields(afterVersion)[0], maxUserAgentVersionLength) + } + if index := strings.Index(userAgentString, "Mattermost/"); index != -1 { afterVersion := userAgentString[index+len("Mattermost/"):] - return strings.Fields(afterVersion)[0] + return limitStringLength(strings.Fields(afterVersion)[0], maxUserAgentVersionLength) } if index := strings.Index(userAgentString, "mmctl/"); index != -1 { afterVersion := userAgentString[index+len("mmctl/"):] - return strings.Fields(afterVersion)[0] + return limitStringLength(strings.Fields(afterVersion)[0], maxUserAgentVersionLength) } if index := strings.Index(userAgentString, "Franz/"); index != -1 { afterVersion := userAgentString[index+len("Franz/"):] - return strings.Fields(afterVersion)[0] + return limitStringLength(strings.Fields(afterVersion)[0], maxUserAgentVersionLength) } return getUAVersion(ua.Browser.Version) } +func limitStringLength(field string, limit int) string { + endPos := min(len(field), limit) + return field[:endPos] +} + func getUAVersion(version uasurfer.Version) string { if version.Patch == 0 { return fmt.Sprintf("%v.%v", version.Major, version.Minor) @@ -123,10 +136,15 @@ var browserNames = map[uasurfer.BrowserName]string{ func getBrowserName(ua *uasurfer.UserAgent, userAgentString string) string { browser := ua.Browser.Name - if strings.Contains(userAgentString, "Mattermost") { + if strings.Contains(userAgentString, "Electron") || + (strings.Contains(userAgentString, "Mattermost") && !strings.Contains(userAgentString, "Mattermost Mobile")) { return "Desktop App" } + if strings.Contains(userAgentString, "Mattermost Mobile") { + return "Mobile App" + } + if strings.Contains(userAgentString, "mmctl") { return "mmctl" } @@ -140,5 +158,12 @@ func getBrowserName(ua *uasurfer.UserAgent, userAgentString string) string { } return browserNames[uasurfer.BrowserUnknown] +} +// min should be replaced by to go 1.21 built-in generic function, see MM-57356. +func min(a, b int) int { + if a < b { + return a + } + return b }
server/channels/app/user_agent_test.go+11 −1 modified@@ -33,6 +33,8 @@ var testUserAgents = []testUserAgent{ {"Safari 9", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6) AppleWebKit/604.1.38 (KHTML, like Gecko) Version/11.0 Safari/604.1.38"}, {"Safari 8", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_4) AppleWebKit/600.7.12 (KHTML, like Gecko) Version/8.0.7 Safari/600.7.12"}, {"Safari Mobile", "Mozilla/5.0 (iPhone; CPU iPhone OS 9_1 like Mac OS X) AppleWebKit/601.1.46 (KHTML, like Gecko) Version/9.0 Mobile/13B137 Safari/601.1"}, + {"Mobile App", "Mattermost Mobile/2.7.0+482 (Android; 13; sdk_gphone64_arm64)"}, + {"Mobile App", "Mattermost Mobile/233.234441.341234223421341234529099823109834440981234+abcdef3214eafeabc3242331129857301afesfffff1930a84e4bd2348fe129ac1309bd929dca3419af934bfe3089fcd (Android; 13; sdk_gphone64_arm64)"}, } func TestGetPlatformName(t *testing.T) { @@ -53,6 +55,8 @@ func TestGetPlatformName(t *testing.T) { "Macintosh", "Macintosh", "iPhone", + "Linux", + "Linux", } for i, userAgent := range testUserAgents { @@ -83,6 +87,8 @@ func TestGetOSName(t *testing.T) { "Mac OS", "Mac OS", "iOS", + "Android", + "Android", } for i, userAgent := range testUserAgents { @@ -103,7 +109,7 @@ func TestGetBrowserName(t *testing.T) { "Chrome", "mmctl", "Desktop App", - "Chrome", + "Desktop App", "Edge", "Internet Explorer", "Internet Explorer", @@ -113,6 +119,8 @@ func TestGetBrowserName(t *testing.T) { "Safari", "Safari", "Safari", + "Mobile App", + "Mobile App", } for i, userAgent := range testUserAgents { @@ -143,6 +151,8 @@ func TestGetBrowserVersion(t *testing.T) { "11.0", "8.0.7", "9.0", + "2.7.0+482", + "233.234441.341234223421341234529099823109834440981234+abcdef3214eafeabc3242331129857301afesfffff1930a84e4bd2348fe129ac1309bd929d", // cut off at len 128 } for i, userAgent := range testUserAgents {
server/channels/store/opentracinglayer/opentracinglayer.go+18 −0 modified@@ -8343,6 +8343,24 @@ func (s *OpenTracingLayerSessionStore) Get(ctx context.Context, sessionIDOrToken return result, err } +func (s *OpenTracingLayerSessionStore) GetLRUSessions(userID string, limit uint64, offset uint64) ([]*model.Session, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SessionStore.GetLRUSessions") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.SessionStore.GetLRUSessions(userID, limit, offset) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + func (s *OpenTracingLayerSessionStore) GetSessions(userID string) ([]*model.Session, error) { origCtx := s.Root.Store.Context() span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SessionStore.GetSessions")
server/channels/store/retrylayer/retrylayer.go+21 −0 modified@@ -9503,6 +9503,27 @@ func (s *RetryLayerSessionStore) Get(ctx context.Context, sessionIDOrToken strin } +func (s *RetryLayerSessionStore) GetLRUSessions(userID string, limit uint64, offset uint64) ([]*model.Session, error) { + + tries := 0 + for { + result, err := s.SessionStore.GetLRUSessions(userID, limit, offset) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + func (s *RetryLayerSessionStore) GetSessions(userID string) ([]*model.Session, error) { tries := 0
server/channels/store/sqlstore/session_store.go+22 −0 modified@@ -123,6 +123,28 @@ func (me SqlSessionStore) GetSessions(userId string) ([]*model.Session, error) { return sessions, nil } +// GetLRUSessions gets the Least Recently Used sessions from the store. Note: the use of limit and offset +// are intentional; they are hardcoded from the app layer (i.e., will not result in a non-performant query). +func (me SqlSessionStore) GetLRUSessions(userId string, limit uint64, offset uint64) ([]*model.Session, error) { + builder := me.getQueryBuilder(). + Select("*"). + From("Sessions"). + Where(sq.Eq{"UserId": userId}). + OrderBy("LastActivityAt DESC"). + Limit(limit). + Offset(offset) + query, args, err := builder.ToSql() + if err != nil { + return nil, errors.Wrap(err, "get_lru_sessions_tosql") + } + + var sessions []*model.Session + if err := me.GetReplicaX().Select(&sessions, query, args...); err != nil { + return nil, errors.Wrapf(err, "failed to find Sessions with userId=%s", userId) + } + return sessions, nil +} + func (me SqlSessionStore) GetSessionsWithActiveDeviceIds(userId string) ([]*model.Session, error) { query := `SELECT *
server/channels/store/store.go+1 −0 modified@@ -501,6 +501,7 @@ type SessionStore interface { Get(ctx context.Context, sessionIDOrToken string) (*model.Session, error) Save(session *model.Session) (*model.Session, error) GetSessions(userID string) ([]*model.Session, error) + GetLRUSessions(userID string, limit uint64, offset uint64) ([]*model.Session, error) GetSessionsWithActiveDeviceIds(userID string) ([]*model.Session, error) GetSessionsExpired(thresholdMillis int64, mobileOnly bool, unnotifiedOnly bool) ([]*model.Session, error) UpdateExpiredNotify(sessionid string, notified bool) error
server/channels/store/storetest/mocks/SessionStore.go+26 −0 modified@@ -80,6 +80,32 @@ func (_m *SessionStore) Get(ctx context.Context, sessionIDOrToken string) (*mode return r0, r1 } +// GetLRUSessions provides a mock function with given fields: userID, limit, offset +func (_m *SessionStore) GetLRUSessions(userID string, limit uint64, offset uint64) ([]*model.Session, error) { + ret := _m.Called(userID, limit, offset) + + var r0 []*model.Session + var r1 error + if rf, ok := ret.Get(0).(func(string, uint64, uint64) ([]*model.Session, error)); ok { + return rf(userID, limit, offset) + } + if rf, ok := ret.Get(0).(func(string, uint64, uint64) []*model.Session); ok { + r0 = rf(userID, limit, offset) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*model.Session) + } + } + + if rf, ok := ret.Get(1).(func(string, uint64, uint64) error); ok { + r1 = rf(userID, limit, offset) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetSessions provides a mock function with given fields: userID func (_m *SessionStore) GetSessions(userID string) ([]*model.Session, error) { ret := _m.Called(userID)
server/channels/store/storetest/session_store.go+52 −0 modified@@ -6,6 +6,7 @@ package storetest import ( "context" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -36,6 +37,7 @@ func TestSessionStore(t *testing.T, ss store.Store) { t.Run("SessionCount", func(t *testing.T) { testSessionCount(t, ss) }) t.Run("GetSessionsExpired", func(t *testing.T) { testGetSessionsExpired(t, ss) }) t.Run("UpdateExpiredNotify", func(t *testing.T) { testUpdateExpiredNotify(t, ss) }) + t.Run("GetLRUSessions", func(t *testing.T) { testGetLRUSessions(t, ss) }) } func testSessionStoreSave(t *testing.T, ss store.Store) { @@ -404,3 +406,53 @@ func testUpdateExpiredNotify(t *testing.T, ss store.Store) { require.NoError(t, err) require.False(t, session.ExpiredNotify) } + +func testGetLRUSessions(t *testing.T, ss store.Store) { + userId := model.NewId() + + // Clear existing sessions. + err := ss.Session().RemoveAllSessions() + require.NoError(t, err) + + s1 := &model.Session{} + s1.UserId = userId + s1.DeviceId = model.NewId() + _, err = ss.Session().Save(s1) + require.NoError(t, err) + time.Sleep(1 * time.Millisecond) + + s2 := &model.Session{} + s2.UserId = userId + s2.DeviceId = model.NewId() + s2, err = ss.Session().Save(s2) + require.NoError(t, err) + time.Sleep(1 * time.Millisecond) + + s3 := &model.Session{} + s3.UserId = userId + s3.DeviceId = model.NewId() + s3, err = ss.Session().Save(s3) + require.NoError(t, err) + + sessions, err := ss.Session().GetLRUSessions(userId, 3, 3) + require.NoError(t, err) + require.Len(t, sessions, 0) + + sessions, err = ss.Session().GetLRUSessions(userId, 3, 2) + require.NoError(t, err) + require.Len(t, sessions, 1) + require.Equal(t, s1.Id, sessions[0].Id) + + sessions, err = ss.Session().GetLRUSessions(userId, 3, 1) + require.NoError(t, err) + require.Len(t, sessions, 2) + require.Equal(t, s2.Id, sessions[0].Id) + require.Equal(t, s1.Id, sessions[1].Id) + + sessions, err = ss.Session().GetLRUSessions(userId, 3, 0) + require.NoError(t, err) + require.Len(t, sessions, 3) + require.Equal(t, s3.Id, sessions[0].Id) + require.Equal(t, s2.Id, sessions[1].Id) + require.Equal(t, s1.Id, sessions[2].Id) +}
server/channels/store/timerlayer/timerlayer.go+16 −0 modified@@ -7524,6 +7524,22 @@ func (s *TimerLayerSessionStore) Get(ctx context.Context, sessionIDOrToken strin return result, err } +func (s *TimerLayerSessionStore) GetLRUSessions(userID string, limit uint64, offset uint64) ([]*model.Session, error) { + start := time.Now() + + result, err := s.SessionStore.GetLRUSessions(userID, limit, offset) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SessionStore.GetLRUSessions", success, elapsed) + } + return result, err +} + func (s *TimerLayerSessionStore) GetSessions(userID string) ([]*model.Session, error) { start := time.Now()
server/i18n/en.json+4 −0 modified@@ -6243,6 +6243,10 @@ "id": "app.session.get.app_error", "translation": "We encountered an error finding the session." }, + { + "id": "app.session.get_lru_sessions.app_error", + "translation": "Unable to get least recently used sessions." + }, { "id": "app.session.get_sessions.app_error", "translation": "We encountered an error while finding user sessions."
86920d641760MM-55320 - Limit length of browser user agent version; ratelimit the /sessions endpoint (#25900) (#26549)
16 files changed · +321 −4
server/channels/app/app_iface.go+3 −0 modified@@ -194,6 +194,9 @@ type AppIface interface { // relationship with a user. That means any user sharing any channel, including // direct and group channels. GetKnownUsers(userID string) ([]string, *model.AppError) + // GetLRUSessions returns the Least Recently Used sessions for userID, skipping over the newest 'offset' + // number of sessions. E.g., if userID has 100 sessions, offset 98 will return the oldest 2 sessions. + GetLRUSessions(c request.CTX, userID string, limit uint64, offset uint64) ([]*model.Session, *model.AppError) // GetLastAccessibleFileTime returns CreateAt time(from cache) of the last accessible post as per the cloud limit GetLastAccessibleFileTime() (int64, *model.AppError) // GetLastAccessiblePostTime returns CreateAt time(from cache) of the last accessible post as per the cloud limit
server/channels/app/oauth.go+5 −0 modified@@ -383,6 +383,11 @@ func (a *App) GetOAuthAccessTokenForCodeFlow(c request.CTX, clientId, grantType, } func (a *App) newSession(c request.CTX, app *model.OAuthApp, user *model.User) (*model.Session, *model.AppError) { + if err := a.limitNumberOfSessions(c, user.Id); err != nil { + return nil, model.NewAppError("newSession", "api.oauth.get_access_token.internal_session.app_error", nil, + "", http.StatusInternalServerError).Wrap(err) + } + // Set new token an session session := &model.Session{UserId: user.Id, Roles: user.Roles, IsOAuth: true} session.GenerateCSRF()
server/channels/app/opentracing/opentracing_layer.go+22 −0 modified@@ -7166,6 +7166,28 @@ func (a *OpenTracingAppLayer) GetKnownUsers(userID string) ([]string, *model.App return resultVar0, resultVar1 } +func (a *OpenTracingAppLayer) GetLRUSessions(c request.CTX, userID string, limit uint64, offset uint64) ([]*model.Session, *model.AppError) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetLRUSessions") + + a.ctx = newCtx + a.app.Srv().Store().SetContext(newCtx) + defer func() { + a.app.Srv().Store().SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.GetLRUSessions(c, userID, limit, offset) + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + func (a *OpenTracingAppLayer) GetLastAccessibleFileTime() (int64, *model.AppError) { origCtx := a.ctx span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetLastAccessibleFileTime")
server/channels/app/platform/session.go+4 −0 modified@@ -41,6 +41,10 @@ func (ps *PlatformService) GetSessions(c request.CTX, userID string) ([]*model.S return ps.Store.Session().GetSessions(c, userID) } +func (ps *PlatformService) GetLRUSessions(c request.CTX, userID string, limit uint64, offset uint64) ([]*model.Session, error) { + return ps.Store.Session().GetLRUSessions(c, userID, limit, offset) +} + func (ps *PlatformService) AddSessionToCache(session *model.Session) { ps.sessionCache.SetWithExpiry(session.Token, session, time.Duration(int64(*ps.Config().ServiceSettings.SessionCacheInMinutes))*time.Minute) }
server/channels/app/session.go+45 −0 modified@@ -18,7 +18,14 @@ import ( "github.com/mattermost/mattermost/server/v8/channels/store" ) +// maxSessionsLimit prevents a potential DOS caused by creating an unbounded number of sessions; MM-55320 +const maxSessionsLimit = 500 + func (a *App) CreateSession(c request.CTX, session *model.Session) (*model.Session, *model.AppError) { + if appErr := a.limitNumberOfSessions(c, session.UserId); appErr != nil { + return nil, appErr + } + session, err := a.ch.srv.platform.CreateSession(c, session) if err != nil { var invErr *store.ErrInvalidInput @@ -136,6 +143,40 @@ func (a *App) GetSessions(c request.CTX, userID string) ([]*model.Session, *mode return sessions, nil } +// limitNumberOfSessions revokes userId's least recently used sessions to keep the number below +// maxSessionsLimit; MM-55320 +func (a *App) limitNumberOfSessions(c request.CTX, userId string) *model.AppError { + const returnLimit = 100 + sessions, appErr := a.GetLRUSessions(c, userId, returnLimit, maxSessionsLimit-1) + if appErr != nil { + return model.NewAppError("limitNumberOfSessions", "app.session.save.app_error", nil, "", http.StatusInternalServerError).Wrap(appErr) + } + + // Revoke any sessions over the limit to make room for new sessions + for _, sess := range sessions { + if err := a.RevokeSession(c, sess); err != nil { + return model.NewAppError("limitNumberOfSessions", "app.session.save.app_error", nil, "", http.StatusInternalServerError).Wrap(err) + } + + c.Logger().Debug("Session revoked; user's number of sessions were over the maxSessionsLimit", + mlog.String("user_id", userId), + mlog.String("session_id", sess.Id)) + } + + return nil +} + +// GetLRUSessions returns the Least Recently Used sessions for userID, skipping over the newest 'offset' +// number of sessions. E.g., if userID has 100 sessions, offset 98 will return the oldest 2 sessions. +func (a *App) GetLRUSessions(c request.CTX, userID string, limit uint64, offset uint64) ([]*model.Session, *model.AppError) { + sessions, err := a.ch.srv.platform.GetLRUSessions(c, userID, limit, offset) + if err != nil { + return nil, model.NewAppError("GetLRUSessions", "app.session.get_lru_sessions.app_error", nil, "", http.StatusInternalServerError).Wrap(err) + } + + return sessions, nil +} + func (a *App) RevokeAllSessions(c request.CTX, userID string) *model.AppError { if err := a.ch.srv.platform.RevokeAllSessions(c, userID); err != nil { switch { @@ -384,6 +425,10 @@ func (a *App) createSessionForUserAccessToken(c request.CTX, tokenString string) return nil, model.NewAppError("createSessionForUserAccessToken", "app.user_access_token.invalid_or_missing", nil, "inactive_user_id="+user.Id, http.StatusUnauthorized) } + if appErr := a.limitNumberOfSessions(c, user.Id); appErr != nil { + return nil, appErr + } + session := &model.Session{ Token: token.Token, UserId: user.Id,
server/channels/app/session_test.go+57 −0 modified@@ -5,8 +5,11 @@ package app import ( "fmt" + "net/http" + "net/http/httptest" "os" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -395,3 +398,57 @@ func TestGetRemoteClusterSession(t *testing.T) { require.Nil(t, session) }) } + +func TestSessionsLimit(t *testing.T) { + th := Setup(t).InitBasic() + defer th.TearDown() + + user := th.BasicUser + var sessions []*model.Session + + r := &http.Request{} + w := httptest.NewRecorder() + for i := 0; i < maxSessionsLimit; i++ { + session, err := th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false) + require.Nil(t, err) + sessions = append(sessions, session) + time.Sleep(1 * time.Millisecond) + } + + gotSessions, _ := th.App.GetSessions(th.Context, user.Id) + require.Equal(t, maxSessionsLimit, len(gotSessions), "should have maxSessionsLimit number of sessions") + + // Ensure we are retrieving the same sessions. + reverse(gotSessions) + for i, sess := range gotSessions { + require.Equal(t, sessions[i].Id, sess.Id) + } + + // Now add 10 more. + for i := 0; i < 10; i++ { + session, err := th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false) + require.Nil(t, err, "should not have an error creating user sessions") + + // Remove oldest, append newest. + sessions = sessions[1:] + sessions = append(sessions, session) + time.Sleep(1 * time.Millisecond) + } + + // Ensure that we still only have the max allowed. + gotSessions, _ = th.App.GetSessions(th.Context, user.Id) + require.Equal(t, maxSessionsLimit, len(gotSessions), "should have maxSessionsLimit number of sessions") + + // Ensure the the oldest sessions were removed first. + reverse(gotSessions) + for i, sess := range gotSessions { + require.Equal(t, sessions[i].Id, sess.Id) + } +} + +// reverse can be replaced by the slices version when we move to 1.21+ +func reverse[S ~[]E, E any](s S) { + for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { + s[i], s[j] = s[j], s[i] + } +}
server/channels/app/user_agent.go+20 −4 modified@@ -10,6 +10,8 @@ import ( "github.com/avct/uasurfer" ) +const maxUserAgentVersionLength = 128 + var platformNames = map[uasurfer.Platform]string{ uasurfer.PlatformUnknown: "Windows", uasurfer.PlatformWindows: "Windows", @@ -86,27 +88,33 @@ func getOSName(ua *uasurfer.UserAgent) string { func getBrowserVersion(ua *uasurfer.UserAgent, userAgentString string) string { if index := strings.Index(userAgentString, "Mattermost Mobile/"); index != -1 { afterVersion := userAgentString[index+len("Mattermost Mobile/"):] - return strings.Fields(afterVersion)[0] + // MM-55320: limitStringLength prevents potential DOS caused by filling an unbounded string with junk data + return limitStringLength(strings.Fields(afterVersion)[0], maxUserAgentVersionLength) } if index := strings.Index(userAgentString, "Mattermost/"); index != -1 { afterVersion := userAgentString[index+len("Mattermost/"):] - return strings.Fields(afterVersion)[0] + return limitStringLength(strings.Fields(afterVersion)[0], maxUserAgentVersionLength) } if index := strings.Index(userAgentString, "mmctl/"); index != -1 { afterVersion := userAgentString[index+len("mmctl/"):] - return strings.Fields(afterVersion)[0] + return limitStringLength(strings.Fields(afterVersion)[0], maxUserAgentVersionLength) } if index := strings.Index(userAgentString, "Franz/"); index != -1 { afterVersion := userAgentString[index+len("Franz/"):] - return strings.Fields(afterVersion)[0] + return limitStringLength(strings.Fields(afterVersion)[0], maxUserAgentVersionLength) } return getUAVersion(ua.Browser.Version) } +func limitStringLength(field string, limit int) string { + endPos := min(len(field), limit) + return field[:endPos] +} + func getUAVersion(version uasurfer.Version) string { if version.Patch == 0 { return fmt.Sprintf("%v.%v", version.Major, version.Minor) @@ -151,3 +159,11 @@ func getBrowserName(ua *uasurfer.UserAgent, userAgentString string) string { return browserNames[uasurfer.BrowserUnknown] } + +// min should be replaced by to go 1.21 built-in generic function, see MM-57356. +func min(a, b int) int { + if a < b { + return a + } + return b +}
server/channels/app/user_agent_test.go+5 −0 modified@@ -34,6 +34,7 @@ var testUserAgents = []testUserAgent{ {"Safari 8", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_4) AppleWebKit/600.7.12 (KHTML, like Gecko) Version/8.0.7 Safari/600.7.12"}, {"Safari Mobile", "Mozilla/5.0 (iPhone; CPU iPhone OS 9_1 like Mac OS X) AppleWebKit/601.1.46 (KHTML, like Gecko) Version/9.0 Mobile/13B137 Safari/601.1"}, {"Mobile App", "Mattermost Mobile/2.7.0+482 (Android; 13; sdk_gphone64_arm64)"}, + {"Mobile App", "Mattermost Mobile/233.234441.341234223421341234529099823109834440981234+abcdef3214eafeabc3242331129857301afesfffff1930a84e4bd2348fe129ac1309bd929dca3419af934bfe3089fcd (Android; 13; sdk_gphone64_arm64)"}, } func TestGetPlatformName(t *testing.T) { @@ -55,6 +56,7 @@ func TestGetPlatformName(t *testing.T) { "Macintosh", "iPhone", "Linux", + "Linux", } for i, userAgent := range testUserAgents { @@ -86,6 +88,7 @@ func TestGetOSName(t *testing.T) { "Mac OS", "iOS", "Android", + "Android", } for i, userAgent := range testUserAgents { @@ -117,6 +120,7 @@ func TestGetBrowserName(t *testing.T) { "Safari", "Safari", "Mobile App", + "Mobile App", } for i, userAgent := range testUserAgents { @@ -148,6 +152,7 @@ func TestGetBrowserVersion(t *testing.T) { "8.0.7", "9.0", "2.7.0+482", + "233.234441.341234223421341234529099823109834440981234+abcdef3214eafeabc3242331129857301afesfffff1930a84e4bd2348fe129ac1309bd929d", // cut off at len 128 } for i, userAgent := range testUserAgents {
server/channels/store/opentracinglayer/opentracinglayer.go+18 −0 modified@@ -8435,6 +8435,24 @@ func (s *OpenTracingLayerSessionStore) Get(c request.CTX, sessionIDOrToken strin return result, err } +func (s *OpenTracingLayerSessionStore) GetLRUSessions(c request.CTX, userID string, limit uint64, offset uint64) ([]*model.Session, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SessionStore.GetLRUSessions") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.SessionStore.GetLRUSessions(c, userID, limit, offset) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + func (s *OpenTracingLayerSessionStore) GetSessions(c request.CTX, userID string) ([]*model.Session, error) { origCtx := s.Root.Store.Context() span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SessionStore.GetSessions")
server/channels/store/retrylayer/retrylayer.go+21 −0 modified@@ -9608,6 +9608,27 @@ func (s *RetryLayerSessionStore) Get(c request.CTX, sessionIDOrToken string) (*m } +func (s *RetryLayerSessionStore) GetLRUSessions(c request.CTX, userID string, limit uint64, offset uint64) ([]*model.Session, error) { + + tries := 0 + for { + result, err := s.SessionStore.GetLRUSessions(c, userID, limit, offset) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + func (s *RetryLayerSessionStore) GetSessions(c request.CTX, userID string) ([]*model.Session, error) { tries := 0
server/channels/store/sqlstore/session_store.go+22 −0 modified@@ -123,6 +123,28 @@ func (me SqlSessionStore) GetSessions(c request.CTX, userId string) ([]*model.Se return sessions, nil } +// GetLRUSessions gets the Least Recently Used sessions from the store. Note: the use of limit and offset +// are intentional; they are hardcoded from the app layer (i.e., will not result in a non-performant query). +func (me SqlSessionStore) GetLRUSessions(c request.CTX, userId string, limit uint64, offset uint64) ([]*model.Session, error) { + builder := me.getQueryBuilder(). + Select("*"). + From("Sessions"). + Where(sq.Eq{"UserId": userId}). + OrderBy("LastActivityAt DESC"). + Limit(limit). + Offset(offset) + query, args, err := builder.ToSql() + if err != nil { + return nil, errors.Wrap(err, "get_lru_sessions_tosql") + } + + var sessions []*model.Session + if err := me.GetReplicaX().Select(&sessions, query, args...); err != nil { + return nil, errors.Wrapf(err, "failed to find Sessions with userId=%s", userId) + } + return sessions, nil +} + func (me SqlSessionStore) GetSessionsWithActiveDeviceIds(userId string) ([]*model.Session, error) { query := `SELECT *
server/channels/store/store.go+1 −0 modified@@ -504,6 +504,7 @@ type SessionStore interface { Get(c request.CTX, sessionIDOrToken string) (*model.Session, error) Save(c request.CTX, session *model.Session) (*model.Session, error) GetSessions(c request.CTX, userID string) ([]*model.Session, error) + GetLRUSessions(c request.CTX, userID string, limit uint64, offset uint64) ([]*model.Session, error) GetSessionsWithActiveDeviceIds(userID string) ([]*model.Session, error) GetSessionsExpired(thresholdMillis int64, mobileOnly bool, unnotifiedOnly bool) ([]*model.Session, error) UpdateExpiredNotify(sessionid string, notified bool) error
server/channels/store/storetest/mocks/SessionStore.go+26 −0 modified@@ -79,6 +79,32 @@ func (_m *SessionStore) Get(c request.CTX, sessionIDOrToken string) (*model.Sess return r0, r1 } +// GetLRUSessions provides a mock function with given fields: c, userID, limit, offset +func (_m *SessionStore) GetLRUSessions(c request.CTX, userID string, limit uint64, offset uint64) ([]*model.Session, error) { + ret := _m.Called(c, userID, limit, offset) + + var r0 []*model.Session + var r1 error + if rf, ok := ret.Get(0).(func(request.CTX, string, uint64, uint64) ([]*model.Session, error)); ok { + return rf(c, userID, limit, offset) + } + if rf, ok := ret.Get(0).(func(request.CTX, string, uint64, uint64) []*model.Session); ok { + r0 = rf(c, userID, limit, offset) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*model.Session) + } + } + + if rf, ok := ret.Get(1).(func(request.CTX, string, uint64, uint64) error); ok { + r1 = rf(c, userID, limit, offset) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetSessions provides a mock function with given fields: c, userID func (_m *SessionStore) GetSessions(c request.CTX, userID string) ([]*model.Session, error) { ret := _m.Called(c, userID)
server/channels/store/storetest/session_store.go+52 −0 modified@@ -5,6 +5,7 @@ package storetest import ( "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -36,6 +37,7 @@ func TestSessionStore(t *testing.T, rctx request.CTX, ss store.Store) { t.Run("SessionCount", func(t *testing.T) { testSessionCount(t, rctx, ss) }) t.Run("GetSessionsExpired", func(t *testing.T) { testGetSessionsExpired(t, rctx, ss) }) t.Run("UpdateExpiredNotify", func(t *testing.T) { testUpdateExpiredNotify(t, rctx, ss) }) + t.Run("GetLRUSessions", func(t *testing.T) { testGetLRUSessions(t, rctx, ss) }) } func testSessionStoreSave(t *testing.T, rctx request.CTX, ss store.Store) { @@ -404,3 +406,53 @@ func testUpdateExpiredNotify(t *testing.T, rctx request.CTX, ss store.Store) { require.NoError(t, err) require.False(t, session.ExpiredNotify) } + +func testGetLRUSessions(t *testing.T, rctx request.CTX, ss store.Store) { + userId := model.NewId() + + // Clear existing sessions. + err := ss.Session().RemoveAllSessions() + require.NoError(t, err) + + s1 := &model.Session{} + s1.UserId = userId + s1.DeviceId = model.NewId() + _, err = ss.Session().Save(rctx, s1) + require.NoError(t, err) + time.Sleep(1 * time.Millisecond) + + s2 := &model.Session{} + s2.UserId = userId + s2.DeviceId = model.NewId() + s2, err = ss.Session().Save(rctx, s2) + require.NoError(t, err) + time.Sleep(1 * time.Millisecond) + + s3 := &model.Session{} + s3.UserId = userId + s3.DeviceId = model.NewId() + s3, err = ss.Session().Save(rctx, s3) + require.NoError(t, err) + + sessions, err := ss.Session().GetLRUSessions(rctx, userId, 3, 3) + require.NoError(t, err) + require.Len(t, sessions, 0) + + sessions, err = ss.Session().GetLRUSessions(rctx, userId, 3, 2) + require.NoError(t, err) + require.Len(t, sessions, 1) + require.Equal(t, s1.Id, sessions[0].Id) + + sessions, err = ss.Session().GetLRUSessions(rctx, userId, 3, 1) + require.NoError(t, err) + require.Len(t, sessions, 2) + require.Equal(t, s2.Id, sessions[0].Id) + require.Equal(t, s1.Id, sessions[1].Id) + + sessions, err = ss.Session().GetLRUSessions(rctx, userId, 3, 0) + require.NoError(t, err) + require.Len(t, sessions, 3) + require.Equal(t, s3.Id, sessions[0].Id) + require.Equal(t, s2.Id, sessions[1].Id) + require.Equal(t, s1.Id, sessions[2].Id) +}
server/channels/store/timerlayer/timerlayer.go+16 −0 modified@@ -7609,6 +7609,22 @@ func (s *TimerLayerSessionStore) Get(c request.CTX, sessionIDOrToken string) (*m return result, err } +func (s *TimerLayerSessionStore) GetLRUSessions(c request.CTX, userID string, limit uint64, offset uint64) ([]*model.Session, error) { + start := time.Now() + + result, err := s.SessionStore.GetLRUSessions(c, userID, limit, offset) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SessionStore.GetLRUSessions", success, elapsed) + } + return result, err +} + func (s *TimerLayerSessionStore) GetSessions(c request.CTX, userID string) ([]*model.Session, error) { start := time.Now()
server/i18n/en.json+4 −0 modified@@ -6358,6 +6358,10 @@ "id": "app.session.get.app_error", "translation": "We encountered an error finding the session." }, + { + "id": "app.session.get_lru_sessions.app_error", + "translation": "Unable to get least recently used sessions." + }, { "id": "app.session.get_sessions.app_error", "translation": "We encountered an error while finding user sessions."
b45c3dac4c16MM-55320 - Limit length of browser user agent version; ratelimit the /sessions endpoint (#25900) (#26548)
16 files changed · +321 −4
server/channels/app/app_iface.go+3 −0 modified@@ -194,6 +194,9 @@ type AppIface interface { // relationship with a user. That means any user sharing any channel, including // direct and group channels. GetKnownUsers(userID string) ([]string, *model.AppError) + // GetLRUSessions returns the Least Recently Used sessions for userID, skipping over the newest 'offset' + // number of sessions. E.g., if userID has 100 sessions, offset 98 will return the oldest 2 sessions. + GetLRUSessions(c request.CTX, userID string, limit uint64, offset uint64) ([]*model.Session, *model.AppError) // GetLastAccessibleFileTime returns CreateAt time(from cache) of the last accessible post as per the cloud limit GetLastAccessibleFileTime() (int64, *model.AppError) // GetLastAccessiblePostTime returns CreateAt time(from cache) of the last accessible post as per the cloud limit
server/channels/app/oauth.go+5 −0 modified@@ -383,6 +383,11 @@ func (a *App) GetOAuthAccessTokenForCodeFlow(c request.CTX, clientId, grantType, } func (a *App) newSession(c request.CTX, app *model.OAuthApp, user *model.User) (*model.Session, *model.AppError) { + if err := a.limitNumberOfSessions(c, user.Id); err != nil { + return nil, model.NewAppError("newSession", "api.oauth.get_access_token.internal_session.app_error", nil, + "", http.StatusInternalServerError).Wrap(err) + } + // Set new token an session session := &model.Session{UserId: user.Id, Roles: user.Roles, IsOAuth: true} session.GenerateCSRF()
server/channels/app/opentracing/opentracing_layer.go+22 −0 modified@@ -7159,6 +7159,28 @@ func (a *OpenTracingAppLayer) GetKnownUsers(userID string) ([]string, *model.App return resultVar0, resultVar1 } +func (a *OpenTracingAppLayer) GetLRUSessions(c request.CTX, userID string, limit uint64, offset uint64) ([]*model.Session, *model.AppError) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetLRUSessions") + + a.ctx = newCtx + a.app.Srv().Store().SetContext(newCtx) + defer func() { + a.app.Srv().Store().SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.GetLRUSessions(c, userID, limit, offset) + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + func (a *OpenTracingAppLayer) GetLastAccessibleFileTime() (int64, *model.AppError) { origCtx := a.ctx span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetLastAccessibleFileTime")
server/channels/app/platform/session.go+4 −0 modified@@ -41,6 +41,10 @@ func (ps *PlatformService) GetSessions(c request.CTX, userID string) ([]*model.S return ps.Store.Session().GetSessions(c, userID) } +func (ps *PlatformService) GetLRUSessions(c request.CTX, userID string, limit uint64, offset uint64) ([]*model.Session, error) { + return ps.Store.Session().GetLRUSessions(c, userID, limit, offset) +} + func (ps *PlatformService) AddSessionToCache(session *model.Session) { ps.sessionCache.SetWithExpiry(session.Token, session, time.Duration(int64(*ps.Config().ServiceSettings.SessionCacheInMinutes))*time.Minute) }
server/channels/app/session.go+45 −0 modified@@ -18,7 +18,14 @@ import ( "github.com/mattermost/mattermost/server/v8/channels/store" ) +// maxSessionsLimit prevents a potential DOS caused by creating an unbounded number of sessions; MM-55320 +const maxSessionsLimit = 500 + func (a *App) CreateSession(c request.CTX, session *model.Session) (*model.Session, *model.AppError) { + if appErr := a.limitNumberOfSessions(c, session.UserId); appErr != nil { + return nil, appErr + } + session, err := a.ch.srv.platform.CreateSession(c, session) if err != nil { var invErr *store.ErrInvalidInput @@ -136,6 +143,40 @@ func (a *App) GetSessions(c request.CTX, userID string) ([]*model.Session, *mode return sessions, nil } +// limitNumberOfSessions revokes userId's least recently used sessions to keep the number below +// maxSessionsLimit; MM-55320 +func (a *App) limitNumberOfSessions(c request.CTX, userId string) *model.AppError { + const returnLimit = 100 + sessions, appErr := a.GetLRUSessions(c, userId, returnLimit, maxSessionsLimit-1) + if appErr != nil { + return model.NewAppError("limitNumberOfSessions", "app.session.save.app_error", nil, "", http.StatusInternalServerError).Wrap(appErr) + } + + // Revoke any sessions over the limit to make room for new sessions + for _, sess := range sessions { + if err := a.RevokeSession(c, sess); err != nil { + return model.NewAppError("limitNumberOfSessions", "app.session.save.app_error", nil, "", http.StatusInternalServerError).Wrap(err) + } + + c.Logger().Debug("Session revoked; user's number of sessions were over the maxSessionsLimit", + mlog.String("user_id", userId), + mlog.String("session_id", sess.Id)) + } + + return nil +} + +// GetLRUSessions returns the Least Recently Used sessions for userID, skipping over the newest 'offset' +// number of sessions. E.g., if userID has 100 sessions, offset 98 will return the oldest 2 sessions. +func (a *App) GetLRUSessions(c request.CTX, userID string, limit uint64, offset uint64) ([]*model.Session, *model.AppError) { + sessions, err := a.ch.srv.platform.GetLRUSessions(c, userID, limit, offset) + if err != nil { + return nil, model.NewAppError("GetLRUSessions", "app.session.get_lru_sessions.app_error", nil, "", http.StatusInternalServerError).Wrap(err) + } + + return sessions, nil +} + func (a *App) RevokeAllSessions(c request.CTX, userID string) *model.AppError { if err := a.ch.srv.platform.RevokeAllSessions(c, userID); err != nil { switch { @@ -384,6 +425,10 @@ func (a *App) createSessionForUserAccessToken(c request.CTX, tokenString string) return nil, model.NewAppError("createSessionForUserAccessToken", "app.user_access_token.invalid_or_missing", nil, "inactive_user_id="+user.Id, http.StatusUnauthorized) } + if appErr := a.limitNumberOfSessions(c, user.Id); appErr != nil { + return nil, appErr + } + session := &model.Session{ Token: token.Token, UserId: user.Id,
server/channels/app/session_test.go+57 −0 modified@@ -5,8 +5,11 @@ package app import ( "fmt" + "net/http" + "net/http/httptest" "os" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -395,3 +398,57 @@ func TestGetRemoteClusterSession(t *testing.T) { require.Nil(t, session) }) } + +func TestSessionsLimit(t *testing.T) { + th := Setup(t).InitBasic() + defer th.TearDown() + + user := th.BasicUser + var sessions []*model.Session + + r := &http.Request{} + w := httptest.NewRecorder() + for i := 0; i < maxSessionsLimit; i++ { + session, err := th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false) + require.Nil(t, err) + sessions = append(sessions, session) + time.Sleep(1 * time.Millisecond) + } + + gotSessions, _ := th.App.GetSessions(th.Context, user.Id) + require.Equal(t, maxSessionsLimit, len(gotSessions), "should have maxSessionsLimit number of sessions") + + // Ensure we are retrieving the same sessions. + reverse(gotSessions) + for i, sess := range gotSessions { + require.Equal(t, sessions[i].Id, sess.Id) + } + + // Now add 10 more. + for i := 0; i < 10; i++ { + session, err := th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false) + require.Nil(t, err, "should not have an error creating user sessions") + + // Remove oldest, append newest. + sessions = sessions[1:] + sessions = append(sessions, session) + time.Sleep(1 * time.Millisecond) + } + + // Ensure that we still only have the max allowed. + gotSessions, _ = th.App.GetSessions(th.Context, user.Id) + require.Equal(t, maxSessionsLimit, len(gotSessions), "should have maxSessionsLimit number of sessions") + + // Ensure the the oldest sessions were removed first. + reverse(gotSessions) + for i, sess := range gotSessions { + require.Equal(t, sessions[i].Id, sess.Id) + } +} + +// reverse can be replaced by the slices version when we move to 1.21+ +func reverse[S ~[]E, E any](s S) { + for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { + s[i], s[j] = s[j], s[i] + } +}
server/channels/app/user_agent.go+20 −4 modified@@ -10,6 +10,8 @@ import ( "github.com/avct/uasurfer" ) +const maxUserAgentVersionLength = 128 + var platformNames = map[uasurfer.Platform]string{ uasurfer.PlatformUnknown: "Windows", uasurfer.PlatformWindows: "Windows", @@ -86,27 +88,33 @@ func getOSName(ua *uasurfer.UserAgent) string { func getBrowserVersion(ua *uasurfer.UserAgent, userAgentString string) string { if index := strings.Index(userAgentString, "Mattermost Mobile/"); index != -1 { afterVersion := userAgentString[index+len("Mattermost Mobile/"):] - return strings.Fields(afterVersion)[0] + // MM-55320: limitStringLength prevents potential DOS caused by filling an unbounded string with junk data + return limitStringLength(strings.Fields(afterVersion)[0], maxUserAgentVersionLength) } if index := strings.Index(userAgentString, "Mattermost/"); index != -1 { afterVersion := userAgentString[index+len("Mattermost/"):] - return strings.Fields(afterVersion)[0] + return limitStringLength(strings.Fields(afterVersion)[0], maxUserAgentVersionLength) } if index := strings.Index(userAgentString, "mmctl/"); index != -1 { afterVersion := userAgentString[index+len("mmctl/"):] - return strings.Fields(afterVersion)[0] + return limitStringLength(strings.Fields(afterVersion)[0], maxUserAgentVersionLength) } if index := strings.Index(userAgentString, "Franz/"); index != -1 { afterVersion := userAgentString[index+len("Franz/"):] - return strings.Fields(afterVersion)[0] + return limitStringLength(strings.Fields(afterVersion)[0], maxUserAgentVersionLength) } return getUAVersion(ua.Browser.Version) } +func limitStringLength(field string, limit int) string { + endPos := min(len(field), limit) + return field[:endPos] +} + func getUAVersion(version uasurfer.Version) string { if version.Patch == 0 { return fmt.Sprintf("%v.%v", version.Major, version.Minor) @@ -151,3 +159,11 @@ func getBrowserName(ua *uasurfer.UserAgent, userAgentString string) string { return browserNames[uasurfer.BrowserUnknown] } + +// min should be replaced by to go 1.21 built-in generic function, see MM-57356. +func min(a, b int) int { + if a < b { + return a + } + return b +}
server/channels/app/user_agent_test.go+5 −0 modified@@ -34,6 +34,7 @@ var testUserAgents = []testUserAgent{ {"Safari 8", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_4) AppleWebKit/600.7.12 (KHTML, like Gecko) Version/8.0.7 Safari/600.7.12"}, {"Safari Mobile", "Mozilla/5.0 (iPhone; CPU iPhone OS 9_1 like Mac OS X) AppleWebKit/601.1.46 (KHTML, like Gecko) Version/9.0 Mobile/13B137 Safari/601.1"}, {"Mobile App", "Mattermost Mobile/2.7.0+482 (Android; 13; sdk_gphone64_arm64)"}, + {"Mobile App", "Mattermost Mobile/233.234441.341234223421341234529099823109834440981234+abcdef3214eafeabc3242331129857301afesfffff1930a84e4bd2348fe129ac1309bd929dca3419af934bfe3089fcd (Android; 13; sdk_gphone64_arm64)"}, } func TestGetPlatformName(t *testing.T) { @@ -55,6 +56,7 @@ func TestGetPlatformName(t *testing.T) { "Macintosh", "iPhone", "Linux", + "Linux", } for i, userAgent := range testUserAgents { @@ -86,6 +88,7 @@ func TestGetOSName(t *testing.T) { "Mac OS", "iOS", "Android", + "Android", } for i, userAgent := range testUserAgents { @@ -117,6 +120,7 @@ func TestGetBrowserName(t *testing.T) { "Safari", "Safari", "Mobile App", + "Mobile App", } for i, userAgent := range testUserAgents { @@ -148,6 +152,7 @@ func TestGetBrowserVersion(t *testing.T) { "8.0.7", "9.0", "2.7.0+482", + "233.234441.341234223421341234529099823109834440981234+abcdef3214eafeabc3242331129857301afesfffff1930a84e4bd2348fe129ac1309bd929d", // cut off at len 128 } for i, userAgent := range testUserAgents {
server/channels/store/opentracinglayer/opentracinglayer.go+18 −0 modified@@ -8489,6 +8489,24 @@ func (s *OpenTracingLayerSessionStore) Get(c request.CTX, sessionIDOrToken strin return result, err } +func (s *OpenTracingLayerSessionStore) GetLRUSessions(c request.CTX, userID string, limit uint64, offset uint64) ([]*model.Session, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SessionStore.GetLRUSessions") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.SessionStore.GetLRUSessions(c, userID, limit, offset) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + func (s *OpenTracingLayerSessionStore) GetSessions(c request.CTX, userID string) ([]*model.Session, error) { origCtx := s.Root.Store.Context() span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SessionStore.GetSessions")
server/channels/store/retrylayer/retrylayer.go+21 −0 modified@@ -9671,6 +9671,27 @@ func (s *RetryLayerSessionStore) Get(c request.CTX, sessionIDOrToken string) (*m } +func (s *RetryLayerSessionStore) GetLRUSessions(c request.CTX, userID string, limit uint64, offset uint64) ([]*model.Session, error) { + + tries := 0 + for { + result, err := s.SessionStore.GetLRUSessions(c, userID, limit, offset) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + func (s *RetryLayerSessionStore) GetSessions(c request.CTX, userID string) ([]*model.Session, error) { tries := 0
server/channels/store/sqlstore/session_store.go+22 −0 modified@@ -123,6 +123,28 @@ func (me SqlSessionStore) GetSessions(c request.CTX, userId string) ([]*model.Se return sessions, nil } +// GetLRUSessions gets the Least Recently Used sessions from the store. Note: the use of limit and offset +// are intentional; they are hardcoded from the app layer (i.e., will not result in a non-performant query). +func (me SqlSessionStore) GetLRUSessions(c request.CTX, userId string, limit uint64, offset uint64) ([]*model.Session, error) { + builder := me.getQueryBuilder(). + Select("*"). + From("Sessions"). + Where(sq.Eq{"UserId": userId}). + OrderBy("LastActivityAt DESC"). + Limit(limit). + Offset(offset) + query, args, err := builder.ToSql() + if err != nil { + return nil, errors.Wrap(err, "get_lru_sessions_tosql") + } + + var sessions []*model.Session + if err := me.GetReplicaX().Select(&sessions, query, args...); err != nil { + return nil, errors.Wrapf(err, "failed to find Sessions with userId=%s", userId) + } + return sessions, nil +} + func (me SqlSessionStore) GetSessionsWithActiveDeviceIds(userId string) ([]*model.Session, error) { query := `SELECT *
server/channels/store/store.go+1 −0 modified@@ -498,6 +498,7 @@ type SessionStore interface { Get(c request.CTX, sessionIDOrToken string) (*model.Session, error) Save(c request.CTX, session *model.Session) (*model.Session, error) GetSessions(c request.CTX, userID string) ([]*model.Session, error) + GetLRUSessions(c request.CTX, userID string, limit uint64, offset uint64) ([]*model.Session, error) GetSessionsWithActiveDeviceIds(userID string) ([]*model.Session, error) GetSessionsExpired(thresholdMillis int64, mobileOnly bool, unnotifiedOnly bool) ([]*model.Session, error) UpdateExpiredNotify(sessionid string, notified bool) error
server/channels/store/storetest/mocks/SessionStore.go+26 −0 modified@@ -79,6 +79,32 @@ func (_m *SessionStore) Get(c request.CTX, sessionIDOrToken string) (*model.Sess return r0, r1 } +// GetLRUSessions provides a mock function with given fields: c, userID, limit, offset +func (_m *SessionStore) GetLRUSessions(c request.CTX, userID string, limit uint64, offset uint64) ([]*model.Session, error) { + ret := _m.Called(c, userID, limit, offset) + + var r0 []*model.Session + var r1 error + if rf, ok := ret.Get(0).(func(request.CTX, string, uint64, uint64) ([]*model.Session, error)); ok { + return rf(c, userID, limit, offset) + } + if rf, ok := ret.Get(0).(func(request.CTX, string, uint64, uint64) []*model.Session); ok { + r0 = rf(c, userID, limit, offset) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*model.Session) + } + } + + if rf, ok := ret.Get(1).(func(request.CTX, string, uint64, uint64) error); ok { + r1 = rf(c, userID, limit, offset) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetSessions provides a mock function with given fields: c, userID func (_m *SessionStore) GetSessions(c request.CTX, userID string) ([]*model.Session, error) { ret := _m.Called(c, userID)
server/channels/store/storetest/session_store.go+52 −0 modified@@ -5,6 +5,7 @@ package storetest import ( "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -36,6 +37,7 @@ func TestSessionStore(t *testing.T, rctx request.CTX, ss store.Store) { t.Run("SessionCount", func(t *testing.T) { testSessionCount(t, rctx, ss) }) t.Run("GetSessionsExpired", func(t *testing.T) { testGetSessionsExpired(t, rctx, ss) }) t.Run("UpdateExpiredNotify", func(t *testing.T) { testUpdateExpiredNotify(t, rctx, ss) }) + t.Run("GetLRUSessions", func(t *testing.T) { testGetLRUSessions(t, rctx, ss) }) } func testSessionStoreSave(t *testing.T, rctx request.CTX, ss store.Store) { @@ -404,3 +406,53 @@ func testUpdateExpiredNotify(t *testing.T, rctx request.CTX, ss store.Store) { require.NoError(t, err) require.False(t, session.ExpiredNotify) } + +func testGetLRUSessions(t *testing.T, rctx request.CTX, ss store.Store) { + userId := model.NewId() + + // Clear existing sessions. + err := ss.Session().RemoveAllSessions() + require.NoError(t, err) + + s1 := &model.Session{} + s1.UserId = userId + s1.DeviceId = model.NewId() + _, err = ss.Session().Save(rctx, s1) + require.NoError(t, err) + time.Sleep(1 * time.Millisecond) + + s2 := &model.Session{} + s2.UserId = userId + s2.DeviceId = model.NewId() + s2, err = ss.Session().Save(rctx, s2) + require.NoError(t, err) + time.Sleep(1 * time.Millisecond) + + s3 := &model.Session{} + s3.UserId = userId + s3.DeviceId = model.NewId() + s3, err = ss.Session().Save(rctx, s3) + require.NoError(t, err) + + sessions, err := ss.Session().GetLRUSessions(rctx, userId, 3, 3) + require.NoError(t, err) + require.Len(t, sessions, 0) + + sessions, err = ss.Session().GetLRUSessions(rctx, userId, 3, 2) + require.NoError(t, err) + require.Len(t, sessions, 1) + require.Equal(t, s1.Id, sessions[0].Id) + + sessions, err = ss.Session().GetLRUSessions(rctx, userId, 3, 1) + require.NoError(t, err) + require.Len(t, sessions, 2) + require.Equal(t, s2.Id, sessions[0].Id) + require.Equal(t, s1.Id, sessions[1].Id) + + sessions, err = ss.Session().GetLRUSessions(rctx, userId, 3, 0) + require.NoError(t, err) + require.Len(t, sessions, 3) + require.Equal(t, s3.Id, sessions[0].Id) + require.Equal(t, s2.Id, sessions[1].Id) + require.Equal(t, s1.Id, sessions[2].Id) +}
server/channels/store/timerlayer/timerlayer.go+16 −0 modified@@ -7657,6 +7657,22 @@ func (s *TimerLayerSessionStore) Get(c request.CTX, sessionIDOrToken string) (*m return result, err } +func (s *TimerLayerSessionStore) GetLRUSessions(c request.CTX, userID string, limit uint64, offset uint64) ([]*model.Session, error) { + start := time.Now() + + result, err := s.SessionStore.GetLRUSessions(c, userID, limit, offset) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SessionStore.GetLRUSessions", success, elapsed) + } + return result, err +} + func (s *TimerLayerSessionStore) GetSessions(c request.CTX, userID string) ([]*model.Session, error) { start := time.Now()
server/i18n/en.json+4 −0 modified@@ -6494,6 +6494,10 @@ "id": "app.session.get.app_error", "translation": "We encountered an error finding the session." }, + { + "id": "app.session.get_lru_sessions.app_error", + "translation": "Unable to get least recently used sessions." + }, { "id": "app.session.get_sessions.app_error", "translation": "We encountered an error while finding user sessions."
bc699e6789cfMM-55320 - Limit length of browser user agent version; ratelimit the /sessions endpoint (#25900) (#26547)
16 files changed · +321 −4
server/channels/app/app_iface.go+3 −0 modified@@ -194,6 +194,9 @@ type AppIface interface { // relationship with a user. That means any user sharing any channel, including // direct and group channels. GetKnownUsers(userID string) ([]string, *model.AppError) + // GetLRUSessions returns the Least Recently Used sessions for userID, skipping over the newest 'offset' + // number of sessions. E.g., if userID has 100 sessions, offset 98 will return the oldest 2 sessions. + GetLRUSessions(c request.CTX, userID string, limit uint64, offset uint64) ([]*model.Session, *model.AppError) // GetLastAccessibleFileTime returns CreateAt time(from cache) of the last accessible post as per the cloud limit GetLastAccessibleFileTime() (int64, *model.AppError) // GetLastAccessiblePostTime returns CreateAt time(from cache) of the last accessible post as per the cloud limit
server/channels/app/oauth.go+5 −0 modified@@ -383,6 +383,11 @@ func (a *App) GetOAuthAccessTokenForCodeFlow(c request.CTX, clientId, grantType, } func (a *App) newSession(c request.CTX, app *model.OAuthApp, user *model.User) (*model.Session, *model.AppError) { + if err := a.limitNumberOfSessions(c, user.Id); err != nil { + return nil, model.NewAppError("newSession", "api.oauth.get_access_token.internal_session.app_error", nil, + "", http.StatusInternalServerError).Wrap(err) + } + // Set new token an session session := &model.Session{UserId: user.Id, Roles: user.Roles, IsOAuth: true} session.GenerateCSRF()
server/channels/app/opentracing/opentracing_layer.go+22 −0 modified@@ -7203,6 +7203,28 @@ func (a *OpenTracingAppLayer) GetKnownUsers(userID string) ([]string, *model.App return resultVar0, resultVar1 } +func (a *OpenTracingAppLayer) GetLRUSessions(c request.CTX, userID string, limit uint64, offset uint64) ([]*model.Session, *model.AppError) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetLRUSessions") + + a.ctx = newCtx + a.app.Srv().Store().SetContext(newCtx) + defer func() { + a.app.Srv().Store().SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.GetLRUSessions(c, userID, limit, offset) + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + func (a *OpenTracingAppLayer) GetLastAccessibleFileTime() (int64, *model.AppError) { origCtx := a.ctx span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetLastAccessibleFileTime")
server/channels/app/platform/session.go+4 −0 modified@@ -41,6 +41,10 @@ func (ps *PlatformService) GetSessions(c request.CTX, userID string) ([]*model.S return ps.Store.Session().GetSessions(c, userID) } +func (ps *PlatformService) GetLRUSessions(c request.CTX, userID string, limit uint64, offset uint64) ([]*model.Session, error) { + return ps.Store.Session().GetLRUSessions(c, userID, limit, offset) +} + func (ps *PlatformService) AddSessionToCache(session *model.Session) { ps.sessionCache.SetWithExpiry(session.Token, session, time.Duration(int64(*ps.Config().ServiceSettings.SessionCacheInMinutes))*time.Minute) }
server/channels/app/session.go+45 −0 modified@@ -18,7 +18,14 @@ import ( "github.com/mattermost/mattermost/server/v8/channels/store" ) +// maxSessionsLimit prevents a potential DOS caused by creating an unbounded number of sessions; MM-55320 +const maxSessionsLimit = 500 + func (a *App) CreateSession(c request.CTX, session *model.Session) (*model.Session, *model.AppError) { + if appErr := a.limitNumberOfSessions(c, session.UserId); appErr != nil { + return nil, appErr + } + session, err := a.ch.srv.platform.CreateSession(c, session) if err != nil { var invErr *store.ErrInvalidInput @@ -136,6 +143,40 @@ func (a *App) GetSessions(c request.CTX, userID string) ([]*model.Session, *mode return sessions, nil } +// limitNumberOfSessions revokes userId's least recently used sessions to keep the number below +// maxSessionsLimit; MM-55320 +func (a *App) limitNumberOfSessions(c request.CTX, userId string) *model.AppError { + const returnLimit = 100 + sessions, appErr := a.GetLRUSessions(c, userId, returnLimit, maxSessionsLimit-1) + if appErr != nil { + return model.NewAppError("limitNumberOfSessions", "app.session.save.app_error", nil, "", http.StatusInternalServerError).Wrap(appErr) + } + + // Revoke any sessions over the limit to make room for new sessions + for _, sess := range sessions { + if err := a.RevokeSession(c, sess); err != nil { + return model.NewAppError("limitNumberOfSessions", "app.session.save.app_error", nil, "", http.StatusInternalServerError).Wrap(err) + } + + c.Logger().Debug("Session revoked; user's number of sessions were over the maxSessionsLimit", + mlog.String("user_id", userId), + mlog.String("session_id", sess.Id)) + } + + return nil +} + +// GetLRUSessions returns the Least Recently Used sessions for userID, skipping over the newest 'offset' +// number of sessions. E.g., if userID has 100 sessions, offset 98 will return the oldest 2 sessions. +func (a *App) GetLRUSessions(c request.CTX, userID string, limit uint64, offset uint64) ([]*model.Session, *model.AppError) { + sessions, err := a.ch.srv.platform.GetLRUSessions(c, userID, limit, offset) + if err != nil { + return nil, model.NewAppError("GetLRUSessions", "app.session.get_lru_sessions.app_error", nil, "", http.StatusInternalServerError).Wrap(err) + } + + return sessions, nil +} + func (a *App) RevokeAllSessions(c request.CTX, userID string) *model.AppError { if err := a.ch.srv.platform.RevokeAllSessions(c, userID); err != nil { switch { @@ -384,6 +425,10 @@ func (a *App) createSessionForUserAccessToken(c request.CTX, tokenString string) return nil, model.NewAppError("createSessionForUserAccessToken", "app.user_access_token.invalid_or_missing", nil, "inactive_user_id="+user.Id, http.StatusUnauthorized) } + if appErr := a.limitNumberOfSessions(c, user.Id); appErr != nil { + return nil, appErr + } + session := &model.Session{ Token: token.Token, UserId: user.Id,
server/channels/app/session_test.go+57 −0 modified@@ -5,8 +5,11 @@ package app import ( "fmt" + "net/http" + "net/http/httptest" "os" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -395,3 +398,57 @@ func TestGetRemoteClusterSession(t *testing.T) { require.Nil(t, session) }) } + +func TestSessionsLimit(t *testing.T) { + th := Setup(t).InitBasic() + defer th.TearDown() + + user := th.BasicUser + var sessions []*model.Session + + r := &http.Request{} + w := httptest.NewRecorder() + for i := 0; i < maxSessionsLimit; i++ { + session, err := th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false) + require.Nil(t, err) + sessions = append(sessions, session) + time.Sleep(1 * time.Millisecond) + } + + gotSessions, _ := th.App.GetSessions(th.Context, user.Id) + require.Equal(t, maxSessionsLimit, len(gotSessions), "should have maxSessionsLimit number of sessions") + + // Ensure we are retrieving the same sessions. + reverse(gotSessions) + for i, sess := range gotSessions { + require.Equal(t, sessions[i].Id, sess.Id) + } + + // Now add 10 more. + for i := 0; i < 10; i++ { + session, err := th.App.DoLogin(th.Context, w, r, th.BasicUser, "", false, false, false) + require.Nil(t, err, "should not have an error creating user sessions") + + // Remove oldest, append newest. + sessions = sessions[1:] + sessions = append(sessions, session) + time.Sleep(1 * time.Millisecond) + } + + // Ensure that we still only have the max allowed. + gotSessions, _ = th.App.GetSessions(th.Context, user.Id) + require.Equal(t, maxSessionsLimit, len(gotSessions), "should have maxSessionsLimit number of sessions") + + // Ensure the the oldest sessions were removed first. + reverse(gotSessions) + for i, sess := range gotSessions { + require.Equal(t, sessions[i].Id, sess.Id) + } +} + +// reverse can be replaced by the slices version when we move to 1.21+ +func reverse[S ~[]E, E any](s S) { + for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { + s[i], s[j] = s[j], s[i] + } +}
server/channels/app/user_agent.go+20 −4 modified@@ -10,6 +10,8 @@ import ( "github.com/avct/uasurfer" ) +const maxUserAgentVersionLength = 128 + var platformNames = map[uasurfer.Platform]string{ uasurfer.PlatformUnknown: "Windows", uasurfer.PlatformWindows: "Windows", @@ -86,27 +88,33 @@ func getOSName(ua *uasurfer.UserAgent) string { func getBrowserVersion(ua *uasurfer.UserAgent, userAgentString string) string { if index := strings.Index(userAgentString, "Mattermost Mobile/"); index != -1 { afterVersion := userAgentString[index+len("Mattermost Mobile/"):] - return strings.Fields(afterVersion)[0] + // MM-55320: limitStringLength prevents potential DOS caused by filling an unbounded string with junk data + return limitStringLength(strings.Fields(afterVersion)[0], maxUserAgentVersionLength) } if index := strings.Index(userAgentString, "Mattermost/"); index != -1 { afterVersion := userAgentString[index+len("Mattermost/"):] - return strings.Fields(afterVersion)[0] + return limitStringLength(strings.Fields(afterVersion)[0], maxUserAgentVersionLength) } if index := strings.Index(userAgentString, "mmctl/"); index != -1 { afterVersion := userAgentString[index+len("mmctl/"):] - return strings.Fields(afterVersion)[0] + return limitStringLength(strings.Fields(afterVersion)[0], maxUserAgentVersionLength) } if index := strings.Index(userAgentString, "Franz/"); index != -1 { afterVersion := userAgentString[index+len("Franz/"):] - return strings.Fields(afterVersion)[0] + return limitStringLength(strings.Fields(afterVersion)[0], maxUserAgentVersionLength) } return getUAVersion(ua.Browser.Version) } +func limitStringLength(field string, limit int) string { + endPos := min(len(field), limit) + return field[:endPos] +} + func getUAVersion(version uasurfer.Version) string { if version.Patch == 0 { return fmt.Sprintf("%v.%v", version.Major, version.Minor) @@ -151,3 +159,11 @@ func getBrowserName(ua *uasurfer.UserAgent, userAgentString string) string { return browserNames[uasurfer.BrowserUnknown] } + +// min should be replaced by to go 1.21 built-in generic function, see MM-57356. +func min(a, b int) int { + if a < b { + return a + } + return b +}
server/channels/app/user_agent_test.go+5 −0 modified@@ -34,6 +34,7 @@ var testUserAgents = []testUserAgent{ {"Safari 8", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_4) AppleWebKit/600.7.12 (KHTML, like Gecko) Version/8.0.7 Safari/600.7.12"}, {"Safari Mobile", "Mozilla/5.0 (iPhone; CPU iPhone OS 9_1 like Mac OS X) AppleWebKit/601.1.46 (KHTML, like Gecko) Version/9.0 Mobile/13B137 Safari/601.1"}, {"Mobile App", "Mattermost Mobile/2.7.0+482 (Android; 13; sdk_gphone64_arm64)"}, + {"Mobile App", "Mattermost Mobile/233.234441.341234223421341234529099823109834440981234+abcdef3214eafeabc3242331129857301afesfffff1930a84e4bd2348fe129ac1309bd929dca3419af934bfe3089fcd (Android; 13; sdk_gphone64_arm64)"}, } func TestGetPlatformName(t *testing.T) { @@ -55,6 +56,7 @@ func TestGetPlatformName(t *testing.T) { "Macintosh", "iPhone", "Linux", + "Linux", } for i, userAgent := range testUserAgents { @@ -86,6 +88,7 @@ func TestGetOSName(t *testing.T) { "Mac OS", "iOS", "Android", + "Android", } for i, userAgent := range testUserAgents { @@ -117,6 +120,7 @@ func TestGetBrowserName(t *testing.T) { "Safari", "Safari", "Mobile App", + "Mobile App", } for i, userAgent := range testUserAgents { @@ -148,6 +152,7 @@ func TestGetBrowserVersion(t *testing.T) { "8.0.7", "9.0", "2.7.0+482", + "233.234441.341234223421341234529099823109834440981234+abcdef3214eafeabc3242331129857301afesfffff1930a84e4bd2348fe129ac1309bd929d", // cut off at len 128 } for i, userAgent := range testUserAgents {
server/channels/store/opentracinglayer/opentracinglayer.go+18 −0 modified@@ -8525,6 +8525,24 @@ func (s *OpenTracingLayerSessionStore) Get(c request.CTX, sessionIDOrToken strin return result, err } +func (s *OpenTracingLayerSessionStore) GetLRUSessions(c request.CTX, userID string, limit uint64, offset uint64) ([]*model.Session, error) { + origCtx := s.Root.Store.Context() + span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SessionStore.GetLRUSessions") + s.Root.Store.SetContext(newCtx) + defer func() { + s.Root.Store.SetContext(origCtx) + }() + + defer span.Finish() + result, err := s.SessionStore.GetLRUSessions(c, userID, limit, offset) + if err != nil { + span.LogFields(spanlog.Error(err)) + ext.Error.Set(span, true) + } + + return result, err +} + func (s *OpenTracingLayerSessionStore) GetSessions(c request.CTX, userID string) ([]*model.Session, error) { origCtx := s.Root.Store.Context() span, newCtx := tracing.StartSpanWithParentByContext(s.Root.Store.Context(), "SessionStore.GetSessions")
server/channels/store/retrylayer/retrylayer.go+21 −0 modified@@ -9713,6 +9713,27 @@ func (s *RetryLayerSessionStore) Get(c request.CTX, sessionIDOrToken string) (*m } +func (s *RetryLayerSessionStore) GetLRUSessions(c request.CTX, userID string, limit uint64, offset uint64) ([]*model.Session, error) { + + tries := 0 + for { + result, err := s.SessionStore.GetLRUSessions(c, userID, limit, offset) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + func (s *RetryLayerSessionStore) GetSessions(c request.CTX, userID string) ([]*model.Session, error) { tries := 0
server/channels/store/sqlstore/session_store.go+22 −0 modified@@ -123,6 +123,28 @@ func (me SqlSessionStore) GetSessions(c request.CTX, userId string) ([]*model.Se return sessions, nil } +// GetLRUSessions gets the Least Recently Used sessions from the store. Note: the use of limit and offset +// are intentional; they are hardcoded from the app layer (i.e., will not result in a non-performant query). +func (me SqlSessionStore) GetLRUSessions(c request.CTX, userId string, limit uint64, offset uint64) ([]*model.Session, error) { + builder := me.getQueryBuilder(). + Select("*"). + From("Sessions"). + Where(sq.Eq{"UserId": userId}). + OrderBy("LastActivityAt DESC"). + Limit(limit). + Offset(offset) + query, args, err := builder.ToSql() + if err != nil { + return nil, errors.Wrap(err, "get_lru_sessions_tosql") + } + + var sessions []*model.Session + if err := me.GetReplicaX().Select(&sessions, query, args...); err != nil { + return nil, errors.Wrapf(err, "failed to find Sessions with userId=%s", userId) + } + return sessions, nil +} + func (me SqlSessionStore) GetSessionsWithActiveDeviceIds(userId string) ([]*model.Session, error) { query := `SELECT *
server/channels/store/store.go+1 −0 modified@@ -498,6 +498,7 @@ type SessionStore interface { Get(c request.CTX, sessionIDOrToken string) (*model.Session, error) Save(c request.CTX, session *model.Session) (*model.Session, error) GetSessions(c request.CTX, userID string) ([]*model.Session, error) + GetLRUSessions(c request.CTX, userID string, limit uint64, offset uint64) ([]*model.Session, error) GetSessionsWithActiveDeviceIds(userID string) ([]*model.Session, error) GetSessionsExpired(thresholdMillis int64, mobileOnly bool, unnotifiedOnly bool) ([]*model.Session, error) UpdateExpiredNotify(sessionid string, notified bool) error
server/channels/store/storetest/mocks/SessionStore.go+26 −0 modified@@ -79,6 +79,32 @@ func (_m *SessionStore) Get(c request.CTX, sessionIDOrToken string) (*model.Sess return r0, r1 } +// GetLRUSessions provides a mock function with given fields: c, userID, limit, offset +func (_m *SessionStore) GetLRUSessions(c request.CTX, userID string, limit uint64, offset uint64) ([]*model.Session, error) { + ret := _m.Called(c, userID, limit, offset) + + var r0 []*model.Session + var r1 error + if rf, ok := ret.Get(0).(func(request.CTX, string, uint64, uint64) ([]*model.Session, error)); ok { + return rf(c, userID, limit, offset) + } + if rf, ok := ret.Get(0).(func(request.CTX, string, uint64, uint64) []*model.Session); ok { + r0 = rf(c, userID, limit, offset) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*model.Session) + } + } + + if rf, ok := ret.Get(1).(func(request.CTX, string, uint64, uint64) error); ok { + r1 = rf(c, userID, limit, offset) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetSessions provides a mock function with given fields: c, userID func (_m *SessionStore) GetSessions(c request.CTX, userID string) ([]*model.Session, error) { ret := _m.Called(c, userID)
server/channels/store/storetest/session_store.go+52 −0 modified@@ -5,6 +5,7 @@ package storetest import ( "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -36,6 +37,7 @@ func TestSessionStore(t *testing.T, rctx request.CTX, ss store.Store) { t.Run("SessionCount", func(t *testing.T) { testSessionCount(t, rctx, ss) }) t.Run("GetSessionsExpired", func(t *testing.T) { testGetSessionsExpired(t, rctx, ss) }) t.Run("UpdateExpiredNotify", func(t *testing.T) { testUpdateExpiredNotify(t, rctx, ss) }) + t.Run("GetLRUSessions", func(t *testing.T) { testGetLRUSessions(t, rctx, ss) }) } func testSessionStoreSave(t *testing.T, rctx request.CTX, ss store.Store) { @@ -404,3 +406,53 @@ func testUpdateExpiredNotify(t *testing.T, rctx request.CTX, ss store.Store) { require.NoError(t, err) require.False(t, session.ExpiredNotify) } + +func testGetLRUSessions(t *testing.T, rctx request.CTX, ss store.Store) { + userId := model.NewId() + + // Clear existing sessions. + err := ss.Session().RemoveAllSessions() + require.NoError(t, err) + + s1 := &model.Session{} + s1.UserId = userId + s1.DeviceId = model.NewId() + _, err = ss.Session().Save(rctx, s1) + require.NoError(t, err) + time.Sleep(1 * time.Millisecond) + + s2 := &model.Session{} + s2.UserId = userId + s2.DeviceId = model.NewId() + s2, err = ss.Session().Save(rctx, s2) + require.NoError(t, err) + time.Sleep(1 * time.Millisecond) + + s3 := &model.Session{} + s3.UserId = userId + s3.DeviceId = model.NewId() + s3, err = ss.Session().Save(rctx, s3) + require.NoError(t, err) + + sessions, err := ss.Session().GetLRUSessions(rctx, userId, 3, 3) + require.NoError(t, err) + require.Len(t, sessions, 0) + + sessions, err = ss.Session().GetLRUSessions(rctx, userId, 3, 2) + require.NoError(t, err) + require.Len(t, sessions, 1) + require.Equal(t, s1.Id, sessions[0].Id) + + sessions, err = ss.Session().GetLRUSessions(rctx, userId, 3, 1) + require.NoError(t, err) + require.Len(t, sessions, 2) + require.Equal(t, s2.Id, sessions[0].Id) + require.Equal(t, s1.Id, sessions[1].Id) + + sessions, err = ss.Session().GetLRUSessions(rctx, userId, 3, 0) + require.NoError(t, err) + require.Len(t, sessions, 3) + require.Equal(t, s3.Id, sessions[0].Id) + require.Equal(t, s2.Id, sessions[1].Id) + require.Equal(t, s1.Id, sessions[2].Id) +}
server/channels/store/timerlayer/timerlayer.go+16 −0 modified@@ -7689,6 +7689,22 @@ func (s *TimerLayerSessionStore) Get(c request.CTX, sessionIDOrToken string) (*m return result, err } +func (s *TimerLayerSessionStore) GetLRUSessions(c request.CTX, userID string, limit uint64, offset uint64) ([]*model.Session, error) { + start := time.Now() + + result, err := s.SessionStore.GetLRUSessions(c, userID, limit, offset) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("SessionStore.GetLRUSessions", success, elapsed) + } + return result, err +} + func (s *TimerLayerSessionStore) GetSessions(c request.CTX, userID string) ([]*model.Session, error) { start := time.Now()
server/i18n/en.json+4 −0 modified@@ -6590,6 +6590,10 @@ "id": "app.session.get.app_error", "translation": "We encountered an error finding the session." }, + { + "id": "app.session.get_lru_sessions.app_error", + "translation": "Unable to get least recently used sessions." + }, { "id": "app.session.get_sessions.app_error", "translation": "We encountered an error while finding user sessions."
Vulnerability mechanics
Generated by null/stub on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.
References
7- github.com/advisories/GHSA-wj37-mpq9-xrcmghsaADVISORY
- nvd.nist.gov/vuln/detail/CVE-2024-4183ghsaADVISORY
- github.com/mattermost/mattermost/commit/86920d641760552c5aafa5e1d14c93bd30039bc4ghsaWEB
- github.com/mattermost/mattermost/commit/9d81eee979aee93374bff8ba6714d805e12ffb03ghsaWEB
- github.com/mattermost/mattermost/commit/b45c3dac4c160992a1ce757ade968e8f5ec506c1ghsaWEB
- github.com/mattermost/mattermost/commit/bc699e6789cf3ba1544235087897699aaa639e7dghsaWEB
- mattermost.com/security-updatesghsaWEB
News mentions
0No linked articles in our index yet.