VYPR
Critical severityNVD Advisory· Published Jul 1, 2024· Updated Aug 2, 2024

Fiber Session Middleware Token Injection Vulnerability

CVE-2024-38513

Description

Fiber is an Express-inspired web framework written in Go A vulnerability present in versions prior to 2.52.5 is a session middleware issue in GoFiber versions 2 and above. This vulnerability allows users to supply their own session_id value, resulting in the creation of a session with that key. If a website relies on the mere presence of a session for security purposes, this can lead to significant security risks, including unauthorized access and session fixation attacks. All users utilizing GoFiber's session middleware in the affected versions are impacted. The issue has been addressed in version 2.52.5. Users are strongly encouraged to upgrade to version 2.52.5 or higher to mitigate this vulnerability. Users who are unable to upgrade immediately can apply the following workarounds to reduce the risk: Either implement additional validation to ensure session IDs are not supplied by the user and are securely generated by the server, or regularly rotate session IDs and enforce strict session expiration policies.

Affected packages

Versions sourced from the GitHub Security Advisory.

PackageAffected versionsPatched versions
github.com/gofiber/fiberGo
< 2.52.52.52.5
github.com/gofiber/fiber/v2Go
< 2.52.52.52.5
github.com/gofiber/fiber/v2/middleware/sessionGo
< 2.52.52.52.5

Affected products

1

Patches

2
66a881441b27

fix(middleware/session): mutex for thread safety (#3050)

https://github.com/gofiber/fiberJason McNeilJun 30, 2024via ghsa
3 files changed · +276 24
  • middleware/session/session.go+41 8 modified
    @@ -14,6 +14,7 @@ import (
     )
     
     type Session struct {
    +	mu         sync.RWMutex  // Mutex to protect non-data fields
     	id         string        // session id
     	fresh      bool          // if new session
     	ctx        *fiber.Ctx    // fiber context
    @@ -42,6 +43,7 @@ func acquireSession() *Session {
     }
     
     func releaseSession(s *Session) {
    +	s.mu.Lock()
     	s.id = ""
     	s.exp = 0
     	s.ctx = nil
    @@ -52,16 +54,21 @@ func releaseSession(s *Session) {
     	if s.byteBuffer != nil {
     		s.byteBuffer.Reset()
     	}
    +	s.mu.Unlock()
     	sessionPool.Put(s)
     }
     
     // Fresh is true if the current session is new
     func (s *Session) Fresh() bool {
    +	s.mu.RLock()
    +	defer s.mu.RUnlock()
     	return s.fresh
     }
     
     // ID returns the session id
     func (s *Session) ID() string {
    +	s.mu.RLock()
    +	defer s.mu.RUnlock()
     	return s.id
     }
     
    @@ -102,6 +109,9 @@ func (s *Session) Destroy() error {
     	// Reset local data
     	s.data.Reset()
     
    +	s.mu.RLock()
    +	defer s.mu.RUnlock()
    +
     	// Use external Storage if exist
     	if err := s.config.Storage.Delete(s.id); err != nil {
     		return err
    @@ -114,6 +124,9 @@ func (s *Session) Destroy() error {
     
     // Regenerate generates a new session id and delete the old one from Storage
     func (s *Session) Regenerate() error {
    +	s.mu.Lock()
    +	defer s.mu.Unlock()
    +
     	// Delete old id from storage
     	if err := s.config.Storage.Delete(s.id); err != nil {
     		return err
    @@ -131,6 +144,10 @@ func (s *Session) Reset() error {
     	if s.data != nil {
     		s.data.Reset()
     	}
    +
    +	s.mu.Lock()
    +	defer s.mu.Unlock()
    +
     	// Reset byte buffer
     	if s.byteBuffer != nil {
     		s.byteBuffer.Reset()
    @@ -154,20 +171,24 @@ func (s *Session) Reset() error {
     
     // refresh generates a new session, and set session.fresh to be true
     func (s *Session) refresh() {
    -	// Create a new id
     	s.id = s.config.KeyGenerator()
    -
    -	// We assign a new id to the session, so the session must be fresh
     	s.fresh = true
     }
     
     // Save will update the storage and client cookie
    +//
    +// sess.Save() will save the session data to the storage and update the
    +// client cookie, and it will release the session after saving.
    +//
    +// It's not safe to use the session after calling Save().
     func (s *Session) Save() error {
     	// Better safe than sorry
     	if s.data == nil {
     		return nil
     	}
     
    +	s.mu.Lock()
    +
     	// Check if session has your own expiration, otherwise use default value
     	if s.exp <= 0 {
     		s.exp = s.config.Expiration
    @@ -177,25 +198,25 @@ func (s *Session) Save() error {
     	s.setSession()
     
     	// Convert data to bytes
    -	mux.Lock()
    -	defer mux.Unlock()
     	encCache := gob.NewEncoder(s.byteBuffer)
     	err := encCache.Encode(&s.data.Data)
     	if err != nil {
     		return fmt.Errorf("failed to encode data: %w", err)
     	}
     
    -	// copy the data in buffer
    +	// Copy the data in buffer
     	encodedBytes := make([]byte, s.byteBuffer.Len())
     	copy(encodedBytes, s.byteBuffer.Bytes())
     
    -	// pass copied bytes with session id to provider
    +	// Pass copied bytes with session id to provider
     	if err := s.config.Storage.Set(s.id, encodedBytes, s.exp); err != nil {
     		return err
     	}
     
    +	s.mu.Unlock()
    +
     	// Release session
    -	// TODO: It's not safe to use the Session after called Save()
    +	// TODO: It's not safe to use the Session after calling Save()
     	releaseSession(s)
     
     	return nil
    @@ -211,6 +232,8 @@ func (s *Session) Keys() []string {
     
     // SetExpiry sets a specific expiration for this session
     func (s *Session) SetExpiry(exp time.Duration) {
    +	s.mu.Lock()
    +	defer s.mu.Unlock()
     	s.exp = exp
     }
     
    @@ -276,3 +299,13 @@ func (s *Session) delSession() {
     		fasthttp.ReleaseCookie(fcookie)
     	}
     }
    +
    +// decodeSessionData decodes the session data from raw bytes.
    +func (s *Session) decodeSessionData(rawData []byte) error {
    +	_, _ = s.byteBuffer.Write(rawData) //nolint:errcheck // This will never fail
    +	encCache := gob.NewDecoder(s.byteBuffer)
    +	if err := encCache.Decode(&s.data.Data); err != nil {
    +		return fmt.Errorf("failed to decode session data: %w", err)
    +	}
    +	return nil
    +}
    
  • middleware/session/session_test.go+229 0 modified
    @@ -1,6 +1,8 @@
     package session
     
     import (
    +	"errors"
    +	"sync"
     	"testing"
     	"time"
     
    @@ -673,3 +675,230 @@ func Benchmark_Session(b *testing.B) {
     		utils.AssertEqual(b, nil, err)
     	})
     }
    +
    +// go test -v -run=^$ -bench=Benchmark_Session_Parallel -benchmem -count=4
    +func Benchmark_Session_Parallel(b *testing.B) {
    +	b.Run("default", func(b *testing.B) {
    +		app, store := fiber.New(), New()
    +		b.ReportAllocs()
    +		b.ResetTimer()
    +		b.RunParallel(func(pb *testing.PB) {
    +			for pb.Next() {
    +				c := app.AcquireCtx(&fasthttp.RequestCtx{})
    +				c.Request().Header.SetCookie(store.sessionName, "12356789")
    +
    +				sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
    +				sess.Set("john", "doe")
    +				_ = sess.Save() //nolint:errcheck // We're inside a benchmark
    +				app.ReleaseCtx(c)
    +			}
    +		})
    +	})
    +
    +	b.Run("storage", func(b *testing.B) {
    +		app := fiber.New()
    +		store := New(Config{
    +			Storage: memory.New(),
    +		})
    +		b.ReportAllocs()
    +		b.ResetTimer()
    +		b.RunParallel(func(pb *testing.PB) {
    +			for pb.Next() {
    +				c := app.AcquireCtx(&fasthttp.RequestCtx{})
    +				c.Request().Header.SetCookie(store.sessionName, "12356789")
    +
    +				sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark
    +				sess.Set("john", "doe")
    +				_ = sess.Save() //nolint:errcheck // We're inside a benchmark
    +				app.ReleaseCtx(c)
    +			}
    +		})
    +	})
    +}
    +
    +// go test -v -run=^$ -bench=Benchmark_Session_Asserted -benchmem -count=4
    +func Benchmark_Session_Asserted(b *testing.B) {
    +	b.Run("default", func(b *testing.B) {
    +		app, store := fiber.New(), New()
    +		c := app.AcquireCtx(&fasthttp.RequestCtx{})
    +		defer app.ReleaseCtx(c)
    +		c.Request().Header.SetCookie(store.sessionName, "12356789")
    +
    +		b.ReportAllocs()
    +		b.ResetTimer()
    +		for n := 0; n < b.N; n++ {
    +			sess, err := store.Get(c)
    +			utils.AssertEqual(b, nil, err)
    +			sess.Set("john", "doe")
    +			err = sess.Save()
    +			utils.AssertEqual(b, nil, err)
    +		}
    +	})
    +
    +	b.Run("storage", func(b *testing.B) {
    +		app := fiber.New()
    +		store := New(Config{
    +			Storage: memory.New(),
    +		})
    +		c := app.AcquireCtx(&fasthttp.RequestCtx{})
    +		defer app.ReleaseCtx(c)
    +		c.Request().Header.SetCookie(store.sessionName, "12356789")
    +
    +		b.ReportAllocs()
    +		b.ResetTimer()
    +		for n := 0; n < b.N; n++ {
    +			sess, err := store.Get(c)
    +			utils.AssertEqual(b, nil, err)
    +			sess.Set("john", "doe")
    +			err = sess.Save()
    +			utils.AssertEqual(b, nil, err)
    +		}
    +	})
    +}
    +
    +// go test -v -run=^$ -bench=Benchmark_Session_Asserted_Parallel -benchmem -count=4
    +func Benchmark_Session_Asserted_Parallel(b *testing.B) {
    +	b.Run("default", func(b *testing.B) {
    +		app, store := fiber.New(), New()
    +		b.ReportAllocs()
    +		b.ResetTimer()
    +		b.RunParallel(func(pb *testing.PB) {
    +			for pb.Next() {
    +				c := app.AcquireCtx(&fasthttp.RequestCtx{})
    +				c.Request().Header.SetCookie(store.sessionName, "12356789")
    +
    +				sess, err := store.Get(c)
    +				utils.AssertEqual(b, nil, err)
    +				sess.Set("john", "doe")
    +				utils.AssertEqual(b, nil, sess.Save())
    +				app.ReleaseCtx(c)
    +			}
    +		})
    +	})
    +
    +	b.Run("storage", func(b *testing.B) {
    +		app := fiber.New()
    +		store := New(Config{
    +			Storage: memory.New(),
    +		})
    +		b.ReportAllocs()
    +		b.ResetTimer()
    +		b.RunParallel(func(pb *testing.PB) {
    +			for pb.Next() {
    +				c := app.AcquireCtx(&fasthttp.RequestCtx{})
    +				c.Request().Header.SetCookie(store.sessionName, "12356789")
    +
    +				sess, err := store.Get(c)
    +				utils.AssertEqual(b, nil, err)
    +				sess.Set("john", "doe")
    +				utils.AssertEqual(b, nil, sess.Save())
    +				app.ReleaseCtx(c)
    +			}
    +		})
    +	})
    +}
    +
    +// go test -v -race -run Test_Session_Concurrency ./...
    +func Test_Session_Concurrency(t *testing.T) {
    +	t.Parallel()
    +	app := fiber.New()
    +	store := New()
    +
    +	var wg sync.WaitGroup
    +	errChan := make(chan error, 10) // Buffered channel to collect errors
    +	const numGoroutines = 10        // Number of concurrent goroutines to test
    +
    +	// Start numGoroutines goroutines
    +	for i := 0; i < numGoroutines; i++ {
    +		wg.Add(1)
    +		go func() {
    +			defer wg.Done()
    +
    +			localCtx := app.AcquireCtx(&fasthttp.RequestCtx{})
    +
    +			sess, err := store.Get(localCtx)
    +			if err != nil {
    +				errChan <- err
    +				return
    +			}
    +
    +			// Set a value
    +			sess.Set("name", "john")
    +
    +			// get the session id
    +			id := sess.ID()
    +
    +			// Check if the session is fresh
    +			if !sess.Fresh() {
    +				errChan <- errors.New("session should be fresh")
    +				return
    +			}
    +
    +			// Save the session
    +			if err := sess.Save(); err != nil {
    +				errChan <- err
    +				return
    +			}
    +
    +			// Release the context
    +			app.ReleaseCtx(localCtx)
    +
    +			// Acquire a new context
    +			localCtx = app.AcquireCtx(&fasthttp.RequestCtx{})
    +			defer app.ReleaseCtx(localCtx)
    +
    +			// Set the session id in the header
    +			localCtx.Request().Header.SetCookie(store.sessionName, id)
    +
    +			// Get the session
    +			sess, err = store.Get(localCtx)
    +			if err != nil {
    +				errChan <- err
    +				return
    +			}
    +
    +			// Get the value
    +			name := sess.Get("name")
    +			if name != "john" {
    +				errChan <- errors.New("name should be john")
    +				return
    +			}
    +
    +			// Get ID from the session
    +			if sess.ID() != id {
    +				errChan <- errors.New("id should be the same")
    +				return
    +			}
    +
    +			// Check if the session is fresh
    +			if sess.Fresh() {
    +				errChan <- errors.New("session should not be fresh")
    +				return
    +			}
    +
    +			// Delete the key
    +			sess.Delete("name")
    +
    +			// Get the value
    +			name = sess.Get("name")
    +			if name != nil {
    +				errChan <- errors.New("name should be nil")
    +				return
    +			}
    +
    +			// Destroy the session
    +			if err := sess.Destroy(); err != nil {
    +				errChan <- err
    +				return
    +			}
    +		}()
    +	}
    +
    +	wg.Wait()      // Wait for all goroutines to finish
    +	close(errChan) // Close the channel to signal no more errors will be sent
    +
    +	// Check for errors sent to errChan
    +	for err := range errChan {
    +		utils.AssertEqual(t, nil, err)
    +	}
    +}
    
  • middleware/session/store.go+6 16 modified
    @@ -4,7 +4,6 @@ import (
     	"encoding/gob"
     	"errors"
     	"fmt"
    -	"sync"
     
     	"github.com/gofiber/fiber/v2"
     	"github.com/gofiber/fiber/v2/internal/storage/memory"
    @@ -14,9 +13,6 @@ import (
     // ErrEmptySessionID is an error that occurs when the session ID is empty.
     var ErrEmptySessionID = errors.New("session id cannot be empty")
     
    -// mux is a global mutex for session operations.
    -var mux sync.Mutex
    -
     // sessionIDKey is the local key type used to store and retrieve the session ID in context.
     type sessionIDKey int
     
    @@ -81,13 +77,19 @@ func (s *Store) Get(c *fiber.Ctx) (*Session, error) {
     
     	// Create session object
     	sess := acquireSession()
    +
    +	sess.mu.Lock()
    +	defer sess.mu.Unlock()
    +
     	sess.ctx = c
     	sess.config = s
     	sess.id = id
     	sess.fresh = fresh
     
     	// Decode session data if found
     	if rawData != nil {
    +		sess.data.Lock()
    +		defer sess.data.Unlock()
     		if err := sess.decodeSessionData(rawData); err != nil {
     			return nil, fmt.Errorf("failed to decode session data: %w", err)
     		}
    @@ -132,15 +134,3 @@ func (s *Store) Delete(id string) error {
     	}
     	return s.Storage.Delete(id)
     }
    -
    -// decodeSessionData decodes the session data from raw bytes.
    -func (s *Session) decodeSessionData(rawData []byte) error {
    -	mux.Lock()
    -	defer mux.Unlock()
    -	_, _ = s.byteBuffer.Write(rawData) //nolint:errcheck // This will never fail
    -	encCache := gob.NewDecoder(s.byteBuffer)
    -	if err := encCache.Decode(&s.data.Data); err != nil {
    -		return fmt.Errorf("failed to decode session data: %w", err)
    -	}
    -	return nil
    -}
    
7926e5bf4da0

Merge pull request from GHSA-98j2-3j3p-fw2v

https://github.com/gofiber/fiberJason McNeilJun 26, 2024via ghsa
4 files changed · +144 74
  • middleware/csrf/csrf_test.go+4 0 modified
    @@ -88,6 +88,8 @@ func Test_CSRF_WithSession(t *testing.T) {
     
     	// the session string is no longer be 123
     	newSessionIDString := sess.ID()
    +	sess.Save()
    +
     	app.AcquireCtx(ctx).Request().Header.SetCookie("_session", newSessionIDString)
     
     	// middleware config
    @@ -221,6 +223,8 @@ func Test_CSRF_ExpiredToken_WithSession(t *testing.T) {
     
     	// get session id
     	newSessionIDString := sess.ID()
    +	sess.Save()
    +
     	app.AcquireCtx(ctx).Request().Header.SetCookie("_session", newSessionIDString)
     
     	// middleware config
    
  • middleware/session/session_test.go+84 11 modified
    @@ -25,13 +25,23 @@ func Test_Session(t *testing.T) {
     	ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
     	defer app.ReleaseCtx(ctx)
     
    +	// Get a new session
    +	sess, err := store.Get(ctx)
    +	utils.AssertEqual(t, nil, err)
    +	utils.AssertEqual(t, true, sess.Fresh())
    +	token := sess.ID()
    +	sess.Save()
    +
    +	app.ReleaseCtx(ctx)
    +	ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
    +
     	// set session
    -	ctx.Request().Header.SetCookie(store.sessionName, "123")
    +	ctx.Request().Header.SetCookie(store.sessionName, token)
     
     	// get session
    -	sess, err := store.Get(ctx)
    +	sess, err = store.Get(ctx)
     	utils.AssertEqual(t, nil, err)
    -	utils.AssertEqual(t, true, sess.Fresh())
    +	utils.AssertEqual(t, false, sess.Fresh())
     
     	// get keys
     	keys := sess.Keys()
    @@ -64,12 +74,14 @@ func Test_Session(t *testing.T) {
     
     	// get id
     	id := sess.ID()
    -	utils.AssertEqual(t, "123", id)
    +	utils.AssertEqual(t, token, id)
     
     	// save the old session first
     	err = sess.Save()
     	utils.AssertEqual(t, nil, err)
     
    +	app.ReleaseCtx(ctx)
    +
     	// requesting entirely new context to prevent falsy tests
     	ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
     	defer app.ReleaseCtx(ctx)
    @@ -108,7 +120,6 @@ func Test_Session_Types(t *testing.T) {
     
     	// fiber context
     	ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
    -	defer app.ReleaseCtx(ctx)
     
     	// set cookie
     	ctx.Request().Header.SetCookie(store.sessionName, "123")
    @@ -120,6 +131,10 @@ func Test_Session_Types(t *testing.T) {
     
     	// the session string is no longer be 123
     	newSessionIDString := sess.ID()
    +
    +	app.ReleaseCtx(ctx)
    +	ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
    +
     	ctx.Request().Header.SetCookie(store.sessionName, newSessionIDString)
     
     	type User struct {
    @@ -177,6 +192,11 @@ func Test_Session_Types(t *testing.T) {
     	err = sess.Save()
     	utils.AssertEqual(t, nil, err)
     
    +	app.ReleaseCtx(ctx)
    +	ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
    +
    +	ctx.Request().Header.SetCookie(store.sessionName, newSessionIDString)
    +
     	// get session
     	sess, err = store.Get(ctx)
     	utils.AssertEqual(t, nil, err)
    @@ -203,6 +223,7 @@ func Test_Session_Types(t *testing.T) {
     	utils.AssertEqual(t, vfloat64, sess.Get("vfloat64").(float64))
     	utils.AssertEqual(t, vcomplex64, sess.Get("vcomplex64").(complex64))
     	utils.AssertEqual(t, vcomplex128, sess.Get("vcomplex128").(complex128))
    +	app.ReleaseCtx(ctx)
     }
     
     // go test -run Test_Session_Store_Reset
    @@ -214,7 +235,6 @@ func Test_Session_Store_Reset(t *testing.T) {
     	app := fiber.New()
     	// fiber context
     	ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
    -	defer app.ReleaseCtx(ctx)
     
     	// get session
     	sess, err := store.Get(ctx)
    @@ -228,6 +248,12 @@ func Test_Session_Store_Reset(t *testing.T) {
     
     	// reset store
     	utils.AssertEqual(t, nil, store.Reset())
    +	id := sess.ID()
    +
    +	app.ReleaseCtx(ctx)
    +	ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
    +	defer app.ReleaseCtx(ctx)
    +	ctx.Request().Header.SetCookie(store.sessionName, id)
     
     	// make sure the session is recreated
     	sess, err = store.Get(ctx)
    @@ -302,25 +328,37 @@ func Test_Session_Save_Expiration(t *testing.T) {
     		// set value
     		sess.Set("name", "john")
     
    +		token := sess.ID()
    +
     		// expire this session in 5 seconds
     		sess.SetExpiry(sessionDuration)
     
     		// save session
     		err = sess.Save()
     		utils.AssertEqual(t, nil, err)
     
    +		app.ReleaseCtx(ctx)
    +		ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
    +
     		// here you need to get the old session yet
    +		ctx.Request().Header.SetCookie(store.sessionName, token)
     		sess, err = store.Get(ctx)
     		utils.AssertEqual(t, nil, err)
     		utils.AssertEqual(t, "john", sess.Get("name"))
     
     		// just to make sure the session has been expired
     		time.Sleep(sessionDuration + (10 * time.Millisecond))
     
    +		app.ReleaseCtx(ctx)
    +		ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
    +		defer app.ReleaseCtx(ctx)
    +
     		// here you should get a new session
    +		ctx.Request().Header.SetCookie(store.sessionName, token)
     		sess, err = store.Get(ctx)
     		utils.AssertEqual(t, nil, err)
     		utils.AssertEqual(t, nil, sess.Get("name"))
    +		utils.AssertEqual(t, true, sess.ID() != token)
     	})
     }
     
    @@ -364,7 +402,15 @@ func Test_Session_Destroy(t *testing.T) {
     
     		// set value & save
     		sess.Set("name", "fenny")
    +		id := sess.ID()
     		utils.AssertEqual(t, nil, sess.Save())
    +
    +		app.ReleaseCtx(ctx)
    +		ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
    +		defer app.ReleaseCtx(ctx)
    +
    +		// get session
    +		ctx.Request().Header.Set(store.sessionName, id)
     		sess, err = store.Get(ctx)
     		utils.AssertEqual(t, nil, err)
     
    @@ -408,7 +454,8 @@ func Test_Session_Cookie(t *testing.T) {
     }
     
     // go test -run Test_Session_Cookie_In_Response
    -func Test_Session_Cookie_In_Response(t *testing.T) {
    +// Regression: https://github.com/gofiber/fiber/pull/1191
    +func Test_Session_Cookie_In_Middleware_Chain(t *testing.T) {
     	t.Parallel()
     	store := New()
     	app := fiber.New()
    @@ -421,15 +468,17 @@ func Test_Session_Cookie_In_Response(t *testing.T) {
     	sess, err := store.Get(ctx)
     	utils.AssertEqual(t, nil, err)
     	sess.Set("id", "1")
    +	id := sess.ID()
     	utils.AssertEqual(t, true, sess.Fresh())
     	utils.AssertEqual(t, nil, sess.Save())
     
     	sess, err = store.Get(ctx)
     	utils.AssertEqual(t, nil, err)
     	sess.Set("name", "john")
     	utils.AssertEqual(t, true, sess.Fresh())
    +	utils.AssertEqual(t, id, sess.ID()) // session id should be the same
     
    -	utils.AssertEqual(t, "1", sess.Get("id"))
    +	utils.AssertEqual(t, sess.ID() != "1", true)
     	utils.AssertEqual(t, "john", sess.Get("name"))
     }
     
    @@ -441,24 +490,31 @@ func Test_Session_Deletes_Single_Key(t *testing.T) {
     	app := fiber.New()
     
     	ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
    -	defer app.ReleaseCtx(ctx)
     
     	sess, err := store.Get(ctx)
     	utils.AssertEqual(t, nil, err)
    -	ctx.Request().Header.SetCookie(store.sessionName, sess.ID())
    -
    +	id := sess.ID()
     	sess.Set("id", "1")
     	utils.AssertEqual(t, nil, sess.Save())
     
    +	app.ReleaseCtx(ctx)
    +	ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
    +	ctx.Request().Header.SetCookie(store.sessionName, id)
    +
     	sess, err = store.Get(ctx)
     	utils.AssertEqual(t, nil, err)
     	sess.Delete("id")
     	utils.AssertEqual(t, nil, sess.Save())
     
    +	app.ReleaseCtx(ctx)
    +	ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
    +	ctx.Request().Header.SetCookie(store.sessionName, id)
    +
     	sess, err = store.Get(ctx)
     	utils.AssertEqual(t, nil, err)
     	utils.AssertEqual(t, false, sess.Fresh())
     	utils.AssertEqual(t, nil, sess.Get("id"))
    +	app.ReleaseCtx(ctx)
     }
     
     // go test -run Test_Session_Reset
    @@ -475,6 +531,9 @@ func Test_Session_Reset(t *testing.T) {
     	defer app.ReleaseCtx(ctx)
     
     	t.Run("reset session data and id, and set fresh to be true", func(t *testing.T) {
    +		t.Parallel()
    +		// fiber context
    +		ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
     		// a random session uuid
     		originalSessionUUIDString := ""
     
    @@ -491,6 +550,9 @@ func Test_Session_Reset(t *testing.T) {
     		err = freshSession.Save()
     		utils.AssertEqual(t, nil, err)
     
    +		app.ReleaseCtx(ctx)
    +		ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
    +
     		// set cookie
     		ctx.Request().Header.SetCookie(store.sessionName, originalSessionUUIDString)
     
    @@ -524,6 +586,8 @@ func Test_Session_Reset(t *testing.T) {
     		// Check that the session id is not in the header or cookie anymore
     		utils.AssertEqual(t, "", string(ctx.Response().Header.Peek(store.sessionName)))
     		utils.AssertEqual(t, "", string(ctx.Request().Header.Peek(store.sessionName)))
    +
    +		app.ReleaseCtx(ctx)
     	})
     }
     
    @@ -551,6 +615,12 @@ func Test_Session_Regenerate(t *testing.T) {
     		err = freshSession.Save()
     		utils.AssertEqual(t, nil, err)
     
    +		// release the context
    +		app.ReleaseCtx(ctx)
    +
    +		// acquire a new context
    +		ctx = app.AcquireCtx(&fasthttp.RequestCtx{})
    +
     		// set cookie
     		ctx.Request().Header.SetCookie(store.sessionName, originalSessionUUIDString)
     
    @@ -566,6 +636,9 @@ func Test_Session_Regenerate(t *testing.T) {
     
     		// acquiredSession.fresh should be true after regenerating
     		utils.AssertEqual(t, true, acquiredSession.Fresh())
    +
    +		// release the context
    +		app.ReleaseCtx(ctx)
     	})
     }
     
    
  • middleware/session/store.go+53 61 modified
    @@ -9,19 +9,27 @@ import (
     	"github.com/gofiber/fiber/v2"
     	"github.com/gofiber/fiber/v2/internal/storage/memory"
     	"github.com/gofiber/fiber/v2/utils"
    -
    -	"github.com/valyala/fasthttp"
     )
     
     // ErrEmptySessionID is an error that occurs when the session ID is empty.
     var ErrEmptySessionID = errors.New("session id cannot be empty")
     
    +// mux is a global mutex for session operations.
    +var mux sync.Mutex
    +
    +// sessionIDKey is the local key type used to store and retrieve the session ID in context.
    +type sessionIDKey int
    +
    +const (
    +	// sessionIDContextKey is the key used to store the session ID in the context locals.
    +	sessionIDContextKey sessionIDKey = iota
    +)
    +
     type Store struct {
     	Config
     }
     
    -var mux sync.Mutex
    -
    +// New creates a new session store with the provided configuration.
     func New(config ...Config) *Store {
     	// Set default config
     	cfg := configDefault(config...)
    @@ -35,31 +43,40 @@ func New(config ...Config) *Store {
     	}
     }
     
    -// RegisterType will allow you to encode/decode custom types
    -// into any Storage provider
    +// RegisterType registers a custom type for encoding/decoding into any storage provider.
     func (*Store) RegisterType(i interface{}) {
     	gob.Register(i)
     }
     
    -// Get will get/create a session
    +// Get retrieves or creates a session for the given context.
     func (s *Store) Get(c *fiber.Ctx) (*Session, error) {
    -	var fresh bool
    -	loadData := true
    +	var rawData []byte
    +	var err error
     
    -	id := s.getSessionID(c)
    +	id, ok := c.Locals(sessionIDContextKey).(string)
    +	if !ok {
    +		id = s.getSessionID(c)
    +	}
     
    -	if len(id) == 0 {
    -		fresh = true
    -		var err error
    -		if id, err = s.responseCookies(c); err != nil {
    +	fresh := ok // Assume the session is fresh if the ID is found in locals
    +
    +	// Attempt to fetch session data if an ID is provided
    +	if id != "" {
    +		rawData, err = s.Storage.Get(id)
    +		if err != nil {
     			return nil, err
     		}
    +		if rawData == nil {
    +			// Data not found, prepare to generate a new session
    +			id = ""
    +		}
     	}
     
    -	// If no key exist, create new one
    -	if len(id) == 0 {
    -		loadData = false
    +	// Generate a new ID if needed
    +	if id == "" {
    +		fresh = true // The session is fresh if a new ID is generated
     		id = s.KeyGenerator()
    +		c.Locals(sessionIDContextKey, id)
     	}
     
     	// Create session object
    @@ -69,34 +86,17 @@ func (s *Store) Get(c *fiber.Ctx) (*Session, error) {
     	sess.id = id
     	sess.fresh = fresh
     
    -	// Fetch existing data
    -	if loadData {
    -		raw, err := s.Storage.Get(id)
    -		// Unmarshal if we found data
    -		if raw != nil && err == nil {
    -			mux.Lock()
    -			defer mux.Unlock()
    -			_, _ = sess.byteBuffer.Write(raw) //nolint:errcheck // This will never fail
    -			encCache := gob.NewDecoder(sess.byteBuffer)
    -			err := encCache.Decode(&sess.data.Data)
    -			if err != nil {
    -				return nil, fmt.Errorf("failed to decode session data: %w", err)
    -			}
    -		} else if err != nil {
    -			return nil, err
    -		} else {
    -			// both raw and err is nil, which means id is not in the storage
    -			sess.fresh = true
    +	// Decode session data if found
    +	if rawData != nil {
    +		if err := sess.decodeSessionData(rawData); err != nil {
    +			return nil, fmt.Errorf("failed to decode session data: %w", err)
     		}
     	}
     
     	return sess, nil
     }
     
    -// getSessionID will return the session id from:
    -// 1. cookie
    -// 2. http headers
    -// 3. query string
    +// getSessionID returns the session ID from cookies, headers, or query string.
     func (s *Store) getSessionID(c *fiber.Ctx) string {
     	id := c.Cookies(s.sessionName)
     	if len(id) > 0 {
    @@ -120,35 +120,27 @@ func (s *Store) getSessionID(c *fiber.Ctx) string {
     	return ""
     }
     
    -func (s *Store) responseCookies(c *fiber.Ctx) (string, error) {
    -	// Get key from response cookie
    -	cookieValue := c.Response().Header.PeekCookie(s.sessionName)
    -	if len(cookieValue) == 0 {
    -		return "", nil
    -	}
    -
    -	cookie := fasthttp.AcquireCookie()
    -	defer fasthttp.ReleaseCookie(cookie)
    -	err := cookie.ParseBytes(cookieValue)
    -	if err != nil {
    -		return "", err
    -	}
    -
    -	value := make([]byte, len(cookie.Value()))
    -	copy(value, cookie.Value())
    -	id := string(value)
    -	return id, nil
    -}
    -
    -// Reset will delete all session from the storage
    +// Reset deletes all sessions from the storage.
     func (s *Store) Reset() error {
     	return s.Storage.Reset()
     }
     
    -// Delete deletes a session by its id.
    +// Delete deletes a session by its ID.
     func (s *Store) Delete(id string) error {
     	if id == "" {
     		return ErrEmptySessionID
     	}
     	return s.Storage.Delete(id)
     }
    +
    +// decodeSessionData decodes the session data from raw bytes.
    +func (s *Session) decodeSessionData(rawData []byte) error {
    +	mux.Lock()
    +	defer mux.Unlock()
    +	_, _ = s.byteBuffer.Write(rawData) //nolint:errcheck // This will never fail
    +	encCache := gob.NewDecoder(s.byteBuffer)
    +	if err := encCache.Decode(&s.data.Data); err != nil {
    +		return fmt.Errorf("failed to decode session data: %w", err)
    +	}
    +	return nil
    +}
    
  • middleware/session/store_test.go+3 2 modified
    @@ -64,12 +64,13 @@ func TestStore_getSessionID(t *testing.T) {
     
     // go test -run TestStore_Get
     // Regression: https://github.com/gofiber/fiber/issues/1408
    +// Regression: https://github.com/gofiber/fiber/security/advisories/GHSA-98j2-3j3p-fw2v
     func TestStore_Get(t *testing.T) {
     	t.Parallel()
     	unexpectedID := "test-session-id"
     	// fiber instance
     	app := fiber.New()
    -	t.Run("session should persisted even session is invalid", func(t *testing.T) {
    +	t.Run("session should be re-generated if it is invalid", func(t *testing.T) {
     		t.Parallel()
     		// session store
     		store := New()
    @@ -82,7 +83,7 @@ func TestStore_Get(t *testing.T) {
     		acquiredSession, err := store.Get(ctx)
     		utils.AssertEqual(t, err, nil)
     
    -		utils.AssertEqual(t, unexpectedID, acquiredSession.ID())
    +		utils.AssertEqual(t, acquiredSession.ID() != unexpectedID, true)
     	})
     }
     
    

Vulnerability mechanics

Generated by null/stub on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.

References

5

News mentions

0

No linked articles in our index yet.