DNS Rebinding vulnerability present when running MCP Gateway in sse or streaming mode
Description
MCP Gateway allows easy and secure running and deployment of MCP servers. In versions 0.27.0 and earlier, when MCP Gateway runs in sse or streaming transport mode, it is vulnerable to DNS rebinding. An attacker who can get a victim to visit a malicious website or be served a malicious advertisement can perform browser-based exploitation of MCP servers executing behind the gateway, including manipulating tools or other features exposed by those MCP servers. MCP Gateway is not affected when running in the default stdio mode, which does not listen on network ports. Version 0.28.0 fixes this issue.
AI Insight
LLM-synthesized narrative grounded in this CVE's description and references.
MCP Gateway versions ≤0.27.0 in SSE/streaming mode are vulnerable to DNS rebinding, allowing browser-based exploitation of MCP servers.
Vulnerability
Overview
CVE-2025-64443 affects MCP Gateway (docker-mcp CLI plugin) in versions 0.27.0 and earlier when running in SSE or streaming transport mode. The vulnerability is a DNS rebinding attack that stems from the lack of Origin header validation on incoming HTTP connections. The MCP specification itself warns that servers implementing Streamable HTTP transport must validate the Origin header to prevent such attacks [2]. In the gateway did not perform this validation, making it susceptible to cross-origin requests from malicious websites.
Exploitation
Method of Exploitation
An attacker can exploit this by luring a victim to visit a malicious website or by serving a malicious advertisement. When the victim's browser makes a request to the MCP Gateway (which listens on network ports in SSE/streaming mode), the attacker can use DNS rebinding to bypass the same-origin policy. The gateway accepts requests without verifying the Origin header, allowing the attacker's script to interact with MCP servers behind the gateway. No authentication is required for these requests in-flight requests in the vulnerable versions.
Impact
Successful exploitation allows an attacker to manipulate tools or other features exposed by MCP servers running behind the gateway. This could include invoking arbitrary MCP tools, reading resources, or performing actions that the MCP server isCP server is authorized to do. The impact is limited to the capabilities of the MCP servers configured in the gateway, but could lead to data exfiltration or unauthorized operations depending on those servers' functions.
Mitigation
Version 0.28.0 fixes this issue by implementing Origin header validation [1]. The fix includes an isAllowedOrigin function that only permits requests from localhost origins (http://localhost, http://127.0.0.1, http://[::1]) and rejects all other origins, including subdomain attacks like localhost.evil.com [4]. Additionally, the fix introduces authentication support for SSE and streaming modes, with auto-generated tokens and support for query parameter or HTTP Basic Auth [1]. Users should upgrade to version 0.28.0 or later. The default stdio mode is not affected as it does not listen on network ports.
AI Insight generated on May 19, 2026. Synthesized from this CVE's description and the cited reference URLs; citations are validated against the source bundle.
Affected packages
Versions sourced from the GitHub Security Advisory.
| Package | Affected versions | Patched versions |
|---|---|---|
github.com/docker/mcp-gatewayGo | < 0.28.0 | 0.28.0 |
Affected products
1- docker/mcp-gatewayv5Range: < 0.28.0
Patches
26b076b2479d8Merge commit from fork
3 files changed · +285 −2
pkg/gateway/auth_test.go+29 −0 modified@@ -228,3 +228,32 @@ func TestFormatBearerToken(t *testing.T) { t.Errorf("expected result to start with 'Authorization: Bearer ', got %q", result) } } + +func TestAuthenticationMiddleware_ContainerMode(t *testing.T) { + // Set container environment variable + os.Setenv("DOCKER_MCP_IN_CONTAINER", "1") + defer os.Unsetenv("DOCKER_MCP_IN_CONTAINER") + + authToken := "test-token-123" + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("success")) + }) + + secured := authenticationMiddleware(authToken, handler) + + // Should allow request without auth token when in container mode + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + // No Authorization header + + rr := httptest.NewRecorder() + secured.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("Expected 200 in container mode, got %d", rr.Code) + } + + if rr.Body.String() != "success" { + t.Errorf("Expected 'success', got %q", rr.Body.String()) + } +}
pkg/gateway/transport.go+43 −2 modified@@ -2,9 +2,11 @@ package gateway import ( "context" + "fmt" "io" "net" "net/http" + "net/url" "github.com/modelcontextprotocol/go-sdk/mcp" @@ -23,7 +25,7 @@ func (g *Gateway) startSseServer(ctx context.Context, ln net.Listener) error { sseHandler := mcp.NewSSEHandler(func(_ *http.Request) *mcp.Server { return g.mcpServer }, nil) - mux.Handle("/sse", sseHandler) + mux.Handle("/sse", originSecurityHandler(sseHandler)) // Wrap with authentication middleware var handler http.Handler = mux @@ -48,7 +50,7 @@ func (g *Gateway) startStreamingServer(ctx context.Context, ln net.Listener) err streamHandler := mcp.NewStreamableHTTPHandler(func(_ *http.Request) *mcp.Server { return g.mcpServer }, nil) - mux.Handle("/mcp", streamHandler) + mux.Handle("/mcp", originSecurityHandler(streamHandler)) // Wrap with authentication middleware var handler http.Handler = mux @@ -82,3 +84,42 @@ func healthHandler(state *health.State) http.HandlerFunc { } } } + +// isAllowedOrigin validates that the origin is from localhost. +// Returns true if the origin's hostname is "localhost", "127.0.0.1", or "::1" (IPv6 localhost). +func isAllowedOrigin(origin string) bool { + u, err := url.Parse(origin) + if err != nil { + return false // Invalid URL format + } + + // Only allow http or https schemes + if u.Scheme != "http" && u.Scheme != "https" { + return false + } + + // Extract hostname (without port) + host := u.Hostname() + + // Only allow localhost, IPv4 loopback, or IPv6 loopback + return host == "localhost" || host == "127.0.0.1" || host == "::1" +} + +// originSecurityHandler validates Origin header to prevent DNS rebinding attacks. +func originSecurityHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + + // Allow requests with no Origin header + // This handles: + // - Non-browser clients (curl, SDKs) - no Origin header sent + // - Same-origin requests - browsers don't send Origin for same-origin + if origin != "" && !isAllowedOrigin(origin) { + msg := fmt.Sprintf("Forbidden: Origin, if set, must be localhost, 127.0.0.1, or ::1, got: %s", origin) + http.Error(w, msg, http.StatusForbidden) + return + } + + next.ServeHTTP(w, r) + }) +}
pkg/gateway/transport_test.go+213 −0 added@@ -0,0 +1,213 @@ +package gateway + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +// TestIsAllowedOrigin tests the isAllowedOrigin helper function with various inputs. +func TestIsAllowedOrigin(t *testing.T) { + tests := []struct { + name string + origin string + expected bool + }{ + // Valid localhost origins + {"http localhost no port", "http://localhost", true}, + {"https localhost no port", "https://localhost", true}, + {"http localhost with port", "http://localhost:3000", true}, + {"https localhost with port", "https://localhost:8080", true}, + {"http 127.0.0.1 no port", "http://127.0.0.1", true}, + {"https 127.0.0.1 no port", "https://127.0.0.1", true}, + {"http 127.0.0.1 with port", "http://127.0.0.1:8080", true}, + {"https 127.0.0.1 with port", "https://127.0.0.1:5000", true}, + {"http IPv6 localhost", "http://[::1]", true}, + {"https IPv6 localhost", "https://[::1]", true}, + {"http IPv6 localhost with port", "http://[::1]:8080", true}, + {"https IPv6 localhost with port", "https://[::1]:3000", true}, + + // Invalid origins - malicious domains + {"evil domain", "https://evil.com", false}, + {"evil domain with port", "https://evil.com:8080", false}, + {"subdomain attack", "http://localhost.evil.com", false}, + {"subdomain with 127", "http://127.0.0.1.evil.com", false}, + + // Invalid origins - RFC 1122 prohibits 0.0.0.0 as destination + {"0.0.0.0 address", "http://0.0.0.0:8080", false}, + {"0.0.0.0 no port", "http://0.0.0.0", false}, + {"all zeros IPv6", "http://[::]:8080", false}, + + // Invalid schemes + {"ftp scheme", "ftp://localhost", false}, + {"ws scheme", "ws://localhost", false}, + {"file scheme", "file://localhost", false}, + + // Malformed URLs + {"not a URL", "not-a-url", false}, + {"missing scheme", "localhost:8080", false}, + {"single slash", "http:/localhost", false}, + {"empty string", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isAllowedOrigin(tt.origin) + if result != tt.expected { + t.Errorf("isAllowedOrigin(%q) = %v, expected %v", tt.origin, result, tt.expected) + } + }) + } +} + +func TestOriginSecurityHandler(t *testing.T) { + // Create a simple handler that always succeeds if reached + successHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("success")) + }) + + // Wrap it with our security handler + secureHandler := originSecurityHandler(successHandler) + + tests := []struct { + name string + origin string + expectedStatus int + reason string + }{ + { + name: "No Origin header (non-browser clients)", + origin: "", + expectedStatus: http.StatusOK, + reason: "CRITICAL: curl, SDKs, and same-origin browser requests must work. Browsers don't send Origin for same-origin requests.", + }, + { + name: "Malicious origin (the actual attack)", + origin: "https://evil.com", + expectedStatus: http.StatusForbidden, + reason: "CRITICAL: This is the vulnerability we're fixing. Block cross-origin requests from non-localhost origins.", + }, + { + name: "Localhost origin (legitimate browser client)", + origin: "http://localhost:3000", + expectedStatus: http.StatusOK, + reason: "CRITICAL: Developer running local frontend on different port must work. Common development scenario.", + }, + { + name: "Non-localhost IP origin", + origin: "http://0.0.0.0:8080", + expectedStatus: http.StatusForbidden, + reason: "Block non-localhost IPs. Note: In DNS rebinding, evil.com resolves to 0.0.0.0 but Origin would be http://evil.com", + }, + { + name: "Subdomain bypass attempt", + origin: "http://localhost.evil.com", + expectedStatus: http.StatusForbidden, + reason: "IMPORTANT: Prevent validation bypass using subdomain that contains 'localhost'. Common attack technique.", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + if tt.origin != "" { + req.Header.Set("Origin", tt.origin) + } + + rr := httptest.NewRecorder() + secureHandler.ServeHTTP(rr, req) + + if rr.Code != tt.expectedStatus { + t.Errorf("Expected status %d, got %d\nReason: %s", tt.expectedStatus, rr.Code, tt.reason) + } + + // Verify response body for blocked requests contains helpful error message + if tt.expectedStatus == http.StatusForbidden { + body := rr.Body.String() + if !strings.Contains(body, "Forbidden") || !strings.Contains(body, "Origin") { + t.Errorf("Expected error body to contain 'Forbidden' and 'Origin', got %q", body) + } + } + }) + } +} + +// TestCombinedSecurityLayers verifies that both Origin validation and authentication work together. +// This ensures defense-in-depth: both layers must pass for a request to succeed. +func TestCombinedSecurityLayers(t *testing.T) { + authToken := "test-token-secure-123" + + successHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("success")) + }) + + // Wrap with both layers (same as production code) + withOriginCheck := originSecurityHandler(successHandler) + withBothLayers := authenticationMiddleware(authToken, withOriginCheck) + + tests := []struct { + name string + origin string + authHeader string + expectedStatus int + reason string + }{ + { + name: "Valid origin + valid token", + origin: "http://localhost:3000", + authHeader: "Bearer test-token-secure-123", + expectedStatus: http.StatusOK, + reason: "Both security layers pass - request should succeed", + }, + { + name: "Valid origin + invalid token", + origin: "http://localhost:3000", + authHeader: "Bearer wrong-token", + expectedStatus: http.StatusUnauthorized, + reason: "Origin is valid but auth fails - should block at auth layer", + }, + { + name: "Invalid origin + valid token", + origin: "https://evil.com", + authHeader: "Bearer test-token-secure-123", + expectedStatus: http.StatusForbidden, + reason: "Token is valid but origin is malicious - should block at origin layer", + }, + { + name: "Invalid origin + no token", + origin: "https://evil.com", + authHeader: "", + expectedStatus: http.StatusUnauthorized, + reason: "Both layers fail - auth middleware (outer) checks first, blocks with 401", + }, + { + name: "No origin + valid token (CLI client)", + origin: "", + authHeader: "Bearer test-token-secure-123", + expectedStatus: http.StatusOK, + reason: "CLI tools with valid auth should work", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + if tt.origin != "" { + req.Header.Set("Origin", tt.origin) + } + if tt.authHeader != "" { + req.Header.Set("Authorization", tt.authHeader) + } + + rr := httptest.NewRecorder() + withBothLayers.ServeHTTP(rr, req) + + if rr.Code != tt.expectedStatus { + t.Errorf("Expected status %d, got %d\nReason: %s", tt.expectedStatus, rr.Code, tt.reason) + } + }) + } +}
fe073985c8ebAdd authentication for SSE and streaming gateway modes (#190)
13 files changed · +989 −78
CLAUDE.md+145 −0 added@@ -0,0 +1,145 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Overview + +This is the **Docker MCP Gateway** - a CLI plugin that enables easy and secure running of Model Context Protocol (MCP) servers through Docker containers. The plugin acts as a gateway between AI clients and containerized MCP servers, providing isolation, security, and management capabilities. + +## Architecture + +The codebase follows a gateway pattern where: +- **AI Client** connects to the **MCP Gateway** +- **MCP Gateway** (this CLI) manages multiple **MCP Servers** running in Docker containers + +Key architectural components: +- **Gateway**: Core routing and protocol translation (`cmd/docker-mcp/internal/gateway/`) +- **Client Management**: Handles connections to AI clients (`cmd/docker-mcp/client/`) +- **Server Management**: Manages MCP server lifecycle (`cmd/docker-mcp/server/`) +- **Catalog System**: Manages available MCP servers (`cmd/docker-mcp/catalog/`) +- **Security**: Secrets management and OAuth flows (`cmd/docker-mcp/secret-management/`, `cmd/docker-mcp/oauth/`) + +## Development Commands + +### Building +```bash +# Build the CLI plugin locally +make docker-mcp + +# Cross-compile for all platforms +make docker-mcp-cross +``` + +### Testing +```bash +# Run all tests +make test + +# Run integration tests specifically +make integration + +# Run long-lived integration tests +go test -count=1 ./... -run 'TestLongLived' + +# Run specific tests by pattern +go test -count=1 ./... -run 'TestIntegration' + +# Run a single test file +go test ./cmd/docker-mcp/server/server_test.go + +# Run tests with coverage +go test -cover ./... +``` + +### Code Quality +```bash +# Format code +make format + +# Run linter +make lint + +# Run linter for specific platform +make lint-linux +make lint-darwin + +# Run Go vet (static analysis) +go vet ./... + +# Run Go mod tidy to clean dependencies +go mod tidy +``` + +## Project Structure + +- `cmd/docker-mcp/` - Main CLI application entry point +- `cmd/docker-mcp/internal/gateway/` - Core gateway implementation with client pooling, proxy management, and transport handling +- `cmd/docker-mcp/internal/docker/` - Docker integration for container management +- `cmd/docker-mcp/internal/mcp/` - MCP protocol implementations (stdio, SSE) +- `cmd/docker-mcp/internal/desktop/` - Docker Desktop integration and authentication +- `cmd/docker-mcp/catalog/` - Server catalog management commands +- `cmd/docker-mcp/client/` - Client configuration and connection management +- `cmd/docker-mcp/server/` - Server lifecycle management commands +- `cmd/docker-mcp/tools/` - Tool execution and management +- `examples/` - Usage examples and compose configurations +- `docs/` - Technical documentation + +## Key Configuration Files + +The CLI uses these configuration files (typically in `~/.docker/mcp/`): +- `docker-mcp.yaml` - Server catalog definitions +- `registry.yaml` - Registry of enabled servers +- `config.yaml` - Gateway configuration and options + +## Important Patterns + +### Transport Modes +The gateway supports multiple transport modes: +- `stdio` - Standard input/output (default) +- `streaming` - HTTP streaming for multiple clients +- `sse` - Server-sent events + +### Security Features +- Container isolation for MCP servers +- Secrets management via Docker Desktop +- OAuth flow handling +- API key and credential protection +- Call interception and logging + +### Client Integration +The system integrates with various AI clients: +- VS Code / Cursor +- Claude Desktop +- Continue Dev +- Custom MCP clients + +Configuration files for different clients are automatically managed in `cmd/docker-mcp/client/testdata/`. + +## CLI Plugin Development + +This is a Docker CLI plugin written in Go 1.24+. Key development patterns: + +### Plugin Installation +The plugin is installed as `docker-mcp` and becomes available as `docker mcp <command>`. The Makefile handles building and installing to the correct Docker CLI plugins directory (`~/.docker/cli-plugins/`). + +### Command Structure +Commands follow the Cobra CLI pattern with the main command tree defined in `cmd/docker-mcp/commands/`. Each major command area (catalog, server, client, etc.) has its own file. + +### Configuration Management +The CLI uses YAML configuration files stored in `~/.docker/mcp/`: +- Server definitions are loaded from catalog files +- Runtime configuration is managed through config.yaml +- Server enablement tracked in registry.yaml + +### Container Lifecycle +MCP servers run as Docker containers with proper lifecycle management: +- Images are pulled and validated before use +- Containers have consistent naming patterns +- Health checks and logging are built-in +- Proper cleanup on shutdown + +### Testing Patterns +- Integration tests require Docker daemon +- Long-lived tests run actual container scenarios +- Mock configurations in testdata directories +- Use `go test -count=1` to disable test caching \ No newline at end of file
cmd/docker-mcp/commands/gateway.go+19 −2 modified@@ -77,6 +77,9 @@ func gatewayCommand(docker docker.Client, dockerCli command.Cli) *cobra.Command // Check if dynamic tools feature is enabled options.DynamicTools = isDynamicToolsFeatureEnabled(dockerCli) + // Check if tool name prefix feature is enabled + options.ToolNamePrefix = isToolNamePrefixFeatureEnabled(dockerCli) + // Update catalog URL based on mcp-oauth-dcr flag if using default Docker catalog URL if len(options.CatalogPath) == 1 && (options.CatalogPath[0] == catalog.DockerCatalogURLV2 || options.CatalogPath[0] == catalog.DockerCatalogURLV3) { options.CatalogPath[0] = catalog.GetDockerCatalogURL(options.McpOAuthDcrEnabled) @@ -160,7 +163,7 @@ func gatewayCommand(docker docker.Client, dockerCli command.Cli) *cobra.Command runCmd.Flags().StringArrayVar(&options.OciRef, "oci-ref", options.OciRef, "OCI image references to use") runCmd.Flags().StringSliceVar(&mcpRegistryUrls, "mcp-registry", nil, "MCP registry URLs to fetch servers from (can be repeated)") runCmd.Flags().IntVar(&options.Port, "port", options.Port, "TCP port to listen on (default is to listen on stdio)") - runCmd.Flags().StringVar(&options.Transport, "transport", options.Transport, "stdio, sse or streaming (default is stdio)") + runCmd.Flags().StringVar(&options.Transport, "transport", options.Transport, "stdio, sse or streaming. Uses MCP_GATEWAY_AUTH_TOKEN environment variable for localhost authentication to prevent dns rebinding attacks.") runCmd.Flags().BoolVar(&options.LogCalls, "log-calls", options.LogCalls, "Log calls to the tools") runCmd.Flags().BoolVar(&options.BlockSecrets, "block-secrets", options.BlockSecrets, "Block secrets from being/received sent to/from tools") runCmd.Flags().BoolVar(&options.BlockNetwork, "block-network", options.BlockNetwork, "Block tools from accessing forbidden network resources") @@ -176,7 +179,6 @@ func gatewayCommand(docker docker.Client, dockerCli command.Cli) *cobra.Command runCmd.Flags().StringVar(&options.LogFilePath, "log", options.LogFilePath, "Path to log file for stderr output (relative or absolute)") // Very experimental features - _ = runCmd.Flags().MarkHidden("transport") _ = runCmd.Flags().MarkHidden("log") cmd.AddCommand(runCmd) @@ -303,3 +305,18 @@ func isDynamicToolsFeatureEnabled(dockerCli command.Cli) bool { return value == "enabled" } + +// isToolNamePrefixFeatureEnabled checks if the tool-name-prefix feature is enabled +func isToolNamePrefixFeatureEnabled(dockerCli command.Cli) bool { + configFile := dockerCli.ConfigFile() + if configFile == nil || configFile.Features == nil { + return false + } + + value, exists := configFile.Features["tool-name-prefix"] + if !exists { + return false + } + + return value == "enabled" +}
docs/generator/reference/docker_mcp_gateway_run.yaml+3 −2 modified@@ -275,9 +275,10 @@ options: - option: transport value_type: string default_value: stdio - description: stdio, sse or streaming (default is stdio) + description: | + stdio, sse or streaming. Uses MCP_GATEWAY_AUTH_TOKEN environment variable for localhost authentication to prevent dns rebinding attacks. deprecated: false - hidden: true + hidden: false experimental: false experimentalcli: false kubernetes: false
docs/generator/reference/mcp_gateway_run.md+1 −0 modified@@ -32,6 +32,7 @@ Run the gateway | `--static` | `bool` | | Enable static mode (aka pre-started servers) | | `--tools` | `stringSlice` | | List of tools to enable | | `--tools-config` | `stringSlice` | `[tools.yaml]` | Paths to the tools files (absolute or relative to ~/.docker/mcp/) | +| `--transport` | `string` | `stdio` | stdio, sse or streaming. Uses MCP_GATEWAY_AUTH_TOKEN environment variable for localhost authentication to prevent dns rebinding attacks. | | `--verbose` | `bool` | | Verbose output | | `--verify-signatures` | `bool` | | Verify signatures of the server images | | `--watch` | `bool` | `true` | Watch for changes and reconfigure the gateway |
pkg/catalog/types.go+1 −0 modified@@ -31,6 +31,7 @@ type Server struct { AllowHosts []string `yaml:"allowHosts,omitempty" json:"allowHosts,omitempty"` Tools []Tool `yaml:"tools,omitempty" json:"tools,omitempty"` Config []any `yaml:"config,omitempty" json:"config,omitempty"` + Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"` } type Secret struct {
pkg/gateway/auth.go+97 −0 added@@ -0,0 +1,97 @@ +package gateway + +import ( + "crypto/rand" + "crypto/subtle" + "fmt" + "math/big" + "net/http" + "os" +) + +const ( + tokenLength = 50 + // Characters to use for random token generation (lowercase letters and numbers) + tokenCharset = "abcdefghijklmnopqrstuvwxyz0123456789" +) + +// generateAuthToken generates a random 50-character string using lowercase letters and numbers +func generateAuthToken() (string, error) { + token := make([]byte, tokenLength) + charsetLen := big.NewInt(int64(len(tokenCharset))) + + for i := range tokenLength { + num, err := rand.Int(rand.Reader, charsetLen) + if err != nil { + return "", fmt.Errorf("failed to generate random token: %w", err) + } + token[i] = tokenCharset[num.Int64()] + } + + return string(token), nil +} + +// getOrGenerateAuthToken retrieves the auth token from environment variable MCP_GATEWAY_AUTH_TOKEN +// or generates a new one if not set or empty +func getOrGenerateAuthToken() (string, bool, error) { + envToken := os.Getenv("MCP_GATEWAY_AUTH_TOKEN") + if envToken != "" { + return envToken, false, nil // false indicates token was from environment + } + + token, err := generateAuthToken() + if err != nil { + return "", false, err + } + return token, true, nil // true indicates token was generated +} + +// authenticationMiddleware creates an HTTP middleware that validates requests using +// Bearer token in the Authorization header. +// +// The /health endpoint is excluded from authentication. +func authenticationMiddleware(authToken string, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Skip authentication for health check endpoint + if r.URL.Path == "/health" { + next.ServeHTTP(w, r) + return + } + + authenticated := false + + // Check for Bearer token in Authorization header + authHeader := r.Header.Get("Authorization") + if authHeader != "" { + // Extract Bearer token from "Bearer <token>" format + const bearerPrefix = "Bearer " + if len(authHeader) > len(bearerPrefix) && authHeader[:len(bearerPrefix)] == bearerPrefix { + bearerToken := authHeader[len(bearerPrefix):] + // Use constant-time comparison to prevent timing attacks + if subtle.ConstantTimeCompare([]byte(bearerToken), []byte(authToken)) == 1 { + authenticated = true + } + } + } + + if !authenticated { + // Return 401 Unauthorized with WWW-Authenticate header + w.Header().Set("WWW-Authenticate", `Bearer realm="MCP Gateway"`) + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + // Authentication successful, proceed to next handler + next.ServeHTTP(w, r) + }) +} + +// formatGatewayURL formats the gateway URL without authentication info +func formatGatewayURL(port int, endpoint string) string { + return fmt.Sprintf("http://localhost:%d%s", port, endpoint) +} + +// formatBearerToken formats the Bearer token for display in the Authorization header +func formatBearerToken(authToken string) string { + return fmt.Sprintf("Authorization: Bearer %s", authToken) +}
pkg/gateway/auth_integration_test.go+323 −0 added@@ -0,0 +1,323 @@ +package gateway + +import ( + "context" + "fmt" + "io" + "net" + "net/http" + "os" + "strings" + "testing" + "time" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// TestSSEServerAuthentication tests that the SSE server properly enforces authentication +func TestSSEServerAuthentication(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test") + } + + // Create a minimal gateway with SSE transport + g := &Gateway{ + Options: Options{ + Port: 0, // Let the OS assign a port + Transport: "sse", + }, + } + g.health.SetHealthy() // Mark as healthy for testing + + // Initialize an empty MCP server + g.mcpServer = mcp.NewServer(&mcp.Implementation{Name: "test-auth-gateway", Version: "1.0.0"}, nil) + + // Generate auth token + token, wasGenerated, err := getOrGenerateAuthToken() + if err != nil { + t.Fatalf("failed to generate auth token: %v", err) + } + g.authToken = token + g.authTokenWasGenerated = wasGenerated + + // Create a listener on a random available port + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to create listener: %v", err) + } + defer ln.Close() + + port := ln.Addr().(*net.TCPAddr).Port + + // Start the SSE server in a goroutine + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serverErr := make(chan error, 1) + go func() { + serverErr <- g.startSseServer(ctx, ln) + }() + + // Give the server time to start + time.Sleep(100 * time.Millisecond) + + // Test 1: Health endpoint should be accessible without auth + t.Run("HealthEndpointNoAuth", func(t *testing.T) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/health", port), nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("health check failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200 for /health, got %d", resp.StatusCode) + } + }) + + // Test 2: SSE endpoint should reject requests without auth + t.Run("SSEEndpointNoAuth", func(t *testing.T) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/sse", port), nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusUnauthorized { + t.Errorf("expected status 401 for /sse without auth, got %d", resp.StatusCode) + } + }) + + // Test 3: SSE endpoint should accept valid bearer auth + t.Run("SSEEndpointBearerAuth", func(t *testing.T) { + client := &http.Client{} + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/sse", port), nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + req.Header.Set("Authorization", "Bearer "+token) + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + t.Errorf("expected successful response for /sse with valid bearer auth, got %d: %s", resp.StatusCode, string(body)) + } + }) + + // Test 4: SSE endpoint should reject invalid bearer auth + t.Run("SSEEndpointInvalidBearerAuth", func(t *testing.T) { + client := &http.Client{} + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/sse", port), nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + req.Header.Set("Authorization", "Bearer wrong-token") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusUnauthorized { + t.Errorf("expected status 401 for /sse with invalid bearer auth, got %d", resp.StatusCode) + } + }) + + // Cancel context to stop server + cancel() + + // Wait a bit for cleanup + select { + case <-serverErr: + // Server stopped + case <-time.After(1 * time.Second): + t.Error("server did not stop in time") + } +} + +// TestStreamingServerAuthentication tests that the streaming server properly enforces authentication +func TestStreamingServerAuthentication(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test") + } + + // Create a minimal gateway with streaming transport + g := &Gateway{ + Options: Options{ + Port: 0, + Transport: "streaming", + }, + } + g.health.SetHealthy() // Mark as healthy for testing + + // Initialize an empty MCP server + g.mcpServer = mcp.NewServer(&mcp.Implementation{Name: "test-auth-gateway", Version: "1.0.0"}, nil) + + // Generate auth token + token, wasGenerated, err := getOrGenerateAuthToken() + if err != nil { + t.Fatalf("failed to generate auth token: %v", err) + } + g.authToken = token + g.authTokenWasGenerated = wasGenerated + + // Create a listener on a random available port + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to create listener: %v", err) + } + defer ln.Close() + + port := ln.Addr().(*net.TCPAddr).Port + + // Start the streaming server in a goroutine + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serverErr := make(chan error, 1) + go func() { + serverErr <- g.startStreamingServer(ctx, ln) + }() + + // Give the server time to start + time.Sleep(100 * time.Millisecond) + + // Test 1: Health endpoint should be accessible without auth + t.Run("HealthEndpointNoAuth", func(t *testing.T) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/health", port), nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("health check failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200 for /health, got %d", resp.StatusCode) + } + }) + + // Test 2: MCP endpoint should reject requests without auth + t.Run("MCPEndpointNoAuth", func(t *testing.T) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/mcp", port), nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusUnauthorized { + t.Errorf("expected status 401 for /mcp without auth, got %d", resp.StatusCode) + } + }) + + // Test 3: MCP endpoint should accept valid bearer auth + t.Run("MCPEndpointBearerAuth", func(t *testing.T) { + client := &http.Client{} + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/mcp", port), nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + req.Header.Set("Authorization", "Bearer "+token) + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusUnauthorized { + t.Errorf("expected non-401 status for /mcp with valid bearer auth, got %d", resp.StatusCode) + } + }) + + // Cancel context to stop server + cancel() + + // Wait a bit for cleanup + select { + case <-serverErr: + // Server stopped + case <-time.After(1 * time.Second): + t.Error("server did not stop in time") + } +} + +// TestAuthTokenFromEnvironment tests that the auth token is read from the environment +func TestAuthTokenFromEnvironment(t *testing.T) { + expectedToken := "my-custom-token-from-env" + os.Setenv("MCP_GATEWAY_AUTH_TOKEN", expectedToken) + defer os.Unsetenv("MCP_GATEWAY_AUTH_TOKEN") + + token, wasGenerated, err := getOrGenerateAuthToken() + if err != nil { + t.Fatalf("failed to get auth token: %v", err) + } + + if token != expectedToken { + t.Errorf("expected token %q, got %q", expectedToken, token) + } + + if wasGenerated { + t.Error("expected wasGenerated to be false when token is from environment") + } +} + +// TestFormatGatewayURLIntegration tests that the gateway URL is formatted correctly +func TestFormatGatewayURLIntegration(t *testing.T) { + port := 8811 + endpoint := "/sse" + + url := formatGatewayURL(port, endpoint) + + expected := fmt.Sprintf("http://localhost:%d%s", port, endpoint) + if url != expected { + t.Errorf("expected URL %q, got %q", expected, url) + } +} + +// TestFormatBearerTokenEncoding tests that bearer token is properly formatted +func TestFormatBearerTokenEncoding(t *testing.T) { + token := "test-token-abc123" + authHeader := formatBearerToken(token) + + // Should start with "Authorization: Bearer " + if !strings.HasPrefix(authHeader, "Authorization: Bearer ") { + t.Errorf("auth header should start with 'Authorization: Bearer ', got %q", authHeader) + } + + // Extract the token part + parts := strings.SplitN(authHeader, " ", 3) + if len(parts) != 3 { + t.Fatalf("expected 3 parts in auth header, got %d", len(parts)) + } + + // The third part should be the token + if parts[2] != token { + t.Errorf("expected token %q, got %q", token, parts[2]) + } + + // Verify the complete format + expected := fmt.Sprintf("Authorization: Bearer %s", token) + if authHeader != expected { + t.Errorf("expected auth header %q, got %q", expected, authHeader) + } +}
pkg/gateway/auth_test.go+230 −0 added@@ -0,0 +1,230 @@ +package gateway + +import ( + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" +) + +func TestGenerateAuthToken(t *testing.T) { + token, err := generateAuthToken() + if err != nil { + t.Fatalf("generateAuthToken() failed: %v", err) + } + + if len(token) != tokenLength { + t.Errorf("expected token length %d, got %d", tokenLength, len(token)) + } + + // Check that token only contains allowed characters + for _, ch := range token { + if !strings.ContainsRune(tokenCharset, ch) { + t.Errorf("token contains invalid character: %c", ch) + } + } +} + +func TestGetOrGenerateAuthToken_FromEnvironment(t *testing.T) { + expectedToken := "test-token-from-env" + os.Setenv("MCP_GATEWAY_AUTH_TOKEN", expectedToken) + defer os.Unsetenv("MCP_GATEWAY_AUTH_TOKEN") + + token, wasGenerated, err := getOrGenerateAuthToken() + if err != nil { + t.Fatalf("getOrGenerateAuthToken() failed: %v", err) + } + + if token != expectedToken { + t.Errorf("expected token %q, got %q", expectedToken, token) + } + + if wasGenerated { + t.Error("expected wasGenerated to be false when token is from environment") + } +} + +func TestGetOrGenerateAuthToken_Generated(t *testing.T) { + os.Unsetenv("MCP_GATEWAY_AUTH_TOKEN") + + token, wasGenerated, err := getOrGenerateAuthToken() + if err != nil { + t.Fatalf("getOrGenerateAuthToken() failed: %v", err) + } + + if len(token) != tokenLength { + t.Errorf("expected token length %d, got %d", tokenLength, len(token)) + } + + if !wasGenerated { + t.Error("expected wasGenerated to be true when token is generated") + } +} + +func TestGetOrGenerateAuthToken_EmptyEnvironment(t *testing.T) { + os.Setenv("MCP_GATEWAY_AUTH_TOKEN", "") + defer os.Unsetenv("MCP_GATEWAY_AUTH_TOKEN") + + token, wasGenerated, err := getOrGenerateAuthToken() + if err != nil { + t.Fatalf("getOrGenerateAuthToken() failed: %v", err) + } + + if len(token) != tokenLength { + t.Errorf("expected token length %d, got %d", tokenLength, len(token)) + } + + if !wasGenerated { + t.Error("expected wasGenerated to be true when environment token is empty") + } +} + +func TestAuthenticationMiddleware_HealthEndpoint(t *testing.T) { + authToken := "test-token-123" + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("healthy")) + }) + + middleware := authenticationMiddleware(authToken, handler) + + req := httptest.NewRequest(http.MethodGet, "/health", nil) + w := httptest.NewRecorder() + + middleware.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status %d for /health, got %d", http.StatusOK, w.Code) + } +} + +func TestAuthenticationMiddleware_BearerAuth_Valid(t *testing.T) { + authToken := "test-token-123" + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("success")) + }) + + middleware := authenticationMiddleware(authToken, handler) + + req := httptest.NewRequest(http.MethodGet, "/sse", nil) + // Set Bearer token in Authorization header + req.Header.Set("Authorization", "Bearer "+authToken) + w := httptest.NewRecorder() + + middleware.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("expected status %d with valid bearer auth, got %d", http.StatusOK, w.Code) + } +} + +func TestAuthenticationMiddleware_BearerAuth_Invalid(t *testing.T) { + authToken := "test-token-123" + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("success")) + }) + + middleware := authenticationMiddleware(authToken, handler) + + req := httptest.NewRequest(http.MethodGet, "/sse", nil) + req.Header.Set("Authorization", "Bearer wrong-token") + w := httptest.NewRecorder() + + middleware.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected status %d with invalid bearer auth, got %d", http.StatusUnauthorized, w.Code) + } +} + +func TestAuthenticationMiddleware_NoAuth(t *testing.T) { + authToken := "test-token-123" + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("success")) + }) + + middleware := authenticationMiddleware(authToken, handler) + + req := httptest.NewRequest(http.MethodGet, "/sse", nil) + w := httptest.NewRecorder() + + middleware.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected status %d with no auth, got %d", http.StatusUnauthorized, w.Code) + } + + // Check for WWW-Authenticate header + if w.Header().Get("WWW-Authenticate") == "" { + t.Error("expected WWW-Authenticate header to be set") + } +} + +func TestAuthenticationMiddleware_BearerAuth_MalformedHeader(t *testing.T) { + authToken := "test-token-123" + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("success")) + }) + + middleware := authenticationMiddleware(authToken, handler) + + // Test with malformed Authorization headers - all should fail + malformedHeaders := []string{ + "bearer " + authToken, // lowercase bearer + "Basic " + authToken, // wrong auth type + "Bearer", // missing token + authToken, // missing Bearer prefix + "Bearer " + authToken, // extra space + } + + for _, header := range malformedHeaders { + req := httptest.NewRequest(http.MethodGet, "/sse", nil) + req.Header.Set("Authorization", header) + w := httptest.NewRecorder() + + middleware.ServeHTTP(w, req) + + if w.Code != http.StatusUnauthorized { + t.Errorf("expected status %d with malformed header %q, got %d", http.StatusUnauthorized, header, w.Code) + } + } +} + +func TestFormatGatewayURL(t *testing.T) { + tests := []struct { + port int + endpoint string + expected string + }{ + {8811, "/sse", "http://localhost:8811/sse"}, + {3000, "/mcp", "http://localhost:3000/mcp"}, + {80, "/test", "http://localhost:80/test"}, + } + + for _, tt := range tests { + result := formatGatewayURL(tt.port, tt.endpoint) + if result != tt.expected { + t.Errorf("formatGatewayURL(%d, %q) = %q, want %q", tt.port, tt.endpoint, result, tt.expected) + } + } +} + +func TestFormatBearerToken(t *testing.T) { + authToken := "test-token-123" + result := formatBearerToken(authToken) + + expected := "Authorization: Bearer " + authToken + if result != expected { + t.Errorf("expected %q, got %q", expected, result) + } + + // Verify it has the correct prefix + if !strings.HasPrefix(result, "Authorization: Bearer ") { + t.Errorf("expected result to start with 'Authorization: Bearer ', got %q", result) + } +}
pkg/gateway/capabilitites.go+43 −2 modified@@ -12,6 +12,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" "golang.org/x/sync/errgroup" + "github.com/docker/mcp-gateway/pkg/catalog" "github.com/docker/mcp-gateway/pkg/log" "github.com/docker/mcp-gateway/pkg/telemetry" ) @@ -56,6 +57,32 @@ func (caps *Capabilities) getToolByName(toolName string) (ToolRegistration, erro return ToolRegistration{}, fmt.Errorf("unable to find tool") } +// getToolNamePrefix returns the prefix to use for tool names based on server configuration +// and gateway options. If ServerSpec.Prefix is set, it always uses that. Otherwise, it +// uses the server name if ToolNamePrefix feature flag is enabled. +func (g *Gateway) getToolNamePrefix(serverConfig *catalog.ServerConfig) string { + // If explicit prefix is set in server config, always use it + if serverConfig.Spec.Prefix != "" { + return serverConfig.Spec.Prefix + } + + // Otherwise, use server name if tool-name-prefix feature is enabled + if g.ToolNamePrefix { + return serverConfig.Name + } + + // No prefix + return "" +} + +// prefixToolName adds a prefix to a tool name if prefix is not empty +func prefixToolName(prefix, toolName string) string { + if prefix == "" { + return toolName + } + return prefix + ":" + toolName +} + func (caps *Capabilities) getPromptByName(promptName string) (PromptRegistration, error) { for _, prompt := range caps.Prompts { if prompt.Prompt.Name == promptName { @@ -117,13 +144,21 @@ func (g *Gateway) listCapabilities(ctx context.Context, configuration Configurat // Record the number of tools discovered from this server telemetry.RecordToolList(ctx, serverConfig.Name, len(tools.Tools)) + // Determine the prefix to use for this server's tools + prefix := g.getToolNamePrefix(serverConfig) + for _, tool := range tools.Tools { if !isToolEnabled(configuration, serverConfig.Name, serverConfig.Spec.Image, tool.Name, g.ToolNames) { continue } + + // Create a copy of the tool and apply prefix to its name + prefixedTool := *tool + prefixedTool.Name = prefixToolName(prefix, tool.Name) + capabilities.Tools = append(capabilities.Tools, ToolRegistration{ ServerName: serverConfig.Name, - Tool: tool, + Tool: &prefixedTool, Handler: g.mcpServerToolHandler(serverConfig, g.mcpServer, tool.Annotations), }) } @@ -199,6 +234,12 @@ func (g *Gateway) listCapabilities(ctx context.Context, configuration Configurat case toolGroup != nil: var capabilities Capabilities + // For POCI tools, use server name as prefix if feature flag is enabled + var prefix string + if g.ToolNamePrefix { + prefix = serverName + } + for _, tool := range *toolGroup { if !isToolEnabled(configuration, serverName, "", tool.Name, g.ToolNames) { continue @@ -218,7 +259,7 @@ func (g *Gateway) listCapabilities(ctx context.Context, configuration Configurat } mcpTool := mcp.Tool{ - Name: tool.Name, + Name: prefixToolName(prefix, tool.Name), Description: tool.Description, InputSchema: schema, }
pkg/gateway/config.go+1 −0 modified@@ -34,5 +34,6 @@ type Options struct { OAuthInterceptorEnabled bool McpOAuthDcrEnabled bool DynamicTools bool + ToolNamePrefix bool LogFilePath string }
pkg/gateway/dynamic_mcps.go+75 −69 modified@@ -370,6 +370,80 @@ func (a *serverToolSetAdapter) Tools(ctx context.Context) ([]*codemode.ToolWithH return result, nil } +// addRemoteOAuthServer handles the OAuth setup for a remote OAuth server +// It registers the provider, starts it, and handles authorization through elicitation or direct URL +func (g *Gateway) addRemoteOAuthServer(ctx context.Context, serverName string, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Register DCR client with DD so user can authorize + if err := oauth.RegisterProviderForLazySetup(ctx, serverName); err != nil { + log.Logf("Warning: Failed to register OAuth provider for %s: %v", serverName, err) + } + + // Start provider + g.startProvider(ctx, serverName) + + // Check if current serverSession supports elicitations + if req.Session.InitializeParams().Capabilities != nil && req.Session.InitializeParams().Capabilities.Elicitation != nil { + // Elicit a response from the client asking whether to open a browser for authorization + elicitResult, err := req.Session.Elicit(ctx, &mcp.ElicitParams{ + Message: fmt.Sprintf("Would you like to open a browser to authorize the '%s' server?", serverName), + RequestedSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "authorize": { + Type: "boolean", + Description: "Whether to open the browser for authorization", + }, + }, + Required: []string{"authorize"}, + }, + }) + if err != nil { + log.Logf("Warning: Failed to elicit authorization response for %s: %v", serverName, err) + } else if elicitResult.Action == "accept" && elicitResult.Content != nil { + // Check if user authorized + if authorize, ok := elicitResult.Content["authorize"].(bool); ok && authorize { + // User agreed to authorize, call the OAuth authorize function + client := desktop.NewAuthClient() + authResponse, err := client.PostOAuthApp(ctx, serverName, "", false) + if err != nil { + log.Logf("Warning: Failed to start OAuth flow for %s: %v", serverName, err) + } else if authResponse.BrowserURL != "" { + log.Logf("Opening browser for authentication: %s", authResponse.BrowserURL) + } else { + log.Logf("Warning: OAuth provider for %s does not exist", serverName) + } + } + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{ + Text: fmt.Sprintf("Successfully added server '%s'. Authorization completed.", serverName), + }}, + }, nil + } + + // Client doesn't support elicitations, get the login link and include it in the response + client := desktop.NewAuthClient() + // Set context flag to enable disableAutoOpen parameter + ctxWithFlag := context.WithValue(ctx, contextkeys.OAuthInterceptorEnabledKey, true) + authResponse, err := client.PostOAuthApp(ctxWithFlag, serverName, "", true) + if err != nil { + log.Logf("Warning: Failed to get OAuth URL for %s: %v", serverName, err) + } else if authResponse.BrowserURL != "" { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{ + Text: fmt.Sprintf("Successfully added server '%s'. To authorize this server, please open the following URL in your browser:\n\n%s", serverName, authResponse.BrowserURL), + }}, + }, nil + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{ + Text: fmt.Sprintf("Successfully added server '%s'. You will need to authorize this server with: docker mcp oauth authorize %s", serverName, serverName), + }}, + }, nil +} + // mcpAddTool implements a tool for adding new servers to the registry func (g *Gateway) createMcpAddTool(clientConfig *clientConfig) *ToolRegistration { tool := &mcp.Tool{ @@ -468,75 +542,7 @@ func (g *Gateway) createMcpAddTool(clientConfig *clientConfig) *ToolRegistration // Register DCR client and start OAuth provider if this is a remote OAuth server if g.McpOAuthDcrEnabled && serverConfig != nil && serverConfig.Spec.IsRemoteOAuthServer() { - // Register DCR client with DD so user can authorize - if err := oauth.RegisterProviderForLazySetup(ctx, serverName); err != nil { - log.Logf("Warning: Failed to register OAuth provider for %s: %v", serverName, err) - } - - // Start provider - g.startProvider(ctx, serverName) - - // Check if current serverSession supports elicitations - if req.Session.InitializeParams().Capabilities != nil && req.Session.InitializeParams().Capabilities.Elicitation != nil { - // Elicit a response from the client asking whether to open a browser for authorization - elicitResult, err := req.Session.Elicit(ctx, &mcp.ElicitParams{ - Message: fmt.Sprintf("Would you like to open a browser to authorize the '%s' server?", serverName), - RequestedSchema: &jsonschema.Schema{ - Type: "object", - Properties: map[string]*jsonschema.Schema{ - "authorize": { - Type: "boolean", - Description: "Whether to open the browser for authorization", - }, - }, - Required: []string{"authorize"}, - }, - }) - if err != nil { - log.Logf("Warning: Failed to elicit authorization response for %s: %v", serverName, err) - } else if elicitResult.Action == "accept" && elicitResult.Content != nil { - // Check if user authorized - if authorize, ok := elicitResult.Content["authorize"].(bool); ok && authorize { - // User agreed to authorize, call the OAuth authorize function - client := desktop.NewAuthClient() - authResponse, err := client.PostOAuthApp(ctx, serverName, "", false) - if err != nil { - log.Logf("Warning: Failed to start OAuth flow for %s: %v", serverName, err) - } else if authResponse.BrowserURL != "" { - log.Logf("Opening browser for authentication: %s", authResponse.BrowserURL) - } else { - log.Logf("Warning: OAuth provider for %s does not exist", serverName) - } - } - } - - return &mcp.CallToolResult{ - Content: []mcp.Content{&mcp.TextContent{ - Text: fmt.Sprintf("Successfully added server '%s'. Authorization completed.", serverName), - }}, - }, nil - } - - // Client doesn't support elicitations, get the login link and include it in the response - client := desktop.NewAuthClient() - // Set context flag to enable disableAutoOpen parameter - ctxWithFlag := context.WithValue(ctx, contextkeys.OAuthInterceptorEnabledKey, true) - authResponse, err := client.PostOAuthApp(ctxWithFlag, serverName, "", true) - if err != nil { - log.Logf("Warning: Failed to get OAuth URL for %s: %v", serverName, err) - } else if authResponse.BrowserURL != "" { - return &mcp.CallToolResult{ - Content: []mcp.Content{&mcp.TextContent{ - Text: fmt.Sprintf("Successfully added server '%s'. To authorize this server, please open the following URL in your browser:\n\n%s", serverName, authResponse.BrowserURL), - }}, - }, nil - } - - return &mcp.CallToolResult{ - Content: []mcp.Content{&mcp.TextContent{ - Text: fmt.Sprintf("Successfully added server '%s'. You will need to authorize this server with: docker mcp oauth authorize %s", serverName, serverName), - }}, - }, nil + return g.addRemoteOAuthServer(ctx, serverName, req) } return &mcp.CallToolResult{
pkg/gateway/run.go+35 −1 modified@@ -64,6 +64,11 @@ type Gateway struct { // Track registered capabilities per server for proper reload handling capabilitiesMu sync.RWMutex serverCapabilities map[string]*ServerCapabilities + + // authToken stores the authentication token for SSE/streaming modes + authToken string + // authTokenWasGenerated indicates whether the token was auto-generated or from environment + authTokenWasGenerated bool } func NewGateway(config Config, docker docker.Client) *Gateway { @@ -282,18 +287,47 @@ func (g *Gateway) Run(ctx context.Context) error { return nil } + // Initialize authentication token for SSE and streaming modes + transport := strings.ToLower(g.Transport) + if transport == "sse" || transport == "http" || transport == "streamable" || transport == "streaming" || transport == "streamable-http" { + token, wasGenerated, err := getOrGenerateAuthToken() + if err != nil { + return fmt.Errorf("failed to initialize auth token: %w", err) + } + g.authToken = token + g.authTokenWasGenerated = wasGenerated + } + // Start the server - switch strings.ToLower(g.Transport) { + switch transport { case "stdio": log.Log("> Start stdio server") return g.startStdioServer(ctx, os.Stdin, os.Stdout) case "sse": log.Log("> Start sse server on port", g.Port) + endpoint := "/sse" + url := formatGatewayURL(g.Port, endpoint) + if g.authTokenWasGenerated { + log.Logf("> Gateway URL: %s", url) + log.Logf("> Use Bearer token: %s", formatBearerToken(g.authToken)) + } else { + log.Logf("> Gateway URL: %s", url) + log.Logf("> Use Bearer token from MCP_GATEWAY_AUTH_TOKEN environment variable") + } return g.startSseServer(ctx, ln) case "http", "streamable", "streaming", "streamable-http": log.Log("> Start streaming server on port", g.Port) + endpoint := "/mcp" + url := formatGatewayURL(g.Port, endpoint) + if g.authTokenWasGenerated { + log.Logf("> Gateway URL: %s", url) + log.Logf("> Use Bearer token: %s", formatBearerToken(g.authToken)) + } else { + log.Logf("> Gateway URL: %s", url) + log.Logf("> Use Bearer token from MCP_GATEWAY_AUTH_TOKEN environment variable") + } return g.startStreamingServer(ctx, ln) default:
pkg/gateway/transport.go+16 −2 modified@@ -24,8 +24,15 @@ func (g *Gateway) startSseServer(ctx context.Context, ln net.Listener) error { return g.mcpServer }, nil) mux.Handle("/sse", sseHandler) + + // Wrap with authentication middleware + var handler http.Handler = mux + if g.authToken != "" { + handler = authenticationMiddleware(g.authToken, mux) + } + httpServer := &http.Server{ - Handler: mux, + Handler: handler, } go func() { <-ctx.Done() @@ -42,8 +49,15 @@ func (g *Gateway) startStreamingServer(ctx context.Context, ln net.Listener) err return g.mcpServer }, nil) mux.Handle("/mcp", streamHandler) + + // Wrap with authentication middleware + var handler http.Handler = mux + if g.authToken != "" { + handler = authenticationMiddleware(g.authToken, mux) + } + httpServer := &http.Server{ - Handler: mux, + Handler: handler, } go func() {
Vulnerability mechanics
Generated on May 9, 2026. Inputs: CWE entries + fix-commit diffs from this CVE's patches. Citations validated against bundle.
References
7- github.com/advisories/GHSA-46gc-mwh4-cc5rghsaADVISORY
- nvd.nist.gov/vuln/detail/CVE-2025-64443ghsaADVISORY
- github.com/docker/mcp-gateway/commit/6b076b2479d8d1345c50c112119c62978d46858eghsax_refsource_MISCWEB
- github.com/docker/mcp-gateway/commit/fe073985c8eb6e0c9317d2f198c07686f70ea06dghsaWEB
- github.com/docker/mcp-gateway/pull/190ghsaWEB
- github.com/docker/mcp-gateway/security/advisories/GHSA-46gc-mwh4-cc5rghsax_refsource_CONFIRMWEB
- modelcontextprotocol.io/specification/2025-06-18/basic/transportsghsaWEB
News mentions
0No linked articles in our index yet.