Account Takeover via Code Exchange Endpoint
Description
Mattermost versions 11.0.x <= 11.0.2, 10.12.x <= 10.12.1, 10.11.x <= 10.11.4, 10.5.x <= 10.5.12 fail to to verify that the token used during the code exchange originates from the same authentication flow, which allows an authenticated user to perform account takeover via a specially crafted email address used when switching authentication methods and sending a request to the /users/login/sso/code-exchange endpoint. The vulnerability requires ExperimentalEnableAuthenticationTransfer to be enabled (default: enabled) and RequireEmailVerification to be disabled (default: disabled).
Affected packages
Versions sourced from the GitHub Security Advisory.
| Package | Affected versions | Patched versions |
|---|---|---|
github.com/mattermost/mattermost/server/v8Go | < 8.0.0-20251022210333-acda1fb5dd46 | 8.0.0-20251022210333-acda1fb5dd46 |
github.com/mattermost/mattermost-serverGo | >= 11.0.0, < 11.0.3 | 11.0.3 |
github.com/mattermost/mattermost-serverGo | >= 10.12.0, < 10.12.2 | 10.12.2 |
github.com/mattermost/mattermost-serverGo | >= 10.11.0, < 10.11.5 | 10.11.5 |
github.com/mattermost/mattermost-serverGo | >= 10.5.0, < 10.5.13 | 10.5.13 |
Affected products
1- Range: 11.0.0
Patches
4f361e7d75a7aAutomated cherry pick of #34247 (#34257)
14 files changed · +352 −32
server/channels/api4/user.go+1 −1 modified@@ -130,7 +130,7 @@ func loginSSOCodeExchange(c *Context, w http.ResponseWriter, r *http.Request) { } // Consume one-time code atomically - token, appErr := c.App.ConsumeTokenOnce(loginCode) + token, appErr := c.App.ConsumeTokenOnce(model.TokenTypeSSOCodeExchange, loginCode) if appErr != nil { c.Err = appErr return
server/channels/api4/user_test.go+81 −0 modified@@ -6,6 +6,8 @@ package api4 import ( "bytes" "context" + "crypto/sha256" + "encoding/base64" "encoding/json" "fmt" "image/png" @@ -8673,6 +8675,85 @@ func TestLoginWithDesktopToken(t *testing.T) { }) } +func TestLoginSSOCodeExchange(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic() + defer th.TearDown() + + t.Run("wrong token type cannot be used for code exchange", func(t *testing.T) { + th.App.UpdateConfig(func(cfg *model.Config) { + cfg.FeatureFlags.MobileSSOCodeExchange = true + }) + + token := model.NewToken(model.TokenTypeOAuth, "extra-data") + require.NoError(t, th.App.Srv().Store().Token().Save(token)) + defer func() { + _ = th.App.Srv().Store().Token().Delete(token.Token) + }() + + props := map[string]string{ + "login_code": token.Token, + "code_verifier": "test_verifier", + "state": "test_state", + } + + resp, err := th.Client.DoAPIPost(context.Background(), "/users/login/sso/code-exchange", model.MapToJSON(props)) + require.Error(t, err) + require.Equal(t, http.StatusNotFound, resp.StatusCode) + }) + + t.Run("successful code exchange with S256 challenge", func(t *testing.T) { + th.App.UpdateConfig(func(cfg *model.Config) { + cfg.FeatureFlags.MobileSSOCodeExchange = true + }) + + samlUser := th.CreateUserWithAuth(model.UserAuthServiceSaml) + + codeVerifier := "test_code_verifier_123456789" + state := "test_state_value" + + sum := sha256.Sum256([]byte(codeVerifier)) + codeChallenge := base64.RawURLEncoding.EncodeToString(sum[:]) + + extra := map[string]string{ + "user_id": samlUser.Id, + "code_challenge": codeChallenge, + "code_challenge_method": "S256", + "state": state, + } + + token := model.NewToken(model.TokenTypeSSOCodeExchange, model.MapToJSON(extra)) + require.NoError(t, th.App.Srv().Store().Token().Save(token)) + + props := map[string]string{ + "login_code": token.Token, + "code_verifier": codeVerifier, + "state": state, + } + + resp, err := th.Client.DoAPIPost(context.Background(), "/users/login/sso/code-exchange", model.MapToJSON(props)) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + var result map[string]string + require.NoError(t, json.NewDecoder(resp.Body).Decode(&result)) + assert.NotEmpty(t, result["token"]) + assert.NotEmpty(t, result["csrf"]) + + _, err = th.App.Srv().Store().Token().GetByToken(token.Token) + require.Error(t, err) + + authenticatedClient := model.NewAPIv4Client(th.Client.URL) + authenticatedClient.SetToken(result["token"]) + + user, _, err := authenticatedClient.GetMe(context.Background(), "") + require.NoError(t, err) + assert.Equal(t, samlUser.Id, user.Id) + assert.Equal(t, samlUser.Email, user.Email) + assert.Equal(t, samlUser.Username, user.Username) + }) +} + func TestGetUsersByNames(t *testing.T) { mainHelper.Parallel(t) th := Setup(t).InitBasic()
server/channels/app/oauth.go+1 −1 modified@@ -977,7 +977,7 @@ func (a *App) SwitchEmailToOAuth(c request.CTX, w http.ResponseWriter, r *http.R stateProps["email"] = email if service == model.UserAuthServiceSaml { - samlToken, samlErr := a.CreateSamlRelayToken(email) + samlToken, samlErr := a.CreateSamlRelayToken(model.TokenTypeSaml, email) if samlErr != nil { return "", samlErr }
server/channels/app/saml.go+2 −2 modified@@ -298,8 +298,8 @@ func (a *App) ResetSamlAuthDataToEmail(includeDeleted bool, dryRun bool, userIDs return } -func (a *App) CreateSamlRelayToken(extra string) (*model.Token, *model.AppError) { - token := model.NewToken(model.TokenTypeSaml, extra) +func (a *App) CreateSamlRelayToken(tokenType string, extra string) (*model.Token, *model.AppError) { + token := model.NewToken(tokenType, extra) if err := a.Srv().Store().Token().Save(token); err != nil { var appErr *model.AppError
server/channels/app/user.go+2 −2 modified@@ -1750,8 +1750,8 @@ func (a *App) GetTokenById(token string) (*model.Token, *model.AppError) { return rtoken, nil } -func (a *App) ConsumeTokenOnce(tokenStr string) (*model.Token, *model.AppError) { - token, err := a.Srv().Store().Token().ConsumeOnce(tokenStr) +func (a *App) ConsumeTokenOnce(tokenType, tokenStr string) (*model.Token, *model.AppError) { + token, err := a.Srv().Store().Token().ConsumeOnce(tokenType, tokenStr) if err != nil { var status int switch err.(type) {
server/channels/app/user_test.go+82 −0 modified@@ -9,6 +9,7 @@ import ( "database/sql" "encoding/json" "errors" + "net/http" "os" "path/filepath" "strings" @@ -2484,3 +2485,84 @@ func TestRemoteUserDirectChannelCreation(t *testing.T) { assert.Equal(t, model.ChannelTypeDirect, channel.Type) }) } + +func TestConsumeTokenOnce(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic() + defer th.TearDown() + + t.Run("successfully consume valid token", func(t *testing.T) { + token := model.NewToken(model.TokenTypeOAuth, "extra-data") + require.NoError(t, th.App.Srv().Store().Token().Save(token)) + + consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeOAuth, token.Token) + require.Nil(t, appErr) + require.NotNil(t, consumedToken) + assert.Equal(t, token.Token, consumedToken.Token) + assert.Equal(t, model.TokenTypeOAuth, consumedToken.Type) + assert.Equal(t, "extra-data", consumedToken.Extra) + + _, err := th.App.Srv().Store().Token().GetByToken(token.Token) + require.Error(t, err) + }) + + t.Run("token not found returns 404", func(t *testing.T) { + nonExistentToken := model.NewRandomString(model.TokenSize) + + consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeOAuth, nonExistentToken) + require.NotNil(t, appErr) + require.Nil(t, consumedToken) + assert.Equal(t, http.StatusNotFound, appErr.StatusCode) + assert.Equal(t, "ConsumeTokenOnce", appErr.Where) + }) + + t.Run("wrong token type returns not found", func(t *testing.T) { + token := model.NewToken(model.TokenTypeOAuth, "extra-data") + require.NoError(t, th.App.Srv().Store().Token().Save(token)) + defer func() { + _ = th.App.Srv().Store().Token().Delete(token.Token) + }() + + consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeSaml, token.Token) + require.NotNil(t, appErr) + require.Nil(t, consumedToken) + assert.Equal(t, http.StatusNotFound, appErr.StatusCode) + + _, err := th.App.Srv().Store().Token().GetByToken(token.Token) + require.NoError(t, err) + }) + + t.Run("token can only be consumed once", func(t *testing.T) { + token := model.NewToken(model.TokenTypeSSOCodeExchange, "extra-data") + require.NoError(t, th.App.Srv().Store().Token().Save(token)) + + consumedToken1, appErr := th.App.ConsumeTokenOnce(model.TokenTypeSSOCodeExchange, token.Token) + require.Nil(t, appErr) + require.NotNil(t, consumedToken1) + + consumedToken2, appErr := th.App.ConsumeTokenOnce(model.TokenTypeSSOCodeExchange, token.Token) + require.NotNil(t, appErr) + require.Nil(t, consumedToken2) + assert.Equal(t, http.StatusNotFound, appErr.StatusCode) + }) + + t.Run("empty token string returns not found", func(t *testing.T) { + consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeOAuth, "") + require.NotNil(t, appErr) + require.Nil(t, consumedToken) + assert.Equal(t, http.StatusNotFound, appErr.StatusCode) + }) + + t.Run("empty token type returns not found", func(t *testing.T) { + token := model.NewToken(model.TokenTypeOAuth, "extra-data") + require.NoError(t, th.App.Srv().Store().Token().Save(token)) + defer func() { + _ = th.App.Srv().Store().Token().Delete(token.Token) + }() + + consumedToken, appErr := th.App.ConsumeTokenOnce("", token.Token) + require.NotNil(t, appErr) + require.Nil(t, consumedToken) + assert.Equal(t, http.StatusNotFound, appErr.StatusCode) + }) +}
server/channels/store/retrylayer/retrylayer.go+2 −2 modified@@ -14138,11 +14138,11 @@ func (s *RetryLayerTokenStore) Cleanup(expiryTime int64) { } -func (s *RetryLayerTokenStore) ConsumeOnce(tokenStr string) (*model.Token, error) { +func (s *RetryLayerTokenStore) ConsumeOnce(tokenType string, tokenStr string) (*model.Token, error) { tries := 0 for { - result, err := s.TokenStore.ConsumeOnce(tokenStr) + result, err := s.TokenStore.ConsumeOnce(tokenType, tokenStr) if err == nil { return result, nil }
server/channels/store/sqlstore/tokens_store.go+29 −4 modified@@ -78,16 +78,41 @@ func (s SqlTokenStore) GetByToken(tokenString string) (*model.Token, error) { return &token, nil } -func (s SqlTokenStore) ConsumeOnce(tokenStr string) (*model.Token, error) { +func (s SqlTokenStore) ConsumeOnce(tokenType, tokenStr string) (*model.Token, error) { var token model.Token - query := `DELETE FROM Tokens WHERE Token = ? RETURNING *` + if s.DriverName() == model.DatabaseDriverPostgres { + query := `DELETE FROM Tokens WHERE Type = ? AND Token = ? RETURNING *` + if err := s.GetMaster().Get(&token, query, tokenType, tokenStr); err != nil { + if err == sql.ErrNoRows { + return nil, store.NewErrNotFound("Token", tokenStr) + } + return nil, errors.Wrapf(err, "failed to consume token with type %s", tokenType) + } + return &token, nil + } + + transaction, err := s.GetMaster().Beginx() + if err != nil { + return nil, errors.Wrap(err, "failed to begin transaction") + } + defer finalizeTransactionX(transaction, &err) - if err := s.GetMaster().Get(&token, query, tokenStr); err != nil { + query := `SELECT * FROM Tokens WHERE Type = ? AND Token = ? FOR UPDATE` + if err = transaction.Get(&token, query, tokenType, tokenStr); err != nil { if err == sql.ErrNoRows { return nil, store.NewErrNotFound("Token", tokenStr) } - return nil, errors.Wrapf(err, "failed to consume token") + return nil, errors.Wrapf(err, "failed to select token with type %s", tokenType) + } + + deleteQuery := `DELETE FROM Tokens WHERE Type = ? AND Token = ?` + if _, err = transaction.Exec(deleteQuery, tokenType, tokenStr); err != nil { + return nil, errors.Wrapf(err, "failed to delete token with type %s", tokenType) + } + + if err = transaction.Commit(); err != nil { + return nil, errors.Wrap(err, "failed to commit transaction") } return &token, nil
server/channels/store/store.go+1 −1 modified@@ -693,7 +693,7 @@ type TokenStore interface { Save(recovery *model.Token) error Delete(token string) error GetByToken(token string) (*model.Token, error) - ConsumeOnce(tokenStr string) (*model.Token, error) + ConsumeOnce(tokenType, tokenStr string) (*model.Token, error) Cleanup(expiryTime int64) GetAllTokensByType(tokenType string) ([]*model.Token, error) RemoveAllTokensByType(tokenType string) error
server/channels/store/storetest/mocks/TokenStore.go+9 −9 modified@@ -19,29 +19,29 @@ func (_m *TokenStore) Cleanup(expiryTime int64) { _m.Called(expiryTime) } -// ConsumeOnce provides a mock function with given fields: tokenStr -func (_m *TokenStore) ConsumeOnce(tokenStr string) (*model.Token, error) { - ret := _m.Called(tokenStr) +// ConsumeOnce provides a mock function with given fields: tokenType, tokenStr +func (_m *TokenStore) ConsumeOnce(tokenType string, tokenStr string) (*model.Token, error) { + ret := _m.Called(tokenType, tokenStr) if len(ret) == 0 { panic("no return value specified for ConsumeOnce") } var r0 *model.Token var r1 error - if rf, ok := ret.Get(0).(func(string) (*model.Token, error)); ok { - return rf(tokenStr) + if rf, ok := ret.Get(0).(func(string, string) (*model.Token, error)); ok { + return rf(tokenType, tokenStr) } - if rf, ok := ret.Get(0).(func(string) *model.Token); ok { - r0 = rf(tokenStr) + if rf, ok := ret.Get(0).(func(string, string) *model.Token); ok { + r0 = rf(tokenType, tokenStr) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*model.Token) } } - if rf, ok := ret.Get(1).(func(string) error); ok { - r1 = rf(tokenStr) + if rf, ok := ret.Get(1).(func(string, string) error); ok { + r1 = rf(tokenType, tokenStr) } else { r1 = ret.Error(1) }
server/channels/store/storetest/tokens_store.go+128 −0 modified@@ -16,6 +16,7 @@ import ( func TestTokensStore(t *testing.T, rctx request.CTX, ss store.Store) { t.Run("TokensCleanup", func(t *testing.T) { testTokensCleanup(t, rctx, ss) }) + t.Run("ConsumeOnce", func(t *testing.T) { testConsumeOnce(t, rctx, ss) }) } func testTokensCleanup(t *testing.T, rctx request.CTX, ss store.Store) { @@ -41,3 +42,130 @@ func testTokensCleanup(t *testing.T, rctx request.CTX, ss store.Store) { require.NoError(t, err) assert.Len(t, tokens, 0) } + +func testConsumeOnce(t *testing.T, rctx request.CTX, ss store.Store) { + t.Run("successfully consume token once", func(t *testing.T) { + token := &model.Token{ + Token: model.NewRandomString(model.TokenSize), + CreateAt: model.GetMillis(), + Type: model.TokenTypeOAuth, + Extra: "test-extra", + } + err := ss.Token().Save(token) + require.NoError(t, err) + + consumedToken, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token) + require.NoError(t, err) + assert.Equal(t, token.Token, consumedToken.Token) + assert.Equal(t, token.Type, consumedToken.Type) + assert.Equal(t, token.Extra, consumedToken.Extra) + + tokens, err := ss.Token().GetAllTokensByType(model.TokenTypeOAuth) + require.NoError(t, err) + assert.Len(t, tokens, 0) + }) + + t.Run("second consumption of same token fails", func(t *testing.T) { + token := &model.Token{ + Token: model.NewRandomString(model.TokenSize), + CreateAt: model.GetMillis(), + Type: model.TokenTypeOAuth, + Extra: "test-extra", + } + err := ss.Token().Save(token) + require.NoError(t, err) + + _, err = ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token) + require.NoError(t, err) + + _, err = ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token) + require.Error(t, err) + var nfErr *store.ErrNotFound + assert.ErrorAs(t, err, &nfErr) + }) + + t.Run("consume with wrong type fails", func(t *testing.T) { + token := &model.Token{ + Token: model.NewRandomString(model.TokenSize), + CreateAt: model.GetMillis(), + Type: model.TokenTypeOAuth, + Extra: "test-extra", + } + err := ss.Token().Save(token) + require.NoError(t, err) + + _, err = ss.Token().ConsumeOnce(model.TokenTypeSSOCodeExchange, token.Token) + require.Error(t, err) + var nfErr *store.ErrNotFound + assert.ErrorAs(t, err, &nfErr) + + tokens, err := ss.Token().GetAllTokensByType(model.TokenTypeOAuth) + require.NoError(t, err) + assert.Len(t, tokens, 1) + + err = ss.Token().Delete(token.Token) + require.NoError(t, err) + }) + + t.Run("consume non-existent token fails", func(t *testing.T) { + nonExistentToken := model.NewRandomString(model.TokenSize) + _, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, nonExistentToken) + require.Error(t, err) + var nfErr *store.ErrNotFound + assert.ErrorAs(t, err, &nfErr) + }) + + t.Run("multiple tokens with same type can each be consumed once", func(t *testing.T) { + tokens := make([]*model.Token, 3) + for i := range tokens { + tokens[i] = &model.Token{ + Token: model.NewRandomString(model.TokenSize), + CreateAt: model.GetMillis(), + Type: model.TokenTypeOAuth, + Extra: "test-extra", + } + err := ss.Token().Save(tokens[i]) + require.NoError(t, err) + } + + for _, token := range tokens { + consumedToken, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token) + require.NoError(t, err) + assert.Equal(t, token.Token, consumedToken.Token) + } + + allTokens, err := ss.Token().GetAllTokensByType(model.TokenTypeOAuth) + require.NoError(t, err) + assert.Len(t, allTokens, 0) + }) + + t.Run("consuming token of different type leaves others intact", func(t *testing.T) { + oauthToken := &model.Token{ + Token: model.NewRandomString(model.TokenSize), + CreateAt: model.GetMillis(), + Type: model.TokenTypeOAuth, + Extra: "oauth-extra", + } + codeExchangeToken := &model.Token{ + Token: model.NewRandomString(model.TokenSize), + CreateAt: model.GetMillis(), + Type: model.TokenTypeSSOCodeExchange, + Extra: "password-extra", + } + err := ss.Token().Save(oauthToken) + require.NoError(t, err) + err = ss.Token().Save(codeExchangeToken) + require.NoError(t, err) + + consumedToken, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, oauthToken.Token) + require.NoError(t, err) + assert.Equal(t, oauthToken.Token, consumedToken.Token) + + codeExchangeTokens, err := ss.Token().GetAllTokensByType(model.TokenTypeSSOCodeExchange) + require.NoError(t, err) + assert.Len(t, codeExchangeTokens, 1) + + err = ss.Token().Delete(codeExchangeToken.Token) + require.NoError(t, err) + }) +}
server/channels/store/timerlayer/timerlayer.go+2 −2 modified@@ -11106,10 +11106,10 @@ func (s *TimerLayerTokenStore) Cleanup(expiryTime int64) { } } -func (s *TimerLayerTokenStore) ConsumeOnce(tokenStr string) (*model.Token, error) { +func (s *TimerLayerTokenStore) ConsumeOnce(tokenType string, tokenStr string) (*model.Token, error) { start := time.Now() - result, err := s.TokenStore.ConsumeOnce(tokenStr) + result, err := s.TokenStore.ConsumeOnce(tokenType, tokenStr) elapsed := float64(time.Since(start)) / float64(time.Second) if s.Root.Metrics != nil {
server/channels/web/saml.go+7 −4 modified@@ -106,7 +106,7 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) { return } - //Validate that the user is with SAML and all that + // Validate that the user is with SAML and all that encodedXML := r.FormValue("SAMLResponse") relayState := r.FormValue("RelayState") @@ -161,7 +161,8 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) { return } - if err = c.App.CheckUserAllAuthenticationCriteria(c.AppContext, user, ""); err != nil { + err = c.App.CheckUserAllAuthenticationCriteria(c.AppContext, user, "") + if err != nil { handleError(err) return } @@ -250,8 +251,10 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) { "code_challenge": samlChallenge, "code_challenge_method": samlMethod, }) - code := model.NewToken(model.TokenTypeSaml, extra) - if err := c.App.Srv().Store().Token().Save(code); err != nil { + + var code *model.Token + code, err = c.App.CreateSamlRelayToken(model.TokenTypeSSOCodeExchange, extra) + if err != nil { handleError(model.NewAppError("completeSaml", "app.recover.save.app_error", nil, "", http.StatusInternalServerError).Wrap(err)) return }
server/public/model/token.go+5 −4 modified@@ -8,10 +8,11 @@ import ( ) const ( - TokenSize = 64 - MaxTokenExipryTime = 1000 * 60 * 60 * 48 // 48 hour - TokenTypeOAuth = "oauth" - TokenTypeSaml = "saml" + TokenSize = 64 + MaxTokenExipryTime = 1000 * 60 * 60 * 48 // 48 hour + TokenTypeOAuth = "oauth" + TokenTypeSaml = "saml" + TokenTypeSSOCodeExchange = "sso-code-exchange" ) type Token struct {
5072bbf689a4Automated cherry pick of #34247 (#34256)
14 files changed · +352 −32
server/channels/api4/user.go+1 −1 modified@@ -130,7 +130,7 @@ func loginSSOCodeExchange(c *Context, w http.ResponseWriter, r *http.Request) { } // Consume one-time code atomically - token, appErr := c.App.ConsumeTokenOnce(loginCode) + token, appErr := c.App.ConsumeTokenOnce(model.TokenTypeSSOCodeExchange, loginCode) if appErr != nil { c.Err = appErr return
server/channels/api4/user_test.go+81 −0 modified@@ -6,6 +6,8 @@ package api4 import ( "bytes" "context" + "crypto/sha256" + "encoding/base64" "encoding/json" "fmt" "image/png" @@ -8673,6 +8675,85 @@ func TestLoginWithDesktopToken(t *testing.T) { }) } +func TestLoginSSOCodeExchange(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic() + defer th.TearDown() + + t.Run("wrong token type cannot be used for code exchange", func(t *testing.T) { + th.App.UpdateConfig(func(cfg *model.Config) { + cfg.FeatureFlags.MobileSSOCodeExchange = true + }) + + token := model.NewToken(model.TokenTypeOAuth, "extra-data") + require.NoError(t, th.App.Srv().Store().Token().Save(token)) + defer func() { + _ = th.App.Srv().Store().Token().Delete(token.Token) + }() + + props := map[string]string{ + "login_code": token.Token, + "code_verifier": "test_verifier", + "state": "test_state", + } + + resp, err := th.Client.DoAPIPost(context.Background(), "/users/login/sso/code-exchange", model.MapToJSON(props)) + require.Error(t, err) + require.Equal(t, http.StatusNotFound, resp.StatusCode) + }) + + t.Run("successful code exchange with S256 challenge", func(t *testing.T) { + th.App.UpdateConfig(func(cfg *model.Config) { + cfg.FeatureFlags.MobileSSOCodeExchange = true + }) + + samlUser := th.CreateUserWithAuth(model.UserAuthServiceSaml) + + codeVerifier := "test_code_verifier_123456789" + state := "test_state_value" + + sum := sha256.Sum256([]byte(codeVerifier)) + codeChallenge := base64.RawURLEncoding.EncodeToString(sum[:]) + + extra := map[string]string{ + "user_id": samlUser.Id, + "code_challenge": codeChallenge, + "code_challenge_method": "S256", + "state": state, + } + + token := model.NewToken(model.TokenTypeSSOCodeExchange, model.MapToJSON(extra)) + require.NoError(t, th.App.Srv().Store().Token().Save(token)) + + props := map[string]string{ + "login_code": token.Token, + "code_verifier": codeVerifier, + "state": state, + } + + resp, err := th.Client.DoAPIPost(context.Background(), "/users/login/sso/code-exchange", model.MapToJSON(props)) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + var result map[string]string + require.NoError(t, json.NewDecoder(resp.Body).Decode(&result)) + assert.NotEmpty(t, result["token"]) + assert.NotEmpty(t, result["csrf"]) + + _, err = th.App.Srv().Store().Token().GetByToken(token.Token) + require.Error(t, err) + + authenticatedClient := model.NewAPIv4Client(th.Client.URL) + authenticatedClient.SetToken(result["token"]) + + user, _, err := authenticatedClient.GetMe(context.Background(), "") + require.NoError(t, err) + assert.Equal(t, samlUser.Id, user.Id) + assert.Equal(t, samlUser.Email, user.Email) + assert.Equal(t, samlUser.Username, user.Username) + }) +} + func TestGetUsersByNames(t *testing.T) { mainHelper.Parallel(t) th := Setup(t).InitBasic()
server/channels/app/oauth.go+1 −1 modified@@ -977,7 +977,7 @@ func (a *App) SwitchEmailToOAuth(c request.CTX, w http.ResponseWriter, r *http.R stateProps["email"] = email if service == model.UserAuthServiceSaml { - samlToken, samlErr := a.CreateSamlRelayToken(email) + samlToken, samlErr := a.CreateSamlRelayToken(model.TokenTypeSaml, email) if samlErr != nil { return "", samlErr }
server/channels/app/saml.go+2 −2 modified@@ -298,8 +298,8 @@ func (a *App) ResetSamlAuthDataToEmail(includeDeleted bool, dryRun bool, userIDs return } -func (a *App) CreateSamlRelayToken(extra string) (*model.Token, *model.AppError) { - token := model.NewToken(model.TokenTypeSaml, extra) +func (a *App) CreateSamlRelayToken(tokenType string, extra string) (*model.Token, *model.AppError) { + token := model.NewToken(tokenType, extra) if err := a.Srv().Store().Token().Save(token); err != nil { var appErr *model.AppError
server/channels/app/user.go+2 −2 modified@@ -1750,8 +1750,8 @@ func (a *App) GetTokenById(token string) (*model.Token, *model.AppError) { return rtoken, nil } -func (a *App) ConsumeTokenOnce(tokenStr string) (*model.Token, *model.AppError) { - token, err := a.Srv().Store().Token().ConsumeOnce(tokenStr) +func (a *App) ConsumeTokenOnce(tokenType, tokenStr string) (*model.Token, *model.AppError) { + token, err := a.Srv().Store().Token().ConsumeOnce(tokenType, tokenStr) if err != nil { var status int switch err.(type) {
server/channels/app/user_test.go+82 −0 modified@@ -9,6 +9,7 @@ import ( "database/sql" "encoding/json" "errors" + "net/http" "os" "path/filepath" "strings" @@ -2484,3 +2485,84 @@ func TestRemoteUserDirectChannelCreation(t *testing.T) { assert.Equal(t, model.ChannelTypeDirect, channel.Type) }) } + +func TestConsumeTokenOnce(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic() + defer th.TearDown() + + t.Run("successfully consume valid token", func(t *testing.T) { + token := model.NewToken(model.TokenTypeOAuth, "extra-data") + require.NoError(t, th.App.Srv().Store().Token().Save(token)) + + consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeOAuth, token.Token) + require.Nil(t, appErr) + require.NotNil(t, consumedToken) + assert.Equal(t, token.Token, consumedToken.Token) + assert.Equal(t, model.TokenTypeOAuth, consumedToken.Type) + assert.Equal(t, "extra-data", consumedToken.Extra) + + _, err := th.App.Srv().Store().Token().GetByToken(token.Token) + require.Error(t, err) + }) + + t.Run("token not found returns 404", func(t *testing.T) { + nonExistentToken := model.NewRandomString(model.TokenSize) + + consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeOAuth, nonExistentToken) + require.NotNil(t, appErr) + require.Nil(t, consumedToken) + assert.Equal(t, http.StatusNotFound, appErr.StatusCode) + assert.Equal(t, "ConsumeTokenOnce", appErr.Where) + }) + + t.Run("wrong token type returns not found", func(t *testing.T) { + token := model.NewToken(model.TokenTypeOAuth, "extra-data") + require.NoError(t, th.App.Srv().Store().Token().Save(token)) + defer func() { + _ = th.App.Srv().Store().Token().Delete(token.Token) + }() + + consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeSaml, token.Token) + require.NotNil(t, appErr) + require.Nil(t, consumedToken) + assert.Equal(t, http.StatusNotFound, appErr.StatusCode) + + _, err := th.App.Srv().Store().Token().GetByToken(token.Token) + require.NoError(t, err) + }) + + t.Run("token can only be consumed once", func(t *testing.T) { + token := model.NewToken(model.TokenTypeSSOCodeExchange, "extra-data") + require.NoError(t, th.App.Srv().Store().Token().Save(token)) + + consumedToken1, appErr := th.App.ConsumeTokenOnce(model.TokenTypeSSOCodeExchange, token.Token) + require.Nil(t, appErr) + require.NotNil(t, consumedToken1) + + consumedToken2, appErr := th.App.ConsumeTokenOnce(model.TokenTypeSSOCodeExchange, token.Token) + require.NotNil(t, appErr) + require.Nil(t, consumedToken2) + assert.Equal(t, http.StatusNotFound, appErr.StatusCode) + }) + + t.Run("empty token string returns not found", func(t *testing.T) { + consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeOAuth, "") + require.NotNil(t, appErr) + require.Nil(t, consumedToken) + assert.Equal(t, http.StatusNotFound, appErr.StatusCode) + }) + + t.Run("empty token type returns not found", func(t *testing.T) { + token := model.NewToken(model.TokenTypeOAuth, "extra-data") + require.NoError(t, th.App.Srv().Store().Token().Save(token)) + defer func() { + _ = th.App.Srv().Store().Token().Delete(token.Token) + }() + + consumedToken, appErr := th.App.ConsumeTokenOnce("", token.Token) + require.NotNil(t, appErr) + require.Nil(t, consumedToken) + assert.Equal(t, http.StatusNotFound, appErr.StatusCode) + }) +}
server/channels/store/retrylayer/retrylayer.go+2 −2 modified@@ -14138,11 +14138,11 @@ func (s *RetryLayerTokenStore) Cleanup(expiryTime int64) { } -func (s *RetryLayerTokenStore) ConsumeOnce(tokenStr string) (*model.Token, error) { +func (s *RetryLayerTokenStore) ConsumeOnce(tokenType string, tokenStr string) (*model.Token, error) { tries := 0 for { - result, err := s.TokenStore.ConsumeOnce(tokenStr) + result, err := s.TokenStore.ConsumeOnce(tokenType, tokenStr) if err == nil { return result, nil }
server/channels/store/sqlstore/tokens_store.go+29 −4 modified@@ -78,16 +78,41 @@ func (s SqlTokenStore) GetByToken(tokenString string) (*model.Token, error) { return &token, nil } -func (s SqlTokenStore) ConsumeOnce(tokenStr string) (*model.Token, error) { +func (s SqlTokenStore) ConsumeOnce(tokenType, tokenStr string) (*model.Token, error) { var token model.Token - query := `DELETE FROM Tokens WHERE Token = ? RETURNING *` + if s.DriverName() == model.DatabaseDriverPostgres { + query := `DELETE FROM Tokens WHERE Type = ? AND Token = ? RETURNING *` + if err := s.GetMaster().Get(&token, query, tokenType, tokenStr); err != nil { + if err == sql.ErrNoRows { + return nil, store.NewErrNotFound("Token", tokenStr) + } + return nil, errors.Wrapf(err, "failed to consume token with type %s", tokenType) + } + return &token, nil + } + + transaction, err := s.GetMaster().Beginx() + if err != nil { + return nil, errors.Wrap(err, "failed to begin transaction") + } + defer finalizeTransactionX(transaction, &err) - if err := s.GetMaster().Get(&token, query, tokenStr); err != nil { + query := `SELECT * FROM Tokens WHERE Type = ? AND Token = ? FOR UPDATE` + if err = transaction.Get(&token, query, tokenType, tokenStr); err != nil { if err == sql.ErrNoRows { return nil, store.NewErrNotFound("Token", tokenStr) } - return nil, errors.Wrapf(err, "failed to consume token") + return nil, errors.Wrapf(err, "failed to select token with type %s", tokenType) + } + + deleteQuery := `DELETE FROM Tokens WHERE Type = ? AND Token = ?` + if _, err = transaction.Exec(deleteQuery, tokenType, tokenStr); err != nil { + return nil, errors.Wrapf(err, "failed to delete token with type %s", tokenType) + } + + if err = transaction.Commit(); err != nil { + return nil, errors.Wrap(err, "failed to commit transaction") } return &token, nil
server/channels/store/store.go+1 −1 modified@@ -693,7 +693,7 @@ type TokenStore interface { Save(recovery *model.Token) error Delete(token string) error GetByToken(token string) (*model.Token, error) - ConsumeOnce(tokenStr string) (*model.Token, error) + ConsumeOnce(tokenType, tokenStr string) (*model.Token, error) Cleanup(expiryTime int64) GetAllTokensByType(tokenType string) ([]*model.Token, error) RemoveAllTokensByType(tokenType string) error
server/channels/store/storetest/mocks/TokenStore.go+9 −9 modified@@ -19,29 +19,29 @@ func (_m *TokenStore) Cleanup(expiryTime int64) { _m.Called(expiryTime) } -// ConsumeOnce provides a mock function with given fields: tokenStr -func (_m *TokenStore) ConsumeOnce(tokenStr string) (*model.Token, error) { - ret := _m.Called(tokenStr) +// ConsumeOnce provides a mock function with given fields: tokenType, tokenStr +func (_m *TokenStore) ConsumeOnce(tokenType string, tokenStr string) (*model.Token, error) { + ret := _m.Called(tokenType, tokenStr) if len(ret) == 0 { panic("no return value specified for ConsumeOnce") } var r0 *model.Token var r1 error - if rf, ok := ret.Get(0).(func(string) (*model.Token, error)); ok { - return rf(tokenStr) + if rf, ok := ret.Get(0).(func(string, string) (*model.Token, error)); ok { + return rf(tokenType, tokenStr) } - if rf, ok := ret.Get(0).(func(string) *model.Token); ok { - r0 = rf(tokenStr) + if rf, ok := ret.Get(0).(func(string, string) *model.Token); ok { + r0 = rf(tokenType, tokenStr) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*model.Token) } } - if rf, ok := ret.Get(1).(func(string) error); ok { - r1 = rf(tokenStr) + if rf, ok := ret.Get(1).(func(string, string) error); ok { + r1 = rf(tokenType, tokenStr) } else { r1 = ret.Error(1) }
server/channels/store/storetest/tokens_store.go+128 −0 modified@@ -16,6 +16,7 @@ import ( func TestTokensStore(t *testing.T, rctx request.CTX, ss store.Store) { t.Run("TokensCleanup", func(t *testing.T) { testTokensCleanup(t, rctx, ss) }) + t.Run("ConsumeOnce", func(t *testing.T) { testConsumeOnce(t, rctx, ss) }) } func testTokensCleanup(t *testing.T, rctx request.CTX, ss store.Store) { @@ -41,3 +42,130 @@ func testTokensCleanup(t *testing.T, rctx request.CTX, ss store.Store) { require.NoError(t, err) assert.Len(t, tokens, 0) } + +func testConsumeOnce(t *testing.T, rctx request.CTX, ss store.Store) { + t.Run("successfully consume token once", func(t *testing.T) { + token := &model.Token{ + Token: model.NewRandomString(model.TokenSize), + CreateAt: model.GetMillis(), + Type: model.TokenTypeOAuth, + Extra: "test-extra", + } + err := ss.Token().Save(token) + require.NoError(t, err) + + consumedToken, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token) + require.NoError(t, err) + assert.Equal(t, token.Token, consumedToken.Token) + assert.Equal(t, token.Type, consumedToken.Type) + assert.Equal(t, token.Extra, consumedToken.Extra) + + tokens, err := ss.Token().GetAllTokensByType(model.TokenTypeOAuth) + require.NoError(t, err) + assert.Len(t, tokens, 0) + }) + + t.Run("second consumption of same token fails", func(t *testing.T) { + token := &model.Token{ + Token: model.NewRandomString(model.TokenSize), + CreateAt: model.GetMillis(), + Type: model.TokenTypeOAuth, + Extra: "test-extra", + } + err := ss.Token().Save(token) + require.NoError(t, err) + + _, err = ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token) + require.NoError(t, err) + + _, err = ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token) + require.Error(t, err) + var nfErr *store.ErrNotFound + assert.ErrorAs(t, err, &nfErr) + }) + + t.Run("consume with wrong type fails", func(t *testing.T) { + token := &model.Token{ + Token: model.NewRandomString(model.TokenSize), + CreateAt: model.GetMillis(), + Type: model.TokenTypeOAuth, + Extra: "test-extra", + } + err := ss.Token().Save(token) + require.NoError(t, err) + + _, err = ss.Token().ConsumeOnce(model.TokenTypeSSOCodeExchange, token.Token) + require.Error(t, err) + var nfErr *store.ErrNotFound + assert.ErrorAs(t, err, &nfErr) + + tokens, err := ss.Token().GetAllTokensByType(model.TokenTypeOAuth) + require.NoError(t, err) + assert.Len(t, tokens, 1) + + err = ss.Token().Delete(token.Token) + require.NoError(t, err) + }) + + t.Run("consume non-existent token fails", func(t *testing.T) { + nonExistentToken := model.NewRandomString(model.TokenSize) + _, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, nonExistentToken) + require.Error(t, err) + var nfErr *store.ErrNotFound + assert.ErrorAs(t, err, &nfErr) + }) + + t.Run("multiple tokens with same type can each be consumed once", func(t *testing.T) { + tokens := make([]*model.Token, 3) + for i := range tokens { + tokens[i] = &model.Token{ + Token: model.NewRandomString(model.TokenSize), + CreateAt: model.GetMillis(), + Type: model.TokenTypeOAuth, + Extra: "test-extra", + } + err := ss.Token().Save(tokens[i]) + require.NoError(t, err) + } + + for _, token := range tokens { + consumedToken, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token) + require.NoError(t, err) + assert.Equal(t, token.Token, consumedToken.Token) + } + + allTokens, err := ss.Token().GetAllTokensByType(model.TokenTypeOAuth) + require.NoError(t, err) + assert.Len(t, allTokens, 0) + }) + + t.Run("consuming token of different type leaves others intact", func(t *testing.T) { + oauthToken := &model.Token{ + Token: model.NewRandomString(model.TokenSize), + CreateAt: model.GetMillis(), + Type: model.TokenTypeOAuth, + Extra: "oauth-extra", + } + codeExchangeToken := &model.Token{ + Token: model.NewRandomString(model.TokenSize), + CreateAt: model.GetMillis(), + Type: model.TokenTypeSSOCodeExchange, + Extra: "password-extra", + } + err := ss.Token().Save(oauthToken) + require.NoError(t, err) + err = ss.Token().Save(codeExchangeToken) + require.NoError(t, err) + + consumedToken, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, oauthToken.Token) + require.NoError(t, err) + assert.Equal(t, oauthToken.Token, consumedToken.Token) + + codeExchangeTokens, err := ss.Token().GetAllTokensByType(model.TokenTypeSSOCodeExchange) + require.NoError(t, err) + assert.Len(t, codeExchangeTokens, 1) + + err = ss.Token().Delete(codeExchangeToken.Token) + require.NoError(t, err) + }) +}
server/channels/store/timerlayer/timerlayer.go+2 −2 modified@@ -11106,10 +11106,10 @@ func (s *TimerLayerTokenStore) Cleanup(expiryTime int64) { } } -func (s *TimerLayerTokenStore) ConsumeOnce(tokenStr string) (*model.Token, error) { +func (s *TimerLayerTokenStore) ConsumeOnce(tokenType string, tokenStr string) (*model.Token, error) { start := time.Now() - result, err := s.TokenStore.ConsumeOnce(tokenStr) + result, err := s.TokenStore.ConsumeOnce(tokenType, tokenStr) elapsed := float64(time.Since(start)) / float64(time.Second) if s.Root.Metrics != nil {
server/channels/web/saml.go+7 −4 modified@@ -106,7 +106,7 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) { return } - //Validate that the user is with SAML and all that + // Validate that the user is with SAML and all that encodedXML := r.FormValue("SAMLResponse") relayState := r.FormValue("RelayState") @@ -161,7 +161,8 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) { return } - if err = c.App.CheckUserAllAuthenticationCriteria(c.AppContext, user, ""); err != nil { + err = c.App.CheckUserAllAuthenticationCriteria(c.AppContext, user, "") + if err != nil { handleError(err) return } @@ -250,8 +251,10 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) { "code_challenge": samlChallenge, "code_challenge_method": samlMethod, }) - code := model.NewToken(model.TokenTypeSaml, extra) - if err := c.App.Srv().Store().Token().Save(code); err != nil { + + var code *model.Token + code, err = c.App.CreateSamlRelayToken(model.TokenTypeSSOCodeExchange, extra) + if err != nil { handleError(model.NewAppError("completeSaml", "app.recover.save.app_error", nil, "", http.StatusInternalServerError).Wrap(err)) return }
server/public/model/token.go+5 −4 modified@@ -8,10 +8,11 @@ import ( ) const ( - TokenSize = 64 - MaxTokenExipryTime = 1000 * 60 * 60 * 48 // 48 hour - TokenTypeOAuth = "oauth" - TokenTypeSaml = "saml" + TokenSize = 64 + MaxTokenExipryTime = 1000 * 60 * 60 * 48 // 48 hour + TokenTypeOAuth = "oauth" + TokenTypeSaml = "saml" + TokenTypeSSOCodeExchange = "sso-code-exchange" ) type Token struct {
feb598ed2b7aMM-66299: type handling for ConsumeTokenOnce (#34247) (#34261)
14 files changed · +397 −28
server/channels/api4/user.go+1 −1 modified@@ -130,7 +130,7 @@ func loginSSOCodeExchange(c *Context, w http.ResponseWriter, r *http.Request) { } // Consume one-time code atomically - token, appErr := c.App.ConsumeTokenOnce(loginCode) + token, appErr := c.App.ConsumeTokenOnce(model.TokenTypeSSOCodeExchange, loginCode) if appErr != nil { c.Err = appErr return
server/channels/api4/user_test.go+81 −0 modified@@ -6,6 +6,8 @@ package api4 import ( "bytes" "context" + "crypto/sha256" + "encoding/base64" "encoding/json" "fmt" "image/png" @@ -8490,6 +8492,85 @@ func TestLoginWithDesktopToken(t *testing.T) { }) } +func TestLoginSSOCodeExchange(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic() + defer th.TearDown() + + t.Run("wrong token type cannot be used for code exchange", func(t *testing.T) { + th.App.UpdateConfig(func(cfg *model.Config) { + cfg.FeatureFlags.MobileSSOCodeExchange = true + }) + + token := model.NewToken(model.TokenTypeOAuth, "extra-data") + require.NoError(t, th.App.Srv().Store().Token().Save(token)) + defer func() { + _ = th.App.Srv().Store().Token().Delete(token.Token) + }() + + props := map[string]string{ + "login_code": token.Token, + "code_verifier": "test_verifier", + "state": "test_state", + } + + resp, err := th.Client.DoAPIPost(context.Background(), "/users/login/sso/code-exchange", model.MapToJSON(props)) + require.Error(t, err) + require.Equal(t, http.StatusNotFound, resp.StatusCode) + }) + + t.Run("successful code exchange with S256 challenge", func(t *testing.T) { + th.App.UpdateConfig(func(cfg *model.Config) { + cfg.FeatureFlags.MobileSSOCodeExchange = true + }) + + samlUser := th.CreateUserWithAuth(model.UserAuthServiceSaml) + + codeVerifier := "test_code_verifier_123456789" + state := "test_state_value" + + sum := sha256.Sum256([]byte(codeVerifier)) + codeChallenge := base64.RawURLEncoding.EncodeToString(sum[:]) + + extra := map[string]string{ + "user_id": samlUser.Id, + "code_challenge": codeChallenge, + "code_challenge_method": "S256", + "state": state, + } + + token := model.NewToken(model.TokenTypeSSOCodeExchange, model.MapToJSON(extra)) + require.NoError(t, th.App.Srv().Store().Token().Save(token)) + + props := map[string]string{ + "login_code": token.Token, + "code_verifier": codeVerifier, + "state": state, + } + + resp, err := th.Client.DoAPIPost(context.Background(), "/users/login/sso/code-exchange", model.MapToJSON(props)) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + var result map[string]string + require.NoError(t, json.NewDecoder(resp.Body).Decode(&result)) + assert.NotEmpty(t, result["token"]) + assert.NotEmpty(t, result["csrf"]) + + _, err = th.App.Srv().Store().Token().GetByToken(token.Token) + require.Error(t, err) + + authenticatedClient := model.NewAPIv4Client(th.Client.URL) + authenticatedClient.SetToken(result["token"]) + + user, _, err := authenticatedClient.GetMe(context.Background(), "") + require.NoError(t, err) + assert.Equal(t, samlUser.Id, user.Id) + assert.Equal(t, samlUser.Email, user.Email) + assert.Equal(t, samlUser.Username, user.Username) + }) +} + func TestGetUsersByNames(t *testing.T) { mainHelper.Parallel(t) th := Setup(t).InitBasic()
server/channels/app/oauth.go+1 −1 modified@@ -977,7 +977,7 @@ func (a *App) SwitchEmailToOAuth(rctx request.CTX, w http.ResponseWriter, r *htt stateProps["email"] = email if service == model.UserAuthServiceSaml { - samlToken, samlErr := a.CreateSamlRelayToken(email) + samlToken, samlErr := a.CreateSamlRelayToken(model.TokenTypeSaml, email) if samlErr != nil { return "", samlErr }
server/channels/app/saml.go+2 −2 modified@@ -298,8 +298,8 @@ func (a *App) ResetSamlAuthDataToEmail(includeDeleted bool, dryRun bool, userIDs return } -func (a *App) CreateSamlRelayToken(extra string) (*model.Token, *model.AppError) { - token := model.NewToken(model.TokenTypeSaml, extra) +func (a *App) CreateSamlRelayToken(tokenType string, extra string) (*model.Token, *model.AppError) { + token := model.NewToken(tokenType, extra) if err := a.Srv().Store().Token().Save(token); err != nil { var appErr *model.AppError
server/channels/app/user.go+2 −2 modified@@ -1751,8 +1751,8 @@ func (a *App) GetTokenById(token string) (*model.Token, *model.AppError) { return rtoken, nil } -func (a *App) ConsumeTokenOnce(tokenStr string) (*model.Token, *model.AppError) { - token, err := a.Srv().Store().Token().ConsumeOnce(tokenStr) +func (a *App) ConsumeTokenOnce(tokenType, tokenStr string) (*model.Token, *model.AppError) { + token, err := a.Srv().Store().Token().ConsumeOnce(tokenType, tokenStr) if err != nil { var status int switch err.(type) {
server/channels/app/user_test.go+82 −0 modified@@ -9,6 +9,7 @@ import ( "database/sql" "encoding/json" "errors" + "net/http" "os" "path/filepath" "strings" @@ -2484,3 +2485,84 @@ func TestRemoteUserDirectChannelCreation(t *testing.T) { assert.Equal(t, model.ChannelTypeDirect, channel.Type) }) } + +func TestConsumeTokenOnce(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic() + defer th.TearDown() + + t.Run("successfully consume valid token", func(t *testing.T) { + token := model.NewToken(model.TokenTypeOAuth, "extra-data") + require.NoError(t, th.App.Srv().Store().Token().Save(token)) + + consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeOAuth, token.Token) + require.Nil(t, appErr) + require.NotNil(t, consumedToken) + assert.Equal(t, token.Token, consumedToken.Token) + assert.Equal(t, model.TokenTypeOAuth, consumedToken.Type) + assert.Equal(t, "extra-data", consumedToken.Extra) + + _, err := th.App.Srv().Store().Token().GetByToken(token.Token) + require.Error(t, err) + }) + + t.Run("token not found returns 404", func(t *testing.T) { + nonExistentToken := model.NewRandomString(model.TokenSize) + + consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeOAuth, nonExistentToken) + require.NotNil(t, appErr) + require.Nil(t, consumedToken) + assert.Equal(t, http.StatusNotFound, appErr.StatusCode) + assert.Equal(t, "ConsumeTokenOnce", appErr.Where) + }) + + t.Run("wrong token type returns not found", func(t *testing.T) { + token := model.NewToken(model.TokenTypeOAuth, "extra-data") + require.NoError(t, th.App.Srv().Store().Token().Save(token)) + defer func() { + _ = th.App.Srv().Store().Token().Delete(token.Token) + }() + + consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeSaml, token.Token) + require.NotNil(t, appErr) + require.Nil(t, consumedToken) + assert.Equal(t, http.StatusNotFound, appErr.StatusCode) + + _, err := th.App.Srv().Store().Token().GetByToken(token.Token) + require.NoError(t, err) + }) + + t.Run("token can only be consumed once", func(t *testing.T) { + token := model.NewToken(model.TokenTypeSSOCodeExchange, "extra-data") + require.NoError(t, th.App.Srv().Store().Token().Save(token)) + + consumedToken1, appErr := th.App.ConsumeTokenOnce(model.TokenTypeSSOCodeExchange, token.Token) + require.Nil(t, appErr) + require.NotNil(t, consumedToken1) + + consumedToken2, appErr := th.App.ConsumeTokenOnce(model.TokenTypeSSOCodeExchange, token.Token) + require.NotNil(t, appErr) + require.Nil(t, consumedToken2) + assert.Equal(t, http.StatusNotFound, appErr.StatusCode) + }) + + t.Run("empty token string returns not found", func(t *testing.T) { + consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeOAuth, "") + require.NotNil(t, appErr) + require.Nil(t, consumedToken) + assert.Equal(t, http.StatusNotFound, appErr.StatusCode) + }) + + t.Run("empty token type returns not found", func(t *testing.T) { + token := model.NewToken(model.TokenTypeOAuth, "extra-data") + require.NoError(t, th.App.Srv().Store().Token().Save(token)) + defer func() { + _ = th.App.Srv().Store().Token().Delete(token.Token) + }() + + consumedToken, appErr := th.App.ConsumeTokenOnce("", token.Token) + require.NotNil(t, appErr) + require.Nil(t, consumedToken) + assert.Equal(t, http.StatusNotFound, appErr.StatusCode) + }) +}
server/channels/store/retrylayer/retrylayer.go+42 −0 modified@@ -6438,6 +6438,27 @@ func (s *RetryLayerJobStore) GetAllByTypesPage(rctx request.CTX, jobTypes []stri } +func (s *RetryLayerJobStore) GetByTypeAndData(rctx request.CTX, jobType string, data map[string]string, useMaster bool, statuses ...string) ([]*model.Job, error) { + + tries := 0 + for { + result, err := s.JobStore.GetByTypeAndData(rctx, jobType, data, useMaster, statuses...) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + func (s *RetryLayerJobStore) GetCountByStatusAndType(status string, jobType string) (int64, error) { tries := 0 @@ -14130,6 +14151,27 @@ func (s *RetryLayerTokenStore) Cleanup(expiryTime int64) { } +func (s *RetryLayerTokenStore) ConsumeOnce(tokenType string, tokenStr string) (*model.Token, error) { + + tries := 0 + for { + result, err := s.TokenStore.ConsumeOnce(tokenType, tokenStr) + if err == nil { + return result, nil + } + if !isRepeatableError(err) { + return result, err + } + tries++ + if tries >= 3 { + err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures") + return result, err + } + timepkg.Sleep(100 * timepkg.Millisecond) + } + +} + func (s *RetryLayerTokenStore) Delete(token string) error { tries := 0
server/channels/store/sqlstore/tokens_store.go+4 −4 modified@@ -78,16 +78,16 @@ func (s SqlTokenStore) GetByToken(tokenString string) (*model.Token, error) { return &token, nil } -func (s SqlTokenStore) ConsumeOnce(tokenStr string) (*model.Token, error) { +func (s SqlTokenStore) ConsumeOnce(tokenType, tokenStr string) (*model.Token, error) { var token model.Token - query := `DELETE FROM Tokens WHERE Token = ? RETURNING *` + query := `DELETE FROM Tokens WHERE Type = ? AND Token = ? RETURNING *` - if err := s.GetMaster().Get(&token, query, tokenStr); err != nil { + if err := s.GetMaster().Get(&token, query, tokenType, tokenStr); err != nil { if err == sql.ErrNoRows { return nil, store.NewErrNotFound("Token", tokenStr) } - return nil, errors.Wrapf(err, "failed to consume token") + return nil, errors.Wrapf(err, "failed to consume token with type %s", tokenType) } return &token, nil
server/channels/store/store.go+1 −1 modified@@ -692,7 +692,7 @@ type TokenStore interface { Save(recovery *model.Token) error Delete(token string) error GetByToken(token string) (*model.Token, error) - ConsumeOnce(tokenStr string) (*model.Token, error) + ConsumeOnce(tokenType, tokenStr string) (*model.Token, error) Cleanup(expiryTime int64) GetAllTokensByType(tokenType string) ([]*model.Token, error) RemoveAllTokensByType(tokenType string) error
server/channels/store/storetest/mocks/TokenStore.go+9 −9 modified@@ -19,29 +19,29 @@ func (_m *TokenStore) Cleanup(expiryTime int64) { _m.Called(expiryTime) } -// ConsumeOnce provides a mock function with given fields: tokenStr -func (_m *TokenStore) ConsumeOnce(tokenStr string) (*model.Token, error) { - ret := _m.Called(tokenStr) +// ConsumeOnce provides a mock function with given fields: tokenType, tokenStr +func (_m *TokenStore) ConsumeOnce(tokenType string, tokenStr string) (*model.Token, error) { + ret := _m.Called(tokenType, tokenStr) if len(ret) == 0 { panic("no return value specified for ConsumeOnce") } var r0 *model.Token var r1 error - if rf, ok := ret.Get(0).(func(string) (*model.Token, error)); ok { - return rf(tokenStr) + if rf, ok := ret.Get(0).(func(string, string) (*model.Token, error)); ok { + return rf(tokenType, tokenStr) } - if rf, ok := ret.Get(0).(func(string) *model.Token); ok { - r0 = rf(tokenStr) + if rf, ok := ret.Get(0).(func(string, string) *model.Token); ok { + r0 = rf(tokenType, tokenStr) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*model.Token) } } - if rf, ok := ret.Get(1).(func(string) error); ok { - r1 = rf(tokenStr) + if rf, ok := ret.Get(1).(func(string, string) error); ok { + r1 = rf(tokenType, tokenStr) } else { r1 = ret.Error(1) }
server/channels/store/storetest/tokens_store.go+128 −0 modified@@ -16,6 +16,7 @@ import ( func TestTokensStore(t *testing.T, rctx request.CTX, ss store.Store) { t.Run("TokensCleanup", func(t *testing.T) { testTokensCleanup(t, rctx, ss) }) + t.Run("ConsumeOnce", func(t *testing.T) { testConsumeOnce(t, rctx, ss) }) } func testTokensCleanup(t *testing.T, rctx request.CTX, ss store.Store) { @@ -41,3 +42,130 @@ func testTokensCleanup(t *testing.T, rctx request.CTX, ss store.Store) { require.NoError(t, err) assert.Len(t, tokens, 0) } + +func testConsumeOnce(t *testing.T, rctx request.CTX, ss store.Store) { + t.Run("successfully consume token once", func(t *testing.T) { + token := &model.Token{ + Token: model.NewRandomString(model.TokenSize), + CreateAt: model.GetMillis(), + Type: model.TokenTypeOAuth, + Extra: "test-extra", + } + err := ss.Token().Save(token) + require.NoError(t, err) + + consumedToken, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token) + require.NoError(t, err) + assert.Equal(t, token.Token, consumedToken.Token) + assert.Equal(t, token.Type, consumedToken.Type) + assert.Equal(t, token.Extra, consumedToken.Extra) + + tokens, err := ss.Token().GetAllTokensByType(model.TokenTypeOAuth) + require.NoError(t, err) + assert.Len(t, tokens, 0) + }) + + t.Run("second consumption of same token fails", func(t *testing.T) { + token := &model.Token{ + Token: model.NewRandomString(model.TokenSize), + CreateAt: model.GetMillis(), + Type: model.TokenTypeOAuth, + Extra: "test-extra", + } + err := ss.Token().Save(token) + require.NoError(t, err) + + _, err = ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token) + require.NoError(t, err) + + _, err = ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token) + require.Error(t, err) + var nfErr *store.ErrNotFound + assert.ErrorAs(t, err, &nfErr) + }) + + t.Run("consume with wrong type fails", func(t *testing.T) { + token := &model.Token{ + Token: model.NewRandomString(model.TokenSize), + CreateAt: model.GetMillis(), + Type: model.TokenTypeOAuth, + Extra: "test-extra", + } + err := ss.Token().Save(token) + require.NoError(t, err) + + _, err = ss.Token().ConsumeOnce(model.TokenTypeSSOCodeExchange, token.Token) + require.Error(t, err) + var nfErr *store.ErrNotFound + assert.ErrorAs(t, err, &nfErr) + + tokens, err := ss.Token().GetAllTokensByType(model.TokenTypeOAuth) + require.NoError(t, err) + assert.Len(t, tokens, 1) + + err = ss.Token().Delete(token.Token) + require.NoError(t, err) + }) + + t.Run("consume non-existent token fails", func(t *testing.T) { + nonExistentToken := model.NewRandomString(model.TokenSize) + _, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, nonExistentToken) + require.Error(t, err) + var nfErr *store.ErrNotFound + assert.ErrorAs(t, err, &nfErr) + }) + + t.Run("multiple tokens with same type can each be consumed once", func(t *testing.T) { + tokens := make([]*model.Token, 3) + for i := range tokens { + tokens[i] = &model.Token{ + Token: model.NewRandomString(model.TokenSize), + CreateAt: model.GetMillis(), + Type: model.TokenTypeOAuth, + Extra: "test-extra", + } + err := ss.Token().Save(tokens[i]) + require.NoError(t, err) + } + + for _, token := range tokens { + consumedToken, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token) + require.NoError(t, err) + assert.Equal(t, token.Token, consumedToken.Token) + } + + allTokens, err := ss.Token().GetAllTokensByType(model.TokenTypeOAuth) + require.NoError(t, err) + assert.Len(t, allTokens, 0) + }) + + t.Run("consuming token of different type leaves others intact", func(t *testing.T) { + oauthToken := &model.Token{ + Token: model.NewRandomString(model.TokenSize), + CreateAt: model.GetMillis(), + Type: model.TokenTypeOAuth, + Extra: "oauth-extra", + } + codeExchangeToken := &model.Token{ + Token: model.NewRandomString(model.TokenSize), + CreateAt: model.GetMillis(), + Type: model.TokenTypeSSOCodeExchange, + Extra: "password-extra", + } + err := ss.Token().Save(oauthToken) + require.NoError(t, err) + err = ss.Token().Save(codeExchangeToken) + require.NoError(t, err) + + consumedToken, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, oauthToken.Token) + require.NoError(t, err) + assert.Equal(t, oauthToken.Token, consumedToken.Token) + + codeExchangeTokens, err := ss.Token().GetAllTokensByType(model.TokenTypeSSOCodeExchange) + require.NoError(t, err) + assert.Len(t, codeExchangeTokens, 1) + + err = ss.Token().Delete(codeExchangeToken.Token) + require.NoError(t, err) + }) +}
server/channels/store/timerlayer/timerlayer.go+32 −0 modified@@ -5159,6 +5159,22 @@ func (s *TimerLayerJobStore) GetAllByTypesPage(rctx request.CTX, jobTypes []stri return result, err } +func (s *TimerLayerJobStore) GetByTypeAndData(rctx request.CTX, jobType string, data map[string]string, useMaster bool, statuses ...string) ([]*model.Job, error) { + start := time.Now() + + result, err := s.JobStore.GetByTypeAndData(rctx, jobType, data, useMaster, statuses...) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("JobStore.GetByTypeAndData", success, elapsed) + } + return result, err +} + func (s *TimerLayerJobStore) GetCountByStatusAndType(status string, jobType string) (int64, error) { start := time.Now() @@ -11106,6 +11122,22 @@ func (s *TimerLayerTokenStore) Cleanup(expiryTime int64) { } } +func (s *TimerLayerTokenStore) ConsumeOnce(tokenType string, tokenStr string) (*model.Token, error) { + start := time.Now() + + result, err := s.TokenStore.ConsumeOnce(tokenType, tokenStr) + + elapsed := float64(time.Since(start)) / float64(time.Second) + if s.Root.Metrics != nil { + success := "false" + if err == nil { + success = "true" + } + s.Root.Metrics.ObserveStoreMethodDuration("TokenStore.ConsumeOnce", success, elapsed) + } + return result, err +} + func (s *TimerLayerTokenStore) Delete(token string) error { start := time.Now()
server/channels/web/saml.go+7 −4 modified@@ -106,7 +106,7 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) { return } - //Validate that the user is with SAML and all that + // Validate that the user is with SAML and all that encodedXML := r.FormValue("SAMLResponse") relayState := r.FormValue("RelayState") @@ -161,7 +161,8 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) { return } - if err = c.App.CheckUserAllAuthenticationCriteria(c.AppContext, user, ""); err != nil { + err = c.App.CheckUserAllAuthenticationCriteria(c.AppContext, user, "") + if err != nil { handleError(err) return } @@ -250,8 +251,10 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) { "code_challenge": samlChallenge, "code_challenge_method": samlMethod, }) - code := model.NewToken(model.TokenTypeSaml, extra) - if err := c.App.Srv().Store().Token().Save(code); err != nil { + + var code *model.Token + code, err = c.App.CreateSamlRelayToken(model.TokenTypeSSOCodeExchange, extra) + if err != nil { handleError(model.NewAppError("completeSaml", "app.recover.save.app_error", nil, "", http.StatusInternalServerError).Wrap(err)) return }
server/public/model/token.go+5 −4 modified@@ -8,10 +8,11 @@ import ( ) const ( - TokenSize = 64 - MaxTokenExipryTime = 1000 * 60 * 60 * 48 // 48 hour - TokenTypeOAuth = "oauth" - TokenTypeSaml = "saml" + TokenSize = 64 + MaxTokenExipryTime = 1000 * 60 * 60 * 48 // 48 hour + TokenTypeOAuth = "oauth" + TokenTypeSaml = "saml" + TokenTypeSSOCodeExchange = "sso-code-exchange" ) type Token struct {
acda1fb5dd46MM-66299: type handling for ConsumeTokenOnce (#34247)
14 files changed · +327 −32
server/channels/api4/user.go+1 −1 modified@@ -130,7 +130,7 @@ func loginSSOCodeExchange(c *Context, w http.ResponseWriter, r *http.Request) { } // Consume one-time code atomically - token, appErr := c.App.ConsumeTokenOnce(loginCode) + token, appErr := c.App.ConsumeTokenOnce(model.TokenTypeSSOCodeExchange, loginCode) if appErr != nil { c.Err = appErr return
server/channels/api4/user_test.go+81 −0 modified@@ -6,6 +6,8 @@ package api4 import ( "bytes" "context" + "crypto/sha256" + "encoding/base64" "encoding/json" "fmt" "image/png" @@ -8490,6 +8492,85 @@ func TestLoginWithDesktopToken(t *testing.T) { }) } +func TestLoginSSOCodeExchange(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic() + defer th.TearDown() + + t.Run("wrong token type cannot be used for code exchange", func(t *testing.T) { + th.App.UpdateConfig(func(cfg *model.Config) { + cfg.FeatureFlags.MobileSSOCodeExchange = true + }) + + token := model.NewToken(model.TokenTypeOAuth, "extra-data") + require.NoError(t, th.App.Srv().Store().Token().Save(token)) + defer func() { + _ = th.App.Srv().Store().Token().Delete(token.Token) + }() + + props := map[string]string{ + "login_code": token.Token, + "code_verifier": "test_verifier", + "state": "test_state", + } + + resp, err := th.Client.DoAPIPost(context.Background(), "/users/login/sso/code-exchange", model.MapToJSON(props)) + require.Error(t, err) + require.Equal(t, http.StatusNotFound, resp.StatusCode) + }) + + t.Run("successful code exchange with S256 challenge", func(t *testing.T) { + th.App.UpdateConfig(func(cfg *model.Config) { + cfg.FeatureFlags.MobileSSOCodeExchange = true + }) + + samlUser := th.CreateUserWithAuth(model.UserAuthServiceSaml) + + codeVerifier := "test_code_verifier_123456789" + state := "test_state_value" + + sum := sha256.Sum256([]byte(codeVerifier)) + codeChallenge := base64.RawURLEncoding.EncodeToString(sum[:]) + + extra := map[string]string{ + "user_id": samlUser.Id, + "code_challenge": codeChallenge, + "code_challenge_method": "S256", + "state": state, + } + + token := model.NewToken(model.TokenTypeSSOCodeExchange, model.MapToJSON(extra)) + require.NoError(t, th.App.Srv().Store().Token().Save(token)) + + props := map[string]string{ + "login_code": token.Token, + "code_verifier": codeVerifier, + "state": state, + } + + resp, err := th.Client.DoAPIPost(context.Background(), "/users/login/sso/code-exchange", model.MapToJSON(props)) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + var result map[string]string + require.NoError(t, json.NewDecoder(resp.Body).Decode(&result)) + assert.NotEmpty(t, result["token"]) + assert.NotEmpty(t, result["csrf"]) + + _, err = th.App.Srv().Store().Token().GetByToken(token.Token) + require.Error(t, err) + + authenticatedClient := model.NewAPIv4Client(th.Client.URL) + authenticatedClient.SetToken(result["token"]) + + user, _, err := authenticatedClient.GetMe(context.Background(), "") + require.NoError(t, err) + assert.Equal(t, samlUser.Id, user.Id) + assert.Equal(t, samlUser.Email, user.Email) + assert.Equal(t, samlUser.Username, user.Username) + }) +} + func TestGetUsersByNames(t *testing.T) { mainHelper.Parallel(t) th := Setup(t).InitBasic()
server/channels/app/oauth.go+1 −1 modified@@ -977,7 +977,7 @@ func (a *App) SwitchEmailToOAuth(rctx request.CTX, w http.ResponseWriter, r *htt stateProps["email"] = email if service == model.UserAuthServiceSaml { - samlToken, samlErr := a.CreateSamlRelayToken(email) + samlToken, samlErr := a.CreateSamlRelayToken(model.TokenTypeSaml, email) if samlErr != nil { return "", samlErr }
server/channels/app/saml.go+2 −2 modified@@ -298,8 +298,8 @@ func (a *App) ResetSamlAuthDataToEmail(includeDeleted bool, dryRun bool, userIDs return } -func (a *App) CreateSamlRelayToken(extra string) (*model.Token, *model.AppError) { - token := model.NewToken(model.TokenTypeSaml, extra) +func (a *App) CreateSamlRelayToken(tokenType string, extra string) (*model.Token, *model.AppError) { + token := model.NewToken(tokenType, extra) if err := a.Srv().Store().Token().Save(token); err != nil { var appErr *model.AppError
server/channels/app/user.go+2 −2 modified@@ -1750,8 +1750,8 @@ func (a *App) GetTokenById(token string) (*model.Token, *model.AppError) { return rtoken, nil } -func (a *App) ConsumeTokenOnce(tokenStr string) (*model.Token, *model.AppError) { - token, err := a.Srv().Store().Token().ConsumeOnce(tokenStr) +func (a *App) ConsumeTokenOnce(tokenType, tokenStr string) (*model.Token, *model.AppError) { + token, err := a.Srv().Store().Token().ConsumeOnce(tokenType, tokenStr) if err != nil { var status int switch err.(type) {
server/channels/app/user_test.go+82 −0 modified@@ -8,6 +8,7 @@ import ( "database/sql" "encoding/json" "errors" + "net/http" "os" "path/filepath" "strings" @@ -2483,3 +2484,84 @@ func TestRemoteUserDirectChannelCreation(t *testing.T) { assert.Equal(t, model.ChannelTypeDirect, channel.Type) }) } + +func TestConsumeTokenOnce(t *testing.T) { + mainHelper.Parallel(t) + th := Setup(t).InitBasic() + defer th.TearDown() + + t.Run("successfully consume valid token", func(t *testing.T) { + token := model.NewToken(model.TokenTypeOAuth, "extra-data") + require.NoError(t, th.App.Srv().Store().Token().Save(token)) + + consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeOAuth, token.Token) + require.Nil(t, appErr) + require.NotNil(t, consumedToken) + assert.Equal(t, token.Token, consumedToken.Token) + assert.Equal(t, model.TokenTypeOAuth, consumedToken.Type) + assert.Equal(t, "extra-data", consumedToken.Extra) + + _, err := th.App.Srv().Store().Token().GetByToken(token.Token) + require.Error(t, err) + }) + + t.Run("token not found returns 404", func(t *testing.T) { + nonExistentToken := model.NewRandomString(model.TokenSize) + + consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeOAuth, nonExistentToken) + require.NotNil(t, appErr) + require.Nil(t, consumedToken) + assert.Equal(t, http.StatusNotFound, appErr.StatusCode) + assert.Equal(t, "ConsumeTokenOnce", appErr.Where) + }) + + t.Run("wrong token type returns not found", func(t *testing.T) { + token := model.NewToken(model.TokenTypeOAuth, "extra-data") + require.NoError(t, th.App.Srv().Store().Token().Save(token)) + defer func() { + _ = th.App.Srv().Store().Token().Delete(token.Token) + }() + + consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeSaml, token.Token) + require.NotNil(t, appErr) + require.Nil(t, consumedToken) + assert.Equal(t, http.StatusNotFound, appErr.StatusCode) + + _, err := th.App.Srv().Store().Token().GetByToken(token.Token) + require.NoError(t, err) + }) + + t.Run("token can only be consumed once", func(t *testing.T) { + token := model.NewToken(model.TokenTypeSSOCodeExchange, "extra-data") + require.NoError(t, th.App.Srv().Store().Token().Save(token)) + + consumedToken1, appErr := th.App.ConsumeTokenOnce(model.TokenTypeSSOCodeExchange, token.Token) + require.Nil(t, appErr) + require.NotNil(t, consumedToken1) + + consumedToken2, appErr := th.App.ConsumeTokenOnce(model.TokenTypeSSOCodeExchange, token.Token) + require.NotNil(t, appErr) + require.Nil(t, consumedToken2) + assert.Equal(t, http.StatusNotFound, appErr.StatusCode) + }) + + t.Run("empty token string returns not found", func(t *testing.T) { + consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeOAuth, "") + require.NotNil(t, appErr) + require.Nil(t, consumedToken) + assert.Equal(t, http.StatusNotFound, appErr.StatusCode) + }) + + t.Run("empty token type returns not found", func(t *testing.T) { + token := model.NewToken(model.TokenTypeOAuth, "extra-data") + require.NoError(t, th.App.Srv().Store().Token().Save(token)) + defer func() { + _ = th.App.Srv().Store().Token().Delete(token.Token) + }() + + consumedToken, appErr := th.App.ConsumeTokenOnce("", token.Token) + require.NotNil(t, appErr) + require.Nil(t, consumedToken) + assert.Equal(t, http.StatusNotFound, appErr.StatusCode) + }) +}
server/channels/store/retrylayer/retrylayer.go+2 −2 modified@@ -14293,11 +14293,11 @@ func (s *RetryLayerTokenStore) Cleanup(expiryTime int64) { } -func (s *RetryLayerTokenStore) ConsumeOnce(tokenStr string) (*model.Token, error) { +func (s *RetryLayerTokenStore) ConsumeOnce(tokenType string, tokenStr string) (*model.Token, error) { tries := 0 for { - result, err := s.TokenStore.ConsumeOnce(tokenStr) + result, err := s.TokenStore.ConsumeOnce(tokenType, tokenStr) if err == nil { return result, nil }
server/channels/store/sqlstore/tokens_store.go+4 −4 modified@@ -78,16 +78,16 @@ func (s SqlTokenStore) GetByToken(tokenString string) (*model.Token, error) { return &token, nil } -func (s SqlTokenStore) ConsumeOnce(tokenStr string) (*model.Token, error) { +func (s SqlTokenStore) ConsumeOnce(tokenType, tokenStr string) (*model.Token, error) { var token model.Token - query := `DELETE FROM Tokens WHERE Token = ? RETURNING *` + query := `DELETE FROM Tokens WHERE Type = ? AND Token = ? RETURNING *` - if err := s.GetMaster().Get(&token, query, tokenStr); err != nil { + if err := s.GetMaster().Get(&token, query, tokenType, tokenStr); err != nil { if err == sql.ErrNoRows { return nil, store.NewErrNotFound("Token", tokenStr) } - return nil, errors.Wrapf(err, "failed to consume token") + return nil, errors.Wrapf(err, "failed to consume token with type %s", tokenType) } return &token, nil
server/channels/store/store.go+1 −1 modified@@ -696,7 +696,7 @@ type TokenStore interface { Save(recovery *model.Token) error Delete(token string) error GetByToken(token string) (*model.Token, error) - ConsumeOnce(tokenStr string) (*model.Token, error) + ConsumeOnce(tokenType, tokenStr string) (*model.Token, error) Cleanup(expiryTime int64) GetAllTokensByType(tokenType string) ([]*model.Token, error) RemoveAllTokensByType(tokenType string) error
server/channels/store/storetest/mocks/TokenStore.go+9 −9 modified@@ -19,29 +19,29 @@ func (_m *TokenStore) Cleanup(expiryTime int64) { _m.Called(expiryTime) } -// ConsumeOnce provides a mock function with given fields: tokenStr -func (_m *TokenStore) ConsumeOnce(tokenStr string) (*model.Token, error) { - ret := _m.Called(tokenStr) +// ConsumeOnce provides a mock function with given fields: tokenType, tokenStr +func (_m *TokenStore) ConsumeOnce(tokenType string, tokenStr string) (*model.Token, error) { + ret := _m.Called(tokenType, tokenStr) if len(ret) == 0 { panic("no return value specified for ConsumeOnce") } var r0 *model.Token var r1 error - if rf, ok := ret.Get(0).(func(string) (*model.Token, error)); ok { - return rf(tokenStr) + if rf, ok := ret.Get(0).(func(string, string) (*model.Token, error)); ok { + return rf(tokenType, tokenStr) } - if rf, ok := ret.Get(0).(func(string) *model.Token); ok { - r0 = rf(tokenStr) + if rf, ok := ret.Get(0).(func(string, string) *model.Token); ok { + r0 = rf(tokenType, tokenStr) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*model.Token) } } - if rf, ok := ret.Get(1).(func(string) error); ok { - r1 = rf(tokenStr) + if rf, ok := ret.Get(1).(func(string, string) error); ok { + r1 = rf(tokenType, tokenStr) } else { r1 = ret.Error(1) }
server/channels/store/storetest/tokens_store.go+128 −0 modified@@ -16,6 +16,7 @@ import ( func TestTokensStore(t *testing.T, rctx request.CTX, ss store.Store) { t.Run("TokensCleanup", func(t *testing.T) { testTokensCleanup(t, rctx, ss) }) + t.Run("ConsumeOnce", func(t *testing.T) { testConsumeOnce(t, rctx, ss) }) } func testTokensCleanup(t *testing.T, rctx request.CTX, ss store.Store) { @@ -41,3 +42,130 @@ func testTokensCleanup(t *testing.T, rctx request.CTX, ss store.Store) { require.NoError(t, err) assert.Len(t, tokens, 0) } + +func testConsumeOnce(t *testing.T, rctx request.CTX, ss store.Store) { + t.Run("successfully consume token once", func(t *testing.T) { + token := &model.Token{ + Token: model.NewRandomString(model.TokenSize), + CreateAt: model.GetMillis(), + Type: model.TokenTypeOAuth, + Extra: "test-extra", + } + err := ss.Token().Save(token) + require.NoError(t, err) + + consumedToken, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token) + require.NoError(t, err) + assert.Equal(t, token.Token, consumedToken.Token) + assert.Equal(t, token.Type, consumedToken.Type) + assert.Equal(t, token.Extra, consumedToken.Extra) + + tokens, err := ss.Token().GetAllTokensByType(model.TokenTypeOAuth) + require.NoError(t, err) + assert.Len(t, tokens, 0) + }) + + t.Run("second consumption of same token fails", func(t *testing.T) { + token := &model.Token{ + Token: model.NewRandomString(model.TokenSize), + CreateAt: model.GetMillis(), + Type: model.TokenTypeOAuth, + Extra: "test-extra", + } + err := ss.Token().Save(token) + require.NoError(t, err) + + _, err = ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token) + require.NoError(t, err) + + _, err = ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token) + require.Error(t, err) + var nfErr *store.ErrNotFound + assert.ErrorAs(t, err, &nfErr) + }) + + t.Run("consume with wrong type fails", func(t *testing.T) { + token := &model.Token{ + Token: model.NewRandomString(model.TokenSize), + CreateAt: model.GetMillis(), + Type: model.TokenTypeOAuth, + Extra: "test-extra", + } + err := ss.Token().Save(token) + require.NoError(t, err) + + _, err = ss.Token().ConsumeOnce(model.TokenTypeSSOCodeExchange, token.Token) + require.Error(t, err) + var nfErr *store.ErrNotFound + assert.ErrorAs(t, err, &nfErr) + + tokens, err := ss.Token().GetAllTokensByType(model.TokenTypeOAuth) + require.NoError(t, err) + assert.Len(t, tokens, 1) + + err = ss.Token().Delete(token.Token) + require.NoError(t, err) + }) + + t.Run("consume non-existent token fails", func(t *testing.T) { + nonExistentToken := model.NewRandomString(model.TokenSize) + _, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, nonExistentToken) + require.Error(t, err) + var nfErr *store.ErrNotFound + assert.ErrorAs(t, err, &nfErr) + }) + + t.Run("multiple tokens with same type can each be consumed once", func(t *testing.T) { + tokens := make([]*model.Token, 3) + for i := range tokens { + tokens[i] = &model.Token{ + Token: model.NewRandomString(model.TokenSize), + CreateAt: model.GetMillis(), + Type: model.TokenTypeOAuth, + Extra: "test-extra", + } + err := ss.Token().Save(tokens[i]) + require.NoError(t, err) + } + + for _, token := range tokens { + consumedToken, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token) + require.NoError(t, err) + assert.Equal(t, token.Token, consumedToken.Token) + } + + allTokens, err := ss.Token().GetAllTokensByType(model.TokenTypeOAuth) + require.NoError(t, err) + assert.Len(t, allTokens, 0) + }) + + t.Run("consuming token of different type leaves others intact", func(t *testing.T) { + oauthToken := &model.Token{ + Token: model.NewRandomString(model.TokenSize), + CreateAt: model.GetMillis(), + Type: model.TokenTypeOAuth, + Extra: "oauth-extra", + } + codeExchangeToken := &model.Token{ + Token: model.NewRandomString(model.TokenSize), + CreateAt: model.GetMillis(), + Type: model.TokenTypeSSOCodeExchange, + Extra: "password-extra", + } + err := ss.Token().Save(oauthToken) + require.NoError(t, err) + err = ss.Token().Save(codeExchangeToken) + require.NoError(t, err) + + consumedToken, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, oauthToken.Token) + require.NoError(t, err) + assert.Equal(t, oauthToken.Token, consumedToken.Token) + + codeExchangeTokens, err := ss.Token().GetAllTokensByType(model.TokenTypeSSOCodeExchange) + require.NoError(t, err) + assert.Len(t, codeExchangeTokens, 1) + + err = ss.Token().Delete(codeExchangeToken.Token) + require.NoError(t, err) + }) +}
server/channels/store/timerlayer/timerlayer.go+2 −2 modified@@ -11243,10 +11243,10 @@ func (s *TimerLayerTokenStore) Cleanup(expiryTime int64) { } } -func (s *TimerLayerTokenStore) ConsumeOnce(tokenStr string) (*model.Token, error) { +func (s *TimerLayerTokenStore) ConsumeOnce(tokenType string, tokenStr string) (*model.Token, error) { start := time.Now() - result, err := s.TokenStore.ConsumeOnce(tokenStr) + result, err := s.TokenStore.ConsumeOnce(tokenType, tokenStr) elapsed := float64(time.Since(start)) / float64(time.Second) if s.Root.Metrics != nil {
server/channels/web/saml.go+7 −4 modified@@ -106,7 +106,7 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) { return } - //Validate that the user is with SAML and all that + // Validate that the user is with SAML and all that encodedXML := r.FormValue("SAMLResponse") relayState := r.FormValue("RelayState") @@ -161,7 +161,8 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) { return } - if err = c.App.CheckUserAllAuthenticationCriteria(c.AppContext, user, ""); err != nil { + err = c.App.CheckUserAllAuthenticationCriteria(c.AppContext, user, "") + if err != nil { handleError(err) return } @@ -250,8 +251,10 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) { "code_challenge": samlChallenge, "code_challenge_method": samlMethod, }) - code := model.NewToken(model.TokenTypeSaml, extra) - if err := c.App.Srv().Store().Token().Save(code); err != nil { + + var code *model.Token + code, err = c.App.CreateSamlRelayToken(model.TokenTypeSSOCodeExchange, extra) + if err != nil { handleError(model.NewAppError("completeSaml", "app.recover.save.app_error", nil, "", http.StatusInternalServerError).Wrap(err)) return }
server/public/model/token.go+5 −4 modified@@ -8,10 +8,11 @@ import ( ) const ( - TokenSize = 64 - MaxTokenExipryTime = 1000 * 60 * 60 * 48 // 48 hour - TokenTypeOAuth = "oauth" - TokenTypeSaml = "saml" + TokenSize = 64 + MaxTokenExipryTime = 1000 * 60 * 60 * 48 // 48 hour + TokenTypeOAuth = "oauth" + TokenTypeSaml = "saml" + TokenTypeSSOCodeExchange = "sso-code-exchange" ) type Token struct {
Vulnerability mechanics
Generated by null/stub on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.
References
7- github.com/advisories/GHSA-mp6x-97xj-9x62ghsaADVISORY
- nvd.nist.gov/vuln/detail/CVE-2025-12421ghsaADVISORY
- github.com/mattermost/mattermost/commit/5072bbf689a46b3b97b2373f26149f5891396e6bghsaWEB
- github.com/mattermost/mattermost/commit/acda1fb5dd46a2f46c76ae67012423c760525eaaghsaWEB
- github.com/mattermost/mattermost/commit/f361e7d75a7ab9df5a2106e1ceb919b94b55e41dghsaWEB
- github.com/mattermost/mattermost/commit/feb598ed2b7ac3cb78a253b032ad7e4628b0de00ghsaWEB
- mattermost.com/security-updatesghsaWEB
News mentions
0No linked articles in our index yet.