VYPR
Critical severityNVD Advisory· Published Nov 27, 2025· Updated Feb 26, 2026

Account Takeover via Code Exchange Endpoint

CVE-2025-12421

Description

Mattermost versions 11.0.x <= 11.0.2, 10.12.x <= 10.12.1, 10.11.x <= 10.11.4, 10.5.x <= 10.5.12 fail to to verify that the token used during the code exchange originates from the same authentication flow, which allows an authenticated user to perform account takeover via a specially crafted email address used when switching authentication methods and sending a request to the /users/login/sso/code-exchange endpoint. The vulnerability requires ExperimentalEnableAuthenticationTransfer to be enabled (default: enabled) and RequireEmailVerification to be disabled (default: disabled).

Affected packages

Versions sourced from the GitHub Security Advisory.

PackageAffected versionsPatched versions
github.com/mattermost/mattermost/server/v8Go
< 8.0.0-20251022210333-acda1fb5dd468.0.0-20251022210333-acda1fb5dd46
github.com/mattermost/mattermost-serverGo
>= 11.0.0, < 11.0.311.0.3
github.com/mattermost/mattermost-serverGo
>= 10.12.0, < 10.12.210.12.2
github.com/mattermost/mattermost-serverGo
>= 10.11.0, < 10.11.510.11.5
github.com/mattermost/mattermost-serverGo
>= 10.5.0, < 10.5.1310.5.13

Affected products

1

Patches

4
f361e7d75a7a

Automated cherry pick of #34247 (#34257)

https://github.com/mattermost/mattermostMattermost BuildOct 27, 2025via ghsa
14 files changed · +352 32
  • server/channels/api4/user.go+1 1 modified
    @@ -130,7 +130,7 @@ func loginSSOCodeExchange(c *Context, w http.ResponseWriter, r *http.Request) {
     	}
     
     	// Consume one-time code atomically
    -	token, appErr := c.App.ConsumeTokenOnce(loginCode)
    +	token, appErr := c.App.ConsumeTokenOnce(model.TokenTypeSSOCodeExchange, loginCode)
     	if appErr != nil {
     		c.Err = appErr
     		return
    
  • server/channels/api4/user_test.go+81 0 modified
    @@ -6,6 +6,8 @@ package api4
     import (
     	"bytes"
     	"context"
    +	"crypto/sha256"
    +	"encoding/base64"
     	"encoding/json"
     	"fmt"
     	"image/png"
    @@ -8673,6 +8675,85 @@ func TestLoginWithDesktopToken(t *testing.T) {
     	})
     }
     
    +func TestLoginSSOCodeExchange(t *testing.T) {
    +	mainHelper.Parallel(t)
    +	th := Setup(t).InitBasic()
    +	defer th.TearDown()
    +
    +	t.Run("wrong token type cannot be used for code exchange", func(t *testing.T) {
    +		th.App.UpdateConfig(func(cfg *model.Config) {
    +			cfg.FeatureFlags.MobileSSOCodeExchange = true
    +		})
    +
    +		token := model.NewToken(model.TokenTypeOAuth, "extra-data")
    +		require.NoError(t, th.App.Srv().Store().Token().Save(token))
    +		defer func() {
    +			_ = th.App.Srv().Store().Token().Delete(token.Token)
    +		}()
    +
    +		props := map[string]string{
    +			"login_code":    token.Token,
    +			"code_verifier": "test_verifier",
    +			"state":         "test_state",
    +		}
    +
    +		resp, err := th.Client.DoAPIPost(context.Background(), "/users/login/sso/code-exchange", model.MapToJSON(props))
    +		require.Error(t, err)
    +		require.Equal(t, http.StatusNotFound, resp.StatusCode)
    +	})
    +
    +	t.Run("successful code exchange with S256 challenge", func(t *testing.T) {
    +		th.App.UpdateConfig(func(cfg *model.Config) {
    +			cfg.FeatureFlags.MobileSSOCodeExchange = true
    +		})
    +
    +		samlUser := th.CreateUserWithAuth(model.UserAuthServiceSaml)
    +
    +		codeVerifier := "test_code_verifier_123456789"
    +		state := "test_state_value"
    +
    +		sum := sha256.Sum256([]byte(codeVerifier))
    +		codeChallenge := base64.RawURLEncoding.EncodeToString(sum[:])
    +
    +		extra := map[string]string{
    +			"user_id":               samlUser.Id,
    +			"code_challenge":        codeChallenge,
    +			"code_challenge_method": "S256",
    +			"state":                 state,
    +		}
    +
    +		token := model.NewToken(model.TokenTypeSSOCodeExchange, model.MapToJSON(extra))
    +		require.NoError(t, th.App.Srv().Store().Token().Save(token))
    +
    +		props := map[string]string{
    +			"login_code":    token.Token,
    +			"code_verifier": codeVerifier,
    +			"state":         state,
    +		}
    +
    +		resp, err := th.Client.DoAPIPost(context.Background(), "/users/login/sso/code-exchange", model.MapToJSON(props))
    +		require.NoError(t, err)
    +		require.Equal(t, http.StatusOK, resp.StatusCode)
    +
    +		var result map[string]string
    +		require.NoError(t, json.NewDecoder(resp.Body).Decode(&result))
    +		assert.NotEmpty(t, result["token"])
    +		assert.NotEmpty(t, result["csrf"])
    +
    +		_, err = th.App.Srv().Store().Token().GetByToken(token.Token)
    +		require.Error(t, err)
    +
    +		authenticatedClient := model.NewAPIv4Client(th.Client.URL)
    +		authenticatedClient.SetToken(result["token"])
    +
    +		user, _, err := authenticatedClient.GetMe(context.Background(), "")
    +		require.NoError(t, err)
    +		assert.Equal(t, samlUser.Id, user.Id)
    +		assert.Equal(t, samlUser.Email, user.Email)
    +		assert.Equal(t, samlUser.Username, user.Username)
    +	})
    +}
    +
     func TestGetUsersByNames(t *testing.T) {
     	mainHelper.Parallel(t)
     	th := Setup(t).InitBasic()
    
  • server/channels/app/oauth.go+1 1 modified
    @@ -977,7 +977,7 @@ func (a *App) SwitchEmailToOAuth(c request.CTX, w http.ResponseWriter, r *http.R
     	stateProps["email"] = email
     
     	if service == model.UserAuthServiceSaml {
    -		samlToken, samlErr := a.CreateSamlRelayToken(email)
    +		samlToken, samlErr := a.CreateSamlRelayToken(model.TokenTypeSaml, email)
     		if samlErr != nil {
     			return "", samlErr
     		}
    
  • server/channels/app/saml.go+2 2 modified
    @@ -298,8 +298,8 @@ func (a *App) ResetSamlAuthDataToEmail(includeDeleted bool, dryRun bool, userIDs
     	return
     }
     
    -func (a *App) CreateSamlRelayToken(extra string) (*model.Token, *model.AppError) {
    -	token := model.NewToken(model.TokenTypeSaml, extra)
    +func (a *App) CreateSamlRelayToken(tokenType string, extra string) (*model.Token, *model.AppError) {
    +	token := model.NewToken(tokenType, extra)
     
     	if err := a.Srv().Store().Token().Save(token); err != nil {
     		var appErr *model.AppError
    
  • server/channels/app/user.go+2 2 modified
    @@ -1750,8 +1750,8 @@ func (a *App) GetTokenById(token string) (*model.Token, *model.AppError) {
     	return rtoken, nil
     }
     
    -func (a *App) ConsumeTokenOnce(tokenStr string) (*model.Token, *model.AppError) {
    -	token, err := a.Srv().Store().Token().ConsumeOnce(tokenStr)
    +func (a *App) ConsumeTokenOnce(tokenType, tokenStr string) (*model.Token, *model.AppError) {
    +	token, err := a.Srv().Store().Token().ConsumeOnce(tokenType, tokenStr)
     	if err != nil {
     		var status int
     		switch err.(type) {
    
  • server/channels/app/user_test.go+82 0 modified
    @@ -9,6 +9,7 @@ import (
     	"database/sql"
     	"encoding/json"
     	"errors"
    +	"net/http"
     	"os"
     	"path/filepath"
     	"strings"
    @@ -2484,3 +2485,84 @@ func TestRemoteUserDirectChannelCreation(t *testing.T) {
     		assert.Equal(t, model.ChannelTypeDirect, channel.Type)
     	})
     }
    +
    +func TestConsumeTokenOnce(t *testing.T) {
    +	mainHelper.Parallel(t)
    +	th := Setup(t).InitBasic()
    +	defer th.TearDown()
    +
    +	t.Run("successfully consume valid token", func(t *testing.T) {
    +		token := model.NewToken(model.TokenTypeOAuth, "extra-data")
    +		require.NoError(t, th.App.Srv().Store().Token().Save(token))
    +
    +		consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeOAuth, token.Token)
    +		require.Nil(t, appErr)
    +		require.NotNil(t, consumedToken)
    +		assert.Equal(t, token.Token, consumedToken.Token)
    +		assert.Equal(t, model.TokenTypeOAuth, consumedToken.Type)
    +		assert.Equal(t, "extra-data", consumedToken.Extra)
    +
    +		_, err := th.App.Srv().Store().Token().GetByToken(token.Token)
    +		require.Error(t, err)
    +	})
    +
    +	t.Run("token not found returns 404", func(t *testing.T) {
    +		nonExistentToken := model.NewRandomString(model.TokenSize)
    +
    +		consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeOAuth, nonExistentToken)
    +		require.NotNil(t, appErr)
    +		require.Nil(t, consumedToken)
    +		assert.Equal(t, http.StatusNotFound, appErr.StatusCode)
    +		assert.Equal(t, "ConsumeTokenOnce", appErr.Where)
    +	})
    +
    +	t.Run("wrong token type returns not found", func(t *testing.T) {
    +		token := model.NewToken(model.TokenTypeOAuth, "extra-data")
    +		require.NoError(t, th.App.Srv().Store().Token().Save(token))
    +		defer func() {
    +			_ = th.App.Srv().Store().Token().Delete(token.Token)
    +		}()
    +
    +		consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeSaml, token.Token)
    +		require.NotNil(t, appErr)
    +		require.Nil(t, consumedToken)
    +		assert.Equal(t, http.StatusNotFound, appErr.StatusCode)
    +
    +		_, err := th.App.Srv().Store().Token().GetByToken(token.Token)
    +		require.NoError(t, err)
    +	})
    +
    +	t.Run("token can only be consumed once", func(t *testing.T) {
    +		token := model.NewToken(model.TokenTypeSSOCodeExchange, "extra-data")
    +		require.NoError(t, th.App.Srv().Store().Token().Save(token))
    +
    +		consumedToken1, appErr := th.App.ConsumeTokenOnce(model.TokenTypeSSOCodeExchange, token.Token)
    +		require.Nil(t, appErr)
    +		require.NotNil(t, consumedToken1)
    +
    +		consumedToken2, appErr := th.App.ConsumeTokenOnce(model.TokenTypeSSOCodeExchange, token.Token)
    +		require.NotNil(t, appErr)
    +		require.Nil(t, consumedToken2)
    +		assert.Equal(t, http.StatusNotFound, appErr.StatusCode)
    +	})
    +
    +	t.Run("empty token string returns not found", func(t *testing.T) {
    +		consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeOAuth, "")
    +		require.NotNil(t, appErr)
    +		require.Nil(t, consumedToken)
    +		assert.Equal(t, http.StatusNotFound, appErr.StatusCode)
    +	})
    +
    +	t.Run("empty token type returns not found", func(t *testing.T) {
    +		token := model.NewToken(model.TokenTypeOAuth, "extra-data")
    +		require.NoError(t, th.App.Srv().Store().Token().Save(token))
    +		defer func() {
    +			_ = th.App.Srv().Store().Token().Delete(token.Token)
    +		}()
    +
    +		consumedToken, appErr := th.App.ConsumeTokenOnce("", token.Token)
    +		require.NotNil(t, appErr)
    +		require.Nil(t, consumedToken)
    +		assert.Equal(t, http.StatusNotFound, appErr.StatusCode)
    +	})
    +}
    
  • server/channels/store/retrylayer/retrylayer.go+2 2 modified
    @@ -14138,11 +14138,11 @@ func (s *RetryLayerTokenStore) Cleanup(expiryTime int64) {
     
     }
     
    -func (s *RetryLayerTokenStore) ConsumeOnce(tokenStr string) (*model.Token, error) {
    +func (s *RetryLayerTokenStore) ConsumeOnce(tokenType string, tokenStr string) (*model.Token, error) {
     
     	tries := 0
     	for {
    -		result, err := s.TokenStore.ConsumeOnce(tokenStr)
    +		result, err := s.TokenStore.ConsumeOnce(tokenType, tokenStr)
     		if err == nil {
     			return result, nil
     		}
    
  • server/channels/store/sqlstore/tokens_store.go+29 4 modified
    @@ -78,16 +78,41 @@ func (s SqlTokenStore) GetByToken(tokenString string) (*model.Token, error) {
     	return &token, nil
     }
     
    -func (s SqlTokenStore) ConsumeOnce(tokenStr string) (*model.Token, error) {
    +func (s SqlTokenStore) ConsumeOnce(tokenType, tokenStr string) (*model.Token, error) {
     	var token model.Token
     
    -	query := `DELETE FROM Tokens WHERE Token = ? RETURNING *`
    +	if s.DriverName() == model.DatabaseDriverPostgres {
    +		query := `DELETE FROM Tokens WHERE Type = ? AND Token = ? RETURNING *`
    +		if err := s.GetMaster().Get(&token, query, tokenType, tokenStr); err != nil {
    +			if err == sql.ErrNoRows {
    +				return nil, store.NewErrNotFound("Token", tokenStr)
    +			}
    +			return nil, errors.Wrapf(err, "failed to consume token with type %s", tokenType)
    +		}
    +		return &token, nil
    +	}
    +
    +	transaction, err := s.GetMaster().Beginx()
    +	if err != nil {
    +		return nil, errors.Wrap(err, "failed to begin transaction")
    +	}
    +	defer finalizeTransactionX(transaction, &err)
     
    -	if err := s.GetMaster().Get(&token, query, tokenStr); err != nil {
    +	query := `SELECT * FROM Tokens WHERE Type = ? AND Token = ? FOR UPDATE`
    +	if err = transaction.Get(&token, query, tokenType, tokenStr); err != nil {
     		if err == sql.ErrNoRows {
     			return nil, store.NewErrNotFound("Token", tokenStr)
     		}
    -		return nil, errors.Wrapf(err, "failed to consume token")
    +		return nil, errors.Wrapf(err, "failed to select token with type %s", tokenType)
    +	}
    +
    +	deleteQuery := `DELETE FROM Tokens WHERE Type = ? AND Token = ?`
    +	if _, err = transaction.Exec(deleteQuery, tokenType, tokenStr); err != nil {
    +		return nil, errors.Wrapf(err, "failed to delete token with type %s", tokenType)
    +	}
    +
    +	if err = transaction.Commit(); err != nil {
    +		return nil, errors.Wrap(err, "failed to commit transaction")
     	}
     
     	return &token, nil
    
  • server/channels/store/store.go+1 1 modified
    @@ -693,7 +693,7 @@ type TokenStore interface {
     	Save(recovery *model.Token) error
     	Delete(token string) error
     	GetByToken(token string) (*model.Token, error)
    -	ConsumeOnce(tokenStr string) (*model.Token, error)
    +	ConsumeOnce(tokenType, tokenStr string) (*model.Token, error)
     	Cleanup(expiryTime int64)
     	GetAllTokensByType(tokenType string) ([]*model.Token, error)
     	RemoveAllTokensByType(tokenType string) error
    
  • server/channels/store/storetest/mocks/TokenStore.go+9 9 modified
    @@ -19,29 +19,29 @@ func (_m *TokenStore) Cleanup(expiryTime int64) {
     	_m.Called(expiryTime)
     }
     
    -// ConsumeOnce provides a mock function with given fields: tokenStr
    -func (_m *TokenStore) ConsumeOnce(tokenStr string) (*model.Token, error) {
    -	ret := _m.Called(tokenStr)
    +// ConsumeOnce provides a mock function with given fields: tokenType, tokenStr
    +func (_m *TokenStore) ConsumeOnce(tokenType string, tokenStr string) (*model.Token, error) {
    +	ret := _m.Called(tokenType, tokenStr)
     
     	if len(ret) == 0 {
     		panic("no return value specified for ConsumeOnce")
     	}
     
     	var r0 *model.Token
     	var r1 error
    -	if rf, ok := ret.Get(0).(func(string) (*model.Token, error)); ok {
    -		return rf(tokenStr)
    +	if rf, ok := ret.Get(0).(func(string, string) (*model.Token, error)); ok {
    +		return rf(tokenType, tokenStr)
     	}
    -	if rf, ok := ret.Get(0).(func(string) *model.Token); ok {
    -		r0 = rf(tokenStr)
    +	if rf, ok := ret.Get(0).(func(string, string) *model.Token); ok {
    +		r0 = rf(tokenType, tokenStr)
     	} else {
     		if ret.Get(0) != nil {
     			r0 = ret.Get(0).(*model.Token)
     		}
     	}
     
    -	if rf, ok := ret.Get(1).(func(string) error); ok {
    -		r1 = rf(tokenStr)
    +	if rf, ok := ret.Get(1).(func(string, string) error); ok {
    +		r1 = rf(tokenType, tokenStr)
     	} else {
     		r1 = ret.Error(1)
     	}
    
  • server/channels/store/storetest/tokens_store.go+128 0 modified
    @@ -16,6 +16,7 @@ import (
     
     func TestTokensStore(t *testing.T, rctx request.CTX, ss store.Store) {
     	t.Run("TokensCleanup", func(t *testing.T) { testTokensCleanup(t, rctx, ss) })
    +	t.Run("ConsumeOnce", func(t *testing.T) { testConsumeOnce(t, rctx, ss) })
     }
     
     func testTokensCleanup(t *testing.T, rctx request.CTX, ss store.Store) {
    @@ -41,3 +42,130 @@ func testTokensCleanup(t *testing.T, rctx request.CTX, ss store.Store) {
     	require.NoError(t, err)
     	assert.Len(t, tokens, 0)
     }
    +
    +func testConsumeOnce(t *testing.T, rctx request.CTX, ss store.Store) {
    +	t.Run("successfully consume token once", func(t *testing.T) {
    +		token := &model.Token{
    +			Token:    model.NewRandomString(model.TokenSize),
    +			CreateAt: model.GetMillis(),
    +			Type:     model.TokenTypeOAuth,
    +			Extra:    "test-extra",
    +		}
    +		err := ss.Token().Save(token)
    +		require.NoError(t, err)
    +
    +		consumedToken, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token)
    +		require.NoError(t, err)
    +		assert.Equal(t, token.Token, consumedToken.Token)
    +		assert.Equal(t, token.Type, consumedToken.Type)
    +		assert.Equal(t, token.Extra, consumedToken.Extra)
    +
    +		tokens, err := ss.Token().GetAllTokensByType(model.TokenTypeOAuth)
    +		require.NoError(t, err)
    +		assert.Len(t, tokens, 0)
    +	})
    +
    +	t.Run("second consumption of same token fails", func(t *testing.T) {
    +		token := &model.Token{
    +			Token:    model.NewRandomString(model.TokenSize),
    +			CreateAt: model.GetMillis(),
    +			Type:     model.TokenTypeOAuth,
    +			Extra:    "test-extra",
    +		}
    +		err := ss.Token().Save(token)
    +		require.NoError(t, err)
    +
    +		_, err = ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token)
    +		require.NoError(t, err)
    +
    +		_, err = ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token)
    +		require.Error(t, err)
    +		var nfErr *store.ErrNotFound
    +		assert.ErrorAs(t, err, &nfErr)
    +	})
    +
    +	t.Run("consume with wrong type fails", func(t *testing.T) {
    +		token := &model.Token{
    +			Token:    model.NewRandomString(model.TokenSize),
    +			CreateAt: model.GetMillis(),
    +			Type:     model.TokenTypeOAuth,
    +			Extra:    "test-extra",
    +		}
    +		err := ss.Token().Save(token)
    +		require.NoError(t, err)
    +
    +		_, err = ss.Token().ConsumeOnce(model.TokenTypeSSOCodeExchange, token.Token)
    +		require.Error(t, err)
    +		var nfErr *store.ErrNotFound
    +		assert.ErrorAs(t, err, &nfErr)
    +
    +		tokens, err := ss.Token().GetAllTokensByType(model.TokenTypeOAuth)
    +		require.NoError(t, err)
    +		assert.Len(t, tokens, 1)
    +
    +		err = ss.Token().Delete(token.Token)
    +		require.NoError(t, err)
    +	})
    +
    +	t.Run("consume non-existent token fails", func(t *testing.T) {
    +		nonExistentToken := model.NewRandomString(model.TokenSize)
    +		_, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, nonExistentToken)
    +		require.Error(t, err)
    +		var nfErr *store.ErrNotFound
    +		assert.ErrorAs(t, err, &nfErr)
    +	})
    +
    +	t.Run("multiple tokens with same type can each be consumed once", func(t *testing.T) {
    +		tokens := make([]*model.Token, 3)
    +		for i := range tokens {
    +			tokens[i] = &model.Token{
    +				Token:    model.NewRandomString(model.TokenSize),
    +				CreateAt: model.GetMillis(),
    +				Type:     model.TokenTypeOAuth,
    +				Extra:    "test-extra",
    +			}
    +			err := ss.Token().Save(tokens[i])
    +			require.NoError(t, err)
    +		}
    +
    +		for _, token := range tokens {
    +			consumedToken, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token)
    +			require.NoError(t, err)
    +			assert.Equal(t, token.Token, consumedToken.Token)
    +		}
    +
    +		allTokens, err := ss.Token().GetAllTokensByType(model.TokenTypeOAuth)
    +		require.NoError(t, err)
    +		assert.Len(t, allTokens, 0)
    +	})
    +
    +	t.Run("consuming token of different type leaves others intact", func(t *testing.T) {
    +		oauthToken := &model.Token{
    +			Token:    model.NewRandomString(model.TokenSize),
    +			CreateAt: model.GetMillis(),
    +			Type:     model.TokenTypeOAuth,
    +			Extra:    "oauth-extra",
    +		}
    +		codeExchangeToken := &model.Token{
    +			Token:    model.NewRandomString(model.TokenSize),
    +			CreateAt: model.GetMillis(),
    +			Type:     model.TokenTypeSSOCodeExchange,
    +			Extra:    "password-extra",
    +		}
    +		err := ss.Token().Save(oauthToken)
    +		require.NoError(t, err)
    +		err = ss.Token().Save(codeExchangeToken)
    +		require.NoError(t, err)
    +
    +		consumedToken, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, oauthToken.Token)
    +		require.NoError(t, err)
    +		assert.Equal(t, oauthToken.Token, consumedToken.Token)
    +
    +		codeExchangeTokens, err := ss.Token().GetAllTokensByType(model.TokenTypeSSOCodeExchange)
    +		require.NoError(t, err)
    +		assert.Len(t, codeExchangeTokens, 1)
    +
    +		err = ss.Token().Delete(codeExchangeToken.Token)
    +		require.NoError(t, err)
    +	})
    +}
    
  • server/channels/store/timerlayer/timerlayer.go+2 2 modified
    @@ -11106,10 +11106,10 @@ func (s *TimerLayerTokenStore) Cleanup(expiryTime int64) {
     	}
     }
     
    -func (s *TimerLayerTokenStore) ConsumeOnce(tokenStr string) (*model.Token, error) {
    +func (s *TimerLayerTokenStore) ConsumeOnce(tokenType string, tokenStr string) (*model.Token, error) {
     	start := time.Now()
     
    -	result, err := s.TokenStore.ConsumeOnce(tokenStr)
    +	result, err := s.TokenStore.ConsumeOnce(tokenType, tokenStr)
     
     	elapsed := float64(time.Since(start)) / float64(time.Second)
     	if s.Root.Metrics != nil {
    
  • server/channels/web/saml.go+7 4 modified
    @@ -106,7 +106,7 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) {
     		return
     	}
     
    -	//Validate that the user is with SAML and all that
    +	// Validate that the user is with SAML and all that
     	encodedXML := r.FormValue("SAMLResponse")
     	relayState := r.FormValue("RelayState")
     
    @@ -161,7 +161,8 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) {
     		return
     	}
     
    -	if err = c.App.CheckUserAllAuthenticationCriteria(c.AppContext, user, ""); err != nil {
    +	err = c.App.CheckUserAllAuthenticationCriteria(c.AppContext, user, "")
    +	if err != nil {
     		handleError(err)
     		return
     	}
    @@ -250,8 +251,10 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) {
     			"code_challenge":        samlChallenge,
     			"code_challenge_method": samlMethod,
     		})
    -		code := model.NewToken(model.TokenTypeSaml, extra)
    -		if err := c.App.Srv().Store().Token().Save(code); err != nil {
    +
    +		var code *model.Token
    +		code, err = c.App.CreateSamlRelayToken(model.TokenTypeSSOCodeExchange, extra)
    +		if err != nil {
     			handleError(model.NewAppError("completeSaml", "app.recover.save.app_error", nil, "", http.StatusInternalServerError).Wrap(err))
     			return
     		}
    
  • server/public/model/token.go+5 4 modified
    @@ -8,10 +8,11 @@ import (
     )
     
     const (
    -	TokenSize          = 64
    -	MaxTokenExipryTime = 1000 * 60 * 60 * 48 // 48 hour
    -	TokenTypeOAuth     = "oauth"
    -	TokenTypeSaml      = "saml"
    +	TokenSize                = 64
    +	MaxTokenExipryTime       = 1000 * 60 * 60 * 48 // 48 hour
    +	TokenTypeOAuth           = "oauth"
    +	TokenTypeSaml            = "saml"
    +	TokenTypeSSOCodeExchange = "sso-code-exchange"
     )
     
     type Token struct {
    
5072bbf689a4

Automated cherry pick of #34247 (#34256)

https://github.com/mattermost/mattermostMattermost BuildOct 27, 2025via ghsa
14 files changed · +352 32
  • server/channels/api4/user.go+1 1 modified
    @@ -130,7 +130,7 @@ func loginSSOCodeExchange(c *Context, w http.ResponseWriter, r *http.Request) {
     	}
     
     	// Consume one-time code atomically
    -	token, appErr := c.App.ConsumeTokenOnce(loginCode)
    +	token, appErr := c.App.ConsumeTokenOnce(model.TokenTypeSSOCodeExchange, loginCode)
     	if appErr != nil {
     		c.Err = appErr
     		return
    
  • server/channels/api4/user_test.go+81 0 modified
    @@ -6,6 +6,8 @@ package api4
     import (
     	"bytes"
     	"context"
    +	"crypto/sha256"
    +	"encoding/base64"
     	"encoding/json"
     	"fmt"
     	"image/png"
    @@ -8673,6 +8675,85 @@ func TestLoginWithDesktopToken(t *testing.T) {
     	})
     }
     
    +func TestLoginSSOCodeExchange(t *testing.T) {
    +	mainHelper.Parallel(t)
    +	th := Setup(t).InitBasic()
    +	defer th.TearDown()
    +
    +	t.Run("wrong token type cannot be used for code exchange", func(t *testing.T) {
    +		th.App.UpdateConfig(func(cfg *model.Config) {
    +			cfg.FeatureFlags.MobileSSOCodeExchange = true
    +		})
    +
    +		token := model.NewToken(model.TokenTypeOAuth, "extra-data")
    +		require.NoError(t, th.App.Srv().Store().Token().Save(token))
    +		defer func() {
    +			_ = th.App.Srv().Store().Token().Delete(token.Token)
    +		}()
    +
    +		props := map[string]string{
    +			"login_code":    token.Token,
    +			"code_verifier": "test_verifier",
    +			"state":         "test_state",
    +		}
    +
    +		resp, err := th.Client.DoAPIPost(context.Background(), "/users/login/sso/code-exchange", model.MapToJSON(props))
    +		require.Error(t, err)
    +		require.Equal(t, http.StatusNotFound, resp.StatusCode)
    +	})
    +
    +	t.Run("successful code exchange with S256 challenge", func(t *testing.T) {
    +		th.App.UpdateConfig(func(cfg *model.Config) {
    +			cfg.FeatureFlags.MobileSSOCodeExchange = true
    +		})
    +
    +		samlUser := th.CreateUserWithAuth(model.UserAuthServiceSaml)
    +
    +		codeVerifier := "test_code_verifier_123456789"
    +		state := "test_state_value"
    +
    +		sum := sha256.Sum256([]byte(codeVerifier))
    +		codeChallenge := base64.RawURLEncoding.EncodeToString(sum[:])
    +
    +		extra := map[string]string{
    +			"user_id":               samlUser.Id,
    +			"code_challenge":        codeChallenge,
    +			"code_challenge_method": "S256",
    +			"state":                 state,
    +		}
    +
    +		token := model.NewToken(model.TokenTypeSSOCodeExchange, model.MapToJSON(extra))
    +		require.NoError(t, th.App.Srv().Store().Token().Save(token))
    +
    +		props := map[string]string{
    +			"login_code":    token.Token,
    +			"code_verifier": codeVerifier,
    +			"state":         state,
    +		}
    +
    +		resp, err := th.Client.DoAPIPost(context.Background(), "/users/login/sso/code-exchange", model.MapToJSON(props))
    +		require.NoError(t, err)
    +		require.Equal(t, http.StatusOK, resp.StatusCode)
    +
    +		var result map[string]string
    +		require.NoError(t, json.NewDecoder(resp.Body).Decode(&result))
    +		assert.NotEmpty(t, result["token"])
    +		assert.NotEmpty(t, result["csrf"])
    +
    +		_, err = th.App.Srv().Store().Token().GetByToken(token.Token)
    +		require.Error(t, err)
    +
    +		authenticatedClient := model.NewAPIv4Client(th.Client.URL)
    +		authenticatedClient.SetToken(result["token"])
    +
    +		user, _, err := authenticatedClient.GetMe(context.Background(), "")
    +		require.NoError(t, err)
    +		assert.Equal(t, samlUser.Id, user.Id)
    +		assert.Equal(t, samlUser.Email, user.Email)
    +		assert.Equal(t, samlUser.Username, user.Username)
    +	})
    +}
    +
     func TestGetUsersByNames(t *testing.T) {
     	mainHelper.Parallel(t)
     	th := Setup(t).InitBasic()
    
  • server/channels/app/oauth.go+1 1 modified
    @@ -977,7 +977,7 @@ func (a *App) SwitchEmailToOAuth(c request.CTX, w http.ResponseWriter, r *http.R
     	stateProps["email"] = email
     
     	if service == model.UserAuthServiceSaml {
    -		samlToken, samlErr := a.CreateSamlRelayToken(email)
    +		samlToken, samlErr := a.CreateSamlRelayToken(model.TokenTypeSaml, email)
     		if samlErr != nil {
     			return "", samlErr
     		}
    
  • server/channels/app/saml.go+2 2 modified
    @@ -298,8 +298,8 @@ func (a *App) ResetSamlAuthDataToEmail(includeDeleted bool, dryRun bool, userIDs
     	return
     }
     
    -func (a *App) CreateSamlRelayToken(extra string) (*model.Token, *model.AppError) {
    -	token := model.NewToken(model.TokenTypeSaml, extra)
    +func (a *App) CreateSamlRelayToken(tokenType string, extra string) (*model.Token, *model.AppError) {
    +	token := model.NewToken(tokenType, extra)
     
     	if err := a.Srv().Store().Token().Save(token); err != nil {
     		var appErr *model.AppError
    
  • server/channels/app/user.go+2 2 modified
    @@ -1750,8 +1750,8 @@ func (a *App) GetTokenById(token string) (*model.Token, *model.AppError) {
     	return rtoken, nil
     }
     
    -func (a *App) ConsumeTokenOnce(tokenStr string) (*model.Token, *model.AppError) {
    -	token, err := a.Srv().Store().Token().ConsumeOnce(tokenStr)
    +func (a *App) ConsumeTokenOnce(tokenType, tokenStr string) (*model.Token, *model.AppError) {
    +	token, err := a.Srv().Store().Token().ConsumeOnce(tokenType, tokenStr)
     	if err != nil {
     		var status int
     		switch err.(type) {
    
  • server/channels/app/user_test.go+82 0 modified
    @@ -9,6 +9,7 @@ import (
     	"database/sql"
     	"encoding/json"
     	"errors"
    +	"net/http"
     	"os"
     	"path/filepath"
     	"strings"
    @@ -2484,3 +2485,84 @@ func TestRemoteUserDirectChannelCreation(t *testing.T) {
     		assert.Equal(t, model.ChannelTypeDirect, channel.Type)
     	})
     }
    +
    +func TestConsumeTokenOnce(t *testing.T) {
    +	mainHelper.Parallel(t)
    +	th := Setup(t).InitBasic()
    +	defer th.TearDown()
    +
    +	t.Run("successfully consume valid token", func(t *testing.T) {
    +		token := model.NewToken(model.TokenTypeOAuth, "extra-data")
    +		require.NoError(t, th.App.Srv().Store().Token().Save(token))
    +
    +		consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeOAuth, token.Token)
    +		require.Nil(t, appErr)
    +		require.NotNil(t, consumedToken)
    +		assert.Equal(t, token.Token, consumedToken.Token)
    +		assert.Equal(t, model.TokenTypeOAuth, consumedToken.Type)
    +		assert.Equal(t, "extra-data", consumedToken.Extra)
    +
    +		_, err := th.App.Srv().Store().Token().GetByToken(token.Token)
    +		require.Error(t, err)
    +	})
    +
    +	t.Run("token not found returns 404", func(t *testing.T) {
    +		nonExistentToken := model.NewRandomString(model.TokenSize)
    +
    +		consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeOAuth, nonExistentToken)
    +		require.NotNil(t, appErr)
    +		require.Nil(t, consumedToken)
    +		assert.Equal(t, http.StatusNotFound, appErr.StatusCode)
    +		assert.Equal(t, "ConsumeTokenOnce", appErr.Where)
    +	})
    +
    +	t.Run("wrong token type returns not found", func(t *testing.T) {
    +		token := model.NewToken(model.TokenTypeOAuth, "extra-data")
    +		require.NoError(t, th.App.Srv().Store().Token().Save(token))
    +		defer func() {
    +			_ = th.App.Srv().Store().Token().Delete(token.Token)
    +		}()
    +
    +		consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeSaml, token.Token)
    +		require.NotNil(t, appErr)
    +		require.Nil(t, consumedToken)
    +		assert.Equal(t, http.StatusNotFound, appErr.StatusCode)
    +
    +		_, err := th.App.Srv().Store().Token().GetByToken(token.Token)
    +		require.NoError(t, err)
    +	})
    +
    +	t.Run("token can only be consumed once", func(t *testing.T) {
    +		token := model.NewToken(model.TokenTypeSSOCodeExchange, "extra-data")
    +		require.NoError(t, th.App.Srv().Store().Token().Save(token))
    +
    +		consumedToken1, appErr := th.App.ConsumeTokenOnce(model.TokenTypeSSOCodeExchange, token.Token)
    +		require.Nil(t, appErr)
    +		require.NotNil(t, consumedToken1)
    +
    +		consumedToken2, appErr := th.App.ConsumeTokenOnce(model.TokenTypeSSOCodeExchange, token.Token)
    +		require.NotNil(t, appErr)
    +		require.Nil(t, consumedToken2)
    +		assert.Equal(t, http.StatusNotFound, appErr.StatusCode)
    +	})
    +
    +	t.Run("empty token string returns not found", func(t *testing.T) {
    +		consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeOAuth, "")
    +		require.NotNil(t, appErr)
    +		require.Nil(t, consumedToken)
    +		assert.Equal(t, http.StatusNotFound, appErr.StatusCode)
    +	})
    +
    +	t.Run("empty token type returns not found", func(t *testing.T) {
    +		token := model.NewToken(model.TokenTypeOAuth, "extra-data")
    +		require.NoError(t, th.App.Srv().Store().Token().Save(token))
    +		defer func() {
    +			_ = th.App.Srv().Store().Token().Delete(token.Token)
    +		}()
    +
    +		consumedToken, appErr := th.App.ConsumeTokenOnce("", token.Token)
    +		require.NotNil(t, appErr)
    +		require.Nil(t, consumedToken)
    +		assert.Equal(t, http.StatusNotFound, appErr.StatusCode)
    +	})
    +}
    
  • server/channels/store/retrylayer/retrylayer.go+2 2 modified
    @@ -14138,11 +14138,11 @@ func (s *RetryLayerTokenStore) Cleanup(expiryTime int64) {
     
     }
     
    -func (s *RetryLayerTokenStore) ConsumeOnce(tokenStr string) (*model.Token, error) {
    +func (s *RetryLayerTokenStore) ConsumeOnce(tokenType string, tokenStr string) (*model.Token, error) {
     
     	tries := 0
     	for {
    -		result, err := s.TokenStore.ConsumeOnce(tokenStr)
    +		result, err := s.TokenStore.ConsumeOnce(tokenType, tokenStr)
     		if err == nil {
     			return result, nil
     		}
    
  • server/channels/store/sqlstore/tokens_store.go+29 4 modified
    @@ -78,16 +78,41 @@ func (s SqlTokenStore) GetByToken(tokenString string) (*model.Token, error) {
     	return &token, nil
     }
     
    -func (s SqlTokenStore) ConsumeOnce(tokenStr string) (*model.Token, error) {
    +func (s SqlTokenStore) ConsumeOnce(tokenType, tokenStr string) (*model.Token, error) {
     	var token model.Token
     
    -	query := `DELETE FROM Tokens WHERE Token = ? RETURNING *`
    +	if s.DriverName() == model.DatabaseDriverPostgres {
    +		query := `DELETE FROM Tokens WHERE Type = ? AND Token = ? RETURNING *`
    +		if err := s.GetMaster().Get(&token, query, tokenType, tokenStr); err != nil {
    +			if err == sql.ErrNoRows {
    +				return nil, store.NewErrNotFound("Token", tokenStr)
    +			}
    +			return nil, errors.Wrapf(err, "failed to consume token with type %s", tokenType)
    +		}
    +		return &token, nil
    +	}
    +
    +	transaction, err := s.GetMaster().Beginx()
    +	if err != nil {
    +		return nil, errors.Wrap(err, "failed to begin transaction")
    +	}
    +	defer finalizeTransactionX(transaction, &err)
     
    -	if err := s.GetMaster().Get(&token, query, tokenStr); err != nil {
    +	query := `SELECT * FROM Tokens WHERE Type = ? AND Token = ? FOR UPDATE`
    +	if err = transaction.Get(&token, query, tokenType, tokenStr); err != nil {
     		if err == sql.ErrNoRows {
     			return nil, store.NewErrNotFound("Token", tokenStr)
     		}
    -		return nil, errors.Wrapf(err, "failed to consume token")
    +		return nil, errors.Wrapf(err, "failed to select token with type %s", tokenType)
    +	}
    +
    +	deleteQuery := `DELETE FROM Tokens WHERE Type = ? AND Token = ?`
    +	if _, err = transaction.Exec(deleteQuery, tokenType, tokenStr); err != nil {
    +		return nil, errors.Wrapf(err, "failed to delete token with type %s", tokenType)
    +	}
    +
    +	if err = transaction.Commit(); err != nil {
    +		return nil, errors.Wrap(err, "failed to commit transaction")
     	}
     
     	return &token, nil
    
  • server/channels/store/store.go+1 1 modified
    @@ -693,7 +693,7 @@ type TokenStore interface {
     	Save(recovery *model.Token) error
     	Delete(token string) error
     	GetByToken(token string) (*model.Token, error)
    -	ConsumeOnce(tokenStr string) (*model.Token, error)
    +	ConsumeOnce(tokenType, tokenStr string) (*model.Token, error)
     	Cleanup(expiryTime int64)
     	GetAllTokensByType(tokenType string) ([]*model.Token, error)
     	RemoveAllTokensByType(tokenType string) error
    
  • server/channels/store/storetest/mocks/TokenStore.go+9 9 modified
    @@ -19,29 +19,29 @@ func (_m *TokenStore) Cleanup(expiryTime int64) {
     	_m.Called(expiryTime)
     }
     
    -// ConsumeOnce provides a mock function with given fields: tokenStr
    -func (_m *TokenStore) ConsumeOnce(tokenStr string) (*model.Token, error) {
    -	ret := _m.Called(tokenStr)
    +// ConsumeOnce provides a mock function with given fields: tokenType, tokenStr
    +func (_m *TokenStore) ConsumeOnce(tokenType string, tokenStr string) (*model.Token, error) {
    +	ret := _m.Called(tokenType, tokenStr)
     
     	if len(ret) == 0 {
     		panic("no return value specified for ConsumeOnce")
     	}
     
     	var r0 *model.Token
     	var r1 error
    -	if rf, ok := ret.Get(0).(func(string) (*model.Token, error)); ok {
    -		return rf(tokenStr)
    +	if rf, ok := ret.Get(0).(func(string, string) (*model.Token, error)); ok {
    +		return rf(tokenType, tokenStr)
     	}
    -	if rf, ok := ret.Get(0).(func(string) *model.Token); ok {
    -		r0 = rf(tokenStr)
    +	if rf, ok := ret.Get(0).(func(string, string) *model.Token); ok {
    +		r0 = rf(tokenType, tokenStr)
     	} else {
     		if ret.Get(0) != nil {
     			r0 = ret.Get(0).(*model.Token)
     		}
     	}
     
    -	if rf, ok := ret.Get(1).(func(string) error); ok {
    -		r1 = rf(tokenStr)
    +	if rf, ok := ret.Get(1).(func(string, string) error); ok {
    +		r1 = rf(tokenType, tokenStr)
     	} else {
     		r1 = ret.Error(1)
     	}
    
  • server/channels/store/storetest/tokens_store.go+128 0 modified
    @@ -16,6 +16,7 @@ import (
     
     func TestTokensStore(t *testing.T, rctx request.CTX, ss store.Store) {
     	t.Run("TokensCleanup", func(t *testing.T) { testTokensCleanup(t, rctx, ss) })
    +	t.Run("ConsumeOnce", func(t *testing.T) { testConsumeOnce(t, rctx, ss) })
     }
     
     func testTokensCleanup(t *testing.T, rctx request.CTX, ss store.Store) {
    @@ -41,3 +42,130 @@ func testTokensCleanup(t *testing.T, rctx request.CTX, ss store.Store) {
     	require.NoError(t, err)
     	assert.Len(t, tokens, 0)
     }
    +
    +func testConsumeOnce(t *testing.T, rctx request.CTX, ss store.Store) {
    +	t.Run("successfully consume token once", func(t *testing.T) {
    +		token := &model.Token{
    +			Token:    model.NewRandomString(model.TokenSize),
    +			CreateAt: model.GetMillis(),
    +			Type:     model.TokenTypeOAuth,
    +			Extra:    "test-extra",
    +		}
    +		err := ss.Token().Save(token)
    +		require.NoError(t, err)
    +
    +		consumedToken, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token)
    +		require.NoError(t, err)
    +		assert.Equal(t, token.Token, consumedToken.Token)
    +		assert.Equal(t, token.Type, consumedToken.Type)
    +		assert.Equal(t, token.Extra, consumedToken.Extra)
    +
    +		tokens, err := ss.Token().GetAllTokensByType(model.TokenTypeOAuth)
    +		require.NoError(t, err)
    +		assert.Len(t, tokens, 0)
    +	})
    +
    +	t.Run("second consumption of same token fails", func(t *testing.T) {
    +		token := &model.Token{
    +			Token:    model.NewRandomString(model.TokenSize),
    +			CreateAt: model.GetMillis(),
    +			Type:     model.TokenTypeOAuth,
    +			Extra:    "test-extra",
    +		}
    +		err := ss.Token().Save(token)
    +		require.NoError(t, err)
    +
    +		_, err = ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token)
    +		require.NoError(t, err)
    +
    +		_, err = ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token)
    +		require.Error(t, err)
    +		var nfErr *store.ErrNotFound
    +		assert.ErrorAs(t, err, &nfErr)
    +	})
    +
    +	t.Run("consume with wrong type fails", func(t *testing.T) {
    +		token := &model.Token{
    +			Token:    model.NewRandomString(model.TokenSize),
    +			CreateAt: model.GetMillis(),
    +			Type:     model.TokenTypeOAuth,
    +			Extra:    "test-extra",
    +		}
    +		err := ss.Token().Save(token)
    +		require.NoError(t, err)
    +
    +		_, err = ss.Token().ConsumeOnce(model.TokenTypeSSOCodeExchange, token.Token)
    +		require.Error(t, err)
    +		var nfErr *store.ErrNotFound
    +		assert.ErrorAs(t, err, &nfErr)
    +
    +		tokens, err := ss.Token().GetAllTokensByType(model.TokenTypeOAuth)
    +		require.NoError(t, err)
    +		assert.Len(t, tokens, 1)
    +
    +		err = ss.Token().Delete(token.Token)
    +		require.NoError(t, err)
    +	})
    +
    +	t.Run("consume non-existent token fails", func(t *testing.T) {
    +		nonExistentToken := model.NewRandomString(model.TokenSize)
    +		_, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, nonExistentToken)
    +		require.Error(t, err)
    +		var nfErr *store.ErrNotFound
    +		assert.ErrorAs(t, err, &nfErr)
    +	})
    +
    +	t.Run("multiple tokens with same type can each be consumed once", func(t *testing.T) {
    +		tokens := make([]*model.Token, 3)
    +		for i := range tokens {
    +			tokens[i] = &model.Token{
    +				Token:    model.NewRandomString(model.TokenSize),
    +				CreateAt: model.GetMillis(),
    +				Type:     model.TokenTypeOAuth,
    +				Extra:    "test-extra",
    +			}
    +			err := ss.Token().Save(tokens[i])
    +			require.NoError(t, err)
    +		}
    +
    +		for _, token := range tokens {
    +			consumedToken, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token)
    +			require.NoError(t, err)
    +			assert.Equal(t, token.Token, consumedToken.Token)
    +		}
    +
    +		allTokens, err := ss.Token().GetAllTokensByType(model.TokenTypeOAuth)
    +		require.NoError(t, err)
    +		assert.Len(t, allTokens, 0)
    +	})
    +
    +	t.Run("consuming token of different type leaves others intact", func(t *testing.T) {
    +		oauthToken := &model.Token{
    +			Token:    model.NewRandomString(model.TokenSize),
    +			CreateAt: model.GetMillis(),
    +			Type:     model.TokenTypeOAuth,
    +			Extra:    "oauth-extra",
    +		}
    +		codeExchangeToken := &model.Token{
    +			Token:    model.NewRandomString(model.TokenSize),
    +			CreateAt: model.GetMillis(),
    +			Type:     model.TokenTypeSSOCodeExchange,
    +			Extra:    "password-extra",
    +		}
    +		err := ss.Token().Save(oauthToken)
    +		require.NoError(t, err)
    +		err = ss.Token().Save(codeExchangeToken)
    +		require.NoError(t, err)
    +
    +		consumedToken, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, oauthToken.Token)
    +		require.NoError(t, err)
    +		assert.Equal(t, oauthToken.Token, consumedToken.Token)
    +
    +		codeExchangeTokens, err := ss.Token().GetAllTokensByType(model.TokenTypeSSOCodeExchange)
    +		require.NoError(t, err)
    +		assert.Len(t, codeExchangeTokens, 1)
    +
    +		err = ss.Token().Delete(codeExchangeToken.Token)
    +		require.NoError(t, err)
    +	})
    +}
    
  • server/channels/store/timerlayer/timerlayer.go+2 2 modified
    @@ -11106,10 +11106,10 @@ func (s *TimerLayerTokenStore) Cleanup(expiryTime int64) {
     	}
     }
     
    -func (s *TimerLayerTokenStore) ConsumeOnce(tokenStr string) (*model.Token, error) {
    +func (s *TimerLayerTokenStore) ConsumeOnce(tokenType string, tokenStr string) (*model.Token, error) {
     	start := time.Now()
     
    -	result, err := s.TokenStore.ConsumeOnce(tokenStr)
    +	result, err := s.TokenStore.ConsumeOnce(tokenType, tokenStr)
     
     	elapsed := float64(time.Since(start)) / float64(time.Second)
     	if s.Root.Metrics != nil {
    
  • server/channels/web/saml.go+7 4 modified
    @@ -106,7 +106,7 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) {
     		return
     	}
     
    -	//Validate that the user is with SAML and all that
    +	// Validate that the user is with SAML and all that
     	encodedXML := r.FormValue("SAMLResponse")
     	relayState := r.FormValue("RelayState")
     
    @@ -161,7 +161,8 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) {
     		return
     	}
     
    -	if err = c.App.CheckUserAllAuthenticationCriteria(c.AppContext, user, ""); err != nil {
    +	err = c.App.CheckUserAllAuthenticationCriteria(c.AppContext, user, "")
    +	if err != nil {
     		handleError(err)
     		return
     	}
    @@ -250,8 +251,10 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) {
     			"code_challenge":        samlChallenge,
     			"code_challenge_method": samlMethod,
     		})
    -		code := model.NewToken(model.TokenTypeSaml, extra)
    -		if err := c.App.Srv().Store().Token().Save(code); err != nil {
    +
    +		var code *model.Token
    +		code, err = c.App.CreateSamlRelayToken(model.TokenTypeSSOCodeExchange, extra)
    +		if err != nil {
     			handleError(model.NewAppError("completeSaml", "app.recover.save.app_error", nil, "", http.StatusInternalServerError).Wrap(err))
     			return
     		}
    
  • server/public/model/token.go+5 4 modified
    @@ -8,10 +8,11 @@ import (
     )
     
     const (
    -	TokenSize          = 64
    -	MaxTokenExipryTime = 1000 * 60 * 60 * 48 // 48 hour
    -	TokenTypeOAuth     = "oauth"
    -	TokenTypeSaml      = "saml"
    +	TokenSize                = 64
    +	MaxTokenExipryTime       = 1000 * 60 * 60 * 48 // 48 hour
    +	TokenTypeOAuth           = "oauth"
    +	TokenTypeSaml            = "saml"
    +	TokenTypeSSOCodeExchange = "sso-code-exchange"
     )
     
     type Token struct {
    
feb598ed2b7a

MM-66299: type handling for ConsumeTokenOnce (#34247) (#34261)

https://github.com/mattermost/mattermostJesse HallamOct 24, 2025via ghsa
14 files changed · +397 28
  • server/channels/api4/user.go+1 1 modified
    @@ -130,7 +130,7 @@ func loginSSOCodeExchange(c *Context, w http.ResponseWriter, r *http.Request) {
     	}
     
     	// Consume one-time code atomically
    -	token, appErr := c.App.ConsumeTokenOnce(loginCode)
    +	token, appErr := c.App.ConsumeTokenOnce(model.TokenTypeSSOCodeExchange, loginCode)
     	if appErr != nil {
     		c.Err = appErr
     		return
    
  • server/channels/api4/user_test.go+81 0 modified
    @@ -6,6 +6,8 @@ package api4
     import (
     	"bytes"
     	"context"
    +	"crypto/sha256"
    +	"encoding/base64"
     	"encoding/json"
     	"fmt"
     	"image/png"
    @@ -8490,6 +8492,85 @@ func TestLoginWithDesktopToken(t *testing.T) {
     	})
     }
     
    +func TestLoginSSOCodeExchange(t *testing.T) {
    +	mainHelper.Parallel(t)
    +	th := Setup(t).InitBasic()
    +	defer th.TearDown()
    +
    +	t.Run("wrong token type cannot be used for code exchange", func(t *testing.T) {
    +		th.App.UpdateConfig(func(cfg *model.Config) {
    +			cfg.FeatureFlags.MobileSSOCodeExchange = true
    +		})
    +
    +		token := model.NewToken(model.TokenTypeOAuth, "extra-data")
    +		require.NoError(t, th.App.Srv().Store().Token().Save(token))
    +		defer func() {
    +			_ = th.App.Srv().Store().Token().Delete(token.Token)
    +		}()
    +
    +		props := map[string]string{
    +			"login_code":    token.Token,
    +			"code_verifier": "test_verifier",
    +			"state":         "test_state",
    +		}
    +
    +		resp, err := th.Client.DoAPIPost(context.Background(), "/users/login/sso/code-exchange", model.MapToJSON(props))
    +		require.Error(t, err)
    +		require.Equal(t, http.StatusNotFound, resp.StatusCode)
    +	})
    +
    +	t.Run("successful code exchange with S256 challenge", func(t *testing.T) {
    +		th.App.UpdateConfig(func(cfg *model.Config) {
    +			cfg.FeatureFlags.MobileSSOCodeExchange = true
    +		})
    +
    +		samlUser := th.CreateUserWithAuth(model.UserAuthServiceSaml)
    +
    +		codeVerifier := "test_code_verifier_123456789"
    +		state := "test_state_value"
    +
    +		sum := sha256.Sum256([]byte(codeVerifier))
    +		codeChallenge := base64.RawURLEncoding.EncodeToString(sum[:])
    +
    +		extra := map[string]string{
    +			"user_id":               samlUser.Id,
    +			"code_challenge":        codeChallenge,
    +			"code_challenge_method": "S256",
    +			"state":                 state,
    +		}
    +
    +		token := model.NewToken(model.TokenTypeSSOCodeExchange, model.MapToJSON(extra))
    +		require.NoError(t, th.App.Srv().Store().Token().Save(token))
    +
    +		props := map[string]string{
    +			"login_code":    token.Token,
    +			"code_verifier": codeVerifier,
    +			"state":         state,
    +		}
    +
    +		resp, err := th.Client.DoAPIPost(context.Background(), "/users/login/sso/code-exchange", model.MapToJSON(props))
    +		require.NoError(t, err)
    +		require.Equal(t, http.StatusOK, resp.StatusCode)
    +
    +		var result map[string]string
    +		require.NoError(t, json.NewDecoder(resp.Body).Decode(&result))
    +		assert.NotEmpty(t, result["token"])
    +		assert.NotEmpty(t, result["csrf"])
    +
    +		_, err = th.App.Srv().Store().Token().GetByToken(token.Token)
    +		require.Error(t, err)
    +
    +		authenticatedClient := model.NewAPIv4Client(th.Client.URL)
    +		authenticatedClient.SetToken(result["token"])
    +
    +		user, _, err := authenticatedClient.GetMe(context.Background(), "")
    +		require.NoError(t, err)
    +		assert.Equal(t, samlUser.Id, user.Id)
    +		assert.Equal(t, samlUser.Email, user.Email)
    +		assert.Equal(t, samlUser.Username, user.Username)
    +	})
    +}
    +
     func TestGetUsersByNames(t *testing.T) {
     	mainHelper.Parallel(t)
     	th := Setup(t).InitBasic()
    
  • server/channels/app/oauth.go+1 1 modified
    @@ -977,7 +977,7 @@ func (a *App) SwitchEmailToOAuth(rctx request.CTX, w http.ResponseWriter, r *htt
     	stateProps["email"] = email
     
     	if service == model.UserAuthServiceSaml {
    -		samlToken, samlErr := a.CreateSamlRelayToken(email)
    +		samlToken, samlErr := a.CreateSamlRelayToken(model.TokenTypeSaml, email)
     		if samlErr != nil {
     			return "", samlErr
     		}
    
  • server/channels/app/saml.go+2 2 modified
    @@ -298,8 +298,8 @@ func (a *App) ResetSamlAuthDataToEmail(includeDeleted bool, dryRun bool, userIDs
     	return
     }
     
    -func (a *App) CreateSamlRelayToken(extra string) (*model.Token, *model.AppError) {
    -	token := model.NewToken(model.TokenTypeSaml, extra)
    +func (a *App) CreateSamlRelayToken(tokenType string, extra string) (*model.Token, *model.AppError) {
    +	token := model.NewToken(tokenType, extra)
     
     	if err := a.Srv().Store().Token().Save(token); err != nil {
     		var appErr *model.AppError
    
  • server/channels/app/user.go+2 2 modified
    @@ -1751,8 +1751,8 @@ func (a *App) GetTokenById(token string) (*model.Token, *model.AppError) {
     	return rtoken, nil
     }
     
    -func (a *App) ConsumeTokenOnce(tokenStr string) (*model.Token, *model.AppError) {
    -	token, err := a.Srv().Store().Token().ConsumeOnce(tokenStr)
    +func (a *App) ConsumeTokenOnce(tokenType, tokenStr string) (*model.Token, *model.AppError) {
    +	token, err := a.Srv().Store().Token().ConsumeOnce(tokenType, tokenStr)
     	if err != nil {
     		var status int
     		switch err.(type) {
    
  • server/channels/app/user_test.go+82 0 modified
    @@ -9,6 +9,7 @@ import (
     	"database/sql"
     	"encoding/json"
     	"errors"
    +	"net/http"
     	"os"
     	"path/filepath"
     	"strings"
    @@ -2484,3 +2485,84 @@ func TestRemoteUserDirectChannelCreation(t *testing.T) {
     		assert.Equal(t, model.ChannelTypeDirect, channel.Type)
     	})
     }
    +
    +func TestConsumeTokenOnce(t *testing.T) {
    +	mainHelper.Parallel(t)
    +	th := Setup(t).InitBasic()
    +	defer th.TearDown()
    +
    +	t.Run("successfully consume valid token", func(t *testing.T) {
    +		token := model.NewToken(model.TokenTypeOAuth, "extra-data")
    +		require.NoError(t, th.App.Srv().Store().Token().Save(token))
    +
    +		consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeOAuth, token.Token)
    +		require.Nil(t, appErr)
    +		require.NotNil(t, consumedToken)
    +		assert.Equal(t, token.Token, consumedToken.Token)
    +		assert.Equal(t, model.TokenTypeOAuth, consumedToken.Type)
    +		assert.Equal(t, "extra-data", consumedToken.Extra)
    +
    +		_, err := th.App.Srv().Store().Token().GetByToken(token.Token)
    +		require.Error(t, err)
    +	})
    +
    +	t.Run("token not found returns 404", func(t *testing.T) {
    +		nonExistentToken := model.NewRandomString(model.TokenSize)
    +
    +		consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeOAuth, nonExistentToken)
    +		require.NotNil(t, appErr)
    +		require.Nil(t, consumedToken)
    +		assert.Equal(t, http.StatusNotFound, appErr.StatusCode)
    +		assert.Equal(t, "ConsumeTokenOnce", appErr.Where)
    +	})
    +
    +	t.Run("wrong token type returns not found", func(t *testing.T) {
    +		token := model.NewToken(model.TokenTypeOAuth, "extra-data")
    +		require.NoError(t, th.App.Srv().Store().Token().Save(token))
    +		defer func() {
    +			_ = th.App.Srv().Store().Token().Delete(token.Token)
    +		}()
    +
    +		consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeSaml, token.Token)
    +		require.NotNil(t, appErr)
    +		require.Nil(t, consumedToken)
    +		assert.Equal(t, http.StatusNotFound, appErr.StatusCode)
    +
    +		_, err := th.App.Srv().Store().Token().GetByToken(token.Token)
    +		require.NoError(t, err)
    +	})
    +
    +	t.Run("token can only be consumed once", func(t *testing.T) {
    +		token := model.NewToken(model.TokenTypeSSOCodeExchange, "extra-data")
    +		require.NoError(t, th.App.Srv().Store().Token().Save(token))
    +
    +		consumedToken1, appErr := th.App.ConsumeTokenOnce(model.TokenTypeSSOCodeExchange, token.Token)
    +		require.Nil(t, appErr)
    +		require.NotNil(t, consumedToken1)
    +
    +		consumedToken2, appErr := th.App.ConsumeTokenOnce(model.TokenTypeSSOCodeExchange, token.Token)
    +		require.NotNil(t, appErr)
    +		require.Nil(t, consumedToken2)
    +		assert.Equal(t, http.StatusNotFound, appErr.StatusCode)
    +	})
    +
    +	t.Run("empty token string returns not found", func(t *testing.T) {
    +		consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeOAuth, "")
    +		require.NotNil(t, appErr)
    +		require.Nil(t, consumedToken)
    +		assert.Equal(t, http.StatusNotFound, appErr.StatusCode)
    +	})
    +
    +	t.Run("empty token type returns not found", func(t *testing.T) {
    +		token := model.NewToken(model.TokenTypeOAuth, "extra-data")
    +		require.NoError(t, th.App.Srv().Store().Token().Save(token))
    +		defer func() {
    +			_ = th.App.Srv().Store().Token().Delete(token.Token)
    +		}()
    +
    +		consumedToken, appErr := th.App.ConsumeTokenOnce("", token.Token)
    +		require.NotNil(t, appErr)
    +		require.Nil(t, consumedToken)
    +		assert.Equal(t, http.StatusNotFound, appErr.StatusCode)
    +	})
    +}
    
  • server/channels/store/retrylayer/retrylayer.go+42 0 modified
    @@ -6438,6 +6438,27 @@ func (s *RetryLayerJobStore) GetAllByTypesPage(rctx request.CTX, jobTypes []stri
     
     }
     
    +func (s *RetryLayerJobStore) GetByTypeAndData(rctx request.CTX, jobType string, data map[string]string, useMaster bool, statuses ...string) ([]*model.Job, error) {
    +
    +	tries := 0
    +	for {
    +		result, err := s.JobStore.GetByTypeAndData(rctx, jobType, data, useMaster, statuses...)
    +		if err == nil {
    +			return result, nil
    +		}
    +		if !isRepeatableError(err) {
    +			return result, err
    +		}
    +		tries++
    +		if tries >= 3 {
    +			err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures")
    +			return result, err
    +		}
    +		timepkg.Sleep(100 * timepkg.Millisecond)
    +	}
    +
    +}
    +
     func (s *RetryLayerJobStore) GetCountByStatusAndType(status string, jobType string) (int64, error) {
     
     	tries := 0
    @@ -14130,6 +14151,27 @@ func (s *RetryLayerTokenStore) Cleanup(expiryTime int64) {
     
     }
     
    +func (s *RetryLayerTokenStore) ConsumeOnce(tokenType string, tokenStr string) (*model.Token, error) {
    +
    +	tries := 0
    +	for {
    +		result, err := s.TokenStore.ConsumeOnce(tokenType, tokenStr)
    +		if err == nil {
    +			return result, nil
    +		}
    +		if !isRepeatableError(err) {
    +			return result, err
    +		}
    +		tries++
    +		if tries >= 3 {
    +			err = errors.Wrap(err, "giving up after 3 consecutive repeatable transaction failures")
    +			return result, err
    +		}
    +		timepkg.Sleep(100 * timepkg.Millisecond)
    +	}
    +
    +}
    +
     func (s *RetryLayerTokenStore) Delete(token string) error {
     
     	tries := 0
    
  • server/channels/store/sqlstore/tokens_store.go+4 4 modified
    @@ -78,16 +78,16 @@ func (s SqlTokenStore) GetByToken(tokenString string) (*model.Token, error) {
     	return &token, nil
     }
     
    -func (s SqlTokenStore) ConsumeOnce(tokenStr string) (*model.Token, error) {
    +func (s SqlTokenStore) ConsumeOnce(tokenType, tokenStr string) (*model.Token, error) {
     	var token model.Token
     
    -	query := `DELETE FROM Tokens WHERE Token = ? RETURNING *`
    +	query := `DELETE FROM Tokens WHERE Type = ? AND Token = ? RETURNING *`
     
    -	if err := s.GetMaster().Get(&token, query, tokenStr); err != nil {
    +	if err := s.GetMaster().Get(&token, query, tokenType, tokenStr); err != nil {
     		if err == sql.ErrNoRows {
     			return nil, store.NewErrNotFound("Token", tokenStr)
     		}
    -		return nil, errors.Wrapf(err, "failed to consume token")
    +		return nil, errors.Wrapf(err, "failed to consume token with type %s", tokenType)
     	}
     
     	return &token, nil
    
  • server/channels/store/store.go+1 1 modified
    @@ -692,7 +692,7 @@ type TokenStore interface {
     	Save(recovery *model.Token) error
     	Delete(token string) error
     	GetByToken(token string) (*model.Token, error)
    -	ConsumeOnce(tokenStr string) (*model.Token, error)
    +	ConsumeOnce(tokenType, tokenStr string) (*model.Token, error)
     	Cleanup(expiryTime int64)
     	GetAllTokensByType(tokenType string) ([]*model.Token, error)
     	RemoveAllTokensByType(tokenType string) error
    
  • server/channels/store/storetest/mocks/TokenStore.go+9 9 modified
    @@ -19,29 +19,29 @@ func (_m *TokenStore) Cleanup(expiryTime int64) {
     	_m.Called(expiryTime)
     }
     
    -// ConsumeOnce provides a mock function with given fields: tokenStr
    -func (_m *TokenStore) ConsumeOnce(tokenStr string) (*model.Token, error) {
    -	ret := _m.Called(tokenStr)
    +// ConsumeOnce provides a mock function with given fields: tokenType, tokenStr
    +func (_m *TokenStore) ConsumeOnce(tokenType string, tokenStr string) (*model.Token, error) {
    +	ret := _m.Called(tokenType, tokenStr)
     
     	if len(ret) == 0 {
     		panic("no return value specified for ConsumeOnce")
     	}
     
     	var r0 *model.Token
     	var r1 error
    -	if rf, ok := ret.Get(0).(func(string) (*model.Token, error)); ok {
    -		return rf(tokenStr)
    +	if rf, ok := ret.Get(0).(func(string, string) (*model.Token, error)); ok {
    +		return rf(tokenType, tokenStr)
     	}
    -	if rf, ok := ret.Get(0).(func(string) *model.Token); ok {
    -		r0 = rf(tokenStr)
    +	if rf, ok := ret.Get(0).(func(string, string) *model.Token); ok {
    +		r0 = rf(tokenType, tokenStr)
     	} else {
     		if ret.Get(0) != nil {
     			r0 = ret.Get(0).(*model.Token)
     		}
     	}
     
    -	if rf, ok := ret.Get(1).(func(string) error); ok {
    -		r1 = rf(tokenStr)
    +	if rf, ok := ret.Get(1).(func(string, string) error); ok {
    +		r1 = rf(tokenType, tokenStr)
     	} else {
     		r1 = ret.Error(1)
     	}
    
  • server/channels/store/storetest/tokens_store.go+128 0 modified
    @@ -16,6 +16,7 @@ import (
     
     func TestTokensStore(t *testing.T, rctx request.CTX, ss store.Store) {
     	t.Run("TokensCleanup", func(t *testing.T) { testTokensCleanup(t, rctx, ss) })
    +	t.Run("ConsumeOnce", func(t *testing.T) { testConsumeOnce(t, rctx, ss) })
     }
     
     func testTokensCleanup(t *testing.T, rctx request.CTX, ss store.Store) {
    @@ -41,3 +42,130 @@ func testTokensCleanup(t *testing.T, rctx request.CTX, ss store.Store) {
     	require.NoError(t, err)
     	assert.Len(t, tokens, 0)
     }
    +
    +func testConsumeOnce(t *testing.T, rctx request.CTX, ss store.Store) {
    +	t.Run("successfully consume token once", func(t *testing.T) {
    +		token := &model.Token{
    +			Token:    model.NewRandomString(model.TokenSize),
    +			CreateAt: model.GetMillis(),
    +			Type:     model.TokenTypeOAuth,
    +			Extra:    "test-extra",
    +		}
    +		err := ss.Token().Save(token)
    +		require.NoError(t, err)
    +
    +		consumedToken, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token)
    +		require.NoError(t, err)
    +		assert.Equal(t, token.Token, consumedToken.Token)
    +		assert.Equal(t, token.Type, consumedToken.Type)
    +		assert.Equal(t, token.Extra, consumedToken.Extra)
    +
    +		tokens, err := ss.Token().GetAllTokensByType(model.TokenTypeOAuth)
    +		require.NoError(t, err)
    +		assert.Len(t, tokens, 0)
    +	})
    +
    +	t.Run("second consumption of same token fails", func(t *testing.T) {
    +		token := &model.Token{
    +			Token:    model.NewRandomString(model.TokenSize),
    +			CreateAt: model.GetMillis(),
    +			Type:     model.TokenTypeOAuth,
    +			Extra:    "test-extra",
    +		}
    +		err := ss.Token().Save(token)
    +		require.NoError(t, err)
    +
    +		_, err = ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token)
    +		require.NoError(t, err)
    +
    +		_, err = ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token)
    +		require.Error(t, err)
    +		var nfErr *store.ErrNotFound
    +		assert.ErrorAs(t, err, &nfErr)
    +	})
    +
    +	t.Run("consume with wrong type fails", func(t *testing.T) {
    +		token := &model.Token{
    +			Token:    model.NewRandomString(model.TokenSize),
    +			CreateAt: model.GetMillis(),
    +			Type:     model.TokenTypeOAuth,
    +			Extra:    "test-extra",
    +		}
    +		err := ss.Token().Save(token)
    +		require.NoError(t, err)
    +
    +		_, err = ss.Token().ConsumeOnce(model.TokenTypeSSOCodeExchange, token.Token)
    +		require.Error(t, err)
    +		var nfErr *store.ErrNotFound
    +		assert.ErrorAs(t, err, &nfErr)
    +
    +		tokens, err := ss.Token().GetAllTokensByType(model.TokenTypeOAuth)
    +		require.NoError(t, err)
    +		assert.Len(t, tokens, 1)
    +
    +		err = ss.Token().Delete(token.Token)
    +		require.NoError(t, err)
    +	})
    +
    +	t.Run("consume non-existent token fails", func(t *testing.T) {
    +		nonExistentToken := model.NewRandomString(model.TokenSize)
    +		_, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, nonExistentToken)
    +		require.Error(t, err)
    +		var nfErr *store.ErrNotFound
    +		assert.ErrorAs(t, err, &nfErr)
    +	})
    +
    +	t.Run("multiple tokens with same type can each be consumed once", func(t *testing.T) {
    +		tokens := make([]*model.Token, 3)
    +		for i := range tokens {
    +			tokens[i] = &model.Token{
    +				Token:    model.NewRandomString(model.TokenSize),
    +				CreateAt: model.GetMillis(),
    +				Type:     model.TokenTypeOAuth,
    +				Extra:    "test-extra",
    +			}
    +			err := ss.Token().Save(tokens[i])
    +			require.NoError(t, err)
    +		}
    +
    +		for _, token := range tokens {
    +			consumedToken, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token)
    +			require.NoError(t, err)
    +			assert.Equal(t, token.Token, consumedToken.Token)
    +		}
    +
    +		allTokens, err := ss.Token().GetAllTokensByType(model.TokenTypeOAuth)
    +		require.NoError(t, err)
    +		assert.Len(t, allTokens, 0)
    +	})
    +
    +	t.Run("consuming token of different type leaves others intact", func(t *testing.T) {
    +		oauthToken := &model.Token{
    +			Token:    model.NewRandomString(model.TokenSize),
    +			CreateAt: model.GetMillis(),
    +			Type:     model.TokenTypeOAuth,
    +			Extra:    "oauth-extra",
    +		}
    +		codeExchangeToken := &model.Token{
    +			Token:    model.NewRandomString(model.TokenSize),
    +			CreateAt: model.GetMillis(),
    +			Type:     model.TokenTypeSSOCodeExchange,
    +			Extra:    "password-extra",
    +		}
    +		err := ss.Token().Save(oauthToken)
    +		require.NoError(t, err)
    +		err = ss.Token().Save(codeExchangeToken)
    +		require.NoError(t, err)
    +
    +		consumedToken, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, oauthToken.Token)
    +		require.NoError(t, err)
    +		assert.Equal(t, oauthToken.Token, consumedToken.Token)
    +
    +		codeExchangeTokens, err := ss.Token().GetAllTokensByType(model.TokenTypeSSOCodeExchange)
    +		require.NoError(t, err)
    +		assert.Len(t, codeExchangeTokens, 1)
    +
    +		err = ss.Token().Delete(codeExchangeToken.Token)
    +		require.NoError(t, err)
    +	})
    +}
    
  • server/channels/store/timerlayer/timerlayer.go+32 0 modified
    @@ -5159,6 +5159,22 @@ func (s *TimerLayerJobStore) GetAllByTypesPage(rctx request.CTX, jobTypes []stri
     	return result, err
     }
     
    +func (s *TimerLayerJobStore) GetByTypeAndData(rctx request.CTX, jobType string, data map[string]string, useMaster bool, statuses ...string) ([]*model.Job, error) {
    +	start := time.Now()
    +
    +	result, err := s.JobStore.GetByTypeAndData(rctx, jobType, data, useMaster, statuses...)
    +
    +	elapsed := float64(time.Since(start)) / float64(time.Second)
    +	if s.Root.Metrics != nil {
    +		success := "false"
    +		if err == nil {
    +			success = "true"
    +		}
    +		s.Root.Metrics.ObserveStoreMethodDuration("JobStore.GetByTypeAndData", success, elapsed)
    +	}
    +	return result, err
    +}
    +
     func (s *TimerLayerJobStore) GetCountByStatusAndType(status string, jobType string) (int64, error) {
     	start := time.Now()
     
    @@ -11106,6 +11122,22 @@ func (s *TimerLayerTokenStore) Cleanup(expiryTime int64) {
     	}
     }
     
    +func (s *TimerLayerTokenStore) ConsumeOnce(tokenType string, tokenStr string) (*model.Token, error) {
    +	start := time.Now()
    +
    +	result, err := s.TokenStore.ConsumeOnce(tokenType, tokenStr)
    +
    +	elapsed := float64(time.Since(start)) / float64(time.Second)
    +	if s.Root.Metrics != nil {
    +		success := "false"
    +		if err == nil {
    +			success = "true"
    +		}
    +		s.Root.Metrics.ObserveStoreMethodDuration("TokenStore.ConsumeOnce", success, elapsed)
    +	}
    +	return result, err
    +}
    +
     func (s *TimerLayerTokenStore) Delete(token string) error {
     	start := time.Now()
     
    
  • server/channels/web/saml.go+7 4 modified
    @@ -106,7 +106,7 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) {
     		return
     	}
     
    -	//Validate that the user is with SAML and all that
    +	// Validate that the user is with SAML and all that
     	encodedXML := r.FormValue("SAMLResponse")
     	relayState := r.FormValue("RelayState")
     
    @@ -161,7 +161,8 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) {
     		return
     	}
     
    -	if err = c.App.CheckUserAllAuthenticationCriteria(c.AppContext, user, ""); err != nil {
    +	err = c.App.CheckUserAllAuthenticationCriteria(c.AppContext, user, "")
    +	if err != nil {
     		handleError(err)
     		return
     	}
    @@ -250,8 +251,10 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) {
     			"code_challenge":        samlChallenge,
     			"code_challenge_method": samlMethod,
     		})
    -		code := model.NewToken(model.TokenTypeSaml, extra)
    -		if err := c.App.Srv().Store().Token().Save(code); err != nil {
    +
    +		var code *model.Token
    +		code, err = c.App.CreateSamlRelayToken(model.TokenTypeSSOCodeExchange, extra)
    +		if err != nil {
     			handleError(model.NewAppError("completeSaml", "app.recover.save.app_error", nil, "", http.StatusInternalServerError).Wrap(err))
     			return
     		}
    
  • server/public/model/token.go+5 4 modified
    @@ -8,10 +8,11 @@ import (
     )
     
     const (
    -	TokenSize          = 64
    -	MaxTokenExipryTime = 1000 * 60 * 60 * 48 // 48 hour
    -	TokenTypeOAuth     = "oauth"
    -	TokenTypeSaml      = "saml"
    +	TokenSize                = 64
    +	MaxTokenExipryTime       = 1000 * 60 * 60 * 48 // 48 hour
    +	TokenTypeOAuth           = "oauth"
    +	TokenTypeSaml            = "saml"
    +	TokenTypeSSOCodeExchange = "sso-code-exchange"
     )
     
     type Token struct {
    
acda1fb5dd46

MM-66299: type handling for ConsumeTokenOnce (#34247)

https://github.com/mattermost/mattermostJesse HallamOct 22, 2025via ghsa
14 files changed · +327 32
  • server/channels/api4/user.go+1 1 modified
    @@ -130,7 +130,7 @@ func loginSSOCodeExchange(c *Context, w http.ResponseWriter, r *http.Request) {
     	}
     
     	// Consume one-time code atomically
    -	token, appErr := c.App.ConsumeTokenOnce(loginCode)
    +	token, appErr := c.App.ConsumeTokenOnce(model.TokenTypeSSOCodeExchange, loginCode)
     	if appErr != nil {
     		c.Err = appErr
     		return
    
  • server/channels/api4/user_test.go+81 0 modified
    @@ -6,6 +6,8 @@ package api4
     import (
     	"bytes"
     	"context"
    +	"crypto/sha256"
    +	"encoding/base64"
     	"encoding/json"
     	"fmt"
     	"image/png"
    @@ -8490,6 +8492,85 @@ func TestLoginWithDesktopToken(t *testing.T) {
     	})
     }
     
    +func TestLoginSSOCodeExchange(t *testing.T) {
    +	mainHelper.Parallel(t)
    +	th := Setup(t).InitBasic()
    +	defer th.TearDown()
    +
    +	t.Run("wrong token type cannot be used for code exchange", func(t *testing.T) {
    +		th.App.UpdateConfig(func(cfg *model.Config) {
    +			cfg.FeatureFlags.MobileSSOCodeExchange = true
    +		})
    +
    +		token := model.NewToken(model.TokenTypeOAuth, "extra-data")
    +		require.NoError(t, th.App.Srv().Store().Token().Save(token))
    +		defer func() {
    +			_ = th.App.Srv().Store().Token().Delete(token.Token)
    +		}()
    +
    +		props := map[string]string{
    +			"login_code":    token.Token,
    +			"code_verifier": "test_verifier",
    +			"state":         "test_state",
    +		}
    +
    +		resp, err := th.Client.DoAPIPost(context.Background(), "/users/login/sso/code-exchange", model.MapToJSON(props))
    +		require.Error(t, err)
    +		require.Equal(t, http.StatusNotFound, resp.StatusCode)
    +	})
    +
    +	t.Run("successful code exchange with S256 challenge", func(t *testing.T) {
    +		th.App.UpdateConfig(func(cfg *model.Config) {
    +			cfg.FeatureFlags.MobileSSOCodeExchange = true
    +		})
    +
    +		samlUser := th.CreateUserWithAuth(model.UserAuthServiceSaml)
    +
    +		codeVerifier := "test_code_verifier_123456789"
    +		state := "test_state_value"
    +
    +		sum := sha256.Sum256([]byte(codeVerifier))
    +		codeChallenge := base64.RawURLEncoding.EncodeToString(sum[:])
    +
    +		extra := map[string]string{
    +			"user_id":               samlUser.Id,
    +			"code_challenge":        codeChallenge,
    +			"code_challenge_method": "S256",
    +			"state":                 state,
    +		}
    +
    +		token := model.NewToken(model.TokenTypeSSOCodeExchange, model.MapToJSON(extra))
    +		require.NoError(t, th.App.Srv().Store().Token().Save(token))
    +
    +		props := map[string]string{
    +			"login_code":    token.Token,
    +			"code_verifier": codeVerifier,
    +			"state":         state,
    +		}
    +
    +		resp, err := th.Client.DoAPIPost(context.Background(), "/users/login/sso/code-exchange", model.MapToJSON(props))
    +		require.NoError(t, err)
    +		require.Equal(t, http.StatusOK, resp.StatusCode)
    +
    +		var result map[string]string
    +		require.NoError(t, json.NewDecoder(resp.Body).Decode(&result))
    +		assert.NotEmpty(t, result["token"])
    +		assert.NotEmpty(t, result["csrf"])
    +
    +		_, err = th.App.Srv().Store().Token().GetByToken(token.Token)
    +		require.Error(t, err)
    +
    +		authenticatedClient := model.NewAPIv4Client(th.Client.URL)
    +		authenticatedClient.SetToken(result["token"])
    +
    +		user, _, err := authenticatedClient.GetMe(context.Background(), "")
    +		require.NoError(t, err)
    +		assert.Equal(t, samlUser.Id, user.Id)
    +		assert.Equal(t, samlUser.Email, user.Email)
    +		assert.Equal(t, samlUser.Username, user.Username)
    +	})
    +}
    +
     func TestGetUsersByNames(t *testing.T) {
     	mainHelper.Parallel(t)
     	th := Setup(t).InitBasic()
    
  • server/channels/app/oauth.go+1 1 modified
    @@ -977,7 +977,7 @@ func (a *App) SwitchEmailToOAuth(rctx request.CTX, w http.ResponseWriter, r *htt
     	stateProps["email"] = email
     
     	if service == model.UserAuthServiceSaml {
    -		samlToken, samlErr := a.CreateSamlRelayToken(email)
    +		samlToken, samlErr := a.CreateSamlRelayToken(model.TokenTypeSaml, email)
     		if samlErr != nil {
     			return "", samlErr
     		}
    
  • server/channels/app/saml.go+2 2 modified
    @@ -298,8 +298,8 @@ func (a *App) ResetSamlAuthDataToEmail(includeDeleted bool, dryRun bool, userIDs
     	return
     }
     
    -func (a *App) CreateSamlRelayToken(extra string) (*model.Token, *model.AppError) {
    -	token := model.NewToken(model.TokenTypeSaml, extra)
    +func (a *App) CreateSamlRelayToken(tokenType string, extra string) (*model.Token, *model.AppError) {
    +	token := model.NewToken(tokenType, extra)
     
     	if err := a.Srv().Store().Token().Save(token); err != nil {
     		var appErr *model.AppError
    
  • server/channels/app/user.go+2 2 modified
    @@ -1750,8 +1750,8 @@ func (a *App) GetTokenById(token string) (*model.Token, *model.AppError) {
     	return rtoken, nil
     }
     
    -func (a *App) ConsumeTokenOnce(tokenStr string) (*model.Token, *model.AppError) {
    -	token, err := a.Srv().Store().Token().ConsumeOnce(tokenStr)
    +func (a *App) ConsumeTokenOnce(tokenType, tokenStr string) (*model.Token, *model.AppError) {
    +	token, err := a.Srv().Store().Token().ConsumeOnce(tokenType, tokenStr)
     	if err != nil {
     		var status int
     		switch err.(type) {
    
  • server/channels/app/user_test.go+82 0 modified
    @@ -8,6 +8,7 @@ import (
     	"database/sql"
     	"encoding/json"
     	"errors"
    +	"net/http"
     	"os"
     	"path/filepath"
     	"strings"
    @@ -2483,3 +2484,84 @@ func TestRemoteUserDirectChannelCreation(t *testing.T) {
     		assert.Equal(t, model.ChannelTypeDirect, channel.Type)
     	})
     }
    +
    +func TestConsumeTokenOnce(t *testing.T) {
    +	mainHelper.Parallel(t)
    +	th := Setup(t).InitBasic()
    +	defer th.TearDown()
    +
    +	t.Run("successfully consume valid token", func(t *testing.T) {
    +		token := model.NewToken(model.TokenTypeOAuth, "extra-data")
    +		require.NoError(t, th.App.Srv().Store().Token().Save(token))
    +
    +		consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeOAuth, token.Token)
    +		require.Nil(t, appErr)
    +		require.NotNil(t, consumedToken)
    +		assert.Equal(t, token.Token, consumedToken.Token)
    +		assert.Equal(t, model.TokenTypeOAuth, consumedToken.Type)
    +		assert.Equal(t, "extra-data", consumedToken.Extra)
    +
    +		_, err := th.App.Srv().Store().Token().GetByToken(token.Token)
    +		require.Error(t, err)
    +	})
    +
    +	t.Run("token not found returns 404", func(t *testing.T) {
    +		nonExistentToken := model.NewRandomString(model.TokenSize)
    +
    +		consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeOAuth, nonExistentToken)
    +		require.NotNil(t, appErr)
    +		require.Nil(t, consumedToken)
    +		assert.Equal(t, http.StatusNotFound, appErr.StatusCode)
    +		assert.Equal(t, "ConsumeTokenOnce", appErr.Where)
    +	})
    +
    +	t.Run("wrong token type returns not found", func(t *testing.T) {
    +		token := model.NewToken(model.TokenTypeOAuth, "extra-data")
    +		require.NoError(t, th.App.Srv().Store().Token().Save(token))
    +		defer func() {
    +			_ = th.App.Srv().Store().Token().Delete(token.Token)
    +		}()
    +
    +		consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeSaml, token.Token)
    +		require.NotNil(t, appErr)
    +		require.Nil(t, consumedToken)
    +		assert.Equal(t, http.StatusNotFound, appErr.StatusCode)
    +
    +		_, err := th.App.Srv().Store().Token().GetByToken(token.Token)
    +		require.NoError(t, err)
    +	})
    +
    +	t.Run("token can only be consumed once", func(t *testing.T) {
    +		token := model.NewToken(model.TokenTypeSSOCodeExchange, "extra-data")
    +		require.NoError(t, th.App.Srv().Store().Token().Save(token))
    +
    +		consumedToken1, appErr := th.App.ConsumeTokenOnce(model.TokenTypeSSOCodeExchange, token.Token)
    +		require.Nil(t, appErr)
    +		require.NotNil(t, consumedToken1)
    +
    +		consumedToken2, appErr := th.App.ConsumeTokenOnce(model.TokenTypeSSOCodeExchange, token.Token)
    +		require.NotNil(t, appErr)
    +		require.Nil(t, consumedToken2)
    +		assert.Equal(t, http.StatusNotFound, appErr.StatusCode)
    +	})
    +
    +	t.Run("empty token string returns not found", func(t *testing.T) {
    +		consumedToken, appErr := th.App.ConsumeTokenOnce(model.TokenTypeOAuth, "")
    +		require.NotNil(t, appErr)
    +		require.Nil(t, consumedToken)
    +		assert.Equal(t, http.StatusNotFound, appErr.StatusCode)
    +	})
    +
    +	t.Run("empty token type returns not found", func(t *testing.T) {
    +		token := model.NewToken(model.TokenTypeOAuth, "extra-data")
    +		require.NoError(t, th.App.Srv().Store().Token().Save(token))
    +		defer func() {
    +			_ = th.App.Srv().Store().Token().Delete(token.Token)
    +		}()
    +
    +		consumedToken, appErr := th.App.ConsumeTokenOnce("", token.Token)
    +		require.NotNil(t, appErr)
    +		require.Nil(t, consumedToken)
    +		assert.Equal(t, http.StatusNotFound, appErr.StatusCode)
    +	})
    +}
    
  • server/channels/store/retrylayer/retrylayer.go+2 2 modified
    @@ -14293,11 +14293,11 @@ func (s *RetryLayerTokenStore) Cleanup(expiryTime int64) {
     
     }
     
    -func (s *RetryLayerTokenStore) ConsumeOnce(tokenStr string) (*model.Token, error) {
    +func (s *RetryLayerTokenStore) ConsumeOnce(tokenType string, tokenStr string) (*model.Token, error) {
     
     	tries := 0
     	for {
    -		result, err := s.TokenStore.ConsumeOnce(tokenStr)
    +		result, err := s.TokenStore.ConsumeOnce(tokenType, tokenStr)
     		if err == nil {
     			return result, nil
     		}
    
  • server/channels/store/sqlstore/tokens_store.go+4 4 modified
    @@ -78,16 +78,16 @@ func (s SqlTokenStore) GetByToken(tokenString string) (*model.Token, error) {
     	return &token, nil
     }
     
    -func (s SqlTokenStore) ConsumeOnce(tokenStr string) (*model.Token, error) {
    +func (s SqlTokenStore) ConsumeOnce(tokenType, tokenStr string) (*model.Token, error) {
     	var token model.Token
     
    -	query := `DELETE FROM Tokens WHERE Token = ? RETURNING *`
    +	query := `DELETE FROM Tokens WHERE Type = ? AND Token = ? RETURNING *`
     
    -	if err := s.GetMaster().Get(&token, query, tokenStr); err != nil {
    +	if err := s.GetMaster().Get(&token, query, tokenType, tokenStr); err != nil {
     		if err == sql.ErrNoRows {
     			return nil, store.NewErrNotFound("Token", tokenStr)
     		}
    -		return nil, errors.Wrapf(err, "failed to consume token")
    +		return nil, errors.Wrapf(err, "failed to consume token with type %s", tokenType)
     	}
     
     	return &token, nil
    
  • server/channels/store/store.go+1 1 modified
    @@ -696,7 +696,7 @@ type TokenStore interface {
     	Save(recovery *model.Token) error
     	Delete(token string) error
     	GetByToken(token string) (*model.Token, error)
    -	ConsumeOnce(tokenStr string) (*model.Token, error)
    +	ConsumeOnce(tokenType, tokenStr string) (*model.Token, error)
     	Cleanup(expiryTime int64)
     	GetAllTokensByType(tokenType string) ([]*model.Token, error)
     	RemoveAllTokensByType(tokenType string) error
    
  • server/channels/store/storetest/mocks/TokenStore.go+9 9 modified
    @@ -19,29 +19,29 @@ func (_m *TokenStore) Cleanup(expiryTime int64) {
     	_m.Called(expiryTime)
     }
     
    -// ConsumeOnce provides a mock function with given fields: tokenStr
    -func (_m *TokenStore) ConsumeOnce(tokenStr string) (*model.Token, error) {
    -	ret := _m.Called(tokenStr)
    +// ConsumeOnce provides a mock function with given fields: tokenType, tokenStr
    +func (_m *TokenStore) ConsumeOnce(tokenType string, tokenStr string) (*model.Token, error) {
    +	ret := _m.Called(tokenType, tokenStr)
     
     	if len(ret) == 0 {
     		panic("no return value specified for ConsumeOnce")
     	}
     
     	var r0 *model.Token
     	var r1 error
    -	if rf, ok := ret.Get(0).(func(string) (*model.Token, error)); ok {
    -		return rf(tokenStr)
    +	if rf, ok := ret.Get(0).(func(string, string) (*model.Token, error)); ok {
    +		return rf(tokenType, tokenStr)
     	}
    -	if rf, ok := ret.Get(0).(func(string) *model.Token); ok {
    -		r0 = rf(tokenStr)
    +	if rf, ok := ret.Get(0).(func(string, string) *model.Token); ok {
    +		r0 = rf(tokenType, tokenStr)
     	} else {
     		if ret.Get(0) != nil {
     			r0 = ret.Get(0).(*model.Token)
     		}
     	}
     
    -	if rf, ok := ret.Get(1).(func(string) error); ok {
    -		r1 = rf(tokenStr)
    +	if rf, ok := ret.Get(1).(func(string, string) error); ok {
    +		r1 = rf(tokenType, tokenStr)
     	} else {
     		r1 = ret.Error(1)
     	}
    
  • server/channels/store/storetest/tokens_store.go+128 0 modified
    @@ -16,6 +16,7 @@ import (
     
     func TestTokensStore(t *testing.T, rctx request.CTX, ss store.Store) {
     	t.Run("TokensCleanup", func(t *testing.T) { testTokensCleanup(t, rctx, ss) })
    +	t.Run("ConsumeOnce", func(t *testing.T) { testConsumeOnce(t, rctx, ss) })
     }
     
     func testTokensCleanup(t *testing.T, rctx request.CTX, ss store.Store) {
    @@ -41,3 +42,130 @@ func testTokensCleanup(t *testing.T, rctx request.CTX, ss store.Store) {
     	require.NoError(t, err)
     	assert.Len(t, tokens, 0)
     }
    +
    +func testConsumeOnce(t *testing.T, rctx request.CTX, ss store.Store) {
    +	t.Run("successfully consume token once", func(t *testing.T) {
    +		token := &model.Token{
    +			Token:    model.NewRandomString(model.TokenSize),
    +			CreateAt: model.GetMillis(),
    +			Type:     model.TokenTypeOAuth,
    +			Extra:    "test-extra",
    +		}
    +		err := ss.Token().Save(token)
    +		require.NoError(t, err)
    +
    +		consumedToken, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token)
    +		require.NoError(t, err)
    +		assert.Equal(t, token.Token, consumedToken.Token)
    +		assert.Equal(t, token.Type, consumedToken.Type)
    +		assert.Equal(t, token.Extra, consumedToken.Extra)
    +
    +		tokens, err := ss.Token().GetAllTokensByType(model.TokenTypeOAuth)
    +		require.NoError(t, err)
    +		assert.Len(t, tokens, 0)
    +	})
    +
    +	t.Run("second consumption of same token fails", func(t *testing.T) {
    +		token := &model.Token{
    +			Token:    model.NewRandomString(model.TokenSize),
    +			CreateAt: model.GetMillis(),
    +			Type:     model.TokenTypeOAuth,
    +			Extra:    "test-extra",
    +		}
    +		err := ss.Token().Save(token)
    +		require.NoError(t, err)
    +
    +		_, err = ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token)
    +		require.NoError(t, err)
    +
    +		_, err = ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token)
    +		require.Error(t, err)
    +		var nfErr *store.ErrNotFound
    +		assert.ErrorAs(t, err, &nfErr)
    +	})
    +
    +	t.Run("consume with wrong type fails", func(t *testing.T) {
    +		token := &model.Token{
    +			Token:    model.NewRandomString(model.TokenSize),
    +			CreateAt: model.GetMillis(),
    +			Type:     model.TokenTypeOAuth,
    +			Extra:    "test-extra",
    +		}
    +		err := ss.Token().Save(token)
    +		require.NoError(t, err)
    +
    +		_, err = ss.Token().ConsumeOnce(model.TokenTypeSSOCodeExchange, token.Token)
    +		require.Error(t, err)
    +		var nfErr *store.ErrNotFound
    +		assert.ErrorAs(t, err, &nfErr)
    +
    +		tokens, err := ss.Token().GetAllTokensByType(model.TokenTypeOAuth)
    +		require.NoError(t, err)
    +		assert.Len(t, tokens, 1)
    +
    +		err = ss.Token().Delete(token.Token)
    +		require.NoError(t, err)
    +	})
    +
    +	t.Run("consume non-existent token fails", func(t *testing.T) {
    +		nonExistentToken := model.NewRandomString(model.TokenSize)
    +		_, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, nonExistentToken)
    +		require.Error(t, err)
    +		var nfErr *store.ErrNotFound
    +		assert.ErrorAs(t, err, &nfErr)
    +	})
    +
    +	t.Run("multiple tokens with same type can each be consumed once", func(t *testing.T) {
    +		tokens := make([]*model.Token, 3)
    +		for i := range tokens {
    +			tokens[i] = &model.Token{
    +				Token:    model.NewRandomString(model.TokenSize),
    +				CreateAt: model.GetMillis(),
    +				Type:     model.TokenTypeOAuth,
    +				Extra:    "test-extra",
    +			}
    +			err := ss.Token().Save(tokens[i])
    +			require.NoError(t, err)
    +		}
    +
    +		for _, token := range tokens {
    +			consumedToken, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, token.Token)
    +			require.NoError(t, err)
    +			assert.Equal(t, token.Token, consumedToken.Token)
    +		}
    +
    +		allTokens, err := ss.Token().GetAllTokensByType(model.TokenTypeOAuth)
    +		require.NoError(t, err)
    +		assert.Len(t, allTokens, 0)
    +	})
    +
    +	t.Run("consuming token of different type leaves others intact", func(t *testing.T) {
    +		oauthToken := &model.Token{
    +			Token:    model.NewRandomString(model.TokenSize),
    +			CreateAt: model.GetMillis(),
    +			Type:     model.TokenTypeOAuth,
    +			Extra:    "oauth-extra",
    +		}
    +		codeExchangeToken := &model.Token{
    +			Token:    model.NewRandomString(model.TokenSize),
    +			CreateAt: model.GetMillis(),
    +			Type:     model.TokenTypeSSOCodeExchange,
    +			Extra:    "password-extra",
    +		}
    +		err := ss.Token().Save(oauthToken)
    +		require.NoError(t, err)
    +		err = ss.Token().Save(codeExchangeToken)
    +		require.NoError(t, err)
    +
    +		consumedToken, err := ss.Token().ConsumeOnce(model.TokenTypeOAuth, oauthToken.Token)
    +		require.NoError(t, err)
    +		assert.Equal(t, oauthToken.Token, consumedToken.Token)
    +
    +		codeExchangeTokens, err := ss.Token().GetAllTokensByType(model.TokenTypeSSOCodeExchange)
    +		require.NoError(t, err)
    +		assert.Len(t, codeExchangeTokens, 1)
    +
    +		err = ss.Token().Delete(codeExchangeToken.Token)
    +		require.NoError(t, err)
    +	})
    +}
    
  • server/channels/store/timerlayer/timerlayer.go+2 2 modified
    @@ -11243,10 +11243,10 @@ func (s *TimerLayerTokenStore) Cleanup(expiryTime int64) {
     	}
     }
     
    -func (s *TimerLayerTokenStore) ConsumeOnce(tokenStr string) (*model.Token, error) {
    +func (s *TimerLayerTokenStore) ConsumeOnce(tokenType string, tokenStr string) (*model.Token, error) {
     	start := time.Now()
     
    -	result, err := s.TokenStore.ConsumeOnce(tokenStr)
    +	result, err := s.TokenStore.ConsumeOnce(tokenType, tokenStr)
     
     	elapsed := float64(time.Since(start)) / float64(time.Second)
     	if s.Root.Metrics != nil {
    
  • server/channels/web/saml.go+7 4 modified
    @@ -106,7 +106,7 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) {
     		return
     	}
     
    -	//Validate that the user is with SAML and all that
    +	// Validate that the user is with SAML and all that
     	encodedXML := r.FormValue("SAMLResponse")
     	relayState := r.FormValue("RelayState")
     
    @@ -161,7 +161,8 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) {
     		return
     	}
     
    -	if err = c.App.CheckUserAllAuthenticationCriteria(c.AppContext, user, ""); err != nil {
    +	err = c.App.CheckUserAllAuthenticationCriteria(c.AppContext, user, "")
    +	if err != nil {
     		handleError(err)
     		return
     	}
    @@ -250,8 +251,10 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) {
     			"code_challenge":        samlChallenge,
     			"code_challenge_method": samlMethod,
     		})
    -		code := model.NewToken(model.TokenTypeSaml, extra)
    -		if err := c.App.Srv().Store().Token().Save(code); err != nil {
    +
    +		var code *model.Token
    +		code, err = c.App.CreateSamlRelayToken(model.TokenTypeSSOCodeExchange, extra)
    +		if err != nil {
     			handleError(model.NewAppError("completeSaml", "app.recover.save.app_error", nil, "", http.StatusInternalServerError).Wrap(err))
     			return
     		}
    
  • server/public/model/token.go+5 4 modified
    @@ -8,10 +8,11 @@ import (
     )
     
     const (
    -	TokenSize          = 64
    -	MaxTokenExipryTime = 1000 * 60 * 60 * 48 // 48 hour
    -	TokenTypeOAuth     = "oauth"
    -	TokenTypeSaml      = "saml"
    +	TokenSize                = 64
    +	MaxTokenExipryTime       = 1000 * 60 * 60 * 48 // 48 hour
    +	TokenTypeOAuth           = "oauth"
    +	TokenTypeSaml            = "saml"
    +	TokenTypeSSOCodeExchange = "sso-code-exchange"
     )
     
     type Token struct {
    

Vulnerability mechanics

Generated by null/stub on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.

References

7

News mentions

0

No linked articles in our index yet.